## ‚ö†Ô∏è IMPORTANT: Runtime Restart Required

**After running Cell 1 (install dependencies), you MUST restart the runtime/kernel before running Cell 2.**

Why? PyJnius starts a JVM when pyserini is first imported, and it cannot be changed once started. The restart ensures Java 21 is used from the beginning.

**Steps:**
1. Run Cell 1 (Install dependencies) 
2. **Runtime ‚Üí Restart runtime** (Colab) or **Kernel ‚Üí Restart** (Jupyter)
3. Run Cell 2 and continue

In [None]:
# Install dependencies
!pip install -q sentence-transformers pyserini pandas matplotlib seaborn beir
# Install newer Java (class file version 65 requires Java 21 JDK)
!apt-get -y install -qq openjdk-21-jdk-headless || true
print("‚úÖ Dependencies installed (Pyserini + Java 21 JDK for Lucene + BEIR)")

In [None]:
import os
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import time
import re
import subprocess
from tqdm.auto import tqdm

# Ensure Java 21 is used (needed for Pyserini/Lucene class version 65)
java_home = "/usr/lib/jvm/java-21-openjdk-amd64"
os.environ["JAVA_HOME"] = os.environ.get("JAVA_HOME", java_home)
os.environ["JAVAHOME"] = os.environ.get("JAVAHOME", java_home)
os.environ["JDK_HOME"] = os.environ.get("JDK_HOME", java_home)
os.environ["PATH"] = f"{os.environ['JAVA_HOME']}/bin:" + os.environ.get("PATH", "")
# Force alternatives to Java 21
try:
    subprocess.run(["update-alternatives", "--set", "java", f"{java_home}/bin/java"], check=True)
    subprocess.run(["update-alternatives", "--set", "javac", f"{java_home}/bin/javac"], check=True)
except Exception as e:
    print(f"‚ö†Ô∏è update-alternatives failed: {e}")
# Verify Java version
try:
    subprocess.run(["java", "-version"], check=True)
except Exception as e:
    print(f"‚ö†Ô∏è Java version check failed: {e}")

from sentence_transformers import SentenceTransformer

from beir import util
from beir.datasets.data_loader import GenericDataLoader

from pyserini.search.lucene import LuceneSearcher

sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (10, 6)

print("‚úÖ Libraries imported (using Lucene/Pyserini; Java forced to 21)")

In [None]:
# Diagnostic: Check actual Java version being used
import subprocess
import sys

print("üîç Java Diagnostic:")
try:
    result = subprocess.run(["java", "-version"], capture_output=True, text=True)
    print("Java version:", result.stderr.split('\n')[0])
except Exception as e:
    print(f"‚ùå Java not found: {e}")

print("\nüîç Python JVM Status:")
try:
    import jnius_config
    if jnius_config.vm_running:
        print("‚ö†Ô∏è JVM already running - restart required to change Java version")
    else:
        print("‚úÖ JVM not started yet")
except:
    print("‚úÖ jnius not loaded yet")

## Dataset Selection
Choose a dataset. Defaults to FiQA for medium scale.

In [None]:
# Select dataset: 'scifact', 'fiqa', 'trec-covid', 'webis-touche2020', 'quora', 'robust04', 'trec-news', or 'nq'
dataset_name = 'fiqa'  # pick from the list above; 'nq' is very large

dataset_urls = {
    'scifact': 'https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/scifact.zip',          # ~5k docs
    'fiqa': 'https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/fiqa.zip',            # ~57k docs
    'trec-covid': 'https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/trec-covid.zip',  # ~171k docs
    'webis-touche2020': 'https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/webis-touche2020.zip', # ~382k docs
    'quora': 'https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/quora.zip',          # ~523k docs
    'robust04': 'https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/robust04.zip',    # ~528k docs
    'trec-news': 'https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/trec-news.zip',  # ~595k docs
    # Note: NQ is very large; ensure sufficient resources
    'nq': 'https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/nq.zip',                # ~2.6M docs
}

url = dataset_urls[dataset_name]
print(f"Downloading {dataset_name} dataset...")
data_path = util.download_and_unzip(url, "datasets")

print("Loading dataset...")
corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="test")

