# Topic Router

This notebook assigns topics to reviews using LLM-powered multi-label classification.


In [1]:
import sys
sys.path.append('../')

import polars as pl
import json
import os
from pathlib import Path
from datetime import datetime, timedelta, timezone
from utils.llm_client import LLMClient
from tqdm import tqdm
import hashlib

# Set up paths
DATA_DIR = Path("../data")
REVIEWS_FILE = DATA_DIR / "reviews_clean.parquet"
REGISTRY_FILE = Path("../registry/topic_registry.json")
OUTPUT_FILE = DATA_DIR / "labels_initial.parquet"

IST_TZ = timezone(timedelta(hours=5, minutes=30))
START_DATE = datetime(2024, 6, 1, tzinfo=IST_TZ)
TARGET_DATE = datetime.now(IST_TZ).date()
ROLLING_WINDOW_DAYS = 30
DAILY_REVIEWS_DIR = DATA_DIR / "daily_batches"
DAILY_LABELS_DIR = DATA_DIR / "daily_labels"

for path in [DATA_DIR, DAILY_REVIEWS_DIR, DAILY_LABELS_DIR]:
    path.mkdir(exist_ok=True)


print("‚úì Setup complete")


‚úì Setup complete


## Load Data and Registry


In [2]:
# Load reviews
reviews_df = pl.read_parquet(REVIEWS_FILE)
if 'created_at' not in reviews_df.columns:
    raise ValueError('Expected created_at column in reviews parquet.')
print(f"‚úì Loaded {len(reviews_df):,} reviews")

reviews_df = reviews_df.with_columns([
    pl.col('created_at').dt.convert_time_zone('Asia/Kolkata').alias('created_at'),
    pl.col('created_at').dt.date().alias('dt')
])


‚úì Loaded 225,918 reviews


## Initialize LLM Client


In [3]:
# Choose provider: 'openai' or 'ollama'
PROVIDER = 'ollama'  # Default to local LLM (switch to 'openai' if API available)
MODEL = 'qwen3:8b'  # Change to another Ollama tag or OpenAI model as needed

llm = LLMClient(provider=PROVIDER, model=MODEL)

print(f"‚úì Initialized LLM client: {PROVIDER} with model {MODEL}")


‚úì Initialized ollama client with model qwen3:8b
‚úì Initialized LLM client: ollama with model qwen3:8b


## Load Topic Registry

In [4]:
# Load topic registry
with open(REGISTRY_FILE) as f:
    registry = json.load(f)

registry_topics = registry.get("topics", [])
if not registry_topics:
    raise ValueError(f"No topics loaded from registry: {REGISTRY_FILE}")

topic_lookup = {topic["id"]: topic for topic in registry_topics}
print(f"‚úì Loaded {len(registry_topics)} topics from registry")

‚úì Loaded 32 topics from registry


## Routing Helpers

In [5]:
from textwrap import dedent
from typing import Any, Dict, List
import time

ROUTING_SYSTEM_PROMPT = (
    "You are a high-recall topic routing assistant for Swiggy reviews. "
    "Use the provided topic catalog to assign every relevant topic id. "
    "Return JSON with keys topic_ids (list[str]), is_novel (bool), novel (object or null). "
    "Prefer recall but avoid assigning unrelated topics."
)

MAX_TOPICS_PER_REVIEW = 4
ROUTING_TEMPERATURE = 0.2
MODEL_COST_USD_PER_1K = {
    "gpt-3.5-turbo": {"input": 0.0005, "output": 0.0015},
    "gpt-4o-mini": {"input": 0.0006, "output": 0.0024},
    "gpt-4o": {"input": 0.0025, "output": 0.01},
}
DEFAULT_COST = {"input": 0.0010, "output": 0.0020}
MAX_COST_USD = float(os.getenv('ROUTING_COST_BUDGET_USD', '12'))
OUTPUT_TOKEN_ESTIMATE = 180  # heuristic budget per call
PROMPT_BACKOFF = 1.5

def build_topic_catalog(topics: List[Dict[str, Any]]) -> str:
    lines = []
    for topic in topics:
        examples = topic.get('positive_examples') or []
        example_snippet = '; '.join(examples[:2]) if examples else ''
        definition = topic.get('definition', '')
        if example_snippet:
            lines.append(f"{topic['id']} :: {topic['name']} ‚Äî {definition} (e.g. {example_snippet})")
        else:
            lines.append(f"{topic['id']} :: {topic['name']} ‚Äî {definition}")
    return "\n".join(lines)


