# ‚ö° Notebook 4: SPLADE Custom Matrix Engine (Experimental)
**Author:** Gabriele Righi

**Project:** Dense vs Sparse Retrieval Reproducibility

## üéØ Objective
This notebook implements a **custom "Pure Python" search engine** for the SPLADE model.
It serves as the engineering solution to the **"Sparse Indexing Challenge"**: standard tools (like Pyserini) introduced quantization artifacts that degraded SPLADE's performance. Here, we implement exact sparse retrieval to measure the true quality of the model.

## üõ†Ô∏è Methodology (The "Matrix Engine")
Instead of using a traditional Inverted Index (Lucene), this notebook treats retrieval as a **Linear Algebra** problem:
1.  **Model:** Uses `naver/splade-cocondenser-ensembledistil` to generate sparse weights.
2.  **Structure:** Constructs a **SciPy CSR Matrix** ($D$) to store document vectors efficiently.
3.  **Search:** Performs retrieval via **Sparse Matrix Multiplication** ($S = Q \times D^T$).
    * *Advantage:* Preserves full `float32` precision (avoiding the quality loss of integer quantization).

## üìâ Scope & Constraints
Due to the computational cost of encoding **2.6 Million documents** with a BERT-based model (approx. 5+ hours on a single T4 GPU), this notebook performs a **Fast Estimation** on a subset:
* **Subset Size:** First **100,000 documents** of Natural Questions (NQ).
* **Goal:** To obtain a representative **nDCG@10** score for the comparison table without the massive time overhead.

## üìÇ Inputs & Outputs
* **Input:** Raw text from NQ Dataset (`Cohere/beir-embed-english-v3`).
* **Output:** Evaluation metrics (nDCG, Recall) for SPLADE on the specified subset.

---

In [None]:
import torch
from transformers import AutoModelForMaskedLM, AutoTokenizer
from datasets import load_dataset
import numpy as np
from scipy.sparse import csr_matrix
from tqdm import tqdm
import gc
import collections
# ==========================================
# 0. CONFIGURATION (UPDATED)
# ==========================================
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 32   
MAX_DOCS = 100000 
MODEL_ID = "naver/splade-cocondenser-ensembledistil"

def encode_splade(texts, batch_size=32):
    """
    Generates sparse vectors with Memory Optimization (AMP + Smaller Batches).
    """
    all_indices = []
    all_values = []
    all_indptr = [0]
    
    for i in tqdm(range(0, len(texts), batch_size), desc="Encoding Batch"):
        batch_texts = texts[i:i+batch_size]
        
        # Tokenization
        inputs = tokenizer(
            batch_texts, 
            return_tensors="pt", 
            padding=True, 
            truncation=True, 
            max_length=256
        ).to(DEVICE)
        
        with torch.no_grad():
            with torch.amp.autocast('cuda'): 
                logits = model(**inputs).logits
                
                # SPLADE LOGIC inside autocast
                values = torch.log(1 + torch.relu(logits))
                batch_scores, _ = torch.max(values, dim=1)
        
        # Move to CPU immediately to free GPU VRAM
        batch_scores = batch_scores.float().cpu().numpy() # Cast back to float32 for numpy
        
        # Clear GPU cache explicitely if needed (optional but safer)
        del inputs, logits, values
        
        for score_vec in batch_scores:
            non_zero_indices = np.nonzero(score_vec)[0]
            non_zero_values = score_vec[non_zero_indices]
            
            all_indices.extend(non_zero_indices)
            all_values.extend(non_zero_values)
            all_indptr.append(len(all_indices))
            
    vocab_size = tokenizer.vocab_size
    return csr_matrix(
        (all_values, all_indices, all_indptr), 
        shape=(len(texts), vocab_size),
        dtype=np.float32
    )

# ==========================================
# 2. DATASET LOADING & PREPROCESSING
# ==========================================
print(f"\n[2/4] Loading NQ Dataset (First {MAX_DOCS} docs)...")
corpus_ds = load_dataset("Cohere/beir-embed-english-v3", "nq-corpus", split="train")
queries_ds = load_dataset("Cohere/beir-embed-english-v3", "nq-queries", split="test")
qrels_ds = load_dataset("Cohere/beir-embed-english-v3", "nq-qrels", split="test")

