## Prerequisites

1. ✅ foundation/00-setup-postgres-schema.ipynb
2. ✅ foundation/02-rag-postgresql-persistent.ipynb
3. ✅ evaluation-lab/01-create-ground-truth-human-in-loop.ipynb

## Configuration

In [None]:
EMBEDDING_MODEL_ALIAS = "all-minilm-l6-v2"
TOP_K = 5
CONFIDENCE_THRESHOLD = 0.5

EXPERIMENT_NAME = "citation-tracking-transparency"
TECHNIQUES_APPLIED = ["vector_retrieval", "citation_tracking", "confidence_scoring"]

## Load Embeddings from Registry

In [None]:
import psycopg2
import psycopg2.extras
import ollama
import json
import pandas as pd
import numpy as np
import hashlib
import re
from datetime import datetime
from typing import List, Dict, Tuple, Optional
import os

# PostgreSQL connection
POSTGRES_CONFIG = {
    'host': 'localhost',
    'port': 5432,
    'database': 'rag_db',
    'user': 'postgres',
    'password': 'postgres',
}

# Create database connection
try:
    db_connection = psycopg2.connect(
        host=POSTGRES_CONFIG['host'],
        port=POSTGRES_CONFIG['port'],
        database=POSTGRES_CONFIG['database'],
        user=POSTGRES_CONFIG['user'],
        password=POSTGRES_CONFIG['password']
    )
    print("✓ Connected to PostgreSQL")
except psycopg2.OperationalError as e:
    print(f"✗ Failed to connect to PostgreSQL: {e}")
    raise

# ============================================================================
# PART 1: LOAD EMBEDDINGS FROM REGISTRY
# ============================================================================

def list_available_embeddings(db_connection) -> pd.DataFrame:
    """Query embedding_registry to show available models with metadata."""
    query = '''
        SELECT
            model_alias,
            model_name,
            dimension,
            embedding_count,
            chunk_source_dataset,
            chunk_size_config,
            created_at,
            last_accessed
        FROM embedding_registry
        ORDER BY created_at DESC
    '''
    return pd.read_sql(query, db_connection)


class PostgreSQLVectorDB:
    """Helper to load embeddings from PostgreSQL without regeneration."""

    def __init__(self, config, table_name, preserve_existing=True):
        self.config = config
        self.table_name = table_name
        self.conn = psycopg2.connect(
            host=config['host'],
            port=config['port'],
            database=config['database'],
            user=config['user'],
            password=config['password']
        )
        print(f'✓ Connected to table: {table_name}')

    def get_chunk_count(self):
        """How many embeddings are stored?"""
        with self.conn.cursor() as cur:
            cur.execute(f'SELECT COUNT(*) FROM {self.table_name}')
            return cur.fetchone()[0]

    def similarity_search(self, query_embedding, top_n=5):
        """Retrieve most similar chunks using pgvector.

        Returns:
            List of tuples: (chunk_text, similarity_score, chunk_id)
        """
        with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
            cur.execute(f'''
                SELECT id,
                       content as chunk_text,
                       1 - (embedding <=> %s::vector) as similarity
                FROM {self.table_name}
                ORDER BY embedding <=> %s::vector
                LIMIT %s
            ''', (query_embedding, query_embedding, top_n))

            results = cur.fetchall()
            return [(row['chunk_text'], row['similarity'], row['id']) for row in results]

    def close(self):
        if self.conn:
            self.conn.close()


