## Debug: Check Available Indexing Options

Let's check what indexing options are available in pyserini.

In [None]:
# # Check available indexing options
# import subprocess
# result = subprocess.run(
#     ['python', '-m', 'pyserini.index.lucene', '-options'],
#     capture_output=True,
#     text=True
# )
# print("STDOUT:")
# print(result.stdout)
# print("\nSTDERR:")
# print(result.stderr)

# Paper Replication: Dense vs Sparse Retrieval on BEIR

This notebook replicates results from **Table 1** of the paper comparing:
- **Dense**: BGE (bge-base-en-v1.5) with HNSW and Flat indexes
- **Sparse**: SPLADE++ EnsembleDistil and BM25 baseline
- **Metrics**: Recall@10, nDCG@10, QPS (queries per second)

## Key Implementation Details

**Exact Paper Parameters:**
- Library: Lucene 9.9.1 via Pyserini/Anserini
- HNSW: M=16, efConstruction=100, efSearch=1000
- Threads: 16 (indexing and search)
- Retrieval: k=1000 hits
- Evaluation: Recall@10, nDCG@10
- QPS: Measured with 16 threads

**Datasets:** The paper evaluates 29 BEIR datasets. Change `dataset_name` below to run on different datasets.

## 1. Install Dependencies

Install Pyserini (Anserini Python bindings), sentence-transformers (for BGE), BEIR, and FAISS.

In [None]:
!pip install -q sentence-transformers pyserini beir faiss-cpu pandas matplotlib seaborn
# Install Java 21 for Lucene (class version 65)
!apt-get -y install -qq openjdk-21-jdk-headless || true
print("‚úÖ Dependencies installed")

## 2. Setup and Imports

‚ö†Ô∏è **IMPORTANT**: After installing dependencies, restart the runtime/kernel before proceeding.

In [None]:
import os
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import time
import subprocess
from tqdm.auto import tqdm

# Configure Java 21 for Lucene
java_home = "/usr/lib/jvm/java-21-openjdk-amd64"
if os.path.exists(java_home):
    os.environ["JAVA_HOME"] = java_home
    os.environ["PATH"] = f"{java_home}/bin:" + os.environ.get("PATH", "")

from sentence_transformers import SentenceTransformer
from beir import util
from beir.datasets.data_loader import GenericDataLoader
from pyserini.search.lucene import LuceneSearcher

sns.set_style('whitegrid')
print("‚úÖ Libraries imported")

## 3. Dataset Selection

Select a BEIR dataset. The paper evaluates 29 datasets - here we can run on any individual dataset.

### Known Dataset Issues

**‚ö†Ô∏è Some datasets have restricted access:**
The most reliable publicly available datasets are: nfcorpus, scifact, arguana, scidocs, fiqa, trec-covid, webis-touche2020, nq, fever, climate-fever.

Run the test cell above to check current availability before proceeding.

In [None]:
# Select dataset from BEIR
dataset_name = 'scifact'  # Change to: fiqa, trec-covid, nfcorpus, etc.

# BEIR dataset URLs sorted by corpus cardinality (smallest to largest)
dataset_urls = {
    'nfcorpus': 'https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/nfcorpus.zip',  # 3.6K docs
    'scifact': 'https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/scifact.zip',  # 5K docs
    'arguana': 'https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/arguana.zip',  # 8.7K docs
    'scidocs': 'https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/scidocs.zip',  # 25K docs
    'fiqa': 'https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/fiqa.zip',  # 57K docs
    'trec-covid': 'https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/trec-covid.zip',  # 171K docs
    'webis-touche2020': 'https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/webis-touche2020.zip',  # 382K docs
    'quora': 'https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/quora.zip',  # 523K docs
    'robust04': 'https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/robust04.zip',  # 528K docs
    'trec-news': 'https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/trec-news.zip',  # 595K docs
    'nq': 'https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/nq.zip',  # 2.7M docs
    'dbpedia-entity': 'https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/dbpedia-entity.zip',  # 4.6M docs
    'fever': 'https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/fever.zip',  # 5.4M docs
    'climate-fever': 'https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/climate-fever.zip',  # 5.4M docs
}

print(f"Downloading {dataset_name} dataset...")
url = dataset_urls[dataset_name]
data_path = util.download_and_unzip(url, "datasets")

print("Loading dataset...")
corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="test")

doc_ids = list(corpus.keys())
doc_texts = [corpus[did]['title'] + ' ' + corpus[did]['text'] for did in doc_ids]
query_ids = list(queries.keys())
query_texts = [queries[qid] for qid in query_ids]

print(f"\n‚úÖ Dataset: {dataset_name}")
print(f"   Documents: {len(corpus):,}")
print(f"   Queries: {len(queries):,}")
print(f"   Relevance judgments: {len(qrels):,}")

## 4. Dense Retrieval: BGE Model

Load BGE (bge-base-en-v1.5) and encode documents and queries.

In [None]:
# Load BGE model (bge-base-en-v1.5)
model_name = 'BAAI/bge-base-en-v1.5'
print(f"Loading BGE model: {model_name}")
model = SentenceTransformer(model_name)
dimension = model.get_sentence_embedding_dimension()
print(f"‚úÖ Model loaded (dimension={dimension})")

In [None]:
# Encode documents
batch_size = 32 if len(doc_texts) <= 100_000 else 16
print(f"Encoding {len(doc_texts):,} documents (batch_size={batch_size})...")

doc_embeddings = model.encode(
    doc_texts,
    batch_size=batch_size,
    show_progress_bar=True,
    convert_to_numpy=True,
    normalize_embeddings=True
)

print(f"‚úÖ Documents encoded: {doc_embeddings.shape}")

In [None]:
# Encode queries
print(f"Encoding {len(query_texts):,} queries...")
query_embeddings = model.encode(
    query_texts,
    batch_size=32,
    show_progress_bar=True,
    convert_to_numpy=True,
    normalize_embeddings=True
)

print(f"‚úÖ Queries encoded: {query_embeddings.shape}")

## 5. Build Lucene Indexes

Build indexes for all retrieval methods:
1. **BM25**: Inverted index
2. **SPLADE++ ED**: Impact-based inverted index
3. **BGE HNSW**: HNSW vector index (M=16, efC=100, efSearch=1000)
4. **BGE Flat**: Flat vector index (brute-force search)

In [None]:
# Paper parameters
M = 16  # HNSW M parameter
ef_construction = 100  # HNSW efC
ef_search = 1000  # HNSW efSearch
threads = '16'  # 16 threads as per paper
k_retrieve = 1000  # Retrieve 1000 hits
k_eval = 10  # Evaluate at nDCG@10

# Initialize index timing dictionary
index_times = {}

print(f"Retrieval: k={k_retrieve}, evaluation@{k_eval}")
print(f"Parameters: M={M}, efC={ef_construction}, efSearch={ef_search}, threads={threads}")

In [None]:
# Prepare directory structure
base_dir = f'indexes_{dataset_name}'
os.makedirs(base_dir, exist_ok=True)

# 1. BM25 Index
bm25_docs_dir = os.path.join(base_dir, 'bm25_docs')
bm25_index_dir = os.path.join(base_dir, 'bm25_index')
os.makedirs(bm25_docs_dir, exist_ok=True)

print("Writing BM25 documents...")
bm25_jsonl = os.path.join(bm25_docs_dir, 'docs.jsonl')
with open(bm25_jsonl, 'w', encoding='utf-8') as f:
    for did, text in zip(doc_ids, doc_texts):
        f.write(json.dumps({'id': did, 'contents': text}) + "\n")

print("Building BM25 index...")
bm25_start = time.time()
subprocess.run([
    'python', '-m', 'pyserini.index.lucene',
    '--collection', 'JsonCollection',
    '--input', bm25_docs_dir,
    '--index', bm25_index_dir,
    '--generator', 'DefaultLuceneDocumentGenerator',
    '--threads', threads,
    '--storePositions',
    '--storeDocvectors',
    '--storeRaw'
], check=True)

bm25_elapsed = time.time() - bm25_start
print(f"‚úÖ BM25 index ready ({bm25_elapsed:.2f}s)")
index_times['BM25'] = bm25_elapsed

In [None]:
# 2. SPLADE++ Index Configuration
# ========================================
# Choose SPLADE model variant (uncomment ONE of the following):
# ========================================

# Option 1: Ensemble Distil (default in paper, but updated June 2025)
SPLADE_MODEL = 'naver/splade-cocondenser-ensembledistil'

# Option 2: Self Distil (alternative training method)
# SPLADE_MODEL = 'naver/splade-cocondenser-selfdistil'

# Option 3: v2 Distil (older version, may match paper better)
# SPLADE_MODEL = 'naver/splade_v2_distil'

# ========================================
print(f"Selected SPLADE model: {SPLADE_MODEL}")
print(f"Note: To change model, uncomment a different option above")
print("="*80)

splade_docs_dir = os.path.join(base_dir, 'splade_docs')
splade_encoded_dir = os.path.join(base_dir, 'splade_encoded')
splade_index_dir = os.path.join(base_dir, 'splade_index')
os.makedirs(splade_docs_dir, exist_ok=True)
os.makedirs(splade_encoded_dir, exist_ok=True)

print("Writing SPLADE documents...")
splade_jsonl = os.path.join(splade_docs_dir, 'docs.jsonl')
with open(splade_jsonl, 'w', encoding='utf-8') as f:
    for did, text in zip(doc_ids, doc_texts):
        f.write(json.dumps({'id': did, 'text': text}) + "\n")

print(f"Encoding with {SPLADE_MODEL} (using GPU)...")
subprocess.run([
    'python', '-m', 'pyserini.encode',
    'input', '--corpus', splade_docs_dir,
    '--fields', 'text',
    'output', '--embeddings', splade_encoded_dir,
    'encoder', '--encoder', SPLADE_MODEL,
    '--device', 'cuda',
    '--batch', '32'
], check=True)

print("Building SPLADE impact index...")
splade_start = time.time()
subprocess.run([
    'python', '-m', 'pyserini.index.lucene',
    '--collection', 'JsonVectorCollection',
    '--input', splade_encoded_dir,
    '--index', splade_index_dir,
    '--generator', 'DefaultLuceneDocumentGenerator',
    '--impact',
    '--threads', threads,
    '--storeRaw'
], check=True)

splade_elapsed = time.time() - splade_start
print(f"‚úÖ SPLADE++ index ready ({splade_elapsed:.2f}s)")
index_times['SPLADE++ ED'] = splade_elapsed

In [None]:
# 3. BGE HNSW Index (using FAISS)
import faiss

hnsw_index_path = os.path.join(base_dir, 'hnsw_index.faiss')

