In [1]:
import os
import time
import json
import re
import sys
import csv
 
import gradio as gr
from openai import OpenAI
from dotenv import load_dotenv
from pydantic import BaseModel, ConfigDict, field_validator, Field, RootModel
from enum import Enum
from typing import List, Optional, Dict, Tuple, Any
from datetime import date
from concurrent import futures
from tqdm import tqdm
from pathlib import Path
from openai.lib._pydantic import to_strict_json_schema

# Add parent directory to path to import from implementation package
# Notebooks are in implementation/notebooks/, so we go up two levels to project root
sys.path.insert(0, str(Path().resolve().parent.parent))

from implementation.prompts.vector_subquery_prompts import (
    VECTOR_QUERY_PROMPTS
)
from implementation.prompts.vector_weights_prompts import (
    VECTOR_WEIGHT_PROMPTS
)
from implementation.prompts.metadata_preferences_prompts import (
    ALL_METADATA_EXTRACTION_PROMPTS
)
from implementation.enums import Genre

# Load environment variables (for API key)
load_dotenv()

  from .autonotebook import tqdm as notebook_tqdm


True

## Generic Kimmi K Calling Function

In [2]:
kimi_api_key = os.environ.get("MOONSHOT_API_KEY")
if not kimi_api_key:
    raise ValueError(
        "MOONSHOT_API_KEY environment variable not set. "
        "Please set it before importing this module."
    )
else:
    print(f"Kimi API key loaded: {kimi_api_key[:8]}...")

kimi_client = OpenAI(
    api_key=kimi_api_key,
    base_url="https://api.moonshot.ai/v1",
)

# Get OpenAI API key from environment and initialize client once at module load
openai_api_key = os.getenv("OPENAI_API_KEY")
if not openai_api_key:
    raise ValueError(
        "OPENAI_API_KEY environment variable not set. "
        "Please set it before importing this module."
    )
else:
    print(f"OpenAI API key loaded: {openai_api_key[:8]}...")

# Initialize OpenAI client - created once when module is loaded
openai_client = OpenAI(api_key=openai_api_key)

print("Successfully initialized both clients")

Kimi API key loaded: sk-PCwbW...
OpenAI API key loaded: sk-proj-...
Successfully initialized both clients


In [3]:
def generate_openai_response(
    user_prompt: str,
    system_prompt: str,
    response_format: BaseModel,
    model: str = "gpt-5-mini",
    reasoning_effort: str = "low",
    verbosity: str = "low"
):
    response = openai_client.chat.completions.parse(
        model=model,
        messages=[
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt}
        ],
        response_format=response_format,
        reasoning_effort=reasoning_effort,
        verbosity=verbosity
    )
    
    # Extract the parsed response - OpenAI automatically validates structure matches PlotMetadata
    message = response.choices[0].message
    if message.parsed:
        return message.parsed
    else:
        # Handle case where model refuses to generate output
        raise ValueError(f"OpenAI failed to generate response: {message.refusal}")

def generate_kimi_response(
    user_prompt: str,
    system_prompt: str,
    response_format: BaseModel,
    enable_thinking: bool = False,
):
    try:
        thinking_type = "enabled" if enable_thinking else "disabled"
        schema = to_strict_json_schema(response_format)

        response = kimi_client.chat.completions.create(
            model="kimi-k2.5",
            messages=[
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": user_prompt}
            ],
            response_format={
                "type": "json_schema",
                "json_schema": {
                    "name": "test_metadata",
                    "strict": True,
                    "schema": schema,
                },
            },
            extra_body={
                "thinking": {"type": thinking_type}
            }
        )
        
        # Extract the parsed response - OpenAI automatically validates structure matches PlotMetadata
        raw = response.choices[0].message.content
        data = json.loads(raw)
        metadata = response_format.model_validate(data)
        return metadata
    except Exception as e:
        raise ValueError(f"Kimi failed to generate response: {e}")

In [None]:
class TestMetadata(BaseModel):
    favorite_color: str
    favorite_number: float

response = generate_kimi_response(
    user_prompt="Give me your favorite color and number",
    system_prompt="You are a helpful assistant that can answer questions and help with tasks. Your answer must follow the provided JSON schema.",
    response_format=TestMetadata
)

response

## Test Queries

In [4]:
routing_queries = [
  "manly action movies from the 80s",
  "movies like parasite but american and funnier not dumb",
  "1990s french psych thriller not slow not frantic nonlinear timeline unreliable narrator no gore but creepy critics said beautiful cinematography plot holes",
  "hand drawn animation not cgi spanish audio coming of age dramedy uplifting and hopeful iconic songs great dialogue",
  "two strangers handcuffed together escape the city in one night set in tokyo time loop twist ending",
  "date night movie to unwind after a long day funny but not gross no jump scares",
  "low budget indie filmed in new york directed by nolan?? (or similar vibe) mixed reviews overrated but still smart",
  "science fiction war epic intergalactic warfare morally gray lead ticking clock deadline red herrings",
  "love-to-hate villain redemption arc but also unreliable narrator fourth wall breaks",
  "cozy sick day comfort watch background at a party not too loud not overstimulating ear bursting sound avoid that",
  "a goofy movie but i mean goofy like silly not the title, 90s vibe, witty dialogue, not slow",
  "romcom about two rival bakers, light and flirty, 00s vibe",
  "doc about free solo climbers, inspiring but not preachy",
  "YA fantasy with a chosen one prophecy, not too dark, PG-13",
  "set in Boston during a blizzard, but filmed in Toronto",
  "something on Netflix under 90 minutes",
  "critics hated it but I love it anyway, fun guilty pleasure",
  "Oscar-winning cinematography, but the story is messy",
  "real-time thriller in one apartment, ticking clock deadline, no flashbacks",
  "found footage horror, no jump scares, creepy dread",
  "multiple POVs, unreliable narrator, twist ending explained at the end",
  "movies with Jack Sparrow energy but not Pirates, witty swashbuckling",
  "directed by Quinten Tarantino, snappy dialogue, violent but funny",
  "Her but not the one with Joaquin Phoenix",
  "ultra-gory body horror, disgusting, make me squirm",
  "background while coding, dialogue not important, chill visuals, low volume",
  "family movie night with kids, not babyish, jokes for adults too",
  "Korean audio with English subtitles, critics called it a slow-burn masterpiece",
  "adapted from a video game, big studio blockbuster, mixed reviews, amazing fight choreography",
  "set in ancient Rome, political betrayal, ends on a bleak note",
  "A24 vibe, but I want it less depressing, more hopeful"
]

## Lexical Entity

In [33]:
class EntityCategory(Enum):
    """Enum representing the various categories of lexical entities."""
    CHARACTER = "character"
    FRANCHISE = "franchise"
    MOVIE_TITLE = "movie_title"
    PERSON = "person"
    STUDIO = "studio"


class ExtractedEntity(BaseModel):
    model_config = ConfigDict(use_enum_values=True)
    
    candidate_entity_phrase: str
    most_likely_category: EntityCategory
    exclude_from_results: bool
    corrected_and_normalized_entity: str

class ExtractedEntities(BaseModel):
    entity_candidates: List[ExtractedEntity]