def load_or_generate(db_connection, embedding_model_alias, preserve_existing=True):
    """Load embeddings from registry OR show instructions if not available."""

    print(f"\n{'='*70}")
    print(f"Checking for embeddings: '{embedding_model_alias}'...")
    print(f"{'='*70}\n")

    try:
        with db_connection.cursor() as cur:
            cur.execute('''
                SELECT id, dimension, embedding_count, created_at, metadata_json
                FROM embedding_registry
                WHERE model_alias = %s
            ''', (embedding_model_alias,))
            registry_entry = cur.fetchone()
    except Exception as e:
        print(f"Could not query registry: {e}")
        print("Make sure foundation/00-setup-postgres-schema.ipynb has been run.")
        return None

    # Case A: Embeddings exist
    if registry_entry:
        reg_id, dimension, embedding_count, created_at, metadata_json = registry_entry

        print(f"✓ FOUND EXISTING EMBEDDINGS")
        print(f"  Model:      {embedding_model_alias}")
        print(f"  Count:      {embedding_count:,} embeddings")
        print(f"  Dimension:  {dimension}")
        print(f"  Created:    {created_at}")
        print(f"\n  TIME SAVINGS:")
        print(f"    Loading:       <1 second")
        print(f"    Regenerating:  ~50+ minutes")
        print(f"    ➜ You save 50+ minutes by loading!\n")

        if preserve_existing:
            # Auto-load (for scripts/notebooks)
            print("Loading existing embeddings...\n")

            try:
                table_name = f'embeddings_{embedding_model_alias.replace(".", "_")}'

                db_instance = PostgreSQLVectorDB(
                    config=POSTGRES_CONFIG,
                    table_name=table_name,
                    preserve_existing=True
                )

                count = db_instance.get_chunk_count()
                print(f"✓ LOADED SUCCESSFULLY")
                print(f"  Embeddings: {count:,}")
                print(f"  Table: {table_name}")
                print(f"  Status: Ready for retrieval\n")

                return db_instance

            except Exception as e:
                print(f"\n✗ Error loading embeddings: {e}")
                print(f"\nTroubleshooting:")
                print(f"  1. Verify PostgreSQL is running")
                print(f"  2. Check POSTGRES_CONFIG settings")
                print(f"  3. Run foundation/02 to generate embeddings first")
                return None

    # Case B: No embeddings found
    else:
        print(f"✗ NO EMBEDDINGS FOUND")
        print(f"  Model: {embedding_model_alias}")
        print(f"\nTo create embeddings, run:")
        print(f"  foundation/02-rag-postgresql-persistent.ipynb")
        print(f"\nThen come back and re-run this cell.\n")
        return None


# Discover and load embeddings
print("Step 1: Discovering available embeddings...\n")
available = list_available_embeddings(db_connection)

if available.empty:
    print("⚠️  No embeddings found in registry yet.")
    print("Run foundation/02-rag-postgresql-persistent.ipynb first.\n")
else:
    print("Available embeddings:")
    print(available.to_string(index=False))
    print()

# Load embeddings using the pattern
print("\nStep 2: Loading embeddings using load-or-generate pattern...\n")
embeddings_db = load_or_generate(
    db_connection=db_connection,
    embedding_model_alias=EMBEDDING_MODEL_ALIAS,
    preserve_existing=True  # Auto-load if available
)

if embeddings_db:
    print("✓ Success! Embeddings loaded and ready for retrieval.")
else:
    print("⚠️  Could not load embeddings. See instructions above.")
    embeddings_db = None

## Implement Citation Tracking

In [None]:
# ============================================================================
# PART 2: CITATION TRACKING IMPLEMENTATION
# ============================================================================

def retrieve_with_citations(query: str, 
                           embeddings_db: PostgreSQLVectorDB, 
                           embedding_model: str, 
                           top_k: int = 5) -> List[Dict]:
    """
    Retrieve chunks and prepare citation data.
    
    For each retrieved chunk, extract source information and assign citation ID.
    This enables tracking which chunks contributed to the answer.
    
    Args:
        query: User question
        embeddings_db: PostgreSQLVectorDB instance
        embedding_model: Model alias for embeddings
        top_k: Number of results to retrieve
        
    Returns:
        List of dicts with citation info:
        {
            'chunk_id': int,
            'chunk_text': str,
            'similarity_score': float,
            'citation_id': str,  # e.g., "[1]", "[2]"
            'source': str,       # e.g., "Article: Einstein"
            'rank': int          # Position in retrieval results
        }
    """
    # Embed the query
    query_emb = ollama.embed(model=embedding_model, input=query)['embeddings'][0]
    
    # Retrieve most similar chunks
    results = embeddings_db.similarity_search(query_emb, top_n=top_k)
    
    citations = []
    for idx, (chunk_text, score, chunk_id) in enumerate(results, start=1):
        # Extract source from chunk text (format: "Article: Title\n\nContent...")
        source = "Unknown Source"
        if chunk_text.startswith("Article:"):
            # Get first line as source
            first_line = chunk_text.split('\n')[0]
            source = first_line  # "Article: Title"
        elif chunk_text.startswith("Section:"):
            first_line = chunk_text.split('\n')[0]
            source = first_line
        else:
            # Try to extract from content
            lines = chunk_text.split('\n')
            for line in lines:
                if line.strip() and len(line.strip()) < 100:
                    source = line.strip()[:80]
                    break
        
        citations.append({
            'chunk_id': chunk_id,
            'chunk_text': chunk_text,
            'similarity_score': float(score),
            'citation_id': f"[{idx}]",
            'source': source,
            'rank': idx
        })
    
    return citations