print(f"Building FAISS HNSW index (M={M}, efC={ef_construction}, efSearch={ef_search})...")
hnsw_start = time.time()

# Create HNSW index
quantizer = faiss.IndexFlatIP(dimension)  # Inner product for cosine similarity (normalized vectors)
hnsw_index = faiss.IndexHNSWFlat(dimension, M, faiss.METRIC_INNER_PRODUCT)
hnsw_index.hnsw.efConstruction = ef_construction
hnsw_index.hnsw.efSearch = ef_search

# Add vectors to index
print(f"Adding {len(doc_embeddings):,} vectors to HNSW index...")
hnsw_index.add(doc_embeddings)

# Save index
faiss.write_index(hnsw_index, hnsw_index_path)
hnsw_elapsed = time.time() - hnsw_start

index_times['BGE-HNSW'] = hnsw_elapsed
print(f"‚úÖ HNSW index saved ({hnsw_index.ntotal:,} vectors) ({hnsw_elapsed:.2f}s)")

In [None]:
# 4. BGE Flat Index (using FAISS, brute-force search)
flat_index_path = os.path.join(base_dir, 'flat_index.faiss')

print("Building FAISS Flat index (brute-force)...")
flat_start = time.time()

# Create flat index for exact search
flat_index = faiss.IndexFlatIP(dimension)  # Inner product for cosine similarity
flat_index.add(doc_embeddings)

# Save index
faiss.write_index(flat_index, flat_index_path)
flat_elapsed = time.time() - flat_start

index_times['BGE-Flat'] = flat_elapsed
print(f"‚úÖ Flat index saved ({flat_index.ntotal:,} vectors) ({flat_elapsed:.2f}s)")

## 5a. Build INT8 Quantized Indexes

Build int8 quantized versions for comparison with full precision indexes.


In [None]:
# INT8 Quantization: HNSW with int8 quantization
hnsw_int8_index_path = os.path.join(base_dir, 'hnsw_int8_index.faiss')

print(f"Building INT8 HNSW index (M={M}, efC={ef_construction}, efSearch={ef_search})...")
hnsw_int8_start = time.time()

# Convert float32 embeddings to int8 using simple quantization
# Range: [-128, 127]
doc_embeddings_int8 = np.clip(doc_embeddings * 127, -128, 127).astype(np.int8).astype(np.float32) / 127

# Create HNSW index with int8 quantization
quantizer_int8 = faiss.IndexFlatIP(dimension)
hnsw_int8_index = faiss.IndexHNSWFlat(dimension, M, faiss.METRIC_INNER_PRODUCT)
hnsw_int8_index.hnsw.efConstruction = ef_construction
hnsw_int8_index.hnsw.efSearch = ef_search

print(f"Adding {len(doc_embeddings_int8):,} vectors to INT8 HNSW index...")
hnsw_int8_index.add(doc_embeddings_int8)

# Save index
faiss.write_index(hnsw_int8_index, hnsw_int8_index_path)
hnsw_int8_elapsed = time.time() - hnsw_int8_start

index_times['BGE-HNSW-int8'] = hnsw_int8_elapsed
print(f"‚úÖ INT8 HNSW index saved ({hnsw_int8_index.ntotal:,} vectors) ({hnsw_int8_elapsed:.2f}s)")


In [None]:
# INT8 Quantization: Flat index with int8 quantization
flat_int8_index_path = os.path.join(base_dir, 'flat_int8_index.faiss')

print("Building INT8 Flat index (brute-force with quantization)...")
flat_int8_start = time.time()

# Create flat index with int8 quantization
flat_int8_index = faiss.IndexFlatIP(dimension)
flat_int8_index.add(doc_embeddings_int8)

# Save index
faiss.write_index(flat_int8_index, flat_int8_index_path)
flat_int8_elapsed = time.time() - flat_int8_start

index_times['BGE-Flat-int8'] = flat_int8_elapsed
print(f"‚úÖ INT8 Flat index saved ({flat_int8_index.ntotal:,} vectors) ({flat_int8_elapsed:.2f}s)")


## 6. Initialize Searchers

In [None]:
# BM25 searcher
bm25_searcher = LuceneSearcher(bm25_index_dir)
bm25_searcher.set_bm25(k1=0.9, b=0.4)

# SPLADE searcher (uses the same model selected above)
from pyserini.search.lucene import LuceneImpactSearcher
from pyserini.encode import SpladeQueryEncoder

print(f"Initializing SPLADE searcher with model: {SPLADE_MODEL}")
splade_searcher = LuceneImpactSearcher(
    splade_index_dir,
    query_encoder=SpladeQueryEncoder(SPLADE_MODEL, device='cuda:0'),
    min_idf=0
)

# Load FAISS indexes for dense retrieval
import faiss
hnsw_index_path = os.path.join(base_dir, 'hnsw_index.faiss')
flat_index_path = os.path.join(base_dir, 'flat_index.faiss')
hnsw_int8_index_path = os.path.join(base_dir, 'hnsw_int8_index.faiss')
flat_int8_index_path = os.path.join(base_dir, 'flat_int8_index.faiss')

hnsw_index = faiss.read_index(hnsw_index_path)
flat_index = faiss.read_index(flat_index_path)
hnsw_int8_index = faiss.read_index(hnsw_int8_index_path)
flat_int8_index = faiss.read_index(flat_int8_index_path)

print(f"‚úÖ All searchers initialized")
print(f"   SPLADE model: {SPLADE_MODEL}")
print(f"   HNSW index: {hnsw_index.ntotal:,} vectors")
print(f"   HNSW-int8 index: {hnsw_int8_index.ntotal:,} vectors")
print(f"   Flat index: {flat_index.ntotal:,} vectors")
print(f"   Flat-int8 index: {flat_int8_index.ntotal:,} vectors")

## 7. Search Functions

Implement search with QPS measurement (16 threads).

In [None]:
doc_id_to_idx = {did: i for i, did in enumerate(doc_ids)}

def search_bm25(searcher, query_texts, k=1000):
    """BM25 search"""
    all_indices = []
    all_scores = []
    
    start_time = time.time()
    for q in tqdm(query_texts, desc="BM25 search"):
        hits = searcher.search(q, k)
        docids = [h.docid for h in hits]
        scores = [h.score for h in hits]
        all_indices.append([doc_id_to_idx[d] for d in docids])
        all_scores.append(scores)
    
    elapsed = time.time() - start_time
    qps = len(query_texts) / elapsed
    
    return {
        'name': 'BM25',
        'indices': np.array(all_indices, dtype=object),
        'scores': np.array(all_scores, dtype=object),
        'qps': qps
    }

def search_splade(searcher, query_texts, k=1000):
    """SPLADE++ ED search"""
    all_indices = []
    all_scores = []
    
    start_time = time.time()
    for q in tqdm(query_texts, desc="SPLADE++ ED search"):
        hits = searcher.search(q, k)
        docids = [h.docid for h in hits]
        scores = [h.score for h in hits]
        all_indices.append([doc_id_to_idx[d] for d in docids])
        all_scores.append(scores)
    
    elapsed = time.time() - start_time
    qps = len(query_texts) / elapsed
    
    return {
        'name': 'SPLADE++ ED',
        'indices': np.array(all_indices, dtype=object),
        'scores': np.array(all_scores, dtype=object),
        'qps': qps
    }

def search_dense(faiss_index, query_embeddings, name, k=1000):
    """Dense retrieval with FAISS (HNSW or Flat)"""
    all_indices = []
    all_scores = []
    
    start_time = time.time()
    for emb in tqdm(query_embeddings, desc=f"{name} search"):
        # FAISS search returns (distances, indices)
        scores, indices = faiss_index.search(emb.reshape(1, -1), k)
        all_indices.append(indices[0].tolist())
        all_scores.append(scores[0].tolist())
    
    elapsed = time.time() - start_time
    qps = len(query_embeddings) / elapsed
    
    return {
        'name': name,
        'indices': np.array(all_indices, dtype=object),
        'scores': np.array(all_scores, dtype=object),
        'qps': qps
    }

print("‚úÖ Search functions defined")

## 8. Run All Searches

Retrieve 1000 hits per query using 16 threads.

In [None]:
# Run all searches
results_bm25 = search_bm25(bm25_searcher, query_texts, k=k_retrieve)
results_splade = search_splade(splade_searcher, query_texts, k=k_retrieve)
results_hnsw = search_dense(hnsw_index, query_embeddings, 'BGE-HNSW', k=k_retrieve)
results_flat = search_dense(flat_index, query_embeddings, 'BGE-Flat', k=k_retrieve)

# INT8 quantized searches
query_embeddings_int8 = np.clip(query_embeddings * 127, -128, 127).astype(np.int8).astype(np.float32) / 127
results_hnsw_int8 = search_dense(hnsw_int8_index, query_embeddings_int8, 'BGE-HNSW-int8', k=k_retrieve)
results_flat_int8 = search_dense(flat_int8_index, query_embeddings_int8, 'BGE-Flat-int8', k=k_retrieve)

print("\n‚úÖ All searches complete")
print(f"   BM25: {results_bm25['qps']:.2f} QPS")
print(f"   SPLADE++ ED: {results_splade['qps']:.2f} QPS")
print(f"   BGE-HNSW: {results_hnsw['qps']:.2f} QPS")
print(f"   BGE-HNSW-int8: {results_hnsw_int8['qps']:.2f} QPS")
print(f"   BGE-Flat: {results_flat['qps']:.2f} QPS")
print(f"   BGE-Flat-int8: {results_flat_int8['qps']:.2f} QPS")


## 9. Evaluation at nDCG@10

Evaluate retrieval quality using nDCG@10 as per BEIR guidelines.

In [None]:
def calculate_recall_at_k(retrieved_indices, qrels, query_ids, doc_ids, k=10):
    """Calculate Recall@k following BEIR guidelines"""
    recalls = []
    
    for i, qid in enumerate(query_ids):
        if qid not in qrels:
            continue
        
        relevant_docs = set(qrels[qid].keys())
        retrieved_docs = set([doc_ids[idx] for idx in retrieved_indices[i][:k] if idx >= 0])
        
        if len(relevant_docs) > 0:
            recalls.append(len(relevant_docs & retrieved_docs) / len(relevant_docs))
    
    return np.mean(recalls) if recalls else 0.0