print(f"\n‚úÖ Dataset loaded!\n   Documents: {len(corpus):,}\n   Queries: {len(queries):,}\n   Relevance judgments: {len(qrels):,}")

## Prepare Model and Texts
Memory-safe batching for larger corpora.

In [None]:
# Load embedding model
model_name = 'BAAI/bge-base-en-v1.5'
print(f"Loading model: {model_name}")
model = SentenceTransformer(model_name)
dimension = model.get_sentence_embedding_dimension()

# Prepare texts
doc_ids = list(corpus.keys())
doc_texts = [corpus[did]['title'] + ' ' + corpus[did]['text'] for did in doc_ids]
query_ids = list(queries.keys())
query_texts = [queries[qid] for qid in query_ids]

print(f"‚úÖ Model loaded (dim={dimension})")

In [None]:
# Encode documents with memory-safe batching
batch_size_docs = 32 if len(doc_texts) <= 100_000 else 16
print(f"Encoding {len(doc_texts):,} documents (batch_size={batch_size_docs})...")

doc_embeddings = model.encode(
    doc_texts,
    batch_size=batch_size_docs,
    show_progress_bar=True,
    convert_to_numpy=True,
    normalize_embeddings=True
)

print(f"‚úÖ Documents encoded! Shape: {doc_embeddings.shape}, Memory: {doc_embeddings.nbytes / (1024**2):.2f} MB")

# Encode queries
batch_size_queries = 32
print(f"Encoding {len(query_texts):,} queries (batch_size={batch_size_queries})...")
query_embeddings = model.encode(
    query_texts,
    batch_size=batch_size_queries,
    show_progress_bar=True,
    convert_to_numpy=True,
    normalize_embeddings=True
)
print(f"‚úÖ Queries encoded! Shape: {query_embeddings.shape}")

## Build Indexes
HNSW for medium/large datasets, INT8 quantization for memory reduction.

In [None]:
# Build Lucene indexes (BM25 + SPLADE ED + HNSW vectors + flat vectors)
print("Preparing Lucene inputs...")

bm25_root = f'lucene_bm25_{dataset_name}'
bm25_docs_dir = os.path.join(bm25_root, 'docs')
bm25_index_dir = os.path.join(bm25_root, 'index')
os.makedirs(bm25_docs_dir, exist_ok=True)

bm25_jsonl = os.path.join(bm25_docs_dir, 'docs.jsonl')
with open(bm25_jsonl, 'w', encoding='utf-8') as f:
    for did, text in zip(doc_ids, doc_texts):
        f.write(json.dumps({'id': did, 'contents': text}) + "\n")

# SPLADE docs (same jsonl as BM25)
splade_root = f'lucene_splade_{dataset_name}'
splade_docs_dir = os.path.join(splade_root, 'docs')
splade_index_dir = os.path.join(splade_root, 'index')
splade_encoded_dir = os.path.join(splade_root, 'encoded')
os.makedirs(splade_docs_dir, exist_ok=True)
os.makedirs(splade_encoded_dir, exist_ok=True)

splade_jsonl = os.path.join(splade_docs_dir, 'docs.jsonl')
if not os.path.exists(splade_jsonl):
    with open(splade_jsonl, 'w', encoding='utf-8') as f:
        for did, text in zip(doc_ids, doc_texts):
            f.write(json.dumps({'id': did, 'contents': text}) + "\n")

dense_root = f'lucene_dense_{dataset_name}'
dense_vec_dir = os.path.join(dense_root, 'vectors')
dense_index_dir = os.path.join(dense_root, 'index_hnsw')
dense_flat_index_dir = os.path.join(dense_root, 'index_flat')
os.makedirs(dense_vec_dir, exist_ok=True)

dense_jsonl = os.path.join(dense_vec_dir, 'vectors.jsonl')
with open(dense_jsonl, 'w', encoding='utf-8') as f:
    for did, text, vec in zip(doc_ids, doc_texts, doc_embeddings):
        f.write(json.dumps({'id': did, 'contents': text, 'vector': vec.tolist()}) + "\n")

# Match paper defaults
M = 16
ef_construction = 100
ef_search = 1000

threads = '16'