In [34]:
EXTRACT_LEXICAL_ENTITIES_SYSTEM_PROMPT = """\
You are an expert at understanding movie search queries. Your job is to extract all \
lexical entities from the provided search query.

GOALS:
- Identify all lexical entities contained within the search query
- Correctly categorize each lexical entity (movie title, person, character, studio)
- Normalize each lexical entity to its canonical form

INPUT:
You will receive text representing the full movie search query entered by the user.

OUTPUT:
JSON schema. A list of JSON objects, each representing a single lexical entity.
- candidate_entity_phrase: The original verbatim word / phrase from the search query that represents a single lexical entity.
- most_likely_category: The category that best represents this lexical entity.
- exclude_from_results: Whether the user is trying to find movies that contain this entity or DON'T contain this entity (ex. "Not starring Tom Cruise" means DON'T contain Tom Cruise).
- corrected_and_normalized_entity: The MOST LIKELY corrected and normalized form of the typed entity. Represents how that entity would appear on an official movie website or movie poster.

ENTITY CATEGORIES:
- movie_title: Represents a substring or the entirety of a SPECIFIC movie title.
  - Case #1: The query contains a word or phrase that clearly and obviously is the title of a movie. (ex. "shawshank redemption", "fight club", "movies like dark knight")
  - Case #2: In the query the user is explicitly searching for movies with a given substring in the title (ex. "movies with the word 'clown' in the title)
- franchise: Represents a specific media brand (ex. "The Matrix", "Spongebob Squarepants", "Barbie")
- person: Represents the name of a real human who worked on this movie (actor, writer, composer, etc.).
- character: Represents the name of a character who appears in this movie.
- studio: Represents the name of a movie studio that produced this movie.

CORRECTIONS & NORMALIZATIONS:
- HIGH-CONFIDENCE (>95%) terms only
- Clear spelling mistakes (ex. "Leandro Dicaprio" -> "Leonardo DiCaprio")
- Normalized punctuation and numerical formats (ex. "rocky 2" --> "rocky ii", "seven" --> "se7en")
- Obvious acronym expansions (ex. "LOTR" -> "Lord of the Rings")
- NEVER introduce additional information not already present in the original query (ex. "star wars" -> "Star Wars: Episode IV - \
A New Hope" is BAD because the user never specified which specific Star Wars movie they are looking for)
- Introducing additional information not present in the original query is a catastrophic failure.

ADDITIONAL GUIDANCE:
- All values must be nonnull. Providing a null value is a catastrophic failure. Providing None as a value is a catastrophic failure.
- most_likely_category MUST be "movie_title", "franchise", "person", "character", or "studio"
- corrected_and_normalized_entity must be the highest confidence correction / normalization of the user-typed entity.
- Only extract words or phrases that are highly likely to be a lexical entity.
- DO NOT extract words or phrases that simply describe traits of the movie. They MUST be related to specific lexical entities.\
"""

In [196]:
entity_query = "john wick type movies" 

In [197]:
kimi_response = generate_kimi_response(
    user_prompt=f"User query: \"{entity_query}\"",
    system_prompt=EXTRACT_LEXICAL_ENTITIES_SYSTEM_PROMPT,
    response_format=ExtractedEntities,
)

kimi_response

ExtractedEntities(entity_candidates=[ExtractedEntity(candidate_entity_phrase='john wick', most_likely_category='franchise', exclude_from_results=False, corrected_and_normalized_entity='John Wick')])

## Metadata Preferences

In [5]:
class DateMatchOperation(Enum):
    EXACT = "exact"
    BEFORE = "before"
    AFTER = "after"
    BETWEEN = "between"

class NumericalMatchOperation(Enum):
    EXACT = "exact"
    BETWEEN = "between"
    LESS_THAN = "less_than"
    GREATER_THAN = "greater_than"

class RatingMatchOperation(Enum):
    EXACT = "exact"
    GREATER_THAN = "greater_than"
    LESS_THAN = "less_than"
    GREATER_THAN_OR_EQUAL = "greater_than_or_equal"
    LESS_THAN_OR_EQUAL = "less_than_or_equal"

class RatingPreference(Enum):
    CRITICALLY_ACCLAIMED = "critically_acclaimed"
    POORLY_RECEIVED = "poorly_received"

class StreamingAccessType(Enum):
    SUBSCRIPTION = "subscription"
    RENT = "rent"
    BUY = "buy"



class DatePreference(BaseModel):
    model_config = ConfigDict(use_enum_values=True)

    first_date: str = Field(
        ..., 
        pattern=r"^\d{4}-\d{2}-\d{2}$",
        description="Either the first date in the range or the exact date to match. ISO 8601 date: YYYY-MM-DD",
    )
    match_operation: DateMatchOperation = Field(..., description="Whether we want the date to be before, after, or exactly at the first date, or between the two provided dates.")
    second_date: Optional[str] = Field(
        default=None, 
        pattern=r"^\d{4}-\d{2}-\d{2}$",
        description="Optional second date in the range only if match_operation is BETWEEN. ISO 8601 date: YYYY-MM-DD"
    )
    

class NumericalPreference(BaseModel):
    model_config = ConfigDict(use_enum_values=True)

    first_value: float = Field(..., description="Either the first value in the range or the exact value to match.")
    match_operation: NumericalMatchOperation = Field(..., description="How we should evaluate the provided first and (maybe) second values.")
    second_value: Optional[float] = Field(default=None, description="Optional second value in the range only if match_operation is BETWEEN.")


class ListPreference(BaseModel):
    should_include: List[str] = Field(default=[], description="List of items that should be included in the movie's metadata.")
    should_exclude: List[str] = Field(default=[], description="List of items that should be excluded from the movie's metadata.")


class GenreListPreference(BaseModel):
    model_config = ConfigDict(use_enum_values=True)

    should_include: List[Genre] = Field(default=[], description="List of genres that the user's query wants the movie to fall under.")
    should_exclude: List[Genre] = Field(default=[], description="List of genres that the user's query wants to avoid in the movie.")


class MaturityPreference(BaseModel):
    model_config = ConfigDict(use_enum_values=True)

    rating: str = Field(..., description="Standard USA ratings: G, PG, PG-13, R, NC-17")
    match_operation: RatingMatchOperation = Field(..., description="Whether we prefer movies with this rating, greater (more mature), or less (less mature).")


class PopularTrendingPreference(BaseModel):
    prefers_trending_movies: bool
    prefers_popular_movies: bool


class WatchProvidersPreference(BaseModel):
    model_config = ConfigDict(use_enum_values=True)
    
    should_include: List[str]
    should_exclude: List[str]
    preferred_access_type: Optional[StreamingAccessType]



class MetadataPreferences(BaseModel):
    release_date_preference: Optional[DatePreference]
    duration_preference: Optional[NumericalPreference]
    genres_preference: Optional[GenreListPreference]
    audio_languages_preference: Optional[ListPreference]
    watch_providers_preference: Optional[WatchProvidersPreference]
    maturity_rating_preference: Optional[MaturityPreference]
    popular_trending_preference: PopularTrendingPreference
    rating_preference: Optional[RatingPreference]