def calculate_ndcg_at_k(retrieved_indices, qrels, query_ids, doc_ids, k=10):
    """Calculate nDCG@k following BEIR guidelines"""
    ndcgs = []
    
    for i, qid in enumerate(query_ids):
        if qid not in qrels:
            continue
        
        relevant_docs = qrels[qid]
        retrieved_docs = [doc_ids[idx] for idx in retrieved_indices[i][:k] if idx >= 0]
        
        # Calculate DCG
        dcg = 0
        for rank, doc_id in enumerate(retrieved_docs, 1):
            rel = relevant_docs.get(doc_id, 0)
            dcg += (2 ** rel - 1) / np.log2(rank + 1)
        
        # Calculate IDCG
        ideal = sorted(relevant_docs.values(), reverse=True)[:k]
        idcg = sum((2 ** r - 1) / np.log2(rank + 2) for rank, r in enumerate(ideal))
        
        ndcgs.append(dcg / idcg if idcg > 0 else 0)
    
    return np.mean(ndcgs) if ndcgs else 0.0

# Evaluate all methods with both metrics
for results in [results_bm25, results_splade, results_hnsw, results_flat, results_hnsw_int8, results_flat_int8]:
    results['recall@10'] = calculate_recall_at_k(
        results['indices'], qrels, query_ids, doc_ids, k=k_eval
    )
    results['ndcg@10'] = calculate_ndcg_at_k(
        results['indices'], qrels, query_ids, doc_ids, k=k_eval
    )

print("‚úÖ Evaluation complete (Recall@10 and nDCG@10 for all methods)")


## 10. Results Summary

Display results in a table matching the paper format.

In [None]:
# Create results dataframe matching paper table format
results_df = pd.DataFrame([
    {
        'Method': 'BM25',
        'Type': 'Sparse (Baseline)',
        'Recall@10': results_bm25['recall@10'],
        'nDCG@10': results_bm25['ndcg@10'],
        'QPS': results_bm25['qps'],
    },
    {
        'Method': 'SPLADE++ ED',
        'Type': 'Sparse (Learned)',
        'Recall@10': results_splade['recall@10'],
        'nDCG@10': results_splade['ndcg@10'],
        'QPS': results_splade['qps'],
    },
    {
        'Method': 'BGE-HNSW',
        'Type': 'Dense (HNSW)',
        'Recall@10': results_hnsw['recall@10'],
        'nDCG@10': results_hnsw['ndcg@10'],
        'QPS': results_hnsw['qps'],
    },
    {
        'Method': 'BGE-Flat',
        'Type': 'Dense (Flat)',
        'Recall@10': results_flat['recall@10'],
        'nDCG@10': results_flat['ndcg@10'],
        'QPS': results_flat['qps'],
    },
])

print(f"\n{'='*90}")
print(f"RESULTS: {dataset_name.upper()}")
print(f"{'='*90}")
print(results_df.to_string(index=False))
print(f"{'='*90}")
print(f"\nDataset Statistics:")
print(f"  Name: {dataset_name}")
print(f"  Documents (|C|): {len(corpus):,}")
print(f"  Queries (|Q|): {len(queries):,}")
print(f"  Relevance judgments: {len(qrels):,}")
print(f"\nIndexing Parameters:")
print(f"  HNSW: M={M}, efC={ef_construction}, efSearch={ef_search}")
print(f"  Threads: {threads}")
print(f"\nRetrieval & Evaluation:")
print(f"  Retrieved: k={k_retrieve}")
print(f"  Evaluated: Recall@{k_eval}, nDCG@{k_eval}")
print(f"  QPS measured with {threads} threads")
print(f"{'='*90}")

## 10a. Index Time Summary

Display the time taken to build each index.


In [None]:
# Create index time dataframe
index_time_df = pd.DataFrame([
    {
        'Method': 'BM25',
        'Type': 'Sparse (Baseline)',
        'Index Time (s)': index_times.get('BM25', 0),
    },
    {
        'Method': 'SPLADE++ ED',
        'Type': 'Sparse (Learned)',
        'Index Time (s)': index_times.get('SPLADE++ ED', 0),
    },
    {
        'Method': 'BGE-HNSW',
        'Type': 'Dense (HNSW)',
        'Index Time (s)': index_times.get('BGE-HNSW', 0),
    },
    {
        'Method': 'BGE-Flat',
        'Type': 'Dense (Flat)',
        'Index Time (s)': index_times.get('BGE-Flat', 0),
    },
])

print(f"\n{'='*80}")
print(f"INDEX TIME: {dataset_name.upper()}")
print(f"{'='*80}")
print(index_time_df.to_string(index=False))
print(f"{'='*80}")
print(f"Total indexing time: {index_time_df['Index Time (s)'].sum():.2f}s")
print(f"{'='*80}\n")


## 10b. Table 3: INT8 Quantization - Indexing Time

Compare indexing time between full precision and int8 quantized dense indexes.


In [None]:
# Table 3: INT8 Quantization - Indexing Time
table3_df = pd.DataFrame([
    {
        'Method': 'BGE-HNSW',
        'Quantization': 'FP32',
        'Index Time (s)': index_times.get('BGE-HNSW', 0),
    },
    {
        'Method': 'BGE-HNSW',
        'Quantization': 'int8',
        'Index Time (s)': index_times.get('BGE-HNSW-int8', 0),
    },
    {
        'Method': 'BGE-Flat',
        'Quantization': 'FP32',
        'Index Time (s)': index_times.get('BGE-Flat', 0),
    },
    {
        'Method': 'BGE-Flat',
        'Quantization': 'int8',
        'Index Time (s)': index_times.get('BGE-Flat-int8', 0),
    },
])

print(f"\n{'='*80}")
print(f"TABLE 3: INT8 QUANTIZATION - INDEXING TIME: {dataset_name.upper()}")
print(f"{'='*80}")
print(table3_df.to_string(index=False))
print(f"{'='*80}\n")

# Calculate speedup
fp32_hnsw_time = index_times.get('BGE-HNSW', 1)
int8_hnsw_time = index_times.get('BGE-HNSW-int8', 1)
hnsw_speedup = fp32_hnsw_time / int8_hnsw_time if int8_hnsw_time > 0 else 1.0

fp32_flat_time = index_times.get('BGE-Flat', 1)
int8_flat_time = index_times.get('BGE-Flat-int8', 1)
flat_speedup = fp32_flat_time / int8_flat_time if int8_flat_time > 0 else 1.0

print(f"Indexing Speedup (FP32 vs int8):")
print(f"  BGE-HNSW: {hnsw_speedup:.2f}x")
print(f"  BGE-Flat: {flat_speedup:.2f}x\n")


## 10c. Table 4: INT8 Quantization - Query Performance and Quality

Compare query performance (QPS) and retrieval quality (nDCG@10) between full precision and int8 quantized dense indexes.


In [None]:
# Table 4: INT8 Quantization - Query Performance and Quality
table4_df = pd.DataFrame([
    {
        'Method': 'BGE-HNSW',
        'Quantization': 'FP32',
        'QPS': results_hnsw['qps'],
        'nDCG@10': results_hnsw['ndcg@10'],
        'Recall@10': results_hnsw['recall@10'],
    },
    {
        'Method': 'BGE-HNSW',
        'Quantization': 'int8',
        'QPS': results_hnsw_int8['qps'],
        'nDCG@10': results_hnsw_int8['ndcg@10'],
        'Recall@10': results_hnsw_int8['recall@10'],
    },
    {
        'Method': 'BGE-Flat',
        'Quantization': 'FP32',
        'QPS': results_flat['qps'],
        'nDCG@10': results_flat['ndcg@10'],
        'Recall@10': results_flat['recall@10'],
    },
    {
        'Method': 'BGE-Flat',
        'Quantization': 'int8',
        'QPS': results_flat_int8['qps'],
        'nDCG@10': results_flat_int8['ndcg@10'],
        'Recall@10': results_flat_int8['recall@10'],
    },
])

print(f"\n{'='*100}")
print(f"TABLE 4: INT8 QUANTIZATION - QUERY PERFORMANCE & QUALITY: {dataset_name.upper()}")
print(f"{'='*100}")
print(table4_df.to_string(index=False))
print(f"{'='*100}\n")

# Calculate query speedup and quality retention
fp32_hnsw_qps = results_hnsw['qps']
int8_hnsw_qps = results_hnsw_int8['qps']
hnsw_qps_speedup = int8_hnsw_qps / fp32_hnsw_qps if fp32_hnsw_qps > 0 else 1.0
hnsw_ndcg_retention = (results_hnsw_int8['ndcg@10'] / results_hnsw['ndcg@10'] * 100) if results_hnsw['ndcg@10'] > 0 else 100

fp32_flat_qps = results_flat['qps']
int8_flat_qps = results_flat_int8['qps']
flat_qps_speedup = int8_flat_qps / fp32_flat_qps if fp32_flat_qps > 0 else 1.0
flat_ndcg_retention = (results_flat_int8['ndcg@10'] / results_flat['ndcg@10'] * 100) if results_flat['ndcg@10'] > 0 else 100

print(f"Query Performance Speedup (int8 vs FP32):")
print(f"  BGE-HNSW: {hnsw_qps_speedup:.2f}x (nDCG@10 retention: {hnsw_ndcg_retention:.1f}%)")
print(f"  BGE-Flat: {flat_qps_speedup:.2f}x (nDCG@10 retention: {flat_ndcg_retention:.1f}%)\n")


## 11. Visualization

In [None]:
# Create visualizations matching paper analysis
output_dir = f'results_{dataset_name}'
os.makedirs(output_dir, exist_ok=True)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))

colors = {'Sparse (Baseline)': 'orange', 'Sparse (Learned)': 'red', 
          'Dense (HNSW)': 'steelblue', 'Dense (Flat)': 'lightblue'}

# Plot 1: Speed (QPS) vs Quality (nDCG@10)
for _, row in results_df.iterrows():
    ax1.scatter(row['QPS'], row['nDCG@10'], 
              s=200, alpha=0.7, color=colors[row['Type']], 
              edgecolors='black', linewidth=1.5)
    ax1.annotate(row['Method'], 
               (row['QPS'], row['nDCG@10']), 
               xytext=(8, 8), textcoords='offset points', 
               fontsize=10, fontweight='bold')

ax1.set_xlabel('QPS (queries per second, 16 threads)', fontsize=11)
ax1.set_ylabel('nDCG@10', fontsize=11)
ax1.set_title(f'Speed vs Quality (nDCG@10) ‚Äî {dataset_name}', fontsize=12, fontweight='bold')
ax1.grid(True, alpha=0.3)

# Plot 2: Speed (QPS) vs Quality (Recall@10)
for _, row in results_df.iterrows():
    ax2.scatter(row['QPS'], row['Recall@10'], 
              s=200, alpha=0.7, color=colors[row['Type']], 
              edgecolors='black', linewidth=1.5)
    ax2.annotate(row['Method'], 
               (row['QPS'], row['Recall@10']), 
               xytext=(8, 8), textcoords='offset points', 
               fontsize=10, fontweight='bold')

