# Intermediate 04: Comparing Embedding Models

Learn how to objectively compare different embedding models using retrieval quality metrics.

## What You'll Learn

- **Discover** multiple embedding models stored in the registry
- **Load and retrieve** from multiple models in parallel
- **Compute metrics** to measure retrieval quality (Precision, Recall, MRR, NDCG)
- **Visualize** comparisons with charts and tables
- **Analyze trade-offs** and make recommendations

## Prerequisites

1. Run `foundation/02-rag-postgresql-persistent.ipynb` to generate and store embeddings for the **base model**
2. (Optional) For comparison with a second model:
   - Install: `ollama pull hf.co/CompendiumLabs/bge-small-en-v1.5-gguf`
   - Edit foundation/02 to use the second model
   - Run foundation/02 again to generate embeddings for the second model
   - Return here to compare!

## Learning Outcomes

By the end of this notebook, you'll understand:

- How embedding models differ in retrieval quality
- How to measure quality using information retrieval metrics
- Speed vs. quality trade-offs in embedding selection
- How to make data-driven recommendations for model selection

---

## Setup and Configuration

In [None]:
import ollama
import psycopg2
import psycopg2.extras
import pandas as pd
import numpy as np
import time
import json
import math
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Dict, List, Tuple, Set
from collections import defaultdict

# Configure visualization style
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)

POSTGRES_CONFIG = {
    'host': 'localhost',
    'port': 5432,
    'database': 'rag_db',
    'user': 'postgres',
    'password': 'postgres',
}

print("Configuration loaded")

## Part 1: List Available Models

**What we're doing:** Query the embedding registry to find all available embedding models.

The registry table stores metadata about every embedding model that has been generated:
- `model_alias`: Short identifier (e.g., 'bge_base_en_v1_5')
- `model_name`: Full model name (e.g., 'hf.co/CompendiumLabs/bge-base-en-v1.5-gguf')
- `dimension`: Vector size (e.g., 768 or 384)
- `embedding_count`: Number of chunks stored
- `created_at`: When embeddings were generated

We'll format this as a clean pandas DataFrame for easy reading.

In [None]:
def list_available_models(postgres_config: Dict) -> Tuple[pd.DataFrame, psycopg2.extensions.connection]:
    """Query embedding_registry to discover all available embedding models.
    
    Returns:
        Tuple of (DataFrame with models, PostgreSQL connection)
    """
    conn = psycopg2.connect(
        host=postgres_config['host'],
        port=postgres_config['port'],
        database=postgres_config['database'],
        user=postgres_config['user'],
        password=postgres_config['password']
    )
    
    query = '''
        SELECT 
            model_alias, 
            model_name, 
            dimension, 
            embedding_count, 
            created_at,
            chunk_size_config
        FROM embedding_registry
        ORDER BY created_at DESC
    '''
    
    available = pd.read_sql(query, conn)
    
    print("=" * 80)
    print("AVAILABLE EMBEDDING MODELS IN REGISTRY")
    print("=" * 80)
    
    if available.empty:
        print("\n⚠️  No embedding models found in registry!")
        print("\nTo populate the registry:")
        print("  1. Run foundation/02-rag-postgresql-persistent.ipynb")
        print("  2. For a second model for comparison:")
        print("     - ollama pull hf.co/CompendiumLabs/bge-small-en-v1.5-gguf")
        print("     - Edit foundation/02 to use the new model")
        print("     - Run foundation/02 again")
    else:
        print()
        print(available.to_string(index=False))
        print(f"\nTotal: {len(available)} model(s) available")
    
    print()
    return available, conn

# Discover available models
available_models, conn = list_available_models(POSTGRES_CONFIG)

### Check for Sufficient Models

For meaningful comparison, we need at least 2 embedding models. If you only have 1, the instructional text above shows how to generate a second model.

For now, we'll proceed with whatever models are available (analysis will note if only 1 model).

In [None]:
if len(available_models) == 0:
    print("✗ STOP: No embedding models found. Run foundation/02 first.")
elif len(available_models) == 1:
    print(f"⚠️  Only 1 model available: {available_models.iloc[0]['model_alias']}")
    print("   For meaningful comparison, generate a second model (see instructions above)")
else:
    print(f"✓ {len(available_models)} models available for comparison")

## Part 2: Load and Compare Retrieval Results

**What we're doing:** For each embedding model in the registry:
1. Load the model's embeddings from PostgreSQL
2. Define a set of test queries
3. Retrieve top-5 chunks for each query
4. Store results for metric computation