In [6]:
metadata_queries = [
    "brisk 90s action flick, nothing plodding",
    "Portuguese thriller with English subtitles, won something at Sundance",
    "everyone saw it but critics were mixed, big summer tentpole",
    "leisurely paced drama, I have all afternoon",
    "Taiwanese coming-of-age, light and breezy, nothing heavy",
    "something trashy and fun, totally panned, perfect for wine night",
    "late 2010s superhero film, appropriate for my 12-year-old",
    "moody Nordic noir, could be Swedish or Danish",
    "tightly edited, under 100 minutes, no filler",
    "arthouse darling that flopped commercially",
    "streaming free on Tubi, campy 80s horror, the cheesier the better",
    "beautifully shot but narratively messy, visually stunning",
    "films from the silent era, slapstick preferred",
    "I can only rent tonight, nothing on my subscriptions has what I want",
    "British gangster film, stylish and quotable, Guy Ritchie vibes",
    "not looking for anything mainstream, obscure foreign gems only",
    "certified banger, everyone at work won't shut up about it",
    "exactly rated R, I want the hard stuff, uncut",
    "polarizing film, some call it genius others call it pretentious garbage",
    "Australian outback thriller, gritty and relentless, 2000s era",
]

In [None]:
# Wrapper models for optional metadata preferences. Use BaseModel with value: Optional[X]
# so the schema explicitly allows null. RootModel[Optional[T]] + json_schema_extra can
# produce schemas that OpenAI strict mode treats as non-optional (e.g. anyOf at root).
# These wrappers produce {"value": null} or {"value": {...}} - clean and reliably optional.
class ReleaseDateResponse(BaseModel):
    value: Optional[DatePreference] = None

class DurationResponse(BaseModel):
    value: Optional[NumericalPreference] = None

class GenresResponse(BaseModel):
    value: Optional[GenreListPreference] = None

class AudioLanguagesResponse(BaseModel):
    value: Optional[ListPreference] = None

class WatchProvidersResponse(BaseModel):
    value: Optional[WatchProvidersPreference] = None

class MaturityRatingResponse(BaseModel):
    value: Optional[MaturityPreference] = None

class RatingResponse(BaseModel):
    value: Optional[RatingPreference] = None

# Prompt prefix for optional fields: prompts say "Return null" but we need {"value": null}
OPTIONAL_RESPONSE_PREFIX = (
    "IMPORTANT: Your response must be a JSON object with a single 'value' key. "
    "When no preference applies, use {\"value\": null}. "
    "When you have a preference, use {\"value\": { ... your extraction ... }}.\n\n"
)

# Mapping: (prompt_key, field_name, response_schema, needs_optional_prefix)
METADATA_PREFERENCE_MAPPING = [
    ("release_date", "release_date_preference", ReleaseDateResponse, True),
    ("duration", "duration_preference", DurationResponse, True),
    ("genres", "genres_preference", GenresResponse, True),
    ("audio_languages", "audio_languages_preference", AudioLanguagesResponse, True),
    ("watch_providers", "watch_providers_preference", WatchProvidersResponse, True),
    ("maturity_rating", "maturity_rating_preference", MaturityRatingResponse, True),
    ("popularity", "popular_trending_preference", PopularTrendingPreference, False),
    ("rating", "rating_preference", RatingResponse, True),
]


def _process_single_metadata_preference(
    prompt_key: str,
    system_prompt: str,
    response_schema: type,
    needs_optional_prefix: bool,
    query: str,
) -> tuple[str, Any]:
    """
    Process a single metadata preference extraction for a query.

    Returns:
        Tuple of (MetadataPreferences field name, extracted value or None)
    """
    field_name = next(fn for pk, fn, _, _ in METADATA_PREFERENCE_MAPPING if pk == prompt_key)
    try:
        full_system_prompt = (
            (OPTIONAL_RESPONSE_PREFIX + system_prompt) if needs_optional_prefix else system_prompt
        )
        response = generate_openai_response(
            user_prompt=f"User query: \"{query}\"",
            system_prompt=full_system_prompt,
            response_format=response_schema,
            model="gpt-5-mini",
            reasoning_effort="minimal",
        )
        # response = generate_kimi_response(
        #     user_prompt=f"User query: \"{query}\"",
        #     system_prompt=full_system_prompt,
        #     response_format=response_schema,
        # )
        # Optional wrappers have .value; PopularTrendingPreference is the model directly
        if hasattr(response, "value"):
            value = response.value
        else:
            value = response
        return field_name, value
    except Exception as e:
        print(f"Error processing {prompt_key} for query: {e}")
        return field_name, None


def get_metadata_preferences_parallel(query: str) -> MetadataPreferences:
    """
    Extract all metadata preferences for a query by running each prompt in parallel,
    then combine into a single MetadataPreferences instance.
    """
    with futures.ThreadPoolExecutor(max_workers=8) as executor:
        future_to_key = {
            executor.submit(
                _process_single_metadata_preference,
                prompt_key,
                ALL_METADATA_EXTRACTION_PROMPTS[prompt_key],
                response_schema,
                needs_optional_prefix,
                query,
            ): prompt_key
            for prompt_key, _, response_schema, needs_optional_prefix in METADATA_PREFERENCE_MAPPING
        }
        results = {}
        for future in futures.as_completed(future_to_key):
            field_name, value = future.result()
            results[field_name] = value
    return MetadataPreferences(
        release_date_preference=results.get("release_date_preference"),
        duration_preference=results.get("duration_preference"),
        genres_preference=results.get("genres_preference"),
        audio_languages_preference=results.get("audio_languages_preference"),
        watch_providers_preference=results.get("watch_providers_preference"),
        maturity_rating_preference=results.get("maturity_rating_preference"),
        popular_trending_preference=results.get("popular_trending_preference") or PopularTrendingPreference(
            prefers_trending_movies=False, prefers_popular_movies=False
        ),
        rating_preference=results.get("rating_preference"),
    )


metadata_preferences_results = []
for query in tqdm(metadata_queries, desc="Processing metadata preferences"):
    try:
        results = get_metadata_preferences_parallel(query)
        metadata_preferences_results.append((query, results))
    except Exception as e:
        print(f"Error processing query '{query}': {e}")

# Save to metadata_preferences_results.csv
csv_path = Path("../generated_data/metadata_preferences_results.csv")
with open(csv_path, "w", newline="", encoding="utf-8") as f:
    writer = csv.DictWriter(f, fieldnames=["query", "results"])
    writer.writeheader()
    for query, results in metadata_preferences_results:
        writer.writerow({"query": query, "results": results.model_dump_json()})

print(f"Saved {len(metadata_preferences_results)} results to {csv_path.resolve()}")

Processing metadata preferences: 100%|██████████| 20/20 [00:30<00:00,  1.50s/it]

Saved 20 results to /Users/michaelkeohane/Documents/movie-finder-rag/implementation/notebooks/metadata_preferences_results.csv





In [228]:
# print(json.dumps(kimi_response.model_dump(), indent=4))
for k,v in kimi_response.model_dump().items():
    if v:
        print(f"{k}: {v}")