print("Indexing BM25 (Lucene)...")
subprocess.run([
    'python', '-m', 'pyserini.index.lucene',
    '--collection', 'JsonCollection',
    '--input', bm25_root,
    '--index', bm25_index_dir,
    '--generator', 'DefaultLuceneDocumentGenerator',
    '--threads', threads,
    '--storePositions',
    '--storeDocvectors',
    '--storeRaw'
], check=True)
print("‚úÖ BM25 index ready")

print("Encoding SPLADE ED (naver/splade_cocondenser_ensembledistil)...")
subprocess.run([
    'python', '-m', 'pyserini.encode',
    '--encoder', 'naver/splade_cocondenser_ensembledistil',
    '--fields', 'contents',
    '--input', splade_docs_dir,
    '--output', splade_encoded_dir,
    '--batch', '32',
    '--format', 'jsonl',
    '--device', 'cpu'
], check=True)
print("‚úÖ SPLADE encoding ready")

print("Indexing SPLADE impact (Lucene)...")
subprocess.run([
    'python', '-m', 'pyserini.index.lucene',
    '--collection', 'JsonCollection',
    '--input', splade_encoded_dir,
    '--index', splade_index_dir,
    '--generator', 'ImpactLuceneDocumentGenerator',
    '--impact',
    '--threads', threads,
    '--storePositions',
    '--storeDocvectors',
    '--storeRaw'
], check=True)
print("‚úÖ SPLADE impact index ready")

print("Indexing Dense HNSW (Lucene vectors)...")
subprocess.run([
    'python', '-m', 'pyserini.index.lucene',
    '--collection', 'JsonVectorCollection',
    '--input', dense_root,
    '--index', dense_index_dir,
    '--generator', 'DefaultLuceneDocumentGenerator',
    '--threads', threads,
    '--dim', str(dimension),
    '--hnswM', str(M),
    '--hnswefC', str(ef_construction),
    '--hnswefS', str(ef_search),
    '--storeRaw'
], check=True)
print("‚úÖ Dense HNSW index ready (Lucene)")

print("Indexing Dense FLAT (Lucene vectors)...")
subprocess.run([
    'python', '-m', 'pyserini.index.lucene',
    '--collection', 'JsonVectorCollection',
    '--input', dense_root,
    '--index', dense_flat_index_dir,
    '--generator', 'DefaultLuceneDocumentGenerator',
    '--threads', threads,
    '--dim', str(dimension),
    '--vector-indexing-approach', 'flat',
    '--storeRaw'
], check=True)
print("‚úÖ Dense FLAT index ready (Lucene)")

## BM25 Baseline
Tokenize and build BM25 index over documents.

In [None]:
# Instantiate Lucene searchers
bm25_searcher = LuceneSearcher(bm25_index_dir)
bm25_searcher.set_bm25(k1=0.9, b=0.4)

dense_hnsw_searcher = LuceneSearcher(dense_index_dir)
dense_flat_searcher = LuceneSearcher(dense_flat_index_dir)

from pyserini.search.lucene import LuceneImpactSearcher, SpladeQueryEncoder
splade_query_encoder = SpladeQueryEncoder('naver/splade_cocondenser_ensembledistil')
splade_searcher = LuceneImpactSearcher(splade_index_dir, splade_query_encoder)

print("‚úÖ Lucene searchers ready (BM25 + SPLADE ED + HNSW + FLAT)")

## Search Functions
Shared utilities to run and measure searches.

In [None]:
doc_id_to_idx = {did: i for i, did in enumerate(doc_ids)}

def search_dense_lucene(searcher, query_embeddings, k=1000, name="Lucene-HNSW"):
    latencies = []; all_indices = []; all_scores = []
    for emb in tqdm(query_embeddings, desc=f"{name} search"):
        start = time.time()
        hits = searcher.search(query_vector=emb.tolist(), k=k)
        latencies.append((time.time() - start) * 1000)
        docids = [h.docid for h in hits]
        scores = [h.score for h in hits]
        all_indices.append([doc_id_to_idx[d] for d in docids])
        all_scores.append(scores)
    latencies = np.array(latencies)
    return {
        'name': name,
        'indices': np.array(all_indices),
        'scores': np.array(all_scores),
        'latencies': latencies,
        'median_latency': np.median(latencies),
        'p95_latency': np.percentile(latencies, 95),
        'p99_latency': np.percentile(latencies, 99),
    }

