Installs

In [None]:
!pip -q install -U torch==2.10.0 torchaudio==2.10.0

!pip -q install -U langgraph langchain-core requests beautifulsoup4 lxml

!pip -q install -U transformers==4.41.2 accelerate

!pip -q install -U diffusers safetensors pydub soundfile

[2K   [91m━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m205.4/915.7 MB[0m [31m19.9 MB/s[0m eta [36m0:00:36[0m
[?25h[31mERROR: Operation cancelled by user[0m[31m
[0m^C
^C
^C
^C


Imports

In [None]:
import os, re, json, random, time
from typing import TypedDict, List, Dict, Any, Optional

import requests

OUT_DIR = "/content/wiki_case_story_output"
os.makedirs(OUT_DIR, exist_ok=True)

# Reduce HF “helpful” prompts; still no token required for public models
os.environ["HF_HUB_DISABLE_IMPLICIT_TOKEN"] = "1"
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"

SESSION = requests.Session()
HEADERS = {
    "User-Agent": "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/122 Safari/537.36"
}

def clean_text(t: str) -> str:
    return re.sub(r"\s+", " ", (t or "")).strip()

def safe_filename(name: str) -> str:
    name = re.sub(r"[^a-zA-Z0-9_\- ]+", "", (name or "")).strip().replace(" ", "_")
    return name[:80] if name else "case"

def extract_json(text: str) -> str:
    """
    Extract the FIRST valid JSON object found in a model output.
    This avoids 'Extra data' when the model prints multiple objects or commentary.
    """
    if not text:
        raise ValueError("Empty text")

    # Fast path: try direct load
    try:
        json.loads(text)
        return text
    except Exception:
        pass

    # Scan for balanced JSON objects
    start = text.find("{")
    if start == -1:
        raise ValueError("No JSON object start found")

    depth = 0
    in_str = False
    escape = False
    for i in range(start, len(text)):
        ch = text[i]
        if in_str:
            if escape:
                escape = False
            elif ch == "\\":
                escape = True
            elif ch == '"':
                in_str = False
        else:
            if ch == '"':
                in_str = True
            elif ch == "{":
                depth += 1
            elif ch == "}":
                depth -= 1
                if depth == 0:
                    candidate = text[start:i+1]
                    # Validate it is JSON
                    json.loads(candidate)
                    return candidate

    raise ValueError("No complete JSON object found")


def word_count(s: str) -> int:
    return len((s or "").split())


Model

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

# If VRAM issues: "Qwen/Qwen2.5-0.5B-Instruct"
LLM_NAME = "Qwen/Qwen2.5-1.5B-Instruct"

tokenizer = AutoTokenizer.from_pretrained(LLM_NAME, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
    LLM_NAME,
    device_map="auto",
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
)

def llm(prompt: str, max_new_tokens: int = 800, temperature: float = 0.6) -> str:
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    with torch.no_grad():
        out = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            temperature=temperature,
            top_p=0.9,
        )
    return tokenizer.decode(out[0], skip_special_tokens=True)

print("Local LLM ready:", LLM_NAME)


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


✅ Local LLM ready: Qwen/Qwen2.5-1.5B-Instruct


Image + TTS

In [None]:
import soundfile as sf
from pydub import AudioSegment
from transformers import pipeline
from diffusers import StableDiffusionPipeline

device = "cuda" if torch.cuda.is_available() else "cpu"

# Image model (fast). If this ever fails due to bandwidth, try: "stabilityai/sdxl-turbo"
IMG_MODEL = "stabilityai/sd-turbo"

img_pipe = StableDiffusionPipeline.from_pretrained(
    IMG_MODEL,
    torch_dtype=torch.float16 if device == "cuda" else torch.float32,
    safety_checker=None,
    requires_safety_checker=False,
).to(device)

def generate_image(prompt: str, out_path: str):
    im = img_pipe(prompt, num_inference_steps=4, guidance_scale=0.0).images[0]
    im.save(out_path)
    return out_path

tts = pipeline(
    "text-to-speech",
    model="facebook/mms-tts-eng",
    device=0 if device == "cuda" else -1
)

def tts_to_wav(text: str, wav_path: str):
    out = tts(text)
    sf.write(wav_path, out["audio"], out["sampling_rate"])
    return wav_path

def concat_wavs(wav_paths: List[str], out_path: str):
    combined = AudioSegment.silent(duration=200)
    for wp in wav_paths:
        combined += AudioSegment.from_wav(wp) + AudioSegment.silent(duration=200)
    combined.export(out_path, format="wav")
    return out_path

print("Image + TTS ready.")


Loading pipeline components...:   0%|          | 0/5 [00:00<?, ?it/s]

Some weights of the model checkpoint at facebook/mms-tts-eng were not used when initializing VitsModel: ['flow.flows.0.wavenet.in_layers.0.weight_g', 'flow.flows.0.wavenet.in_layers.0.weight_v', 'flow.flows.0.wavenet.in_layers.1.weight_g', 'flow.flows.0.wavenet.in_layers.1.weight_v', 'flow.flows.0.wavenet.in_layers.2.weight_g', 'flow.flows.0.wavenet.in_layers.2.weight_v', 'flow.flows.0.wavenet.in_layers.3.weight_g', 'flow.flows.0.wavenet.in_layers.3.weight_v', 'flow.flows.0.wavenet.res_skip_layers.0.weight_g', 'flow.flows.0.wavenet.res_skip_layers.0.weight_v', 'flow.flows.0.wavenet.res_skip_layers.1.weight_g', 'flow.flows.0.wavenet.res_skip_layers.1.weight_v', 'flow.flows.0.wavenet.res_skip_layers.2.weight_g', 'flow.flows.0.wavenet.res_skip_layers.2.weight_v', 'flow.flows.0.wavenet.res_skip_layers.3.weight_g', 'flow.flows.0.wavenet.res_skip_layers.3.weight_v', 'flow.flows.1.wavenet.in_layers.0.weight_g', 'flow.flows.1.wavenet.in_layers.0.weight_v', 'flow.flows.1.wavenet.in_layers.1.wei

✅ Image + TTS ready.


WIKI API Tool

In [None]:
WIKI_API = "https://en.wikipedia.org/w/api.php"

CASE_QUERY_SEEDS = [
    "notorious criminal case United States",
    "famous criminal case United States",
    "high-profile kidnapping case",
    "bank robbery manhunt case",
    "unsolved mystery case United States",
    "organized crime case United States",
    "serial killer investigation case",
    "famous trial case United States",
]

BAD_TITLE_PREFIXES = ("List of", "Category:", "Template:", "Help:", "Portal:")
BAD_TITLE_CONTAINS = ("(disambiguation)",)

def wiki_search(query: str, limit: int = 10) -> List[Dict[str, Any]]:
    params = {
        "action": "query",
        "list": "search",
        "srsearch": query,
        "srlimit": str(limit),
        "format": "json",
        "utf8": "1",
    }
    r = SESSION.get(WIKI_API, params=params, headers=HEADERS, timeout=30)
    r.raise_for_status()
    data = r.json()
    results = data.get("query", {}).get("search", [])
    out = []
    for item in results:
        title = item.get("title", "")
        if not title or title.startswith(BAD_TITLE_PREFIXES) or any(x in title for x in BAD_TITLE_CONTAINS):
            continue
        out.append({
            "title": title,
            "pageid": item.get("pageid"),
            "snippet": clean_text(re.sub(r"<.*?>", "", item.get("snippet", ""))),  # strip HTML tags from snippet
            "timestamp": item.get("timestamp"),
        })
    return out

def wiki_extract(pageid: int) -> Dict[str, Any]:
    params = {
        "action": "query",
        "pageids": str(pageid),
        "prop": "extracts|info",
        "explaintext": "1",
        "exsectionformat": "plain",
        "inprop": "url",
        "format": "json",
        "utf8": "1",
    }
    r = SESSION.get(WIKI_API, params=params, headers=HEADERS, timeout=30)
    r.raise_for_status()
    pages = r.json().get("query", {}).get("pages", {})
    page = pages.get(str(pageid), {})
    return {
        "title": page.get("title", ""),
        "fullurl": page.get("fullurl", ""),
        "extract": page.get("extract", "") or "",
    }

def trim_wiki_text(text: str, max_chars: int = 9000) -> str:
    # chop off references-ish sections to keep prompt clean
    cut_markers = ["References", "External links", "See also", "Further reading", "Bibliography", "Notes"]
    t = text
    for m in cut_markers:
        idx = t.find("\n" + m + "\n")
        if idx != -1:
            t = t[:idx]
    t = clean_text(t)
    return t[:max_chars]


LangGraph Nodes

In [None]:
from langgraph.graph import StateGraph, END

class StoryState(TypedDict, total=False):
    genre: str
    query: str
    candidates: List[Dict[str, Any]]
    chosen: Dict[str, Any]
    source_title: str
    source_url: str
    source_text: str
    story_text: str
    sections_raw: str
    sections: List[Dict[str, str]]
    attempts: int
    json_error: str
    out_dir: str

def node_make_query(state: StoryState) -> StoryState:
    # Make it “agentic”: let LLM pick a strong query in the "notorious cases" space
    seed = random.choice(CASE_QUERY_SEEDS)
    prompt = f"""
Generate ONE Wikipedia search query to find a compelling "notorious criminal case" or "famous case" article.
Goal: documentary-style narrative potential (investigation, mystery, twist, trial, manhunt).
Return ONLY JSON: {{"query":"..."}}

Seed idea: {seed}
"""
    out = llm(prompt, max_new_tokens=120, temperature=0.3)
    try:
        obj = json.loads(extract_json(out))
        query = clean_text(obj.get("query", "")) or seed
    except Exception:
        query = seed
    return {**state, "query": query}

def node_search_wiki(state: StoryState) -> StoryState:
    query = state["query"]
    cands = wiki_search(query, limit=12)

    # Lightweight quality filter: prefer likely “case-like” pages
    keywords = ["case", "murder", "kidnapping", "trial", "investigation", "robbery", "manhunt", "crime", "shooting"]
    filtered = []
    for c in cands:
        blob = (c["title"] + " " + c.get("snippet", "")).lower()
        if any(k in blob for k in keywords):
            filtered.append(c)
    if len(filtered) >= 5:
        cands = filtered

    if not cands:
        # fallback: broaden query a bit
        cands = wiki_search("notorious criminal case", limit=12)
    return {**state, "candidates": cands[:10]}

def node_choose_case(state: StoryState) -> StoryState:
    cands = state["candidates"][:8]
    prompt = f"""
Pick ONE Wikipedia article that will make the best true-crime / notorious-case narrative.
Prefer: clear timeline, investigation/manhunt, suspense, consequences, or unresolved mystery.
Avoid: purely biographical pages unless the case itself is central.

Return ONLY JSON: {{"title":"...","pageid":123,"reason":"..."}}

Candidates:
{json.dumps(cands, indent=2)}
"""
    out = llm(prompt, max_new_tokens=240, temperature=0.2)

    try:
        obj = json.loads(extract_json(out))
    except Exception:
        # quick repair pass
        repair = llm(
            "Convert the following into ONLY valid JSON with keys "
            '"title", "pageid", "reason". No extra text.\n\n'
            f"{out}",
            max_new_tokens=200,
            temperature=0.0,
        )
        obj = json.loads(extract_json(repair))

    picked = None
    for c in cands:
        if c.get("pageid") == obj.get("pageid") or c.get("title") == obj.get("title"):
            picked = c
            break
    if not picked:
        picked = cands[0]
    return {**state, "chosen": picked}


def node_fetch_wiki(state: StoryState) -> StoryState:
    pageid = int(state["chosen"]["pageid"])
    data = wiki_extract(pageid)
    text = trim_wiki_text(data["extract"], max_chars=9000)
    if word_count(text) < 250:
        raise RuntimeError("Wikipedia extract is too short to build a story. Try rerun for a different case.")
    return {
        **state,
        "source_title": data["title"],
        "source_url": data["fullurl"],
        "source_text": text,
    }

def node_write_story(state: StoryState) -> StoryState:
    prompt = f"""
Write a compelling NON-FICTION documentary-style narrative from the SOURCE TEXT.
Rules:
- Use ONLY facts from SOURCE TEXT (no invented events, names, or details).
- High-school readable.
- 900–1300 words.
- No bullet points.
- Keep a clear timeline with tension.

TITLE: {state["source_title"]}
SOURCE URL: {state["source_url"]}

SOURCE TEXT:
{state["source_text"]}

Return ONLY the story text.
"""
    story = llm(prompt, max_new_tokens=1600, temperature=0.7)
    return {**state, "story_text": story}

def node_split_sections(state: StoryState) -> StoryState:
    prompt = f"""
Split the story into exactly 6 sections.
Return ONLY valid JSON (no markdown) in this schema:

{{
  "sections": [
    {{
      "section_title": "...",
      "section_text": "...",
      "image_prompt": "..."
    }}
  ]
}}

Constraints:
- section_text must be 120–220 words each
- image_prompt should be cinematic documentary still, descriptive, NO text in image
- avoid graphic gore; keep it safe and documentary-like

STORY:
{state["story_text"]}
"""
    out = llm(prompt, max_new_tokens=1100, temperature=0.3)
    return {**state, "sections_raw": out, "json_error": "", "attempts": state.get("attempts", 0)}

def validate_sections(sections: Any) -> Optional[str]:
    if not isinstance(sections, list) or len(sections) != 6:
        return "sections must be a list of length 6"
    for i, s in enumerate(sections):
        if not isinstance(s, dict):
            return f"section {i} must be an object"
        for k in ["section_title", "section_text", "image_prompt"]:
            if k not in s or not isinstance(s[k], str) or not s[k].strip():
                return f"section {i} missing/invalid {k}"
        wc = len(s["section_text"].split())
        if wc < 120 or wc > 220:
            return f"section {i} section_text word count {wc} not in 120–220"
    return None

def node_validate(state: StoryState) -> StoryState:
    raw = state["sections_raw"]
    try:
        obj = json.loads(extract_json(raw))
        err = validate_sections(obj.get("sections"))
        if err:
            return {**state, "json_error": err}
        return {**state, "sections": obj["sections"], "json_error": ""}
    except Exception as e:
        return {**state, "json_error": f"JSON parse error: {e}"}

def node_fix_json(state: StoryState) -> StoryState:
    attempts = state.get("attempts", 0) + 1
    prompt = f"""
Fix the output into ONLY valid JSON matching this schema exactly:

{{
  "sections": [
    {{
      "section_title": "...",
      "section_text": "...",
      "image_prompt": "..."
    }}
  ]
}}

There must be exactly 6 sections.
Each section_text must be 120–220 words.

Problem:
{state.get("json_error","")}

Bad output:
{state["sections_raw"]}

Return ONLY corrected JSON.
"""
    out = llm(prompt, max_new_tokens=1000, temperature=0.2)
    return {**state, "sections_raw": out, "attempts": attempts}

def route_after_validate(state: StoryState) -> str:
    if not state.get("json_error"):
        return "assets"
    if state.get("attempts", 0) < 2:
        return "fix"
    return "give_up"

def node_generate_assets(state: StoryState) -> StoryState:
    slug = safe_filename(state["source_title"])
    out_dir = os.path.join(OUT_DIR, slug)
    os.makedirs(out_dir, exist_ok=True)

    wavs = []
    for i, sec in enumerate(state["sections"], start=1):
        img_path = os.path.join(out_dir, f"section_{i:02d}.png")
        wav_path = os.path.join(out_dir, f"section_{i:02d}.wav")

        generate_image(sec["image_prompt"], img_path)
        tts_to_wav(sec["section_text"], wav_path)
        wavs.append(wav_path)

    full_audio = os.path.join(out_dir, "full_narration.wav")
    concat_wavs(wavs, full_audio)

    payload = {
        "query": state["query"],
        "source_title": state["source_title"],
        "source_url": state["source_url"],
        "source_text": state["source_text"],
        "story_text": state["story_text"],
        "sections": state["sections"],
    }
    with open(os.path.join(out_dir, "story_data.json"), "w", encoding="utf-8") as f:
        json.dump(payload, f, indent=2, ensure_ascii=False)

    html_path = os.path.join(out_dir, "storybook.html")
    parts = [
        f"<h1>{payload['source_title']}</h1>",
        f"<p><b>Wikipedia source:</b> <a href='{payload['source_url']}' target='_blank'>{payload['source_url']}</a></p>",
        f"<p><b>Search query:</b> {payload['query']}</p>",
        "<hr>",
    ]
    for i, sec in enumerate(payload["sections"], start=1):
        parts += [
            f"<h2>{i}. {sec['section_title']}</h2>",
            f"<img src='section_{i:02d}.png' style='max-width:900px;width:100%;border-radius:12px;'>",
            f"<p style='font-size:18px;line-height:1.5'>{sec['section_text']}</p>",
            f"<p><i>Prompt:</i> {sec['image_prompt']}</p>",
            "<hr>",
        ]
    parts += [
        "<h3>Full narration audio</h3>",
        "<audio controls src='full_narration.wav'></audio>",
    ]
    with open(html_path, "w", encoding="utf-8") as f:
        f.write("\n".join(parts))

    return {**state, "out_dir": out_dir}

def node_give_up(state: StoryState) -> StoryState:
    raise ValueError(f"Failed to produce valid sections JSON after retries. Last error: {state.get('json_error')}")


Build LangGraph

In [None]:
g = StateGraph(StoryState)

g.add_node("make_query", node_make_query)
g.add_node("search", node_search_wiki)
g.add_node("choose", node_choose_case)
g.add_node("fetch", node_fetch_wiki)
g.add_node("story", node_write_story)
g.add_node("split", node_split_sections)
g.add_node("validate", node_validate)
g.add_node("fix", node_fix_json)
g.add_node("assets", node_generate_assets)
g.add_node("give_up", node_give_up)

g.set_entry_point("make_query")
g.add_edge("make_query", "search")
g.add_edge("search", "choose")
g.add_edge("choose", "fetch")
g.add_edge("fetch", "story")
g.add_edge("story", "split")
g.add_edge("split", "validate")

g.add_conditional_edges(
    "validate",
    route_after_validate,
    {"assets": "assets", "fix": "fix", "give_up": "give_up"},
)
g.add_edge("fix", "validate")
g.add_edge("assets", END)

app = g.compile()
print("Graph compiled. Running...")

final_state = app.invoke({"attempts": 0})
print("\n Done!")
print("Query:", final_state["query"])
print("Chosen:", final_state["source_title"])
print("Output folder:", final_state["out_dir"])


✅ Graph compiled. Running...


ValueError: Failed to produce valid sections JSON after retries. Last error: sections must be a list of length 6