## Prerequisites

Before running this notebook:

1. ✅ Run `foundation/00-setup-postgres-schema.ipynb` to create tables
2. ✅ Run `foundation/02-rag-postgresql-persistent.ipynb` to generate embeddings
3. ✅ Run `evaluation-lab/01-create-ground-truth-human-in-loop.ipynb` to create test queries

These setup notebooks will populate the registry and create ground-truth test data.

## Configuration

Adjust these parameters to experiment with different strategies:

In [None]:
# Configuration
EMBEDDING_MODEL_ALIAS = "all-minilm-l6-v2"  # From registry
RERANKER_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2"  # HuggingFace model
TOP_K_INITIAL = 20  # How many initial results to get
TOP_K_FINAL = 5     # How many to rerank and return
RERANKING_BATCH_SIZE = 32

# Experiment tracking
EXPERIMENT_NAME = "reranking-cross-encoder"
TECHNIQUES_APPLIED = ["vector_retrieval", "cross_encoder_reranking"]

## Load Embeddings from Registry

This section demonstrates the registry discovery pattern:

In [None]:
import psycopg2
import psycopg2.extras
import ollama
import json
import pandas as pd
import numpy as np
import hashlib
from datetime import datetime
from typing import List, Dict, Tuple, Optional
import os

# PostgreSQL connection
POSTGRES_CONFIG = {
    'host': 'localhost',
    'port': 5432,
    'database': 'rag_db',
    'user': 'postgres',
    'password': 'postgres',
}

# Create database connection
try:
    db_connection = psycopg2.connect(
        host=POSTGRES_CONFIG['host'],
        port=POSTGRES_CONFIG['port'],
        database=POSTGRES_CONFIG['database'],
        user=POSTGRES_CONFIG['user'],
        password=POSTGRES_CONFIG['password']
    )
    print("✓ Connected to PostgreSQL")
except psycopg2.OperationalError as e:
    print(f"✗ Failed to connect to PostgreSQL: {e}")
    raise

# ============================================================================
# PART 1: REGISTRY DISCOVERY & LOAD-OR-GENERATE PATTERN
# ============================================================================
# Inline from foundation/00-registry-and-tracking-utilities.ipynb

def list_available_embeddings(db_connection) -> pd.DataFrame:
    """Query embedding_registry to show available models with metadata.

    Returns:
        DataFrame with columns: model_alias, model_name, dimension, embedding_count,
                                 chunk_source_dataset, created_at, chunk_size_config
    """
    query = '''
        SELECT
            model_alias,
            model_name,
            dimension,
            embedding_count,
            chunk_source_dataset,
            chunk_size_config,
            created_at,
            last_accessed
        FROM embedding_registry
        ORDER BY created_at DESC
    '''
    return pd.read_sql(query, db_connection)


def get_embedding_metadata(db_connection, model_alias: str) -> Optional[Dict]:
    """Fetch metadata_json and other info for a specific model.

    Args:
        db_connection: PostgreSQL connection
        model_alias: The model alias (e.g., 'bge_base_en_v1_5')

    Returns:
        Dict with: dimension, embedding_count, config_hash (if stored),
                   chunk_source_dataset, created_at, metadata_json
    """
    with db_connection.cursor() as cur:
        cur.execute('''
            SELECT
                dimension,
                embedding_count,
                chunk_source_dataset,
                chunk_size_config,
                created_at,
                metadata_json
            FROM embedding_registry
            WHERE model_alias = %s
        ''', (model_alias,))
        result = cur.fetchone()

        if not result:
            return None

        return {
            'dimension': result[0],
            'embedding_count': result[1],
            'chunk_source_dataset': result[2],
            'chunk_size_config': result[3],
            'created_at': result[4],
            'metadata_json': result[5] or {}
        }


class PostgreSQLVectorDB:
    """Helper to load embeddings from PostgreSQL without regeneration."""

    def __init__(self, config, table_name, preserve_existing=True):
        self.config = config
        self.table_name = table_name
        self.conn = psycopg2.connect(
            host=config['host'],
            port=config['port'],
            database=config['database'],
            user=config['user'],
            password=config['password']
        )
        print(f'✓ Connected to table: {table_name}')

    def get_chunk_count(self):
        """How many embeddings are stored?"""
        with self.conn.cursor() as cur:
            cur.execute(f'SELECT COUNT(*) FROM {self.table_name}')
            return cur.fetchone()[0]

    def similarity_search(self, query_embedding, top_n=5):
        """Retrieve most similar chunks using pgvector.

        Args:
            query_embedding: Query embedding vector
            top_n: Number of results to return

        Returns:
            List of tuples: (chunk_text, similarity_score, chunk_id)
        """
        with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
            cur.execute(f'''
                SELECT id,
                       content as chunk_text,
                       1 - (embedding <=> %s::vector) as similarity
                FROM {self.table_name}
                ORDER BY embedding <=> %s::vector
                LIMIT %s
            ''', (query_embedding, query_embedding, top_n))

            results = cur.fetchall()
            return [(row['chunk_text'], row['similarity'], row['id']) for row in results]

    def close(self):
        if self.conn:
            self.conn.close()


