## Plot: A company launches an AI chat agent through a REST API that is smart in writing code. 
  
### Users seem to love it and use it consistently. The company starts to be profitable as their API has many paid users, but there are these concerns:
- Users compare to other services and find it consistently slow, especially for trivial cases.
- The company finds out that profitability is starting to look quite low in comparison with competitors.
  
**Diagnosis**: With some analysis, the company finds out that the model they are using is quite accurate but at the expense of speed and high price. They look at the queries that the users are bringing are only 50% code related, the rest of them are trivial questions that are factual.
  
**Solution**: Company comes to the conclusion that approximately half the queries can be answered faster model and without dropping accuracy with a weaker model. They have a chance to improve speed to contribute to the UX of many queries while only using longer wait times and more resources (infrastructure costs) for code related queries. To solve this, the company chooses to utilise a router that routes factual and general queries to a weaker model and code related queries to a strong model. For the sake of analogies let the weaker model be the Llama 3.1 8B model and the stronger model be Qwen3:coder.


Generally routing can be done based on the properties of the prompt/query. Some possible properties can be extracted semantically (with use of an embedder) which is what is followed in part 1 (below). The property of a prompt can also be extracted as per what category of the content the prompt falls in, with the use of another LLM, which is what is followed in part 2 (further below). But also, as per the target use of the router, a prompt property can be learned by training a model, which is what is done in part 3 (router_training_notebook.ipynb).


**Datasets** used can be understood in the data_notebook.ipynb file.

# Part One

### There are multiple ways of making a router, we try three way. The first one is where we encode every query of the user and try to match with the personalities our available models. The following is the code for this approach.

In [None]:
import os
import time
import json
import csv
from pathlib import Path
import gc

import numpy as np
import torch
from groq import Groq
from dotenv import load_dotenv
from typing import TypedDict, Literal, Callable, Optional
from collections import defaultdict
from langgraph.graph import StateGraph, END
from sentence_transformers import SentenceTransformer


In [3]:
# Utility functions
def cosine_similarity(a, b):
    """Calculate cosine similarity between two vectors."""
    a = np.array(a)
    b = np.array(b)
    return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))

def load_dataset(file_path):
    """Load the dataset from either JSONL or CSV format."""
    file_path = Path(file_path)
    
    if file_path.suffix == '.jsonl':
        data = []
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                data.append(json.loads(line.strip()))
        return data
    
    elif file_path.suffix == '.csv':
        data = []
        with open(file_path, 'r', encoding='utf-8') as f:
            reader = csv.DictReader(f)
            for row in reader:
                data.append(row)
        return data
    
    else:
        raise ValueError(f"Unsupported file format: {file_path.suffix}")


In [5]:

# State definition for the graph
class GraphState(TypedDict):
    prompt: str
    response: str
    agent_type: str
    predicted_category: str
    actual_category: str

def embed(text, embedder) -> np.ndarray:
    """Embed text using sentence-transformers."""
    return embedder.encode(text, convert_to_numpy=True)

def make_router_node(
    embedder: SentenceTransformer,
    coding_emb: np.ndarray,
    chat_emb: np.ndarray,
) -> Callable[[GraphState], GraphState]:
    def router_node(state: GraphState) -> GraphState:
        prompt = state["prompt"]
        user_emb = embed(prompt, embedder)

        sim_coding = cosine_similarity(user_emb, coding_emb)
        sim_chat = cosine_similarity(user_emb, chat_emb)

        if sim_coding > sim_chat:
            state["agent_type"] = "coding"
            state["predicted_category"] = "coding"
        else:
            state["agent_type"] = "chat"
            state["predicted_category"] = "factual"

        return state
    return router_node

# Placeholder coding agent - just counts routing decisions
def coding_agent_node(state: GraphState) -> GraphState:
    state["response"] = "[RESPONSE FROM QWEN3:CODER]"
    return state

# Placeholder chat agent - just counts routing decisions
def chat_agent_node(state: GraphState) -> GraphState:
    state["response"] = "[RESPONSE FROM CHEAP-LLAMA-MODEL]"
    return state

# Conditional edge function to determine routing
def route_decision(state: GraphState) -> Literal["coding_agent", "chat_agent"]:
    return "coding_agent" if state["agent_type"] == "coding" else "chat_agent"