### Create the Vector Database Class

In [None]:
class PostgreSQLVectorDB:
    """Helper class to query embeddings from PostgreSQL using pgvector.
    
    This class provides similarity search on pre-computed embeddings stored in
    a PostgreSQL table with pgvector extension.
    """
    
    def __init__(self, config: Dict, table_name: str):
        """Initialize database connection.
        
        Args:
            config: PostgreSQL config dict with host, port, database, user, password
            table_name: Name of table storing embeddings
        """
        self.config = config
        self.table_name = table_name
        self.conn = psycopg2.connect(
            host=config['host'],
            port=config['port'],
            database=config['database'],
            user=config['user'],
            password=config['password']
        )
    
    def get_chunk_count(self) -> int:
        """Get total number of chunks stored.
        
        Returns:
            Integer count of embeddings in table
        """
        with self.conn.cursor() as cur:
            cur.execute(f'SELECT COUNT(*) FROM {self.table_name}')
            return cur.fetchone()[0]
    
    def similarity_search(self, query_embedding: List[float], top_n: int = 5) -> List[Tuple[str, float]]:
        """Find most similar chunks using cosine similarity (pgvector).
        
        Args:
            query_embedding: Query vector (list of floats)
            top_n: Number of results to return
        
        Returns:
            List of (chunk_text, similarity_score) tuples, sorted by similarity descending
        """
        with self.conn.cursor() as cur:
            cur.execute(f'''
                SELECT chunk_text, 1 - (embedding <=> %s::vector) as similarity
                FROM {self.table_name}
                ORDER BY embedding <=> %s::vector
                LIMIT %s
            ''', (query_embedding, query_embedding, top_n))
            return [(chunk, float(score)) for chunk, score in cur.fetchall()]
    
    def get_chunk_ids_for_similarity_search(self, query_embedding: List[float], top_n: int = 5) -> List[int]:
        """Find most similar chunks and return their IDs (for metrics computation).
        
        Args:
            query_embedding: Query vector
            top_n: Number of results
        
        Returns:
            List of chunk IDs in order of similarity
        """
        with self.conn.cursor() as cur:
            cur.execute(f'''
                SELECT id
                FROM {self.table_name}
                ORDER BY embedding <=> %s::vector
                LIMIT %s
            ''', (query_embedding, top_n))
            return [row[0] for row in cur.fetchall()]
    
    def close(self):
        """Close database connection."""
        if self.conn:
            self.conn.close()

print("Vector DB class loaded")

### Define Test Queries

These are representative questions that we'll use to test both models.
Good test queries are:
- **Diverse**: Cover different topics (science, history, geography, etc.)
- **Specific enough**: Not too vague ("What is X?" not "Tell me about things")
- **Real use cases**: Questions users would actually ask

The quality of comparison depends on representative test queries!

In [None]:
TEST_QUERIES = [
    "What is photosynthesis?",
    "When was World War 2?",
    "What is the capital of France?",
    "How does the human heart work?",
    "What is quantum mechanics?",
    "Who invented the telephone?",
    "What is climate change?",
    "Where is Mount Everest?",
    "What is the Roman Empire?",
    "How does DNA work?",
]

print(f"Test Query Set: {len(TEST_QUERIES)} queries")
print("\nQueries:")
for i, query in enumerate(TEST_QUERIES, 1):
    print(f"  {i:2}. {query}")

### Retrieve from All Models

For each model in the registry:
1. Create a PostgreSQLVectorDB instance pointing to that model's table
2. For each test query:
   - Generate query embedding
   - Retrieve top-5 chunks
   - Store results for metric computation

We track:
- Chunks retrieved (text and score)
- Chunk IDs for later relevance labeling
- Query latency (time to compute embedding + retrieve)