def generate_with_citations(query: str, 
                           citations: List[Dict], 
                           llm_model: str = 'llama3.2:1b') -> Dict:
    """
    Generate answer with inline citations.
    
    Takes retrieved chunks with citation IDs and asks the LLM to answer
    using only those sources, including citation markers [1], [2], etc.
    
    Args:
        query: User question
        citations: List of citation dicts from retrieve_with_citations()
        llm_model: Ollama model to use for generation
        
    Returns:
        dict with:
        {
            'answer': str,           # Answer with [1], [2] citation markers
            'confidence': float,     # Min similarity score of retrieved chunks
            'sources': List[str],    # Source references with citations
            'citations_used': List   # Full citation data
        }
    """
    # Build context from citations with inline markers
    context_parts = []
    for c in citations:
        # Include first 500 chars of chunk with citation marker
        preview = c['chunk_text'][:500]
        # Clean up preview
        if len(c['chunk_text']) > 500:
            preview = preview.rsplit(' ', 1)[0] + '...'
        context_parts.append(f"{c['citation_id']} {c['source']}\n{preview}")
    
    context = "\n\n".join(context_parts)
    
    # Build prompt requesting citations
    prompt = f"""Answer the following question using ONLY the provided sources. 
You MUST include citation markers [1], [2], [3], etc. in your answer to show which source each fact comes from.
Reference the sources by their citation ID in square brackets.

Question: {query}

Sources (use [1], [2], [3], etc. to cite):
{context}

Answer (include citation markers for each fact):"""
    
    # Generate answer with citations
    response = ollama.chat(
        model=llm_model,
        messages=[{'role': 'user', 'content': prompt}]
    )
    
    answer = response['message']['content']
    
    # Confidence: minimum similarity score (weakest link in the chain)
    confidence = min(c['similarity_score'] for c in citations) if citations else 0.0
    
    # Build source list with citations
    sources = [
        f"{c['citation_id']} {c['source']}"
        for c in citations
    ]
    
    return {
        'answer': answer,
        'confidence': confidence,
        'sources': sources,
        'citations_used': citations
    }


def validate_citations(answer: str, citations: List[Dict]) -> Dict:
    """
    Check if citations in answer actually match provided sources.
    
    Validates that:
    - All citations used in the answer refer to available sources
    - Citation IDs are properly formatted [1], [2], etc.
    - Helps detect hallucinated or incorrect citations
    
    Args:
        answer: Generated answer text
        citations: List of citation dicts with citation_id
        
    Returns:
        dict with:
        {
            'valid_citations': List[str],     # Citations that match sources
            'invalid_citations': List[str],   # Citations not in sources
            'unused_sources': List[str],      # Sources not cited in answer
            'citation_accuracy': float        # Fraction of citations that are valid
        }
    """
    # Extract all citation IDs from answer (e.g., [1], [2], [3])
    cited_ids = set(re.findall(r'\[\d+\]', answer))
    
    # Available citation IDs from sources
    available_ids = set(c['citation_id'] for c in citations)
    
    # Find valid vs invalid citations
    valid = cited_ids & available_ids  # Intersection
    invalid = cited_ids - available_ids  # Cited but not available
    unused = available_ids - cited_ids  # Available but not cited
    
    # Calculate accuracy: what fraction of citations are valid
    citation_accuracy = 0.0
    if cited_ids:
        citation_accuracy = len(valid) / len(cited_ids)
    
    return {
        'valid_citations': sorted(list(valid)),
        'invalid_citations': sorted(list(invalid)),
        'unused_sources': sorted(list(unused)),
        'citation_accuracy': citation_accuracy
    }


# Load test questions from ground truth
print("\nLoading ground truth test set...")
ground_truth_questions = []