# Create and configure the graph
def create_routing_graph(embedder: SentenceTransformer, 
                         coding_emb: np.ndarray, 
                         chat_emb: np.ndarray) -> StateGraph:
    workflow = StateGraph(GraphState)
    
    # Add nodes
    workflow.add_node("router", make_router_node(embedder, coding_emb, chat_emb))
    workflow.add_node("coding_agent", coding_agent_node)
    workflow.add_node("chat_agent", chat_agent_node)
    
    # Set entry point
    workflow.set_entry_point("router")
    
    # Add conditional edges
    workflow.add_conditional_edges(
        "router",
        route_decision,
        {
            "coding_agent": "coding_agent",
            "chat_agent": "chat_agent"
        }
    )
    
    # Add edges to end
    workflow.add_edge("coding_agent", END)
    workflow.add_edge("chat_agent", END)
    
    # Compile the graph
    return workflow.compile()

# dataset_files = ['prompts_2000.jsonl']#, 'prompts_2000.csv']
def evaluate_routing_accuracy(dataset_file: str,
                              embedder: SentenceTransformer,
                              coding_emb: np.ndarray,
                              chat_emb: np.ndarray):
    """Evaluate the routing accuracy using the dataset."""
    
    # Try to load the dataset
    data = None
    
    if Path(dataset_file).exists():
        print(f"Loading dataset from {dataset_file}...")
        data = load_dataset(dataset_file)

    if data is None:
        print("Error: No dataset found. Please run dataset_create.py first to generate the dataset.")
        return
    
    print(f"Loaded {len(data)} prompts from dataset")
    
    # Create the routing graph
    app = create_routing_graph(embedder=embedder,
                               coding_emb=coding_emb,
                               chat_emb=chat_emb)

    # Statistics tracking
    stats = defaultdict(int)
    correct_predictions = 0
    total_predictions = 0
    
    # Detailed results
    results = []
    
    print("Evaluating routing decisions...")
    
    for i, item in enumerate(data):
        if i % 200 == 0:
            print(f"Processed {i}/{len(data)} prompts...")
        
        prompt = item['prompt']
        actual_category = item['category']  # 'coding' or 'factual'
        
        # Initial state
        initial_state = {
            "prompt": prompt,
            "response": "",
            "agent_type": "",
            "predicted_category": "",
            "actual_category": actual_category
        }
        
        # Run the graph
        result = app.invoke(initial_state)
        
        predicted_category = result['predicted_category']
        
        # Count statistics
        stats[f"actual_{actual_category}"] += 1
        stats[f"predicted_{predicted_category}"] += 1
        
        # Check accuracy
        is_correct = predicted_category == actual_category
        if is_correct:
            correct_predictions += 1
            stats[f"correct_{actual_category}"] += 1
        else:
            stats[f"incorrect_{actual_category}"] += 1
        
        total_predictions += 1
        
        # Store detailed result
        results.append({
            'id': item.get('id', f'item_{i}'),
            'prompt': prompt[:100] + '...' if len(prompt) > 100 else prompt,
            'actual': actual_category,
            'predicted': predicted_category,
            'correct': is_correct,
            'agent_routed_to': result['agent_type']
        })
    
    # Calculate metrics
    overall_accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0
    
    # Calculate per-category metrics
    coding_correct = stats['correct_coding']
    coding_total = stats['actual_coding']
    coding_accuracy = coding_correct / coding_total if coding_total > 0 else 0
    
    factual_correct = stats['correct_factual']
    factual_total = stats['actual_factual']
    factual_accuracy = factual_correct / factual_total if factual_total > 0 else 0
    
    # Confusion matrix
    coding_predicted_as_coding = sum(1 for r in results if r['actual'] == 'coding' and r['predicted'] == 'coding')
    coding_predicted_as_factual = sum(1 for r in results if r['actual'] == 'coding' and r['predicted'] == 'factual')
    factual_predicted_as_coding = sum(1 for r in results if r['actual'] == 'factual' and r['predicted'] == 'coding')
    factual_predicted_as_factual = sum(1 for r in results if r['actual'] == 'factual' and r['predicted'] == 'factual')
    
    # Print results
    print("\n" + "="*60)
    print("ROUTING ACCURACY EVALUATION RESULTS")
    print("="*60)
    
    print(f"\nOverall Accuracy: {overall_accuracy:.3f} ({correct_predictions}/{total_predictions})")
    
    print(f"\nPer-Category Accuracy:")
    print(f"  Coding: {coding_accuracy:.3f} ({coding_correct}/{coding_total})")
    print(f"  Factual: {factual_accuracy:.3f} ({factual_correct}/{factual_total})")
    
    print(f"\nConfusion Matrix:")
    print(f"                    Predicted")
    print(f"Actual      Coding    Factual")
    print(f"Coding      {coding_predicted_as_coding:6d}    {coding_predicted_as_factual:7d}")
    print(f"Factual     {factual_predicted_as_coding:6d}    {factual_predicted_as_factual:7d}")
    
    print(f"\nRouting Statistics:")
    print(f"  Total prompts routed to coding agent: {stats['predicted_coding']}")
    print(f"  Total prompts routed to chat agent: {stats['predicted_factual']}")
    
    # Show some examples of misclassifications
    print(f"\nSample Misclassifications (first 5):")
    misclassified = [r for r in results if not r['correct']][:5]
    for i, r in enumerate(misclassified, 1):
        print(f"  {i}. Actual: {r['actual']}, Predicted: {r['predicted']}")
        print(f"     Prompt: {r['prompt']}")
        print()
    
    # Save detailed results
    output_file = "routing_evaluation_results.csv"
    with open(output_file, 'w', newline='', encoding='utf-8') as f:
        fieldnames = ['id', 'prompt', 'actual', 'predicted', 'correct', 'agent_routed_to']
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writeheader()
        writer.writerows(results)
    
    print(f"Detailed results saved to: {output_file}")