def search_sparse_impact(searcher, query_texts, k=1000, name="SPLADE-ED"):
    latencies = []; all_indices = []; all_scores = []
    for q in tqdm(query_texts, desc=f"{name} search"):
        start = time.time()
        hits = searcher.search(q, k)
        latencies.append((time.time() - start) * 1000)
        docids = [h.docid for h in hits]
        scores = [h.score for h in hits]
        all_indices.append([doc_id_to_idx[d] for d in docids])
        all_scores.append(scores)
    latencies = np.array(latencies)
    return {
        'name': name,
        'indices': np.array(all_indices),
        'scores': np.array(all_scores),
        'latencies': latencies,
        'median_latency': np.median(latencies),
        'p95_latency': np.percentile(latencies, 95),
        'p99_latency': np.percentile(latencies, 99),
    }

def search_bm25_lucene(searcher, query_texts, k=1000):
    latencies = []; all_indices = []; all_scores = []
    for q in tqdm(query_texts, desc="BM25 search"):
        start = time.time()
        hits = searcher.search(q, k)
        latencies.append((time.time() - start) * 1000)
        docids = [h.docid for h in hits]
        scores = [h.score for h in hits]
        all_indices.append([doc_id_to_idx[d] for d in docids])
        all_scores.append(scores)
    latencies = np.array(latencies)
    return {
        'name': 'BM25',
        'indices': np.array(all_indices),
        'scores': np.array(all_scores),
        'latencies': latencies,
        'median_latency': np.median(latencies),
        'p95_latency': np.percentile(latencies, 95),
        'p99_latency': np.percentile(latencies, 99),
    }

def merge_rankings(dense_indices, dense_scores, sparse_indices, sparse_scores, k=10, alpha=0.5):
    merged = {}
    for rank, (idx, s) in enumerate(zip(dense_indices, dense_scores), 1):
        doc_id = doc_ids[idx]
        merged[doc_id] = merged.get(doc_id, 0) + alpha / (60 + rank)
    for rank, (idx, s) in enumerate(zip(sparse_indices, sparse_scores), 1):
        doc_id = doc_ids[idx]
        merged[doc_id] = merged.get(doc_id, 0) + (1 - alpha) / (60 + rank)
    ranked = sorted(merged.items(), key=lambda x: x[1], reverse=True)[:k]
    return np.array([doc_id_to_idx[doc_id] for doc_id, _ in ranked])

def hybrid_search(dense_res, sparse_res, alpha=0.5, k=10):
    latencies = []; all_indices = []
    for d_indices, d_scores, s_indices, s_scores in tqdm(zip(dense_res['indices'], dense_res['scores'], sparse_res['indices'], sparse_res['scores']), total=len(dense_res['indices']), desc="Hybrid search"):
        start = time.time()
        merged = merge_rankings(d_indices, d_scores, s_indices, s_scores, k=k, alpha=alpha)
        latencies.append((time.time() - start) * 1000)
        all_indices.append(merged)
    latencies = np.array(latencies)
    return {
        'name': f"Hybrid (Œ±={alpha})",
        'indices': np.array(all_indices),
        'latencies': latencies,
        'median_latency': np.median(latencies),
        'p95_latency': np.percentile(latencies, 95),
        'p99_latency': np.percentile(latencies, 99),
    }

## Run Searches
Collect top-10 results and latency stats.

In [None]:
k_eval = 10
k_retrieve = 1000

results_dense_flat = search_dense_lucene(dense_flat_searcher, query_embeddings, k=k_retrieve, name="Lucene-FLAT")
results_dense_hnsw = search_dense_lucene(dense_hnsw_searcher, query_embeddings, k=k_retrieve, name="Lucene-HNSW")
results_bm25 = search_bm25_lucene(bm25_searcher, query_texts, k=k_retrieve)
results_splade = search_sparse_impact(splade_searcher, query_texts, k=k_retrieve, name="SPLADE-ED")

