# Interactive Storybook Prompt Pipeline Tester

Use this notebook to run and validate each prompt stage (1-5) and downstream image generation with Gemini 2.5 Flash.

## Notebook Roadmap
- Environment setup
- Prompt loading utilities
- Stage-by-stage testers (Stages 1-5)
- Reference image prompt generation
- Style-guided image generation
- Full pipeline runner
- Validation helpers

> Stage 6 guardrail review is intentionally omitted for now.

---

## 1. Environment Setup

Configure library imports, load API credentials, and instantiate SDK clients.

> The notebook falls back to `~/.zshrc` if keys are not present in the current environment.

In [169]:
import os
import json
import base64
import textwrap
import datetime as dt
from io import BytesIO
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

import xml.etree.ElementTree as ET

from IPython.display import JSON, Markdown, display

from openai import OpenAI
from google import genai
from google.genai import types
from PIL import Image


In [170]:
NOTEBOOK_ROOT = Path.cwd()
PROMPT_DIR = NOTEBOOK_ROOT / 'prompts'
STYLE_GUIDE_PATH = PROMPT_DIR / 'image_style_guidance.json'
OUTPUT_ROOT = NOTEBOOK_ROOT / 'docs' / 'outputs'
OUTPUT_ROOT.mkdir(parents=True, exist_ok=True)

def load_api_key(env_var: str) -> Optional[str]:
    key = os.getenv(env_var)
    if key:
        return key
    zshrc = Path.home() / '.zshrc'
    if zshrc.exists():
        for line in zshrc.read_text().splitlines():
            if line.startswith(f'{env_var}='):
                value = line.split('=', 1)[1].strip()
                if value and value[0] == value[-1] and value[0] in ("\"", "'"):
                    value = value[1:-1]
                return value
    return None

OPENAI_API_KEY = load_api_key('OPENAI_API_KEY')
if not OPENAI_API_KEY:
    raise EnvironmentError('OPENAI_API_KEY is required. Set it in the environment or ~/.zshrc.')

GEMINI_API_KEY = load_api_key('GEMINI_API_KEY')
if not GEMINI_API_KEY:
    raise EnvironmentError('GEMINI_API_KEY is required. Set it in the environment or ~/.zshrc.')

openai_client = OpenAI(api_key=OPENAI_API_KEY, max_retries=5)
gemini_client = genai.Client(api_key=GEMINI_API_KEY)

TIMESTAMP = dt.datetime.now().strftime('%Y%m%d-%H%M%S')
RUN_OUTPUT_DIR = OUTPUT_ROOT / TIMESTAMP
RUN_OUTPUT_DIR.mkdir(exist_ok=True, parents=True)

print('✅ OpenAI client ready (GPT-5 via Responses API)')
print('✅ Gemini client ready (Nano Banana / Gemini 2.5 Flash)')
print(f'📂 Outputs will be stored in: {RUN_OUTPUT_DIR}')
import wave


✅ OpenAI client ready (GPT-5 via Responses API)
✅ Gemini client ready (Nano Banana / Gemini 2.5 Flash)
📂 Outputs will be stored in: /Users/jacky/Tomo/Jacky/talkbookv1/docs/outputs/20251016-120035


---

## 2. Prompt Management

## Running & Saving Outputs

Run the environment setup cell before any stage testers so the notebook knows where to write files.

1. Execute the setup code that loads API keys and instantiates GPT-5/Gemini clients; it prints the active `RUN_OUTPUT_DIR` under `docs/outputs/`.
2. Keep `DRY_RUN = False` when you want real responses—dry runs skip the `write_stage_output` helper and nothing is saved.
3. Re-run the setup cell any time you need a fresh timestamped folder before running stage cells or `run_full_pipeline`.
4. After stages or the full pipeline finish, watch the console messages (or `print(RUN_OUTPUT_DIR)`) to confirm the saved files.


Utility helpers for loading prompt templates, rendering them with runtime context, and storing raw model responses.

In [171]:
def load_prompt_templates(prompt_dir: Path) -> Dict[str, str]:
    templates: Dict[str, str] = {}
    for path in sorted(prompt_dir.glob('*.txt')):
        templates[path.stem] = path.read_text().strip()
    return templates

def render_prompt(template: str, context: Dict[str, Any]) -> str:
    rendered = template
    for key, value in context.items():
        placeholder = '{{' + key + '}}'
        rendered = rendered.replace(placeholder, str(value))
    return rendered

def write_stage_output(stage: str, content: str, suffix: str = 'txt') -> Path:
    out_path = RUN_OUTPUT_DIR / f'{stage}.{suffix}'
    out_path.write_text(content.strip())
    return out_path

def ensure_directory(path: Path) -> Path:
    path.mkdir(parents=True, exist_ok=True)
    return path


In [172]:
PROMPT_TEMPLATES = load_prompt_templates(PROMPT_DIR)
STYLE_GUIDE = json.loads(STYLE_GUIDE_PATH.read_text())
print(f'📄 Loaded {len(PROMPT_TEMPLATES)} prompt templates from {PROMPT_DIR}')
print('🎨 Style guide suffix:', STYLE_GUIDE.get('style_prompt_suffix', '')[:80] + '...')
print('🚫 Negative prompt:', STYLE_GUIDE.get('negative_prompt', '')[:80] + '...')


📄 Loaded 7 prompt templates from /Users/jacky/Tomo/Jacky/talkbookv1/prompts
🎨 Style guide suffix: Cheerful 2D digital illustration with soft pastel palette, gentle rounded linewo...
🚫 Negative prompt: No weapons, no frightening imagery, no mature themes, no harsh shadows, no cropp...


---

## 3. Stage Workers

Helpers to invoke GPT-5 for stages 1-5, parse their structured outputs, and bridge into Gemini image requests.

In [173]:
import time

DRY_RUN = False  # Toggle to skip live API calls and rely on cached outputs
DEFAULT_GPT5_MODEL = 'gpt-5'
DEFAULT_REASONING_EFFORT = 'medium'
REASONING_MODEL_PREFIXES = ('gpt-5', 'o')