### Here we try to use a simple and fast embedder with very basic personalities that can match to the user query. 

We use LangGraph to form a pipeline for every user query. The query is embedded and a consine distance is calculated against the emmbedded personalities of the models, which ever one comes to be the closest is where the query is routed.

This approach is fast (based on the selected model) and has some levers that can be adjusted in accordance to what the target is (speed, quality, etc.).

We test our approach against a well balanced dataset with 1000 factual queries and 1000 code related queries. For more on dataset generation, look at the data_notebook.

In [7]:
# Agent prompts for routing
CODING_AGENT_PROMPT = "I am a coding assistant. I help with programming questions, algorithms, debugging, and software development."
CHAT_AGENT_PROMPT = "I am a general chat assistant. I help with everyday questions, facts, explanations, and general knowledge."

# Embedding model - using sentence-transformers for local embeddings
embedder = SentenceTransformer('all-MiniLM-L6-v2')  # Lightweight and fast model
coding_emb = embed(CODING_AGENT_PROMPT, embedder)
chat_emb = embed(CHAT_AGENT_PROMPT, embedder)

# Evaluate routing accuracy
evaluate_routing_accuracy("./data/prompts_2000.jsonl", embedder, coding_emb, chat_emb)

Loading dataset from ./data/prompts_2000.jsonl...
Loaded 2000 prompts from dataset
Evaluating routing decisions...
Processed 0/2000 prompts...
Processed 200/2000 prompts...
Processed 400/2000 prompts...
Processed 600/2000 prompts...
Processed 800/2000 prompts...
Processed 1000/2000 prompts...
Processed 1200/2000 prompts...
Processed 1400/2000 prompts...
Processed 1600/2000 prompts...
Processed 1800/2000 prompts...

ROUTING ACCURACY EVALUATION RESULTS

Overall Accuracy: 0.683 (1367/2000)

Per-Category Accuracy:
  Coding: 0.820 (820/1000)
  Factual: 0.547 (547/1000)

Confusion Matrix:
                    Predicted
Actual      Coding    Factual
Coding         820        180
Factual        453        547

Routing Statistics:
  Total prompts routed to coding agent: 1273
  Total prompts routed to chat agent: 727

Sample Misclassifications (first 5):
  1. Actual: factual, Predicted: coding
     Prompt: What is the main difference between ORC and Parquet?

  2. Actual: factual, Predicted: codi

### The current model's low accuracy (68%) is understandbale due to our use of a fast and generic embedding model.

First misclassication is a tough one but also understandable as it is factual question about programming, so it may fall on both sides.