release_date_preference: {'first_date': '1980-01-01', 'match_operation': 'between', 'second_date': '1989-12-31'}
duration_preference: {'first_value': 90.0, 'match_operation': 'greater_than', 'second_value': None}
genres_preference: {'must_include': ['Action'], 'must_exclude': []}
audio_languages_preference: {'must_include': [], 'must_exclude': []}
watch_providers_preference: {'must_include': [], 'must_exclude': []}
maturity_rating_preference: {'rating': 'R', 'match_operation': 'less_than_or_equal'}


## Vector Subqueries

In [37]:
class VectorCollectionQueryData(BaseModel):
    justification: str
    relevant_subquery_text: Optional[str]

class VectorRoutingResponse(BaseModel):
    plot_events_data: VectorCollectionQueryData
    plot_analysis_data: VectorCollectionQueryData
    viewer_experience_data: VectorCollectionQueryData
    watch_context_data: VectorCollectionQueryData
    narrative_techniques_data: VectorCollectionQueryData
    production_data: VectorCollectionQueryData
    reception_data: VectorCollectionQueryData

In [None]:
# ACTUALLY GENERATING THE VECTOR SUBQUERIES

responses = []

def get_vector_queries(query):
    queries = [
        (VECTOR_QUERY_PROMPTS["plot_events"], "plot_events_data"),
        (VECTOR_QUERY_PROMPTS["plot_analysis"], "plot_analysis_data"),
        (VECTOR_QUERY_PROMPTS["viewer_experience"], "viewer_experience_data"),
        (VECTOR_QUERY_PROMPTS["watch_context"], "watch_context_data"),
        (VECTOR_QUERY_PROMPTS["narrative_techniques"], "narrative_techniques_data"),
        (VECTOR_QUERY_PROMPTS["production"], "production_data"),
        (VECTOR_QUERY_PROMPTS["reception"], "reception_data"),
    ]

    def process_query(args):
        system_prompt, key = args
        print(f"Generating for {key}...")
        try:
            response = generate_kimi_response(
                user_prompt=f"User query: \"{query}\"",
                system_prompt=system_prompt,
                response_format=VectorCollectionQueryData
            )
            return key, response
        except Exception as e:
            print(f"Error processing {key}: {e}")
            return key, None

    # Run queries in parallel
    results = {}
    with futures.ThreadPoolExecutor(max_workers=7) as executor:
        future_to_key = {executor.submit(process_query, item): item[1] for item in queries}
        for future in futures.as_completed(future_to_key):
            key, response = future.result()
            if response:
                results[key] = response

    # Construct response from parallel results
    return VectorRoutingResponse(
        plot_events_data=results.get("plot_events_data"),
        plot_analysis_data=results.get("plot_analysis_data"),
        viewer_experience_data=results.get("viewer_experience_data"),
        watch_context_data=results.get("watch_context_data"),
        narrative_techniques_data=results.get("narrative_techniques_data"),
        production_data=results.get("production_data"),
        reception_data=results.get("reception_data")
    )

for query in tqdm(routing_queries):
    vector_response = get_vector_queries(query)
    responses.append(vector_response)

print(f"Generated {len(responses)} responses")

In [None]:
# SAVE GENERATED VECTOR SUBQUERIES AS CSV

with open('../generated_data/results.csv', 'w', newline='', encoding='utf-8') as csvfile:
    # Create CSV writer
    fieldnames = ['query', 'result']
    writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
    
    # Write header row
    writer.writeheader()
    
    # Write data rows
    for i, response in enumerate(responses):
        row = {
            'query': routing_queries[i],
            'result': response.model_dump_json()
        }
        
        writer.writerow(row)

print(f"Saved {len(responses)} results to results.csv")


Saved 31 results to results.csv


In [None]:
# [GRADIO INTERFACE] Visualize vector subquery results


COLLECTION_LABELS = {
    "plot_events_data": "Plot Events",
    "plot_analysis_data": "Plot Analysis",
    "viewer_experience_data": "Viewer Experience",
    "watch_context_data": "Watch Context",
    "narrative_techniques_data": "Narrative Techniques",
    "production_data": "Production",
    "reception_data": "Reception",
}

def load_results_from_csv(filename: str) -> Dict[str, str]:
    """
    Load query results from a CSV file.

    Args:
        filename: Path to the CSV file with columns 'query' and 'json_result' (or 'result')

    Returns:
        Dictionary mapping query strings to their raw JSON result strings
    """
    results = {}
    try:
        with open(filename, 'r', encoding='utf-8') as csvfile:
            reader = csv.DictReader(csvfile)
            for row in reader:
                query = row.get('query', '').strip()
                # Handle both 'json_result' and 'result' column names
                json_result = row.get('json_result', row.get('result', '')).strip()
                if query:
                    results[query] = json_result
    except FileNotFoundError:
        print(f"Warning: {filename} not found. Skipping.")
    except Exception as e:
        print(f"Error reading {filename}: {e}")
    return results

def format_result_as_markdown(json_str: str) -> str:
    """
    Convert a JSON result string into clean, human-readable Markdown.

    Each vector collection becomes a section with a friendly heading,
    the extracted subquery (or a clear "not relevant" note), and
    a collapsed justification the reader can expand if they want detail.

    Args:
        json_str: Raw JSON string with per-collection results

    Returns:
        Markdown-formatted string ready for display
    """
    try:
        cleaned = re.sub(r'[\x00-\x1f\x7f-\x9f]', ' ', json_str)
        data = json.loads(cleaned)
    except (json.JSONDecodeError, TypeError) as e:
        print(f"Error parsing JSON: {e}")
        return json_str  # Return raw text when JSON is invalid

    sections = []

    for key, value in data.items():
        if value is None:
            continue

        # Use the friendly label, fall back to a cleaned-up key name
        label = COLLECTION_LABELS.get(key, key.replace("_", " ").title())
        subquery = value.get("relevant_subquery_text")
        justification = value.get("justification", "")

        # Build the section for this collection
        section_lines = [f"### {label}"]

        if subquery:
            section_lines.append(f"**Subquery:** {subquery}")
        else:
            section_lines.append("*Not relevant to this query*")

        # Show justification in a collapsible details block
        if justification:
            section_lines.append("")
            section_lines.append("<details>")
            section_lines.append("<summary>Justification</summary>")
            section_lines.append("")
            section_lines.append(justification)
            section_lines.append("</details>")

        sections.append("\n".join(section_lines))

    print(f"sections: {len(sections)}")

    return "\n\n---\n\n".join(sections)

# Load results from both CSV files
a_results = load_results_from_csv('../generated_data/ground_truth.csv')
b_results = load_results_from_csv('../generated_data/claude_v3.csv')

# Build a unified list combining results from both files
all_queries = set(a_results.keys()) | set(b_results.keys())
unified_results: List[Dict[str, str]] = []

for query in sorted(all_queries):
    unified_results.append({
        "query": query,
        "a_result": a_results.get(query, "No result available"),
        "b_result": b_results.get(query, "No result available"),
    })

print(f"Loaded {len(unified_results)} queries from both CSV files")

def display_results(selected_query: str) -> tuple:
    """
    Look up the selected query and return formatted Markdown for both results.

    Args:
        selected_query: The query string chosen from the dropdown

    Returns:
        Tuple of (a_markdown, b_markdown) for the two display columns
    """
    for result in unified_results:
        if result["query"] == selected_query:
            return (
                format_result_as_markdown(result["a_result"]),
                format_result_as_markdown(result["b_result"]),
            )
    return "Query not found", "Query not found"