# Hybrid runs with different Œ± using HNSW dense + SPLADE
alpha_values = [0.3, 0.5, 0.7]
hybrid_results = []
for alpha in alpha_values:
    res = hybrid_search(results_dense_hnsw, results_splade, alpha=alpha, k=k_eval)
    hybrid_results.append(res)
print("‚úÖ Searches complete (Flat + HNSW + BM25 + SPLADE)")

## Evaluation
Compute Recall@10 and nDCG@10.

In [None]:
def calculate_recall(retrieved_indices, qrels, query_ids, doc_ids, k=10):
    recalls = []
    for i, qid in enumerate(query_ids):
        if qid not in qrels:
            continue
        relevant_docs = set(qrels[qid].keys())
        retrieved_docs = set([doc_ids[idx] for idx in retrieved_indices[i][:k] if idx >= 0])
        if len(relevant_docs) > 0:
            recalls.append(len(relevant_docs & retrieved_docs) / len(relevant_docs))
    return np.mean(recalls) if recalls else 0.0

def calculate_ndcg(retrieved_indices, qrels, query_ids, doc_ids, k=10):
    ndcgs = []
    for i, qid in enumerate(query_ids):
        if qid not in qrels:
            continue
        relevant_docs = qrels[qid]
        retrieved_docs = [doc_ids[idx] for idx in retrieved_indices[i][:k] if idx >= 0]
        dcg = 0
        for rank, doc_id in enumerate(retrieved_docs, 1):
            rel = relevant_docs.get(doc_id, 0)
            dcg += (2 ** rel - 1) / np.log2(rank + 1)
        ideal = sorted(relevant_docs.values(), reverse=True)[:k]
        idcg = sum((2 ** r - 1) / np.log2(rank + 2) for rank, r in enumerate(ideal))
        ndcgs.append(dcg / idcg if idcg > 0 else 0)
    return np.mean(ndcgs) if ndcgs else 0.0

# Evaluate all
for results in [results_dense_flat, results_dense_hnsw, results_bm25, results_splade] + hybrid_results:
    results['recall@10'] = calculate_recall(results['indices'], qrels, query_ids, doc_ids, k=k_eval)
    results['ndcg@10'] = calculate_ndcg(results['indices'], qrels, query_ids, doc_ids, k=k_eval)

print("‚úÖ Evaluation complete")

## Comparison Tables

In [None]:
comparison_df = pd.DataFrame([
    {
        'Method': 'Lucene-FLAT',
        'Type': 'Dense',
        'Recall@10': results_dense_flat['recall@10'],
        'nDCG@10': results_dense_flat['ndcg@10'],
        'Median Latency (ms)': results_dense_flat['median_latency'],
        'P95 Latency (ms)': results_dense_flat['p95_latency'],
    },
    {
        'Method': 'Lucene-HNSW',
        'Type': 'Dense',
        'Recall@10': results_dense_hnsw['recall@10'],
        'nDCG@10': results_dense_hnsw['ndcg@10'],
        'Median Latency (ms)': results_dense_hnsw['median_latency'],
        'P95 Latency (ms)': results_dense_hnsw['p95_latency'],
    },
    {
        'Method': 'BM25',
        'Type': 'Sparse',
        'Recall@10': results_bm25['recall@10'],
        'nDCG@10': results_bm25['ndcg@10'],
        'Median Latency (ms)': results_bm25['median_latency'],
        'P95 Latency (ms)': results_bm25['p95_latency'],
    },
    {
        'Method': 'SPLADE-ED',
        'Type': 'Sparse',
        'Recall@10': results_splade['recall@10'],
        'nDCG@10': results_splade['ndcg@10'],
        'Median Latency (ms)': results_splade['median_latency'],
        'P95 Latency (ms)': results_splade['p95_latency'],
    },
])

# Hybrid rows
hybrid_rows = []
for res in hybrid_results:
    hybrid_rows.append({
        'Method': res['name'],
        'Type': 'Hybrid',
        'Recall@10': res['recall@10'],
        'nDCG@10': res['ndcg@10'],
        'Median Latency (ms)': res['median_latency'],
        'P95 Latency (ms)': res['p95_latency'],
    })
comparison_df = pd.concat([comparison_df, pd.DataFrame(hybrid_rows)], ignore_index=True)

print("\nüìä COMPARISON TABLE")
print("="*100)
print(comparison_df.to_string(index=False))
print("="*100)