def estimate_tokens(text: str) -> int:
    if not text:
        return 0
    return max(1, int(len(text) / 4))

TOPIC_CATALOG = build_topic_catalog(registry_topics)
PROMPT_OVERHEAD_TOKENS = estimate_tokens(TOPIC_CATALOG) + 150

def estimate_cost(input_tokens: int, output_tokens: int, model: str) -> float:
    rates = MODEL_COST_USD_PER_1K.get(model, DEFAULT_COST)
    return (input_tokens * rates['input'] + output_tokens * rates['output']) / 1000.0

def build_user_prompt(review_text: str) -> str:
    return dedent(
        f"""
        Review text:
        \"\"\"{review_text.strip()}\"\"\"

    Topic catalog:
    {TOPIC_CATALOG}

    Instructions:
    - Reply strictly in JSON with keys:
      - topic_ids: up to {MAX_TOPICS_PER_REVIEW} topic IDs from the catalog that apply.
      - is_novel: true if the review exposes a new issue not covered in the catalog.
      - novel: when is_novel is true, include an object with keys label (<=5 words) and rationale.
    - Capture every relevant topic even for positive sentiment.
    - Prefer existing topics when the description is close to a catalog entry.
    """).strip()

def validate_topic_ids(topic_ids: Any) -> List[str]:
    cleaned: List[str] = []
    if isinstance(topic_ids, str):
        topic_ids = [topic_ids]
    if not isinstance(topic_ids, list):
        return cleaned
    for topic_id in topic_ids:
        if not isinstance(topic_id, str):
            continue
        normalized = topic_id.strip().upper()
        if normalized in topic_lookup and normalized not in cleaned:
            cleaned.append(normalized)
    return cleaned

def route_review(review_text: str, llm_client: LLMClient, max_retries: int = 3, retry_backoff: float = PROMPT_BACKOFF) -> Dict[str, Any]:
    if not review_text or not review_text.strip():
        return {
            'topic_ids': [],
            'is_novel': False,
            'novel': None,
            'input_tokens_est': 0,
            'output_tokens_est': OUTPUT_TOKEN_ESTIMATE,
            'routing_error': 'empty_review',
        }

    input_estimate = estimate_tokens(review_text) + PROMPT_OVERHEAD_TOKENS
    attempt = 0
    last_error = None

    while attempt < max_retries:
        try:
            response = llm_client.complete(
                system_prompt=ROUTING_SYSTEM_PROMPT,
                user_prompt=build_user_prompt(review_text),
                temperature=ROUTING_TEMPERATURE,
                response_format='json',
                use_cache=True,
            ) or {}
            topic_ids = response.get('topic_ids') or response.get('topics')
            cleaned_topics = validate_topic_ids(topic_ids)
            novel_payload = response.get('novel') if isinstance(response.get('novel'), dict) else None
            is_novel = bool(response.get('is_novel') and novel_payload)
            return {
                'topic_ids': cleaned_topics,
                'is_novel': is_novel,
                'novel': novel_payload,
                'input_tokens_est': input_estimate,
                'output_tokens_est': OUTPUT_TOKEN_ESTIMATE,
                'routing_error': None,
            }
        except Exception as exc:
            attempt += 1
            last_error = repr(exc)
            wait_seconds = min(8, retry_backoff ** attempt)
            print(f"  ‚ö†Ô∏è Routing failed (attempt {attempt}/{max_retries}): {exc}")
            time.sleep(wait_seconds)

    return {
        'topic_ids': [],
        'is_novel': False,
        'novel': None,
        'input_tokens_est': input_estimate,
        'output_tokens_est': OUTPUT_TOKEN_ESTIMATE,
        'routing_error': last_error or 'unknown_error',
    }