In [None]:
def retrieve_from_all_models(
    available_models: pd.DataFrame,
    postgres_config: Dict,
    test_queries: List[str],
    top_n: int = 5
) -> Tuple[Dict, Dict]:
    """Retrieve top-N chunks from all available models for all test queries.
    
    Args:
        available_models: DataFrame from list_available_models()
        postgres_config: PostgreSQL connection config
        test_queries: List of test queries
        top_n: Number of chunks to retrieve per query
    
    Returns:
        Tuple of:
        - results: Dict[model_alias][query] = [(chunk_text, score), ...]
        - latencies: Dict[model_alias][query] = elapsed_time_ms
    """
    results = {}
    latencies = defaultdict(dict)
    
    print("\n" + "=" * 80)
    print("RETRIEVING FROM ALL MODELS")
    print("=" * 80)
    
    for _, row in available_models.iterrows():
        model_alias = row['model_alias']
        model_name = row['model_name']
        
        # Connect to this model's table
        table_name = f"embeddings_{model_alias.replace('.', '_')}"
        
        try:
            db = PostgreSQLVectorDB(postgres_config, table_name)
            chunk_count = db.get_chunk_count()
            
            print(f"\nModel: {model_alias}")
            print(f"  Table: {table_name}")
            print(f"  Chunks: {chunk_count:,}")
            print(f"  Retrieving from {len(test_queries)} queries...")
            
            results[model_alias] = {}
            
            for query in test_queries:
                # Time the retrieval (embedding + search)
                start = time.time()
                
                # Generate query embedding
                query_embedding = ollama.embed(model=model_name, input=query)['embeddings'][0]
                
                # Retrieve top-N chunks
                retrieved = db.similarity_search(query_embedding, top_n=top_n)
                
                elapsed_ms = (time.time() - start) * 1000
                latencies[model_alias][query] = elapsed_ms
                
                # Store results
                results[model_alias][query] = retrieved
            
            db.close()
            
            # Summary for this model
            avg_latency = np.mean(list(latencies[model_alias].values()))
            print(f"  ✓ Complete (avg latency: {avg_latency:.0f}ms)")
            
        except Exception as e:
            print(f"  ✗ Error loading model: {e}")
    
    return results, dict(latencies)

# Retrieve from all models
if len(available_models) > 0:
    retrieval_results, retrieval_latencies = retrieve_from_all_models(
        available_models,
        POSTGRES_CONFIG,
        TEST_QUERIES,
        top_n=5
    )
    print(f"\n✓ Retrieved results from {len(retrieval_results)} model(s)")
else:
    retrieval_results = {}
    retrieval_latencies = {}
    print("⚠️  Skipping retrieval (no models available)")

### View Sample Results

Let's look at what one model retrieved for a sample query.

In [None]:
if retrieval_results:
    sample_model = list(retrieval_results.keys())[0]
    sample_query = TEST_QUERIES[0]
    
    print(f"Sample: Model '{sample_model}' retrieving for query '{sample_query}'")
    print()
    
    retrieved = retrieval_results[sample_model][sample_query]
    for i, (chunk, score) in enumerate(retrieved, 1):
        # Extract title from chunk
        lines = chunk.split('\n')
        title = lines[0] if lines[0].startswith('Article:') else 'Unknown'
        
        # Show preview
        preview = chunk[:250].replace('\n', ' ') + '...'
        print(f"[{i}] Similarity: {score:.4f}")
        print(f"    {title}")
        print(f"    {preview}")
        print()

## Part 3: Compute Comparison Metrics

**What we're doing:** Define metric functions to measure retrieval quality.

### Information Retrieval Metrics

Since we don't have ground truth relevance labels yet, we'll use **overlap-based evaluation**:
- Assume chunks that appear in multiple models' results are likely relevant
- Measure **coverage**: Do both models retrieve similar relevant chunks?
- Measure **ranking quality**: Does each model rank relevant chunks high?

In a production system, you would:
1. Have human annotators label which chunks are relevant to each query
2. Use those labels to compute true precision/recall
3. Compare against that ground truth

### Metric Definitions

In [None]:
def compute_precision_at_k(
    retrieved_ids: List[int],
    relevant_ids: Set[int],
    k: int = 5
) -> float:
    """Compute Precision@K: what % of top-K results are relevant?
    
    Formula: |{relevant results in top-K}| / K
    
    Interpretation:
    - Precision = 0.8 means 80% of top-5 results are relevant
    - Good for: "How many of the results I saw were useful?"
    
    Args:
        retrieved_ids: List of chunk IDs returned by retrieval (in order)
        relevant_ids: Set of chunk IDs that are actually relevant
        k: Top-K threshold
    
    Returns:
        Float between 0 and 1
    """
    retrieved_k = retrieved_ids[:k]
    matches = sum(1 for chunk_id in retrieved_k if chunk_id in relevant_ids)
    return matches / k if k > 0 else 0.0