# Optimization: Slice the corpus to speed up the process
corpus_ds = corpus_ds.select(range(MAX_DOCS))

# Combining Title + Text for better recall (Standard SPLADE practice)
doc_texts = [f"{t} {txt}" for t, txt in zip(corpus_ds['title'], corpus_ds['text'])]
doc_ids = corpus_ds['_id']

# --- ROBUST QREL LOADING (Fix for KeyError) ---
print("Mapping QRELs (Ground Truth)...")
qrels = collections.defaultdict(dict)

# Automatically detect column names (handles 'query-id' vs 'query_id')
cols = qrels_ds.column_names
q_key = 'query_id' if 'query_id' in cols else 'query-id'
c_key = 'corpus_id' if 'corpus_id' in cols else 'corpus-id'
s_key = 'score'

print(f"Detected columns -> Query: '{q_key}', Corpus: '{c_key}'")

for row in tqdm(qrels_ds, desc="Processing QRELs"):
    qid = str(row[q_key])
    did = str(row[c_key])
    score = int(row[s_key])
    qrels[qid][did] = score

# ==========================================
# 3. ENCODING (The Heavy Lifting)
# ==========================================
print("\n[3/4] Encoding Corpus (Matrix D)...")
doc_matrix = encode_splade(doc_texts, batch_size=BATCH_SIZE)

print("Encoding Queries (Matrix Q)...")
query_matrix = encode_splade(queries_ds['text'], batch_size=BATCH_SIZE)

# Convert Document Matrix to CSC (Compressed Sparse Column)
# This makes column slicing and transposition much faster for the dot product
doc_matrix_t = doc_matrix.T.tocsc()

# Clean up raw text to save RAM
del doc_texts, model
gc.collect()

# ==========================================
# 4. SEARCH & EVALUATION
# ==========================================
print("\n[4/4] Running Search (Matrix Multiplication S = Q * D^T)...")

k = 10
ndcg_list = []
doc_ids_lookup = np.array(doc_ids) # Faster numpy access for retrieval

for i in tqdm(range(query_matrix.shape[0]), desc="Evaluating"):
    # Get Query ID (using the robust key detection logic if needed, but usually '_id' is standard here)
    qid = str(queries_ds[i]['_id'])
    
    # Skip if query has no relevant documents in the ground truth
    if qid not in qrels: continue
    
    # Get current query vector (Sparse)
    q_vec = query_matrix[i]
    
    # Dot Product: (1, Vocab) @ (Vocab, Docs) -> (1, Docs)
    # This calculates the score for this query against ALL 100k documents
    scores = q_vec.dot(doc_matrix_t).toarray().flatten()
    
    # Fast Top-K Retrieval
    if len(scores) > k:
        # argpartition is O(n), much faster than O(n log n) full sort
        top_k_idx = np.argpartition(scores, -k)[-k:]
        # Sort only the top k results
        top_k_idx = top_k_idx[np.argsort(scores[top_k_idx])][::-1]
    else:
        top_k_idx = np.argsort(scores)[::-1]
        
    retrieved_ids = doc_ids_lookup[top_k_idx]
    
    # --- Metric Calculation (nDCG@10) ---
    relevant_docs = qrels[qid]
    
    dcg = 0.0
    idcg = 0.0
    
    # Calculate DCG (Discounted Cumulative Gain)
    for rank, doc_id in enumerate(retrieved_ids):
        if doc_id in relevant_docs:
            rel_score = relevant_docs[doc_id]
            dcg += rel_score / np.log2(rank + 2)
            
    # Calculate IDCG (Ideal DCG)
    ideal_rels = sorted(relevant_docs.values(), reverse=True)[:k]
    for rank, rel_score in enumerate(ideal_rels):
        idcg += rel_score / np.log2(rank + 2)
        
    ndcg_list.append(dcg / idcg if idcg > 0 else 0)

# ==========================================
# FINAL RESULTS
# ==========================================
mean_ndcg = np.mean(ndcg_list)

print("\n" + "="*40)
print(f"üèÜ SPLADE RESULT (NQ Subset {MAX_DOCS})")
print(f"nDCG@10:    {mean_ndcg:.4f}")
print("="*40)