def route_batch(rows: List[Dict[str, Any]], llm_client: LLMClient, batch_desc: str, cache: Dict[str, Dict[str, Any]]) -> List[Dict[str, Any]]:
    results: List[Dict[str, Any]] = []
    for row in tqdm(rows, desc=f'Routing {batch_desc}', leave=False):
        review_text = row['content_raw']
        text_hash = hashlib.sha256(review_text.encode('utf-8')).hexdigest()
        if text_hash in cache:
            cached_result = cache[text_hash]
            routed = cached_result.copy()
            routed['from_cache'] = True
        else:
            routed = route_review(review_text, llm_client)
            cache[text_hash] = routed.copy()
            routed['from_cache'] = False
        results.append({**row, **routed})
    return results

## Create Router Function


## Batch Process Reviews

Process all reviews in batches with caching


In [None]:
# Configure daily processing window
target_date = TARGET_DATE
if target_date < START_DATE.date():
    raise ValueError('Target date precedes June 1 2024. Adjust TARGET_DATE.')
window_start = target_date - timedelta(days=ROLLING_WINDOW_DAYS)
print(f'Processing window: {window_start} to {target_date}')

# Ensure date column exists
if 'dt' not in reviews_df.columns:
    reviews_df = reviews_df.with_columns([
        pl.col('created_at').dt.convert_time_zone('Asia/Kolkata').dt.date().alias('dt')
    ])

unique_dates = sorted(set(reviews_df['dt'].to_list()))
if not unique_dates:
    raise ValueError('No review dates available for routing.')
print(f'Total candidate days: {len(unique_dates)}')

all_labels = []
cache: Dict[str, Dict[str, Any]] = {}
total_review_count = 0
total_assignments = 0
total_novel = 0
total_input_tokens = 0
total_output_tokens = 0
total_cost_est = 0.0
routing_errors = 0

for batch_date in unique_dates:
    if batch_date < window_start or batch_date > target_date:
        continue
    day_reviews = reviews_df.filter(pl.col('dt') == batch_date).sort('created_at')
    day_key = batch_date.isoformat()
    print(f'üìÖ Processing {day_key}: {len(day_reviews)} reviews')

    if len(day_reviews) == 0:
        print('  ‚Üí No reviews for this day.')
        continue

    day_reviews_path = DAILY_REVIEWS_DIR / f'reviews_{day_key}.parquet'
    day_reviews.write_parquet(day_reviews_path)

    day_rows = [
        {
            'review_id': row['review_id'],
            'content_raw': row['content_raw'],
            'created_at': row['created_at'],
            'dt': row['dt'],
        }
        for row in day_reviews.iter_rows(named=True)
    ]

    batch_results = route_batch(day_rows, llm, batch_desc=day_key, cache=cache)

    day_labels = []
    day_assignments = 0
    day_novel = 0
    day_errors = 0
    day_cost_est = 0.0

    for routed in batch_results:
        total_review_count += 1

        input_tokens = routed.get('input_tokens_est', 0)
        output_tokens = routed.get('output_tokens_est', OUTPUT_TOKEN_ESTIMATE)
        from_cache = routed.get('from_cache', False)
        call_cost = 0.0
        if not from_cache:
            call_cost = estimate_cost(input_tokens, output_tokens, MODEL)
            total_input_tokens += input_tokens
            total_output_tokens += output_tokens
            total_cost_est += call_cost
            day_cost_est += call_cost

            if total_cost_est > MAX_COST_USD:
                raise RuntimeError(
                    f"Estimated routing cost ${total_cost_est:.2f} exceeds budget ${MAX_COST_USD:.2f}. "
                    'Set ROUTING_COST_BUDGET_USD to raise the limit or reduce the date range.'
                )

        topic_ids = routed.get('topic_ids', [])
        if topic_ids:
            for topic_id in topic_ids:
                day_labels.append({
                    'review_id': routed['review_id'],
                    'topic_id': topic_id,
                    'is_novel': False,
                    'novel_label': None,
                    'novel_rationale': None,
                    'created_at': routed['created_at'],
                    'dt': routed['dt'],
                })
                day_assignments += 1
                total_assignments += 1

        if routed.get('is_novel') and isinstance(routed.get('novel'), dict):
            novel = routed['novel']
            day_labels.append({
                'review_id': routed['review_id'],
                'topic_id': 'NOVEL',
                'is_novel': True,
                'novel_label': novel.get('label'),
                'novel_rationale': novel.get('rationale'),
                'created_at': routed['created_at'],
                'dt': routed['dt'],
            })
            day_novel += 1
            total_novel += 1

        if routed.get('routing_error'):
            routing_errors += 1
            day_errors += 1

    print(f'  ‚Üí Routed {len(batch_results)} reviews | assignments: {day_assignments} | novel: {day_novel} | errors: {day_errors} | est cost: ${day_cost_est:.2f}')

    if day_labels:
        day_labels_df = pl.DataFrame(day_labels)
        day_labels_path = DAILY_LABELS_DIR / f'labels_{day_key}.parquet'
        day_labels_df.write_parquet(day_labels_path)
        print(f'    Saved {len(day_labels)} label rows to {day_labels_path.name}')
        all_labels.extend(day_labels)
    else:
        print('    No topics detected for this day.')