## Visualizations

In [None]:
# Create temp directory for plots
import os
temp_plots_dir = 'temp_plots'
os.makedirs(temp_plots_dir, exist_ok=True)

# Speed vs Quality
fig, ax = plt.subplots(figsize=(10, 6))
for _, row in comparison_df.iterrows():
    color = 'orange' if row['Type'] == 'Sparse' else ('green' if row['Type'] == 'Hybrid' else 'steelblue')
    ax.scatter(row['Median Latency (ms)'], row['nDCG@10'], s=200, alpha=0.75, color=color, edgecolors='black')
    ax.annotate(row['Method'], (row['Median Latency (ms)'], row['nDCG@10']), xytext=(8, 8), textcoords='offset points', fontsize=9, fontweight='bold')
ax.set_xlabel('Median Latency (ms)')
ax.set_ylabel('nDCG@10')
ax.set_title(f'Speed vs Quality ‚Äî {dataset_name}')
ax.grid(True, alpha=0.3)
plt.tight_layout()
plot1_path = os.path.join(temp_plots_dir, f'speed_vs_quality_{dataset_name}.pdf')
plt.savefig(plot1_path, dpi=300, bbox_inches='tight')
plt.show()
print(f"‚úÖ Plot saved as {plot1_path}")

# Bar chart quality
fig, ax = plt.subplots(figsize=(12, 5))
ax.bar(comparison_df['Method'], comparison_df['nDCG@10'], color='skyblue', edgecolor='black')
ax.set_ylabel('nDCG@10')
ax.set_title(f'Quality Comparison ‚Äî {dataset_name}')
ax.grid(True, alpha=0.3, axis='y')
plt.xticks(rotation=30, ha='right')
plt.tight_layout()
plot2_path = os.path.join(temp_plots_dir, f'quality_comparison_{dataset_name}.pdf')
plt.savefig(plot2_path, dpi=300, bbox_inches='tight')
plt.show()
print(f"‚úÖ Plot saved as {plot2_path}")

## Save Results

In [None]:
# Detect environment and set output directory
import os
import shutil

def detect_environment():
    """Detect if running on Colab, Kaggle, Modal, or local"""
    if 'COLAB_GPU' in os.environ or 'COLAB_TPU_ADDR' in os.environ:
        return 'colab'
    elif 'KAGGLE_KERNEL_RUN_TYPE' in os.environ:
        return 'kaggle'
    elif 'MODAL_PROJECT_NAME' in os.environ:
        return 'modal'
    else:
        return 'local'

environment = detect_environment()
print(f"üîç Detected environment: {environment.upper()}")

# Set output directory based on environment
if environment == 'colab':
    # For Colab, save to /content/results/
    output_dir = '/content/results'
    os.makedirs(output_dir, exist_ok=True)
    print(f"üíæ Saving to: {output_dir}")
    
elif environment == 'kaggle':
    # For Kaggle, save to /kaggle/working/
    output_dir = '/kaggle/working'
    print(f"üíæ Saving to: {output_dir}")

elif environment == 'modal':
    # For Modal, save to a results directory (often mapped to a Volume)
    output_dir = f'/root/results/{dataset_name}'
    os.makedirs(output_dir, exist_ok=True)
    print(f"üíæ Saving to Modal Volume path: {output_dir}")
    
else:
    # For local, save to current directory or create results folder
    output_dir = f'{dataset_name}_results'
    os.makedirs(output_dir, exist_ok=True)
    print(f"üíæ Saving to: {output_dir}/")

# Save comparison table
comparison_path = os.path.join(output_dir, f'experiment_results_{dataset_name}.csv')
comparison_df.to_csv(comparison_path, index=False)
print(f"‚úÖ Saved: {comparison_path}")

# Save latency data
latency_df = pd.DataFrame({
    'Lucene-FLAT': results_dense_flat['latencies'],
    'Lucene-HNSW': results_dense_hnsw['latencies'],
    'BM25': results_bm25['latencies'],
    'SPLADE-ED': results_splade['latencies'],
})
latency_path = os.path.join(output_dir, f'latency_data_{dataset_name}.csv')
latency_df.to_csv(latency_path, index=False)
print(f"‚úÖ Saved: {latency_path}")

