In [28]:
from sqlalchemy.orm import Session
from core.config import settings

from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import PydanticOutputParser

from core.prompts import STORY_PROMPT
from models.story import Story, StoryNode
from core.models import StoryLLMResponse, StoryNodeLLM


In [29]:
# Tạo LLM
llm = ChatGoogleGenerativeAI(
    model="gemini-2.5-flash",
    api_key=settings.GEMINI_API_KEY
)

# Pydantic parser cho schema StoryLLMResponse
story_parser = PydanticOutputParser(pydantic_object=StoryLLMResponse)


In [30]:
prompt = ChatPromptTemplate.from_messages([
    (
        "system",
        STORY_PROMPT
    ),
    (
        "human",
        "Create the story with this theme: {theme}"
    )
]).partial(format_instructions=story_parser.get_format_instructions())

In [31]:
prompt.format(theme="Hero aventures")

'System: \n                You are a creative story writer that creates engaging choose-your-own-adventure stories.\n                Generate a complete branching story with multiple paths and endings in the JSON format I\'ll specify.\n\n                The story should have:\n                1. A compelling title\n                2. A starting situation (root node) with 2-3 options\n                3. Each option should lead to another node with its own options\n                4. Some paths should lead to endings (both winning and losing)\n                5. At least one path should lead to a winning ending\n\n                Story structure requirements:\n                - Each node should have 2 options except for ending nodes\n                - The story should be 3 levels deep (including root node)\n                - Add variety in the path lengths (some end earlier, some later)\n                - Make sure there\'s at least one winning path\n\n                Output your story i

In [32]:
raw_response = llm.invoke(prompt.invoke({"theme": "Hero aventures"}))
raw_response