# Dropdown choices
query_options = [r["query"] for r in unified_results]

# Build the Gradio interface
with gr.Blocks(title="Query Results Comparison", theme=gr.themes.Soft()) as interface:
    gr.Markdown("# Query Results Comparison")
    gr.Markdown("Select a query to compare results from **A** and **B**.")

    # Dropdown sits above the two result columns so it gets full width
    query_dropdown = gr.Dropdown(
        choices=query_options,
        label="Select Query",
        value=query_options[0] if query_options else None,
        interactive=True,
    )

    with gr.Row(equal_height=True):
        # Column A — rendered as Markdown for readable formatting
        with gr.Column(scale=1):
            gr.Markdown("## Result A")
            a_output = gr.Markdown(value="")

        # Column B — rendered as Markdown for readable formatting
        with gr.Column(scale=1):
            gr.Markdown("## Result B")
            b_output = gr.Markdown(value="")

    # Wire up the dropdown to update both columns
    query_dropdown.change(
        fn=display_results,
        inputs=query_dropdown,
        outputs=[a_output, b_output],
    )

    # Show the first query's results on load
    if query_options:
        interface.load(
            fn=lambda: display_results(query_options[0]),
            outputs=[a_output, b_output],
        )

interface.launch(share=False)

## Vector Weights

In [38]:
class RelevanceSize(Enum):
    NOT_RELEVANT = "not_relevant"
    SMALL = "small"
    MEDIUM = "medium"
    LARGE = "large"

class VectorWeightResponse(BaseModel):
    model_config = ConfigDict(use_enum_values=True)

    relevance: RelevanceSize
    justification: str

class VectorWeights(BaseModel):
    model_config = ConfigDict(use_enum_values=True)

    plot_events_data: VectorWeightResponse
    plot_analysis_data: VectorWeightResponse
    viewer_experience_data: VectorWeightResponse
    watch_context_data: VectorWeightResponse
    narrative_techniques_data: VectorWeightResponse
    production_data: VectorWeightResponse
    reception_data: VectorWeightResponse

In [39]:
# GENERATE VECTOR WEIGHTS

def generate_single_vector_weight(system_prompt: str, query: str):
    """
    Generic method for generating a single vector weight.
    Takes a system prompt and query, runs generate_kimi_response with VectorWeightResponse format.
    
    Args:
        system_prompt: The system prompt string for the LLM
        query: The user query string
        
    Returns:
        VectorWeightResponse from the LLM
    """
    return generate_kimi_response(
        user_prompt=f"User query: \"{query}\"",
        system_prompt=system_prompt,
        response_format=VectorWeightResponse
    )

def process_single_vector_weight_query(query: str) -> Tuple[str, Dict[str, Any]]:
    """
    Process a single query by running the generic method for each prompt in VECTOR_WEIGHT_PROMPTS in parallel.
    
    Args:
        query: The query string to process
        
    Returns:
        Tuple of (query, results_dict) where results_dict maps prompt_name -> VectorWeightResponse (or error string)
    """
    results_dict = {}
    # Run all prompts for this query in parallel
    with futures.ThreadPoolExecutor(max_workers=len(VECTOR_WEIGHT_PROMPTS)) as executor:
        future_to_prompt = {
            executor.submit(generate_single_vector_weight, prompt_str, query): prompt_name
            for prompt_name, prompt_str in VECTOR_WEIGHT_PROMPTS.items()
        }
        for future in futures.as_completed(future_to_prompt):
            prompt_name = future_to_prompt[future]
            try:
                results_dict[prompt_name] = future.result()
            except Exception as e:
                results_dict[prompt_name] = f"Error: {str(e)}"
    return query, results_dict

In [None]:
# Process all queries in parallel (each query runs its prompts in parallel internally)
results: List[Tuple[str, Dict[str, Any]]] = []

with futures.ThreadPoolExecutor(max_workers=7) as executor:
    # Submit process_single_query for each routing query
    future_to_query = {executor.submit(process_single_vector_weight_query, query): query for query in routing_queries}
    
    # Collect results as they complete, with progress bar
    for future in tqdm(futures.as_completed(future_to_query), total=len(routing_queries), desc="Processing queries"):
        query, results_dict = future.result()
        results.append((query, results_dict))

# Write results to CSV file
with open('../generated_data/weights_results.csv', 'w', newline='', encoding='utf-8') as csvfile:
    fieldnames = ['query', 'result']
    writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
    writer.writeheader()
    
    for query, results_dict in results:
        # Serialize each result (VectorWeightResponse.model_dump() or keep error string)
        serializable = {
            k: (v.model_dump() if hasattr(v, 'model_dump') else v)
            for k, v in results_dict.items()
        }
        writer.writerow({'query': query, 'result': json.dumps(serializable)})

print(f"Processed {len(results)} queries and saved results to weights_results.csv")

Processing queries: 100%|██████████| 31/31 [00:15<00:00,  2.03it/s]

Processed 31 queries and saved results to weights_results.csv





In [None]:
import pandas as pd
import json
import gradio as gr
from pathlib import Path

# Load the CSV file with query results
csv_path = Path("../generated_data/weights_results.csv")
df = pd.read_csv(csv_path)

def format_results_as_markdown(query: str) -> str:
    """
    Formats the query results as a nicely structured markdown string.
    Handles the JSON format: {category: {relevance, justification}} or {category: "Error: ..."}.
    
    Args:
        query: The selected query string
        
    Returns:
        Formatted markdown string with query, relevance scores, and justifications
    """
    # Find the row matching the selected query
    row = df[df['query'] == query]
    
    if row.empty:
        return "Query not found."
    
    # Parse the result string (stored as JSON)
    result_str = row.iloc[0]['result']
    try:
        result_dict = json.loads(result_str)
    except (json.JSONDecodeError, TypeError):
        return f"Error parsing results for query: {query}"
    
    # Build the markdown output
    markdown = f"## Query\n\n**{query}**\n\n"
    markdown += "## Relevance Scores\n\n"
    
    # Group categories by relevance level; each value is {relevance, justification} or error string
    relevance_levels = {
        'large': [],
        'medium': [],
        'small': [],
        'not_relevant': []
    }
    errors = []
    
    for category, value in result_dict.items():
        category_display = category.replace('_', ' ').title()
        if isinstance(value, str):
            # Error string from failed API call
            errors.append((category_display, value))
        elif isinstance(value, dict) and 'relevance' in value:
            relevance = value['relevance']
            justification = value.get('justification', '')
            if relevance in relevance_levels:
                relevance_levels[relevance].append((category_display, justification))
            else:
                errors.append((category_display, f"Unknown relevance: {relevance}"))
        else:
            errors.append((category_display, "Invalid result format"))
    
    # Display by relevance level (high to low)
    for level in ['large', 'medium', 'small', 'not_relevant']:
        items = relevance_levels[level]
        if items:
            level_display = level.replace('_', ' ').title()
            markdown += f"### {level_display}\n\n"
            for category_display, justification in items:
                markdown += f"- **{category_display}**\n"
                if justification:
                    markdown += f"  _{justification}_\n"
            markdown += "\n"
    
    # Show any errors at the end
    if errors:
        markdown += "### Errors\n\n"
        for category_display, error_msg in errors:
            markdown += f"- **{category_display}**: {error_msg}\n"
    
    return markdown