At its core this is an embedding classification task, so lets look at the MTEB and choose one of the top performing embedding models over the task of classification specifically.  
Looking at the open-source option sorted by classification scores the Qwen3 embedding model falls in the top scoring models.  
Lets test it out with our approach.  

In [8]:
# Agent prompts for routing
CODING_AGENT_PROMPT = "I am a coding assistant. I help with programming questions, algorithms, debugging, and software development."
CHAT_AGENT_PROMPT = "I am a general chat assistant. I help with everyday questions, facts, explanations, and general knowledge."

# Embedding model - using sentence-transformers for local embeddings
embedder = SentenceTransformer("Qwen/Qwen3-Embedding-4B")

coding_emb = embed(CODING_AGENT_PROMPT, embedder)
chat_emb = embed(CHAT_AGENT_PROMPT, embedder)

# Evaluate routing accuracy
evaluate_routing_accuracy("./data/prompts_2000.jsonl", embedder, coding_emb, chat_emb)

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`
Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`
Fetching 2 files: 100%|██████████| 2/2 [04:25<00:00, 132.83s/it]
Loading checkpoint shards: 100%|██████████| 2/2 [00:29<00:00, 14.97s/it]
Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


Loading dataset from ./data/prompts_2000.jsonl...
Loaded 2000 prompts from dataset
Evaluating routing decisions...
Processed 0/2000 prompts...
Processed 200/2000 prompts...
Processed 400/2000 prompts...
Processed 600/2000 prompts...
Processed 800/2000 prompts...
Processed 1000/2000 prompts...
Processed 1200/2000 prompts...
Processed 1400/2000 prompts...
Processed 1600/2000 prompts...
Processed 1800/2000 prompts...

ROUTING ACCURACY EVALUATION RESULTS

Overall Accuracy: 0.897 (1794/2000)

Per-Category Accuracy:
  Coding: 0.931 (931/1000)
  Factual: 0.863 (863/1000)

Confusion Matrix:
                    Predicted
Actual      Coding    Factual
Coding         931         69
Factual        137        863

Routing Statistics:
  Total prompts routed to coding agent: 1068
  Total prompts routed to chat agent: 932

Sample Misclassifications (first 5):
  1. Actual: factual, Predicted: coding
     Prompt: What is the main difference between ORC and Parquet?

  2. Actual: factual, Predicted: codi

### This score is a clear improvement but to generalise to data outside of current distribution it would be smarter to have stronger and wider prompts which capture more words from incoming queries. Let's expand on the personalities of our agent prompts for that. This can help set better guardrails.

In [9]:
# Agent prompts for routing - Enhanced with detailed descriptions
CODING_AGENT_PROMPT = """I am a specialized coding and programming assistant. I help with:
- Writing, debugging, and optimizing code in Python, JavaScript, Java, C++, Go, Rust, and other languages
- Algorithm design, data structures, and computational complexity analysis
- Software architecture, design patterns, and best practices
- Web development (HTML, CSS, React, Node.js, APIs, databases)
- Machine learning and data science (pandas, numpy, scikit-learn, TensorFlow, PyTorch)
- DevOps, CI/CD, Docker, Kubernetes, and cloud deployment
- Code reviews, refactoring, and performance optimization
- Programming concepts like functions, classes, loops, recursion, and OOP
- Framework-specific help (Django, Flask, FastAPI, Spring, Express)
- Testing, unit tests, integration tests, and TDD
- Version control with Git and collaboration workflows
- Technical problem solving and computational thinking
- Code documentation, commenting, and maintainability
- Package management, dependencies, and build systems"""

CHAT_AGENT_PROMPT = """I am a general knowledge and conversational assistant. I help with:
- Answering factual questions about history, science, geography, and current events
- Explaining concepts in physics, chemistry, biology, mathematics, and other academic subjects
- Providing information about culture, arts, literature, and entertainment
- Discussing philosophy, ethics, psychology, and social sciences
- Offering advice on personal development, relationships, and life decisions
- Explaining how things work in everyday life and natural phenomena
- Helping with writing, grammar, language learning, and communication
- Providing travel information, recommendations, and cultural insights
- Discussing business, economics, politics, and world affairs
- Offering creative ideas for hobbies, crafts, and recreational activities
- Explaining health, wellness, nutrition, and lifestyle topics
- Helping with educational support across various non-technical subjects
- Engaging in casual conversation and answering general curiosity questions
- Providing summaries and explanations of complex topics in simple terms
- Offering perspectives on ethical dilemmas and philosophical questions"""