def load_or_generate(db_connection, embedding_model_alias, preserve_existing=True):
    """Load embeddings from registry OR show instructions if not available.

    This is the CORE PATTERN for fast iteration: check registry first,
    load existing embeddings instantly (<1 second), avoid 50+ minute regeneration.

    Args:
        db_connection: PostgreSQL connection object
        embedding_model_alias: Model identifier (e.g., 'all_minilm_l6_v2')
        preserve_existing: If True, always load. If False, regenerate (requires manual run).
                          If None, prompt user interactively.

    Returns:
        PostgreSQLVectorDB instance ready for use, or None if no embeddings available
    """

    print(f"\n{'='*70}")
    print(f"Checking for embeddings: '{embedding_model_alias}'...")
    print(f"{'='*70}\n")

    try:
        with db_connection.cursor() as cur:
            cur.execute('''
                SELECT id, dimension, embedding_count, created_at, metadata_json
                FROM embedding_registry
                WHERE model_alias = %s
            ''', (embedding_model_alias,))
            registry_entry = cur.fetchone()
    except Exception as e:
        print(f"Could not query registry: {e}")
        print("Make sure foundation/00-setup-postgres-schema.ipynb has been run.")
        return None

    # Case A: Embeddings exist
    if registry_entry:
        reg_id, dimension, embedding_count, created_at, metadata_json = registry_entry

        print(f"✓ FOUND EXISTING EMBEDDINGS")
        print(f"  Model:      {embedding_model_alias}")
        print(f"  Count:      {embedding_count:,} embeddings")
        print(f"  Dimension:  {dimension}")
        print(f"  Created:    {created_at}")
        print(f"\n  TIME SAVINGS:")
        print(f"    Loading:       <1 second")
        print(f"    Regenerating:  ~50+ minutes")
        print(f"    ➜ You save 50+ minutes by loading!\n")

        if preserve_existing:
            # Auto-load (for scripts/notebooks)
            print("Loading existing embeddings...\n")

            try:
                table_name = f'embeddings_{embedding_model_alias.replace(".", "_")}'

                db_instance = PostgreSQLVectorDB(
                    config=POSTGRES_CONFIG,
                    table_name=table_name,
                    preserve_existing=True
                )

                count = db_instance.get_chunk_count()
                print(f"✓ LOADED SUCCESSFULLY")
                print(f"  Embeddings: {count:,}")
                print(f"  Table: {table_name}")
                print(f"  Status: Ready for retrieval\n")

                return db_instance

            except Exception as e:
                print(f"\n✗ Error loading embeddings: {e}")
                print(f"\nTroubleshooting:")
                print(f"  1. Verify PostgreSQL is running")
                print(f"  2. Check POSTGRES_CONFIG settings")
                print(f"  3. Run foundation/02 to generate embeddings first")
                return None

    # Case B: No embeddings found
    else:
        print(f"✗ NO EMBEDDINGS FOUND")
        print(f"  Model: {embedding_model_alias}")
        print(f"\nTo create embeddings, run:")
        print(f"  foundation/02-rag-postgresql-persistent.ipynb")
        print(f"\nThen come back and re-run this cell.\n")
        return None


# Discover and load embeddings
print("Step 1: Discovering available embeddings...\n")
available = list_available_embeddings(db_connection)

if available.empty:
    print("⚠️  No embeddings found in registry yet.")
    print("Run foundation/02-rag-postgresql-persistent.ipynb first.\n")
else:
    print("Available embeddings:")
    print(available.to_string(index=False))
    print()

# Load embeddings using the pattern
print("\nStep 2: Loading embeddings using load-or-generate pattern...\n")
embeddings_db = load_or_generate(
    db_connection=db_connection,
    embedding_model_alias=EMBEDDING_MODEL_ALIAS,
    preserve_existing=True  # Auto-load if available
)

if embeddings_db:
    print("✓ Success! Embeddings loaded and ready for retrieval.")
else:
    print("⚠️  Could not load embeddings. See instructions above.")
    # Continue anyway for structure, but operations will fail
    embeddings_db = None

## Implement Cross-Encoder Reranking

Core reranking implementation:

In [None]:
# ============================================================================
# PART 2: CROSS-ENCODER RERANKING IMPLEMENTATION
# ============================================================================

from sentence_transformers import CrossEncoder

