# FAISS Retrieval Evaluation (CPU)

This notebook evaluates FAISS retrieval using pre-built indices.

**Can run on CPU** - indices are loaded from disk.

Evaluates:
- 3 embedding models (Legal-BERT, GTE-Large, BGE-Large)
- 2 fields (content, metadata)
- All 10 legal queries

Total: 6 retrieval configurations

In [None]:
import sys
sys.path.append('..')

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import json
from pathlib import Path

from src.data_loader import load_data, prepare_data, get_documents_by_field
from src.queries import get_all_queries
from src.faiss_retriever import FAISSRetriever, evaluate_faiss
from src.evaluation import (
    print_query_results, 
    create_comparison_table,
    calculate_overlap,
    analyze_retrieval_diversity
)
from src.config import EMBEDDING_MODELS, INDICES_DIR, RESULTS_DIR, TOP_K

sns.set_style('whitegrid')
print("✓ Imports successful")

## 1. Load and Prepare Data

In [None]:
# Load data
df = load_data()
df = prepare_data(df)

# Get queries
queries = get_all_queries()
print(f"Loaded {len(df)} documents and {len(queries)} queries")

## 2. Verify Indices Exist

In [None]:
# Check which indices are available
available_indices = []

print("Available FAISS Indices:")
print("="*80)

for model_key in EMBEDDING_MODELS.keys():
    for field in ['content', 'metadata']:
        index_dir = INDICES_DIR / f"{field}_{model_key}"
        index_file = index_dir / "index.faiss"
        
        if index_file.exists():
            available_indices.append((model_key, field))
            print(f"✓ {field}_{model_key}")
        else:
            print(f"✗ {field}_{model_key} - NOT FOUND")

print(f"\nTotal available: {len(available_indices)} / 6")

if len(available_indices) == 0:
    print("\n⚠️ ERROR: No indices found!")
    print("Please run notebook 03 (FAISS Index Builder) first.")

## 3. Run FAISS Retrieval for All Models

In [None]:
# Store all results
all_results = {}

for model_key in EMBEDDING_MODELS.keys():
    for field in ['content', 'metadata']:
        index_dir = INDICES_DIR / f"{field}_{model_key}"
        index_file = index_dir / "index.faiss"
        
        if not index_file.exists():
            print(f"Skipping {field}_{model_key} - index not found")
            continue
        
        print(f"\n{'='*80}")
        print(f"Evaluating: {model_key} on {field}")
        print(f"{'='*80}")
        
        # Evaluate
        results = evaluate_faiss(
            model_key=model_key,
            queries=queries,
            field_name=field,
            top_k=TOP_K
        )
        
        # Store results
        key = f"faiss_{model_key}_{field}"
        all_results[key] = results
        
        # Save individual results
        output_file = RESULTS_DIR / f"{key}_results.json"
        with open(output_file, 'w') as f:
            json.dump(results, f, indent=2)
        print(f"Results saved to: {output_file}")

print(f"\n✓ Completed {len(all_results)} evaluations")

## 4. Compare Models on Content Field

In [None]:
# Compare retrieval times for content field
content_results = {k: v for k, v in all_results.items() if 'content' in k}

if content_results:
    comparison_data = []
    
    for key, results in content_results.items():
        model = results['model']
        comparison_data.append({
            'Model': model,
            'Avg Time (s)': results['avg_retrieval_time'],
            'Model Name': results['model_name']
        })
    
    comparison_df = pd.DataFrame(comparison_data)
    
    print("Content Field - Model Comparison:")
    print(comparison_df.to_string(index=False))
    
    # Visualize
    plt.figure(figsize=(10, 6))
    plt.bar(comparison_df['Model'], comparison_df['Avg Time (s)'])
    plt.xlabel('Model')
    plt.ylabel('Average Retrieval Time (seconds)')
    plt.title('FAISS Retrieval Time - Content Field')
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.show()

## 5. Compare Models on Metadata Field

In [None]:
# Compare retrieval times for metadata field
metadata_results = {k: v for k, v in all_results.items() if 'metadata' in k}

if metadata_results:
    comparison_data = []
    
    for key, results in metadata_results.items():
        model = results['model']
        comparison_data.append({
            'Model': model,
            'Avg Time (s)': results['avg_retrieval_time'],
            'Model Name': results['model_name']
        })
    
    comparison_df = pd.DataFrame(comparison_data)
    
    print("Metadata Field - Model Comparison:")
    print(comparison_df.to_string(index=False))
    
    # Visualize
    plt.figure(figsize=(10, 6))
    plt.bar(comparison_df['Model'], comparison_df['Avg Time (s)'])
    plt.xlabel('Model')
    plt.ylabel('Average Retrieval Time (seconds)')
    plt.title('FAISS Retrieval Time - Metadata Field')
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.show()