def compute_recall_at_k(
    retrieved_ids: List[int],
    relevant_ids: Set[int],
    k: int = 5
) -> float:
    """Compute Recall@K: what % of all relevant chunks did we find?
    
    Formula: |{relevant results in top-K}| / |all relevant items|
    
    Interpretation:
    - Recall = 0.6 means we found 60% of all relevant chunks in top-5
    - Good for: "Did I find all the useful results?"
    
    Args:
        retrieved_ids: List of chunk IDs returned by retrieval
        relevant_ids: Set of chunk IDs that are actually relevant
        k: Top-K threshold
    
    Returns:
        Float between 0 and 1 (or 0 if no relevant items exist)
    """
    if len(relevant_ids) == 0:
        return 0.0
    
    retrieved_k = retrieved_ids[:k]
    matches = sum(1 for chunk_id in retrieved_k if chunk_id in relevant_ids)
    return matches / len(relevant_ids)


def compute_mrr(
    retrieved_ids: List[int],
    relevant_ids: Set[int]
) -> float:
    """Compute Mean Reciprocal Rank: where is the first relevant result?
    
    Formula: 1 / (position of first relevant result)
    
    Interpretation:
    - MRR = 1.0 means first result is relevant (position 1, 1/1 = 1.0)
    - MRR = 0.5 means first relevant result is at position 2 (1/2 = 0.5)
    - MRR = 0.0 means no relevant result found
    - Good for: "How quickly did I find a relevant result?"
    
    Args:
        retrieved_ids: List of chunk IDs returned by retrieval
        relevant_ids: Set of chunk IDs that are actually relevant
    
    Returns:
        Float between 0 and 1
    """
    for position, chunk_id in enumerate(retrieved_ids, 1):
        if chunk_id in relevant_ids:
            return 1.0 / position
    return 0.0


def compute_ndcg_at_k(
    retrieved_ids: List[int],
    relevant_ids: Set[int],
    k: int = 5
) -> float:
    """Compute Normalized Discounted Cumulative Gain: ranking quality metric.
    
    NDCG measures how well-ranked the results are:
    - Relevant items should appear earlier (higher discount for lower positions)
    - Normalized against an ideal ranking (best possible score = 1.0)
    
    Formula:
    - DCG = sum of (relevance / log2(position + 1)) for each result
    - IDCG = DCG of perfect ranking (all relevant items first)
    - NDCG = DCG / IDCG
    
    Interpretation:
    - NDCG = 0.9 means ranking is 90% as good as ideal
    - Penalizes relevant items appearing low in ranking more than Precision/Recall
    - Good for: "How well-ordered are the results?"
    
    Args:
        retrieved_ids: List of chunk IDs returned by retrieval
        relevant_ids: Set of chunk IDs that are actually relevant
        k: Top-K threshold
    
    Returns:
        Float between 0 and 1
    """
    def dcg_score(relevances: List[int]) -> float:
        """Compute Discounted Cumulative Gain."""
        return sum((rel) / math.log2(i + 2) for i, rel in enumerate(relevances))
    
    # Get top-K results
    retrieved_k = retrieved_ids[:k]
    
    # Binary relevance: 1 if relevant, 0 if not
    relevances = [1 if chunk_id in relevant_ids else 0 for chunk_id in retrieved_k]
    
    # Compute actual DCG
    dcg = dcg_score(relevances)
    
    # Compute ideal DCG (perfect ranking: all relevant items first)
    ideal_relevances = sorted(relevances, reverse=True)
    idcg = dcg_score(ideal_relevances)
    
    # Normalize
    if idcg == 0:
        return 0.0
    
    return dcg / idcg