def call_gpt5(
    prompt: str,
    stage: str,
    *,
    model: str = DEFAULT_GPT5_MODEL,
    reasoning_effort: Optional[str] = DEFAULT_REASONING_EFFORT,
    instructions: Optional[str] = None,
    max_output_tokens: int = 15000,
    temperature: Optional[float] = None,
    output_suffix: str = 'txt',
    stream: bool = False,  # Default to False for more reliable operation
    max_retries: int = 3,
) -> str:
    """Invoke GPT-5 via Responses API and persist the raw output.

    GPT-5 and other reasoning models ignore sampling controls such as `temperature`,
    so we only forward that parameter for non-reasoning models per the OpenAI docs.
    """
    if DRY_RUN:
        cached_path = RUN_OUTPUT_DIR / f'{stage}.{output_suffix}'
        if cached_path.exists():
            print(f'[dry-run] Using cached response for {stage} -> {cached_path.name}')
            return cached_path.read_text()
        print(f'[dry-run] No cached response found for {stage}; returning empty string.')
        return ''

    input_payload = [
        {
            'role': 'user',
            'content': [
                {'type': 'input_text', 'text': prompt},
            ],
        }
    ]

    response_args = {
        'model': model,
        'input': input_payload,
        'max_output_tokens': max_output_tokens,
        'stream': stream,
    }
    if instructions:
        response_args['instructions'] = instructions
    if reasoning_effort:
        response_args['reasoning'] = {'effort': reasoning_effort}
    if temperature is not None:
        if not model.startswith(REASONING_MODEL_PREFIXES):
            response_args['temperature'] = temperature
        else:
            print(f"[info] temperature parameter ignored for reasoning model {model} per GPT-5 API guidance.")

    # Retry logic for connection issues
    for attempt in range(max_retries):
        try:
            if stream:
                print(f'🔄 Streaming response for {stage}... (attempt {attempt + 1})')
                full_text = ''
                
                for chunk in openai_client.responses.create(**response_args):
                    if hasattr(chunk, 'output_text') and chunk.output_text:
                        print(chunk.output_text, end='', flush=True)
                        full_text += chunk.output_text
                
                print()  # New line after streaming completes
                text = full_text
            else:
                print(f'🔄 Calling GPT-5 for {stage}... (attempt {attempt + 1})')
                response = openai_client.responses.create(**response_args)
                text = response.output_text or ''
            
            # If we get here, the request succeeded
            break
            
        except Exception as e:
            print(f'❌ Attempt {attempt + 1} failed: {e}')
            if attempt < max_retries - 1:
                wait_time = 2 ** attempt  # Exponential backoff: 1s, 2s, 4s
                print(f'⏳ Retrying in {wait_time} seconds...')
                time.sleep(wait_time)
            else:
                print(f'💥 All {max_retries} attempts failed for {stage}')
                raise
    
    write_stage_output(stage, text, suffix=output_suffix)
    return text

def parse_json_response(raw_text: str, stage: str) -> Dict[str, Any]:
    raw = raw_text.strip()
    if not raw:
        raise ValueError(f'No JSON content returned for {stage}.')
    try:
        return json.loads(raw)
    except json.JSONDecodeError as err:
        raise ValueError(f'Failed to parse JSON for {stage}: {err}') from err

def parse_xml_response(raw_text: str, stage: str) -> ET.Element:
    raw = raw_text.strip()
    if not raw:
        raise ValueError(f'No XML content returned for {stage}.')
    try:
        return ET.fromstring(raw)
    except ET.ParseError as err:
        raise ValueError(f'Failed to parse XML for {stage}: {err}') from err

def stage_result(stage: str, prompt: str, raw: str, parsed: Any) -> Dict[str, Any]:
    return {
        'stage': stage,
        'prompt': prompt,
        'raw': raw,
        'parsed': parsed,
    }

In [174]:
def story_nodes_xml_to_records(root: ET.Element) -> List[Dict[str, Any]]:
    records: List[Dict[str, Any]] = []
    for node in root.findall('Node'):
        record: Dict[str, Any] = {
            'node_id': node.findtext('Id'),
            'type': node.findtext('Type') or 'linear',
            'next_node': node.findtext('NextNode'),
            'parent_choice_id': node.findtext('ParentChoiceId'),
            'continuity_notes': node.findtext('ContinuityNotes'),
            'content': node.findtext('NodeContent') or node.findtext('SceneText'),
            'retry_message': node.findtext('RetryMessage'),
        }
        choices_elem = node.find('Choices')
        if choices_elem is not None:
            record['choices'] = []
            for choice in choices_elem.findall('Choice'):
                record['choices'].append({
                    'id': choice.findtext('Id'),
                    'text': choice.findtext('ChoiceText'),
                    'is_correct': (choice.findtext('IsCorrect') or '').lower() == 'true',
                    'next_node': choice.findtext('NextNode'),
                })
        records.append(record)
    return records

def storyboard_nodes_xml_to_records(root: ET.Element) -> List[Dict[str, Any]]:
    records: List[Dict[str, Any]] = []
    for node in root.findall('Node'):
        scene_text = (node.findtext('SceneText') or '').strip()
        image_prompt = (node.findtext('ImagePrompt') or '').strip()
        character_ids: List[str] = []
        char_list_elem = node.find('CharacterList')
        if char_list_elem is not None:
            for character_id_elem in char_list_elem.findall('CharacterId'):
                character_id = (character_id_elem.text or '').strip()
                if character_id and character_id not in character_ids:
                    character_ids.append(character_id)
        record: Dict[str, Any] = {
            'node_id': node.findtext('Id'),
            'type': node.findtext('Type') or 'linear',
            'next_node': node.findtext('NextNode'),
            'parent_choice_id': node.findtext('ParentChoiceId'),
            'continuity_notes': node.findtext('ContinuityNotes'),
            'scene_text': scene_text,
            'image_prompt': image_prompt,
            'display_text': scene_text,
            'retry_message': (node.findtext('RetryMessage') or '').strip() or None,
            'character_ids': character_ids,
        }
        characters_data = [{'character_id': cid} for cid in character_ids]
        illustration_brief: Dict[str, Any] = {
            'prompt': image_prompt,
            'characters': characters_data,
        }
        record['illustration_brief'] = illustration_brief
        choices_elem = node.find('Choices')
        if choices_elem is not None:
            record['choices'] = []
            for choice in choices_elem.findall('Choice'):
                choice_entry: Dict[str, Any] = {
                    'id': choice.findtext('Id'),
                    'next_node': choice.findtext('NextNode'),
                    'text': choice.findtext('ChoiceText') or '',
                    'is_correct': (choice.findtext('IsCorrect') or '').lower() == 'true',
                }
                record['choices'].append(choice_entry)
        extra_fields: Dict[str, Any] = {}
        skip_tags = {'Id', 'Type', 'NextNode', 'ParentChoiceId', 'ContinuityNotes', 'SceneText', 'ImagePrompt', 'RetryMessage', 'Choices', 'CharacterList'}
        for child in list(node):
            if child.tag in skip_tags:
                continue
            extra_fields[child.tag] = (child.text or '').strip()
        if extra_fields:
            record['extra_fields'] = extra_fields
        records.append(record)
    return records

