## Prerequisites

1. ✅ foundation/00-setup-postgres-schema.ipynb
2. ✅ foundation/02-rag-postgresql-persistent.ipynb
3. ✅ evaluation-lab/01-create-ground-truth-human-in-loop.ipynb

## Configuration

In [None]:
import ollama
import psycopg2
import psycopg2.extras
import json
import math
import numpy as np
import pandas as pd
from datetime import datetime
from typing import List, Dict, Tuple, Optional
import hashlib

# Configuration
EMBEDDING_MODEL_ALIAS = "bge_base_en_v1_5"
EMBEDDING_MODEL = "hf.co/CompendiumLabs/bge-base-en-v1.5-gguf"
LLM_MODEL = "llama3.2:1b"  # Ollama model for query expansion
NUM_EXPANSIONS = 4  # How many query variants to generate
TOP_K_PER_QUERY = 5  # Results per variant
TOP_K_FINAL = 5  # Final results to return

EXPERIMENT_NAME = "query-expansion-llm"
TECHNIQUES_APPLIED = ["vector_retrieval", "llm_query_expansion"]

# PostgreSQL configuration
POSTGRES_CONFIG = {
    'host': 'localhost',
    'port': 5432,
    'database': 'rag_db',
    'user': 'postgres',
    'password': 'postgres',
}

## Load Embeddings from Registry

In [None]:
# Connect to PostgreSQL
db_connection = psycopg2.connect(
    host=POSTGRES_CONFIG['host'],
    port=POSTGRES_CONFIG['port'],
    database=POSTGRES_CONFIG['database'],
    user=POSTGRES_CONFIG['user'],
    password=POSTGRES_CONFIG['password']
)

print(f"Connected to PostgreSQL at {POSTGRES_CONFIG['host']}:{POSTGRES_CONFIG['port']}")

# List available embeddings in registry
def list_available_embeddings(db_conn) -> pd.DataFrame:
    """Query embedding_registry to show available models with metadata."""
    query = '''
        SELECT 
            model_alias,
            model_name,
            dimension,
            embedding_count,
            chunk_source_dataset,
            chunk_size_config,
            created_at,
            last_accessed
        FROM embedding_registry
        ORDER BY created_at DESC
    '''
    return pd.read_sql(query, db_conn)

# Load or verify embeddings exist
print("\nLooking for embeddings in registry...")
available_embeddings = list_available_embeddings(db_connection)

if available_embeddings.empty:
    print("✗ No embeddings found in registry")
    print("Please run foundation/02-rag-postgresql-persistent.ipynb first")
    raise ValueError("No embeddings available. Run foundation/02 first.")

print("\nAvailable embeddings:")
print(available_embeddings[['model_alias', 'embedding_count', 'dimension']])

# Check if our target embedding model is available
embedding_found = available_embeddings[available_embeddings['model_alias'] == EMBEDDING_MODEL_ALIAS]

if embedding_found.empty:
    print(f"\n✗ Embedding model '{EMBEDDING_MODEL_ALIAS}' not found")
    print("Please regenerate embeddings with foundation/02 using this model alias")
    raise ValueError(f"Embedding model {EMBEDDING_MODEL_ALIAS} not found")

# Load embedding metadata
embedding_meta = embedding_found.iloc[0]
embedding_count = embedding_meta['embedding_count']
embedding_dimension = embedding_meta['dimension']

print(f"\n✓ Found embedding model: {EMBEDDING_MODEL_ALIAS}")
print(f"  Count: {embedding_count:,} embeddings")
print(f"  Dimension: {embedding_dimension}")
print(f"  Created: {embedding_meta['created_at']}")

# Load ground truth test questions
print("\nLoading ground truth test questions...")
ground_truth_questions = []

with db_connection.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
    cur.execute('''
        SELECT 
            id,
            question,
            relevant_chunk_ids,
            quality_rating,
            source_type
        FROM evaluation_groundtruth
        WHERE quality_rating = 'good'
        ORDER BY id
        LIMIT 20
    ''')
    
    for row in cur.fetchall():
        ground_truth_questions.append({
            'id': row['id'],
            'question': row['question'],
            'relevant_chunk_ids': row['relevant_chunk_ids'],
            'quality_rating': row['quality_rating'],
            'source_type': row['source_type']
        })