# Create the Gradio interface
def create_interface():
    """
    Creates and launches the Gradio interface for viewing query results.
    """
    # Get list of queries for the dropdown
    queries = df['query'].tolist()
    
    # Create the interface
    interface = gr.Interface(
        fn=format_results_as_markdown,
        inputs=gr.Dropdown(
            choices=queries,
            label="Select Query",
            value=queries[0] if queries else None
        ),
        outputs=gr.Markdown(label="Results"),
        title="Query Understanding Results Viewer",
        description="Select a query to view its relevance scores for different vector categories."
    )
    
    return interface

# Launch the interface
iface = create_interface()
iface.launch()

* Running on local URL:  http://127.0.0.1:7861

To create a public link, set `share=True` in `launch()`.




## Channel Weights

In [40]:
class ChannelWeights(BaseModel):
    model_config = ConfigDict(use_enum_values=True)

    lexical_relevance: RelevanceSize
    metadata_relevance: RelevanceSize
    vector_relevance: RelevanceSize

In [41]:
CHANNEL_WEIGHTS_SYSTEM_PROMPT = """\
You are an expert at understanding search query intentions. You are an intent-to-weight router for a movie search system.

TASK
Given a single user search query, estimate the RELATIVE importance of three search channels:
1) Lexical search (specific entities to find)
2) Metadata preferences (concrete attributes to filter on. Only attributes from the list below.)
3) Vector search (semantic intent to match)

INPUT
- The user's full movie search query (as typed into a search bar).

OUTPUT (STRICT)
JSON with these keys:
- "lexical_relevance"
- "metadata_relevance"
- "vector_relevance"

Value rules:
- lexical_relevance / metadata_relevance / vector_relevance must be one of:
  "not_relevant", "small", "medium", "large"

RELEVANCY DEFINITIONS
- "not_relevant": The query has absolutely no intent relevant to what this channel searches for.
- "small": A small portion of the query's intent / search features are relevant to what this channel searches for.
- "medium": A moderate portion of the query's intent / search features are relevant to what this channel searches for.
- "large": Nearly all of the query's intent / search features are relevant to what this channel searches for.

HOW TO THINK (HIGH LEVEL)
- Each query is looking for one or more distinct features for their movie, with each one applying to one or more channels.
- Overall a specific channel's relevance is what percentage of these distinct features are searched within this channel.

CHANNEL DEFINITIONS

A) Lexical search (lexical_relevance)
The user is explicitly searching for one of the following:
- character names
- franchises / series names
- real-world people (actors, directors, writers, composers, etc.)
- real-world studios / production companies
- movie titles

Rules:
- If the user likely misspelled a name/title but intent is clearly an entity, count it as lexical.
- If a phrase could be either an entity (e.g., title) OR a descriptive phrase, count it as lexical AND also count it for whichever other channel(s) it fits.
- There are no lexical entities beyond characters, franchises, people, studios, and movie titles. Do not make up new categories.

B) Metadata preferences (metadata_relevance)
ONLY the following attributes:
- release date / decade / year
- duration / runtime
- genres
- audio languages
- streaming platforms
- maturity rating
- trending status (binary true or false)
- popularity status (binary true or false)
- reception level ONLY when explicitly framed as “good/bad” (e.g., “acclaimed”, “bad reviews”, “overrated”)

Rules:
- Some metadata attributes overlap with semantics (ex. genre). In that case count it towards both channels.
- The attributes listed above are the only pieces of metadata we use. Only increase metadata_relevance if the query has parts that match these exact attributes.
- Never add new metadata attributes beyond what I've listed above.

C) Vector search (vector_relevance)
Use this for semantic intent that is not purely a lexical entity or a structured metadata preference, including:
- plot/story content (what happens, setting in-story, character motivations)
- themes, arcs, generalized “what it's about”
- viewer experience (tone, tension, intensity, disturbance, etc.)
- watch context (why/when to watch, scenarios, motivations) if present in the query
- storytelling techniques (unreliable narrator, nonlinear timeline, twist ending, etc.)
- any ambiguous phrases that could plausibly be semantic descriptors

Rules:
- Vector search covers all movie attributes so it should always be included. It's weight increases the more the query asks for "vibes" or attributes that are hard to evaluate concretely (ex. "Has jumpscares")
- Just because a part of the query applies to one channel doesn't mean it can't also apply to this one.

CONSTRAINTS
- Base your judgment ONLY on the raw query text.
- My lists are gospel. If a part of the query doesn't match the description I've provided for a given channel, it's not relevant to this channel.
- Do not assume access to any other extraction models or filters.
- Do not output absolute numeric weights—only the allowed T-shirt sizes.
- Double check: are you using metadata attributes not explicitly stated in my list above? If so, remove them.\
"""

In [42]:
def process_channel_weights_query(query: str):
    """Process a single query through generate_kimi_response for channel weights."""
    response = generate_kimi_response(
        user_prompt=f"User query: \"{query}\"",
        system_prompt=CHANNEL_WEIGHTS_SYSTEM_PROMPT,
        response_format=ChannelWeights,
    )
    return (query, response)

In [None]:
results: List[Tuple[str, Dict[str, Any]]] = []

with futures.ThreadPoolExecutor(max_workers=10) as executor:
    # Submit process_single_query for each routing query
    future_to_query = {executor.submit(process_channel_weights_query, query): query for query in routing_queries}
    
    # Collect results as they complete, with progress bar
    for future in tqdm(futures.as_completed(future_to_query), total=len(routing_queries), desc="Processing queries"):
        query, results_dict = future.result()
        results.append((query, results_dict))

print(f"Generated {len(results)} channel weights responses")

# Write results to CSV: query plus each response's relevance fields
with open('../generated_data/channel_weights.csv', 'w', newline='', encoding='utf-8') as csvfile:
    fieldnames = ['query', 'result']
    writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
    
    # Write header row
    writer.writeheader()
    
    # Write data rows
    for query, results_dict in results:
        row = {
            'query': query,
            'result': results_dict.model_dump_json()
        }
        
        writer.writerow(row)
    
print(f"Saved to {csv_path.resolve()}")

Processing queries: 100%|██████████| 31/31 [00:06<00:00,  5.03it/s]

Generated 31 channel weights responses





AttributeError: 'ChannelWeights' object has no attribute 'items'

# Putting it all together

In [None]:
def get_channel_weights(query):
    """Extract channel weights (lexical, metadata, vector relevance) from query."""
    return process_channel_weights_query(query)

def get_lexical_entities(query):
    """Extract lexical entities (characters, franchises, people, studios, titles) from query."""
    return generate_openai_response(
        user_prompt=f"User query: \"{query}\"",
        system_prompt=EXTRACT_LEXICAL_ENTITIES_SYSTEM_PROMPT,
        response_format=ExtractedEntities,
    )