def extract_choice_node_set(root: ET.Element) -> Dict[str, Any]:
    question_nodes: List[Dict[str, Any]] = []
    outcome_nodes: List[Dict[str, Any]] = []
    for node in root.findall('Node'):
        node_type = (node.findtext('Type') or 'linear').lower()
        node_id = node.findtext('Id')
        next_node = node.findtext('NextNode')
        if node_type == 'choice_question':
            question_nodes.append({
                'node_id': node_id,
                'question_text': node.findtext('NodeContent') or node.findtext('SceneText'),
                'retry_message': node.findtext('RetryMessage') or '',
                'choices': [
                    {
                        'id': choice.findtext('Id'),
                        'text': choice.findtext('ChoiceText'),
                        'is_correct': (choice.findtext('IsCorrect') or '').lower() == 'true',
                        'next_node': choice.findtext('NextNode'),
                    }
                    for choice in (node.find('Choices') or ET.Element('Choices')).findall('Choice')
                ],
            })
        elif node_type == 'choice_outcome':
            outcome_nodes.append({
                'node_id': node_id,
                'parent_choice_id': node.findtext('ParentChoiceId'),
                'outcome_text': node.findtext('NodeContent') or node.findtext('SceneText'),
                'next_node': next_node,
            })
    return {
        'question_nodes': question_nodes,
        'outcome_nodes': outcome_nodes,
    }

def to_pretty_json(data: Any) -> str:
    return json.dumps(data, separators=(',', ':'), ensure_ascii=False)

def story_records_to_graph(records: List[Dict[str, Any]], story_id: Optional[str] = None) -> Dict[str, Any]:
    payload = {'nodes': records}
    if story_id:
        payload['story_id'] = story_id
    return payload

In [None]:
def run_stage1_learning_framework(user_inputs: Dict[str, Any]) -> Dict[str, Any]:
    template = PROMPT_TEMPLATES['stage1_educational_theme_integrator']
    context = {
        'child_age': user_inputs.get('child_age', ''),
        'theme': user_inputs.get('theme', ''),
        'parental_guidance': user_inputs.get('parental_guidance', ''),
        'character_preferences': user_inputs.get('character_preferences', ''),
        'plot_preferences': user_inputs.get('plot_preferences', ''),
    }
    prompt_text = render_prompt(template, context)
    raw_text = call_gpt5(prompt_text, 'stage1_learning_framework', output_suffix='json')
    parsed = parse_json_response(raw_text, 'stage1_learning_framework') if raw_text else {}
    return stage_result('stage1_learning_framework', prompt_text, raw_text, parsed)

def run_stage2_story(user_inputs: Dict[str, Any], learning_framework: Any) -> Dict[str, Any]:
    template = PROMPT_TEMPLATES['stage2_story_generator']
    if isinstance(learning_framework, dict):
        framework_json = to_pretty_json(learning_framework)
    elif isinstance(learning_framework, str):
        framework_json = learning_framework
    else:
        raise ValueError('learning_framework must be dict or JSON string')
    context = {
        'learning_framework_json': framework_json,
        'child_age': user_inputs.get('child_age', ''),
        'theme': user_inputs.get('theme', ''),
        'parental_guidance': user_inputs.get('parental_guidance', ''),
        'character_preferences': user_inputs.get('character_preferences', ''),
        'plot_preferences': user_inputs.get('plot_preferences', ''),
    }
    prompt_text = render_prompt(template, context)
    raw_text = call_gpt5(prompt_text, 'stage2_story_generator', output_suffix='json')
    parsed = parse_json_response(raw_text, 'stage2_story_generator') if raw_text else {}
    return stage_result('stage2_story_generator', prompt_text, raw_text, parsed)


In [176]:
def run_stage3_nodes(story_draft: Any, node_count: int = 10) -> Dict[str, Any]:
    template = PROMPT_TEMPLATES['stage3_story_orchestrator']
    if isinstance(story_draft, dict):
        story_json = to_pretty_json(story_draft)
    elif isinstance(story_draft, str):
        story_json = story_draft
    else:
        raise ValueError('story_draft must be dict or JSON string')
    context = {
        'story_draft_json': story_json,
        'node_count': node_count,
    }
    prompt_text = render_prompt(template, context)
    raw_text = call_gpt5(prompt_text, 'stage3_story_orchestrator', output_suffix='xml')
    parsed_records: List[Dict[str, Any]] = []
    if raw_text:
        root = parse_xml_response(raw_text, 'stage3_story_orchestrator')
        parsed_records = story_nodes_xml_to_records(root)
    return stage_result('stage3_story_orchestrator', prompt_text, raw_text, {'nodes': parsed_records})