coding_emb = embed(CODING_AGENT_PROMPT, embedder)
chat_emb = embed(CHAT_AGENT_PROMPT, embedder)

# Evaluate routing accuracy
evaluate_routing_accuracy("./data/prompts_2000.jsonl", embedder, coding_emb, chat_emb)

Loading dataset from ./data/prompts_2000.jsonl...
Loaded 2000 prompts from dataset
Evaluating routing decisions...
Processed 0/2000 prompts...
Processed 200/2000 prompts...
Processed 400/2000 prompts...
Processed 600/2000 prompts...
Processed 800/2000 prompts...
Processed 1000/2000 prompts...
Processed 1200/2000 prompts...
Processed 1400/2000 prompts...
Processed 1600/2000 prompts...
Processed 1800/2000 prompts...

ROUTING ACCURACY EVALUATION RESULTS

Overall Accuracy: 0.936 (1873/2000)

Per-Category Accuracy:
  Coding: 0.942 (942/1000)
  Factual: 0.931 (931/1000)

Confusion Matrix:
                    Predicted
Actual      Coding    Factual
Coding         942         58
Factual         69        931

Routing Statistics:
  Total prompts routed to coding agent: 1011
  Total prompts routed to chat agent: 989

Sample Misclassifications (first 5):
  1. Actual: factual, Predicted: coding
     Prompt: What is the main difference between ORC and Parquet?

  2. Actual: factual, Predicted: codi

### That's another big improvement. Lets stop here for now and explore other methods.

In [None]:
del(embedder)
gc.collect()                  # free unreferenced Python objects
torch.cuda.empty_cache()      # return cached blocks to the driver
torch.cuda.ipc_collect()      # (optional) collect interprocess cached memory

# Part Two:
In part two we let a LLM calculate the property by classifying the prompt. We use a cheap LLM model (i.e. Llama 3.1 8B), and to query it we use Groq. I like Groq due the speed some of their servings offer.  
We provide a detailed prompt to the LLM to identify if the incoming query is of class coding or class factual (not a coding related query).
Now this approach may prove to be a slower one but for the current case this could be a reliable one. We also don't need to train any model when using this approach. However it may increase cost when the input query size is scaled whereas in the previous approach cost could be lowered by setting max_tokens (because in most cases (not all) part of the query is enough to classify).

In [12]:
# Load environment variables
load_dotenv()

True

In [14]:
Label = str  # "coding" | "factual"

CLASSIFICATION_PROMPT = """You are a routing classifier. Your job is to classify user queries into exactly one of two categories:

1. "coding" - Programming, software development, technical implementation questions including:
   - Writing, debugging, or optimizing code
   - Programming languages (Python, JavaScript, Java, C++, etc.)
   - Software development concepts (algorithms, data structures, OOP)
   - Web development, APIs, databases
   - DevOps, CI/CD, deployment
   - Code reviews, testing, debugging
   - Technical problem solving
   - Programming frameworks and libraries

2. "factual" - General knowledge, educational, conversational questions including:
   - History, science, geography, current events
   - Academic subjects (physics, chemistry, biology, math)
   - Culture, arts, literature, entertainment
   - Personal advice, relationships, life decisions
   - General explanations of how things work
   - Travel, recommendations, lifestyle
   - Creative ideas, hobbies, crafts
   - Health, wellness, nutrition
   - Philosophy, ethics, social topics

Analyze the following user query and respond with ONLY the category name: either "coding" or "factual"

User query: {query}

Category:"""