print(f"✓ Loaded {len(ground_truth_questions)} ground truth questions")
if ground_truth_questions:
    print(f"  Sample: {ground_truth_questions[0]['question'][:80]}...")

# Helper function to get embeddings table name
def get_embeddings_table_name(model_alias: str) -> str:
    """Convert model alias to table name"""
    return f'embeddings_{model_alias.replace(".", "_")}'

embeddings_table_name = get_embeddings_table_name(EMBEDDING_MODEL_ALIAS)
print(f"\nEmbeddings table: {embeddings_table_name}")

## Implement LLM-Based Query Expansion

In [None]:
def expand_query_with_llm(query: str, num_expansions: int = 4, llm_model: str = 'llama3.2:1b') -> List[str]:
    """
    Generate semantically similar query reformulations using an LLM.
    
    Args:
        query: Original user question
        num_expansions: How many variants to generate
        llm_model: Ollama model name
        
    Returns:
        List of [original_query, variant1, variant2, ...] (up to num_expansions + 1)
    """
    prompt = f"""Generate {num_expansions} different ways to ask this question. Each variant should:
- Mean the same thing as the original
- Use different wording and phrasing
- Be a complete question that can stand alone

Original question: {query}

Generate exactly {num_expansions} variants. Format each on a new line starting with "Q:".
Example format:
Q: How is something done?
Q: What is the process of something?
Q: What are the steps involved in something?

Now generate {num_expansions} variants for the original question:"""
    
    try:
        response = ollama.chat(
            model=llm_model,
            messages=[{'role': 'user', 'content': prompt}]
        )
        
        content = response['message']['content']
        variants = [query]  # Start with original
        
        for line in content.split('\n'):
            line = line.strip()
            if line.startswith('Q:'):
                variant = line[2:].strip()
                if variant and len(variant) > 10 and variant not in variants:  # Avoid duplicates
                    variants.append(variant)
            elif line and '?' in line and len(line) > 10:
                # Handle variations where Q: might not be present
                if not any(x in line[:5] for x in ['1.', '2.', '3.', '-']):
                    # Likely a question, add it
                    if line not in variants:
                        variants.append(line)
        
        return variants[:num_expansions + 1]  # Original + up to num_expansions
    
    except Exception as e:
        print(f"  ✗ LLM expansion failed: {e}")
        return [query]  # Fallback to original only


def retrieve_multi_query(queries: List[str], 
                         db_conn, 
                         embeddings_table: str,
                         embedding_model: str,
                         top_k_per_query: int = 5) -> List[Tuple[str, float, int]]:
    """
    Retrieve for each query variant, then merge and deduplicate results.
    
    Args:
        queries: List of query variants
        db_conn: PostgreSQL connection
        embeddings_table: Name of the embeddings table
        embedding_model: Model name for embeddings
        top_k_per_query: Results per variant
        
    Returns:
        List of (chunk_text, max_similarity_score, chunk_id) tuples, deduplicated and sorted by score
    """
    # Retrieve for each variant
    all_results = {}  # chunk_id → (chunk_text, max_similarity_score)
    
    for query_variant in queries:
        # Generate embedding for this query variant
        query_emb_response = ollama.embed(model=embedding_model, input=query_variant)
        query_emb = query_emb_response['embeddings'][0]
        
        # Search for similar chunks
        with db_conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
            cur.execute(f'''
                SELECT 
                    id,
                    chunk_text,
                    1 - (embedding <=> %s::vector) as similarity
                FROM {embeddings_table}
                ORDER BY embedding <=> %s::vector
                LIMIT %s
            ''', (query_emb, query_emb, top_k_per_query))
            
            results = cur.fetchall()
            
            for chunk in results:
                chunk_id = chunk['id']
                chunk_text = chunk['chunk_text']
                score = chunk['similarity']
                
                # Keep chunk with highest similarity score across all variants
                if chunk_id not in all_results or score > all_results[chunk_id][1]:
                    all_results[chunk_id] = (chunk_text, score)
    
    # Convert to list and sort by score (descending)
    merged = [
        (chunk_text, score, chunk_id)
        for chunk_id, (chunk_text, score) in all_results.items()
    ]
    merged.sort(key=lambda x: x[1], reverse=True)
    
    return merged


