diff --git a/configs/scenes/test_scene_001.yaml b/configs/scenes/test_scene_001.yaml index cd08930..1463d01 100644 --- a/configs/scenes/test_scene_001.yaml +++ b/configs/scenes/test_scene_001.yaml @@ -14,7 +14,7 @@ speakers: - speaker_id: VIC_F_25-40_002 role: VIC -script_template: avdp/script/templates/she_proves/intimate_terror_coercive_control.j2 +script_template: synthbanshee/script/templates/she_proves/intimate_terror_coercive_control.j2 script_slots: relationship: spouse setting: apartment_kitchen diff --git a/synthbanshee/script/__init__.py b/synthbanshee/script/__init__.py new file mode 100644 index 0000000..ce3f25a --- /dev/null +++ b/synthbanshee/script/__init__.py @@ -0,0 +1,12 @@ +"""Script generation: LLM-based dialogue generator, disfluency injection, validation.""" + +from synthbanshee.script.generator import ScriptGenerator, inject_disfluency, validate_script +from synthbanshee.script.types import DialogueTurn, MixedScene + +__all__ = [ + "DialogueTurn", + "MixedScene", + "ScriptGenerator", + "inject_disfluency", + "validate_script", +] diff --git a/synthbanshee/script/generator.py b/synthbanshee/script/generator.py new file mode 100644 index 0000000..d6f1766 --- /dev/null +++ b/synthbanshee/script/generator.py @@ -0,0 +1,361 @@ +"""LLM-based script generator for AVDP scenes. + +Renders a Jinja2 prompt template (from the scene config's script_template field) +and calls an LLM (Anthropic Claude or OpenAI GPT-4) to produce a structured Hebrew +dialogue. Results are cached on disk keyed by a SHA-256 of all generation inputs +so identical scene configs never hit the API twice. + +Cache key components: scene_id, script_template path, script_slots JSON, +intensity_arc, random_seed, provider, model, speaker IDs. +""" + +from __future__ import annotations + +import hashlib +import json +import math +import re +import unicodedata +from pathlib import Path + +from synthbanshee.script.types import DialogueTurn + +_DEFAULT_CACHE_DIR = Path("assets/scripts") +_DEFAULT_ANTHROPIC_MODEL = "claude-opus-4-6" +_DEFAULT_OPENAI_MODEL = "gpt-4o" + +# Hebrew filled-pause tokens inserted by inject_disfluency +_HE_FILLED_PAUSES = ["אממ", "אה", "אנ"] + + +def inject_disfluency( + text: str, + prob: float = 0.10, + rng_seed: int | None = None, +) -> str: + """Insert Hebrew filled-pause tokens between sentences with probability *prob*. + + Operates on sentence boundaries (splits on '.', '!', '?' followed by space). + The original text is never truncated — pauses are inserted, not substituted. + + Args: + text: Hebrew UTF-8 text. + prob: Probability of inserting a filled pause between any two sentences. + rng_seed: Optional seed for reproducibility. + + Returns: + Modified text with occasional Hebrew filled pauses. + """ + import random + + rng = random.Random(rng_seed) + # Split into sentences keeping the delimiters + parts = re.split(r"(?<=[.!?])\s+", text.strip()) + if len(parts) <= 1: + return text + + result_parts: list[str] = [parts[0]] + for part in parts[1:]: + if rng.random() < prob: + pause_token = rng.choice(_HE_FILLED_PAUSES) + result_parts.append(pause_token) + result_parts.append(part) + return " ".join(result_parts) + + +def validate_script( + turns: list[DialogueTurn], + known_speaker_ids: set[str], +) -> list[str]: + """Validate a generated script for spec compliance. + + Checks: + - All turns have non-empty Hebrew text + - All speaker_ids appear in known_speaker_ids + - Intensity values are 1–5 + - No 4+ consecutive identical tokens (LLM repetition artifact) + + Returns: + List of error message strings (empty → valid). + """ + errors: list[str] = [] + for i, turn in enumerate(turns): + prefix = f"turn[{i}]" + + if not turn.text.strip(): + errors.append(f"{prefix}: empty text") + continue + + if turn.speaker_id not in known_speaker_ids: + errors.append(f"{prefix}: speaker_id {turn.speaker_id!r} not in known speakers") + + if turn.intensity not in range(1, 6): + errors.append(f"{prefix}: intensity {turn.intensity} out of range 1–5") + + # Validate pause_before_s: must be finite and within [0.0, 1.5] s + if not math.isfinite(turn.pause_before_s) or not (0.0 <= turn.pause_before_s <= 1.5): + errors.append( + f"{prefix}: pause_before_s {turn.pause_before_s} must be finite and in [0.0, 1.5]" + ) + + # Detect repetition: 4+ consecutive identical whitespace-split tokens + tokens = turn.text.split() + run = 1 + for j in range(1, len(tokens)): + if tokens[j] == tokens[j - 1]: + run += 1 + if run >= 4: + errors.append(f"{prefix}: 4+ consecutive identical tokens ({tokens[j]!r})") + break + else: + run = 1 + + # Check text contains at least some Unicode Hebrew characters + has_hebrew = any( + unicodedata.name(ch, "").startswith("HEBREW") for ch in turn.text if ch.strip() + ) + if not has_hebrew: + errors.append(f"{prefix}: text contains no Hebrew characters") + + return errors + + +class ScriptGenerator: + """Generate a structured Hebrew dialogue from a scene config using an LLM. + + Supports Anthropic (Claude) and OpenAI (GPT-4o) providers. + Results are cached to ``cache_dir`` so identical scenes never re-call the API. + """ + + def __init__( + self, + provider: str = "anthropic", + model: str | None = None, + cache_dir: Path | str = _DEFAULT_CACHE_DIR, + ) -> None: + if provider not in {"anthropic", "openai"}: + raise ValueError(f"provider must be 'anthropic' or 'openai', got {provider!r}") + self._provider = provider + self._model = model or ( + _DEFAULT_ANTHROPIC_MODEL if provider == "anthropic" else _DEFAULT_OPENAI_MODEL + ) + self._cache_dir = Path(cache_dir) + + # ------------------------------------------------------------------ + # Cache + # ------------------------------------------------------------------ + + def _cache_key( + self, + scene_id: str, + script_template: str, + script_slots: dict, + intensity_arc: list[int], + random_seed: int, + speaker_ids: list[str], + ) -> str: + payload = json.dumps( + { + "scene_id": scene_id, + "script_template": script_template, + "script_slots": script_slots, + "intensity_arc": intensity_arc, + "random_seed": random_seed, + "provider": self._provider, + "model": self._model, + "speaker_ids": sorted(speaker_ids), + }, + sort_keys=True, + ) + return hashlib.sha256(payload.encode()).hexdigest() + + def _cache_path(self, key: str) -> Path: + return self._cache_dir / f"{key}.json" + + def _load_from_cache(self, key: str) -> list[DialogueTurn] | None: + p = self._cache_path(key) + if not p.exists(): + return None + raw = json.loads(p.read_text(encoding="utf-8")) + return [DialogueTurn(**t) for t in raw["turns"]] + + def _save_to_cache(self, key: str, turns: list[DialogueTurn]) -> None: + p = self._cache_path(key) + p.parent.mkdir(parents=True, exist_ok=True) + data = { + "turns": [ + { + "speaker_id": t.speaker_id, + "text": t.text, + "intensity": t.intensity, + "pause_before_s": t.pause_before_s, + "emotional_state": t.emotional_state, + } + for t in turns + ] + } + p.write_text(json.dumps(data, ensure_ascii=False, indent=2), encoding="utf-8") + + # ------------------------------------------------------------------ + # Template rendering + # ------------------------------------------------------------------ + + def _render_prompt( + self, + template_path: str, + scene_id: str, + project: str, + violence_typology: str, + script_slots: dict, + intensity_arc: list[int], + target_duration_minutes: float, + speakers: list[dict], + ) -> str: + """Render the Jinja2 prompt template to a string.""" + from jinja2 import Environment, FileSystemLoader, StrictUndefined + + tpl_path = Path(template_path) + env = Environment( + loader=FileSystemLoader(str(tpl_path.parent)), + undefined=StrictUndefined, + keep_trailing_newline=True, + ) + template = env.get_template(tpl_path.name) + return template.render( + scene_id=scene_id, + project=project, + violence_typology=violence_typology, + script_slots=script_slots, + intensity_arc=intensity_arc, + target_duration_minutes=target_duration_minutes, + speakers=speakers, + ) + + # ------------------------------------------------------------------ + # LLM calls + # ------------------------------------------------------------------ + + def _call_anthropic(self, prompt: str) -> str: + import anthropic + + client = anthropic.Anthropic() + message = client.messages.create( + model=self._model, + max_tokens=4096, + messages=[{"role": "user", "content": prompt}], + ) + return message.content[0].text + + def _call_openai(self, prompt: str) -> str: + import openai + + client = openai.OpenAI() + response = client.chat.completions.create( + model=self._model, + messages=[{"role": "user", "content": prompt}], + max_tokens=4096, + ) + return response.choices[0].message.content or "" + + def _call_llm(self, prompt: str) -> str: + if self._provider == "anthropic": + return self._call_anthropic(prompt) + return self._call_openai(prompt) + + # ------------------------------------------------------------------ + # Response parsing + # ------------------------------------------------------------------ + + @staticmethod + def _parse_response(raw: str) -> list[DialogueTurn]: + """Extract dialogue turns from the LLM's JSON response. + + Accepts a raw response that may include markdown code fences. + """ + # Strip markdown code fences if present + stripped = raw.strip() + fence_match = re.search(r"```(?:json)?\s*([\s\S]+?)```", stripped) + if fence_match: + stripped = fence_match.group(1).strip() + + data = json.loads(stripped) + turns_raw = data.get("turns", []) + turns: list[DialogueTurn] = [] + for t in turns_raw: + turns.append( + DialogueTurn( + speaker_id=t["speaker_id"], + text=t["text"], + intensity=int(t.get("intensity", 1)), + pause_before_s=float(t.get("pause_before_s", 0.3)), + emotional_state=str(t.get("emotional_state", "neutral")), + ) + ) + return turns + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + def generate( + self, + scene_id: str, + project: str, + violence_typology: str, + script_template: str, + script_slots: dict, + intensity_arc: list[int], + target_duration_minutes: float, + speakers: list[dict], + random_seed: int = 0, + ) -> list[DialogueTurn]: + """Generate a dialogue script for the given scene parameters. + + Args: + scene_id: Unique scene identifier. + project: 'she_proves' or 'elephant_in_the_room'. + violence_typology: e.g. 'IT', 'SV', 'NEU'. + script_template: Path to the Jinja2 prompt template file. + script_slots: Template slot values (e.g. relationship, setting). + intensity_arc: Sequence of 1–5 intensity levels for the scene. + target_duration_minutes: Desired total clip duration. + speakers: List of speaker dicts with speaker_id, role, gender, etc. + random_seed: Seed used in the cache key for reproducibility. + + Returns: + List of DialogueTurn objects forming the full scene script. + + Raises: + ValueError: If the LLM response cannot be parsed or fails validation. + """ + speaker_ids = [s["speaker_id"] for s in speakers] + key = self._cache_key( + scene_id, script_template, script_slots, intensity_arc, random_seed, speaker_ids + ) + + cached = self._load_from_cache(key) + if cached is not None: + return cached + + prompt = self._render_prompt( + template_path=script_template, + scene_id=scene_id, + project=project, + violence_typology=violence_typology, + script_slots=script_slots, + intensity_arc=intensity_arc, + target_duration_minutes=target_duration_minutes, + speakers=speakers, + ) + + raw_response = self._call_llm(prompt) + turns = self._parse_response(raw_response) + + errors = validate_script(turns, known_speaker_ids=set(speaker_ids)) + if errors: + raise ValueError( + f"Script validation failed for scene {scene_id}:\n" + "\n".join(errors) + ) + + self._save_to_cache(key, turns) + return turns diff --git a/synthbanshee/script/templates/elephant/base_scene.j2 b/synthbanshee/script/templates/elephant/base_scene.j2 new file mode 100644 index 0000000..cfb6113 --- /dev/null +++ b/synthbanshee/script/templates/elephant/base_scene.j2 @@ -0,0 +1,54 @@ +You are a professional Hebrew dialogue writer creating training data for a workplace-safety audio detection system. + +Write a realistic Hebrew dialogue for the following scene. The dialogue must be entirely in Hebrew (UTF-8). + +SCENE PARAMETERS +================ +Scene ID: {{ scene_id }} +Project: elephant_in_the_room (clinic/welfare office alert system) +Violence typology: {{ violence_typology }} +{% for key, val in script_slots.items() %}{{ key }}: {{ val }} +{% endfor %} +Target duration: {{ target_duration_minutes }} minutes +Intensity arc: {{ intensity_arc }} (1 = calm, 5 = peak threat/attack) + +SPEAKERS +======== +{% for sp in speakers %} +- ID: {{ sp.speaker_id }} Role: {{ sp.role }} Gender: {{ sp.gender }} +{% endfor %} + +SCENE STRUCTURE GUIDANCE +========================= +Clinic/welfare-office incidents typically follow this arc: + 1. Routine interaction — client and social worker / clinician in normal session. + 2. Frustration — client becomes impatient, dissatisfied, or agitated. + 3. Verbal escalation — raised voice, threats, verbal abuse. + 4. Peak — explicit threat or physical intimidation (the alert-trigger moment). +The alert should occur in the final 40% of the scene duration. + +DIALOGUE RULES +============== +1. Every utterance MUST be Hebrew text only — no transliteration, no English. +2. Follow the intensity arc chronologically across the turns. +3. Each turn is 1–4 sentences. Aim for roughly {{ (target_duration_minutes * 60 / 15) | int }} turns total. +4. Silence gaps (pause_before_s): 0.2–1.5 s; longer gaps at transitions. +5. emotional_state must be one of: neutral, angry, fearful, sad, pleading, threatening, calm. +6. The social worker / clinician (SW or CLIN role) should attempt de-escalation. +7. Do NOT write binary Violence/Non-Violence labels. + +OUTPUT FORMAT +============= +Return ONLY valid JSON — no prose, no markdown fences. + +{ + "turns": [ + { + "speaker_id": "", + "text": "", + "intensity": , + "pause_before_s": , + "emotional_state": "" + } + ] +} diff --git a/synthbanshee/script/templates/she_proves/base_scene.j2 b/synthbanshee/script/templates/she_proves/base_scene.j2 new file mode 100644 index 0000000..7c99fc2 --- /dev/null +++ b/synthbanshee/script/templates/she_proves/base_scene.j2 @@ -0,0 +1,47 @@ +You are a professional Hebrew dialogue writer creating training data for a domestic-violence audio detection system. + +Write a realistic Hebrew dialogue for the following scene. The dialogue must be entirely in Hebrew (UTF-8). + +SCENE PARAMETERS +================ +Scene ID: {{ scene_id }} +Project: she_proves (smartphone passive monitoring) +Violence typology: {{ violence_typology }} +{% for key, val in script_slots.items() %}{{ key }}: {{ val }} +{% endfor %} +Target duration: {{ target_duration_minutes }} minutes +Intensity arc: {{ intensity_arc }} (1 = calm, 5 = peak) + +SPEAKERS +======== +{% for sp in speakers %} +- ID: {{ sp.speaker_id }} Role: {{ sp.role }} Gender: {{ sp.gender }} +{% endfor %} + +DIALOGUE RULES +============== +1. Every utterance MUST be Hebrew text only — no transliteration, no English. +2. Follow the intensity arc chronologically: the dialogue must build from the first + intensity value to the last across the turns. +3. Each turn is 1–4 sentences. Aim for roughly {{ (target_duration_minutes * 60 / 15) | int }} turns total + (assuming ~15 s per turn on average). +4. Silence gaps (pause_before_s) between turns: 0.2–1.5 s; use longer gaps at + scene transitions or after confrontational outbursts. +5. emotional_state must be one of: neutral, angry, fearful, sad, pleading, threatening, calm. +6. Do NOT use binary Violence/Non-Violence labels anywhere. + +OUTPUT FORMAT +============= +Return ONLY valid JSON — no prose, no markdown fences. + +{ + "turns": [ + { + "speaker_id": "", + "text": "", + "intensity": , + "pause_before_s": , + "emotional_state": "" + } + ] +} diff --git a/synthbanshee/script/templates/she_proves/intimate_terror_coercive_control.j2 b/synthbanshee/script/templates/she_proves/intimate_terror_coercive_control.j2 new file mode 100644 index 0000000..54ee6db --- /dev/null +++ b/synthbanshee/script/templates/she_proves/intimate_terror_coercive_control.j2 @@ -0,0 +1,61 @@ +{# Intimate-terror / coercive-control scene specialisation for she_proves. + Standalone template (not using Jinja2 {% extends %}) so the generator can + render it directly via FileSystemLoader without block definitions. +#} +You are a professional Hebrew dialogue writer creating training data for a domestic-violence audio detection system. + +Write a realistic Hebrew dialogue for the following scene. The dialogue must be entirely in Hebrew (UTF-8). + +SCENE PARAMETERS +================ +Scene ID: {{ scene_id }} +Project: she_proves (smartphone passive monitoring) +Violence typology: {{ violence_typology }} — Intimate Terror / Coercive Control +Relationship: {{ script_slots.get('relationship', 'spouses') }} +Setting: {{ script_slots.get('setting', 'apartment') }} +Grievance trigger: {{ script_slots.get('grievance', 'unspecified') }} +Target duration: {{ target_duration_minutes }} minutes +Intensity arc: {{ intensity_arc }} (1 = calm baseline, 5 = peak coercive episode) + +SPEAKERS +======== +{% for sp in speakers %} +- ID: {{ sp.speaker_id }} Role: {{ sp.role }} Gender: {{ sp.gender }} +{% endfor %} + +SCENE STRUCTURE GUIDANCE +========================= +Coercive-control scenes typically follow this arc: + 1. Baseline / deceptive calm — perpetrator initiates conversation normally. + 2. Tension building — subtle criticism, checking behaviour, jealousy cues. + 3. Grievance statement — perpetrator articulates the perceived transgression. + 4. Escalation — threats (implicit or explicit), voice rising, victim placates. + 5. Peak — overt coercive demand, intimidation, or emotional abuse. +Use the intensity_arc list to map phases to intensity levels. + +DIALOGUE RULES +============== +1. Every utterance MUST be Hebrew text only — no transliteration, no English. +2. Follow the intensity arc chronologically across the turns. +3. Each turn is 1–4 sentences. Aim for roughly {{ (target_duration_minutes * 60 / 15) | int }} turns total. +4. Silence gaps (pause_before_s): 0.2–1.5 s; longer gaps at transitions. +5. emotional_state must be one of: neutral, angry, fearful, sad, pleading, threatening, calm. +6. Do NOT write binary Violence/Non-Violence labels. +7. The victim (VIC role) should respond authentically — fear, appeasement, or quiet resistance. +8. The perpetrator (AGG role) should display controlling language patterns common in intimate terror. + +OUTPUT FORMAT +============= +Return ONLY valid JSON — no prose, no markdown fences. + +{ + "turns": [ + { + "speaker_id": "", + "text": "", + "intensity": , + "pause_before_s": , + "emotional_state": "" + } + ] +} diff --git a/synthbanshee/script/types.py b/synthbanshee/script/types.py new file mode 100644 index 0000000..beb4842 --- /dev/null +++ b/synthbanshee/script/types.py @@ -0,0 +1,47 @@ +"""Shared types for the script generation and TTS mixing pipeline.""" + +from __future__ import annotations + +from dataclasses import dataclass, field + +import numpy as np + + +@dataclass +class DialogueTurn: + """One utterance from a single speaker in a generated scene script. + + Attributes: + speaker_id: Matches a SpeakerConfig.speaker_id in the scene. + text: Hebrew UTF-8 utterance text. Never placed in filenames. + intensity: 1–5, drives SSML style selection in TTSRenderer. + pause_before_s: Silence gap (seconds) inserted before this turn in the mix. + emotional_state: LLM-generated hint; used as a secondary style cue. + """ + + speaker_id: str + text: str + intensity: int + pause_before_s: float = 0.3 + emotional_state: str = "neutral" + + +@dataclass +class MixedScene: + """Audio result of mixing multiple per-turn TTS segments into one scene. + + Attributes: + samples: Float32 numpy array, mono, 16 kHz. + sample_rate: Always 16000. + turn_onsets_s: Per-turn onset time in seconds (after silence pad). + turn_offsets_s: Per-turn offset time in seconds. + duration_s: Total scene duration in seconds. + speaker_ids: Speaker ID for each turn (parallel with onsets/offsets). + """ + + samples: np.ndarray + sample_rate: int + turn_onsets_s: list[float] + turn_offsets_s: list[float] + duration_s: float + speaker_ids: list[str] = field(default_factory=list) diff --git a/synthbanshee/tts/__init__.py b/synthbanshee/tts/__init__.py index 364848d..e740da3 100644 --- a/synthbanshee/tts/__init__.py +++ b/synthbanshee/tts/__init__.py @@ -1,6 +1,7 @@ -"""TTS rendering: SSML builder, Azure provider, render cache.""" +"""TTS rendering: SSML builder, Azure provider, render cache, scene mixer.""" +from synthbanshee.tts.mixer import SceneMixer from synthbanshee.tts.renderer import TTSRenderer from synthbanshee.tts.ssml_builder import SSMLBuilder -__all__ = ["SSMLBuilder", "TTSRenderer"] +__all__ = ["SceneMixer", "SSMLBuilder", "TTSRenderer"] diff --git a/synthbanshee/tts/mixer.py b/synthbanshee/tts/mixer.py new file mode 100644 index 0000000..d2e0f7c --- /dev/null +++ b/synthbanshee/tts/mixer.py @@ -0,0 +1,89 @@ +"""SceneMixer: concatenate per-speaker TTS WAV segments into a single audio scene. + +Each segment is a (wav_bytes, pause_before_s, speaker_id) triple. The mixer +decodes WAV bytes using soundfile, resamples to 16 kHz if needed, prepends the +requested silence gap, and concatenates all segments into a single float32 mono +array while preserving speaker IDs in the mix metadata. + +The output MixedScene carries per-turn onset/offset times so the label generator +can derive event timing from the mix log rather than re-estimating it from the +final waveform. + +Spec reference: docs/spec.md §3.1 +""" + +from __future__ import annotations + +import io + +import numpy as np +import soundfile as sf + +from synthbanshee.augment.preprocessing import _resample +from synthbanshee.script.types import MixedScene + +_TARGET_SR = 16_000 + + +class SceneMixer: + """Mix a sequence of TTS segments into a single-track 16 kHz scene.""" + + def mix_sequential( + self, + segments: list[tuple[bytes, float, str]], + ) -> MixedScene: + """Concatenate segments in order, separated by silence gaps. + + Args: + segments: List of (wav_bytes, pause_before_s, speaker_id) triples. + wav_bytes must be valid WAV data (any SR / channels). + pause_before_s is inserted *before* each segment. + speaker_id is stored in the MixedScene for labelling. + + Returns: + MixedScene with all segments concatenated at 16 kHz mono. + """ + all_samples: list[np.ndarray] = [] + turn_onsets: list[float] = [] + turn_offsets: list[float] = [] + speaker_ids: list[str] = [] + current_pos_s: float = 0.0 + + for wav_bytes, pause_s, speaker_id in segments: + # --- Decode WAV --- + with io.BytesIO(wav_bytes) as buf: + data, src_sr = sf.read(buf, dtype="float32", always_2d=True) + + # Downmix to mono + mono = data.mean(axis=1) if data.shape[1] > 1 else data[:, 0] + + # Resample to 16 kHz if needed + if src_sr != _TARGET_SR: + mono = _resample(mono, src_sr, _TARGET_SR) + + # Prepend silence gap + if pause_s > 0.0: + silence = np.zeros(int(pause_s * _TARGET_SR), dtype=np.float32) + all_samples.append(silence) + current_pos_s += pause_s + + onset_s = current_pos_s + turn_onsets.append(onset_s) + + all_samples.append(mono.astype(np.float32)) + seg_duration_s = len(mono) / _TARGET_SR + current_pos_s += seg_duration_s + + turn_offsets.append(current_pos_s) + speaker_ids.append(speaker_id) + + combined = np.concatenate(all_samples) if all_samples else np.zeros(0, dtype=np.float32) + + return MixedScene( + samples=combined, + sample_rate=_TARGET_SR, + turn_onsets_s=turn_onsets, + turn_offsets_s=turn_offsets, + duration_s=float(len(combined)) / _TARGET_SR, + speaker_ids=speaker_ids, + ) diff --git a/synthbanshee/tts/renderer.py b/synthbanshee/tts/renderer.py index 1c746c1..40b6252 100644 --- a/synthbanshee/tts/renderer.py +++ b/synthbanshee/tts/renderer.py @@ -14,6 +14,7 @@ from pathlib import Path from synthbanshee.config.speaker_config import SpeakerConfig +from synthbanshee.script.types import DialogueTurn, MixedScene from synthbanshee.tts.azure_provider import AzureProvider from synthbanshee.tts.ssml_builder import SSMLBuilder @@ -136,3 +137,62 @@ def render_utterance_to_file( output_path.parent.mkdir(parents=True, exist_ok=True) output_path.write_bytes(wav_bytes) return output_path + + def render_scene( + self, + turns: list[DialogueTurn], + speakers: dict[str, SpeakerConfig], + *, + randomize: bool = False, + rng_seed: int | None = None, + disfluency: bool = False, + ) -> MixedScene: + """Render a multi-speaker dialogue script to a MixedScene. + + Each turn is rendered individually (with caching) then mixed into a + single audio stream by SceneMixer. Turn onset/offset times are derived + from the mix, not from TTS durations, so the label generator should + always use MixedScene.turn_onsets_s / turn_offsets_s. + + Args: + turns: Ordered list of DialogueTurn objects from ScriptGenerator. + speakers: Mapping from speaker_id to SpeakerConfig. + randomize: Apply small random prosody variation per turn. + rng_seed: Seed for reproducible prosody variation. + disfluency: If True, inject Hebrew filled pauses into each turn's + text using the speaker's disfluency profile. + + Returns: + MixedScene with concatenated audio and per-turn timing metadata. + + Raises: + KeyError: If a turn references a speaker_id not in *speakers*. + """ + import random + + from synthbanshee.script.generator import inject_disfluency + from synthbanshee.tts.mixer import SceneMixer + + rng = random.Random(rng_seed) + mixer = SceneMixer() + + segments: list[tuple[bytes, float, str]] = [] + for turn in turns: + speaker = speakers[turn.speaker_id] + text = turn.text + if disfluency: + text = inject_disfluency( + text, + prob=speaker.disfluency.filled_pause_prob, + rng_seed=rng.randint(0, 2**31), + ) + wav_bytes, _ = self.render_utterance( + text, + speaker, + turn.intensity, + randomize=randomize, + rng_seed=rng.randint(0, 2**31) if randomize else None, + ) + segments.append((wav_bytes, turn.pause_before_s, turn.speaker_id)) + + return mixer.mix_sequential(segments) diff --git a/tests/integration/test_multi_speaker.py b/tests/integration/test_multi_speaker.py new file mode 100644 index 0000000..3db92a8 --- /dev/null +++ b/tests/integration/test_multi_speaker.py @@ -0,0 +1,178 @@ +"""Integration test: multi-speaker TTS rendering via TTSRenderer.render_scene(). + +Wires DialogueTurn → TTSRenderer.render_scene() → SceneMixer → MixedScene. +Azure TTS calls are mocked with synthetic audio — no API credentials needed. +""" + +from __future__ import annotations + +import io +import wave +from pathlib import Path +from unittest.mock import MagicMock + +import numpy as np +import pytest + +from synthbanshee.config.speaker_config import SpeakerConfig +from synthbanshee.script.types import DialogueTurn +from synthbanshee.tts.azure_provider import AzureProvider +from synthbanshee.tts.renderer import TTSRenderer + +EXAMPLES_DIR = Path(__file__).parent.parent.parent / "configs" / "examples" + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_wav_bytes( + sample_rate: int = 24000, + duration_s: float = 2.0, + freq: float = 440.0, +) -> bytes: + n = int(sample_rate * duration_s) + t = np.linspace(0, duration_s, n, endpoint=False) + samples = (0.3 * np.sin(2 * np.pi * freq * t) * 32767).astype(np.int16) + buf = io.BytesIO() + with wave.open(buf, "w") as w: + w.setnchannels(1) + w.setsampwidth(2) + w.setframerate(sample_rate) + w.writeframes(samples.tobytes()) + return buf.getvalue() + + +def _mock_azure_factory(key, region): + synth = MagicMock() + mock_result = MagicMock() + mock_result.audio_data = _make_wav_bytes() + del mock_result.reason + synth.speak_ssml_async.return_value.get.return_value = mock_result + return synth + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture() +def renderer(tmp_path): + provider = AzureProvider(sdk_factory=_mock_azure_factory) + return TTSRenderer(provider=provider, cache_dir=tmp_path / "tts_cache") + + +@pytest.fixture() +def speakers(): + agg = SpeakerConfig.from_yaml(EXAMPLES_DIR / "speaker_AGG_M_30-45_001.yaml") + vic = SpeakerConfig.from_yaml(EXAMPLES_DIR / "speaker_VIC_F_25-40_002.yaml") + return {agg.speaker_id: agg, vic.speaker_id: vic} + + +@pytest.fixture() +def dialogue_turns(): + return [ + DialogueTurn( + speaker_id="AGG_M_30-45_001", + text="שלום, בוא נדבר", + intensity=1, + pause_before_s=0.0, + emotional_state="neutral", + ), + DialogueTurn( + speaker_id="VIC_F_25-40_002", + text="בסדר, על מה?", + intensity=1, + pause_before_s=0.3, + emotional_state="neutral", + ), + DialogueTurn( + speaker_id="AGG_M_30-45_001", + text="למה יצאת בלי לשאול אותי?", + intensity=3, + pause_before_s=0.2, + emotional_state="angry", + ), + ] + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestRenderScene: + def test_returns_mixed_scene(self, renderer, speakers, dialogue_turns): + scene = renderer.render_scene(dialogue_turns, speakers) + assert scene.sample_rate == 16000 + assert len(scene.samples) > 0 + + def test_turn_count_matches(self, renderer, speakers, dialogue_turns): + scene = renderer.render_scene(dialogue_turns, speakers) + assert len(scene.turn_onsets_s) == len(dialogue_turns) + assert len(scene.turn_offsets_s) == len(dialogue_turns) + assert len(scene.speaker_ids) == len(dialogue_turns) + + def test_speaker_ids_preserved(self, renderer, speakers, dialogue_turns): + scene = renderer.render_scene(dialogue_turns, speakers) + expected_ids = [t.speaker_id for t in dialogue_turns] + assert scene.speaker_ids == expected_ids + + def test_onsets_are_non_decreasing(self, renderer, speakers, dialogue_turns): + scene = renderer.render_scene(dialogue_turns, speakers) + for i in range(1, len(scene.turn_onsets_s)): + assert scene.turn_onsets_s[i] >= scene.turn_onsets_s[i - 1] + + def test_offsets_after_onsets(self, renderer, speakers, dialogue_turns): + scene = renderer.render_scene(dialogue_turns, speakers) + for onset, offset in zip(scene.turn_onsets_s, scene.turn_offsets_s, strict=True): + assert offset > onset + + def test_total_duration_positive(self, renderer, speakers, dialogue_turns): + scene = renderer.render_scene(dialogue_turns, speakers) + assert scene.duration_s > 0.0 + + def test_samples_are_float32(self, renderer, speakers, dialogue_turns): + scene = renderer.render_scene(dialogue_turns, speakers) + assert scene.samples.dtype == np.float32 + + def test_first_pause_zero_means_immediate_start(self, renderer, speakers, dialogue_turns): + """First turn with pause_before_s=0 should start at t=0.""" + scene = renderer.render_scene(dialogue_turns, speakers) + assert scene.turn_onsets_s[0] == pytest.approx(0.0, abs=0.01) + + def test_second_turn_onset_reflects_pause(self, renderer, speakers, dialogue_turns): + """Second turn onset should be ~first-turn offset + pause_before_s.""" + scene = renderer.render_scene(dialogue_turns, speakers) + expected = scene.turn_offsets_s[0] + dialogue_turns[1].pause_before_s + assert scene.turn_onsets_s[1] == pytest.approx(expected, abs=0.05) + + def test_unknown_speaker_raises(self, renderer, dialogue_turns): + """render_scene should raise KeyError for unknown speaker_id.""" + with pytest.raises(KeyError): + renderer.render_scene(dialogue_turns, speakers={}) + + def test_disfluency_flag_accepted(self, renderer, speakers, dialogue_turns): + """disfluency=True should not raise (even if it changes text).""" + scene = renderer.render_scene(dialogue_turns, speakers, disfluency=True) + assert scene.sample_rate == 16000 + + def test_randomize_produces_different_cache_keys(self, renderer, speakers, tmp_path): + """Two randomize=True calls with different rng_seed should call TTS separately.""" + turns = [ + DialogueTurn( + speaker_id="AGG_M_30-45_001", + text="שלום", + intensity=2, + ) + ] + scene1 = renderer.render_scene(turns, speakers, randomize=True, rng_seed=1) + # Clear cache to force re-render + for f in (tmp_path / "tts_cache").glob("*.wav"): + f.unlink() + scene2 = renderer.render_scene(turns, speakers, randomize=True, rng_seed=99) + # Both should produce valid scenes (behaviour tested, not identity) + assert scene1.duration_s > 0.0 + assert scene2.duration_s > 0.0 diff --git a/tests/unit/test_mixer.py b/tests/unit/test_mixer.py new file mode 100644 index 0000000..b5859d9 --- /dev/null +++ b/tests/unit/test_mixer.py @@ -0,0 +1,148 @@ +"""Unit tests for synthbanshee.tts.mixer.SceneMixer.""" + +from __future__ import annotations + +import io +import wave + +import numpy as np +import pytest + +from synthbanshee.tts.mixer import _TARGET_SR, SceneMixer + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _sine_wav_bytes( + freq: float = 440.0, + duration_s: float = 1.0, + sample_rate: int = 16000, + amplitude: float = 0.3, +) -> bytes: + """Return WAV bytes for a mono sine wave at the given frequency.""" + n = int(sample_rate * duration_s) + t = np.linspace(0, duration_s, n, endpoint=False) + samples = (amplitude * np.sin(2 * np.pi * freq * t) * 32767).astype(np.int16) + buf = io.BytesIO() + with wave.open(buf, "w") as w: + w.setnchannels(1) + w.setsampwidth(2) + w.setframerate(sample_rate) + w.writeframes(samples.tobytes()) + return buf.getvalue() + + +def _stereo_wav_bytes(duration_s: float = 1.0, sample_rate: int = 16000) -> bytes: + """Return WAV bytes for a stereo file (to test downmix).""" + n = int(sample_rate * duration_s) + t = np.linspace(0, duration_s, n, endpoint=False) + ch1 = (0.3 * np.sin(2 * np.pi * 440 * t) * 32767).astype(np.int16) + ch2 = (0.3 * np.sin(2 * np.pi * 880 * t) * 32767).astype(np.int16) + stereo = np.stack([ch1, ch2], axis=1) + buf = io.BytesIO() + with wave.open(buf, "w") as w: + w.setnchannels(2) + w.setsampwidth(2) + w.setframerate(sample_rate) + w.writeframes(stereo.tobytes()) + return buf.getvalue() + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestSceneMixer: + def test_empty_segments_returns_zero_length(self): + mixer = SceneMixer() + result = mixer.mix_sequential([]) + assert result.duration_s == 0.0 + assert len(result.samples) == 0 + assert result.turn_onsets_s == [] + assert result.turn_offsets_s == [] + + def test_single_segment_no_pause(self): + mixer = SceneMixer() + wav = _sine_wav_bytes(duration_s=1.0, sample_rate=_TARGET_SR) + result = mixer.mix_sequential([(wav, 0.0, "SPK_001")]) + + assert result.sample_rate == _TARGET_SR + assert len(result.turn_onsets_s) == 1 + assert result.turn_onsets_s[0] == pytest.approx(0.0) + assert result.turn_offsets_s[0] == pytest.approx(1.0, abs=0.05) + assert result.duration_s == pytest.approx(1.0, abs=0.05) + assert result.speaker_ids == ["SPK_001"] + + def test_pause_shifts_onset(self): + mixer = SceneMixer() + wav = _sine_wav_bytes(duration_s=0.5, sample_rate=_TARGET_SR) + pause_s = 0.3 + result = mixer.mix_sequential([(wav, pause_s, "SPK_001")]) + + assert result.turn_onsets_s[0] == pytest.approx(pause_s, abs=0.01) + assert result.duration_s == pytest.approx(pause_s + 0.5, abs=0.05) + + def test_two_segments_sequential(self): + mixer = SceneMixer() + wav1 = _sine_wav_bytes(freq=440, duration_s=1.0, sample_rate=_TARGET_SR) + wav2 = _sine_wav_bytes(freq=880, duration_s=0.5, sample_rate=_TARGET_SR) + result = mixer.mix_sequential( + [ + (wav1, 0.0, "SPK_A"), + (wav2, 0.2, "SPK_B"), + ] + ) + + assert len(result.turn_onsets_s) == 2 + # Second turn onset = duration of first + pause + expected_onset2 = result.turn_offsets_s[0] + 0.2 + assert result.turn_onsets_s[1] == pytest.approx(expected_onset2, abs=0.05) + assert result.speaker_ids == ["SPK_A", "SPK_B"] + + def test_total_duration_matches_samples(self): + mixer = SceneMixer() + segments = [ + (_sine_wav_bytes(duration_s=0.8, sample_rate=_TARGET_SR), 0.1, "S1"), + (_sine_wav_bytes(duration_s=0.6, sample_rate=_TARGET_SR), 0.2, "S2"), + (_sine_wav_bytes(duration_s=0.4, sample_rate=_TARGET_SR), 0.15, "S3"), + ] + result = mixer.mix_sequential(segments) + computed_duration = len(result.samples) / result.sample_rate + assert result.duration_s == pytest.approx(computed_duration, abs=1e-4) + + def test_resamples_24k_to_16k(self): + """Mixer should downsample 24 kHz input to 16 kHz output.""" + mixer = SceneMixer() + wav_24k = _sine_wav_bytes(duration_s=0.5, sample_rate=24000) + result = mixer.mix_sequential([(wav_24k, 0.0, "SPK_001")]) + + assert result.sample_rate == _TARGET_SR + # Duration should still be approximately 0.5 s + assert result.duration_s == pytest.approx(0.5, abs=0.05) + + def test_downmixes_stereo_to_mono(self): + mixer = SceneMixer() + stereo_wav = _stereo_wav_bytes(duration_s=1.0, sample_rate=_TARGET_SR) + result = mixer.mix_sequential([(stereo_wav, 0.0, "SPK_001")]) + + assert result.samples.ndim == 1 + assert result.duration_s == pytest.approx(1.0, abs=0.05) + + def test_output_samples_are_float32(self): + mixer = SceneMixer() + wav = _sine_wav_bytes(duration_s=0.5, sample_rate=_TARGET_SR) + result = mixer.mix_sequential([(wav, 0.0, "SPK_001")]) + assert result.samples.dtype == np.float32 + + def test_offsets_greater_than_onsets(self): + mixer = SceneMixer() + segments = [ + (_sine_wav_bytes(duration_s=0.5, sample_rate=_TARGET_SR), 0.1, "S1"), + (_sine_wav_bytes(duration_s=0.5, sample_rate=_TARGET_SR), 0.2, "S2"), + ] + result = mixer.mix_sequential(segments) + for onset, offset in zip(result.turn_onsets_s, result.turn_offsets_s, strict=True): + assert offset > onset diff --git a/tests/unit/test_script_generator.py b/tests/unit/test_script_generator.py new file mode 100644 index 0000000..73c6d22 --- /dev/null +++ b/tests/unit/test_script_generator.py @@ -0,0 +1,375 @@ +"""Unit tests for synthbanshee.script.generator and related utilities.""" + +from __future__ import annotations + +import json +import sys +from pathlib import Path +from types import ModuleType +from unittest.mock import MagicMock, patch + +import pytest + +from synthbanshee.script.generator import ( + ScriptGenerator, + inject_disfluency, + validate_script, +) +from synthbanshee.script.types import DialogueTurn + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +_SPEAKER_IDS = {"AGG_M_30-45_001", "VIC_F_25-40_002"} + +_VALID_TURNS = [ + DialogueTurn( + speaker_id="AGG_M_30-45_001", + text="שלום, מה שלומך היום?", + intensity=1, + pause_before_s=0.3, + emotional_state="neutral", + ), + DialogueTurn( + speaker_id="VIC_F_25-40_002", + text="בסדר, תודה. איך הולך לך?", + intensity=1, + pause_before_s=0.5, + emotional_state="neutral", + ), + DialogueTurn( + speaker_id="AGG_M_30-45_001", + text="אני לא מרוצה. צריך לדבר על משהו חשוב.", + intensity=3, + pause_before_s=0.3, + emotional_state="angry", + ), +] + +_VALID_TURNS_JSON = { + "turns": [ + { + "speaker_id": t.speaker_id, + "text": t.text, + "intensity": t.intensity, + "pause_before_s": t.pause_before_s, + "emotional_state": t.emotional_state, + } + for t in _VALID_TURNS + ] +} + + +# --------------------------------------------------------------------------- +# inject_disfluency +# --------------------------------------------------------------------------- + + +class TestInjectDisfluency: + def test_no_op_on_single_sentence(self): + text = "שלום מה שלומך" + result = inject_disfluency(text, prob=1.0, rng_seed=0) + assert result == text + + def test_inserts_pause_between_sentences(self): + text = "שלום. מה שלומך?" + result = inject_disfluency(text, prob=1.0, rng_seed=0) + # A Hebrew pause token should appear between the two sentences + assert any(p in result for p in ["אממ", "אה", "אנ"]) + # Original sentences are still present + assert "שלום" in result + assert "מה שלומך" in result + + def test_zero_prob_no_change(self): + text = "שלום. מה שלומך? טוב מאוד." + result = inject_disfluency(text, prob=0.0, rng_seed=42) + assert result == text + + def test_reproducible_with_seed(self): + text = "טוב. מה נשמע? הכל בסדר." + r1 = inject_disfluency(text, prob=0.8, rng_seed=7) + r2 = inject_disfluency(text, prob=0.8, rng_seed=7) + assert r1 == r2 + + def test_different_seeds_may_differ(self): + text = "טוב. מה נשמע? הכל בסדר. ומה איתך?" + results = {inject_disfluency(text, prob=0.5, rng_seed=s) for s in range(20)} + # With prob 0.5 over multiple seeds we expect at least two distinct outputs + assert len(results) >= 2 + + +# --------------------------------------------------------------------------- +# validate_script +# --------------------------------------------------------------------------- + + +class TestValidateScript: + def test_valid_turns_no_errors(self): + errors = validate_script(_VALID_TURNS, _SPEAKER_IDS) + assert errors == [] + + def test_empty_text_flagged(self): + bad = [DialogueTurn(speaker_id="AGG_M_30-45_001", text=" ", intensity=1)] + errors = validate_script(bad, _SPEAKER_IDS) + assert any("empty text" in e for e in errors) + + def test_unknown_speaker_flagged(self): + bad = [DialogueTurn(speaker_id="UNKNOWN_001", text="שלום", intensity=1)] + errors = validate_script(bad, _SPEAKER_IDS) + assert any("speaker_id" in e for e in errors) + + def test_invalid_intensity_flagged(self): + bad = [DialogueTurn(speaker_id="AGG_M_30-45_001", text="שלום", intensity=0)] + errors = validate_script(bad, _SPEAKER_IDS) + assert any("intensity" in e for e in errors) + + def test_repetition_flagged(self): + repeated = "שלום " * 5 + bad = [DialogueTurn(speaker_id="AGG_M_30-45_001", text=repeated, intensity=1)] + errors = validate_script(bad, _SPEAKER_IDS) + assert any("consecutive" in e for e in errors) + + def test_non_hebrew_text_flagged(self): + bad = [DialogueTurn(speaker_id="AGG_M_30-45_001", text="hello world", intensity=1)] + errors = validate_script(bad, _SPEAKER_IDS) + assert any("Hebrew" in e for e in errors) + + def test_negative_pause_flagged(self): + bad = [ + DialogueTurn( + speaker_id="AGG_M_30-45_001", text="שלום", intensity=1, pause_before_s=-0.1 + ) + ] + errors = validate_script(bad, _SPEAKER_IDS) + assert any("pause_before_s" in e for e in errors) + + def test_excessive_pause_flagged(self): + bad = [ + DialogueTurn(speaker_id="AGG_M_30-45_001", text="שלום", intensity=1, pause_before_s=2.0) + ] + errors = validate_script(bad, _SPEAKER_IDS) + assert any("pause_before_s" in e for e in errors) + + def test_valid_pause_no_error(self): + ok = [ + DialogueTurn(speaker_id="AGG_M_30-45_001", text="שלום", intensity=1, pause_before_s=1.0) + ] + errors = validate_script(ok, _SPEAKER_IDS) + assert not any("pause_before_s" in e for e in errors) + + +# --------------------------------------------------------------------------- +# ScriptGenerator — cache +# --------------------------------------------------------------------------- + + +class TestScriptGeneratorCache: + def _make_generator(self, tmp_path: Path) -> ScriptGenerator: + return ScriptGenerator(provider="anthropic", cache_dir=tmp_path / "scripts") + + def _scene_kwargs(self, scene_id: str = "TEST_001") -> dict: + return dict( + scene_id=scene_id, + project="she_proves", + violence_typology="IT", + script_template="synthbanshee/script/templates/she_proves/intimate_terror_coercive_control.j2", + script_slots={"relationship": "spouse", "setting": "kitchen"}, + intensity_arc=[1, 2, 3], + target_duration_minutes=1.0, + speakers=[ + {"speaker_id": "AGG_M_30-45_001", "role": "AGG", "gender": "male"}, + {"speaker_id": "VIC_F_25-40_002", "role": "VIC", "gender": "female"}, + ], + random_seed=0, + ) + + def test_cache_hit_skips_llm(self, tmp_path: Path): + gen = self._make_generator(tmp_path) + # Pre-populate cache + key = gen._cache_key( + "TEST_001", + "synthbanshee/script/templates/she_proves/intimate_terror_coercive_control.j2", + {"relationship": "spouse", "setting": "kitchen"}, + [1, 2, 3], + 0, + ["AGG_M_30-45_001", "VIC_F_25-40_002"], + ) + gen._save_to_cache(key, _VALID_TURNS) + + with patch.object(gen, "_call_llm") as mock_llm: + turns = gen.generate(**self._scene_kwargs()) + + mock_llm.assert_not_called() + assert len(turns) == len(_VALID_TURNS) + assert turns[0].speaker_id == _VALID_TURNS[0].speaker_id + + def test_cache_miss_calls_llm_and_saves(self, tmp_path: Path): + gen = self._make_generator(tmp_path) + + with patch.object(gen, "_call_llm", return_value=json.dumps(_VALID_TURNS_JSON)): + turns = gen.generate(**self._scene_kwargs("SCENE_NEW")) + + assert len(turns) == len(_VALID_TURNS) + # Cache file should now exist + key = gen._cache_key( + "SCENE_NEW", + "synthbanshee/script/templates/she_proves/intimate_terror_coercive_control.j2", + {"relationship": "spouse", "setting": "kitchen"}, + [1, 2, 3], + 0, + ["AGG_M_30-45_001", "VIC_F_25-40_002"], + ) + assert gen._cache_path(key).exists() + + def test_validation_error_raises(self, tmp_path: Path): + gen = self._make_generator(tmp_path) + bad_response = json.dumps( + { + "turns": [ + { + "speaker_id": "UNKNOWN_SPEAKER", + "text": "hello world", # no Hebrew, unknown speaker + "intensity": 1, + "pause_before_s": 0.3, + "emotional_state": "neutral", + } + ] + } + ) + + with ( + patch.object(gen, "_call_llm", return_value=bad_response), + pytest.raises(ValueError, match="Script validation failed"), + ): + gen.generate(**self._scene_kwargs("BAD_SCENE")) + + +# --------------------------------------------------------------------------- +# ScriptGenerator — _parse_response +# --------------------------------------------------------------------------- + + +class TestParseResponse: + def test_plain_json(self): + turns = ScriptGenerator._parse_response(json.dumps(_VALID_TURNS_JSON)) + assert len(turns) == len(_VALID_TURNS) + assert turns[0].speaker_id == _VALID_TURNS[0].speaker_id + + def test_json_in_markdown_fence(self): + raw = f"```json\n{json.dumps(_VALID_TURNS_JSON)}\n```" + turns = ScriptGenerator._parse_response(raw) + assert len(turns) == len(_VALID_TURNS) + + def test_json_in_plain_fence(self): + raw = f"```\n{json.dumps(_VALID_TURNS_JSON)}\n```" + turns = ScriptGenerator._parse_response(raw) + assert len(turns) == len(_VALID_TURNS) + + def test_invalid_json_raises(self): + with pytest.raises((json.JSONDecodeError, ValueError)): + ScriptGenerator._parse_response("not json") + + +# --------------------------------------------------------------------------- +# ScriptGenerator — provider validation +# --------------------------------------------------------------------------- + + +class TestScriptGeneratorInit: + def test_invalid_provider_raises(self): + with pytest.raises(ValueError, match="provider"): + ScriptGenerator(provider="gemini") + + def test_default_models(self): + gen_a = ScriptGenerator(provider="anthropic") + assert "claude" in gen_a._model + + gen_o = ScriptGenerator(provider="openai") + assert "gpt" in gen_o._model + + def test_custom_model(self): + gen = ScriptGenerator(provider="anthropic", model="claude-haiku-4-5-20251001") + assert gen._model == "claude-haiku-4-5-20251001" + + +# --------------------------------------------------------------------------- +# ScriptGenerator — LLM provider dispatch (_call_llm, _call_anthropic, _call_openai) +# --------------------------------------------------------------------------- + + +class TestLLMDispatch: + """Cover _call_anthropic, _call_openai, and _call_llm routing.""" + + def test_call_llm_routes_to_anthropic(self, tmp_path: Path): + gen = ScriptGenerator(provider="anthropic", cache_dir=tmp_path) + with patch.object(gen, "_call_anthropic", return_value="resp") as mock_a: + result = gen._call_llm("prompt") + mock_a.assert_called_once_with("prompt") + assert result == "resp" + + def test_call_llm_routes_to_openai(self, tmp_path: Path): + gen = ScriptGenerator(provider="openai", cache_dir=tmp_path) + with patch.object(gen, "_call_openai", return_value="resp") as mock_o: + result = gen._call_llm("prompt") + mock_o.assert_called_once_with("prompt") + assert result == "resp" + + def test_call_anthropic_uses_sdk(self, tmp_path: Path): + gen = ScriptGenerator(provider="anthropic", model="claude-test", cache_dir=tmp_path) + mock_message = MagicMock() + mock_message.content = [MagicMock(text="שלום")] + mock_client = MagicMock() + mock_client.messages.create.return_value = mock_message + + mock_anthropic_mod = ModuleType("anthropic") + mock_anthropic_mod.Anthropic = MagicMock(return_value=mock_client) # type: ignore[attr-defined] + + with patch.dict(sys.modules, {"anthropic": mock_anthropic_mod}): + result = gen._call_anthropic("test prompt") + + mock_client.messages.create.assert_called_once_with( + model="claude-test", + max_tokens=4096, + messages=[{"role": "user", "content": "test prompt"}], + ) + assert result == "שלום" + + def test_call_openai_uses_sdk(self, tmp_path: Path): + gen = ScriptGenerator(provider="openai", model="gpt-test", cache_dir=tmp_path) + mock_choice = MagicMock() + mock_choice.message.content = "שלום מהאי" + mock_response = MagicMock() + mock_response.choices = [mock_choice] + mock_client = MagicMock() + mock_client.chat.completions.create.return_value = mock_response + + mock_openai_mod = ModuleType("openai") + mock_openai_mod.OpenAI = MagicMock(return_value=mock_client) # type: ignore[attr-defined] + + with patch.dict(sys.modules, {"openai": mock_openai_mod}): + result = gen._call_openai("test prompt") + + mock_client.chat.completions.create.assert_called_once_with( + model="gpt-test", + messages=[{"role": "user", "content": "test prompt"}], + max_tokens=4096, + ) + assert result == "שלום מהאי" + + def test_call_openai_empty_content_returns_empty_string(self, tmp_path: Path): + gen = ScriptGenerator(provider="openai", model="gpt-test", cache_dir=tmp_path) + mock_choice = MagicMock() + mock_choice.message.content = None + mock_response = MagicMock() + mock_response.choices = [mock_choice] + mock_client = MagicMock() + mock_client.chat.completions.create.return_value = mock_response + + mock_openai_mod = ModuleType("openai") + mock_openai_mod.OpenAI = MagicMock(return_value=mock_client) # type: ignore[attr-defined] + + with patch.dict(sys.modules, {"openai": mock_openai_mod}): + result = gen._call_openai("prompt") + + assert result == ""