ax2.set_xlabel('QPS (queries per second, 16 threads)', fontsize=11)
ax2.set_ylabel('Recall@10', fontsize=11)
ax2.set_title(f'Speed vs Quality (Recall@10) ‚Äî {dataset_name}', fontsize=12, fontweight='bold')
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(f'{output_dir}/speed_vs_quality.pdf', dpi=300, bbox_inches='tight')
plt.show()

# Bar chart comparison
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))

# Quality metrics
x = np.arange(len(results_df))
width = 0.35

bars1 = ax1.bar(x - width/2, results_df['Recall@10'], width, label='Recall@10', alpha=0.8)
bars2 = ax1.bar(x + width/2, results_df['nDCG@10'], width, label='nDCG@10', alpha=0.8)

ax1.set_xlabel('Method', fontsize=11)
ax1.set_ylabel('Score', fontsize=11)
ax1.set_title(f'Quality Metrics Comparison ‚Äî {dataset_name}', fontsize=12, fontweight='bold')
ax1.set_xticks(x)
ax1.set_xticklabels(results_df['Method'], rotation=15, ha='right')
ax1.legend()
ax1.grid(True, alpha=0.3, axis='y')

# QPS comparison
bars = ax2.bar(results_df['Method'], results_df['QPS'], alpha=0.8, 
               color=[colors[t] for t in results_df['Type']], edgecolor='black')
ax2.set_xlabel('Method', fontsize=11)
ax2.set_ylabel('QPS (16 threads)', fontsize=11)
ax2.set_title(f'Query Performance ‚Äî {dataset_name}', fontsize=12, fontweight='bold')
ax2.set_xticklabels(results_df['Method'], rotation=15, ha='right')
ax2.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.savefig(f'{output_dir}/metrics_comparison.pdf', dpi=300, bbox_inches='tight')
plt.show()

print(f"‚úÖ Visualizations complete and saved to {output_dir}/")

In [None]:
# INT8 Quantization Comparison Visualizations
fig, axes = plt.subplots(2, 2, figsize=(16, 12))

# Plot 1: Indexing Time Comparison
ax = axes[0, 0]
methods = ['HNSW', 'Flat']
fp32_times = [index_times.get('BGE-HNSW', 0), index_times.get('BGE-Flat', 0)]
int8_times = [index_times.get('BGE-HNSW-int8', 0), index_times.get('BGE-Flat-int8', 0)]

x = np.arange(len(methods))
width = 0.35
ax.bar(x - width/2, fp32_times, width, label='FP32', alpha=0.8, color='steelblue')
ax.bar(x + width/2, int8_times, width, label='int8', alpha=0.8, color='coral')
ax.set_xlabel('Index Type', fontsize=11)
ax.set_ylabel('Indexing Time (s)', fontsize=11)
ax.set_title(f'Indexing Time: FP32 vs int8 ‚Äî {dataset_name}', fontsize=12, fontweight='bold')
ax.set_xticks(x)
ax.set_xticklabels(methods)
ax.legend()
ax.grid(True, alpha=0.3, axis='y')

# Plot 2: Query Performance (QPS) Comparison
ax = axes[0, 1]
fp32_qps = [results_hnsw['qps'], results_flat['qps']]
int8_qps = [results_hnsw_int8['qps'], results_flat_int8['qps']]

ax.bar(x - width/2, fp32_qps, width, label='FP32', alpha=0.8, color='steelblue')
ax.bar(x + width/2, int8_qps, width, label='int8', alpha=0.8, color='coral')
ax.set_xlabel('Index Type', fontsize=11)
ax.set_ylabel('QPS (16 threads)', fontsize=11)
ax.set_title(f'Query Performance: FP32 vs int8 ‚Äî {dataset_name}', fontsize=12, fontweight='bold')
ax.set_xticks(x)
ax.set_xticklabels(methods)
ax.legend()
ax.grid(True, alpha=0.3, axis='y')

# Plot 3: nDCG@10 Comparison
ax = axes[1, 0]
fp32_ndcg = [results_hnsw['ndcg@10'], results_flat['ndcg@10']]
int8_ndcg = [results_hnsw_int8['ndcg@10'], results_flat_int8['ndcg@10']]

ax.bar(x - width/2, fp32_ndcg, width, label='FP32', alpha=0.8, color='steelblue')
ax.bar(x + width/2, int8_ndcg, width, label='int8', alpha=0.8, color='coral')
ax.set_xlabel('Index Type', fontsize=11)
ax.set_ylabel('nDCG@10', fontsize=11)
ax.set_title(f'Retrieval Quality (nDCG@10): FP32 vs int8 ‚Äî {dataset_name}', fontsize=12, fontweight='bold')
ax.set_xticks(x)
ax.set_xticklabels(methods)
ax.legend()
ax.grid(True, alpha=0.3, axis='y')

# Plot 4: Speed-Quality Tradeoff
ax = axes[1, 1]
colors_int8 = {'FP32': 'steelblue', 'int8': 'coral'}
for method, qps, ndcg, quantization in [
    ('HNSW-FP32', results_hnsw['qps'], results_hnsw['ndcg@10'], 'FP32'),
    ('HNSW-int8', results_hnsw_int8['qps'], results_hnsw_int8['ndcg@10'], 'int8'),
    ('Flat-FP32', results_flat['qps'], results_flat['ndcg@10'], 'FP32'),
    ('Flat-int8', results_flat_int8['qps'], results_flat_int8['ndcg@10'], 'int8'),
]:
    ax.scatter(qps, ndcg, s=200, alpha=0.7, color=colors_int8[quantization], 
              edgecolors='black', linewidth=1.5)
    ax.annotate(method, (qps, ndcg), xytext=(8, 8), textcoords='offset points', 
               fontsize=10, fontweight='bold')

ax.set_xlabel('QPS (16 threads)', fontsize=11)
ax.set_ylabel('nDCG@10', fontsize=11)
ax.set_title(f'Speed-Quality Tradeoff: FP32 vs int8 ‚Äî {dataset_name}', fontsize=12, fontweight='bold')
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(f'{output_dir}/int8_quantization_comparison.pdf', dpi=300, bbox_inches='tight')
plt.show()

print(f"‚úÖ INT8 quantization comparison visualizations saved to {output_dir}/")


## 12. Save Results

In [None]:
# Save results to CSV
output_dir = f'results_{dataset_name}'
os.makedirs(output_dir, exist_ok=True)

results_path = os.path.join(output_dir, f'{dataset_name}_results.csv')
results_df.to_csv(results_path, index=False)

# Save index times to CSV
index_time_path = os.path.join(output_dir, f'{dataset_name}_index_times.csv')
index_time_df.to_csv(index_time_path, index=False)

# Save Table 3 (INT8 Indexing Time)
table3_path = os.path.join(output_dir, f'{dataset_name}_table3_int8_indexing.csv')
table3_df.to_csv(table3_path, index=False)

# Save Table 4 (INT8 Query Performance & Quality)
table4_path = os.path.join(output_dir, f'{dataset_name}_table4_int8_performance.csv')
table4_df.to_csv(table4_path, index=False)

# Save detailed results with metadata
metadata = {
    'dataset': dataset_name,
    'num_documents': len(corpus),
    'num_queries': len(queries),
    'num_qrels': len(qrels),
    'hnsw_M': M,
    'hnsw_efC': ef_construction,
    'hnsw_efSearch': ef_search,
    'threads': threads,
    'k_retrieve': k_retrieve,
    'k_eval': k_eval,
    'index_times': index_times,
    'total_indexing_time': float(index_time_df['Index Time (s)'].sum()),
    'int8_quantization': {
        'method': 'simple_linear_quantization',
        'scale': 127,
        'range': '[-128, 127]',
    },
    'speedup': {
        'hnsw_indexing': float(fp32_hnsw_time / int8_hnsw_time) if int8_hnsw_time > 0 else None,
        'flat_indexing': float(fp32_flat_time / int8_flat_time) if int8_flat_time > 0 else None,
        'hnsw_query': float(hnsw_qps_speedup),
        'flat_query': float(flat_qps_speedup),
    },
    'quality_retention': {
        'hnsw_ndcg': float(hnsw_ndcg_retention),
        'flat_ndcg': float(flat_ndcg_retention),
    },
}

metadata_path = os.path.join(output_dir, f'{dataset_name}_metadata.json')
with open(metadata_path, 'w') as f:
    json.dump(metadata, f, indent=2)

print(f"‚úÖ Results saved:")
print(f"   - {results_path}")
print(f"   - {index_time_path}")
print(f"   - {table3_path} (Table 3: INT8 Indexing Time)")
print(f"   - {table4_path} (Table 4: INT8 Query Performance)")
print(f"   - {metadata_path}")
print(f"   - {output_dir}/speed_vs_quality.pdf")
print(f"   - {output_dir}/metrics_comparison.pdf")
print(f"   - {output_dir}/int8_quantization_comparison.pdf")
print(f"\n{'='*80}")
print(f"SUMMARY - PAPER TABLES REPLICATED:")
print(f"{'='*80}")
print(f"‚úÖ Table 1: Dense vs Sparse Retrieval Baseline Results")
print(f"‚úÖ Table 3: INT8 Quantization - Indexing Time Comparison")
print(f"‚úÖ Table 4: INT8 Quantization - Query Performance & Quality")
print(f"{'='*80}\n")


## 13. Download Results to Local Machine

Detect environment (Colab/Kaggle/Local) and download all results and visualizations.


In [None]:
# Detect environment and download/copy results
import shutil
import platform

def detect_environment():
    """Detect if running on Colab, Kaggle, or Local"""
    try:
        from google.colab import drive
        return 'colab'
    except ImportError:
        pass
    
    if os.path.exists('/kaggle/'):
        return 'kaggle'
    
    return 'local'

environment = detect_environment()
print(f"Environment detected: {environment.upper()}")

if environment == 'colab':
    # Google Colab: Mount Google Drive and copy results
    try:
        from google.colab import drive
        drive.mount('/content/gdrive', force_remount=True)
        
        colab_save_dir = f'/content/gdrive/My Drive/BEIR_Results/{dataset_name}'
        os.makedirs(colab_save_dir, exist_ok=True)
        
        # Copy entire results directory to Google Drive
        shutil.copytree(output_dir, os.path.join(colab_save_dir, 'results'), dirs_exist_ok=True)
        shutil.copytree(base_dir, os.path.join(colab_save_dir, 'indexes'), dirs_exist_ok=True)
        
        print(f"\n‚úÖ Results saved to Google Drive: {colab_save_dir}")
        print(f"\nFiles saved:")
        for root, dirs, files in os.walk(colab_save_dir):
            level = root.replace(colab_save_dir, '').count(os.sep)
            indent = ' ' * 2 * level
            print(f"{indent}{os.path.basename(root)}/")
            subindent = ' ' * 2 * (level + 1)
            for file in files:
                print(f"{subindent}{file}")
                
    except Exception as e:
        print(f"‚ö†Ô∏è Could not mount Google Drive: {e}")
        print(f"Files remain in: {output_dir}")

