## Install Libraries


In [None]:
!pip install --upgrade pyyaml roman python-Levenshtein transformers scipy langchain langchain-core langchain-community

In [None]:
!pip install -q bitsandbytes accelerate

In [None]:
!pip install -U bitsandbytes

## preparing LLM

In [1]:
import os
import logging
import json
import re
import signal
from contextlib import contextmanager
from pathlib import Path
from typing import Any

import yaml
import roman
import Levenshtein
from transformers import AutoTokenizer
from scipy.special import log_softmax
from langchain_core.prompts import PromptTemplate


logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)-8s %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
print("Initial imports and logging configured.")

Initial imports and logging configured.


In [2]:
LOCALHOST = 'http://localhost'
DEFAULT_PORT = 8000

class ServerConfig:
    def __init__(self, engine, host, port, server_type, tensor_parallel_size):
        self.engine = engine
        self.host = host
        self.port = port
        self.server_type = server_type
        self.tensor_parallel_size = tensor_parallel_size

    @staticmethod
    def from_config(config):
        return ServerConfig(
            engine=config['engine'],
            host=config['host'],
            port=config.get('port', DEFAULT_PORT),
            server_type=config['server_type'],
            tensor_parallel_size=config['tensor_parallel_size']
        )

    def __getitem__(self, key):
        return getattr(self, key)

    def __hash__(self):
        return hash((self.engine, self.host, self.port,
                     self.server_type, self.tensor_parallel_size))

    def __eq__(self, other):
        return (self.engine, self.host, self.port, self.server_type, self.tensor_parallel_size) == (other.engine, other.host, other.port, other.server_type, self.tensor_parallel_size)

In [3]:
def recursive_lowercase_keys(d):
    if type(d) is dict:
        new_d = {}
        for key in d:
            new_d[key.lower()] = recursive_lowercase_keys(d[key])
        return new_d
    else:
        return d

class Config:
    def __init__(self, config, parent=None):
        self.parent_config = parent
        self.config = config
        for key in self.config:
            if type(self.config[key]) is dict:
                self.config[key] = Config(self.config[key], self)

    @staticmethod
    def load_from_dict(all_confs, config_names):
        config = {}
        for name in config_names:
            if ',' in name:
                for n in name.split(','):
                    config.update(all_confs[n])
            else:
                config.update(all_confs[name])

        return Config(config, None)

    def __getattr__(self, name):
        if name in self.config:
            return self.config[name]
        elif self.parent_config is not None:
            return getattr(self.parent_config, name)
        else:
            raise AttributeError(f"Config has no attribute {name}.")

    def __getitem__(self, name):
        return getattr(self, name)

    def __contains__(self, name):
        return name in self.config

    def get(self, name, default=None):
        try:
            return self[name]
        except AttributeError:
            return default

In [4]:
tokenizers = {}

def init_logging(logging_level):
    logging_level = logging_level.upper()
    assert logging_level in ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL']
    logging.getLogger().setLevel(logging_level)

class TimeoutException(Exception): pass

@contextmanager
def time_limit(seconds):
    def signal_handler(signum, frame):
        raise TimeoutException("Timed out!")
    try:
        yield
    finally:
        pass

class Filter:
    def __init__(self, filter_func):
        self.filter_func = filter_func

    @staticmethod
    def wrap_preprocessor(preprocessor, filter):
        return Filter(lambda s: filter(preprocessor(s)))

    def __call__(self, *args, **kwargs):
        try:
            return self.filter_func(*args, **kwargs)
        except:
            return self.filter_func(*args)

    def __add__(self, other):
        return Filter(lambda s: self.filter_func(s) and other.filter_func(s))

def min_max_tokens_filter(min_tokens, max_tokens, tokenizer_model_string='gpt2', filter_empty=True):
    global tokenizers
    if tokenizer_model_string not in tokenizers:
        try:
            tokenizer = AutoTokenizer.from_pretrained(tokenizer_model_string)
            tokenizers[tokenizer_model_string] = tokenizer
        except Exception as e:
            logging.warning(f"Could not load tokenizer {tokenizer_model_string}. Token filtering will be skipped. Error: {e}")
            return Filter(lambda s: True)
    else:
        tokenizer = tokenizers[tokenizer_model_string]

    filter_func = Filter(lambda s: min_tokens <= len(tokenizer.encode(s.strip())) <= max_tokens)
    if filter_empty:
        filter_func = filter_func + Filter(lambda s: len(s.strip()) > 0)
    return filter_func

def levenshtein_ratio_filter(passages_to_match, threshold=0.8):
    return Filter(lambda s: all([all([Levenshtein.ratio(sub_s, passage) < threshold for passage in passages_to_match]) for sub_s in s.split()]))

def word_filter(word_list):
    return Filter(lambda s: all([word not in s for word in word_list]))

def list_next_number_format_filter():
    bad_regex = re.compile(r'[^]\d+\.')
    return Filter(lambda s: not bad_regex.search(s))

def extract_choice_logprobs(full_completion, choices=['yes', 'no'], default_logprobs=[-1e8, -1e8], case_sensitive=False):
    batch_logprobs = []
    for choice in full_completion['choices']:
        all_logprobs = choice['logprobs']['top_logprobs']
        found = False
        logprobs = [l for l in default_logprobs]
        for token_logprobs in all_logprobs:
            for key, value in token_logprobs.items():
                for i, choice in enumerate(choices):
                    if choice in key or (not case_sensitive and choice.lower() in key.lower()):
                        found = True
                        logprobs[i] = value
            if found:
                break
        batch_logprobs.append(log_softmax(logprobs))
    return batch_logprobs

In [5]:
warned_prompt_format = {'openai_response_prefix': False}

def format_langchain_prompt(langchain_prompt, **kwargs):
    kwargs = {k: v for k, v in kwargs.items() if k in langchain_prompt.input_variables}
    return langchain_prompt.format(**kwargs)

class TemplatePromptBuilder:
    def __init__(self, base_dict):
        self.instruction = PromptTemplate.from_template(template=base_dict['instruction'],)
        self.system_message = PromptTemplate.from_template(template=base_dict['system_message'],) if 'system_message' in base_dict else None
        self.response_prefix = PromptTemplate.from_template(template=base_dict['response_prefix'],) if 'response_prefix' in base_dict else None
        self.output_prefix = PromptTemplate.from_template(template=base_dict['output_prefix'],) if 'output_prefix' in base_dict else None

    def format(self, **kwargs):
        return PromptBuilder(self, **kwargs)