# Save hybrid results table
hybrid_comparison_df = pd.DataFrame([{
    'Method': res['name'],
    'Recall@10': res['recall@10'],
    'nDCG@10': res['ndcg@10'],
    'Median Latency (ms)': res['median_latency'],
    'P95 Latency (ms)': res['p95_latency'],
} for res in hybrid_results])
hybrid_path = os.path.join(output_dir, f'hybrid_results_{dataset_name}.csv')
hybrid_comparison_df.to_csv(hybrid_path, index=False)
print(f"‚úÖ Saved: {hybrid_path}")

# Copy plot files to output directory
temp_plots_dir = 'temp_plots'
if os.path.exists(temp_plots_dir):
    plot_files = [
        f'speed_vs_quality_{dataset_name}.pdf',
        f'quality_comparison_{dataset_name}.pdf'
    ]
    for plot_file in plot_files:
        src = os.path.join(temp_plots_dir, plot_file)
        if os.path.exists(src):
            dst = os.path.join(output_dir, plot_file)
            shutil.copy2(src, dst)
            print(f"‚úÖ Copied plot: {dst}")

print("\n" + "="*80)
print(f"üìÅ All results saved to: {output_dir}/")
print("="*80)

# Environment-specific download instructions and automatic downloads
if environment == 'colab':
    print("\nüì• AUTO-DOWNLOADING FILES TO YOUR PC...")
    try:
        from google.colab import files
        # Download all result files
        for filename in os.listdir(output_dir):
            filepath = os.path.join(output_dir, filename)
            if os.path.isfile(filepath):
                print(f"   üì¶ Downloading: {filename}")
                files.download(filepath)
        print("‚úÖ All files downloaded to your PC!")
    except Exception as e:
        print(f"‚ö†Ô∏è Auto-download failed: {e}")
        print("\nüì• MANUAL DOWNLOAD INSTRUCTIONS:")
        print("   1. Click the folder icon on the left sidebar")
        print(f"   2. Navigate to {output_dir}/")
        print("   3. Right-click each file ‚Üí Download")
    
elif environment == 'kaggle':
    print("\nüì• FILES READY FOR DOWNLOAD:")
    print("   1. Click 'Save Version' ‚Üí 'Save & Run All'")
    print("   2. Once complete, go to the 'Output' tab")
    print("   3. Download the CSV and PDF files directly")
    
elif environment == 'modal':
    print("\nüì• TO ACCESS FILES IN MODAL:")
    print(f"   1. Files are stored in the volume at: {output_dir}")
    print("   2. Use 'modal volume get <volume_name> <remote_path> <local_path>' to download")
    
else:
    print(f"\nüìÇ Files saved locally in: {os.path.abspath(output_dir)}/")
    print("‚úÖ All files (CSVs and plots) are already on your PC!")

# Create summary report
print("\n" + "="*80)
print("üìä EXPERIMENT SUMMARY")
print("="*80)
print(f"Dataset: {dataset_name}")
print(f"Documents: {len(corpus):,}")
print(f"Queries: {len(queries):,}")
print(f"\nBest Quality Method: {comparison_df.loc[comparison_df['Recall@10'].idxmax()]['Method']}")
print(f"Fastest Method: {comparison_df.loc[comparison_df['Median Latency (ms)'].idxmin()]['Method']}")
print("="*80)

In [None]:
# Add QPS metrics derived from median/p95 latencies (single-thread approximation)
import numpy as np

def approx_qps_from_ms(ms):
    return (1000.0 / ms) if ms and ms > 0 else 0.0

# Extend comparison_df with QPS columns
comparison_df['QPS (approx, median)'] = comparison_df['Median Latency (ms)'].apply(approx_qps_from_ms)
comparison_df['QPS (approx, p95)'] = comparison_df['P95 Latency (ms)'].apply(approx_qps_from_ms)

print("\nüìä COMPARISON TABLE WITH QPS (single-thread approx)")
print("="*100)
print(comparison_df.to_string(index=False))
print("="*100)

# Note: Paper's QPS (cached/ONNX) uses specific hardware & inference stacks.
# These approximations let you compare trends, not exact values.