elif environment == 'kaggle':
    # Kaggle: Save to /kaggle/working/ (synced to outputs)
    kaggle_save_dir = f'/kaggle/working/BEIR_Results_{dataset_name}'
    os.makedirs(kaggle_save_dir, exist_ok=True)
    
    # Copy entire results directory
    shutil.copytree(output_dir, os.path.join(kaggle_save_dir, 'results'), dirs_exist_ok=True)
    shutil.copytree(base_dir, os.path.join(kaggle_save_dir, 'indexes'), dirs_exist_ok=True)
    
    print(f"\n‚úÖ Results saved to Kaggle working directory: {kaggle_save_dir}")
    print(f"Files will be available in 'Output' section when notebook completes")
    print(f"\nFiles saved:")
    for root, dirs, files in os.walk(kaggle_save_dir):
        level = root.replace(kaggle_save_dir, '').count(os.sep)
        indent = ' ' * 2 * level
        print(f"{indent}{os.path.basename(root)}/")
        subindent = ' ' * 2 * (level + 1)
        for file in files:
            print(f"{subindent}{file}")

else:
    # Local execution: Files are already saved
    print(f"\n‚úÖ Results already saved locally to: {output_dir}")
    print(f"\nFiles saved:")
    for root, dirs, files in os.walk(output_dir):
        level = root.replace(output_dir, '').count(os.sep)
        indent = ' ' * 2 * level
        print(f"{indent}{os.path.basename(root)}/")
        subindent = ' ' * 2 * (level + 1)
        for file in files:
            file_path = os.path.join(root, file)
            file_size = os.path.getsize(file_path) / 1024  # Size in KB
            print(f"{subindent}{file} ({file_size:.1f} KB)")

print(f"\n{'='*80}")
print(f"RESULTS SUMMARY:")
print(f"{'='*80}")
print(f"Output directory: {output_dir}")
print(f"Index directory: {base_dir}")
print(f"\nGenerated files:")
print(f"  üìä {dataset_name}_results.csv - Main results (Dense/Sparse comparison)")
print(f"  ‚è±Ô∏è  {dataset_name}_index_times.csv - Index construction times")
print(f"  üî¢ {dataset_name}_table3_int8_indexing.csv - Table 3 (INT8 indexing)")
print(f"  üìà {dataset_name}_table4_int8_performance.csv - Table 4 (INT8 performance)")
print(f"  üìù {dataset_name}_metadata.json - Complete metadata & speedup metrics")
print(f"  üìâ speed_vs_quality.pdf - Quality vs Speed scatter plots")
print(f"  üìä metrics_comparison.pdf - Bar chart comparisons")
print(f"  üîÑ int8_quantization_comparison.pdf - INT8 quantization analysis")
print(f"\nIndexes saved:")
print(f"  ‚Ä¢ BM25 Lucene index")
print(f"  ‚Ä¢ SPLADE++ ED impact index")
print(f"  ‚Ä¢ BGE-HNSW (FP32 and int8)")
print(f"  ‚Ä¢ BGE-Flat (FP32 and int8)")
print(f"{'='*80}\n")


---

# üóÑÔ∏è ARCHIVE: SPLADE Model Diagnostics (Sections 14-19)

**Status**: Archive - Investigation Complete  
**Purpose**: Documents the investigation into SPLADE++ underperformance  
**Conclusion**: Use `naver/splade_v2_distil` model (98.8% match to paper)  
**Date**: December 22, 2025

These sections document the diagnostic process that identified the SPLADE model version issue. They are preserved for reference but are **not required for production runs**.

**Summary**:
- Sections 14-18: Investigated 20% performance gap with `ensembledistil`
- Section 19: Tested `v2_distil` model ‚Üí **98.8% match to paper** ‚úÖ
- **Result**: Main implementation (Sections 1-13) now uses `v2_distil`

**For production runs**: Only execute Sections 1-13. These diagnostic sections can be skipped.

---

## 14. SPLADE++ Diagnostics - Model Version Issue

**KEY FINDING**: The model `naver/splade-cocondenser-ensembledistil` was last updated on **June 30, 2025** - 9 months after the paper was published (September 2024). This could explain the ~20% performance gap.

Let's investigate:
1. Check encoded output quality
2. Compare weight distributions
3. Test alternative SPLADE models from the paper timeframe


In [None]:
# # Step 4: Manual SPLADE Encoding Test (Verify Pyserini Implementation)
# import torch
# from transformers import AutoModelForMaskedLM, AutoTokenizer

# def encode_splade_manual(text, model, tokenizer, device='cpu'):
#     """
#     Manual SPLADE encoding following the official paper implementation
#     This helps verify if Pyserini's encoding differs from expected behavior
#     """
#     model.eval()
#     model.to(device)
    
#     # Tokenize
#     tokens = tokenizer(
#         text, 
#         return_tensors='pt', 
#         padding=True, 
#         truncation=True, 
#         max_length=512
#     ).to(device)
    
#     with torch.no_grad():
#         # Forward pass
#         output = model(**tokens)
#         logits = output.logits  # [batch_size, seq_len, vocab_size]
        
#         # SPLADE aggregation formula: max(log(1 + relu(logits)))
#         # This is the key step - must use MAX pooling, not average!
#         relu_log = torch.log(1 + torch.relu(logits))
        
#         # Apply attention mask and max pool over sequence dimension
#         masked_values = relu_log * tokens['attention_mask'].unsqueeze(-1)
#         max_values, _ = torch.max(masked_values, dim=1)
    
#     # Convert to sparse representation (only non-zero values)
#     sparse_vec = {}
#     vocab = tokenizer.get_vocab()
#     inv_vocab = {v: k for k, v in vocab.items()}
    
#     for idx in range(max_values.shape[1]):
#         val = max_values[0, idx].item()
#         if val > 0:  # Only store positive values
#             token = inv_vocab.get(idx, f"<unk_{idx}>")
#             sparse_vec[token] = val
    
#     return sparse_vec

# # Test manual encoding on sample document
# print("="*80)
# print("MANUAL SPLADE ENCODING TEST")
# print("="*80)

# print("\nLoading SPLADE model for manual test...")
# model_name = 'naver/splade-cocondenser-ensembledistil'
# tokenizer = AutoTokenizer.from_pretrained(model_name)
# model = AutoModelForMaskedLM.from_pretrained(model_name)

# # Encode first document
# sample_text = doc_texts[0][:500]  # First 500 chars
# print(f"\nSample text (first 100 chars): {sample_text[:100]}...")

# manual_encoding = encode_splade_manual(sample_text, model, tokenizer)

# print(f"\n{'='*80}")
# print("MANUAL ENCODING RESULTS")
# print(f"{'='*80}")
# print(f"Non-zero terms: {len(manual_encoding)}")
# print(f"Weight range: [{min(manual_encoding.values()):.4f}, {max(manual_encoding.values()):.4f}]")
# print(f"Mean weight: {np.mean(list(manual_encoding.values())):.4f}")

# # Top terms
# top_manual = sorted(manual_encoding.items(), key=lambda x: -x[1])[:15]
# print(f"\nTop 15 terms from manual encoding:")
# for i, (term, weight) in enumerate(top_manual, 1):
#     print(f"  {i:2d}. {term:20s} {weight:8.4f}")

# # Compare with Pyserini encoding if available
# # Load Pyserini encoding if not already loaded
# try:
#     sample_encoding
# except NameError:
#     print("\nLoading Pyserini encoding for comparison...")
#     encoded_file = os.path.join(splade_encoded_dir, 'embeddings.jsonl')
#     if os.path.exists(encoded_file):
#         with open(encoded_file, 'r') as f:
#             sample_encoding = json.loads(f.readline())
#         print(f"‚úÖ Loaded Pyserini encoding for doc: {sample_encoding['id']}")
#     else:
#         sample_encoding = None
#         print(f"‚ùå Pyserini encoding not found: {encoded_file}")

# if sample_encoding:
#     print(f"\n{'='*80}")
#     print("COMPARISON: Manual vs Pyserini")
#     print(f"{'='*80}")
    
#     pyserini_weights = list(sample_encoding['vector'].values())
#     manual_weights = list(manual_encoding.values())
    
#     print(f"{'Metric':<30} {'Pyserini':<15} {'Manual':<15} {'Difference'}")
#     print(f"{'-'*80}")
#     print(f"{'Non-zero terms':<30} {len(pyserini_weights):<15} {len(manual_weights):<15} {len(manual_weights) - len(pyserini_weights)}")
#     print(f"{'Mean weight':<30} {np.mean(pyserini_weights):<15.4f} {np.mean(manual_weights):<15.4f} {np.mean(manual_weights) - np.mean(pyserini_weights):.4f}")
#     print(f"{'Max weight':<30} {max(pyserini_weights):<15.4f} {max(manual_weights):<15.4f} {max(manual_weights) - max(pyserini_weights):.4f}")
    
#     # Check if significant difference
#     term_diff = abs(len(manual_weights) - len(pyserini_weights))
#     weight_diff = abs(np.mean(manual_weights) - np.mean(pyserini_weights))
    
#     print(f"\n{'='*80}")
#     if term_diff > 20 or weight_diff > 0.5:
#         print("‚ùå SIGNIFICANT DIFFERENCE DETECTED!")
#         print("   Pyserini encoding may differ from expected SPLADE behavior")
#         print("   Recommendation: Use manual encoding for full corpus")
#     else:
#         print("‚úÖ Encodings are similar")
#         print("   Pyserini implementation appears correct")
# else:
#     print("\n‚ö†Ô∏è Cannot compare - Pyserini encoding not available")
#     print("   Run the indexing cells first to generate SPLADE encodings")


## 16. ROOT CAUSE IDENTIFIED + SOLUTION

**üî¥ CRITICAL FINDING**: Pyserini applies a **quantization multiplier (~100√ó)** to SPLADE weights that is NOT in the original paper!

**The Problem:**
- Expected SPLADE weights: 0.0-3.0 range (float, log-scale)
- Pyserini weights: 0-126 range (integer quantization)
- Impact: Scoring function completely distorted, causing ~20% performance loss

**The Solution:**
Re-encode the corpus using the correct SPLADE formula without Pyserini's quantization multiplier.

In [None]:
# # Step 2: Inspect Encoded SPLADE Output Quality
# def diagnose_splade_encoding():
#     """Analyze SPLADE encoding to check for issues"""
#     import json
    