def rerank_with_crossencoder(query: str, 
                             candidates: List[Tuple[str, float, int]], 
                             reranker_model: str = 'cross-encoder/ms-marco-MiniLM-L-6-v2',
                             top_k: int = 5) -> List[Tuple[str, float, int]]:
    """Two-stage retrieval: dense retrieval → cross-encoder reranking.

    Stage 1: Dense retrieval (already done - these are candidates)
    Stage 2: Cross-encoder scoring (more accurate, slower)

    Args:
        query: User question
        candidates: List of (chunk_text, initial_score, chunk_id) from vector search
        reranker_model: CrossEncoder model name
        top_k: Return top K after reranking

    Returns:
        List of (chunk_text, rerank_score, chunk_id) sorted by rerank score
    """
    if not candidates:
        return []

    # Load cross-encoder model
    print(f"Loading cross-encoder model: {reranker_model}")
    model = CrossEncoder(reranker_model)

    # Create query-document pairs for the cross-encoder
    chunk_texts = [chunk_text for chunk_text, _, _ in candidates]
    pairs = [[query, chunk_text] for chunk_text in chunk_texts]

    # Score all pairs with cross-encoder (more accurate but slower than dense retrieval)
    print(f"Scoring {len(pairs)} candidates with cross-encoder...")
    scores = model.predict(pairs)

    # Combine original data with reranking scores
    reranked = [
        (chunk_text, float(score), chunk_id)
        for (chunk_text, _, chunk_id), score in zip(candidates, scores)
    ]

    # Sort by rerank score (descending)
    reranked.sort(key=lambda x: x[1], reverse=True)

    return reranked[:top_k]


def retrieve_with_reranking(query: str,
                           embeddings_db: PostgreSQLVectorDB,
                           embedding_model: str,
                           top_k_initial: int = 20,
                           top_k_final: int = 5,
                           reranker_model: str = 'cross-encoder/ms-marco-MiniLM-L-6-v2') -> List[Tuple[str, float, int]]:
    """Complete retrieval pipeline with reranking.

    Args:
        query: User question
        embeddings_db: PostgreSQLVectorDB instance
        embedding_model: Model name for embedding the query
        top_k_initial: How many candidates to get from dense retrieval
        top_k_final: How many to return after reranking
        reranker_model: Cross-encoder model to use

    Returns:
        List of top_k_final results after reranking
    """
    # Step 1: Dense retrieval (broad recall from vector similarity)
    print(f"\nStep 1: Dense retrieval (getting top {top_k_initial} candidates)...")
    query_emb = ollama.embed(model=embedding_model, input=query)['embeddings'][0]
    candidates = embeddings_db.similarity_search(query_emb, top_n=top_k_initial)
    print(f"✓ Retrieved {len(candidates)} candidates via dense retrieval")

    # Step 2: Cross-encoder reranking (high precision, slower)
    print(f"\nStep 2: Cross-encoder reranking (selecting top {top_k_final})...")
    reranked = rerank_with_crossencoder(query, candidates, reranker_model, top_k=top_k_final)
    print(f"✓ Reranked to top {len(reranked)} results")

    return reranked


# Load test questions from ground truth
print("\nLoading ground truth test set...")
ground_truth_questions = []

with db_connection.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
    cur.execute('''
        SELECT
            id,
            question,
            relevant_chunk_ids,
            quality_rating,
            source_type
        FROM evaluation_groundtruth
        WHERE quality_rating = 'good'
        ORDER BY id
    ''')

    for row in cur.fetchall():
        ground_truth_questions.append({
            'id': row['id'],
            'question': row['question'],
            'relevant_chunk_ids': row['relevant_chunk_ids'] or [],
            'quality_rating': row['quality_rating'],
            'source_type': row['source_type']
        })

print(f"✓ Loaded {len(ground_truth_questions)} ground truth questions\n")

if ground_truth_questions:
    print(f"Sample question: {ground_truth_questions[0]['question'][:80]}...")
    print(f"Relevant chunks: {ground_truth_questions[0]['relevant_chunk_ids'][:3]}\n")

# Demonstrate reranking on first test question (if available)
if embeddings_db and ground_truth_questions:
    print("=" * 70)
    print("DEMONSTRATING RERANKING")
    print("=" * 70)

    test_question = ground_truth_questions[0]['question']
    print(f"\nQuery: {test_question}\n")

    # Baseline: vector-only retrieval
    print("BASELINE (Vector Retrieval Only):")
    query_emb = ollama.embed(model=EMBEDDING_MODEL_ALIAS, input=test_question)['embeddings'][0]
    baseline_results = embeddings_db.similarity_search(query_emb, top_n=TOP_K_FINAL)

    print("Top 5 results (vector similarity only):")
    for i, (chunk_text, score, chunk_id) in enumerate(baseline_results, 1):
        preview = chunk_text[:100].replace('\n', ' ') + '...'
        print(f"  [{i}] (score: {score:.4f}) {preview}")

    # With reranking
    print("\nWITH RERANKING (CrossEncoder):")
    reranked_results = retrieve_with_reranking(
        test_question,
        embeddings_db,
        EMBEDDING_MODEL_ALIAS,
        top_k_initial=TOP_K_INITIAL,
        top_k_final=TOP_K_FINAL,
        reranker_model=RERANKER_MODEL
    )

    print("Top 5 results (after cross-encoder reranking):")
    for i, (chunk_text, score, chunk_id) in enumerate(reranked_results, 1):
        preview = chunk_text[:100].replace('\n', ' ') + '...'
        print(f"  [{i}] (rerank score: {score:.4f}) {preview}")

    print("\nNote: Reranking may reorder results, prioritizing semantic relevance over surface similarity.\n")