def retrieve_with_query_expansion(query: str, 
                                  db_conn,
                                  embeddings_table: str,
                                  embedding_model: str,
                                  num_expansions: int = 4, 
                                  top_k_per_query: int = 5,
                                  top_k_final: int = 5,
                                  verbose: bool = True) -> List[Tuple[str, float, int]]:
    """
    Complete retrieval pipeline with query expansion.
    
    Args:
        query: Original user question
        db_conn: PostgreSQL connection
        embeddings_table: Name of embeddings table
        embedding_model: Model name for embeddings
        num_expansions: Number of variants to generate
        top_k_per_query: Results per variant
        top_k_final: Final results to return
        verbose: Whether to print debug info
        
    Returns:
        Top K results after merging multi-query retrieval
    """
    # Step 1: Expand query
    query_variants = expand_query_with_llm(query, num_expansions)
    
    if verbose:
        print(f"✓ Generated {len(query_variants)} query variants:")
        for i, v in enumerate(query_variants):
            prefix = "[Original]" if i == 0 else f"[Variant {i}]"
            print(f"  {prefix} {v[:80]}{'...' if len(v) > 80 else ''}")
    
    # Step 2: Retrieve for each variant
    merged_results = retrieve_multi_query(
        query_variants, 
        db_conn,
        embeddings_table,
        embedding_model,
        top_k_per_query
    )
    
    if verbose:
        print(f"✓ Retrieved {len(merged_results)} deduplicated results across {len(query_variants)} variants")
    
    # Step 3: Return top K
    final_results = merged_results[:top_k_final]
    
    if verbose:
        print(f"✓ Returning top {len(final_results)} results")
    
    return final_results


# Test query expansion on first question
if ground_truth_questions:
    test_query = ground_truth_questions[0]['question']
    print("\n" + "="*70)
    print("TESTING QUERY EXPANSION")
    print("="*70)
    print(f"\nOriginal query: {test_query}\n")
    
    test_results = retrieve_with_query_expansion(
        test_query,
        db_connection,
        embeddings_table_name,
        EMBEDDING_MODEL,
        num_expansions=NUM_EXPANSIONS,
        top_k_per_query=TOP_K_PER_QUERY,
        top_k_final=TOP_K_FINAL,
        verbose=True
    )
    
    print(f"\nTop {TOP_K_FINAL} results:")
    for i, (chunk_text, score, chunk_id) in enumerate(test_results):
        preview = chunk_text[:150].replace('\n', ' ') + '...'
        print(f"  [{i+1}] (score: {score:.4f}, id: {chunk_id}) {preview}")

## Evaluate Impact

In [None]:
# Metric computation functions

def precision_at_k(retrieved_chunk_ids: List[int], 
                   relevant_chunk_ids: List[int], 
                   k: int = 5) -> float:
    """Precision@K: What percentage of top-K results are relevant?"""
    if k == 0:
        return 0.0
    
    retrieved_k = retrieved_chunk_ids[:k]
    relevant_set = set(relevant_chunk_ids)
    
    num_relevant_in_k = sum(1 for chunk_id in retrieved_k if chunk_id in relevant_set)
    
    return num_relevant_in_k / k


def recall_at_k(retrieved_chunk_ids: List[int], 
                relevant_chunk_ids: List[int], 
                k: int = 5) -> float:
    """Recall@K: What percentage of all relevant chunks were found in top-K?"""
    if len(relevant_chunk_ids) == 0:
        return 0.0
    
    retrieved_k = retrieved_chunk_ids[:k]
    relevant_set = set(relevant_chunk_ids)
    
    num_relevant_found = sum(1 for chunk_id in retrieved_k if chunk_id in relevant_set)
    
    return num_relevant_found / len(relevant_set)