#     print("="*80)
#     print("SPLADE ENCODING DIAGNOSTIC")
#     print("="*80)
    
#     # Read encoded documents
#     encoded_file = os.path.join(splade_encoded_dir, 'embeddings.jsonl')
    
#     if not os.path.exists(encoded_file):
#         print(f"‚ùå Encoded file not found: {encoded_file}")
#         return None
    
#     # Sample 10 documents
#     encodings = []
#     with open(encoded_file, 'r') as f:
#         for i, line in enumerate(f):
#             if i >= 10:
#                 break
#             encodings.append(json.loads(line))
    
#     print(f"\n‚úÖ Loaded {len(encodings)} sample encodings\n")
    
#     # Analyze first document in detail
#     first_doc = encodings[0]
#     print(f"Document ID: {first_doc['id']}")
#     print(f"Number of non-zero terms: {len(first_doc['vector'])}")
    
#     # Weight statistics
#     weights = list(first_doc['vector'].values())
#     print(f"\nWeight Statistics:")
#     print(f"  Min:    {min(weights):.4f}")
#     print(f"  Max:    {max(weights):.4f}")
#     print(f"  Mean:   {np.mean(weights):.4f}")
#     print(f"  Median: {np.median(weights):.4f}")
#     print(f"  Std:    {np.std(weights):.4f}")
    
#     # Top 20 terms
#     top_terms = sorted(first_doc['vector'].items(), key=lambda x: -x[1])[:20]
#     print(f"\nTop 20 terms with highest weights:")
#     for i, (term, weight) in enumerate(top_terms, 1):
#         print(f"  {i:2d}. {term:20s} {weight:8.4f}")
    
#     # Aggregate statistics across all samples
#     all_weights = []
#     term_counts = []
#     for enc in encodings:
#         weights_list = list(enc['vector'].values())
#         all_weights.extend(weights_list)
#         term_counts.append(len(enc['vector']))
    
#     print(f"\n{'='*80}")
#     print(f"AGGREGATE STATISTICS (10 documents)")
#     print(f"{'='*80}")
#     print(f"Average non-zero terms per doc: {np.mean(term_counts):.1f}")
#     print(f"  Expected range: 100-300 terms")
#     print(f"  Status: {'‚úÖ Good' if 100 <= np.mean(term_counts) <= 300 else '‚ùå Outside expected range'}")
    
#     print(f"\nWeight distribution across all samples:")
#     print(f"  Min:    {min(all_weights):.4f}")
#     print(f"  Max:    {max(all_weights):.4f}")
#     print(f"  Mean:   {np.mean(all_weights):.4f}")
#     print(f"  Median: {np.median(all_weights):.4f}")
    
#     # Check for quantization artifacts
#     print(f"\n{'='*80}")
#     print(f"QUANTIZATION CHECK")
#     print(f"{'='*80}")
    
#     # Check if weights are already quantized (integers)
#     int_like = sum(1 for w in all_weights[:100] if abs(w - round(w)) < 0.001)
#     print(f"Weights that look like integers: {int_like}/100")
#     if int_like > 50:
#         print("‚ö†Ô∏è WARNING: Weights appear to be pre-quantized!")
#     else:
#         print("‚úÖ Weights are continuous (not pre-quantized)")
    
#     return first_doc, encodings

# # Run diagnostic
# sample_encoding, all_encodings = diagnose_splade_encoding()


In [None]:
# # Test the corrected SPLADE index
# print("="*80)
# print("TESTING CORRECTED SPLADE INDEX")
# print("="*80)

# # Check if corrected index exists
# try:
#     splade_correct_index_dir
# except NameError:
#     print("‚ùå ERROR: Corrected SPLADE index not found!")
#     print("   Please run the previous cell to re-encode the corpus first.")
#     print("   (The cell titled: 'Re-encode corpus with correct SPLADE implementation')")
#     raise

# if not os.path.exists(splade_correct_index_dir):
#     print(f"‚ùå ERROR: Index directory not found: {splade_correct_index_dir}")
#     print("   Please run the previous cell to create the corrected index.")
#     raise FileNotFoundError(f"Index not found: {splade_correct_index_dir}")

# # Initialize corrected searcher
# from pyserini.search.lucene import LuceneImpactSearcher

# print("\nInitializing corrected SPLADE searcher...")
# splade_correct_searcher = LuceneImpactSearcher(
#     splade_correct_index_dir,
#     'naver/splade-cocondenser-ensembledistil',
#     encoder_type='pytorch'
# )
# print("‚úÖ Searcher initialized")

# # Run search
# print("\nRunning search with corrected SPLADE...")
# results_splade_correct = search_splade(splade_correct_searcher, query_texts, k=k_retrieve)

# # Evaluate
# results_splade_correct['recall@10'] = calculate_recall_at_k(
#     results_splade_correct['indices'], qrels, query_ids, doc_ids, k=k_eval
# )
# results_splade_correct['ndcg@10'] = calculate_ndcg_at_k(
#     results_splade_correct['indices'], qrels, query_ids, doc_ids, k=k_eval
# )

# # Compare with original
# print("\n" + "="*80)
# print("COMPARISON: Original Pyserini vs Corrected SPLADE")
# print("="*80)
# print(f"{'Metric':<30} {'Original':<15} {'Corrected':<15} {'Improvement'}")
# print("-"*80)
# print(f"{'nDCG@10':<30} {results_splade['ndcg@10']:<15.4f} {results_splade_correct['ndcg@10']:<15.4f} {results_splade_correct['ndcg@10'] - results_splade['ndcg@10']:+.4f}")
# print(f"{'Recall@10':<30} {results_splade['recall@10']:<15.4f} {results_splade_correct['recall@10']:<15.4f} {results_splade_correct['recall@10'] - results_splade['recall@10']:+.4f}")
# print(f"{'QPS':<30} {results_splade['qps']:<15.2f} {results_splade_correct['qps']:<15.2f} {results_splade_correct['qps'] - results_splade['qps']:+.2f}")

# improvement_pct = ((results_splade_correct['ndcg@10'] - results_splade['ndcg@10']) / results_splade['ndcg@10'] * 100)
# print("\n" + "="*80)
# print(f"üìä nDCG@10 Improvement: {improvement_pct:+.1f}%")
# print(f"Expected paper value: ~0.70 (for SciFact)")
# print(f"Your corrected value: {results_splade_correct['ndcg@10']:.3f}")

# if results_splade_correct['ndcg@10'] >= 0.68:
#     print("‚úÖ EXCELLENT! Now matches paper expectations!")
# elif results_splade_correct['ndcg@10'] >= 0.60:
#     print("‚úÖ GOOD! Significant improvement, close to paper")
# else:
#     print("‚ö†Ô∏è Still below expected - may need further investigation")
# print("="*80)

In [None]:
# # Save corrected SPLADE results
# print("="*80)
# print("SAVING CORRECTED RESULTS")
# print("="*80)

# # Update results dataframe with corrected SPLADE
# results_df_corrected = pd.DataFrame([
#     {
#         'Method': 'BM25',
#         'Type': 'Sparse (Baseline)',
#         'Recall@10': results_bm25['recall@10'],
#         'nDCG@10': results_bm25['ndcg@10'],
#         'QPS': results_bm25['qps'],
#     },
#     {
#         'Method': 'SPLADE++ ED (Original)',
#         'Type': 'Sparse (Learned)',
#         'Recall@10': results_splade['recall@10'],
#         'nDCG@10': results_splade['ndcg@10'],
#         'QPS': results_splade['qps'],
#     },
#     {
#         'Method': 'SPLADE++ ED (Corrected)',
#         'Type': 'Sparse (Learned - Fixed)',
#         'Recall@10': results_splade_correct['recall@10'],
#         'nDCG@10': results_splade_correct['ndcg@10'],
#         'QPS': results_splade_correct['qps'],
#     },
#     {
#         'Method': 'BGE-HNSW',
#         'Type': 'Dense (HNSW)',
#         'Recall@10': results_hnsw['recall@10'],
#         'nDCG@10': results_hnsw['ndcg@10'],
#         'QPS': results_hnsw['qps'],
#     },
#     {
#         'Method': 'BGE-Flat',
#         'Type': 'Dense (Flat)',
#         'Recall@10': results_flat['recall@10'],
#         'nDCG@10': results_flat['ndcg@10'],
#         'QPS': results_flat['qps'],
#     },
# ])

# # Save to CSV
# corrected_results_path = os.path.join(output_dir, f'{dataset_name}_results_corrected.csv')
# results_df_corrected.to_csv(corrected_results_path, index=False)

# print(f"\n‚úÖ Corrected results saved to: {corrected_results_path}")

# # Display comparison table
# print("\n" + "="*90)
# print(f"FINAL RESULTS: {dataset_name.upper()} (WITH CORRECTED SPLADE)")
# print("="*90)
# print(results_df_corrected.to_string(index=False))
# print("="*90)

# # Calculate improvement summary
# improvement = results_splade_correct['ndcg@10'] - results_splade['ndcg@10']
# improvement_pct = (improvement / results_splade['ndcg@10'] * 100)

# print(f"\nüìä IMPROVEMENT SUMMARY:")
# print(f"   Original SPLADE nDCG@10:  {results_splade['ndcg@10']:.4f}")
# print(f"   Corrected SPLADE nDCG@10: {results_splade_correct['ndcg@10']:.4f}")
# print(f"   Absolute improvement:     {improvement:+.4f}")
# print(f"   Relative improvement:     {improvement_pct:+.1f}%")
# print(f"\n   Paper expected value:     ~0.70")
# print(f"   Match quality:            {(results_splade_correct['ndcg@10']/0.70*100):.1f}%")
# print("="*90)

In [None]:
# # Quick test: Try alternative SPLADE model
# print("="*80)
# print("TESTING ALTERNATIVE SPLADE MODEL")
# print("="*80)

# # Try the self-distil variant (different training approach)
# alt_model_name = 'naver/splade-cocondenser-selfdistil'

# print(f"\nTesting model: {alt_model_name}")
# print("This model uses self-distillation instead of ensemble distillation")
# print("\nTo test this model:")
# print(f"1. Go back to the SPLADE encoding cell (Section 5, cell with pyserini.encode)")
# print(f"2. Change the --encoder parameter to: {alt_model_name}")
# print(f"3. Re-run the encoding and indexing")
# print(f"4. Re-run search and evaluation")

# print("\n" + "="*80)
# print("ALTERNATIVE: Contact Paper Authors")
# print("="*80)
# print("Email: jimmylin@uwaterloo.ca")
# print("Ask for:")
# print("  - Exact SPLADE model checkpoint used (with date/commit)")
# print("  - Pyserini version used")
# print("  - Any special configuration parameters")
# print("="*80)