## Evaluate Impact

Compare reranked results to baseline vector retrieval:

In [None]:
# ============================================================================
# PART 3: EVALUATE IMPACT - BASELINE VS RERANKING
# ============================================================================
# Inline metric functions from evaluation-lab/02-evaluation-metrics-framework.ipynb

def precision_at_k(retrieved_chunk_ids: List[int], relevant_chunk_ids: List[int], k: int = 5) -> float:
    """Precision@K: What % of top-K results are relevant?
    
    Formula: |{relevant in top-K}| / K
    
    Good for: Understanding result quality from user perspective
    """
    if k == 0:
        return 0.0

    retrieved_k = retrieved_chunk_ids[:k]
    relevant_set = set(relevant_chunk_ids)

    num_relevant_in_k = sum(1 for chunk_id in retrieved_k if chunk_id in relevant_set)

    return num_relevant_in_k / k


def recall_at_k(retrieved_chunk_ids: List[int], relevant_chunk_ids: List[int], k: int = 5) -> float:
    """Recall@K: What % of all relevant chunks were found in top-K?
    
    Formula: |{relevant in top-K}| / |all relevant|
    
    Good for: Understanding coverage
    """
    if len(relevant_chunk_ids) == 0:
        return 0.0

    retrieved_k = retrieved_chunk_ids[:k]
    relevant_set = set(relevant_chunk_ids)

    num_relevant_found = sum(1 for chunk_id in retrieved_k if chunk_id in relevant_set)

    return num_relevant_found / len(relevant_set)


def mean_reciprocal_rank(retrieved_chunk_ids: List[int], relevant_chunk_ids: List[int]) -> float:
    """MRR: How quickly do we find the first relevant result?
    
    Formula: 1 / (rank of first relevant result)
    
    Good for: Understanding user satisfaction
    """
    relevant_set = set(relevant_chunk_ids)

    for rank, chunk_id in enumerate(retrieved_chunk_ids, start=1):
        if chunk_id in relevant_set:
            return 1.0 / rank

    return 0.0


def ndcg_at_k(retrieved_chunk_ids: List[int], relevant_chunk_ids: List[int], k: int = 5) -> float:
    """NDCG@K: Normalized Discounted Cumulative Gain
    
    How well-ranked are the results? (rewards relevant results at top)
    
    Good for: Understanding ranking quality
    """
    def dcg_score(relevance_scores: List[float]) -> float:
        """Compute DCG from relevance scores."""
        return sum(
            (2**rel - 1) / np.log2(rank + 2)
            for rank, rel in enumerate(relevance_scores)
        )

    if k == 0 or len(relevant_chunk_ids) == 0:
        return 0.0

    # Get top-K retrieved
    retrieved_k = retrieved_chunk_ids[:k]
    relevant_set = set(relevant_chunk_ids)

    # Binary relevance: 1 if relevant, 0 if not
    relevance = [1 if chunk_id in relevant_set else 0 for chunk_id in retrieved_k]

    # Compute DCG for retrieved ranking
    dcg = dcg_score(relevance)

    # Compute ideal DCG (perfect ranking)
    ideal_relevance = sorted(relevance, reverse=True)
    idcg = dcg_score(ideal_relevance)

    if idcg == 0:
        return 0.0

    return dcg / idcg