print('=== Routing Summary ===')
print(f'Total reviews routed: {total_review_count}')
print(f'Total topic assignments: {total_assignments}')
print(f'Novel reviews flagged: {total_novel}')
print(f'Estimated tokens (input/output): {total_input_tokens}/{total_output_tokens}')
print(f'Estimated cost: ${total_cost_est:.2f} (budget ${MAX_COST_USD:.2f})')
if total_cost_est > MAX_COST_USD * 0.9:
    print('‚ö†Ô∏è Estimated cost is approaching the configured budget.')
if routing_errors:
    print(f'‚ö†Ô∏è {routing_errors} reviews encountered routing errors. Consider rerunning with higher retry count.')
else:
    print('‚úì No routing errors detected.')

Processing window: 2025-09-28 to 2025-10-28
Total candidate days: 513
üìÖ Processing 2025-09-28: 776 reviews


Routing 2025-09-28:   0%|                               | 0/776 [00:00<?, ?it/s]

  ‚ö†Ô∏è Routing failed (attempt 1/3): timed out


Routing 2025-09-28:   1%|                   | 4/776 [14:30<37:36:21, 175.37s/it]

  ‚ö†Ô∏è Routing failed (attempt 1/3): timed out


Routing 2025-09-28:   1%|‚ñè                  | 6/776 [23:34<45:50:41, 214.34s/it]

  ‚ö†Ô∏è Routing failed (attempt 1/3): timed out
  ‚ö†Ô∏è Routing failed (attempt 2/3): timed out
  ‚ö†Ô∏è Routing failed (attempt 3/3): timed out


Routing 2025-09-28:   2%|‚ñé                 | 14/776 [52:57<34:58:03, 165.20s/it]

  ‚ö†Ô∏è Routing failed (attempt 1/3): timed out


Routing 2025-09-28:   9%|‚ñà‚ñå               | 72/776 [2:27:47<15:09:34, 77.52s/it]

  ‚ö†Ô∏è Routing failed (attempt 1/3): timed out


Routing 2025-09-28:  10%|‚ñà‚ñå              | 78/776 [2:42:56<21:49:31, 112.57s/it]

  ‚ö†Ô∏è Routing failed (attempt 1/3): timed out


Routing 2025-09-28:  11%|‚ñà‚ñã              | 83/776 [2:58:35<29:05:17, 151.11s/it]

  ‚ö†Ô∏è Routing failed (attempt 1/3): timed out
  ‚ö†Ô∏è Routing failed (attempt 2/3): timed out
  ‚ö†Ô∏è Routing failed (attempt 3/3): timed out


Routing 2025-09-28:  13%|‚ñà‚ñà              | 103/776 [3:42:43<13:17:44, 71.12s/it]

  ‚ö†Ô∏è Routing failed (attempt 1/3): timed out


Routing 2025-09-28:  15%|‚ñà‚ñà‚ñç             | 116/776 [4:08:58<14:28:47, 78.98s/it]

  ‚ö†Ô∏è Routing failed (attempt 1/3): timed out


Routing 2025-09-28:  16%|‚ñà‚ñà‚ñç             | 121/776 [4:18:41<17:01:43, 93.59s/it]

  ‚ö†Ô∏è Routing failed (attempt 1/3): timed out
  ‚ö†Ô∏è Routing failed (attempt 2/3): timed out
  ‚ö†Ô∏è Routing failed (attempt 3/3): timed out