class PromptBuilder:
    def __init__(self, template_prompt_builder, **kwargs):
        self.instruction = format_langchain_prompt(template_prompt_builder.instruction, **kwargs)
        self.system_message = format_langchain_prompt(template_prompt_builder.system_message, **kwargs) \
            if template_prompt_builder.system_message is not None else None
        self.response_prefix = format_langchain_prompt(template_prompt_builder.response_prefix, **kwargs) \
            if template_prompt_builder.response_prefix is not None else None
        self.output_prefix = format_langchain_prompt(template_prompt_builder.output_prefix, **kwargs) \
            if template_prompt_builder.output_prefix is not None else None

    def render_for_llm_format(self, prompt_format):
        if prompt_format not in ['openai-chat', 'llama2-chat', 'none']:
            raise NotImplementedError(f"Prompt format {prompt_format} not implemented.")

        prompt = self.instruction.format().lstrip()

        if prompt_format == 'openai-chat':
            if self.response_prefix is not None:
                global warned_prompt_format
                if warned_prompt_format['openai_response_prefix']:
                    logging.warning(f"Response prefix is not supported for prompt format {prompt_format}. Appending to end of instruction instead.")
                    warned_prompt_format['openai_response_prefix'] = True
                prompt += '\n\n\n\nThe output is already partially generated. Continue from:\n\n' + self.response_prefix.format()
            messages = [{'role': 'user', 'content': prompt}]
            if self.system_message is not None:
                messages = [{'role': 'system', 'content': self.system_message.format()}] + messages
            return messages

        else:
            if prompt_format == 'llama2-chat':
                prompt = '[INST]'
                if self.system_message is not None:
                    prompt += ' <<SYS>>\n' + self.system_message.format() + '\n<</SYS>>\n\n'
                else:
                    prompt += ' '
                prompt += self.instruction.format()
                prompt += '[/INST]' if self.instruction.format()[-1] == ' ' else ' [/INST]'
                if self.response_prefix is not None:
                    prompt += self.response_prefix.format() if self.response_prefix.format()[0] == ' ' else ' ' + self.response_prefix.format()
            else:
                if self.system_message is not None:
                    prompt = self.system_message.format() + '\n\n\n\n' + prompt
                if self.response_prefix is not None:
                    prompt = prompt + '\n\n\n\n' + self.response_prefix.format()
            return prompt

def _create_prompt_templates(prompts):
    for key in prompts:
        assert isinstance(prompts[key], dict)
        if 'instruction' not in prompts[key]:
            _create_prompt_templates(prompts[key])
        else:
            prompts[key] = TemplatePromptBuilder(prompts[key])

def load_prompts_from_dict(prompts_dict):
    prompts = prompts_dict.copy()
    _create_prompt_templates(prompts)
    return prompts

In [6]:
import time
import openai
models = {}

class SamplingConfig:
    def __init__(self,
                 server_config,
                 prompt_format,
                 max_tokens=None,
                 temperature=None,
                 top_p=None,
                 frequency_penalty=None,
                 presence_penalty=None,
                 stop=None,
                 n=None,
                 logit_bias=None,
                 logprobs=None):
        self.server_config = server_config
        self.prompt_format = prompt_format
        self.max_tokens = max_tokens
        self.temperature = temperature
        self.top_p = top_p
        self.frequency_penalty = frequency_penalty
        self.presence_penalty = presence_penalty
        self.stop = stop
        self.n = n
        self.logit_bias = logit_bias
        self.logprobs = logprobs

    @staticmethod
    def from_config(config):
        return SamplingConfig(
            server_config=ServerConfig.from_config(config),
            prompt_format=config['prompt_format'],
            max_tokens=config.get('max_tokens', None),
            temperature=config.get('temperature', None),
            top_p=config.get('top_p', None),
            frequency_penalty=config.get('frequency_penalty', None),
            presence_penalty=config.get('presence_penalty', None),
            stop=config.get('stop', None),
            n=config.get('n', None),
            logit_bias=config.get('logit_bias', None),
            logprobs=config.get('logprobs', None)
        )

    def __getitem__(self, key):
        return getattr(self, key)

    def dict(self):
        d = {'model': self.server_config.engine}
        for attr in ['max_tokens', 'temperature', 'top_p', 'frequency_penalty', 'presence_penalty', 'stop', 'n', 'logit_bias', 'logprobs']:
            if getattr(self, attr) is not None:
                d[attr] = getattr(self, attr)
        return d


from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch

model_id = "mistralai/Mistral-7B-Instruct-v0.3"

tokenizer = AutoTokenizer.from_pretrained(model_id)

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto"
)
class LLMClient:

    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer

    def call_with_retry(self, prompt_builder, sampling_config, postprocessor=None,
                        filter=lambda s: True, max_attempts=5, **kwargs):
        for attempt in range(max_attempts):
            try:
                completions, full_obj = self(prompt_builder, sampling_config, **kwargs)

                if postprocessor:
                    completions = postprocessor(completions)

                completions = [c for c in completions if filter(c)]
                if completions:
                    return completions, full_obj

            except Exception as e:
                print(f"ERROR attempt {attempt+1}: {e}")

        raise RuntimeError("Failed after retries.")

    def __call__(self, prompt_builder, sampling_config, **kwargs):
        messages = prompt_builder.render_for_llm_format(sampling_config.prompt_format)

        prompt = self.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
        )

        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)

        max_tokens = sampling_config.max_tokens if sampling_config.max_tokens is not None else 512
        temperature = sampling_config.temperature if sampling_config.temperature is not None else 1.0
        top_p = sampling_config.top_p if sampling_config.top_p is not None else 1.0

        outputs = self.model.generate(
            **inputs,
            max_new_tokens=max_tokens,
            temperature=temperature,
            top_p=top_p,
            do_sample=True,
            pad_token_id=self.tokenizer.eos_token_id,
        )

        generated_ids = outputs[0, inputs.input_ids.shape[-1]:]
        text = self.tokenizer.decode(generated_ids, skip_special_tokens=True)

        return [text], None

llm_client = LLMClient(model, tokenizer)

2025-12-05 14:37:32.207416: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1764945452.365042     242 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1764945452.411622     242 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

2025-12-05 14:37:42 INFO     NumExpr defaulting to 4 threads.


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

## Extract Premise

In [52]:
class Premise:
    @staticmethod
    def load(path):
        with open(path, 'r') as f:
            data = json.load(f)
            return Premise(data['title'], data['premise'])

    def __init__(self, title=None, premise=None):
        self.title = title
        self.premise = premise

    def __str__(self):
        return f'Title: {self.title}\n\nPremise: {self.premise}'

    def save(self, path):
        with open(path, 'w') as f:
            json.dump({
                'title': self.title,
                'premise': self.premise
            }, f, indent=4)