def compute_overlap_based_metrics(
    retrieval_results: Dict,
    top_n: int = 5,
    min_model_agreement: int = 2
) -> Dict:
    """Compute metrics using overlap-based relevance (chunks retrieved by multiple models).
    
    Since we don't have ground truth labels, we assume:
    - Chunks retrieved by 2+ models are likely relevant (high overlap = high quality)
    - Chunks retrieved by only 1 model are less reliable
    
    Args:
        retrieval_results: Dict[model][query] = [(chunk_text, score), ...]
        top_n: Top-N threshold for metrics
        min_model_agreement: Minimum number of models that must retrieve a chunk
    
    Returns:
        Dict of computed metrics per model per query
    """
    metrics = {}
    
    # Only relevant if we have multiple models
    if len(retrieval_results) < 2:
        print("⚠️  Overlap-based metrics require 2+ models. Using alternative evaluation...")
        # Single model: use keyword-based relevance instead
        return compute_keyword_based_metrics(retrieval_results)
    
    # For each query, find chunks retrieved by multiple models
    for query_idx, query in enumerate(TEST_QUERIES):
        # Collect chunk texts per model
        chunks_by_model = {}
        for model_alias in retrieval_results:
            retrieved = retrieval_results[model_alias].get(query, [])
            chunks_by_model[model_alias] = [chunk_text for chunk_text, _ in retrieved]
        
        # Find overlap: chunks retrieved by 2+ models
        relevant_chunks = set()
        chunk_counts = {}
        
        for model_alias, chunks in chunks_by_model.items():
            for chunk in chunks:
                chunk_counts[chunk] = chunk_counts.get(chunk, 0) + 1
        
        # Mark chunks as relevant if retrieved by min_model_agreement or more models
        for chunk, count in chunk_counts.items():
            if count >= min_model_agreement:
                relevant_chunks.add(chunk)
        
        # Compute metrics for each model
        for model_alias in retrieval_results:
            if query not in metrics:
                metrics[query] = {}
            
            retrieved = retrieval_results[model_alias].get(query, [])
            retrieved_chunks = [chunk_text for chunk_text, _ in retrieved]
            
            # Count matches
            matches = sum(1 for chunk in retrieved_chunks if chunk in relevant_chunks)
            
            # Compute metrics
            precision = matches / top_n if top_n > 0 else 0.0
            recall = matches / len(relevant_chunks) if len(relevant_chunks) > 0 else 0.0
            
            if model_alias not in metrics[query]:
                metrics[query][model_alias] = {}
            
            metrics[query][model_alias] = {
                'precision@5': precision,
                'recall@5': recall,
                'overlap': matches
            }
    
    return metrics


def compute_keyword_based_metrics(retrieval_results: Dict) -> Dict:
    """Compute metrics using simple keyword matching as proxy for relevance.
    
    This is a fallback when we have only 1 model. We measure relevance as:
    - Does the retrieved chunk contain keywords from the query?
    - Is it from a specific/authoritative source?
    
    Args:
        retrieval_results: Dict[model][query] = [(chunk_text, score), ...]
    
    Returns:
        Dict of metrics per model per query
    """
    metrics = {}
    
    for query in TEST_QUERIES:
        query_words = set(query.lower().split())
        metrics[query] = {}
        
        for model_alias in retrieval_results:
            retrieved = retrieval_results[model_alias].get(query, [])
            
            # Count chunks with query keyword matches
            matches = 0
            for chunk_text, score in retrieved:
                chunk_lower = chunk_text.lower()
                # Check if any query keyword appears in chunk
                if any(word in chunk_lower for word in query_words):
                    matches += 1
            
            precision = matches / len(retrieved) if len(retrieved) > 0 else 0.0
            
            metrics[query][model_alias] = {
                'precision@5': precision,
                'recall@5': precision,
                'similarity_avg': np.mean([score for _, score in retrieved])
            }
    
    return metrics


print("Metric functions defined")

### Compute Metrics for All Models and Queries

Now we'll compute metrics using overlap-based relevance (for 2+ models) or keyword-based (for 1 model).

In [None]:
if retrieval_results:
    if len(retrieval_results) >= 2:
        print("Computing overlap-based metrics (2+ models available)...")
    else:
        print("Computing keyword-based metrics (single model)...")
    
    metrics_by_query = compute_overlap_based_metrics(retrieval_results, top_n=5, min_model_agreement=2)
    
    # Aggregate metrics across all queries per model
    metrics_by_model = defaultdict(lambda: defaultdict(list))
    
    for query, model_metrics in metrics_by_query.items():
        for model_alias, metrics_dict in model_metrics.items():
            for metric_name, value in metrics_dict.items():
                metrics_by_model[model_alias][metric_name].append(value)
    
    # Compute averages
    average_metrics = {}
    for model_alias, metrics_dict in metrics_by_model.items():
        average_metrics[model_alias] = {}
        for metric_name, values in metrics_dict.items():
            average_metrics[model_alias][metric_name] = np.mean(values) if values else 0.0
    
    print("Metrics computed")
else:
    metrics_by_query = {}
    average_metrics = {}
    print("⚠️  Skipping metrics (no retrieval results)")

### Display Per-Query Metrics

In [None]:
if metrics_by_query:
    print("\n" + "=" * 80)
    print("METRICS BY QUERY AND MODEL")
    print("=" * 80)
    
    for query in TEST_QUERIES[:3]:
        if query in metrics_by_query:
            print(f"\nQuery: '{query}'")
            query_metrics = metrics_by_query[query]
            
            for model_alias, metrics_dict in query_metrics.items():
                print(f"  {model_alias}:")
                for metric_name, value in metrics_dict.items():
                    if isinstance(value, float):
                        print(f"    {metric_name}: {value:.4f}")
                    else:
                        print(f"    {metric_name}: {value}")
    
    print(f"\n(Showing first 3 of {len(TEST_QUERIES)} queries)")