def classify_with_llm(
    prompt: str,
    model: str = "llama-3.1-8b-instant",# Cheap model to try out
    retries: int = 2,
    backoff: float = 0.5,               # seconds; grows exponentially
    on_error: Label = "coding",         # Expensive default to not let product impression suffer
    fallback: Optional[Callable[[str], Label]] = None,
) -> Label:
    """
    Classify using Groq, returning exactly 'coding' or 'factual'.
    Retries transient errors; optionally calls `fallback(prompt)` on failure.
    """
    api_key = os.getenv("GROQ_API_KEY")
    if not api_key:
        # If even the key is missing, go straight to fallback/default.
        return fallback(prompt) if fallback else on_error

    client = Groq(api_key=api_key)

    for attempt in range(retries + 1):
        try:
            resp = client.chat.completions.create(
                model=model,
                temperature=0,
                max_tokens=5,
                messages=[
                    {"role": "system", "content": "Reply with exactly one word: coding or factual."},
                    {"role": "user", "content": CLASSIFICATION_PROMPT.format(query=prompt)},
                ],
            )
            text = (resp.choices[0].message.content or "").strip().lower()
            if "coding" in text:
                return "coding"
            if "factual" in text:
                return "factual"
            # Unexpected content: force a safe default
            return on_error
        except Exception as e:
            # Optional: log e to your logger/Sentry here
            if attempt < retries:
                time.sleep(backoff * (2 ** attempt))
            else:
                # Exhausted retries → use fallback or default
                return fallback(prompt) if fallback else on_error

# Router node - uses LLM to determine category
def llm_router_node(state: GraphState) -> GraphState:
    prompt = state["prompt"]
    
    # Use LLM to classify the prompt
    predicted_category = classify_with_llm(prompt)
    
    state["agent_type"] = "coding" if predicted_category == "coding" else "chat"
    state["predicted_category"] = predicted_category
    
    return state

# Create the LLM-based routing graph
def create_llm_routing_graph():
    workflow = StateGraph(GraphState)
    
    # Add nodes
    workflow.add_node("llm_router", llm_router_node)
    workflow.add_node("coding_agent", coding_agent_node)
    workflow.add_node("chat_agent", chat_agent_node)
    
    # Set entry point
    workflow.set_entry_point("llm_router")
    
    # Add conditional edges
    workflow.add_conditional_edges(
        "llm_router",
        route_decision,
        {
            "coding_agent": "coding_agent",
            "chat_agent": "chat_agent"
        }
    )
    
    # Add edges to end
    workflow.add_edge("coding_agent", END)
    workflow.add_edge("chat_agent", END)
    
    return workflow.compile()