def mean_reciprocal_rank(retrieved_chunk_ids: List[int], 
                         relevant_chunk_ids: List[int]) -> float:
    """MRR: How quickly do we find the first relevant result?"""
    relevant_set = set(relevant_chunk_ids)
    
    for rank, chunk_id in enumerate(retrieved_chunk_ids, start=1):
        if chunk_id in relevant_set:
            return 1.0 / rank
    
    return 0.0


def ndcg_at_k(retrieved_chunk_ids: List[int], 
              relevant_chunk_ids: List[int], 
              k: int = 5) -> float:
    """NDCG@K: How well-ranked are results? (rewards relevant results at top)"""
    
    def dcg_score(relevance_scores: List[float]) -> float:
        """Compute DCG from relevance scores."""
        return sum(
            (2**rel - 1) / math.log2(rank + 2)
            for rank, rel in enumerate(relevance_scores)
        )
    
    if k == 0 or len(relevant_chunk_ids) == 0:
        return 0.0
    
    # Get top-K retrieved
    retrieved_k = retrieved_chunk_ids[:k]
    relevant_set = set(relevant_chunk_ids)
    
    # Binary relevance: 1 if relevant, 0 if not
    relevance = [1 if chunk_id in relevant_set else 0 for chunk_id in retrieved_k]
    
    # Compute DCG for retrieved ranking
    dcg = dcg_score(relevance)
    
    # Compute ideal DCG (perfect ranking)
    ideal_relevance = sorted(relevance, reverse=True)
    idcg = dcg_score(ideal_relevance)
    
    if idcg == 0:
        return 0.0
    
    return dcg / idcg


def evaluate_query_expansion(test_questions: List[Dict], 
                             db_conn,
                             embeddings_table: str,
                             embedding_model: str) -> Dict:
    """
    Compare baseline (single query) vs expansion (multi-query).
    
    Focus: Recall improvement (finding more relevant chunks)
    
    Args:
        test_questions: List of question dicts with 'question' and 'relevant_chunk_ids'
        db_conn: PostgreSQL connection
        embeddings_table: Name of embeddings table
        embedding_model: Model name for embeddings
        
    Returns:
        Dict with baseline, expanded metrics, and improvements
    """
    
    baseline_results = []
    expanded_results = []
    query_details = []
    
    print(f"\nEvaluating query expansion on {len(test_questions)} test questions...")
    print("-" * 70)
    
    for q_idx, q in enumerate(test_questions, 1):
        query = q['question']
        relevant_ids = q['relevant_chunk_ids']
        
        # Baseline: single query vector retrieval
        query_emb_response = ollama.embed(model=embedding_model, input=query)
        query_emb = query_emb_response['embeddings'][0]
        
        with db_conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
            cur.execute(f'''
                SELECT 
                    id,
                    1 - (embedding <=> %s::vector) as similarity
                FROM {embeddings_table}
                ORDER BY embedding <=> %s::vector
                LIMIT 5
            ''', (query_emb, query_emb))
            
            baseline_chunks = cur.fetchall()
            baseline_ids = [chunk['id'] for chunk in baseline_chunks]
        
        # With expansion
        expanded_chunks = retrieve_with_query_expansion(
            query, 
            db_conn,
            embeddings_table,
            embedding_model, 
            num_expansions=NUM_EXPANSIONS, 
            top_k_per_query=TOP_K_PER_QUERY,
            top_k_final=TOP_K_FINAL,
            verbose=False
        )
        expanded_ids = [chunk_id for _, _, chunk_id in expanded_chunks]
        
        # Compute metrics
        baseline_metrics = {
            'precision@5': precision_at_k(baseline_ids, relevant_ids, k=5),
            'recall@5': recall_at_k(baseline_ids, relevant_ids, k=5),
            'mrr': mean_reciprocal_rank(baseline_ids, relevant_ids),
            'ndcg@5': ndcg_at_k(baseline_ids, relevant_ids, k=5)
        }
        
        expanded_metrics = {
            'precision@5': precision_at_k(expanded_ids, relevant_ids, k=5),
            'recall@5': recall_at_k(expanded_ids, relevant_ids, k=5),
            'mrr': mean_reciprocal_rank(expanded_ids, relevant_ids),
            'ndcg@5': ndcg_at_k(expanded_ids, relevant_ids, k=5)
        }
        
        baseline_results.append(baseline_metrics)
        expanded_results.append(expanded_metrics)
        
        # Track per-query details
        query_details.append({
            'question': query,
            'relevant_count': len(relevant_ids),
            'baseline': baseline_metrics,
            'expanded': expanded_metrics,
            'improvement_recall': expanded_metrics['recall@5'] - baseline_metrics['recall@5'],
            'improvement_ndcg': expanded_metrics['ndcg@5'] - baseline_metrics['ndcg@5']
        })
        
        # Progress output
        if q_idx % 5 == 0 or q_idx == len(test_questions):
            print(f"  [{q_idx}/{len(test_questions)}] Evaluated")
    
    # Aggregate metrics
    def aggregate(results):
        return {
            metric: sum(r[metric] for r in results) / len(results)
            for metric in results[0].keys()
        }
    
    baseline_agg = aggregate(baseline_results)
    expanded_agg = aggregate(expanded_results)
    
    # Compute improvements
    improvements = {}
    for metric in baseline_agg.keys():
        baseline_val = baseline_agg[metric]
        expanded_val = expanded_agg[metric]
        
        if baseline_val > 0:
            improvements[metric] = ((expanded_val - baseline_val) / baseline_val * 100)
        else:
            improvements[metric] = 0.0
    
    return {
        'baseline': baseline_agg,
        'expanded': expanded_agg,
        'improvements_pct': improvements,
        'per_query': query_details,
        'num_queries': len(test_questions)
    }