### Aggregate Metrics Summary

In [None]:
if average_metrics:
    print("\n" + "=" * 80)
    print("AVERAGE METRICS ACROSS ALL QUERIES")
    print("=" * 80)
    print()
    
    metrics_df = pd.DataFrame(average_metrics).T
    print(metrics_df.to_string())
    
    # Add latency info
    if retrieval_latencies:
        print("\n" + "=" * 80)
        print("QUERY LATENCY (ms)")
        print("=" * 80)
        print()
        
        latency_df = pd.DataFrame(retrieval_latencies).T
        latency_summary = latency_df.mean()
        
        for model_alias in latency_summary.index:
            avg_latency = latency_summary[model_alias]
            print(f"  {model_alias}: {avg_latency:.1f} ms")

## Part 4: Visualize Comparisons

**What we're doing:** Create visualizations to compare models clearly.

We'll create:
1. **Bar chart**: Metrics comparison (Precision@5, Recall@5)
2. **Speed vs Quality scatter**: Latency vs Precision trade-off
3. **Metrics table**: Side-by-side comparison

In [None]:
def plot_metrics_comparison(average_metrics: Dict, retrieval_latencies: Dict):
    """Create bar chart comparing key metrics across models.
    
    Args:
        average_metrics: Dict[model_alias] = {metric_name: value}
        retrieval_latencies: Dict[model_alias] = {query: latency_ms}
    """
    if not average_metrics:
        print("No metrics to visualize")
        return
    
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    models = list(average_metrics.keys())
    
    # Plot 1: Precision@5 comparison
    if 'precision@5' in list(average_metrics.values())[0]:
        precision_values = [average_metrics[m].get('precision@5', 0) for m in models]
        
        axes[0].bar(models, precision_values, color='steelblue', alpha=0.8)
        axes[0].set_ylabel('Precision@5', fontsize=12, fontweight='bold')
        axes[0].set_title('Retrieval Quality: Precision@5', fontsize=13, fontweight='bold')
        axes[0].set_ylim(0, 1.0)
        axes[0].grid(axis='y', alpha=0.3)
        
        # Add value labels on bars
        for i, v in enumerate(precision_values):
            axes[0].text(i, v + 0.02, f'{v:.3f}', ha='center', fontweight='bold')
    
    # Plot 2: Average Latency comparison
    if retrieval_latencies:
        latency_values = [
            np.mean(list(retrieval_latencies[m].values())) 
            for m in models 
            if m in retrieval_latencies
        ]
        
        axes[1].bar(models[:len(latency_values)], latency_values, color='coral', alpha=0.8)
        axes[1].set_ylabel('Latency (ms)', fontsize=12, fontweight='bold')
        axes[1].set_title('Query Speed: Average Latency', fontsize=13, fontweight='bold')
        axes[1].grid(axis='y', alpha=0.3)
        
        # Add value labels
        for i, v in enumerate(latency_values):
            axes[1].text(i, v + 5, f'{v:.0f}ms', ha='center', fontweight='bold')
    
    plt.tight_layout()
    plt.show()

if average_metrics:
    plot_metrics_comparison(average_metrics, retrieval_latencies)

### Speed vs Quality Trade-off

Plot each model as a point showing the trade-off between latency and quality.