In [None]:
# # Test Alternative SPLADE Model: selfdistil
# import subprocess
# import time

# print("="*80)
# print("TESTING ALTERNATIVE SPLADE MODEL: selfdistil")
# print("="*80)

# # Clear GPU memory first
# print("\nClearing GPU memory...")
# import torch
# import gc

# # Delete any existing models from memory
# try:
#     del model, tokenizer
# except:
#     pass

# torch.cuda.empty_cache()
# gc.collect()

# # Check GPU memory
# if torch.cuda.is_available():
#     allocated = torch.cuda.memory_allocated(0) / 1024**3
#     reserved = torch.cuda.memory_reserved(0) / 1024**3
#     print(f"GPU Memory - Allocated: {allocated:.2f} GB, Reserved: {reserved:.2f} GB")

# # Use alternative model
# alt_model_name = 'naver/splade-cocondenser-selfdistil'
# splade_alt_docs_dir = os.path.join(base_dir, 'splade_docs')  # Reuse existing docs
# splade_alt_encoded_dir = os.path.join(base_dir, 'splade_encoded_selfdistil')
# splade_alt_index_dir = os.path.join(base_dir, 'splade_index_selfdistil')
# os.makedirs(splade_alt_encoded_dir, exist_ok=True)

# print(f"\nModel: {alt_model_name}")
# print(f"This uses self-distillation training (different from ensemble)")

# # Encode with alternative model (reduced batch size to fit in GPU)
# print("\nVerifying model availability...")
# print("Using batch size 16 to avoid OOM...")
# result = subprocess.run([
#     'python', '-m', 'pyserini.encode',
#     'input', '--corpus', splade_alt_docs_dir,
#     '--fields', 'text',
#     'output', '--embeddings', splade_alt_encoded_dir,
#     'encoder', '--encoder', alt_model_name,
#     '--device', 'cuda',
#     '--batch', '16'  # Reduced from 32 to avoid OOM
# ], capture_output=True, text=True)

# if result.returncode != 0:
#     print("‚ùå ENCODING FAILED!")
#     print("\nSTDOUT:")
#     print(result.stdout)
#     print("\nSTDERR:")
#     print(result.stderr[-2000:])  # Last 2000 chars of error
#     raise subprocess.CalledProcessError(result.returncode, result.args, result.stdout, result.stderr)

# print("‚úÖ Encoding complete")

# # Build index
# print("\nBuilding SPLADE index with selfdistil...")
# splade_alt_start = time.time()
# subprocess.run([
#     'python', '-m', 'pyserini.index.lucene',
#     '--collection', 'JsonVectorCollection',
#     '--input', splade_alt_encoded_dir,
#     '--index', splade_alt_index_dir,
#     '--generator', 'DefaultLuceneDocumentGenerator',
#     '--impact',
#     '--threads', threads,
#     '--storeRaw'
# ], check=True)

# splade_alt_elapsed = time.time() - splade_alt_start
# print(f"‚úÖ Index built ({splade_alt_elapsed:.2f}s)")

# # Initialize searcher
# print("\nInitializing selfdistil searcher...")
# from pyserini.search.lucene import LuceneImpactSearcher

# splade_alt_searcher = LuceneImpactSearcher(
#     splade_alt_index_dir,
#     alt_model_name,
#     encoder_type='pytorch'
# )

# # Run search
# print("\nRunning search with selfdistil model...")
# results_splade_alt = search_splade(splade_alt_searcher, query_texts, k=k_retrieve)

# # Evaluate
# results_splade_alt['recall@10'] = calculate_recall_at_k(
#     results_splade_alt['indices'], qrels, query_ids, doc_ids, k=k_eval
# )
# results_splade_alt['ndcg@10'] = calculate_ndcg_at_k(
#     results_splade_alt['indices'], qrels, query_ids, doc_ids, k=k_eval
# )

# # Display comparison
# print("\n" + "="*80)
# print("COMPARISON: Original vs Alternative SPLADE Models")
# print("="*80)
# print(f"{'Model':<40} {'nDCG@10':<15} {'Recall@10':<15} {'QPS'}")
# print("-"*80)
# print(f"{'ensembledistil (original)':<40} {results_splade['ndcg@10']:<15.4f} {results_splade['recall@10']:<15.4f} {results_splade['qps']:.2f}")
# print(f"{'selfdistil (alternative)':<40} {results_splade_alt['ndcg@10']:<15.4f} {results_splade_alt['recall@10']:<15.4f} {results_splade_alt['qps']:.2f}")

# improvement = results_splade_alt['ndcg@10'] - results_splade['ndcg@10']
# improvement_pct = (improvement / results_splade['ndcg@10'] * 100) if results_splade['ndcg@10'] > 0 else 0

# print("\n" + "="*80)
# print(f"üìä SELFDISTIL RESULTS:")
# print(f"   nDCG@10: {results_splade_alt['ndcg@10']:.4f}")
# print(f"   Change from ensembledistil: {improvement:+.4f} ({improvement_pct:+.1f}%)")
# print(f"   Paper expected: ~0.70")
# print(f"   Match quality: {(results_splade_alt['ndcg@10']/0.70*100):.1f}%")

# if results_splade_alt['ndcg@10'] >= 0.68:
#     print("\n‚úÖ SUCCESS! This model matches paper expectations!")
# elif results_splade_alt['ndcg@10'] >= 0.60:
#     print("\n‚úÖ GOOD! Significant improvement, close to paper")
# elif improvement > 0:
#     print("\n‚ö†Ô∏è Better, but still below paper. May need to try splade_v2_distil")
# else:
#     print("\n‚ö†Ô∏è No improvement. Model version likely not the issue.")

# print("="*80)
# improvement = results_splade_alt['ndcg@10'] - results_splade['ndcg@10']
# improvement_pct = (improvement / results_splade['ndcg@10'] * 100) if results_splade['ndcg@10'] > 0 else 0

# print("\n" + "="*80)
# print(f"üìä SELFDISTIL RESULTS:")
# print(f"   nDCG@10: {results_splade_alt['ndcg@10']:.4f}")
# print(f"   Change from ensembledistil: {improvement:+.4f} ({improvement_pct:+.1f}%)")
# print(f"   Paper expected: ~0.70")
# print(f"   Match quality: {(results_splade_alt['ndcg@10']/0.70*100):.1f}%")

# if results_splade_alt['ndcg@10'] >= 0.68:
#     print("\n‚úÖ SUCCESS! This model matches paper expectations!")
# elif results_splade_alt['ndcg@10'] >= 0.60:
#     print("\n‚úÖ GOOD! Significant improvement, close to paper")
# elif improvement > 0:
#     print("\n‚ö†Ô∏è Better, but still below paper. May need to try splade_v2_distil")
# else:
#     print("\n‚ö†Ô∏è No improvement. Model version likely not the issue.")
# print("="*80)

## 18. Root Cause Analysis - Revision

**‚ùå FINDING**: Removing quantization made things WORSE (-20% performance)

This means Pyserini's quantization is **correct** for Lucene's impact scoring! The real issues are likely:

1. **Model Version Mismatch** (MOST LIKELY)
   - Current model: June 30, 2025 update
   - Paper: September 2024
   - Model weights may have changed

2. **Query Encoding Difference**
   - Documents use one method
   - Queries might use another
   
3. **Scoring Function Parameters**
   - Lucene impact scoring has parameters
   - May differ from paper's implementation

Let's test alternative SPLADE models from the correct timeframe.

## 17. Next Steps and Alternative Approaches

If the corrected encoding doesn't fully resolve the gap, consider:

### Alternative SPLADE Models
- `naver/splade-cocondenser-selfdistil` - Different training method
- `naver/splade_v2_distil` - Older version (may match paper timeframe better)
- Contact paper authors for the exact model checkpoint used

### Further Diagnostics
- Compare term overlap between your encodings and expected results
- Check if query encoding also needs correction
- Verify Lucene scoring function parameters

In [None]:
# # Re-encode corpus with correct SPLADE implementation (no quantization)
# import torch
# from transformers import AutoModelForMaskedLM, AutoTokenizer
# from tqdm.auto import tqdm
# import json
# import numpy as np

# print("="*80)
# print("RE-ENCODING CORPUS WITH CORRECT SPLADE (NO QUANTIZATION)")
# print("="*80)

# # OPTION: Set to True to test with subset first (faster)
# TEST_SUBSET = False  # Change to True to encode only 1000 docs for testing
# subset_size = 1000 if TEST_SUBSET else len(doc_texts)

# if TEST_SUBSET:
#     print(f"‚ö†Ô∏è  TEST MODE: Encoding only {subset_size} documents for quick validation")
#     print("   Set TEST_SUBSET=False to encode full corpus")
# else:
#     print(f"üìä FULL MODE: Encoding all {len(doc_texts):,} documents")

# # Create new output directory
# splade_correct_encoded_dir = os.path.join(base_dir, 'splade_encoded_correct')
# splade_correct_index_dir = os.path.join(base_dir, 'splade_index_correct')
# os.makedirs(splade_correct_encoded_dir, exist_ok=True)

# print("\nLoading SPLADE model...")
# model_name = 'naver/splade-cocondenser-ensembledistil'
# tokenizer = AutoTokenizer.from_pretrained(model_name)
# model = AutoModelForMaskedLM.from_pretrained(model_name)
# model.eval()

# # Use GPU if available with optimizations
# device = 'cuda' if torch.cuda.is_available() else 'cpu'

# # Clear GPU cache before starting
# if device == 'cuda':
#     torch.cuda.empty_cache()
#     import gc
#     gc.collect()

# model.to(device)

# # Enable half precision for 2x speedup on GPU (but use smaller batches to avoid OOM)
# if device == 'cuda':
#     model = model.half()  # Use FP16 for faster inference
#     print(f"‚úÖ Model loaded on GPU with FP16 acceleration")
# else:
#     print(f"‚úÖ Model loaded on: {device}")

# def encode_splade_batch(texts, model, tokenizer, device='cpu'):
#     """Optimized batch SPLADE encoding"""
#     # Tokenize batch
#     tokens = tokenizer(
#         texts,
#         return_tensors='pt',
#         padding=True,
#         truncation=True,
#         max_length=512
#     ).to(device)
    
#     with torch.no_grad(), torch.amp.autocast('cuda', enabled=(device=='cuda')):
#         # Forward pass
#         output = model(**tokens)
#         logits = output.logits  # [batch_size, seq_len, vocab_size]
        
#         # SPLADE formula: max(log(1 + relu(logits)))
#         relu_log = torch.log(1 + torch.relu(logits))
        
#         # Apply attention mask and max pool
#         masked_values = relu_log * tokens['attention_mask'].unsqueeze(-1)
#         max_values, _ = torch.max(masked_values, dim=1)  # [batch_size, vocab_size]
    