def evaluate_retrieval_with_and_without_reranking(test_questions: List[Dict],
                                                   embeddings_db: PostgreSQLVectorDB,
                                                   embedding_model: str,
                                                   reranker_model: str = RERANKER_MODEL) -> Dict:
    """Compare baseline vs reranking on test set.

    Args:
        test_questions: List of dicts from evaluation_groundtruth table
        embeddings_db: PostgreSQLVectorDB instance
        embedding_model: Model name for query embedding
        reranker_model: Cross-encoder model for reranking

    Returns:
        dict: {
            'baseline': {metric: value},
            'reranked': {metric: value},
            'improvements_pct': {metric: improvement_percentage},
            'per_query': [per-query results]
        }
    """
    baseline_results = []
    reranked_results = []
    per_query_metrics = []

    print(f"\nEvaluating {len(test_questions)} test queries...")
    print("=" * 70)

    for i, q in enumerate(test_questions, 1):
        query = q['question']
        relevant_ids = q['relevant_chunk_ids']

        if not relevant_ids:
            print(f"Skipping query {i} (no relevant chunks)")
            continue

        # Progress indicator
        if i % max(1, len(test_questions) // 10) == 0:
            print(f"Progress: {i}/{len(test_questions)} queries processed...")

        # ================================================================
        # BASELINE: Vector-only retrieval (no reranking)
        # ================================================================
        query_emb = ollama.embed(model=embedding_model, input=query)['embeddings'][0]
        baseline_chunks = embeddings_db.similarity_search(query_emb, top_n=TOP_K_FINAL)
        baseline_ids = [chunk_id for _, _, chunk_id in baseline_chunks]

        # ================================================================
        # WITH RERANKING: Dense retrieval + cross-encoder reranking
        # ================================================================
        reranked_chunks = retrieve_with_reranking(
            query,
            embeddings_db,
            embedding_model,
            top_k_initial=TOP_K_INITIAL,
            top_k_final=TOP_K_FINAL,
            reranker_model=reranker_model
        )
        reranked_ids = [chunk_id for _, _, chunk_id in reranked_chunks]

        # ================================================================
        # COMPUTE METRICS
        # ================================================================
        baseline_metrics = {
            'precision@5': precision_at_k(baseline_ids, relevant_ids, k=5),
            'recall@5': recall_at_k(baseline_ids, relevant_ids, k=5),
            'mrr': mean_reciprocal_rank(baseline_ids, relevant_ids),
            'ndcg@5': ndcg_at_k(baseline_ids, relevant_ids, k=5)
        }

        reranked_metrics = {
            'precision@5': precision_at_k(reranked_ids, relevant_ids, k=5),
            'recall@5': recall_at_k(reranked_ids, relevant_ids, k=5),
            'mrr': mean_reciprocal_rank(reranked_ids, relevant_ids),
            'ndcg@5': ndcg_at_k(reranked_ids, relevant_ids, k=5)
        }

        baseline_results.append(baseline_metrics)
        reranked_results.append(reranked_metrics)

        # Store per-query details
        per_query_metrics.append({
            'question': query[:80],
            'baseline': baseline_metrics,
            'reranked': reranked_metrics
        })

    print("=" * 70)
    print(f"✓ Evaluation complete ({len(baseline_results)} queries)\n")

    # ================================================================
    # AGGREGATE METRICS
    # ================================================================
    def aggregate(results):
        if not results:
            return {'precision@5': 0, 'recall@5': 0, 'mrr': 0, 'ndcg@5': 0}
        return {
            metric: np.mean([r[metric] for r in results])
            for metric in results[0].keys()
        }

    baseline_agg = aggregate(baseline_results)
    reranked_agg = aggregate(reranked_results)

    # ================================================================
    # COMPUTE IMPROVEMENTS
    # ================================================================
    improvements = {}
    for metric in baseline_agg.keys():
        if baseline_agg[metric] > 0:
            improvement_pct = (
                (reranked_agg[metric] - baseline_agg[metric]) / baseline_agg[metric] * 100
            )
        else:
            improvement_pct = 0
        improvements[metric] = improvement_pct

    return {
        'baseline': baseline_agg,
        'reranked': reranked_agg,
        'improvements_pct': improvements,
        'per_query': per_query_metrics,
        'num_queries_evaluated': len(baseline_results)
    }


# Run evaluation
print("\n" + "=" * 70)
print("EVALUATING RERANKING IMPACT")
print("=" * 70)

if embeddings_db and ground_truth_questions:
    evaluation_results = evaluate_retrieval_with_and_without_reranking(
        ground_truth_questions,
        embeddings_db,
        EMBEDDING_MODEL_ALIAS,
        reranker_model=RERANKER_MODEL
    )

    # Display results
    print("\n" + "=" * 70)
    print("RESULTS SUMMARY")
    print("=" * 70)

    baseline = evaluation_results['baseline']
    reranked = evaluation_results['reranked']
    improvements = evaluation_results['improvements_pct']

    print(f"\nQueries Evaluated: {evaluation_results['num_queries_evaluated']}\n")

    print("BASELINE (Vector Retrieval Only):")
    print(f"  Precision@5: {baseline['precision@5']:.4f}")
    print(f"  Recall@5:    {baseline['recall@5']:.4f}")
    print(f"  MRR:         {baseline['mrr']:.4f}")
    print(f"  NDCG@5:      {baseline['ndcg@5']:.4f}")

    print("\nWITH RERANKING (CrossEncoder):")
    print(f"  Precision@5: {reranked['precision@5']:.4f}")
    print(f"  Recall@5:    {reranked['recall@5']:.4f}")
    print(f"  MRR:         {reranked['mrr']:.4f}")
    print(f"  NDCG@5:      {reranked['ndcg@5']:.4f}")

    print("\nIMPROVEMENTS:")
    for metric, improvement in improvements.items():
        sign = "+" if improvement > 0 else ""
        print(f"  {metric:<15} {sign}{improvement:>6.2f}%")

    # Visualize per-query improvements
    print("\n" + "=" * 70)
    print("PER-QUERY ANALYSIS")
    print("=" * 70)

    import matplotlib.pyplot as plt

    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    fig.suptitle('Reranking Impact: Baseline vs CrossEncoder', fontsize=14, fontweight='bold')

    # Extract per-query metrics
    baseline_p5 = [q['baseline']['precision@5'] for q in evaluation_results['per_query']]
    reranked_p5 = [q['reranked']['precision@5'] for q in evaluation_results['per_query']]

    baseline_recall = [q['baseline']['recall@5'] for q in evaluation_results['per_query']]
    reranked_recall = [q['reranked']['recall@5'] for q in evaluation_results['per_query']]

    baseline_mrr = [q['baseline']['mrr'] for q in evaluation_results['per_query']]
    reranked_mrr = [q['reranked']['mrr'] for q in evaluation_results['per_query']]

    baseline_ndcg = [q['baseline']['ndcg@5'] for q in evaluation_results['per_query']]
    reranked_ndcg = [q['reranked']['ndcg@5'] for q in evaluation_results['per_query']]

    # Plot 1: Precision@5
    ax = axes[0, 0]
    x = np.arange(len(baseline_p5))
    ax.bar(x - 0.2, baseline_p5, 0.4, label='Baseline', alpha=0.8, color='#2E86AB')
    ax.bar(x + 0.2, reranked_p5, 0.4, label='Reranked', alpha=0.8, color='#06A77D')
    ax.set_ylabel('Precision@5', fontweight='bold')
    ax.set_title('Precision@5 by Query')
    ax.legend()
    ax.set_ylim(0, 1.0)
    ax.grid(True, alpha=0.3, axis='y')

    # Plot 2: Recall@5
    ax = axes[0, 1]
    ax.bar(x - 0.2, baseline_recall, 0.4, label='Baseline', alpha=0.8, color='#2E86AB')
    ax.bar(x + 0.2, reranked_recall, 0.4, label='Reranked', alpha=0.8, color='#06A77D')
    ax.set_ylabel('Recall@5', fontweight='bold')
    ax.set_title('Recall@5 by Query')
    ax.legend()
    ax.set_ylim(0, 1.0)
    ax.grid(True, alpha=0.3, axis='y')

    # Plot 3: MRR
    ax = axes[1, 0]
    ax.bar(x - 0.2, baseline_mrr, 0.4, label='Baseline', alpha=0.8, color='#2E86AB')
    ax.bar(x + 0.2, reranked_mrr, 0.4, label='Reranked', alpha=0.8, color='#06A77D')
    ax.set_ylabel('MRR', fontweight='bold')
    ax.set_title('Mean Reciprocal Rank by Query')
    ax.legend()
    ax.set_ylim(0, 1.0)
    ax.grid(True, alpha=0.3, axis='y')

    # Plot 4: NDCG@5
    ax = axes[1, 1]
    ax.bar(x - 0.2, baseline_ndcg, 0.4, label='Baseline', alpha=0.8, color='#2E86AB')
    ax.bar(x + 0.2, reranked_ndcg, 0.4, label='Reranked', alpha=0.8, color='#06A77D')
    ax.set_ylabel('NDCG@5', fontweight='bold')
    ax.set_title('NDCG@5 by Query')
    ax.legend()
    ax.set_ylim(0, 1.0)
    ax.grid(True, alpha=0.3, axis='y')

    plt.tight_layout()
    plt.show()

    print("\n✓ Visualizations displayed above")

else:
    print("⚠️  Cannot evaluate: embeddings or test questions not available")
    evaluation_results = None


## Track Experiment

Store results for comparison:

In [None]:
# ============================================================================
# PART 4: EXPERIMENT TRACKING
# ============================================================================
# Inline from foundation/00-registry-and-tracking-utilities.ipynb

def compute_config_hash(config_dict: Dict) -> str:
    """Create deterministic SHA256 hash of a configuration dictionary.

    This enables finding all experiments with identical configurations.

    Args:
        config_dict: Configuration parameters

    Returns:
        SHA256 hash string (first 12 characters for readability)
    """
    config_str = json.dumps(config_dict, sort_keys=True)
    hash_obj = hashlib.sha256(config_str.encode())
    return hash_obj.hexdigest()[:12]


def start_experiment(db_connection, experiment_name: str,
                     notebook_path: str = None,
                     embedding_model_alias: str = None,
                     config: Dict = None,
                     techniques: List[str] = None,
                     notes: str = None) -> int:
    """Start a new experiment and return its ID for tracking.

    Args:
        db_connection: PostgreSQL connection
        experiment_name: Human-readable experiment name
        notebook_path: Path to the notebook running this experiment
        embedding_model_alias: Which embedding model is being used
        config: Dict of configuration parameters
        techniques: List of techniques being applied
        notes: Optional notes about the experiment

    Returns:
        Experiment ID for use in save_metrics() and complete_experiment()
    """
    if config is None:
        config = {}
    if techniques is None:
        techniques = []

    config_hash = compute_config_hash(config)

    with db_connection.cursor() as cur:
        cur.execute('''
            INSERT INTO experiments (
                experiment_name, notebook_path, embedding_model_alias,
                config_hash, config_json, techniques_applied, notes, status
            )
            VALUES (%s, %s, %s, %s, %s, %s, %s, 'running')
            RETURNING id
        ''', (
            experiment_name,
            notebook_path,
            embedding_model_alias,
            config_hash,
            json.dumps(config),
            techniques,
            notes
        ))
        exp_id = cur.fetchone()[0]
    db_connection.commit()
    print(f"✓ Started experiment #{exp_id}: {experiment_name}")
    return exp_id


def save_metrics(db_connection, experiment_id: int, metrics_dict: Dict,
                 export_to_file: bool = True,
                 export_dir: str = 'data/experiment_results') -> Tuple[bool, str]:
    """Save experiment metrics to database and optionally to JSON file.

    Args:
        db_connection: PostgreSQL connection
        experiment_id: ID from start_experiment()
        metrics_dict: Dict of {metric_name: value, ...}
        export_to_file: Whether to also save to filesystem JSON
        export_dir: Directory for JSON exports

    Returns:
        Tuple of (success: bool, message: str)
    """
    try:
        with db_connection.cursor() as cur:
            for metric_name, metric_data in metrics_dict.items():
                # Handle both simple floats and nested dicts with details
                if isinstance(metric_data, dict):
                    metric_value = metric_data.get('value', 0.0)
                    metric_details = metric_data.get('details', {})
                else:
                    metric_value = metric_data
                    metric_details = {}

                cur.execute('''
                    INSERT INTO evaluation_results (
                        experiment_id, metric_name, metric_value, metric_details_json
                    )
                    VALUES (%s, %s, %s, %s)
                ''', (
                    experiment_id,
                    metric_name,
                    float(metric_value),
                    json.dumps(metric_details) if metric_details else '{}'
                ))
        db_connection.commit()

        # Export to file if requested
        file_path = None
        if export_to_file:
            os.makedirs(export_dir, exist_ok=True)
            timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
            file_path = os.path.join(export_dir, f'experiment_{experiment_id}_{timestamp}.json')
            with open(file_path, 'w') as f:
                json.dump({
                    'experiment_id': experiment_id,
                    'timestamp': timestamp,
                    'metrics': metrics_dict
                }, f, indent=2)

        msg = f"✓ Saved {len(metrics_dict)} metrics for experiment #{experiment_id}"
        if file_path:
            msg += f" to {file_path}"
        print(msg)
        return True, msg
    except Exception as e:
        msg = f"✗ Failed to save metrics: {e}"
        print(msg)
        db_connection.rollback()
        return False, msg


def complete_experiment(db_connection, experiment_id: int,
                       status: str = 'completed',
                       notes: str = None) -> bool:
    """Mark an experiment as complete.

    Args:
        db_connection: PostgreSQL connection
        experiment_id: ID returned from start_experiment()
        status: 'completed' or 'failed'
        notes: Optional update to notes field

    Returns:
        True if successful
    """
    try:
        with db_connection.cursor() as cur:
            if notes:
                cur.execute('''
                    UPDATE experiments
                    SET status = %s, notes = %s, completed_at = CURRENT_TIMESTAMP
                    WHERE id = %s
                ''', (status, notes, experiment_id))
            else:
                cur.execute('''
                    UPDATE experiments
                    SET status = %s, completed_at = CURRENT_TIMESTAMP
                    WHERE id = %s
                ''', (status, experiment_id))
        db_connection.commit()
        print(f"✓ Experiment #{experiment_id} marked as {status}")
        return True
    except Exception as e:
        print(f"✗ Failed to complete experiment: {e}")
        db_connection.rollback()
        return False


# ============================================================================
# RUN EXPERIMENT TRACKING
# ============================================================================

print("\n" + "=" * 70)
print("TRACKING EXPERIMENT")
print("=" * 70)

if evaluation_results:
    # Prepare configuration
    config_dict = {
        'embedding_model_alias': EMBEDDING_MODEL_ALIAS,
        'reranker_model': RERANKER_MODEL,
        'top_k_initial': TOP_K_INITIAL,
        'top_k_final': TOP_K_FINAL,
        'reranking_batch_size': RERANKING_BATCH_SIZE,
        'num_test_queries': evaluation_results['num_queries_evaluated'],
    }

    config_hash = compute_config_hash(config_dict)

    print(f"\nExperiment Configuration:")
    print(f"  Name: {EXPERIMENT_NAME}")
    print(f"  Embedding Model: {EMBEDDING_MODEL_ALIAS}")
    print(f"  Reranker Model: {RERANKER_MODEL}")
    print(f"  Config Hash: {config_hash}")
    print(f"  Test Queries: {evaluation_results['num_queries_evaluated']}\n")

    # Start experiment tracking
    experiment_id = start_experiment(
        db_connection,
        experiment_name=EXPERIMENT_NAME,
        notebook_path='advanced-techniques/05-reranking.ipynb',
        embedding_model_alias=EMBEDDING_MODEL_ALIAS,
        config=config_dict,
        techniques=TECHNIQUES_APPLIED,
        notes=f'Cross-encoder reranking evaluation on {evaluation_results["num_queries_evaluated"]} queries'
    )

    # Prepare metrics for storage
    metrics_to_store = {}

    # Baseline metrics
    for metric_name, metric_value in evaluation_results['baseline'].items():
        metrics_to_store[f'baseline_{metric_name}'] = metric_value

    # Reranked metrics
    for metric_name, metric_value in evaluation_results['reranked'].items():
        metrics_to_store[f'reranked_{metric_name}'] = metric_value

    # Improvement percentages
    for metric_name, improvement_pct in evaluation_results['improvements_pct'].items():
        metrics_to_store[f'improvement_pct_{metric_name}'] = improvement_pct

    # Configuration and metadata
    metrics_to_store['num_queries_evaluated'] = evaluation_results['num_queries_evaluated']
    metrics_to_store['config_hash'] = config_hash

    # Save metrics
    print("\nSaving metrics to database...\n")
    success, message = save_metrics(db_connection, experiment_id, metrics_to_store, export_to_file=True)

    # Complete experiment
    if success:
        notes = f"Successfully evaluated reranking on {evaluation_results['num_queries_evaluated']} queries. "
        notes += f"Precision@5 improved {evaluation_results['improvements_pct']['precision@5']:.2f}%"

        complete_experiment(db_connection, experiment_id, status='completed', notes=notes)

        # Display results summary
        print("\n" + "=" * 70)
        print("EXPERIMENT RESULTS SUMMARY")
        print("=" * 70)

        print(f"\nExperiment ID: {experiment_id}")
        print(f"Experiment Name: {EXPERIMENT_NAME}")
        print(f"Status: Completed")
        print(f"Config Hash: {config_hash}")

        print(f"\nKey Improvements:")
        for metric_name, improvement_pct in evaluation_results['improvements_pct'].items():
            sign = "+" if improvement_pct > 0 else ""
            baseline_val = evaluation_results['baseline'][metric_name]
            reranked_val = evaluation_results['reranked'][metric_name]
            print(f"  {metric_name}:")
            print(f"    Baseline:  {baseline_val:.4f}")
            print(f"    Reranked:  {reranked_val:.4f}")
            print(f"    Improvement: {sign}{improvement_pct:.2f}%")

        print(f"\nResults exported to:")
        print(f"  Database: evaluation_results table (experiment_id={experiment_id})")
        print(f"  JSON: data/experiment_results/experiment_{experiment_id}_*.json")

        print("\n" + "=" * 70)
        print("NEXT STEPS")
        print("=" * 70)
        print("\n1. Review the visualizations above to understand the impact")
        print("2. Compare results with other techniques:")
        print("   - evaluation-lab/03-compare-experiments.ipynb")
        print("   - evaluation-lab/04-plot-improvements.ipynb")
        print("3. Try other reranking models:")
        print("   - cross-encoder/ms-marco-TinyBERT-L-2-v2 (faster)")
        print("   - cross-encoder/qnli-distilroberta-base (alternative)")
        print("4. Experiment with different TOP_K_INITIAL values")
        print("5. Combine with other techniques (query expansion, hybrid search)")

    else:
        print("\n✗ Failed to track experiment")
        complete_experiment(db_connection, experiment_id, status='failed', notes='Failed to save metrics')

else:
    print("⚠️  Cannot track experiment: evaluation_results not available")

# Close database connection
print("\n\nClosing database connection...")
if embeddings_db:
    embeddings_db.close()
db_connection.close()
print("✓ All connections closed")