In [None]:
def plot_speed_quality_tradeoff(average_metrics: Dict, retrieval_latencies: Dict):
    """Create scatter plot showing speed vs quality trade-off.
    
    Each model is a point:
    - X-axis: Average latency (ms) - lower is faster
    - Y-axis: Average Precision@5 - higher is better
    
    The ideal model is top-left (fast + accurate).
    """
    if not average_metrics or not retrieval_latencies:
        print("Need both metrics and latencies to plot trade-off")
        return
    
    models = list(average_metrics.keys())
    
    # Extract data
    latencies = []
    qualities = []
    
    for model in models:
        if model in retrieval_latencies:
            avg_latency = np.mean(list(retrieval_latencies[model].values()))
            latencies.append(avg_latency)
            
            quality = average_metrics[model].get('precision@5', 0)
            qualities.append(quality)
    
    # Create scatter plot
    fig, ax = plt.subplots(figsize=(10, 6))
    
    # Plot points
    scatter = ax.scatter(
        latencies,
        qualities,
        s=300,
        alpha=0.6,
        c=range(len(models)),
        cmap='viridis',
        edgecolors='black',
        linewidth=2
    )
    
    # Annotate with model names
    for i, model in enumerate(models):
        if i < len(latencies):
            ax.annotate(
                model,
                (latencies[i], qualities[i]),
                xytext=(10, 10),
                textcoords='offset points',
                fontsize=10,
                fontweight='bold',
                bbox=dict(boxstyle='round,pad=0.3', facecolor='yellow', alpha=0.3)
            )
    
    ax.set_xlabel('Latency (ms)', fontsize=12, fontweight='bold')
    ax.set_ylabel('Quality (Precision@5)', fontsize=12, fontweight='bold')
    ax.set_title('Speed vs Quality Trade-off', fontsize=14, fontweight='bold')
    ax.grid(True, alpha=0.3)
    
    # Add zones
    ax.text(
        0.98, 0.98,
        'IDEAL\n(Fast + Accurate)',
        transform=ax.transAxes,
        fontsize=10,
        ha='right',
        va='top',
        bbox=dict(boxstyle='round', facecolor='lightgreen', alpha=0.5)
    )
    
    plt.tight_layout()
    plt.show()

if len(average_metrics) > 0 and retrieval_latencies:
    plot_speed_quality_tradeoff(average_metrics, retrieval_latencies)

### Comprehensive Metrics Table

Display a side-by-side comparison table.

In [None]:
def create_comparison_table(average_metrics: Dict, retrieval_latencies: Dict) -> pd.DataFrame:
    """Create comprehensive comparison table with all metrics.
    
    Args:
        average_metrics: Dict[model_alias] = {metric_name: value}
        retrieval_latencies: Dict[model_alias] = {query: latency_ms}
    
    Returns:
        DataFrame with models as rows, metrics as columns
    """
    rows = []
    
    for model_alias in average_metrics.keys():
        row = {'Model': model_alias}
        
        # Add metrics
        for metric_name, value in average_metrics[model_alias].items():
            if isinstance(value, float):
                row[metric_name.title()] = f"{value:.4f}"
            else:
                row[metric_name.title()] = value
        
        # Add latency
        if model_alias in retrieval_latencies:
            avg_latency = np.mean(list(retrieval_latencies[model_alias].values()))
            row['Avg Latency (ms)'] = f"{avg_latency:.1f}"
        
        rows.append(row)
    
    return pd.DataFrame(rows)

if average_metrics:
    print("\n" + "=" * 80)
    print("COMPREHENSIVE COMPARISON TABLE")
    print("=" * 80)
    print()
    
    comparison_df = create_comparison_table(average_metrics, retrieval_latencies)
    print(comparison_df.to_string(index=False))

## Part 5: Analysis and Recommendations

**What we're doing:** Analyze the comparison results and provide actionable insights.

Key questions to answer:
1. Which model has best quality (precision/recall)?
2. Which model is fastest?
3. What are the speed-quality trade-offs?
4. For different use cases, which model should we recommend?
5. Are there queries where models differ significantly?