# Run evaluation
print("\n" + "="*70)
print("EVALUATION: BASELINE VS QUERY EXPANSION")
print("="*70)

eval_results = evaluate_query_expansion(
    ground_truth_questions,
    db_connection,
    embeddings_table_name,
    EMBEDDING_MODEL
)

# Display results
print("\n" + "="*70)
print("RESULTS SUMMARY")
print("="*70)

print(f"\nQueries evaluated: {eval_results['num_queries']}")
print(f"\n{'Metric':<20} {'Baseline':<15} {'Expanded':<15} {'Improvement':<15}")
print("-" * 65)

for metric in ['precision@5', 'recall@5', 'ndcg@5', 'mrr']:
    baseline_val = eval_results['baseline'][metric]
    expanded_val = eval_results['expanded'][metric]
    improvement = eval_results['improvements_pct'][metric]
    
    improvement_str = f"+{improvement:.1f}%" if improvement >= 0 else f"{improvement:.1f}%"
    print(f"{metric:<20} {baseline_val:<15.4f} {expanded_val:<15.4f} {improvement_str:<15}")

# Detailed per-query analysis
print("\n" + "="*70)
print("ANALYSIS: WHICH QUERIES BENEFIT MOST FROM EXPANSION?")
print("="*70)

# Sort by recall improvement
top_improvements = sorted(
    eval_results['per_query'],
    key=lambda x: x['improvement_recall'],
    reverse=True
)

print("\nTop 5 questions with highest Recall@5 improvement:")
for i, q in enumerate(top_improvements[:5], 1):
    baseline_recall = q['baseline']['recall@5']
    expanded_recall = q['expanded']['recall@5']
    improvement = q['improvement_recall'] * 100
    
    print(f"\n  [{i}] Improvement: {improvement:+.1f}% (Recall@5: {baseline_recall:.2f} → {expanded_recall:.2f})")
    print(f"      Q: {q['question'][:80]}...")
    print(f"      Relevant chunks to find: {q['relevant_count']}")

# Questions that don't benefit
no_benefit = [q for q in eval_results['per_query'] if q['improvement_recall'] <= 0]
if no_benefit:
    print(f"\n  Queries with no Recall improvement: {len(no_benefit)}")
    print("  (Original single query was already optimal)")