def run_stage4_choices(story_node_graph_xml: str, learning_framework: Any, desired_choice_count: int = 4) -> Dict[str, Any]:
    template = PROMPT_TEMPLATES['stage4_choice_node_generator']
    if isinstance(learning_framework, dict):
        framework_json = to_pretty_json(learning_framework)
    elif isinstance(learning_framework, str):
        framework_json = learning_framework
    else:
        raise ValueError('learning_framework must be dict or JSON string')
    context = {
        'story_node_graph_xml': story_node_graph_xml,
        'learning_framework_json': framework_json,
        'desired_choice_count': desired_choice_count,
    }
    prompt_text = render_prompt(template, context)
    raw_text = call_gpt5(prompt_text, 'stage4_choice_node_generator', output_suffix='xml')
    parsed: Dict[str, Any] = {'nodes': [], 'choice_node_set': {}}
    if raw_text:
        root = parse_xml_response(raw_text, 'stage4_choice_node_generator')
        parsed['nodes'] = story_nodes_xml_to_records(root)
        parsed['choice_node_set'] = extract_choice_node_set(root)
    return stage_result('stage4_choice_node_generator', prompt_text, raw_text, parsed)


In [177]:
DEFAULT_IMAGE_MODEL = 'gemini-2.5-flash-image'


def run_reference_image_prompts(character_bible: Any, visual_style_defaults: str = '', style_guidance: str = '') -> Dict[str, Any]:
    template = PROMPT_TEMPLATES['reference_image_generator']
    if isinstance(character_bible, dict):
        bible_json = to_pretty_json(character_bible)
    elif isinstance(character_bible, list):
        bible_json = to_pretty_json(character_bible)
    elif isinstance(character_bible, str):
        bible_json = character_bible
    else:
        raise ValueError('character_bible must be dict, list, or JSON string')
    context = {
        'character_bible_json': bible_json,
        'visual_style_defaults': visual_style_defaults,
        'style_guidance': style_guidance,
    }
    prompt_text = render_prompt(template, context)
    raw_text = call_gpt5(prompt_text, 'reference_image_generator', output_suffix='json')
    parsed = parse_json_response(raw_text, 'reference_image_generator') if raw_text else {}
    return stage_result('reference_image_generator', prompt_text, raw_text, parsed)


def run_stage5_storyboard(
    story_node_graph: Any,
    choice_node_set: Any,
    character_bible: Any,
    reference_images: Any,
    style_defaults: Any = None,
    *,
    story_node_graph_xml: Optional[str] = None,
) -> Dict[str, Any]:
    template = PROMPT_TEMPLATES['stage5_storyboarder']

    def to_json_payload(obj: Any) -> str:
        if isinstance(obj, str):
            return obj
        return to_pretty_json(obj)

    xml_payload = story_node_graph_xml
    if not xml_payload and isinstance(story_node_graph, str) and story_node_graph.strip().startswith('<'):
        xml_payload = story_node_graph

    context = {
        'story_node_graph_xml': xml_payload or '',
        'story_node_graph_json': to_json_payload(story_node_graph),
        'choice_node_set_json': to_json_payload(choice_node_set),
        'character_bible_json': to_json_payload(character_bible),
        'reference_images_json': to_json_payload(reference_images),
        'style_defaults': to_json_payload(style_defaults or STYLE_GUIDE),
    }
    prompt_text = render_prompt(template, context)
    raw_text = call_gpt5(prompt_text, 'stage5_storyboarder', output_suffix='xml')
    parsed: Dict[str, Any] = {}
    if raw_text:
        root = parse_xml_response(raw_text, 'stage5_storyboarder')
        nodes = storyboard_nodes_xml_to_records(root)
        parsed = {
            'storyboard_nodes': {
                'nodes': nodes,
                'raw_xml': raw_text,
            }
        }

    return stage_result('stage5_storyboarder', prompt_text, raw_text, parsed)

def save_inline_content_images(response: Any, output_dir: Path, stem: str) -> List[Path]:
    ensure_directory(output_dir)
    saved_paths: List[Path] = []
    if not response or not getattr(response, 'candidates', None):
        return saved_paths
    for candidate_index, candidate in enumerate(response.candidates, start=1):
        content = getattr(candidate, 'content', None)
        if not content:
            continue
        for part_index, part in enumerate(content.parts, start=1):
            inline = getattr(part, 'inline_data', None)
            if inline and getattr(inline, 'data', None):
                data = inline.data
                if isinstance(data, str):
                    image_bytes = base64.b64decode(data)
                else:
                    image_bytes = data
                image = Image.open(BytesIO(image_bytes))
                mime_type = getattr(inline, 'mime_type', None) or 'image/png'
                extension = 'png' if 'png' in mime_type else 'jpg'
                base_path = output_dir / f"{stem}.{extension}"
                file_path = base_path
                if file_path.exists():
                    counter = 2
                    while True:
                        candidate_path = output_dir / f"{stem}_{counter}.{extension}"
                        if not candidate_path.exists():
                            file_path = candidate_path
                            break
                        counter += 1
                image.save(file_path)
                saved_paths.append(file_path)
    return saved_paths


def generate_reference_images(reference_prompts: Dict[str, Any], *, model: str = DEFAULT_IMAGE_MODEL, style_suffix: Optional[str] = None, negative_prompt: Optional[str] = None, limit: Optional[int] = None) -> List[Dict[str, Any]]:
    entries = (reference_prompts or {}).get('reference_prompts') or []
    if limit is not None:
        entries = entries[:limit]
    assets: List[Dict[str, Any]] = []
    output_dir = ensure_directory(RUN_OUTPUT_DIR / 'reference_images')
    for index, entry in enumerate(entries, start=1):
        prompt_text = entry.get('prompt_text') or entry.get('prompt') or entry.get('description')
        if not prompt_text:
            continue
        pieces = [prompt_text]
        if style_suffix:
            pieces.append(f"Style: {style_suffix}")
        if negative_prompt:
            pieces.append(f"Avoid: {negative_prompt}")
        composed_prompt = '\n'.join(pieces)
        if DRY_RUN:
            print(f"[dry-run] Skipping reference image generation for {entry.get('character_id')}")
            assets.append({'character_id': entry.get('character_id'), 'prompt': composed_prompt, 'image_paths': []})
            continue
        response = gemini_client.models.generate_content(model=model, contents=[composed_prompt])
        image_paths = save_inline_content_images(response, output_dir, entry.get('character_id') or f'reference_{index}')
        assets.append({
            'character_id': entry.get('character_id'),
            'prompt': composed_prompt,
            'image_paths': [str(p) for p in image_paths],
        })
    return assets