def get_metadata_preferences(query):
    """Extract metadata preferences (date, duration, genres, etc.) from query."""
    return generate_kimi_response(
        user_prompt=f"User query: \"{query}\"",
        system_prompt=EXTRACT_METADATA_PREFERENCES_SYSTEM_PROMPT,
        response_format=MetadataPreferences,
    )

def _process_single_vector_query(query: str, system_prompt: str, key: str) -> tuple[str, VectorCollectionQueryData]:
    """
    Process a single vector routing query.
    
    Args:
        query: The user's search query
        system_prompt: The system prompt for this vector collection
        key: The key name for this vector collection
        
    Returns:
        Tuple of (key, VectorCollectionQueryData response)
    """
    try:
        response = generate_kimi_response(
            user_prompt=f"User query: \"{query}\"",
            system_prompt=system_prompt,
            response_format=VectorCollectionQueryData
        )
        return key, response
    except Exception as e:
        print(f"Error processing {key}: {e}")
        return key, None

def _process_single_vector_weight(prompt_name: str, prompt_str: str, query: str) -> tuple[str, Any]:
    """
    Process a single vector weight query.
    
    Returns:
        Tuple of (prompt_name, VectorWeightResponse or error string)
    """
    try:
        response = generate_single_vector_weight(prompt_str, query)
        return prompt_name, response
    except Exception as e:
        return prompt_name, f"Error: {str(e)}"

def get_query_understanding(query: str) -> Dict[str, Any]:
    """
    Extract all query understanding components in parallel.
    
    Runs 17 LLM calls in a single flat executor (no nested parallelization):
    - 1 lexical entity extraction
    - 1 metadata preferences
    - 1 channel weights
    - 7 vector subqueries (VECTOR_QUERY_PROMPTS style)
    - 7 vector weights (VECTOR_WEIGHT_PROMPTS)
    
    Args:
        query: The user's movie search query
        
    Returns:
        Dict with channel_weights, lexical_entities, metadata_preferences,
        vector_routing, vector_weights
    """
    # Vector subqueries: same structure as get_vector_queries (VECTOR_QUERY_PROMPTS)
    vector_subquery_tasks = [
        (VECTOR_QUERY_PROMPTS["plot_events"], "plot_events_data"),
        (VECTOR_QUERY_PROMPTS["plot_analysis"], "plot_analysis_data"),
        (VECTOR_QUERY_PROMPTS["viewer_experience"], "viewer_experience_data"),
        (VECTOR_QUERY_PROMPTS["watch_context"], "watch_context_data"),
        (VECTOR_QUERY_PROMPTS["narrative_techniques"], "narrative_techniques_data"),
        (VECTOR_QUERY_PROMPTS["production"], "production_data"),
        (VECTOR_QUERY_PROMPTS["reception"], "reception_data"),
    ]
    
    # Vector weight tasks: one per prompt in VECTOR_WEIGHT_PROMPTS
    vector_weight_tasks = [
        (prompt_name, prompt_str)
        for prompt_name, prompt_str in VECTOR_WEIGHT_PROMPTS.items()
    ]
    
    with futures.ThreadPoolExecutor(max_workers=17) as executor:
        # Submit all 17 tasks
        future_channel = executor.submit(get_channel_weights, query)
        future_lexical = executor.submit(get_lexical_entities, query)
        future_metadata = executor.submit(get_metadata_preferences, query)
        
        vector_subquery_futures = {
            executor.submit(_process_single_vector_query, query, system_prompt, key): key
            for system_prompt, key in vector_subquery_tasks
        }
        
        vector_weight_futures = {
            executor.submit(_process_single_vector_weight, prompt_name, prompt_str, query): prompt_name
            for prompt_name, prompt_str in vector_weight_tasks
        }
        
        # Collect vector subquery results
        vector_results = {}
        for future in futures.as_completed(vector_subquery_futures):
            key, response = future.result()
            if response:
                vector_results[key] = response
        
        # Collect vector weight results
        vector_weight_results = {}
        for future in futures.as_completed(vector_weight_futures):
            prompt_name, response = future.result()
            vector_weight_results[prompt_name] = response
        
        vector_routing = VectorRoutingResponse(
            plot_events_data=vector_results.get("plot_events_data"),
            plot_analysis_data=vector_results.get("plot_analysis_data"),
            viewer_experience_data=vector_results.get("viewer_experience_data"),
            watch_context_data=vector_results.get("watch_context_data"),
            narrative_techniques_data=vector_results.get("narrative_techniques_data"),
            production_data=vector_results.get("production_data"),
            reception_data=vector_results.get("reception_data")
        )
        
        return {
            "channel_weights": future_channel.result(),
            "lexical_entities": future_lexical.result(),
            "metadata_preferences": future_metadata.result(),
            "vector_routing": vector_routing,
            "vector_weights": vector_weight_results,
        }

def _serialize_for_json(obj: Any) -> Any:
    """Convert Pydantic models and nested structures to JSON-serializable dicts."""
    if hasattr(obj, "model_dump"):
        return obj.model_dump(mode="json")
    if isinstance(obj, dict):
        return {k: _serialize_for_json(v) for k, v in obj.items()}
    if isinstance(obj, (list, tuple)):
        return [_serialize_for_json(v) for v in obj]
    return obj

# Process each routing query and save to overall_results.csv
overall_results = []
for query in tqdm(routing_queries, desc="Processing routing queries"):
    result = get_query_understanding(query)
    # Build JSON-serializable dict of all fetched attributes
    serialized = {
        "channel_weights": _serialize_for_json(result["channel_weights"]),
        "lexical_entities": _serialize_for_json(result["lexical_entities"]),
        "metadata_preferences": _serialize_for_json(result["metadata_preferences"]),
        "vector_routing": _serialize_for_json(result["vector_routing"]),
        "vector_weights": {
            k: _serialize_for_json(v) if hasattr(v, "model_dump") else v
            for k, v in result["vector_weights"].items()
        },
    }
    overall_results.append({"query": query, "results": json.dumps(serialized)})

csv_path = Path("../generated_data/overall_results.csv")
with open(csv_path, "w", newline="", encoding="utf-8") as f:
    writer = csv.DictWriter(f, fieldnames=["query", "results"])
    writer.writeheader()
    writer.writerows(overall_results)

print(f"Saved {len(overall_results)} results to {csv_path.resolve()}")

Processing routing queries: 100%|██████████| 31/31 [03:39<00:00,  7.07s/it]

Saved 31 results to /Users/michaelkeohane/Documents/movie-finder-rag/implementation/notebooks/overall_results.csv





In [None]:
# Gradio interface for overall_results.csv

import pandas as pd
import json
import gradio as gr
from pathlib import Path


def _format_value(val, indent=0):
    """Recursively format a value for display (handles dicts, lists, None)."""
    if val is None:
        return "_none_"
    if isinstance(val, bool):
        return str(val)
    if isinstance(val, (int, float)):
        return str(val)
    if isinstance(val, str):
        return val
    if isinstance(val, list):
        if not val:
            return "[]"
        items = [_format_value(v, indent + 1) for v in val]
        return "\n" + "  " * (indent + 1) + ("\n" + "  " * (indent + 1)).join(f"- {x}" for x in items)
    if isinstance(val, dict):
        lines = []
        for k, v in val.items():
            if v is None or v == "" or v == [] or v == {}:
                continue
            lines.append(f"**{k.replace('_', ' ').title()}:** {_format_value(v, indent + 1)}")
        return "\n" + "  " * (indent + 1) + ("\n" + "  " * (indent + 1)).join(lines)
    return str(val)


