# Universal Model Routing Experiment

**Paper**: [Universal Model Routing for Efficient LLM Inference](https://arxiv.org/pdf/2502.08773)  
**Authors**: Jitkrittum et al. (2025)

## Core Concept
Route queries to different LLMs using **cluster-based error profiles** that work with new unseen models without retraining.

**Key Innovation**: Ψ(m) vectors - error rates per question cluster for each model.

In [None]:
# Setup
!pip install -q openai scikit-learn numpy sentence-transformers datasets pandas pydantic matplotlib groq

import os
import json
import pickle
import numpy as np
import time
from typing import List, Dict, Optional, Literal
from collections import defaultdict
from sklearn.cluster import KMeans
from sentence_transformers import SentenceTransformer
from pydantic import BaseModel
import matplotlib.pyplot as plt
from datasets import load_dataset

print("✅ Setup complete")

# Add your API keys as environment variables:\n# export OPENAI_API_KEY='your-key'\n# export GROQ_API_KEY='your-key'\n\nAPI_KEYS = {\n    'openai': os.getenv('OPENAI_API_KEY'),\n    'groq': os.getenv('GROQ_API_KEY')\n}\n\n# Model pool with cost estimates\nMODELS = [\n    {'name': 'gpt-4o', 'provider': 'openai', 'cost': 6.25},\n    {'name': 'gpt-4.1', 'provider': 'openai', 'cost': 5.00},\n    {'name': 'gpt-4o-mini', 'provider': 'openai', 'cost': 0.375},\n    {'name': 'meta-llama/llama-guard-4-12b', 'provider': 'groq', 'cost': 0.20},\n    {'name': 'llama-3.1-8b-instant', 'provider': 'groq', 'cost': 0.065}\n]\n\nprint(f\"✅ {len(MODELS)} models configured\")\nprint(f\"📊 Cost range: ${min(m['cost'] for m in MODELS)} - ${max(m['cost'] for m in MODELS)}\")

In [None]:
# Add your API keys
API_KEYS = {
    'openai': 'your-openai-key-here',
    'groq': 'your-groq-key-here'
}

MODELS = [
    {'name': 'gpt-4o', 'provider': 'openai', 'cost': 6.25},
    {'name': 'gpt-4.1', 'provider': 'openai', 'cost': 5.00},
    {'name': 'gpt-4o-mini', 'provider': 'openai', 'cost': 0.375},
    {'name': 'meta-llama/llama-guard-4-12b', 'provider': 'groq', 'cost': 0.20},
    {'name': 'llama-3.1-8b-instant', 'provider': 'groq', 'cost': 0.065}
]

print(f"✅ {len(MODELS)} models")
print(f"📊 Cost: ${min(m['cost'] for m in MODELS)} - ${max(m['cost'] for m in MODELS)}")

## Validation Set + Clustering

**Paper Method**: MMLU validation set, K-means clustering on question embeddings

In [None]:
# Load MMLU validation
dataset = load_dataset("cais/mmlu", "all", split="validation")
dataset = dataset.shuffle(seed=42).select(range(200))

# Format for MCQ
VALIDATION_SET = []
for i, item in enumerate(dataset):
    prompt = f"{item['question']}\nA) {item['choices'][0]}\nB) {item['choices'][1]}\nC) {item['choices'][2]}\nD) {item['choices'][3]}\nAnswer:"
    VALIDATION_SET.append({
        'prompt': prompt,
        'answer': ['A', 'B', 'C', 'D'][item['answer']],
        'question': item['question']
    })

print(f"✅ {len(VALIDATION_SET)} validation examples")

In [None]:
# Cluster questions
embedder = SentenceTransformer('all-MiniLM-L6-v2')
embeddings = embedder.encode([item['question'] for item in VALIDATION_SET])

K = 10
kmeans = KMeans(n_clusters=K, random_state=42)
clusters = kmeans.fit_predict(embeddings)

# Group by cluster
validation_clusters = defaultdict(list)
for i, cluster_id in enumerate(clusters):
    validation_clusters[cluster_id].append(VALIDATION_SET[i])

print(f"✅ {K} clusters created")
for k in range(K):
    print(f"  Cluster {k}: {len(validation_clusters[k])} examples")