AIMessage(content='```json\n{\n  "title": "The Quest for the Sunstone",\n  "rootNode": {\n    "content": "You are Elara, a renowned adventurer. A dark blight spreads, threatening the kingdom. The elders believe the legendary Sunstone, hidden in the Whispering Peaks, can stop it. Do you accept the perilous quest?",\n    "isEnding": false,\n    "isWinningEnding": false,\n    "options": [\n      {\n        "text": "Accept the quest immediately, preparing for the mountains.",\n        "nextNode": {\n          "content": "You brave the treacherous ascent. A sudden blizzard engulfs you. You see a small, icy cave to your left and a steep, but clearer, ridge to your right.",\n          "isEnding": false,\n          "isWinningEnding": false,\n          "options": [\n            {\n              "text": "Seek shelter in the icy cave.",\n              "nextNode": {\n                "content": "Inside the cave, you find warmth but also a sleeping Frost Troll guarding a small treasure. It stirs. Do

In [33]:

def post_process_response(raw_output: str) -> StoryLLMResponse:
  import json
  import re

  if raw_output.startswith("```json"):
      raw_output = raw_output[len("```json"):]

  raw_output = raw_output.strip()  # bỏ \n đầu/cuối
  if raw_output.endswith("```"):
      raw_output = raw_output[:-3].strip()

  # Clean the JSON string
  cleaned_json = raw_output.replace("\\'", "'")  # Remove escape characters before single quotes
  cleaned_json = re.sub(r'[^\x20-\x7E]', '', cleaned_json)  # Remove non-printable characters

  #Parse
  data = json.loads(cleaned_json)
  pretty_json = json.dumps(data, indent=4)
  return pretty_json

In [34]:
ans = post_process_response(raw_response.content)
print(ans)

{
    "title": "The Quest for the Sunstone",
    "rootNode": {
        "content": "You are Elara, a renowned adventurer. A dark blight spreads, threatening the kingdom. The elders believe the legendary Sunstone, hidden in the Whispering Peaks, can stop it. Do you accept the perilous quest?",
        "isEnding": false,
        "isWinningEnding": false,
        "options": [
            {
                "text": "Accept the quest immediately, preparing for the mountains.",
                "nextNode": {
                    "content": "You brave the treacherous ascent. A sudden blizzard engulfs you. You see a small, icy cave to your left and a steep, but clearer, ridge to your right.",
                    "isEnding": false,
                    "isWinningEnding": false,
                    "options": [
                        {
                            "text": "Seek shelter in the icy cave.",
                            "nextNode": {
                                "content": "Inside the 

In [14]:
response_text = ans

In [15]:
if hasattr(raw_response, "content"):
    response_text = raw_response.content

story_structure = story_parser.parse(response_text)

In [16]:
story_structure.rootNode

StoryNodeLLM(content='You are a brave adventurer, renowned throughout the land for your courage. You stand at the entrance of a dark cave, rumored to be the lair of a fearsome dragon. Before you lie two paths: a narrow, winding passage and a wider, more direct route. What do you do?', isEnding=False, isWinningEnding=False, options=[StoryOptionLLM(text='Venture down the narrow passage.', nextNode={'content': 'The narrow passage twists and turns, and you hear the sound of dripping water. After a while, you come to a fork in the path. One path leads deeper into the darkness, the other leads to a faint glimmer of light.', 'isEnding': False, 'isWinningEnding': False, 'options': [{'text': 'Follow the path into the deeper darkness.', 'nextNode': {'content': 'You stumble along in the dark, and suddenly, the ground gives way! You plummet into a deep pit, your adventure ending abruptly.', 'isEnding': True, 'isWinningEnding': False, 'options': None}}, {'text': 'Follow the glimmer of light.', 'nex

In [39]:
from db.database import get_db, SessionLocal
db = SessionLocal()

from db.database import create_tables

create_tables()

In [40]:
story_db = Story(title=story_structure.title, session_id="abcd")
db.add(story_db)
db.flush()

In [41]:
root_node_data = story_structure.rootNode
if isinstance(root_node_data, dict):
    root_node_data = StoryNodeLLM.model_validate(root_node_data)

In [42]:
root_node_data

StoryNodeLLM(content="You are a seasoned adventurer, renowned for your courage and cunning. You've received a cryptic map hinting at the location of the Sunken City of Eldoria, a place rumored to hold unimaginable treasures and forgotten magic. You stand at the edge of the Whispering Sea, ready to begin your quest. What do you do?", isEnding=False, isWinningEnding=False, options=[StoryOptionLLM(text='Set sail immediately, trusting your instincts and the map.', nextNode={'content': 'The journey across the Whispering Sea is fraught with peril. Violent storms, monstrous sea creatures, and treacherous currents test your resolve. After weeks at sea, you finally reach the coordinates marked on the map. You find a submerged city. What do you do?', 'isEnding': False, 'isWinningEnding': False, 'options': [{'text': 'Dive into the city, exploring the ruins in search of treasure.', 'nextNode': {'content': 'You descend into the depths, navigating the eerie, waterlogged streets of Eldoria. You disco

In [43]:
root_node_data.isWinningEnding

False

In [44]:
print(StoryNode.__table__.columns.keys())


['id', 'story_id', 'content', 'is_root', 'is_ending', 'is_winning_ending', 'options']


In [45]:
def process_story_node(db: Session, story_id: int, node_data: StoryNodeLLM, is_root: bool = False) -> StoryNode:
      node = StoryNode(
          story_id=story_id,
          content=node_data.content if hasattr(node_data, "content") else node_data["content"],
          is_root=is_root,
          is_ending=node_data.isEnding if hasattr(node_data, "isEnding") else node_data["isEnding"],
          is_winning_ending=node_data.isWinningEnding if hasattr(node_data, "isWinningEnding") else node_data["isWinningEnding"],
          options=[]
      )
      db.add(node)
      db.flush()

      if not node.is_ending and (hasattr(node_data, "options") and node_data.options):
          options_list = []
          for option_data in node_data.options:
              next_node = option_data.nextNode

              if isinstance(next_node, dict):
                  next_node = StoryNodeLLM.model_validate(next_node)

              child_node = process_story_node(db, story_id, next_node, False)

              options_list.append({
                  "text": option_data.text,
                  "node_id": child_node.id
              })

          node.options = options_list

      db.flush()
      return node

In [46]:
process_story_node(db, story_db.id, root_node_data, is_root=True)
db.commit()
story_db

<models.story.Story at 0x1e5c38d32c0>

In [None]:
class StoryGenerator:

  @classmethod
  def _get_llm(cls):
    return ChatGoogleGenerativeAI(model="gemini-2.0-flash-lite", api_key=settings.GEMINI_API_KEY)
  
  @classmethod
  def generate_story(cls, db: Session, session_id: str, theme: str = "fantasy")-> Story:
      llm = cls._get_llm()
      story_parser = PydanticOutputParser(pydantic_object=StoryLLMResponse)

      prompt = ChatPromptTemplate.from_messages([
          (
              "system",
              STORY_PROMPT
          ),
          (
              "human",
              f"Create the story with this theme: {theme}"
          )
      ]).partial(format_instructions=story_parser.get_format_instructions())

      raw_response = llm.invoke(prompt.invoke({}))

      response_text = raw_response
      if hasattr(raw_response, "content"):
          response_text = raw_response.content

      story_structure = story_parser.parse(response_text)

      story_db = Story(title=story_structure.title, session_id=session_id)
      db.add(story_db)
      db.flush()

      root_node_data = story_structure.rootNode
      if isinstance(root_node_data, dict):
          root_node_data = StoryNodeLLM.model_validate(root_node_data)

      cls._process_story_node(db, story_db.id, root_node_data, is_root=True)

      db.commit()
      return story_db

  @classmethod
  def _process_story_node(cls, db: Session, story_id: int, node_data: StoryNodeLLM, is_root: bool = False) -> StoryNode:
      node = StoryNode(
          story_id=story_id,
          content=node_data.content if hasattr(node_data, "content") else node_data["content"],
          is_root=is_root,
          is_ending=node_data.isEnding if hasattr(node_data, "isEnding") else node_data["isEnding"],
          is_winning_ending=node_data.isWinningEnding if hasattr(node_data, "isWinningEnding") else node_data["isWinningEnding"],
          options=[]
      )
      db.add(node)
      db.flush()

      if not node.is_ending and (hasattr(node_data, "options") and node_data.options):
          options_list = []
          for option_data in node_data.options:
              next_node = option_data.nextNode

              if isinstance(next_node, dict):
                  next_node = StoryNodeLLM.model_validate(next_node)

              child_node = cls._process_story_node(db, story_id, next_node, False)

              options_list.append({
                  "text": option_data.text,
                  "node_id": child_node.id
              })

          node.options = options_list

      db.flush()
      return node