# Distribution statistics
recall_improvements = [q['improvement_recall'] * 100 for q in eval_results['per_query']]
print(f"\nRecall@5 Improvement Statistics:")
print(f"  Mean: {np.mean(recall_improvements):+.1f}%")
print(f"  Median: {np.median(recall_improvements):+.1f}%")
print(f"  Min: {np.min(recall_improvements):+.1f}%")
print(f"  Max: {np.max(recall_improvements):+.1f}%")
print(f"  Std Dev: {np.std(recall_improvements):.1f}%")

## Track Experiment

In [None]:
# Experiment tracking utilities (inline from foundation/00)

def compute_config_hash(config_dict: Dict) -> str:
    """Create deterministic SHA256 hash of configuration."""
    config_str = json.dumps(config_dict, sort_keys=True)
    hash_obj = hashlib.sha256(config_str.encode())
    return hash_obj.hexdigest()[:12]


def start_experiment(db_conn, experiment_name: str, 
                    embedding_model_alias: str = None,
                    config: Dict = None,
                    techniques: List[str] = None,
                    notes: str = None) -> int:
    """Start a new experiment and return its ID for tracking."""
    if config is None:
        config = {}
    if techniques is None:
        techniques = []
    
    config_hash = compute_config_hash(config)
    
    with db_conn.cursor() as cur:
        cur.execute('''
            INSERT INTO experiments (
                experiment_name, notebook_path, embedding_model_alias,
                config_hash, config_json, techniques_applied, notes, status
            )
            VALUES (%s, %s, %s, %s, %s, %s, %s, 'running')
            RETURNING id
        ''', (
            experiment_name,
            'advanced-techniques/06-query-expansion.ipynb',
            embedding_model_alias,
            config_hash,
            json.dumps(config),
            techniques,
            notes
        ))
        exp_id = cur.fetchone()[0]
    db_conn.commit()
    print(f"✓ Started experiment #{exp_id}: {experiment_name}")
    return exp_id


def save_metrics(db_conn, experiment_id: int, metrics_dict: Dict,
                export_to_file: bool = True,
                export_dir: str = 'data/experiment_results') -> Tuple[bool, str]:
    """Save experiment metrics to database and optionally to JSON file."""
    import os
    
    try:
        with db_conn.cursor() as cur:
            for metric_name, metric_data in metrics_dict.items():
                # Handle both simple floats and nested dicts with details
                if isinstance(metric_data, dict):
                    metric_value = metric_data.get('value', 0.0)
                    metric_details = metric_data.get('details', {})
                else:
                    metric_value = metric_data
                    metric_details = {}
                
                cur.execute('''
                    INSERT INTO evaluation_results (
                        experiment_id, metric_name, metric_value, metric_details_json
                    )
                    VALUES (%s, %s, %s, %s)
                ''', (
                    experiment_id,
                    metric_name,
                    float(metric_value),
                    json.dumps(metric_details) if metric_details else '{}'
                ))
        db_conn.commit()
        
        # Export to file if requested
        file_path = None
        if export_to_file:
            os.makedirs(export_dir, exist_ok=True)
            timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
            file_path = os.path.join(export_dir, f'experiment_{experiment_id}_{timestamp}.json')
            with open(file_path, 'w') as f:
                json.dump({
                    'experiment_id': experiment_id,
                    'timestamp': timestamp,
                    'metrics': metrics_dict
                }, f, indent=2)
        
        msg = f"✓ Saved {len(metrics_dict)} metrics for experiment #{experiment_id}"
        if file_path:
            msg += f" to {file_path}"
        print(msg)
        return True, msg
    except Exception as e:
        msg = f"✗ Failed to save metrics: {e}"
        print(msg)
        db_conn.rollback()
        return False, msg


def complete_experiment(db_conn, experiment_id: int, 
                       status: str = 'completed',
                       notes: str = None) -> bool:
    """Mark an experiment as complete."""
    try:
        with db_conn.cursor() as cur:
            update_notes = ", notes = %s" if notes else ""
            params = [status, experiment_id] if not notes else [status, notes, experiment_id]
            
            cur.execute(f'''
                UPDATE experiments
                SET status = %s{update_notes}, completed_at = CURRENT_TIMESTAMP
                WHERE id = %s
            ''', params)
        db_connection.commit()
        print(f"✓ Experiment #{experiment_id} marked as {status}")
        return True
    except Exception as e:
        print(f"✗ Failed to complete experiment: {e}")
        db_connection.rollback()
        return False