def build_reference_lookup(reference_assets: List[Dict[str, Any]]) -> Dict[str, str]:
    lookup: Dict[str, str] = {}
    for asset in reference_assets:
        paths = asset.get('image_paths') or []
        if asset.get('character_id') and paths:
            lookup[asset['character_id']] = paths[0]
    return lookup


def compose_story_prompt(node: Dict[str, Any], style_guide: Dict[str, Any]) -> str:
    parts: List[str] = []
    display_text = node.get('display_text') or node.get('scene_text')
    if display_text:
        parts.append(display_text)
    illustration = node.get('illustration_brief', {}) or {}
    setting = illustration.get('setting')
    if setting:
        parts.append(f"Setting details: {setting}")
    mood = illustration.get('mood_palette')
    if mood:
        parts.append(f"Mood palette: {mood}")
    safety = illustration.get('safety_notes')
    if safety:
        parts.append(f"Safety notes: {safety}")
    image_prompt = node.get('image_prompt') or illustration.get('prompt')
    if image_prompt:
        parts.append(f"Illustration focus: {image_prompt}")
    style_suffix = style_guide.get('style_prompt_suffix')
    if style_suffix:
        parts.append(f"Style: {style_suffix}")
    negative_prompt = style_guide.get('negative_prompt')
    if negative_prompt:
        parts.append(f"Please avoid: {negative_prompt}")
    return '\n'.join(parts)

def generate_story_images_with_references(storyboard: Dict[str, Any], reference_assets: List[Dict[str, Any]], *, model: str = DEFAULT_IMAGE_MODEL, style_guide: Optional[Dict[str, Any]] = None, limit: Optional[int] = None) -> List[Dict[str, Any]]:
    style = style_guide or STYLE_GUIDE
    nodes = storyboard.get('nodes', [])
    if limit is not None:
        nodes = nodes[:limit]
    output_dir = ensure_directory(RUN_OUTPUT_DIR / 'images')
    reference_lookup = build_reference_lookup(reference_assets)
    results: List[Dict[str, Any]] = []
    for node in nodes:
        node_id = node.get('node_id')
        illustration = node.get('illustration_brief', {})
        characters = illustration.get('characters', [])
        prompt_text = compose_story_prompt(node, style)
        used_references: List[Dict[str, Any]] = []
        if DRY_RUN:
            print(f"[dry-run] Skipping story image generation for node {node_id}")
            results.append({
                'node_id': node_id,
                'prompt': prompt_text,
                'image_paths': [],
                'references_used': used_references,
            })
            continue
        parts: List[types.Part] = [types.Part.from_text(text=prompt_text)]
        for char in characters:
            char_id = char.get('character_id')
            ref_path = reference_lookup.get(char_id)
            if not ref_path:
                continue
            try:
                data = Path(ref_path).read_bytes()
            except FileNotFoundError:
                continue
            import mimetypes
            mime_type = mimetypes.guess_type(ref_path)[0] or 'image/png'
            parts.append(types.Part.from_bytes(data=data, mime_type=mime_type))
            used_references.append({'character_id': char_id, 'image_path': ref_path})
        contents = [types.Content(role='user', parts=parts)]
        response = gemini_client.models.generate_content(model=model, contents=contents)
        image_paths = save_inline_content_images(response, output_dir, node_id or 'node')
        results.append({
            'node_id': node_id,
            'prompt': prompt_text,
            'image_paths': [str(p) for p in image_paths],
            'references_used': used_references,
        })
    return results


In [178]:

def embed_metadata_in_story_xml(
    story_metadata: Any,
    storyboard_xml: str,
    *,
    language: str = 'en-US',
    version: str = '2.0',
    read_time_minutes: int = 6,
    target_age_range: str = '3-5',
) -> str:
    if not storyboard_xml:
        return ''
    try:
        storyboard_root = ET.fromstring(storyboard_xml)
    except ET.ParseError as exc:
        raise ValueError('Invalid storyboard XML') from exc

    # Normalize NextNode values so the final node (or any "END") points to null.
    for node in storyboard_root.findall('Node'):
        next_elem = node.find('NextNode')
        if next_elem is None:
            next_elem = ET.SubElement(node, 'NextNode')
        if next_elem.text is None or next_elem.text.strip().upper() == 'END':
            next_elem.text = 'null'

    metadata_dict = story_metadata if isinstance(story_metadata, dict) else {}
    story_id = metadata_dict.get('story_id') or 'unknown-story'
    title = metadata_dict.get('title') or ''
    theme = metadata_dict.get('theme') or ''
    moral = metadata_dict.get('moral') or ''

    root = ET.Element(
        'InteractiveStorybook',
        {
            'story_id': story_id,
            'version': version,
            'language': language,
        },
    )

    metadata_elem = ET.SubElement(root, 'Metadata')
    ET.SubElement(metadata_elem, 'Title').text = title
    ET.SubElement(metadata_elem, 'Theme').text = theme
    ET.SubElement(metadata_elem, 'Moral').text = moral
    ET.SubElement(metadata_elem, 'ReadTime', {'minutes': str(read_time_minutes)})
    ET.SubElement(metadata_elem, 'TargetAgeRange').text = target_age_range

    root.append(storyboard_root)

    return ET.tostring(root, encoding='utf-8', xml_declaration=True).decode('utf-8')


def save_story_with_metadata(
    story_metadata: Any,
    storyboard_xml: str,
    *,
    output_dir: Path = RUN_OUTPUT_DIR,
    filename: str = 'storyboard_with_metadata.xml',
) -> Tuple[str, Optional[Path]]:
    combined_xml = embed_metadata_in_story_xml(story_metadata, storyboard_xml)
    if not combined_xml:
        return '', None
    output_path = output_dir / filename
    output_path.write_text(combined_xml)
    return combined_xml, output_path