Routing 2025-09-28:  16%|‚ñà‚ñà‚ñç            | 124/776 [4:36:25<31:50:33, 175.82s/it]

  ‚ö†Ô∏è Routing failed (attempt 1/3): timed out
  ‚ö†Ô∏è Routing failed (attempt 2/3): timed out
  ‚ö†Ô∏è Routing failed (attempt 3/3): timed out


Routing 2025-09-28:  16%|‚ñà‚ñà‚ñç            | 126/776 [4:51:33<46:58:17, 260.15s/it]

  ‚ö†Ô∏è Routing failed (attempt 1/3): timed out


Routing 2025-09-28:  22%|‚ñà‚ñà‚ñà‚ñé           | 170/776 [6:05:22<27:18:50, 162.26s/it]

  ‚ö†Ô∏è Routing failed (attempt 1/3): timed out


Routing 2025-09-28:  23%|‚ñà‚ñà‚ñà‚ñã            | 181/776 [6:26:16<15:37:56, 94.58s/it]

  ‚ö†Ô∏è Routing failed (attempt 1/3): timed out
  ‚ö†Ô∏è Routing failed (attempt 2/3): timed out
  ‚ö†Ô∏è Routing failed (attempt 3/3): timed out


Routing 2025-09-28:  26%|‚ñà‚ñà‚ñà‚ñâ           | 203/776 [7:21:06<22:46:57, 143.14s/it]

  ‚ö†Ô∏è Routing failed (attempt 1/3): timed out


Routing 2025-09-28:  28%|‚ñà‚ñà‚ñà‚ñà‚ñè          | 215/776 [7:46:37<17:26:50, 111.96s/it]

  ‚ö†Ô∏è Routing failed (attempt 1/3): timed out


Routing 2025-09-28:  30%|‚ñà‚ñà‚ñà‚ñà‚ñå          | 235/776 [8:25:54<19:20:43, 128.73s/it]

  ‚ö†Ô∏è Routing failed (attempt 1/3): timed out


Routing 2025-09-28:  34%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñã           | 262/776 [9:19:57<9:42:44, 68.02s/it]

  ‚ö†Ô∏è Routing failed (attempt 1/3): timed out


Routing 2025-09-28:  34%|‚ñà‚ñà‚ñà‚ñà‚ñà          | 265/776 [9:30:58<16:51:46, 118.80s/it]

  ‚ö†Ô∏è Routing failed (attempt 1/3): timed out


Routing 2025-09-28:  39%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñè         | 302/776 [10:32:28<9:39:32, 73.36s/it]

  ‚ö†Ô∏è Routing failed (attempt 1/3): timed out


Routing 2025-09-28:  39%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå        | 305/776 [10:44:06<17:05:27, 130.63s/it]

  ‚ö†Ô∏è Routing failed (attempt 1/3): timed out


Routing 2025-09-28:  40%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå        | 310/776 [11:04:17<23:31:39, 181.76s/it]

  ‚ö†Ô∏è Routing failed (attempt 1/3): timed out


Routing 2025-09-28:  41%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñã        | 315/776 [11:18:01<19:02:17, 148.67s/it]

  ‚ö†Ô∏è Routing failed (attempt 1/3): timed out
  ‚ö†Ô∏è Routing failed (attempt 2/3): timed out


Routing 2025-09-28:  41%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñã        | 317/776 [11:32:12<30:59:02, 243.01s/it]

  ‚ö†Ô∏è Routing failed (attempt 1/3): timed out


Routing 2025-09-28:  43%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñâ        | 332/776 [12:00:24<12:55:15, 104.77s/it]

  ‚ö†Ô∏è Routing failed (attempt 1/3): timed out


Routing 2025-09-28:  47%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå       | 363/776 [13:01:14<14:48:54, 129.14s/it]

  ‚ö†Ô∏è Routing failed (attempt 1/3): timed out


## Save Results and Show Distribution


In [None]:
if not all_labels:
    raise ValueError('No routing labels generated. Run the routing step before saving results.')

labels_df = pl.DataFrame(all_labels)

# Save to Parquet
labels_df.write_parquet(OUTPUT_FILE)
print(f"‚úì Saved labels to {OUTPUT_FILE} ({len(labels_df)} rows)")

# Show distribution
print("üìä Topic Distribution:")
print(labels_df.group_by('topic_id').agg(pl.len().alias('count')).sort('count', descending=True))