with db_connection.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
    cur.execute('''
        SELECT
            id,
            question,
            relevant_chunk_ids,
            quality_rating,
            source_type
        FROM evaluation_groundtruth
        WHERE quality_rating = 'good'
        ORDER BY id
    ''')

    for row in cur.fetchall():
        ground_truth_questions.append({
            'id': row['id'],
            'question': row['question'],
            'relevant_chunk_ids': row['relevant_chunk_ids'] or [],
            'quality_rating': row['quality_rating'],
            'source_type': row['source_type']
        })

print(f"✓ Loaded {len(ground_truth_questions)} ground truth questions\n")

if ground_truth_questions:
    print(f"Sample question: {ground_truth_questions[0]['question'][:80]}...")
    print(f"Relevant chunks: {ground_truth_questions[0]['relevant_chunk_ids'][:3]}\n")

# Demonstrate citation tracking on first test question
if embeddings_db and ground_truth_questions:
    print("=" * 70)
    print("DEMONSTRATING CITATION TRACKING")
    print("=" * 70)
    
    test_question = ground_truth_questions[0]['question']
    print(f"\nQuery: {test_question}\n")
    
    # Step 1: Retrieve with citations
    print("Step 1: Retrieving chunks with citation data...")
    citations = retrieve_with_citations(test_question, embeddings_db, EMBEDDING_MODEL_ALIAS, top_k=TOP_K)
    
    print(f"✓ Retrieved {len(citations)} chunks with citation IDs:")
    for c in citations:
        source_preview = c['source'][:60]
        print(f"  {c['citation_id']} (score: {c['similarity_score']:.4f}) {source_preview}")
    
    # Step 2: Generate answer with citations
    print("\nStep 2: Generating answer with inline citations...")
    result = generate_with_citations(test_question, citations)
    
    print(f"\nGenerated Answer:\n{result['answer']}")
    print(f"\nConfidence (min similarity): {result['confidence']:.4f}")
    print(f"\nSources:")
    for source in result['sources']:
        print(f"  {source}")
    
    # Step 3: Validate citations
    print("\nStep 3: Validating citations...")
    validation = validate_citations(result['answer'], citations)
    
    print(f"  Valid citations: {validation['valid_citations']}")
    print(f"  Invalid citations: {validation['invalid_citations']}")
    print(f"  Unused sources: {validation['unused_sources']}")
    print(f"  Citation accuracy: {validation['citation_accuracy']:.2%}")
    print()

## Evaluate Citations

In [None]:
# ============================================================================
# PART 3: EVALUATE CITATIONS
# ============================================================================