# Start experiment tracking
print("\n" + "="*70)
print("TRACKING EXPERIMENT")
print("="*70)

# Prepare experiment configuration
config_dict = {
    'embedding_model_alias': EMBEDDING_MODEL_ALIAS,
    'embedding_model': EMBEDDING_MODEL,
    'llm_model': LLM_MODEL,
    'num_expansions': NUM_EXPANSIONS,
    'top_k_per_query': TOP_K_PER_QUERY,
    'top_k_final': TOP_K_FINAL,
    'num_test_queries': len(ground_truth_questions),
}

print("\nExperiment Configuration:")
for key, value in config_dict.items():
    print(f"  {key}: {value}")

# Start experiment in database
experiment_id = start_experiment(
    db_connection,
    experiment_name=EXPERIMENT_NAME,
    embedding_model_alias=EMBEDDING_MODEL_ALIAS,
    config=config_dict,
    techniques=TECHNIQUES_APPLIED,
    notes=f"Query expansion evaluation on {len(ground_truth_questions)} test questions. "
          f"Expects +15-30% Recall improvement over baseline."
)

# Prepare metrics for storage
metrics_to_store = {}

# Store baseline metrics
for metric_name, value in eval_results['baseline'].items():
    metrics_to_store[f'baseline_{metric_name}'] = float(value)

# Store expanded metrics
for metric_name, value in eval_results['expanded'].items():
    metrics_to_store[f'expanded_{metric_name}'] = float(value)

# Store improvements
for metric_name, value in eval_results['improvements_pct'].items():
    metrics_to_store[f'improvement_pct_{metric_name}'] = float(value)

# Store test count
metrics_to_store['num_test_queries'] = len(ground_truth_questions)

# Store per-query details as JSON
query_improvements = [
    {
        'question': q['question'],
        'recall_improvement_pct': q['improvement_recall'] * 100,
        'ndcg_improvement_pct': q['improvement_ndcg'] * 100,
    }
    for q in eval_results['per_query']
]

metrics_to_store['per_query_improvements'] = {
    'value': 0.0,
    'details': {'improvements': query_improvements}
}

# Save metrics to database and file
print("\nSaving metrics...")
success, msg = save_metrics(db_connection, experiment_id, metrics_to_store, export_to_file=True)

# Complete the experiment
print("\nMarking experiment as complete...")
complete_experiment(
    db_connection,
    experiment_id,
    status='completed',
    notes=f"Evaluation complete. Recall@5 improvement: {eval_results['improvements_pct']['recall@5']:.1f}%"
)

# Final summary
print("\n" + "="*70)
print("EXPERIMENT COMPLETE")
print("="*70)

print(f"\nExperiment #{experiment_id}: {EXPERIMENT_NAME}")
print(f"Status: Completed")
print(f"Timestamp: {datetime.now().isoformat()}")

print(f"\nKey Results:")
print(f"  Precision@5 improvement:  {eval_results['improvements_pct']['precision@5']:+.1f}%")
print(f"  Recall@5 improvement:     {eval_results['improvements_pct']['recall@5']:+.1f}%")
print(f"  NDCG@5 improvement:       {eval_results['improvements_pct']['ndcg@5']:+.1f}%")
print(f"  MRR improvement:          {eval_results['improvements_pct']['mrr']:+.1f}%")

print(f"\nMetrics Stored:")
print(f"  {len(metrics_to_store)} metrics stored to database")
print(f"  Results exported to data/experiment_results/")

print(f"\nNext Steps:")
print(f"  1. Review the detailed per-query improvements above")
print(f"  2. Compare with other techniques using evaluation-lab/04")
print(f"  3. Combine query expansion with reranking for further improvements")
print(f"  4. Evaluate on more diverse question types in evaluation-lab/01")