## 6. Analyze Query-Specific Results

In [None]:
# Compare all models for a specific query
def compare_models_for_query(query_id, field='content'):
    """Compare all models for a specific query."""
    query = queries[query_id - 1]
    
    print(f"\n{'='*80}")
    print(f"QUERY {query_id}: {query['query']}")
    print(f"Field: {field}")
    print(f"{'='*80}\n")
    
    for model_key in EMBEDDING_MODELS.keys():
        key = f"faiss_{model_key}_{field}"
        if key in all_results:
            results = all_results[key]
            query_result = results['results'][query_id - 1]
            
            print(f"\n--- {model_key.upper()} ---")
            print_query_results(
                df=df,
                query=query['query'],
                indices=query_result['retrieved_indices'],
                scores=query_result['scores'],
                top_n=3,
                field=field
            )

# Example: Compare models for query 1 on content
compare_models_for_query(1, 'content')

In [None]:
# Compare overlap between models for a query
def analyze_model_overlap(query_id, field='content'):
    """Analyze overlap in results between different models."""
    query = queries[query_id - 1]
    
    print(f"\nQuery {query_id}: {query['query']}")
    print(f"Field: {field}\n")
    
    # Get results from all models
    model_results = {}
    for model_key in EMBEDDING_MODELS.keys():
        key = f"faiss_{model_key}_{field}"
        if key in all_results:
            query_result = all_results[key]['results'][query_id - 1]
            model_results[model_key] = query_result['retrieved_indices']
    
    # Calculate pairwise overlap
    models = list(model_results.keys())
    print("Pairwise Overlap (Jaccard Similarity):\n")
    
    for i, model1 in enumerate(models):
        for model2 in models[i+1:]:
            overlap = calculate_overlap(model_results[model1], model_results[model2])
            print(f"{model1} vs {model2}: {overlap:.2%}")

# Example analysis
analyze_model_overlap(1, 'content')

## 7. Overall Performance Comparison

In [None]:
# Create comprehensive comparison table
comparison_table = create_comparison_table(all_results)

print("\nFAISS Retrieval Performance Summary:")
print("="*80)
print(comparison_table.to_string(index=False))

# Save comparison
comparison_table.to_csv(RESULTS_DIR / "faiss_comparison.csv", index=False)
print(f"\nComparison table saved to: {RESULTS_DIR / 'faiss_comparison.csv'}")

In [None]:
# Visualize all retrieval times
if len(all_results) > 0:
    fig, ax = plt.subplots(figsize=(14, 6))
    
    methods = []
    times = []
    colors = []
    
    for key, results in all_results.items():
        label = f"{results['model']}\n({results['field']})"
        methods.append(label)
        times.append(results['avg_retrieval_time'])
        colors.append('skyblue' if results['field'] == 'content' else 'lightcoral')
    
    ax.barh(methods, times, color=colors)
    ax.set_xlabel('Average Retrieval Time (seconds)')
    ax.set_title('FAISS Retrieval Performance - All Configurations')
    ax.grid(axis='x', alpha=0.3)
    
    # Add legend
    from matplotlib.patches import Patch
    legend_elements = [
        Patch(facecolor='skyblue', label='Content'),
        Patch(facecolor='lightcoral', label='Metadata')
    ]
    ax.legend(handles=legend_elements, loc='lower right')
    
    plt.tight_layout()
    plt.show()

## 8. Retrieval Diversity Analysis

In [None]:
# Analyze diversity of retrieved documents
from src.evaluation import analyze_retrieval_diversity, print_diversity_analysis

print("Retrieval Diversity Analysis:")
print("="*80)

for key, results in all_results.items():
    print(f"\n{key}:")
    diversity = analyze_retrieval_diversity(results, df)
    print_diversity_analysis(diversity)

## Summary

FAISS retrieval evaluation complete!

Key findings:
- Tested 3 embedding models (Legal-BERT, GTE-Large, BGE-Large)
- Evaluated on both content and metadata fields
- All results saved for further analysis and comparison
- Retrieval times measured on CPU

Model characteristics:
- **Legal-BERT**: Domain-specific, may excel on legal terminology
- **GTE-Large**: General-purpose SOTA, strong cross-domain performance  
- **BGE-Large**: Top retrieval model, excellent for ranking tasks

Next steps:
- Compare with BM25 results (notebook 02)
- Apply reranking to improve results (notebook 05)
- Analyze which model works best for different query types