In [53]:
config_yaml_content = """
defaults:
  output_path: output/premise.json
  logging_level: info
  MODEL:
    engine: mistralai/Mistral-7B-Instruct-v0.3
    tensor_parallel_size: 1
    server_type: vllm
    host: http://localhost
    port: 9741
    prompt_format: openai-chat
    temperature: 0.7
    top_p: 0.95
    frequency_penalty: 0
    presence_penalty: 0
    TITLE:
      max_tokens: 32
      stop: []
    PREMISE:
      max_tokens: 200
      stop: ["\n"]
"""


all_confs = recursive_lowercase_keys(yaml.safe_load(config_yaml_content ))
config = Config.load_from_dict(all_confs, ['defaults'])

print("Configuration loaded")


Configuration loaded


In [54]:
def generate_title(premise_object, title_prompts, title_config, llm_client):
    title = llm_client.call_with_retry(
        title_prompts.format(educational_summary_input=educational_summary_input),
        SamplingConfig.from_config(title_config),
        filter=min_max_tokens_filter(0, title_config['max_tokens'])
    )[0]
    premise_object.title = title
    return premise_object


def generate_premise(premise_object, premise_prompts, premise_config, llm_client):
    premise = llm_client.call_with_retry(
        premise_prompts.format(
            title=premise_object.title,
            educational_summary_input=educational_summary_input
        ),
        SamplingConfig.from_config(premise_config),
        filter=min_max_tokens_filter(0, premise_config['max_tokens'])
    )[0]
    premise_object.premise = premise
    return premise_object


In [55]:
educational_summary_input = """
Plants are living things that need care to grow strong and healthy. Every plant starts as a tiny seed. When the seed is placed in soil and given water, it begins to wake up. Soon, small roots grow down into the soil to drink water and collect minerals. After that, a little stem grows upward, reaching for the sunlight.

Plants use sunlight to make their own food in a process called photosynthesis. This helps them grow leaves, flowers, and sometimes fruits or vegetables. Different plants need different amounts of water and sunlight, but all of them need love, attention, and patience. By taking care of plants, children learn responsibility and understand how nature works around them.
"""

prompts_json_content = """
{
   "title": {
     "instruction": "Write a fun, simple, and playful title for a children's story based on this summary: {educational_summary_input}. Keep it short and exciting.",
     "response_prefix": ""
   },
   "premise": {
     "instruction": "Write a one-paragraph story premise suitable for kids. Describe the world, the main character, and the adventure. Use simple words, short sentences, and fun imagery. Educational summary: {educational_summary_input}. Do NOT include the word 'Title' or any headings.",
     "response_prefix": ""
   }
}


"""
prompts_dict = json.loads(prompts_json_content)
prompts = load_prompts_from_dict(prompts_dict)
print("Prompts loaded and templates created.")

Prompts loaded and templates created.


In [56]:
try:
    init_logging(config.logging_level)
    logging.info("Starting premise generation...")


    premise = Premise()

    logging.info("Generating title...")
    generate_title(premise, prompts['title'], config['model']['title'], llm_client)
    logging.info(f'Generated title: {premise.title}')

    logging.info("Generating premise...")

    generate_premise(premise, prompts['premise'], config['model']['premise'], llm_client)
    logging.info(f'Generated premise: {premise.premise}')

    output_path = config['output_path']
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    premise.save(output_path)

    print("\n--- FINAL RESULT ---")
    print(premise)
    print(f"\nPremise object saved to: {output_path}")

except Exception as e:
    logging.error(f"An error occurred during execution. Please check your model configuration and ensure your LLM server is running and accessible. Error: {e}")
    raise e