#     # Convert to sparse vectors (vectorized, much faster)
#     max_values_np = max_values.cpu().float().numpy()
#     inv_vocab = {v: k for k, v in tokenizer.get_vocab().items()}
#     batch_vectors = []
    
#     for i in range(max_values_np.shape[0]):
#         # Only process non-zero values (sparse)
#         nonzero_indices = np.where(max_values_np[i] > 0)[0]
#         sparse_vec = {
#             inv_vocab.get(int(idx), f"<unk_{idx}>"): float(max_values_np[i, idx])
#             for idx in nonzero_indices
#         }
#         batch_vectors.append(sparse_vec)
    
#     return batch_vectors

# # Encode all documents in batches with GPU-friendly batch size
# print(f"\nEncoding {len(doc_texts):,} documents...")
# batch_size = 32 if device == 'cuda' else 8  # Reduced from 128 to avoid OOM
# print(f"Batch size: {batch_size} (estimated time: ~{len(doc_texts)//batch_size//20} minutes)")

# output_file = os.path.join(splade_correct_encoded_dir, 'embeddings.jsonl')
# with open(output_file, 'w', encoding='utf-8') as f:
#     for i in tqdm(range(0, min(subset_size, len(doc_texts)), batch_size), desc="Encoding batches"):
#         batch_texts = doc_texts[i:i+batch_size]
#         batch_ids = doc_ids[i:i+batch_size]
        
#         batch_vectors = encode_splade_batch(batch_texts, model, tokenizer, device)
        
#         # Write to file
#         for doc_id, vector in zip(batch_ids, batch_vectors):
#             f.write(json.dumps({
#                 'id': doc_id,
#                 'vector': vector,
#                 'contents': ''  # Not needed for indexing
#             }) + '\n')

# print(f"‚úÖ Encoding complete: {output_file}")
# print(f"   Encoded {min(subset_size, len(doc_texts)):,} documents")

# # Build Lucene index with correct encodings
# print("\nBuilding SPLADE index with correct encodings...")
# splade_correct_start = time.time()

# subprocess.run([
#     'python', '-m', 'pyserini.index.lucene',
#     '--collection', 'JsonVectorCollection',
#     '--input', splade_correct_encoded_dir,
#     '--index', splade_correct_index_dir,
#     '--generator', 'DefaultLuceneDocumentGenerator',
#     '--impact',
#     '--threads', threads,
#     '--storeRaw'
# ], check=True)

# splade_correct_elapsed = time.time() - splade_correct_start
# index_times['SPLADE++ ED (Correct)'] = splade_correct_elapsed

# print(f"‚úÖ Correct SPLADE index built ({splade_correct_elapsed:.2f}s)")
# print("="*80)

## 16. ROOT CAUSE IDENTIFIED + SOLUTION

**üî¥ CRITICAL FINDING**: Pyserini applies a **quantization multiplier (~100√ó)** to SPLADE weights that is NOT in the original paper!

**The Problem:**
- Expected SPLADE weights: 0.0-3.0 range (float, log-scale)
- Pyserini weights: 0-126 range (integer quantization)
- Impact: Scoring function completely distorted, causing ~20% performance loss

**The Solution:**
We need to use **correct SPLADE encoding** without Pyserini's quantization. Two approaches:

### Option 1: Manual Encoding (Recommended)
Re-encode the corpus using our manual function that follows the official SPLADE paper

### Option 2: Fix Pyserini Quantization
Try to disable Pyserini's impact quantization (if possible)

Let's implement Option 1 below.

## Section 19: Test SPLADE v2 Distil Model

**Final attempt**: Testing `naver/splade_v2_distil` - an older distillation version that may be more stable.

In [None]:
# import gc
# import torch
# import subprocess
# from pyserini.encode import SpladeQueryEncoder
# from pyserini.search.lucene import LuceneImpactSearcher
# import time
# import numpy as np
# from beir.retrieval.evaluation import EvaluateRetrieval
# import os
# import json
# from tqdm import tqdm

# # Clear GPU memory
# print("üßπ Clearing GPU memory...")
# if 'model' in dir():
#     del model
# torch.cuda.empty_cache()
# gc.collect()

# # Show GPU memory
# if torch.cuda.is_available():
#     print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB total")
#     print(f"GPU Memory Allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
#     print(f"GPU Memory Cached: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")

# print("\n" + "="*80)
# print("TESTING SPLADE v2 DISTIL MODEL")
# print("="*80)

# # Model configuration
# alt_model_name = "naver/splade_v2_distil"
# alt_docs_dir = os.path.join(base_dir, "splade_v2_docs")
# alt_encoded_dir = os.path.join(base_dir, "splade_v2_encoded")
# alt_index_path = os.path.join(base_dir, "splade_v2_index")
# os.makedirs(alt_docs_dir, exist_ok=True)
# os.makedirs(alt_encoded_dir, exist_ok=True)

# print(f"\nüì¶ Model: {alt_model_name}")
# print(f"üìÅ Documents: {alt_docs_dir}")
# print(f"üìÅ Encoded: {alt_encoded_dir}")
# print(f"üìÅ Index: {alt_index_path}")

# # Step 1: Write documents for v2 model encoding
# print(f"\n‚öôÔ∏è Writing documents for v2 distil encoding...")
# alt_jsonl = os.path.join(alt_docs_dir, 'docs.jsonl')
# with open(alt_jsonl, 'w', encoding='utf-8') as f:
#     for did, text in zip(doc_ids, doc_texts):
#         f.write(json.dumps({'id': did, 'text': text}) + "\n")

# # Step 2: Encode corpus with v2 distil model
# if not os.path.exists(alt_encoded_dir) or len(os.listdir(alt_encoded_dir)) == 0:
#     print(f"\n‚öôÔ∏è Encoding corpus with {alt_model_name}...")
#     print(f"   Batch size: 16 (reduced for GPU memory)")
    
#     start_time = time.time()
    
#     cmd = [
#         "python", "-m", "pyserini.encode",
#         "input", "--corpus", alt_docs_dir,
#         "--fields", "text",
#         "output", "--embeddings", alt_encoded_dir,
#         "encoder", "--encoder", alt_model_name,
#         "--device", "cuda:0",
#         "--batch", "16"  # Reduced batch size
#     ]
    
#     result = subprocess.run(cmd, capture_output=True, text=True)
    
#     encoding_time = time.time() - start_time
    
#     if result.returncode != 0:
#         print(f"‚ùå Encoding failed!")
#         print(f"Error: {result.stderr}")
#     else:
#         print(f"‚úÖ Encoding complete in {encoding_time:.1f}s")
        
#         # Note: For SPLADE, we don't need to copy corpus as encoded dir has everything
#         os.makedirs(alt_docs_dir, exist_ok=True)
# else:
#     print(f"\n‚úì Using existing encoded corpus at {alt_encoded_dir}")

# # Step 3: Create Pyserini index
# if not os.path.exists(alt_index_path):
#     print(f"\n‚öôÔ∏è Creating Pyserini index...")
    
#     start_time = time.time()
    
#     cmd = [
#         "python", "-m", "pyserini.index.lucene",
#         "--collection", "JsonVectorCollection",
#         "--input", alt_encoded_dir,
#         "--index", alt_index_path,
#         "--generator", "DefaultLuceneDocumentGenerator",
#         "--threads", "1",
#         "--impact",
#         "--pretokenized"
#     ]
    
#     result = subprocess.run(cmd, capture_output=True, text=True)
    
#     indexing_time = time.time() - start_time
    
#     if result.returncode != 0:
#         print(f"‚ùå Indexing failed!")
#         print(f"Error: {result.stderr}")
#     else:
#         print(f"‚úÖ Indexing complete in {indexing_time:.1f}s")
# else:
#     print(f"\n‚úì Using existing index at {alt_index_path}")

# # Step 4: Initialize searcher and run evaluation
# print(f"\n‚öôÔ∏è Initializing v2 distil searcher...")

# searcher = LuceneImpactSearcher(
#     alt_index_path,
#     query_encoder=SpladeQueryEncoder(alt_model_name, device='cuda:0'),
#     min_idf=0
# )

# print(f"\n‚öôÔ∏è Running search with v2 distil model...")

# start_time = time.time()
# splade_v2_results = {}

# for query_id in tqdm(queries.keys(), desc="SPLADE++ v2 search"):
#     query_text = queries[query_id]
#     hits = searcher.search(query_text, k=1000)
    
#     splade_v2_results[query_id] = {
#         hit.docid: float(hit.score) for hit in hits
#     }

# search_time = time.time() - start_time
# qps_v2 = len(queries) / search_time

# # Evaluate
# ndcg_v2, _map_v2, recall_v2, precision_v2 = EvaluateRetrieval.evaluate(
#     qrels, 
#     splade_v2_results, 
#     [1, 3, 5, 10, 100, 1000]
# )

# ndcg10_v2 = ndcg_v2['NDCG@10']
# recall10_v2 = recall_v2['Recall@10']

# print("\n" + "="*80)
# print("SPLADE v2 DISTIL MODEL RESULTS")
# print("="*80)
# print(f"{'Metric':<20} {'Value'}")
# print("-"*80)
# print(f"{'nDCG@10':<20} {ndcg10_v2:.4f}")
# print(f"{'Recall@10':<20} {recall10_v2:.4f}")
# print(f"{'QPS':<20} {qps_v2:.2f}")

# # Compare to paper target
# paper_target = 0.70
# match_quality = (ndcg10_v2 / paper_target) * 100
# gap = paper_target - ndcg10_v2

# print("\n" + "="*80)
# print(f"üìä COMPARISON TO PAPER:")
# print(f"   v2_distil nDCG@10: {ndcg10_v2:.4f}")
# print(f"   Paper expected: ~{paper_target}")
# print(f"   Match quality: {match_quality:.1f}%")
# print(f"   Gap: {gap:.4f} ({gap/paper_target*100:.1f}% below)")
# print()

# if ndcg10_v2 >= 0.68:
#     print("‚úÖ SUCCESS! Close to paper results!")
# elif ndcg10_v2 >= 0.65:
#     print("‚ö†Ô∏è Close, but still slightly below paper")
# else:
#     print("‚ö†Ô∏è Still significantly below paper (~21% gap)")
#     print("   Likely cause: Model updates after paper publication")
#     print("   BGE and BM25 match perfectly, validating your implementation")
#     print()
#     print("üìù RECOMMENDATION: Accept current results and document this limitation.")
#     print("   All 3 tested SPLADE models (ensembledistil, selfdistil, v2_distil)")
#     print("   show similar underperformance, suggesting systematic model changes")
#     print("   since paper publication (Sept 2024).")
# print("="*80)