def evaluate_llm_routing_accuracy(dataset_file: str):
    """Evaluate the LLM-based routing accuracy."""
    
    # Load dataset
    data = None
    if Path(dataset_file).exists():
        print(f"Loading dataset from {dataset_file}...")
        data = load_dataset(dataset_file)
    
    if data is None:
        print("Error: No dataset found. Please run dataset_create.py first.")
        return
    
    print(f"Loaded {len(data)} prompts from dataset")
    
    # Create the LLM routing graph
    app = create_llm_routing_graph()
    
    # Statistics tracking
    stats = defaultdict(int)
    correct_predictions = 0
    total_predictions = 0
    
    # Detailed results
    results = []
    
    print("Evaluating LLM-based routing decisions...")
    print("Note: This may take a while depending on LLM response times...")
    
    for i, item in enumerate(data):
        if i % 50 == 0:  # Less frequent updates due to slower LLM calls
            print(f"Processed {i}/{len(data)} prompts...")
        
        prompt = item['prompt']
        actual_category = item['category']
        
        # Initial state
        initial_state = {
            "prompt": prompt,
            "response": "",
            "agent_type": "",
            "predicted_category": "",
            "actual_category": actual_category
        }
        
        try:
            # Run the graph
            result = app.invoke(initial_state)
            predicted_category = result['predicted_category']
            
            # Count statistics
            stats[f"actual_{actual_category}"] += 1
            stats[f"predicted_{predicted_category}"] += 1
            
            # Check accuracy
            is_correct = predicted_category == actual_category
            if is_correct:
                correct_predictions += 1
                stats[f"correct_{actual_category}"] += 1
            else:
                stats[f"incorrect_{actual_category}"] += 1
            
            total_predictions += 1
            
            # Store result
            results.append({
                'id': item.get('id', f'item_{i}'),
                'prompt': prompt[:100] + '...' if len(prompt) > 100 else prompt,
                'actual': actual_category,
                'predicted': predicted_category,
                'correct': is_correct,
                'agent_routed_to': result['agent_type']
            })
            
        except Exception as e:
            print(f"Error processing item {i}: {e}")
            # Continue with next item
            continue
    
    # Calculate metrics
    overall_accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0
    
    coding_correct = stats['correct_coding']
    coding_total = stats['actual_coding']
    coding_accuracy = coding_correct / coding_total if coding_total > 0 else 0
    
    factual_correct = stats['correct_factual']
    factual_total = stats['actual_factual']
    factual_accuracy = factual_correct / factual_total if factual_total > 0 else 0
    
    # Confusion matrix
    coding_predicted_as_coding = sum(1 for r in results if r['actual'] == 'coding' and r['predicted'] == 'coding')
    coding_predicted_as_factual = sum(1 for r in results if r['actual'] == 'coding' and r['predicted'] == 'factual')
    factual_predicted_as_coding = sum(1 for r in results if r['actual'] == 'factual' and r['predicted'] == 'coding')
    factual_predicted_as_factual = sum(1 for r in results if r['actual'] == 'factual' and r['predicted'] == 'factual')
    
    # Print results
    print("\n" + "="*60)
    print("LLM-BASED ROUTING ACCURACY EVALUATION RESULTS")
    print("="*60)
    
    print(f"\nOverall Accuracy: {overall_accuracy:.3f} ({correct_predictions}/{total_predictions})")
    
    print(f"\nPer-Category Accuracy:")
    print(f"  Coding: {coding_accuracy:.3f} ({coding_correct}/{coding_total})")
    print(f"  Factual: {factual_accuracy:.3f} ({factual_correct}/{factual_total})")
    
    print(f"\nConfusion Matrix:")
    print(f"                    Predicted")
    print(f"Actual      Coding    Factual")
    print(f"Coding      {coding_predicted_as_coding:6d}    {coding_predicted_as_factual:7d}")
    print(f"Factual     {factual_predicted_as_coding:6d}    {factual_predicted_as_factual:7d}")
    
    print(f"\nRouting Statistics:")
    print(f"  Total prompts routed to coding agent: {stats['predicted_coding']}")
    print(f"  Total prompts routed to chat agent: {stats['predicted_factual']}")
    
    # Show misclassifications
    print(f"\nSample Misclassifications (first 5):")
    misclassified = [r for r in results if not r['correct']][:5]
    for i, r in enumerate(misclassified, 1):
        print(f"  {i}. Actual: {r['actual']}, Predicted: {r['predicted']}")
        print(f"     Prompt: {r['prompt']}")
        print()
    
    # Save results
    output_file = "llm_routing_evaluation_results.csv"
    with open(output_file, 'w', newline='', encoding='utf-8') as f:
        fieldnames = ['id', 'prompt', 'actual', 'predicted', 'correct', 'agent_routed_to']
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writeheader()
        writer.writerows(results)
    
    print(f"Detailed results saved to: {output_file}")




In [15]:
evaluate_llm_routing_accuracy("./data/prompts_2000.jsonl")

Loading dataset from ./data/prompts_2000.jsonl...
Loaded 2000 prompts from dataset
Evaluating LLM-based routing decisions...
Note: This may take a while depending on LLM response times...
Processed 0/2000 prompts...
Processed 50/2000 prompts...
Processed 50/2000 prompts...
Processed 100/2000 prompts...
Processed 100/2000 prompts...
Processed 150/2000 prompts...
Processed 150/2000 prompts...
Processed 200/2000 prompts...
Processed 200/2000 prompts...
Processed 250/2000 prompts...
Processed 250/2000 prompts...
Processed 300/2000 prompts...
Processed 300/2000 prompts...
Processed 350/2000 prompts...
Processed 350/2000 prompts...
Processed 400/2000 prompts...
Processed 400/2000 prompts...
Processed 450/2000 prompts...
Processed 450/2000 prompts...
Processed 500/2000 prompts...
Processed 500/2000 prompts...
Processed 550/2000 prompts...
Processed 550/2000 prompts...
Processed 600/2000 prompts...
Processed 600/2000 prompts...
Processed 650/2000 prompts...
Processed 650/2000 prompts...
Proces

Alright so our priliminary experiments show that categorical classification is better when the model is a LLM and not an embedding model. But again, depending on what we're trying optimise, either of the approach can be more fitting.