In [None]:
def analyze_and_recommend(
    available_models: pd.DataFrame,
    average_metrics: Dict,
    retrieval_latencies: Dict,
    metrics_by_query: Dict
):
    """Analyze comparison results and provide recommendations.
    
    Args:
        available_models: DataFrame of available models
        average_metrics: Aggregated metrics
        retrieval_latencies: Query latencies
        metrics_by_query: Per-query metrics
    """
    print("\n" + "=" * 80)
    print("ANALYSIS & RECOMMENDATIONS")
    print("=" * 80)
    
    if not average_metrics:
        print("\n⚠️  Not enough data for analysis")
        return
    
    models = list(average_metrics.keys())
    print(f"\nSummary: {len(models)} model(s) compared on {len(TEST_QUERIES)} test queries\n")
    
    # 1. Quality ranking
    print("1. QUALITY RANKING (by Precision@5):")
    quality_ranking = sorted(
        [(m, average_metrics[m].get('precision@5', 0)) for m in models],
        key=lambda x: x[1],
        reverse=True
    )
    
    for rank, (model, precision) in enumerate(quality_ranking, 1):
        print(f"   {rank}. {model}: {precision:.4f}")
    
    # 2. Speed ranking
    if retrieval_latencies:
        print("\n2. SPEED RANKING (by latency, lower is better):")
        speed_ranking = sorted(
            [(m, np.mean(list(retrieval_latencies[m].values()))) for m in models if m in retrieval_latencies],
            key=lambda x: x[1]
        )
        
        for rank, (model, latency) in enumerate(speed_ranking, 1):
            print(f"   {rank}. {model}: {latency:.1f}ms")
    
    # 3. Best model overall
    if len(models) > 1:
        best_quality = quality_ranking[0][0]
        fastest = speed_ranking[0][0] if retrieval_latencies else None
        
        print(f"\n3. OVERALL WINNERS:")
        print(f"   Best Quality: {best_quality} (Precision: {quality_ranking[0][1]:.4f})")
        if fastest:
            print(f"   Fastest: {fastest} ({speed_ranking[0][1]:.1f}ms per query)")
    
    # 4. Trade-off analysis
    if len(models) > 1:
        print(f"\n4. TRADE-OFF ANALYSIS:")
        
        # Calculate quality/latency ratio (efficiency)
        efficiency = {}
        for model in models:
            if model in retrieval_latencies:
                quality = average_metrics[model].get('precision@5', 0)
                latency = np.mean(list(retrieval_latencies[model].values()))
                if latency > 0:
                    efficiency[model] = quality / latency
        
        if efficiency:
            best_efficiency = max(efficiency.items(), key=lambda x: x[1])
            print(f"   Best Efficiency (Quality/Speed): {best_efficiency[0]}")
            print(f"     Quality per millisecond: {best_efficiency[1]:.6f}")
    
    # 5. Per-query analysis
    print(f"\n5. PER-QUERY VARIABILITY:")
    
    if metrics_by_query and len(models) > 1:
        # Find queries where models differ most
        query_differences = []
        
        for query, query_metrics in metrics_by_query.items():
            if len(query_metrics) >= 2:
                precisions = [query_metrics[m].get('precision@5', 0) for m in models if m in query_metrics]
                if precisions:
                    diff = max(precisions) - min(precisions)
                    query_differences.append((query, diff, max(precisions), min(precisions)))
        
        if query_differences:
            # Sort by difference
            query_differences.sort(key=lambda x: x[1], reverse=True)
            
            print(f"\n   Queries with biggest model differences:")
            for query, diff, best, worst in query_differences[:3]:
                print(f"   - '{query}'")
                print(f"     Best: {best:.4f}, Worst: {worst:.4f}, Diff: {diff:.4f}")
    
    # 6. Recommendations
    print(f"\n6. RECOMMENDATIONS:")
    
    if len(models) == 1:
        print(f"   - Only 1 model available for comparison")
        print(f"   - Generate a second model to enable comparative analysis")
        print(f"   - See instructions at top of notebook")
    else:
        best_quality_model = quality_ranking[0][0]
        fastest_model = speed_ranking[0][0] if retrieval_latencies else None
        
        print(f"\n   Use-Case Recommendations:")
        print(f"\n   a) For Maximum Quality:")
        print(f"      - Use: {best_quality_model}")
        print(f"      - Precision@5: {quality_ranking[0][1]:.4f}")
        print(f"      - Best for: Applications requiring high accuracy (e.g., research, legal)")
        
        if fastest_model:
            print(f"\n   b) For Maximum Speed:")
            print(f"      - Use: {fastest_model}")
            print(f"      - Latency: {speed_ranking[0][1]:.1f}ms per query")
            print(f"      - Best for: Real-time applications, user-facing chatbots")
        
        if efficiency:
            best_eff_model = max(efficiency.items(), key=lambda x: x[1])[0]
            print(f"\n   c) For Best Efficiency (Quality/Speed):")
            print(f"      - Use: {best_eff_model}")
            print(f"      - Best for: Balanced production systems")
    
    print(f"\n" + "=" * 80)

analyze_and_recommend(available_models, average_metrics, retrieval_latencies, metrics_by_query)

## Next Steps

Now that you've compared embedding models:

1. **Advanced Techniques (intermediate/05+)**: Use your chosen model to apply retrieval improvements
   - Query expansion
   - Reranking
   - Hybrid search (vector + keyword)

2. **Evaluation Lab**: Compare your techniques against the baseline
   - Track metrics over time
   - A/B test different approaches
   - Measure real user satisfaction

3. **Model Fine-tuning**: If you want better results
   - Collect domain-specific training data
   - Fine-tune an embedding model on your data
   - Compare against base models

4. **Production Deployment**: Use your recommendation
   - Implement the recommended model
   - Monitor quality metrics in production
   - Update if better models become available

## Cleanup

In [None]:
# Close database connections
if 'conn' in locals() and conn:
    conn.close()

print("Connections closed")