def evaluate_citation_tracking(test_questions: List[Dict], 
                              embeddings_db: PostgreSQLVectorDB, 
                              embedding_model: str,
                              llm_model: str = 'llama3.2:1b') -> Tuple[Dict, List[Dict]]:
    """
    Evaluate citation quality across a test set.
    
    For each test question:
    1. Retrieve chunks with citations
    2. Generate answer with citations
    3. Validate citation accuracy
    4. Aggregate metrics
    
    Args:
        test_questions: List of test question dicts
        embeddings_db: PostgreSQLVectorDB instance
        embedding_model: Model for embedding queries
        llm_model: Model for answer generation
        
    Returns:
        Tuple of (aggregated_metrics, per_query_results)
        
        aggregated_metrics dict with:
        {
            'avg_confidence': float,           # Average min similarity score
            'avg_citation_accuracy': float,    # % of citations that are valid
            'avg_sources_used': float,         # Avg number of sources cited
            'avg_invalid_citations': float,    # Avg number of hallucinated citations
        }
        
        per_query_results list with per-query metrics
    """
    results = []
    
    print(f"\nEvaluating citation tracking on {len(test_questions)} test questions...")
    print("=" * 70)
    
    for i, q in enumerate(test_questions, 1):
        query = q['question']
        
        # Progress indicator
        if i % max(1, len(test_questions) // 10) == 0:
            print(f"Progress: {i}/{len(test_questions)} queries processed...")
        
        try:
            # Step 1: Retrieve with citations
            citations = retrieve_with_citations(query, embeddings_db, embedding_model, top_k=TOP_K)
            
            # Step 2: Generate with citations
            generated = generate_with_citations(query, citations, llm_model=llm_model)
            
            # Step 3: Validate citations
            validation = validate_citations(generated['answer'], citations)
            
            # Collect metrics
            results.append({
                'query': query[:100],  # Truncate for display
                'confidence': generated['confidence'],
                'citation_accuracy': validation['citation_accuracy'],
                'sources_used': len(validation['valid_citations']),
                'sources_provided': len(citations),
                'invalid_citations': len(validation['invalid_citations']),
                'unused_sources': len(validation['unused_sources']),
                'answer_preview': generated['answer'][:150]
            })
        
        except Exception as e:
            print(f"\n  Warning: Failed to evaluate query {i}: {e}")
            continue
    
    print("=" * 70)
    print(f"✓ Evaluation complete ({len(results)} queries)\n")
    
    # Aggregate metrics
    if not results:
        print("⚠️  No queries evaluated")
        return {
            'avg_confidence': 0.0,
            'avg_citation_accuracy': 0.0,
            'avg_sources_used': 0.0,
            'avg_invalid_citations': 0.0
        }, []
    
    metrics = {
        'avg_confidence': np.mean([r['confidence'] for r in results]),
        'avg_citation_accuracy': np.mean([r['citation_accuracy'] for r in results]),
        'avg_sources_used': np.mean([r['sources_used'] for r in results]),
        'avg_invalid_citations': np.mean([r['invalid_citations'] for r in results]),
        'min_confidence': np.min([r['confidence'] for r in results]),
        'max_confidence': np.max([r['confidence'] for r in results]),
        'total_queries_evaluated': len(results)
    }
    
    return metrics, results


# Run evaluation
print("\n" + "=" * 70)
print("EVALUATING CITATION TRACKING")
print("=" * 70)

if embeddings_db and ground_truth_questions:
    evaluation_metrics, per_query_results = evaluate_citation_tracking(
        ground_truth_questions,
        embeddings_db,
        EMBEDDING_MODEL_ALIAS,
        llm_model='llama3.2:1b'
    )
    
    # Display results
    print("\n" + "=" * 70)
    print("CITATION TRACKING EVALUATION RESULTS")
    print("=" * 70)
    
    print(f"\nQueries Evaluated: {evaluation_metrics['total_queries_evaluated']}")
    print(f"\nConfidence (Minimum Similarity Score):")
    print(f"  Average:  {evaluation_metrics['avg_confidence']:.4f}")
    print(f"  Min:      {evaluation_metrics['min_confidence']:.4f}")
    print(f"  Max:      {evaluation_metrics['max_confidence']:.4f}")
    
    print(f"\nCitation Accuracy:")
    print(f"  Average:  {evaluation_metrics['avg_citation_accuracy']:.2%}")
    print(f"  (Fraction of citations that match provided sources)")
    
    print(f"\nSource Usage:")
    print(f"  Avg sources cited:    {evaluation_metrics['avg_sources_used']:.2f} / {TOP_K}")
    print(f"  Avg unused sources:   {TOP_K - evaluation_metrics['avg_sources_used']:.2f}")
    
    print(f"\nCitation Validity:")
    print(f"  Avg invalid citations: {evaluation_metrics['avg_invalid_citations']:.2f}")
    print(f"  (Hallucinated citations not in retrieved sources)")
    
    # Show confidence distribution
    print(f"\n" + "=" * 70)
    print("CONFIDENCE DISTRIBUTION ANALYSIS")
    print("=" * 70)
    
    confidences = [r['confidence'] for r in per_query_results]
    
    # Categorize by confidence threshold
    high_confidence = sum(1 for c in confidences if c >= 0.7)
    medium_confidence = sum(1 for c in confidences if 0.5 <= c < 0.7)
    low_confidence = sum(1 for c in confidences if c < 0.5)
    
    print(f"\nHigh Confidence (>= 0.70):    {high_confidence} queries ({high_confidence/len(confidences)*100:.1f}%)")
    print(f"Medium Confidence (0.50-0.70): {medium_confidence} queries ({medium_confidence/len(confidences)*100:.1f}%)")
    print(f"Low Confidence (< 0.50):      {low_confidence} queries ({low_confidence/len(confidences)*100:.1f}%)")
    
    # Recommendations
    print(f"\n" + "=" * 70)
    print("TRANSPARENCY & USER TRUST ASSESSMENT")
    print("=" * 70)
    
    # Accuracy insights
    if evaluation_metrics['avg_citation_accuracy'] > 0.9:
        print(f"\n✓ Excellent: {evaluation_metrics['avg_citation_accuracy']:.0%} of citations are valid")
        print(f"  Users can trust the citations in answers")
    elif evaluation_metrics['avg_citation_accuracy'] > 0.75:
        print(f"\n+ Good: {evaluation_metrics['avg_citation_accuracy']:.0%} of citations are valid")
        print(f"  Most citations are accurate, minor hallucinations detected")
    else:
        print(f"\n- Poor: {evaluation_metrics['avg_citation_accuracy']:.0%} of citations are valid")
        print(f"  Consider adding citation validation to filter hallucinations")
    
    # Confidence insights
    if evaluation_metrics['avg_confidence'] > 0.7:
        print(f"\n✓ High confidence retrievals ({evaluation_metrics['avg_confidence']:.3f})")
        print(f"  Users can rely on the similarity scores")
    elif evaluation_metrics['avg_confidence'] > 0.5:
        print(f"\n+ Medium confidence ({evaluation_metrics['avg_confidence']:.3f})")
        print(f"  Some weaker retrievals - consider reranking or filtering")
    else:
        print(f"\n- Low confidence ({evaluation_metrics['avg_confidence']:.3f})")
        print(f"  Consider: lower TOP_K, apply reranking, or improve query expansion")
    
    # Coverage insights
    avg_coverage = evaluation_metrics['avg_sources_used'] / TOP_K
    if avg_coverage > 0.8:
        print(f"\n✓ High source coverage ({avg_coverage:.0%} of retrieved sources cited)")
        print(f"  Answers use most available information")
    elif avg_coverage > 0.5:
        print(f"\n+ Moderate coverage ({avg_coverage:.0%} of sources cited)")
        print(f"  Some sources unused - check for redundancy")
    else:
        print(f"\n- Low coverage ({avg_coverage:.0%} of sources cited)")
        print(f"  Many unused sources - may indicate retrieval quality issues")
    
    print(f"\n" + "=" * 70)

else:
    print("⚠️  Cannot evaluate: embeddings or test questions not available")
    evaluation_metrics = None
    per_query_results = None

## Track Experiment

In [None]:
# ============================================================================
# PART 4: EXPERIMENT TRACKING
# ============================================================================

def compute_config_hash(config_dict: Dict) -> str:
    """Create deterministic SHA256 hash of a configuration dictionary.
    
    Enables finding all experiments with identical configurations.
    
    Args:
        config_dict: Configuration parameters
        
    Returns:
        SHA256 hash string (first 12 characters)
    """
    config_str = json.dumps(config_dict, sort_keys=True)
    hash_obj = hashlib.sha256(config_str.encode())
    return hash_obj.hexdigest()[:12]


def start_experiment(db_connection, experiment_name: str,
                     notebook_path: str = None,
                     embedding_model_alias: str = None,
                     config: Dict = None,
                     techniques: List[str] = None,
                     notes: str = None) -> int:
    """Start a new experiment and return its ID for tracking.
    
    Args:
        db_connection: PostgreSQL connection
        experiment_name: Human-readable experiment name
        notebook_path: Path to the notebook running this experiment
        embedding_model_alias: Which embedding model is being used
        config: Dict of configuration parameters
        techniques: List of techniques being applied
        notes: Optional notes about the experiment
        
    Returns:
        Experiment ID for use in save_metrics() and complete_experiment()
    """
    if config is None:
        config = {}
    if techniques is None:
        techniques = []

    config_hash = compute_config_hash(config)

    with db_connection.cursor() as cur:
        cur.execute('''
            INSERT INTO experiments (
                experiment_name, notebook_path, embedding_model_alias,
                config_hash, config_json, techniques_applied, notes, status
            )
            VALUES (%s, %s, %s, %s, %s, %s, %s, 'running')
            RETURNING id
        ''', (
            experiment_name,
            notebook_path,
            embedding_model_alias,
            config_hash,
            json.dumps(config),
            techniques,
            notes
        ))
        exp_id = cur.fetchone()[0]
    db_connection.commit()
    print(f"✓ Started experiment #{exp_id}: {experiment_name}")
    return exp_id


def save_metrics(db_connection, experiment_id: int, metrics_dict: Dict,
                 export_to_file: bool = True,
                 export_dir: str = 'data/experiment_results') -> Tuple[bool, str]:
    """Save experiment metrics to database and optionally to JSON file.
    
    Args:
        db_connection: PostgreSQL connection
        experiment_id: ID from start_experiment()
        metrics_dict: Dict of {metric_name: value, ...}
        export_to_file: Whether to also save to filesystem JSON
        export_dir: Directory for JSON exports
        
    Returns:
        Tuple of (success: bool, message: str)
    """
    try:
        with db_connection.cursor() as cur:
            for metric_name, metric_data in metrics_dict.items():
                # Handle both simple floats and nested dicts with details
                if isinstance(metric_data, dict):
                    metric_value = metric_data.get('value', 0.0)
                    metric_details = metric_data.get('details', {})
                else:
                    metric_value = metric_data
                    metric_details = {}

                cur.execute('''
                    INSERT INTO evaluation_results (
                        experiment_id, metric_name, metric_value, metric_details_json
                    )
                    VALUES (%s, %s, %s, %s)
                ''', (
                    experiment_id,
                    metric_name,
                    float(metric_value),
                    json.dumps(metric_details) if metric_details else '{}'
                ))
        db_connection.commit()

        # Export to file if requested
        file_path = None
        if export_to_file:
            os.makedirs(export_dir, exist_ok=True)
            timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
            file_path = os.path.join(export_dir, f'experiment_{experiment_id}_{timestamp}.json')
            with open(file_path, 'w') as f:
                json.dump({
                    'experiment_id': experiment_id,
                    'timestamp': timestamp,
                    'metrics': metrics_dict
                }, f, indent=2)

        msg = f"✓ Saved {len(metrics_dict)} metrics for experiment #{experiment_id}"
        if file_path:
            msg += f" to {file_path}"
        print(msg)
        return True, msg
    except Exception as e:
        msg = f"✗ Failed to save metrics: {e}"
        print(msg)
        db_connection.rollback()
        return False, msg


def complete_experiment(db_connection, experiment_id: int,
                       status: str = 'completed',
                       notes: str = None) -> bool:
    """Mark an experiment as complete.
    
    Args:
        db_connection: PostgreSQL connection
        experiment_id: ID returned from start_experiment()
        status: 'completed' or 'failed'
        notes: Optional update to notes field
        
    Returns:
        True if successful
    """
    try:
        with db_connection.cursor() as cur:
            if notes:
                cur.execute('''
                    UPDATE experiments
                    SET status = %s, notes = %s, completed_at = CURRENT_TIMESTAMP
                    WHERE id = %s
                ''', (status, notes, experiment_id))
            else:
                cur.execute('''
                    UPDATE experiments
                    SET status = %s, completed_at = CURRENT_TIMESTAMP
                    WHERE id = %s
                ''', (status, experiment_id))
        db_connection.commit()
        print(f"✓ Experiment #{experiment_id} marked as {status}")
        return True
    except Exception as e:
        print(f"✗ Failed to complete experiment: {e}")
        db_connection.rollback()
        return False


# ============================================================================
# RUN EXPERIMENT TRACKING
# ============================================================================

print("\n" + "=" * 70)
print("TRACKING EXPERIMENT")
print("=" * 70)

if evaluation_metrics and evaluation_metrics.get('total_queries_evaluated', 0) > 0:
    
    # Prepare configuration
    config_dict = {
        'embedding_model_alias': EMBEDDING_MODEL_ALIAS,
        'top_k': TOP_K,
        'confidence_threshold': CONFIDENCE_THRESHOLD,
        'llm_model': 'llama3.2:1b',
        'num_test_queries': evaluation_metrics['total_queries_evaluated'],
    }

    config_hash = compute_config_hash(config_dict)

    print(f"\nExperiment Configuration:")
    print(f"  Name: {EXPERIMENT_NAME}")
    print(f"  Embedding Model: {EMBEDDING_MODEL_ALIAS}")
    print(f"  Top K: {TOP_K}")
    print(f"  Confidence Threshold: {CONFIDENCE_THRESHOLD}")
    print(f"  Config Hash: {config_hash}")
    print(f"  Test Queries: {evaluation_metrics['total_queries_evaluated']}\n")

    # Start experiment tracking
    experiment_id = start_experiment(
        db_connection,
        experiment_name=EXPERIMENT_NAME,
        notebook_path='advanced-techniques/09-citation-tracking.ipynb',
        embedding_model_alias=EMBEDDING_MODEL_ALIAS,
        config=config_dict,
        techniques=TECHNIQUES_APPLIED,
        notes=f'Citation tracking and source attribution on {evaluation_metrics["total_queries_evaluated"]} queries'
    )

    # Prepare metrics for storage
    metrics_to_store = {}

    # Core citation metrics
    metrics_to_store['avg_confidence'] = evaluation_metrics['avg_confidence']
    metrics_to_store['avg_citation_accuracy'] = evaluation_metrics['avg_citation_accuracy']
    metrics_to_store['avg_sources_used'] = evaluation_metrics['avg_sources_used']
    metrics_to_store['avg_invalid_citations'] = evaluation_metrics['avg_invalid_citations']
    
    # Distribution metrics
    metrics_to_store['min_confidence'] = evaluation_metrics['min_confidence']
    metrics_to_store['max_confidence'] = evaluation_metrics['max_confidence']

    # Configuration and metadata
    metrics_to_store['num_queries_evaluated'] = evaluation_metrics['total_queries_evaluated']
    metrics_to_store['top_k_used'] = TOP_K
    metrics_to_store['config_hash'] = config_hash

    # Save metrics
    print("\nSaving metrics to database...\n")
    success, message = save_metrics(db_connection, experiment_id, metrics_to_store, export_to_file=True)

    # Complete experiment
    if success:
        notes = f"Successfully evaluated citation tracking on {evaluation_metrics['total_queries_evaluated']} queries. "
        notes += f"Citation accuracy: {evaluation_metrics['avg_citation_accuracy']:.2%}, "
        notes += f"Avg confidence: {evaluation_metrics['avg_confidence']:.4f}"

        complete_experiment(db_connection, experiment_id, status='completed', notes=notes)

        # Display results summary
        print("\n" + "=" * 70)
        print("EXPERIMENT RESULTS SUMMARY")
        print("=" * 70)

        print(f"\nExperiment ID: {experiment_id}")
        print(f"Experiment Name: {EXPERIMENT_NAME}")
        print(f"Status: Completed")
        print(f"Config Hash: {config_hash}")

        print(f"\nKey Metrics:")
        print(f"  Avg Confidence:        {evaluation_metrics['avg_confidence']:.4f}")
        print(f"  Citation Accuracy:     {evaluation_metrics['avg_citation_accuracy']:.2%}")
        print(f"  Avg Sources Used:      {evaluation_metrics['avg_sources_used']:.2f} / {TOP_K}")
        print(f"  Avg Invalid Citations: {evaluation_metrics['avg_invalid_citations']:.2f}")

        print(f"\nResults exported to:")
        print(f"  Database: evaluation_results table (experiment_id={experiment_id})")
        print(f"  JSON: data/experiment_results/experiment_{experiment_id}_*.json")

        print("\n" + "=" * 70)
        print("NEXT STEPS")
        print("=" * 70)
        print("\n1. Review citation accuracy metrics above")
        print("2. Compare results with other techniques:")
        print("   - evaluation-lab/03-compare-experiments.ipynb")
        print("   - evaluation-lab/04-plot-improvements.ipynb")
        print("3. Improve citation quality:")
        print("   - Combine with reranking (advanced-techniques/05)")
        print("   - Use query expansion (advanced-techniques/06)")
        print("   - Apply hybrid search (advanced-techniques/07)")
        print("4. Analyze per-query results for patterns:")
        print("   - Which questions have low citation accuracy?")
        print("   - Which retrieval confidence levels most reliable?")
        print("5. Implement citation filtering:")
        print("   - Only show answers with citation_accuracy > 0.8")
        print("   - Only show retrievals with confidence > CONFIDENCE_THRESHOLD")

    else:
        print("\n✗ Failed to track experiment")
        complete_experiment(db_connection, experiment_id, status='failed', notes='Failed to save metrics')

else:
    print("⚠️  Cannot track experiment: evaluation results not available")

# Close database connection
print("\n\nClosing database connection...")
if embeddings_db:
    embeddings_db.close()
db_connection.close()
print("✓ All connections closed")