In [179]:

DEFAULT_TTS_MODEL = 'gemini-2.5-flash-preview-tts'
DEFAULT_TTS_VOICE = 'Kore'
TTS_SAMPLE_RATE = 24000
TTS_SAMPLE_WIDTH = 2
TTS_CHANNELS = 1


def write_pcm_to_wav(pcm_bytes: bytes, path: Path, *, sample_rate: int = TTS_SAMPLE_RATE, channels: int = TTS_CHANNELS, sample_width: int = TTS_SAMPLE_WIDTH) -> Path:
    ensure_directory(path.parent)
    with wave.open(str(path), 'wb') as wav_file:
        wav_file.setnchannels(channels)
        wav_file.setsampwidth(sample_width)
        wav_file.setframerate(sample_rate)
        wav_file.writeframes(pcm_bytes)
    return path


def extract_inline_audio_bytes(response: Any) -> bytes:
    if not getattr(response, 'candidates', None):
        return b''
    for candidate in response.candidates:
        content = getattr(candidate, 'content', None)
        if not content:
            continue
        for part in getattr(content, 'parts', []) or []:
            inline = getattr(part, 'inline_data', None)
            if inline and getattr(inline, 'data', None):
                data = inline.data
                if isinstance(data, str):
                    return base64.b64decode(data)
                return data
    return b''


def call_gemini_tts(text: str, *, model: str = DEFAULT_TTS_MODEL, voice_name: str = DEFAULT_TTS_VOICE) -> bytes:
    generation_config = types.GenerateContentConfig(
        response_modalities=['AUDIO'],
        speech_config=types.SpeechConfig(
            voice_config=types.VoiceConfig(
                prebuilt_voice_config=types.PrebuiltVoiceConfig(voice_name=voice_name)
            )
        ),
    )
    response = gemini_client.models.generate_content(
        model=model,
        contents=[types.Part.from_text(text=text)],
        config=generation_config,
    )
    return extract_inline_audio_bytes(response)


def generate_storyboard_tts(
    storyboard: Dict[str, Any],
    *,
    model: str = DEFAULT_TTS_MODEL,
    voice_name: str = DEFAULT_TTS_VOICE,
    output_dir: Optional[Path] = None,
    limit: Optional[int] = None,
) -> List[Dict[str, Any]]:
    nodes = (storyboard or {}).get('nodes') if isinstance(storyboard, dict) else storyboard
    if not nodes:
        return []
    target_dir = ensure_directory(output_dir or (RUN_OUTPUT_DIR / 'audio'))
    results: List[Dict[str, Any]] = []
    for index, node in enumerate(nodes):
        if limit is not None and index >= limit:
            break
        node_id = node.get('node_id')
        if not node_id:
            continue
        scene_text = (node.get('scene_text') or node.get('display_text') or '').strip()
        retry_text = (node.get('retry_message') or '').strip()
        choices = node.get('choices') or []
        scene_audio_path = None
        retry_audio_path = None
        choice_audio_entries: List[Dict[str, Optional[str]]] = []

        if scene_text:
            filename = f"{node_id}.wav"
            if DRY_RUN:
                print(f"[dry-run] Skipping TTS for node {node_id}")
            else:
                try:
                    audio_bytes = call_gemini_tts(scene_text, model=model, voice_name=voice_name)
                    if audio_bytes:
                        scene_audio_path = write_pcm_to_wav(audio_bytes, target_dir / filename)
                except Exception as exc:
                    print(f"[warning] TTS failed for node {node_id}: {exc}")

        if retry_text:
            retry_filename = f"{node_id}_retry.wav"
            if DRY_RUN:
                print(f"[dry-run] Skipping retry TTS for node {node_id}")
            else:
                try:
                    retry_bytes = call_gemini_tts(retry_text, model=model, voice_name=voice_name)
                    if retry_bytes:
                        retry_audio_path = write_pcm_to_wav(retry_bytes, target_dir / retry_filename)
                except Exception as exc:
                    print(f"[warning] TTS retry failed for node {node_id}: {exc}")

        for choice in choices:
            choice_id = choice.get('id') or choice.get('choice_id')
            choice_text = (choice.get('text') or choice.get('ChoiceText') or '').strip()
            choice_audio_path = None
            if choice_id and choice_text:
                choice_filename = f"{node_id}_{choice_id}.wav"
                if DRY_RUN:
                    print(f"[dry-run] Skipping choice TTS for {node_id} option {choice_id}")
                else:
                    try:
                        choice_bytes = call_gemini_tts(choice_text, model=model, voice_name=voice_name)
                        if choice_bytes:
                            choice_audio_path = write_pcm_to_wav(choice_bytes, target_dir / choice_filename)
                    except Exception as exc:
                        print(f"[warning] TTS choice failed for {node_id} option {choice_id}: {exc}")
            choice_audio_entries.append(
                {
                    'choice_id': choice_id,
                    'choice_text': choice_text or None,
                    'audio': str(choice_audio_path) if choice_audio_path else None,
                }
            )

        results.append(
            {
                'node_id': node_id,
                'scene_text': scene_text or None,
                'scene_audio': str(scene_audio_path) if scene_audio_path else None,
                'retry_text': retry_text or None,
                'retry_audio': str(retry_audio_path) if retry_audio_path else None,
                'choices': choice_audio_entries if choice_audio_entries else None,
            }
        )
    return results


---

## 4. Stage Test Harness

Define a sample user brief for ad-hoc testing. Modify values as needed before invoking individual stages.

In [180]:
sample_user_brief = {
    'child_age': 4,
    'theme': 'Be carefull when crossing the road',
    # Describe the characters you want to be in this story.
    'character_preferences': 'My child is Asian, his name is Jacky, he is a boy',
    'plot_preferences': 'take place in space',
    # Other requirements for this story
    'parental_guidance': '',
}
desired_choice_count = 3
target_node_count = 10
reference_image_limit = None
story_image_limit = None


### Stage 1 – Educational Theme Integrator