2025-12-05 15:11:13 INFO     Starting premise generation...
2025-12-05 15:11:13 INFO     Generating title...
2025-12-05 15:11:16 INFO     Generated title: ['"Seed Sprouts Surprise: A Sunlit Tale of Growth and Green Thumbs"']
2025-12-05 15:11:16 INFO     Generating premise...
2025-12-05 15:11:41 INFO     Generated premise: ["In a vibrant, sun-kissed Gardenia Village, where every house has a blooming garden, lives a curious and kind-hearted child named Marigold. One sunny morning, Marigold finds a small, sleepy seed in her grandmother's garden. With a twinkle in her eyes, she decides to nurture this tiny promise of a plant. Marigold carefully places the seed in a pot filled with nutrient-rich soil and waters it gently. As days pass, she watches and waits, watering and talking to her new friend, hoping and praying for it to grow. One magical day, a green shoot pokes through the soil, reaching up towards the sky. Marigold cheers, realizing that she's not just growing a plant, but also lear


--- FINAL RESULT ---
Title: ['"Seed Sprouts Surprise: A Sunlit Tale of Growth and Green Thumbs"']

Premise: ["In a vibrant, sun-kissed Gardenia Village, where every house has a blooming garden, lives a curious and kind-hearted child named Marigold. One sunny morning, Marigold finds a small, sleepy seed in her grandmother's garden. With a twinkle in her eyes, she decides to nurture this tiny promise of a plant. Marigold carefully places the seed in a pot filled with nutrient-rich soil and waters it gently. As days pass, she watches and waits, watering and talking to her new friend, hoping and praying for it to grow. One magical day, a green shoot pokes through the soil, reaching up towards the sky. Marigold cheers, realizing that she's not just growing a plant, but also learning the secrets of life - patience, love, and the magic of nature. This exciting adventure of growing a plant from a seed teaches Marigold about responsibility"]

Premise object saved to: output/premise.json


## Extract Plan

In [57]:
config_yaml_content = """
defaults:
  premise_path: output/premise.json
  output_path: output/plan.json
  logging_level: info
  MODEL:
    engine: mistralai/Mistral-7B-Instruct-v0.3
    tensor_parallel_size: 1
    server_type: vllm
    host: http://localhost
    port: 9741
    prompt_format: openai-chat
    temperature: 0.7
    top_p: 0.9
    frequency_penalty: 0
    presence_penalty: 0
    PLAN:
      SETTING:
        max_tokens: 256
        stop: []
      ENTITY:
        max_attempts: 2
        min_entities: 1
        max_entities: 4
        NAME:
          max_tokens: 16
          stop: ["\n", ",", ":", "("]
        DESCRIPTION:
          max_tokens: 60
      OUTLINE:
        max_attempts: 2
        expansion_policy: breadth-first
        max_depth: 1
        context: ancestors-with-siblings-children
        min_children: 1
        preferred_max_children: 2
        max_children: 3
        EVENT_DEPTH_0:
          max_tokens: 100
        EVENT:
          frequency_penalty: 0.3
          max_tokens: 100
        SCENE:
          context: ancestors-with-siblings
          max_tokens: 80
        ENTITY_DEPTH_0:
          max_tokens: 80
        ENTITY:
          max_tokens: 80

"""


all_confs = recursive_lowercase_keys(yaml.safe_load(config_yaml_content ))
config = Config.load_from_dict(all_confs, ['defaults'])
print("Configuration loaded")

Configuration loaded


In [72]:
prompts_json_content = """
{
  "plan": {
    "setting": {
      "instruction": "Create a fun, colorful, and simple setting for a children's story. Use short sentences and words kids understand. Show what it looks like, sounds like, and feels like. Title: {title}//Premise: {premise}",
      "response_prefix": ""
    },
    "entity": {
      "name": {
        "instruction": "Generate only the next character's name for a children's story. Use fun and simple names. Output only the name. Do not repeat previous names. Title: {title}//Premise: {premise}//Setting: {setting}//Existing Characters: {entity_list}",
        "response_prefix": ""
      },
      "description": {
        "instruction": "Describe the character in one simple sentence. Include what they look like and what makes them special. Avoid hard words. Title: {title}//Premise: {premise}//Setting: {setting}//Character: {entity_name}",
        "response_prefix": ""
      }
    },
    "outline": {
      "event_depth_0": {
        "instruction": "Write the first important event of the story in one short, clear sentence for kids. Only describe the event. Title: {title}//Premise: {premise}//Setting: {setting}//Characters: {entities}",
        "response_prefix": ""
      },
      "entity_depth_0": {
        "instruction": "List all main characters appearing in this top-level event. Output them as a comma-separated list. Only use names from the global entity list. Do not invent new names. Title: {title}//Premise: {premise}//Setting: {setting}//Event: {current_event}//Detected Entities: {detected_entities}",
        "response_prefix": ""
      },
      "event": {
        "instruction": "Write the next event in the story as one simple sentence. Keep it fun, clear, and easy to understand for kids. Title: {title}//Premise: {premise}//Setting: {setting}//Characters: {entities}//Outline so far://{context_prefix}",
        "response_prefix": ""
      },
      "scene": {
        "instruction": "Describe where this event happens in one sentence. Keep it easy to picture and kid-friendly. Title: {title}//Premise: {premise}//Setting: {setting}//Characters: {entities}//Event: {current_event}",
        "response_prefix": ""
      },
      "entity": {
        "instruction": "Identify all characters present in this event. Use only names from the main entity list. Return them as a comma-separated list. Title: {title}//Premise: {premise}//Setting: {setting}//Event: {current_event}//Scene: {current_scene}//Detected Entities: {detected_entities}",
        "response_prefix": ""
      }
    }
  }
}
"""

prompts_dict = json.loads(prompts_json_content)
prompts = load_prompts_from_dict(prompts_dict)
print("Prompts loaded and templates created.")


Prompts loaded and templates created.


In [73]:
import argparse
import os

from pathlib import Path
import string
from collections.abc import Sequence
from functools import partial
import string
import uuid

In [74]:
class Setting:
    def __init__(self, setting):
        if isinstance(setting, list) and setting:
            self.setting = setting[0]
        else:
            self.setting = setting

    def __str__(self):
        if isinstance(self.setting, str):
            return self.setting
        return str(self.setting)

class Plan:
    @staticmethod
    def load(path):
        with open(path, 'r') as f:
            data = json.load(f)

        premise = Premise(data['premise']['title'], data['premise']['premise'])
        setting = Setting(data['setting'])

        flat = []

        def add_item(obj):
            if obj is None:
                return
            if isinstance(obj, Entity):
                flat.append(obj)
            elif isinstance(obj, dict):
                if "name" in obj and "description" in obj:
                    flat.append(Entity(obj["name"], obj["description"]))
            elif isinstance(obj, list):
                for sub in obj:
                    add_item(sub)

        add_item(data["entities"])

        entity_list = EntityList(flat)

        outline = OutlineNode.from_dict(data['outline'])

        return Plan(premise, setting, entity_list, outline)


    def __init__(self, premise, setting=None, entity_list=None, outline=None):
        self.premise = premise
        self.setting = setting
        self.entity_list = entity_list
        self.outline = outline

    def __str__(self):
        premise_str = str(self.premise) if self.premise is not None else ""

        def flatten_list(x):
            if isinstance(x, list):
                flat = []
                for item in x:
                    if isinstance(item, list):
                        flat.extend(flatten_list(item))
                    else:
                        flat.append(item)
                return flat
            return [x]

        try:
            setting_items = flatten_list(self.setting)
            setting_str = "\n".join(str(s) for s in setting_items)
        except:
            setting_str = str(self.setting.setting) if self.setting and hasattr(self.setting, 'setting') else str(self.setting)

        try:
            if hasattr(self.entity_list, 'entities'):
                entities_str = str(self.entity_list)
            else:
                entities_str = "\n".join(
                    f"{i+1}. {str(e.name)}: {str(e.description)}"
                    for i, e in enumerate(self.entity_list)
                )
        except Exception as e:
                entities_str = str(self.entity_list)
                logging.error(f"Error formatting entities: {e}")


        try:
            outline_str = str(self.outline)
        except:
            outline_str = str(self.outline)


        return (
            f"{premise_str}\n\n"
            f"Setting:\n{setting_str}\n\n"
            f"Characters and Entities:\n{entities_str}\n\n"
            f"Outline:\n{outline_str}"
        )



    def save(self, path):
        with open(path, 'w') as f:
            json.dump({
                'premise': {
                    'title': self.premise.title,
                    'premise': self.premise.premise
                },
                'setting': self.setting.setting,
                'entities': [{
                    'name': entity.name,
                    'description': entity.description
                } for entity in self.entity_list],
                'outline': self.outline.to_dict()
            }, f, indent=4)

In [75]:
try:
    from nltk.corpus import stopwords
    _ = stopwords.words('english')
except:
    import nltk
    nltk.download('stopwords')
    from nltk.corpus import stopwords


class Entity:
    def __init__(self, name, description):

        if isinstance(name, list):
            name = name[0] if name else ""
        self.name = str(name).strip()

        if isinstance(description, list):
            description = description[0] if description else ""
        self.description = str(description).strip()



class EntityList:
    def __init__(self, entities=None):
        self.entities = entities if entities is not None else []

    def __len__(self):
        return len(self.entities)

    def __str__(self):
        lines = []
        for i, entity in enumerate(self.entities):
            name = entity.name
            desc = entity.description

            if isinstance(name, list):
                name = name[0] if name else ""
            if isinstance(desc, list):
                desc = desc[0] if desc else ""

            lines.append(f"{i+1}. {str(name)}: {str(desc)}")

        return "\n\n".join(lines)


    def print_with_full_names(self):
        return '\n\n'.join([f'{i+1}. Full Name: {entity.name}\n\nDescription: {entity.description}' for i, entity in enumerate(self.entities)])

    def __iter__(self):
        return iter(self.entities)

    def __getitem__(self, index):
        return self.entities[index]

    def get_entity_by_name(self, name):
        for entity in self.entities:
            if entity.name == name:
                return entity
        raise ValueError(f'EntityList has no entity named {name}.')


def detect_entities(event, entity_list):
    detected_entities = []

    if isinstance(event, list) and event:
        event = event[0]

    if not isinstance(event, str):
        logging.error(f"Entity detection failed: 'event' is not a string after normalization: {event}")
        return detected_entities

    event_lower = event.lower()

    stopwords_list = stopwords.words('english')

    for entity in entity_list:
        entity_name = entity.name
        if isinstance(entity_name, list) and entity_name:
            entity_name = entity_name[0]

        if not isinstance(entity_name, str):
            continue

        for name_part in entity_name.split():
            name_part_lower = name_part.lower()

            if name_part_lower in stopwords_list:
                continue

            if name_part_lower in event_lower:
                if entity.name not in detected_entities:
                    detected_entities.append(entity.name)
                break
    return detected_entities


In [76]:
def num_to_char(n):
    """Converts a number (1, 2, 3...) to a letter (A, B, C...)."""
    if n < 1:
        return '0'
    return chr(64 + n)
def num_to_roman(n):
     import roman
     return roman.toRoman(n)

class OutlineNode(Sequence):
    def pretty(self):
        """Generates a nicely formatted, hierarchical string representation of the outline."""
        lines = []
        for node in self.depth_first_traverse(include_self=False):
             lines.append(node.format_self())
        return '\n'.join(lines)

    @staticmethod
    def from_dict(d, parent=None):
        node = OutlineNode(d['text'], parent, d['scene'], d['entities'], d['id'])
        node.children = [OutlineNode.from_dict(child, node) for child in d['children']]
        return node

    @staticmethod
    def num_converter(depth):
        if depth == 0:
            return lambda num: ''
        if depth % 3 == 1:
            return str
        elif depth % 3 == 2:
            return num_to_char
        elif depth % 3 == 0:
            return num_to_roman

    @staticmethod
    def indent(depth):
        if depth == 0:
            return ''
        return '\t' * (depth-1)

    def __init__(self, text, parent, scene='', entities=None, id=None):

        if isinstance(text, list):
            if text:
                text = text[0]
            else:
                text = ""

        if not isinstance(text, str):
            text = str(text)

        self.text = text.strip()
        self.entities = entities if entities is not None else []
        self.scene = scene
        self.children = []
        self.parent = parent
        self.id = str(uuid.uuid4()) if id is None else id
        super().__init__()

    def __hash__(self):
        return hash(self.id)

    def __eq__(self, other):
        return self.id == other.id

    def to_dict(self):
        return {
            'text': self.text,
            'scene': self.scene,
            'entities': self.entities,
            'children': [child.to_dict() for child in self.children],
            'id': self.id
        }

    def format_self(self):
        if isinstance(self.text, list) and self.text:
            self.text = self.text[0]
        if not isinstance(self.text, str):
            self.text = ""

        scene_text = self.scene
        if isinstance(scene_text, list) and scene_text:
            scene_text = scene_text[0]
        if not isinstance(scene_text, str):
            scene_text = ""

        s = self.number() + self.text

        if len(scene_text) > 0:
            s += ' Scene: ' + scene_text

        return s


    def __str__(self):
        s = f"{self.number()}: {self.text}"

        for child in self.children:
            s += "\n" + str(child)
        return s

    def __len__(self):
        return len(self.children)

    def __getitem__(self, index):
        return self.children[index]

    def get_node_by_id(self, id):
        for node in self.root().depth_first_traverse():
            if node.id == id:
                return node
        return None

    def number(self, depth_shift=0, lookforward=0, convert=True):
        if self.parent is None:
            num = 1
        else:
            try:
                num = self.parent.children.index(self) + 1
            except ValueError:
                num = 1
                for i, child in enumerate(self.parent.children):
                    if child.id == self.id:
                        num = i + 1
                        break

        num += lookforward

        if convert:
            depth = self.depth() + depth_shift
            if depth == 0:
                return ''
            return '\t' * (depth-1) + OutlineNode.num_converter(depth)(num) + '. '

        return num

    def depth(self):
        if self.parent is None:
            return 0
        return 1 + self.parent.depth()

    def root(self):
        if self.parent is None:
            return self
        return self.parent.root()

    def predecessor(self, max_depth=1e8):
        nodes = list(self.root().depth_first_traverse(max_depth=max_depth))
        idx = nodes.index(self)
        return nodes[idx-1] if idx > 0 else None

    def successor(self, max_depth=1e8):
        nodes = list(self.root().depth_first_traverse(max_depth=max_depth))
        idx = nodes.index(self)
        return nodes[idx+1] if idx < len(nodes)-1 else None

    def ancestors(self, include_self=False):
        if self.parent is None:
            return [self] if include_self else []
        return self.parent.ancestors(include_self=True) + ([self] if include_self else [])

    def siblings(self, include_self=False):
        if self.parent is None:
            return []
        return [child for child in self.parent.children if (include_self or child != self)]

    def leaves(self):
        if len(self.children) == 0:
            return [self]
        return sum([child.leaves() for child in self.children], [])

    def depth_first_traverse(self, include_self=True, max_depth=1e8):
        if self.depth() <= max_depth and include_self:
            yield self
        for child in self.children:
            yield from child.depth_first_traverse(max_depth=max_depth)

    def breadth_first_traverse(self, include_self=True, max_depth=1e8):
        if self.depth() <= max_depth and include_self:
            yield self
        if self.depth() < max_depth:
            queue = [c for c in self.children]
            while queue:
                n = queue.pop(0)
                yield n
                if n.depth() < max_depth:
                    queue.extend(n.children)

    def context(self, context_type):
        if context_type == 'full':
            selected_nodes = set(list(self.root().depth_first_traverse(include_self=False)))
        elif context_type == 'ancestors':
            selected_nodes = set(self.ancestors(include_self=False))
        elif context_type == 'ancestors-with-siblings':
            ancestors = list(self.ancestors(include_self=True))
            selected_nodes = set(sum([a.siblings(include_self=True) for a in ancestors], []))
        elif context_type == 'ancestors-with-siblings-children':
            ancestors = list(self.ancestors(include_self=True))
            anc_sibs = sum([a.siblings(include_self=True) for a in ancestors], [])
            selected_nodes = set(anc_sibs + sum([node.children for node in anc_sibs], []))
        else:
            raise NotImplementedError()

        prefix = []
        suffix = []
        in_prefix = True

        for node in self.root().depth_first_traverse(include_self=False):
            if node == self:
                in_prefix = False
            elif node in selected_nodes:
                (prefix if in_prefix else suffix).append(node)

        return (
            '\n\n'.join([n.format_self() for n in prefix]),
            '\n\n'.join([n.format_self() for n in suffix])
        )

In [77]:
import time
import string
from functools import partial
import re

def split_numbered_items(text):

    items = re.split(r'\n?\s*\d+\.\s*', text)
    items = [item.strip() for item in items if item.strip()]
    return items


def generate_setting(plan, llm_client, setting_prompt, setting_config):
    plan.setting = Setting(
        llm_client.call_with_retry(
            setting_prompt.format(
                title=plan.premise.title,
                premise=plan.premise.premise
            ),
            SamplingConfig.from_config(setting_config),
            filter=Filter(lambda s: len(s.strip()) > 50),
            max_attempts=10
        )[0]
    )
    return plan

def generate_entities(plan, llm_client, entity_prompt, entity_config):

    normalized = []

    if plan.entity_list is not None:
        for item in plan.entity_list:
            if isinstance(item, Entity):
                normalized.append(item)

            elif isinstance(item, dict):
                normalized.append(Entity(item["name"], item["description"]))

            elif isinstance(item, list):
                for sub in item:
                    if isinstance(sub, Entity):
                        normalized.append(sub)
                    elif isinstance(sub, dict):
                        normalized.append(Entity(sub["name"], sub["description"]))

    plan.entity_list = EntityList(normalized)

    def postprocess_name(generated, **kwargs):
        if not isinstance(generated, (list, tuple)) or not generated:
            return [""]

        text = str(generated[0]).strip().split("\n")[0]
        text = re.sub(r'^\d+\.\s*', '', text)
        return [text.strip()]

    def postprocess_entity_description(descriptions, **kwargs):
        desc = descriptions[0].split("\n")[0].strip()
        return [desc]

    name_config = entity_config['name']
    desc_config = entity_config['description']

    while len(plan.entity_list) < entity_config['max_entities']:

        name = llm_client.call_with_retry(
            entity_prompt['name'].format(
                title=plan.premise.title,
                premise=plan.premise.premise,
                setting=plan.setting.setting,
                entity_list=", ".join(e.name for e in plan.entity_list if hasattr(e, 'name') and isinstance(e.name, str)),
            ),
            SamplingConfig.from_config(name_config),
            postprocessor=postprocess_name,
            filter=Filter(lambda s: len(s.strip()) > 1),
            max_attempts=10
        )[0]

        if isinstance(name, list):
           if name:
             name = name[0]
           else:
             name = ""
        name = str(name).strip()


        if name in [e.name for e in plan.entity_list]:
            break

        desc = llm_client.call_with_retry(
            entity_prompt['description'].format(
                title=plan.premise.title,
                premise=plan.premise.premise,
                setting=plan.setting.setting,
                entity_name=name
            ),
            SamplingConfig.from_config(desc_config),
            postprocessor=postprocess_entity_description,
            filter=Filter(lambda s: len(s.strip()) > 10),
            max_attempts=10
        )[0]

        if isinstance(desc, list):
           if desc:
             desc = desc[0]
           else:
             desc = ""
        desc = str(desc).strip()


        plan.entity_list.entities.append(Entity(name, desc))

    return plan

def generate_outline(plan, llm_client, outline_prompt, outline_config):
    plan.outline = OutlineNode('', None)
    max_nodes = 50
    while True:
        if len(list(plan.outline.depth_first_traverse())) > max_nodes:
            break

        try:
            node_to_expand = select_node_to_expand(plan.outline, outline_config)
        except StopIteration:
            break

        generate_node_subevents(node_to_expand, llm_client, outline_prompt, outline_config, plan, max_attempts=10)

    return plan

def generate_node_subevents(node, llm_client, outline_prompt, outline_config, plan, max_attempts=1):
    context_prefix = ""
    context_suffix = ""
    filter = Filter(lambda x: True)

    def event_postprocessor(events, **kwargs):
        responses = []
        for event in events:
            event = event.strip()
            event = re.sub(r'^\[[^\]]*\]\s*', '', event)
            event = event.split('\n')[0]
            event = event.split('Scene:')[0]
            event = event.split('Characters:')[0]
            event = event.strip()

            if not event:
                event = "Something happens."

            if event[-1] not in ".?!":
                event += "."

            responses.append(event)
        return responses

    if node.depth() == 0:
        event_config = outline_config['event_depth_0']
        event_prompt = outline_prompt['event_depth_0']
    else:
        event_config = outline_config['event']
        event_prompt = outline_prompt['event']

    for _ in range(outline_config['preferred_max_children']):
        new_child = OutlineNode('', node)

        event = llm_client.call_with_retry(
            event_prompt.format(
                title=plan.premise.title,
                premise=plan.premise.premise,
                setting=plan.setting.setting,
                entities=str(plan.entity_list),
                formatted_current_number=new_child.number().rstrip(),
                stripped_current_number=new_child.number().strip(),
                context_prefix=context_prefix,
                context_suffix=context_suffix,
                predecessor_info="",
                successor_info="",
                preferred_max_children=outline_config['preferred_max_children']
            ),
            SamplingConfig.from_config(event_config),
            postprocessor=partial(
                event_postprocessor,
                has_next_indicator="\n" + new_child.number(lookforward=1).strip(),
                current_number=new_child.number().strip()
            ),
            filter=filter,
            max_attempts=max_attempts
        )

        logging.warning(f"Raw LLM event: {event}")

        new_child.text = event[0]
        node.children.append(new_child)

        context_prefix, context_suffix = new_child.context(outline_config['context'])

        if len(node.children) >= outline_config['max_children']:
            break

        filter = Filter(lambda x: True)

        generate_node_scene(
            new_child, llm_client,
            outline_prompt['scene'], outline_config['scene'], plan
        )

        generate_node_entities(
            new_child, llm_client,
            outline_prompt['entity_depth_0'] if node.depth() == 0 else outline_prompt['entity'],
            outline_config['entity_depth_0'] if node.depth() == 0 else outline_config['entity'],
            plan
        )

def generate_node_scene(node, llm_client, scene_prompt, scene_config, plan):
    def scene_postprocessor(scenes, **kwargs):
        clean = []
        for sc in scenes:
            sc = sc.split('\n')[0].split('Characters:')[0].split('Scene:')[-1].strip()
            clean.append(sc)
        return clean

    context_prefix, context_suffix = node.context(scene_config['context'])

    node.scene = llm_client.call_with_retry(
        scene_prompt.format(
            title=plan.premise.title,
            premise=plan.premise.premise,
            setting=plan.setting.setting,
            entities=str(plan.entity_list),
            formatted_current_number=node.number().rstrip(),
            stripped_current_number=node.number().strip(),
            current_event=node.text,
            context_prefix=context_prefix,
            context_suffix=context_suffix
        ),
        SamplingConfig.from_config(scene_config),
        postprocessor=scene_postprocessor,
        filter=Filter(lambda s: len(s.strip()) > 0),
    )[0]

def generate_node_entities(node, llm_client, entity_prompt, entity_config, plan):
    detected = detect_entities(node.text, plan.entity_list)
    if detected:
       node.entities = detected
       return

    def entity_postprocessor(predicted_lists, entity_list, already_detected, **kwargs):
        out = []
        for ents in predicted_lists:
            ents = ents.split('\n')[0].strip().rstrip('.')
            ents = [e.strip() for e in ents.split(',')]
            ents = [e for e in ents if e in [x.name for x in entity_list]]
            ents = list(dict.fromkeys(ents))
            ents = [e for e in ents if e not in already_detected]
            out.append(already_detected + ents)
        return out

    detected = detect_entities(node.text[0], plan.entity_list)
    context_prefix, context_suffix = node.context(entity_config['context'])

    try:
        node.entities = llm_client.call_with_retry(
            entity_prompt.format(
                title=plan.premise.title,
                premise=plan.premise.premise,
                setting=plan.setting.setting,
                entities=str(plan.entity_list),
                formatted_current_number=node.number().rstrip(),
                stripped_current_number=node.number().strip(),
                current_event=node.text,
                current_scene=node.scene,
                context_prefix=context_prefix,
                context_suffix=context_suffix,
                detected_entities=", ".join(detected)
            ),
            SamplingConfig.from_config(entity_config),
            postprocessor=partial(entity_postprocessor,
                                  entity_list=plan.entity_list,
                                  already_detected=detected),
            filter=Filter(lambda l: len(l) > 0),
            max_attempts=20
        )[0]

    except Exception:
        node.entities = detected

def select_node_to_expand(outline, outline_config):
    if outline_config['expansion_policy'] == 'breadth-first':
        for node in outline.breadth_first_traverse(max_depth=outline_config['max_depth'] - 1):
            if len(node.children) == 0:
                return node
        raise StopIteration
    else:
        raise NotImplementedError


In [78]:
try:
    prompts_dict = json.loads(prompts_json_content)
    prompts = load_prompts_from_dict(prompts_dict)
    premise = Premise.load(config['premise_path'])

    plan = Plan(premise)

    plan_config = config['model']['plan']
    plan_prompts = prompts['plan']

    logging.info("Generating setting...")
    plan = generate_setting(
        plan,
        llm_client,
        plan_prompts['setting'],
        plan_config['setting']
    )
    logging.info(f"Generated setting: {plan.setting.setting}")
    torch.cuda.empty_cache()

    logging.info("Generating entities...")
    plan = generate_entities(
        plan,
        llm_client,
        plan_prompts['entity'],
        plan_config['entity']
    )
    logging.info(f"Generated entities: {plan.entity_list}")
    torch.cuda.empty_cache()

    logging.info("Generating outline...")
    plan = generate_outline(
        plan,
        llm_client,
        plan_prompts['outline'],
        plan_config['outline']
    )
    logging.info(f"Generated outline with {len(list(plan.outline.depth_first_traverse()))} nodes.")
    torch.cuda.empty_cache()


    output_path = config['output_path']
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    plan.save(output_path)

    print("\n--- FINAL RESULT ---")
    print(plan)
    print(f"\nPlan object saved to: {output_path}")

except Exception as e:
    logging.error(f"An error occurred during execution. Please check your model configuration and ensure your LLM server is running and accessible. Error: {e}")
    raise e

2025-12-05 15:15:29 INFO     Generating setting...
2025-12-05 15:16:02 INFO     Generated setting: "As the seed sprouts, Marigold's garden becomes a bustling, rainbow-colored haven filled with chirping birds, buzzing bees, and butterflies flitting about. The sun shines bright, casting a warm, golden glow on the blooming flowers and the growing plant.

'Sprout, little seed!' Marigold whispers, as she tenderly waters her new friend each day. She feels the cool water drip from the watering can, down onto the soil, and watches as the green shoot grows taller and stronger.

'Look at you grow!' Marigold exclaims, as the plant stretches its leaves towards the sky. The leaves start to unfurl, revealing tiny, delicate flowers that sparkle in the sunlight.

'You're beautiful!' Marigold says, as she admires her creation. She feels a sense of pride and happiness, knowing that she's helped something grow from a tiny seed into a beautiful plant.

One day, Marigold's plant blooms, filling her garden 


--- FINAL RESULT ---
Title: ['"Seed Sprouts Surprise: A Sunlit Tale of Growth and Green Thumbs"']

Premise: ["In a vibrant, sun-kissed Gardenia Village, where every house has a blooming garden, lives a curious and kind-hearted child named Marigold. One sunny morning, Marigold finds a small, sleepy seed in her grandmother's garden. With a twinkle in her eyes, she decides to nurture this tiny promise of a plant. Marigold carefully places the seed in a pot filled with nutrient-rich soil and waters it gently. As days pass, she watches and waits, watering and talking to her new friend, hoping and praying for it to grow. One magical day, a green shoot pokes through the soil, reaching up towards the sky. Marigold cheers, realizing that she's not just growing a plant, but also learning the secrets of life - patience, love, and the magic of nature. This exciting adventure of growing a plant from a seed teaches Marigold about responsibility"]

Setting:
"As the seed sprouts, Marigold's garden be

## Generate the final story

In [79]:
class Story:
    def __init__(self):
        self.passages = []

    def add_passage(self, passage_dict):
        self.passages.append(passage_dict)

    def save(self, path):
        os.makedirs(os.path.dirname(path), exist_ok=True)
        with open(path, "w", encoding="utf-8") as f:
            json.dump({
                "story": self.passages
            }, f, indent=2, ensure_ascii=False)

    def __str__(self):
        all_text = "\n\n".join([p["text"] for p in self.passages])
        return all_text

In [80]:
prompts_story_json_content = """
{
  "story": {
    "write": {
      "instruction": "Write the next passage of the children's story. Use short sentences, simple words, and fun descriptions. Keep the tone cheerful and engaging. Story so far://{story_so_far}//Outline Event://{outline_event}//Scene://{scene}//Characters://{entities}//Write a clear, kid-friendly passage:",
      "response_prefix": ""
    },
    "score": {
      "instruction": "Rate the quality of this passage for a children's story. Focus on clarity, fun, simplicity, and imagination. Return a number from 1 to 10. Passage://{passage}//Score:",
      "response_prefix": ""
    },
    "summarize": {
      "instruction": "Summarize this passage into 1â€“2 sentences for kids. Use simple language and keep it fun. Passage://{passage}//Summary:",
      "response_prefix": ""
    }
  }
}
"""

story_prompts_dict = json.loads(prompts_story_json_content)
story_prompts = load_prompts_from_dict(story_prompts_dict)
print("Story Prompts loaded and templates created.")


Story Prompts loaded and templates created.


In [81]:
class StoryWriter:
    def __init__(self, llm_client, prompts, config):
        self.llm = llm_client
        self.prompts = prompts
        self.config_write = SamplingConfig.from_config(config["write"])
        self.config_score = SamplingConfig.from_config(config["score"])
        self.config_summarize = SamplingConfig.from_config(config["summarize"])

    def generate_passage(self, story_so_far, outline_event, scene, entities):
        prompt_builder = self.prompts["write"].format(
            story_so_far=story_so_far,
            outline_event=outline_event,
            scene=scene,
            entities=", ".join(entities)
        )
        result = self.llm.call_with_retry(
            prompt_builder,
            self.config_write,
            max_attempts=3
        )[0]

        if isinstance(result, list):
            result = result[0]
        return result


    def score_passage(self, passage):
        prompt_builder = self.prompts["score"].format(
            passage=passage
        )
        result = self.llm.call_with_retry(
            prompt_builder,
            self.config_score,
            max_attempts=2
        )[0]

        if isinstance(result, list):
            result = result[0]
        return result.strip()


    def summarize_passage(self, passage):
        prompt_builder = self.prompts["summarize"].format(
            passage=passage
        )
        result = self.llm.call_with_retry(
            prompt_builder,
            self.config_summarize,
            max_attempts=2
        )[0]

        if isinstance(result, list):
            result = result[0]
        return result


In [82]:
story_config_yaml = """
model:
  engine: "mistralai/Mistral-7B-Instruct-v0.3"
  host: "http://localhost"
  port: 8000
  server_type: vllm
  tensor_parallel_size: 1

  story:
    write:
      max_tokens: 350
      temperature: 0.7
      top_p: 0.9
      prompt_format: openai-chat
    score:
      max_tokens: 20
      temperature: 0.1
      top_p: 0.5
      prompt_format: openai-chat
    summarize:
      max_tokens: 80
      temperature: 0.3
      top_p: 0.9
      prompt_format: openai-chat

output_path: "outputs/story.json"

"""

story_config = Config(yaml.safe_load(story_config_yaml), None)
print("Story configuration loaded.")

Story configuration loaded.


In [83]:
def generate_story(plan, llm_client, prompts, config):
    story_writer = StoryWriter(
        llm_client=llm_client,
        prompts=prompts["story"],
        config=config["model"]["story"]
    )
    story = Story()
    logging.info("Beginning story generation from outline...")
    outline_nodes = list(plan.outline.depth_first_traverse())

    story_text_so_far = ""
    for node in outline_nodes:
        logging.info(f"Generating passage for node {node.number()}: {node.text}")
        passage = story_writer.generate_passage(
            story_so_far=story_text_so_far,
            outline_event=node.text,
            scene=node.scene,
            entities=node.entities
        )
        summary = story_writer.summarize_passage(passage)
        score = story_writer.score_passage(passage)

        story.add_passage({
            "event_number": node.number(),
            "text": passage,
            "summary": summary,
            "score": score,
            "entities": node.entities,
            "scene": node.scene
        })
        story_text_so_far += "\n" + passage
    return story

if __name__ == "__main__":
    try:
        logging.info("Loading plan...")
        plan_path = "output/plan.json"

        plan = Plan.load(plan_path)

        logging.info("Configuration and prompts are already loaded globally.")

        logging.info("Generating final story...")
        story = generate_story(
            plan,
            llm_client,
            story_prompts,
            story_config["model"]["story"]
        )

        output_path = story_config["output_path"]
        story.save(output_path)

        print("\n--- FINAL STORY GENERATED SUCCESSFULLY ---")
        print(f"Saved to {output_path}")
        print("\n--- FINAL STORY ---")
        print(story)

    except Exception as e:
        logging.error(f"Story Generation failed: {e}")
        raise e

2025-12-05 15:18:00 INFO     Loading plan...
2025-12-05 15:18:00 INFO     Configuration and prompts are already loaded globally.
2025-12-05 15:18:00 INFO     Generating final story...
2025-12-05 15:18:00 INFO     Beginning story generation from outline...
2025-12-05 15:18:00 INFO     Generating passage for node : 
2025-12-05 15:18:58 INFO     Generating passage for node 1. : 5. "Grandma Rose": "Grandma Rose, a wise and gentle woman with a heart as warm as the sun, is Marigold's beloved grandmother. She is a master gardener, sharing her vast knowledge of plants and the magic of nature with Marigold. Her kindness and patience inspire Marigold to become a responsible and caring gardener like her.".
2025-12-05 15:19:59 INFO     Generating passage for node 2. : 4. Grandma Elara: "Grandma Elara" is a wise, loving, and nurturing figure in Marigold's life. She has a green thumb and is a well-respected gardener in Gardenia Village. Her garden, which is filled with a variety of colorful and exot


--- FINAL STORY GENERATED SUCCESSFULLY ---
Saved to outputs/story.json

--- FINAL STORY ---
In a bright, sunny meadow, where daisies danced and butterflies fluttered, lived a tiny, sprightly squirrel named Squeaky. Squeaky loved to race up and down the tallest trees, chasing the wind and collecting acorns.

One sunny afternoon, Squeaky found a shiny, golden acorn hidden beneath a patch of clover. It was bigger than any acorn he had ever seen! Squeaky picked it up and held it close, feeling its warm, golden glow.

"Oh, what a treasure!" Squeaky exclaimed. "I wonder what this could be!"

Just then, a friendly, talking owl named Oliver swooped down from a nearby tree. "Hello, Squeaky!" he greeted. "What have you got there?"

"I found this golden acorn!" Squeaky replied, showing it to Oliver. "I've never seen anything like it before!"

Oliver's eyes widened with curiosity. "A golden acorn? That's quite unusual! It's said that golden acorns hold magic within them. Perhaps it's a sign of a 