def _format_metadata_preferences(mp: dict) -> str:
    """
    Format metadata preferences as a bulleted list grouped by high-level keys.
    Each top-level key (e.g., release_date_preference) becomes a bullet group
    with its sub-items as nested bullets.
    """
    lines = []
    for key, val in mp.items():
        label = key.replace("_", " ").title()
        if val is None:
            continue
        if isinstance(val, dict):
            sub_items = []
            for k, v in val.items():
                sub_label = k.replace("_", " ").title()
                if isinstance(v, list):
                    sub_val = ", ".join(str(x) for x in v) if v else "(none)"
                    sub_items.append(f"- **{sub_label}:** {sub_val}")
                elif v is None:
                    sub_items.append(f"- **{sub_label}:** (none)")
                else:
                    sub_items.append(f"- **{sub_label}:** {v}")
            if sub_items:
                lines.append(f"- **{label}**")
                lines.extend(f"  {s}" for s in sub_items)
        elif isinstance(val, list):
            lines.append(f"- **{label}:** {', '.join(str(x) for x in val) if val else '(none)'}")
        else:
            lines.append(f"- **{label}:** {val}")
    return "\n".join(lines) if lines else "_No metadata preferences._"


def _escape_html(text: str) -> str:
    """Escape HTML special chars so content doesn't break tags."""
    return text.replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;").replace('"', "&quot;")

def _collapsible_section(title: str, content: str, open_by_default: bool = False) -> str:
    """Build HTML collapsible section for justifications."""
    open_attr = " open" if open_by_default else ""
    safe_content = _escape_html(content)
    return f'''
<details{open_attr}>
<summary><strong>{title}</strong></summary>
<p style="margin: 0.5em 0; line-height: 1.5;">{safe_content}</p>
</details>
'''


def format_overall_results(query: str, df: pd.DataFrame) -> str:
    """
    Format the query results from overall_results.csv for display.
    Justifications are placed in collapsible <details> sections.
    
    Args:
        query: The selected query string
        df: DataFrame loaded from overall_results.csv
        
    Returns:
        HTML/Markdown string with formatted results
    """
    row = df[df["query"] == query]
    if row.empty:
        return "Query not found."
    
    result_str = row.iloc[0]["results"]
    try:
        data = json.loads(result_str)
    except (json.JSONDecodeError, TypeError):
        return f"Error parsing results for query: {query}"
    
    parts = [f"# Query\n\n> **{query}**\n"]
    
    # Channel weights
    cw = data.get("channel_weights")
    if cw is not None:
        # Handle format: [query, dict] or just dict
        weights = cw[1] if isinstance(cw, list) and len(cw) > 1 else (cw if isinstance(cw, dict) else {})
        if weights:
            parts.append("## Channel Weights\n")
            parts.append(f"- **Lexical relevance:** {weights.get('lexical_relevance', 'N/A')}")
            parts.append(f"- **Metadata relevance:** {weights.get('metadata_relevance', 'N/A')}")
            parts.append(f"- **Vector relevance:** {weights.get('vector_relevance', 'N/A')}\n")
    
    # Lexical entities
    le = data.get("lexical_entities", {})
    entity_candidates = le.get("entity_candidates", [])
    parts.append("## Lexical Entities\n")
    if entity_candidates:
        for e in entity_candidates:
            parts.append(f"- **{e.get('corrected_and_normalized_entity', 'N/A')}** ({e.get('most_likely_category', '')})")
    else:
        parts.append("_No entities extracted._\n")
    
    # Metadata preferences (bulleted list grouped by high-level keys)
    mp = data.get("metadata_preferences", {})
    if mp:
        parts.append("## Metadata Preferences\n")
        parts.append(_format_metadata_preferences(mp) + "\n")
    
    # Vectors: combined weight + subquery + justifications per collection
    vr = data.get("vector_routing", {}) or {}
    vw = data.get("vector_weights", {}) or {}
    vector_collections = [
        ("plot_events_data", "plot_events", "Plot Events"),
        ("plot_analysis_data", "plot_analysis", "Plot Analysis"),
        ("viewer_experience_data", "viewer_experience", "Viewer Experience"),
        ("watch_context_data", "watch_context", "Watch Context"),
        ("narrative_techniques_data", "narrative_techniques", "Narrative Techniques"),
        ("production_data", "production", "Production"),
        ("reception_data", "reception", "Reception"),
    ]
    if vr or vw:
        parts.append("## Vectors\n")
        for routing_key, weight_key, label in vector_collections:
            routing_val = vr.get(routing_key)
            weight_val = vw.get(weight_key)
            if routing_val is None and weight_val is None:
                continue
            parts.append(f"### {label}\n")
            # Weight (relevance)
            if isinstance(weight_val, str):
                parts.append(f"**Weight:** _Error: {weight_val}_\n")
            elif isinstance(weight_val, dict):
                relevance = weight_val.get("relevance", "N/A")
                weight_just = weight_val.get("justification", "")
                parts.append(f"**Weight:** {relevance}\n")
                if weight_just:
                    parts.append(_collapsible_section("Weight justification", weight_just))
            else:
                parts.append("**Weight:** N/A\n")
            # Subquery
            if isinstance(routing_val, dict):
                subquery = routing_val.get("relevant_subquery_text") or "_none_"
                subquery_just = routing_val.get("justification", "")
                parts.append(f"**Subquery:** `{subquery}`\n")
                if subquery_just:
                    parts.append(_collapsible_section("Subquery justification", subquery_just))
            elif routing_val is None and isinstance(weight_val, dict):
                parts.append("**Subquery:** _none_\n")
            parts.append("\n")
    
    return "\n".join(parts)


# Load overall_results.csv
csv_path = Path("../generated_data/overall_results.csv")
if not csv_path.exists():
    raise FileNotFoundError(f"overall_results.csv not found at {csv_path.resolve()}")

df_overall = pd.read_csv(csv_path)
query_choices = df_overall["query"].tolist()

# Build Gradio interface
with gr.Blocks(title="Query Understanding Results", theme=gr.themes.Soft()) as overall_interface:
    gr.Markdown("# Query Understanding Results Viewer")
    gr.Markdown("Select a query to view its full extraction results. Justifications are in collapsible sections.")
    
    query_dropdown = gr.Dropdown(
        choices=query_choices,
        value=query_choices[0] if query_choices else None,
        label="Select Query",
        allow_custom_value=False,
    )
    
    results_output = gr.Markdown(
        value=format_overall_results(query_choices[0], df_overall) if query_choices else "No data loaded."
    )
    
    def on_query_change(query):
        if not query:
            return "Select a query."
        return format_overall_results(query, df_overall)
    
    query_dropdown.change(fn=on_query_change, inputs=[query_dropdown], outputs=[results_output])

# Launch (use share=False for local only)
overall_interface.launch()

* Running on local URL:  http://127.0.0.1:7864

To create a public link, set `share=True` in `launch()`.