In [181]:
stage1_result = run_stage1_learning_framework(sample_user_brief)
print('Stage 1 raw output stored at:', (RUN_OUTPUT_DIR / 'stage1_learning_framework.json'))
if stage1_result['parsed']:
    display(JSON(stage1_result['parsed']))
else:
    print('No parsed JSON available.')


🔄 Calling GPT-5 for stage1_learning_framework... (attempt 1)


KeyboardInterrupt: 

### Stage 2 – Story Generator

In [None]:
stage2_result = run_stage2_story(sample_user_brief, stage1_result['parsed'])
print('Stage 2 raw output stored at:', (RUN_OUTPUT_DIR / 'stage2_story_generator.json'))
if stage2_result['parsed']:
    display(JSON(stage2_result['parsed']))
else:
    print('No parsed JSON available.')


### Reference Image Prompt Generation

In [None]:
character_bible = stage2_result['parsed'].get('character_bible', []) if stage2_result['parsed'] else []
reference_prompts_result = run_reference_image_prompts(character_bible, STYLE_GUIDE.get('style_prompt_suffix', ''), '')
print('Reference prompt raw output stored at:', (RUN_OUTPUT_DIR / 'reference_image_generator.json'))
if reference_prompts_result['parsed']:
    display(JSON(reference_prompts_result['parsed']))
else:
    print('No reference prompt data available.')


In [None]:
reference_image_assets = generate_reference_images(
    reference_prompts_result['parsed'],
    model=DEFAULT_IMAGE_MODEL,
    style_suffix=STYLE_GUIDE.get('style_prompt_suffix'),
    negative_prompt=STYLE_GUIDE.get('negative_prompt'),
)
reference_image_lookup = build_reference_lookup(reference_image_assets)
display(JSON({'reference_images': reference_image_assets}))


### Stage 3 – Story Orchestrator

In [None]:
stage3_result = run_stage3_nodes(stage2_result['parsed'], node_count=target_node_count)
print('Stage 3 raw output stored at:', (RUN_OUTPUT_DIR / 'stage3_story_orchestrator.xml'))
preview = stage3_result['parsed'].get('nodes', [])[:3]
if preview:
    display(JSON({'preview_nodes': preview}))
else:
    print('No node previews available.')


### Stage 4 – Choice Node Generator

In [None]:
stage4_result = run_stage4_choices(stage3_result['raw'], stage1_result['parsed'], desired_choice_count)
print('Stage 4 raw output stored at:', (RUN_OUTPUT_DIR / 'stage4_choice_node_generator.xml'))
preview = stage4_result['parsed'].get('choice_node_set', {}).get('question_nodes', [])
if preview:
    display(JSON({'question_nodes_preview': preview[:2]}))
else:
    print('No choice node preview available.')


### Stage 5 – Storyboarder

In [None]:
story_id = None
if stage2_result['parsed']:
    story_id = stage2_result['parsed'].get('story_metadata', {}).get('story_id')
story_graph_payload = story_records_to_graph(stage4_result['parsed'].get('nodes', []), story_id=story_id)
choice_node_set_payload = stage4_result['parsed'].get('choice_node_set', {})
character_bible_payload = stage2_result['parsed'].get('character_bible', []) if stage2_result['parsed'] else []
reference_images_payload = reference_prompts_result['parsed'] if reference_prompts_result['parsed'] else {'reference_prompts': []}
stage5_result = run_stage5_storyboard(
    story_graph_payload,
    choice_node_set_payload,
    character_bible_payload,
    reference_images_payload,
    STYLE_GUIDE,
    story_node_graph_xml=stage4_result['raw'],
)
print('Stage 5 raw output stored at:', (RUN_OUTPUT_DIR / 'stage5_storyboarder.xml'))
story_metadata_payload = stage2_result['parsed'].get('story_metadata', {}) if stage2_result['parsed'] else {}
if stage5_result['raw']:
    combined_xml, combined_path = save_story_with_metadata(story_metadata_payload, stage5_result['raw'])
    if combined_path:
        print('Storyboard with metadata saved at:', combined_path)
    else:
        print('Storyboard metadata embedding skipped.')
else:
    print('Storyboard XML missing; skipping metadata embedding.')
if stage5_result['parsed']:
    display(JSON(stage5_result['parsed']))
else:
    print('No storyboard data available.')


### Text-to-Speech – Gemini 2.5 Flash TTS


In [None]:

tts_results = []
tts_output_dir = RUN_OUTPUT_DIR / 'audio'
if stage5_result['parsed']:
    storyboard_payload = stage5_result['parsed'].get('storyboard_nodes', {})
    if storyboard_payload:
        tts_results = generate_storyboard_tts(
            storyboard_payload,
            voice_name=DEFAULT_TTS_VOICE,
        )
        print('TTS audio saved to:', tts_output_dir)
        if tts_results:
            display(JSON({'tts_audio': tts_results}))
    else:
        print('Storyboard payload missing; skipping TTS generation.')
else:
    print('No storyboard data available; skipping TTS generation.')


### Image Generation – Nano Banana / Gemini 2.5 Flash

In [None]:
storyboard_payload = stage5_result['parsed'].get('storyboard_nodes', {}) if stage5_result['parsed'] else {}
if storyboard_payload:
    story_image_results = generate_story_images_with_references(
        storyboard_payload,
        reference_image_assets,
        model=DEFAULT_IMAGE_MODEL,
        style_guide=STYLE_GUIDE,
        limit=story_image_limit,
    )
    display(JSON({'story_images': story_image_results}))
else:
    print('Storyboard payload missing; cannot generate images.')


---

## 5. Pipeline Runner

Execute the end-to-end prompt flow (Stages 1-5) in one call. Toggle `DRY_RUN` if you only want to render prompts without hitting the APIs.

In [None]:

def run_full_pipeline(user_brief: Dict[str, Any], *, node_count: int = 15, choice_count: int = 4) -> Dict[str, Any]:
    stage1 = run_stage1_learning_framework(user_brief)
    stage2 = run_stage2_story(user_brief, stage1['parsed'])
    stage3 = run_stage3_nodes(stage2['parsed'], node_count=node_count)
    stage4 = run_stage4_choices(stage3['raw'], stage1['parsed'], choice_count)
    character_bible = stage2['parsed'].get('character_bible', []) if stage2['parsed'] else []
    reference_prompts = run_reference_image_prompts(
        character_bible,
        STYLE_GUIDE.get('style_prompt_suffix', ''),
        '',
    )
    reference_assets = generate_reference_images(
        reference_prompts['parsed'],
        model=DEFAULT_IMAGE_MODEL,
        style_suffix=STYLE_GUIDE.get('style_prompt_suffix'),
        negative_prompt=STYLE_GUIDE.get('negative_prompt'),
        limit=reference_image_limit,
    )
    story_id = stage2['parsed'].get('story_metadata', {}).get('story_id') if stage2['parsed'] else None
    story_graph_payload = story_records_to_graph(stage4['parsed'].get('nodes', []), story_id=story_id)
    choice_node_set_payload = stage4['parsed'].get('choice_node_set', {})
    reference_payload = reference_prompts['parsed'] if reference_prompts['parsed'] else {'reference_prompts': []}
    storyboard = run_stage5_storyboard(
        story_graph_payload,
        choice_node_set_payload,
        character_bible,
        reference_payload,
        STYLE_GUIDE,
        story_node_graph_xml=stage4['raw'],
    )
    storyboard_payload = storyboard['parsed'].get('storyboard_nodes', {}) if storyboard['parsed'] else {}
    tts_results = (
        generate_storyboard_tts(
            storyboard_payload,
            voice_name=DEFAULT_TTS_VOICE,
        )
        if storyboard_payload
        else []
    )
    story_metadata = stage2['parsed'].get('story_metadata', {}) if stage2['parsed'] else {}
    combined_storyboard_xml, combined_storyboard_path = '', None
    if storyboard['raw']:
        combined_storyboard_xml, combined_storyboard_path = save_story_with_metadata(
            story_metadata,
            storyboard['raw'],
            filename='storyboard_with_metadata_full_pipeline.xml',
        )
    story_images = generate_story_images_with_references(
        storyboard_payload,
        reference_assets,
        model=DEFAULT_IMAGE_MODEL,
        style_guide=STYLE_GUIDE,
        limit=story_image_limit,
    )
    return {
        'stage1': stage1,
        'stage2': stage2,
        'stage3': stage3,
        'stage4': stage4,
        'reference_prompts': reference_prompts,
        'reference_image_assets': reference_assets,
        'stage5': storyboard,
        'story_images': story_images,
        'storyboard_with_metadata': {
            'xml': combined_storyboard_xml,
            'path': str(combined_storyboard_path) if combined_storyboard_path else None,
        },
        'tts_audio': tts_results,
    }


In [None]:
full_pipeline_results = run_full_pipeline(sample_user_brief, node_count=target_node_count, choice_count=desired_choice_count)
storyboard_nodes_count = 0
if full_pipeline_results['stage5']['parsed']:
    storyboard_nodes_count = len(full_pipeline_results['stage5']['parsed'].get('storyboard_nodes', {}).get('nodes', []))
reference_image_count = len(full_pipeline_results.get('reference_image_assets', []))
story_image_count = len(full_pipeline_results.get('story_images', []))
tts_audio_count = len(full_pipeline_results.get('tts_audio', []))
print(
    'Pipeline finished. '
    f'Storyboard nodes: {storyboard_nodes_count}, '
    f'reference images: {reference_image_count}, '
    f'story node images: {story_image_count}, '
    f'tts clips: {tts_audio_count}'
)


🔄 Calling GPT-5 for stage1_learning_framework... (attempt 1)
🔄 Calling GPT-5 for stage2_story_generator... (attempt 1)
🔄 Calling GPT-5 for stage3_story_orchestrator... (attempt 1)
🔄 Calling GPT-5 for stage4_choice_node_generator... (attempt 1)
🔄 Calling GPT-5 for reference_image_generator... (attempt 1)
🔄 Calling GPT-5 for stage5_storyboarder... (attempt 1)


KeyboardInterrupt: 

---

## 6. Validation & Diagnostics

Lightweight checks on node counts, choice distribution, and storyboard completeness.

In [None]:
from collections import Counter


def summarize_choice_nodes(choice_node_set: Dict[str, Any]) -> Dict[str, int]:
    question_nodes = choice_node_set.get('question_nodes', [])
    outcome_nodes = choice_node_set.get('outcome_nodes', [])
    return {
        'question_count': len(question_nodes),
        'outcome_count': len(outcome_nodes),
    }


def validate_story_graph(nodes: List[Dict[str, Any]], expected_count: int) -> Dict[str, Any]:
    id_counter = Counter(node.get('node_id') for node in nodes)
    return {
        'total_nodes': len(nodes),
        'duplicate_ids': [node_id for node_id, count in id_counter.items() if count > 1],
        'missing_expected_count': len(nodes) != expected_count,
    }

validation_summary = {
    'stage3_nodes': validate_story_graph(stage3_result['parsed'].get('nodes', []), target_node_count),
    'stage4_choices': summarize_choice_nodes(stage4_result['parsed'].get('choice_node_set', {})),
    'storyboard_nodes': len(stage5_result['parsed'].get('storyboard_nodes', {}).get('nodes', [])) if stage5_result['parsed'] else 0,
    'reference_images_generated': len(reference_image_assets),
    'story_images_generated': len(story_image_results) if 'story_image_results' in globals() else 0,
    'tts_audio_generated': len(tts_results) if 'tts_results' in globals() else 0,
}

display(JSON(validation_summary))


---

## 7. Troubleshooting & Next Steps
- Toggle `DRY_RUN = True` to rehearse prompts without incurring API calls.
- Adjust `desired_choice_count` and `target_node_count` to stress-test pacing variations.
- Replace `reference_images_payload` with actual Nano Banana reference image metadata once generated.
- Integrate Stage 6 guardrail prompts when ready and extend the pipeline runner accordingly.
- Consider caching stage outputs to disk (beyond the timestamped folder) for regression comparisons.