<a href="https://colab.research.google.com/github/ahmedsaalman/low-resource-rag-comparison/blob/main/Retriever_Model_NLP_Project.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**Cells 1 - 5b: Sparse Retriver Model (B2M5)**

**Cells 6 - 8c: Dense Retriver Model and its fine tuning (FAISS)**

**Cells 9 - 9b: Hybrid/Finalized Retriver Model (Both B2M5 and Dense fused together)**

In [None]:
# Cell 1: Install required libraries (run this cell first and one by one all required libraries will be installed)
# - transformers: model + generation
# - sentence-transformers: dense embeddings / fine-tuning helpers
# - faiss-cpu (or faiss-gpu if GPU available)
# - rank_bm25: BM25 baseline
# - datasets: convenient JSONL loading
# - evaluate / sacrebleu: BLEU/chrF metrics
# - tqdm: progress bars
# - accelerate (optional) for distributed/faster training
!pip install -q transformers sentence-transformers faiss-cpu rank_bm25 datasets evaluate sacrebleu tqdm accelerate


Collecting faiss-cpu
  Downloading faiss_cpu-1.13.0-cp39-abi3-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (7.7 kB)
Collecting rank-bm25
  Downloading rank_bm25-0.2.2-py3-none-any.whl.metadata (3.2 kB)
Downloading faiss_cpu-1.13.0-cp39-abi3-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (23.6 MB)
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m23.6/23.6 MB[0m [31m76.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading rank_bm25-0.2.2-py3-none-any.whl (8.6 kB)
Installing collected packages: rank-bm25, faiss-cpu
Successfully installed faiss-cpu-1.13.0 rank-bm25-0.2.2


In [None]:
# Connect to google drive if not already connected
# 2. Mount Google Drive
# We need this to load your fine-tuned Dense Retriever and your Corpus file.
from google.colab import drive
drive.mount('/content/drive')

Saving urdu_covid_passages.tsv to urdu_covid_passages.tsv
Saving urdu_covid_corpus.jsonl to urdu_covid_corpus.jsonl
Saving eval_queries.jsonl to eval_queries.jsonl
Saving urdu_covid_passages_min.jsonl to urdu_covid_passages_min.jsonl


In [None]:
# Optional cell
# To add all the required files run
from google.colab import files
uploaded = files.upload()

In [None]:
!pip list # Optional to run this cell: To check which of the libraries/packages have been installed

In [None]:
# Cell 2: Imports and GPU check: Run this cell after the first cell
import os, json, time, math
from pathlib import Path
from tqdm.auto import tqdm
import numpy as np
import pandas as pd

# Transformers / sentence-transformers
import transformers
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModel
from sentence_transformers import SentenceTransformer, InputExample, losses, evaluation
import sentence_transformers # Import the package itself to access __version__

# FAISS and BM25
import faiss
from rank_bm25 import BM25Okapi

# Datasets and metrics
from datasets import load_dataset, Dataset
import evaluate
import sacrebleu

# Print versions and GPU info
print("transformers:", transformers.__version__)
print("sentence-transformers:", sentence_transformers.__version__)
try:
    import torch
    print("torch:", torch.__version__, "cuda:", torch.cuda.is_available())
except Exception as e:
    print("torch not available:", e)


python3: can't open file '/content/train_retrievers.py': [Errno 2] No such file or directory


In [None]:
# Cell 3: Load JSONL/TSV files into Python structures
# There will be a content folder on left side bar, files panel. This is our root
# folder. Inside it create a data folder, if not already present. Upload all files
# there and then run this cell.

DATA_DIR = Path("drive/MyDrive/data")  # change if files are elsewhere

# Create the data directory if it doesn't exist
import os
os.makedirs(DATA_DIR, exist_ok=True)

def load_jsonl(path):
    items = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            if line.strip():
                items.append(json.loads(line))
    return items

corpus_clean = load_jsonl(DATA_DIR / "urdu_covid_corpus_clean.jsonl")
passages_min = load_jsonl(DATA_DIR / "urdu_covid_passages_min.jsonl")
# TSV -> list of dicts
passages_tsv = []
with open(DATA_DIR / "urdu_covid_passages.tsv", "r", encoding="utf-8") as f:
    for line in f:
        # Use split(None, 1) to split on the first occurrence of any whitespace
        # This handles cases where the delimiter might be spaces instead of a tab.
        if line.strip(): # Ensure line is not empty after stripping whitespace
            parts = line.rstrip("\n").split(None, 1)
            if len(parts) == 2:
                pid, text = parts
                passages_tsv.append({"id": pid, "text": text})
            else:
                print(f"Skipping malformed line in urdu_covid_passages.tsv: {line.strip()}")

eval_queries = load_jsonl(DATA_DIR / "eval_queries.jsonl")
synthetic_pairs = load_jsonl(DATA_DIR / "synthetic_qa_pairs.jsonl")
hard_negatives = load_jsonl(DATA_DIR / "hard_negatives.jsonl")

print("Loaded:", len(corpus_clean), "corpus_clean; ", len(passages_min), "passages_min; ", len(eval_queries), "eval queries")


In [None]:
# Cell 4: Validate IDs referenced in eval/synthetic/hard_negatives exist in corpus
# Run this after Cell 3.
passage_ids = {p["id"] for p in passages_min}
missing = []
for q in eval_queries:
    for pid in q.get("positive_ids", []):
        if pid not in passage_ids:
            missing.append(("eval", q["query_id"], pid))
for s in synthetic_pairs:
    if s["positive_id"] not in passage_ids:
        missing.append(("synthetic", s["synthetic_id"], s["positive_id"]))
for h in hard_negatives:
    for pid in h["hard_negatives"]:
        if pid not in passage_ids:
            missing.append(("hardneg", h["query_id"], pid))
print("Missing references (should be zero):", len(missing))
if missing:
    print(missing[:10])


In [None]:
# Cell 5 (Run after Cell 4): BM25 baseline index (tokenize with simple whitespace; for Urdu this is OK as baseline)
# We'll store tokenized corpus and BM25 object for retrieval.
from nltk.tokenize import word_tokenize
# If nltk not installed, use simple split
try:
    import nltk
    nltk.download('punkt')
    nltk.download('punkt_tab') # Added to resolve LookupError for 'punkt_tab'
    tokenizer = lambda s: word_tokenize(s)
except Exception:
    tokenizer = lambda s: s.split()

corpus_texts = [p["text"] for p in passages_min]
corpus_ids = [p["id"] for p in passages_min]
tokenized_corpus = [tokenizer(t) for t in corpus_texts]
bm25 = BM25Okapi(tokenized_corpus)

# Example retrieval function
def bm25_retrieve(query, k=5):
    q_tokens = tokenizer(query)
    scores = bm25.get_scores(q_tokens)
    topk = np.argsort(scores)[::-1][:k]
    return [(corpus_ids[i], corpus_texts[i], float(scores[i])) for i in topk]

# Quick test
print("BM25 top-3 for sample:", bm25_retrieve(eval_queries[0]["query"], k=3))


In [None]:
# Cell 5b: BM25-only retriever evaluation tool (run after Cell 5)
# Purpose: standalone evaluation harness for the independent BM25 retriever (bm25_retrieve)
# Metrics included (applicable to a retriever-only evaluation):
#   - Recall@1, Recall@5
#   - MRR (Mean Reciprocal Rank)
#   - Precision@k (k=1,5)
#   - Average / median retrieval latency
#   - Optional: match by gold_passage_id or by substring match of gold_answer
# Output:
#   - Per-query JSONL saved to bm25_eval_results.jsonl
#   - Printed summary with all metrics
#
# Requirements (must be available in the session):
#   - bm25_retrieve(query, k) -> list of (passage_id, passage_text, score)
#   - eval_queries: list of dicts with at least a query field and optionally:
#       * "question" or "query" or "q"  (the query text)
#       * "gold_passage_id" (optional) OR "answer"/"gold" (gold text to match)
#
# Usage:
#   - Run this cell after you build the BM25 index (Cell 5).
#   - Optionally pass a different eval list or k values to evaluate subsets.

# Use this evaluator if your eval_queries items contain "positive_ids" and "gold_answer"
import json, time, re, statistics
from typing import List, Dict

OUT_JSONL = "bm25_eval_results.jsonl"
DEFAULT_K = 5
RECALL_KS = [1, 5]
PRECISION_KS = [1, 5]

def normalize_text(s: str) -> str:
    if s is None: return ""
    s = str(s).strip()
    return re.sub(r"\s+", " ", s)

def get_query_text(item: Dict) -> str:
    return item.get("query") or item.get("question") or item.get("q") or ""

def evaluate_bm25_with_positive_ids(eval_items: List[Dict],
                                    out_jsonl: str = OUT_JSONL,
                                    k: int = DEFAULT_K,
                                    recall_ks = RECALL_KS,
                                    precision_ks = PRECISION_KS):
    per_query = []
    latencies = []
    rr_list = []
    recall_counts = {rk: 0 for rk in recall_ks}
    precision_sums = {pk: 0.0 for pk in precision_ks}
    total = 0

    for item in eval_items:
        total += 1
        q = get_query_text(item)
        positive_ids = item.get("positive_ids") or item.get("positive_id") or []
        # normalize to list of strings
        if isinstance(positive_ids, str):
            positive_ids = [positive_ids]
        positive_ids = [str(x) for x in positive_ids]

        gold_text = normalize_text(item.get("gold_answer") or item.get("answer") or "")

        t0 = time.time()
        try:
            hits = bm25_retrieve(q, k=k)   # (id, text, score)
        except Exception as e:
            hits = []
            print(f"[eval] bm25_retrieve error for query {q[:60]}... -> {e}")
        latency = time.time() - t0
        latencies.append(latency)

        retrieved_ids = [h[0] for h in hits]
        retrieved_texts = [h[1] for h in hits]

        # Reciprocal rank: first position among positives
        rr = 0.0
        for rank, pid in enumerate(retrieved_ids, start=1):
            if pid in positive_ids:
                rr = 1.0 / rank
                break
        rr_list.append(rr)

        # Recall@k and Precision@k (multiple positives supported)
        for rk in recall_ks:
            recall_counts[rk] += 1 if any(pid in positive_ids for pid in retrieved_ids[:rk]) else 0
        for pk in precision_ks:
            # precision@k = (# positives in top-k) / k
            num_pos_in_topk = sum(1 for pid in retrieved_ids[:pk] if pid in positive_ids)
            precision_sums[pk] += (num_pos_in_topk / pk)

        per_query.append({
            "query_id": item.get("query_id"),
            "query": q,
            "positive_ids": positive_ids,
            "gold_text": gold_text,
            "retrieved_ids": retrieved_ids,
            "retrieved_texts_preview": [t[:300] for t in retrieved_texts],
            "reciprocal_rank": rr,
            "latency": latency
        })

    n = total if total else 1
    mrr = sum(rr_list) / n
    recall_at = {rk: recall_counts[rk] / n for rk in recall_ks}
    precision_at = {pk: precision_sums[pk] / n for pk in precision_ks}
    latency_mean = statistics.mean(latencies) if latencies else 0.0
    latency_median = statistics.median(latencies) if latencies else 0.0

    summary = {
        "n_queries": n,
        "MRR": mrr,
        **{f"Recall@{rk}": recall_at[rk] for rk in recall_ks},
        **{f"Precision@{pk}": precision_at[pk] for pk in precision_ks},
        "latency_mean_s": latency_mean,
        "latency_median_s": latency_median
    }

    with open(out_jsonl, "w", encoding="utf-8") as f:
        for r in per_query:
            f.write(json.dumps(r, ensure_ascii=False) + "\n")

    return summary, per_query

# Run it
if 'eval_queries' not in globals():
    # try to load from file if not in memory
    eval_queries = []
    with open("eval_queries.jsonl","r",encoding="utf-8") as f:
        for line in f:
            eval_queries.append(json.loads(line))

summary, records = evaluate_bm25_with_positive_ids(eval_queries, out_jsonl=OUT_JSONL, k=DEFAULT_K)
print("BM25 retrieval evaluation summary:")
for k,v in summary.items():
    print(f"  {k}: {v}")

# show a few examples where retrieval missed positives
misses = [r for r in records if r["reciprocal_rank"] == 0.0]
print(f"\nTotal misses: {len(misses)} / {len(records)}. Showing up to 5 misses:")
for r in misses[:5]:
    print("Query id:", r.get("query_id"), "Query:", r["query"][:80])
    print(" Positives:", r["positive_ids"])
    print(" Retrieved top ids:", r["retrieved_ids"][:8])
    print()


In [None]:
# Cell 6: Dense embeddings with a multilingual model (use a compact model for Colab)
# We use a multilingual SBERT model that supports Urdu reasonably (e.g., 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2')
embed_model_name = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
embedder = SentenceTransformer(embed_model_name)

# Compute embeddings for passages_min (batching)
passage_embeddings = embedder.encode(corpus_texts, show_progress_bar=True, convert_to_numpy=True)

# Build FAISS index (cosine similarity via normalized vectors)
d = passage_embeddings.shape[1]
index = faiss.IndexFlatIP(d)  # inner product
# normalize embeddings for cosine
faiss.normalize_L2(passage_embeddings)
index.add(passage_embeddings)

# Map index positions to ids
# retrieval function
def dense_retrieve(query, k=5):
    q_emb = embedder.encode([query], convert_to_numpy=True)
    faiss.normalize_L2(q_emb)
    D, I = index.search(q_emb, k)
    results = []
    for idx, score in zip(I[0], D[0]):
        results.append((corpus_ids[idx], corpus_texts[idx], float(score)))
    return results

# Quick test
print("Dense top-3:", dense_retrieve(eval_queries[0]["query"], k=3))


In [None]:
# Cell 6b: Evaluation of dense retriever (run after Cell 6)
# Purpose: measure Recall@1, Recall@5, MRR, Precision@k, latency for dense_retrieve
# Uses eval_queries with "positive_ids" and "gold_answer" fields

import json, time, re, statistics

OUT_JSONL_DENSE = "dense_eval_results.jsonl"
DEFAULT_K = 5
RECALL_KS = [1, 5]
PRECISION_KS = [1, 5]

def normalize_text(s):
    if s is None: return ""
    return re.sub(r"\s+", " ", str(s).strip())

def get_query_text(item):
    return item.get("query") or item.get("question") or item.get("q") or ""

def evaluate_dense(eval_items, out_jsonl=OUT_JSONL_DENSE, k=DEFAULT_K,
                   recall_ks=RECALL_KS, precision_ks=PRECISION_KS):
    per_query = []
    latencies, rr_list = [], []
    recall_counts = {rk: 0 for rk in recall_ks}
    precision_sums = {pk: 0.0 for pk in precision_ks}
    total = 0

    for item in eval_items:
        total += 1
        q = get_query_text(item)
        pos_ids = item.get("positive_ids") or []
        if isinstance(pos_ids, str): pos_ids = [pos_ids]
        pos_ids = [str(x) for x in pos_ids]

        gold_text = normalize_text(item.get("gold_answer") or "")

        t0 = time.time()
        hits = dense_retrieve(q, k=k)  # (id, text, score)
        latency = time.time() - t0
        latencies.append(latency)

        retrieved_ids = [h[0] for h in hits]
        retrieved_texts = [h[1] for h in hits]

        # Reciprocal rank
        rr = 0.0
        for rank, pid in enumerate(retrieved_ids, start=1):
            if pid in pos_ids:
                rr = 1.0 / rank
                break
        rr_list.append(rr)

        # Recall@k and Precision@k
        for rk in recall_ks:
            recall_counts[rk] += 1 if any(pid in pos_ids for pid in retrieved_ids[:rk]) else 0
        for pk in precision_ks:
            num_pos_in_topk = sum(1 for pid in retrieved_ids[:pk] if pid in pos_ids)
            precision_sums[pk] += num_pos_in_topk / pk

        per_query.append({
            "query_id": item.get("query_id"),
            "query": q,
            "positive_ids": pos_ids,
            "gold_text": gold_text,
            "retrieved_ids": retrieved_ids,
            "retrieved_texts_preview": [txt[:300] for txt in retrieved_texts],
            "reciprocal_rank": rr,
            "latency": latency
        })

    n = total if total else 1
    summary = {
        "n_queries": n,
        "MRR": sum(rr_list)/n,
        **{f"Recall@{rk}": recall_counts[rk]/n for rk in recall_ks},
        **{f"Precision@{pk}": precision_sums[pk]/n for pk in precision_ks},
        "latency_mean_s": statistics.mean(latencies) if latencies else 0.0,
        "latency_median_s": statistics.median(latencies) if latencies else 0.0
    }

    with open(out_jsonl, "w", encoding="utf-8") as f:
        for r in per_query:
            f.write(json.dumps(r, ensure_ascii=False) + "\n")

    return summary, per_query

# Run evaluation
print("[dense_eval] Running dense retriever evaluation...")
summary_dense, records_dense = evaluate_dense(eval_queries, out_jsonl=OUT_JSONL_DENSE, k=DEFAULT_K)
print("\nDense retriever evaluation summary:")
for k,v in summary_dense.items():
    print(f"  {k}: {v}")

# Show a few examples
print("\nExamples (first 5):")
for r in records_dense[:5]:
    print(" - Query:", r["query"][:80])
    print("   Retrieved ids:", r["retrieved_ids"][:6])
    print("   Reciprocal rank:", r["reciprocal_rank"], "Latency(s):", round(r["latency"], 4))
    print()


In [None]:
# Cell 7: Prepare InputExamples for sentence-transformers fine-tuning i.e. of dense retriever model
# Now with an 80/20 train/validation split

from sentence_transformers import InputExample
import random

pid2text = {p["id"]: p["text"] for p in passages_min}

examples = []
for s in synthetic_pairs:
    q = s["query"]
    pos = pid2text.get(s["positive_id"])
    neg = None
    # Find hard negatives if available
    hn = next((h for h in hard_negatives if h["query_id"] == s.get("synthetic_id", s.get("query_id"))), None)
    if hn:
        for nid in hn["hard_negatives"]:
            if nid != s["positive_id"]:
                neg = pid2text.get(nid)
                break
    if neg is None:
        # fallback: random negative
        neg_id = random.choice([pid for pid in corpus_ids if pid != s["positive_id"]])
        neg = pid2text[neg_id]
    if pos and neg:
        examples.append(InputExample(texts=[q, pos, neg]))

print("Prepared", len(examples), "triplet examples.")

# --- Split into train/validation (80/20) ---
random.shuffle(examples)
split_idx = int(0.8 * len(examples))
train_examples = examples[:split_idx]
val_examples = examples[split_idx:]

print("Train examples:", len(train_examples))
print("Validation examples:", len(val_examples))


In [None]:
# Cell 8 (use in-memory model; do NOT reload): Fine-tune SBERT with triplet loss and IR validation on passages_min
import os
# --- GRANDMASTER FIX: DISABLE WANDB ---
os.environ["WANDB_DISABLED"] = "true"
os.environ["WANDB_MODE"] = "disabled"
# --------------------------------------

from torch.utils.data import DataLoader
from sentence_transformers import SentenceTransformer, losses, evaluation
#import faiss

# Sanity checks
assert isinstance(train_examples, list) and len(train_examples) > 0, "train_examples must be a non-empty list"
assert 'passages_min' in globals(), "passages_min must be loaded"
assert 'eval_queries' in globals(), "eval_queries must be loaded"

# Build validation split against real corpus & labels
# (Check if eval_queries_val exists, otherwise split eval_queries)
eval_val = eval_queries_val if 'eval_queries_val' in globals() else eval_queries[int(0.8*len(eval_queries)):]

val_queries_dict = {it["query_id"]: it["query"] for it in eval_val}
# Fix: Ensure positive_ids is a list
val_relevant_dict = {it["query_id"]: set(it["positive_ids"] if isinstance(it["positive_ids"], list) else [it["positive_ids"]]) for it in eval_val}
val_corpus_dict = {p["id"]: p["text"] for p in passages_min}

# Warn if labels reference missing ids
missing = []
for qid, rels in val_relevant_dict.items():
    for pid in rels:
        if pid not in val_corpus_dict:
            missing.append((qid, pid))
if missing:
    print(f"Warning: {len(missing)} relevant ids not found in corpus. Example:", missing[:3])

# Construct evaluator (defaults to cosine similarity)
retrieval_evaluator = evaluation.InformationRetrievalEvaluator(
    queries=val_queries_dict,
    corpus=val_corpus_dict,
    relevant_docs=val_relevant_dict,
    name="val_ir_passages"
)

# Start from baseline multilingual MiniLM
# We use the variable 'embedder' from Cell 6 to ensure we continue correctly
if 'embedder' not in globals():
    embedder = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
embedder.to("cuda")

# Triplet loss with conservative settings
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=16)
train_loss = losses.TripletLoss(
    model=embedder,
    distance_metric=losses.TripletDistanceMetric.COSINE,
    triplet_margin=0.3
)

num_epochs = 2
warmup_steps = int(len(train_dataloader) * num_epochs * 0.1)
optimizer_params = {'lr': 2e-5}

print("Starting fine-tuning (WandB Disabled)...")

# Train with IR evaluator
embedder.fit(
    train_objectives=[(train_dataloader, train_loss)],
    evaluator=retrieval_evaluator,
    epochs=num_epochs,
    warmup_steps=warmup_steps,
    optimizer_params=optimizer_params,
    show_progress_bar=True,
    output_path="fine_tuned_sbert_urdu_passages"
)

print("‚úÖ Fine-tuning complete. Using in-memory fine-tuned 'embedder' (no reload).")

In [None]:
# Cell 8b: Save the Fine-Tuned Model to Drive (Run ONLY if satisfied with accuracy)
import os

# Define path
MODEL_SAVE_PATH = "/content/drive/MyDrive/models/urdu_dense_retriever_best"

print(f"üíæ Saving model to {MODEL_SAVE_PATH} ...")

# Create directory if not exists
os.makedirs(MODEL_SAVE_PATH, exist_ok=True)

# Save the model
embedder.save(MODEL_SAVE_PATH)

print(f"‚úÖ Model saved! You can now use Cell 8c in future sessions to skip training.")

In [None]:
# Cell 8c: FAST START - Load Model from Drive & Rebuild FAISS (Skips Training)
# Run this INSTEAD of Cells 6, 7, 8, 8b in future sessions.

import os
import faiss
from sentence_transformers import SentenceTransformer

MODEL_SAVE_PATH = "/content/drive/MyDrive/models/urdu_dense_retriever_best"

# 1. Load the Model
if os.path.exists(MODEL_SAVE_PATH):
    print(f"üìÇ Loading saved model from: {MODEL_SAVE_PATH}")
    embedder = SentenceTransformer(MODEL_SAVE_PATH).to("cuda")
    print("‚úÖ Model loaded successfully.")
else:
    raise FileNotFoundError(f"‚ùå No saved model found at {MODEL_SAVE_PATH}. Please run Cell 8 & 8b once to create it!")

# 2. Rebuild FAISS Index (Critical Step)
# We must re-encode the corpus because we just loaded a specific model
print("‚è≥ Generating embeddings for corpus...")
corpus_texts = [p["text"] for p in passages_min]

# Generate embeddings
passage_embeddings = embedder.encode(corpus_texts, show_progress_bar=True, convert_to_numpy=True)

# Build FAISS
faiss.normalize_L2(passage_embeddings)
index = faiss.IndexFlatIP(passage_embeddings.shape[1])
index.add(passage_embeddings)

# 3. Define the Retrieval Function
# (We must re-define this here because we skipped the previous cells that defined it)
def dense_retrieve(query, k=5):
    q_emb = embedder.encode([query], convert_to_numpy=True)
    faiss.normalize_L2(q_emb)
    D, I = index.search(q_emb, k)
    results = []
    for idx, score in zip(I[0], D[0]):
        results.append((corpus_ids[idx], corpus_texts[idx], float(score)))
    return results

print("‚úÖ Dense Retriever System Restored & Ready for Hybrid Fusion (Cell 9).")

We can now run cell 6b again to test the improvement of our dense retriever model after fine tuning.

In [None]:
# Cell 9 (final): Retriever wrapper with true fusion modes (non-destructive)
# - Creates bm25_new only if not present
# - Supports modes: 'bm25', 'dense', 'hybrid_interleave' (legacy), 'hybrid_score', 'hybrid_rrf'
# - Returns list of (pid, text, score) tuples
# - Does NOT rebuild or overwrite dense/FAISS objects

from datetime import datetime
import numpy as np
import re

# ---------- Config ----------
# Tune these later on a small validation set
DEFAULT_RETRIEVE_POOL = 50
SCORE_FUSION_ALPHA = 0.6   # alpha in [0,1] for score fusion: alpha * dense + (1-alpha) * bm25
RRF_K = 60                 # reciprocal rank fusion constant

# ---------- Sanity checks for canonical corpus ----------
assert 'passages_min' in globals() and isinstance(passages_min, list) and len(passages_min) > 0, "passages_min must be loaded"
assert 'pid2text' in globals() and isinstance(pid2text, dict) and len(pid2text) > 0, "pid2text must be available"
assert 'dense_retrieve' in globals(), "dense_retrieve wrapper must be defined (fine-tuned dense retriever)"

# ---------- Build or reuse BM25 index (non-destructive) ----------
try:
    # If bm25_new already exists from a previous run, reuse it
    bm25_new  # noqa: F821
except Exception:
    try:
        from rank_bm25 import BM25Okapi
    except Exception:
        import sys, subprocess
        subprocess.run([sys.executable, "-m", "pip", "install", "rank_bm25"], check=True)
        from rank_bm25 import BM25Okapi

    # Build tokenized corpus from passages_min (light normalization)
    def _normalize_for_bm25(s: str) -> str:
        if s is None:
            return ""
        s = s.replace("\u200c", " ")  # zero-width non-joiner
        s = re.sub(r"\s+", " ", s).strip()
        return s.lower()

    bm25_new_ids = [p["id"] for p in passages_min]
    bm25_new_texts = [p["text"] for p in passages_min]
    bm25_new_tokenized = [_normalize_for_bm25(t).split() for t in bm25_new_texts]
    bm25_new = BM25Okapi(bm25_new_tokenized)

# Safe wrapper for BM25 that returns (pid, text, score)
def bm25_new_retrieve(query: str, k: int = 5):
    q_tok = _normalize_for_bm25(query).split()
    scores = bm25_new.get_scores(q_tok)
    top_idx = np.argsort(scores)[::-1][:k]
    results = []
    for i in top_idx:
        i = int(i)
        pid = bm25_new_ids[i]
        text = bm25_new_texts[i]
        score = float(scores[i])
        results.append((pid, text, score))
    return results

# ---------- Fusion utilities ----------
def normalize_scores(score_map):
    """Min-max normalize a dict of scores to [0,1]."""
    if not score_map:
        return {}
    vals = list(score_map.values())
    lo, hi = min(vals), max(vals)
    if hi == lo:
        return {k: 1.0 for k in score_map}
    return {k: (v - lo) / (hi - lo) for k, v in score_map.items()}

def rrf_rank(dense_list, bm25_list, k_rrf=RRF_K):
    """Reciprocal Rank Fusion: returns sorted list of pids by RRF score."""
    score = {}
    for rank, (pid, _, _) in enumerate(dense_list, start=1):
        score[pid] = score.get(pid, 0.0) + 1.0 / (k_rrf + rank)
    for rank, (pid, _, _) in enumerate(bm25_list, start=1):
        score[pid] = score.get(pid, 0.0) + 1.0 / (k_rrf + rank)
    sorted_pids = sorted(score.keys(), key=lambda p: score[p], reverse=True)
    return sorted_pids, score

# ---------- Metadata filter helper (unchanged semantics) ----------
# If you have meta_map from corpus_clean, it will be used; otherwise fallback to passages_min metadata
if 'corpus_clean' in globals():
    meta_map = {p["id"]: p for p in corpus_clean}
else:
    meta_map = {p["id"]: p for p in passages_min}

def filter_by_metadata(candidate_ids, min_date=None, max_date=None, allowed_sources=None, exclude_time_sensitive=None):
    out = []
    for pid in candidate_ids:
        m = meta_map.get(pid, {})
        ok = True
        if min_date or max_date:
            dt = None
            if "retrieved_at" in m:
                try:
                    dt = datetime.fromisoformat(m["retrieved_at"])
                except Exception:
                    dt = None
            if dt:
                if min_date and dt < min_date: ok = False
                if max_date and dt > max_date: ok = False
        if allowed_sources and m.get("source") not in allowed_sources:
            ok = False
        if exclude_time_sensitive is not None and m.get("time_sensitive") == exclude_time_sensitive:
            ok = False
        if ok:
            out.append(pid)
    return out

# ---------- Main retrieve wrapper with fusion modes ----------
def retrieve(query: str, k: int = 5, mode: str = "hybrid_score", min_date=None, max_date=None, allowed_sources=None, exclude_time_sensitive=None):
    """
    retrieve(query, k, mode)
    Modes:
      - 'bm25' : BM25-only (bm25_new_retrieve)
      - 'dense' : dense-only (dense_retrieve)
      - 'hybrid_interleave' : legacy interleave (dense first, then bm25)
      - 'hybrid_score' : score fusion (normalized dense + bm25)
      - 'hybrid_rrf' : reciprocal rank fusion (RRF)
    Returns: list of (pid, text, score)
    """
    # Get candidate pools (pool size configurable)
    pool = max(DEFAULT_RETRIEVE_POOL, k)
    dense_hits = dense_retrieve(query, k=pool)   # expected (pid, text, score)
    bm25_hits = bm25_new_retrieve(query, k=pool) # (pid, text, score)

    # Mode-specific behavior
    if mode == "bm25":
        results = bm25_hits[:k]
    elif mode == "dense":
        results = dense_hits[:k]
    elif mode == "hybrid_interleave":
        # preserve dense-first interleaving (legacy behavior)
        seen = set()
        cands = []
        for lst in (dense_hits, bm25_hits):
            for pid, text, score in lst:
                if pid not in seen:
                    seen.add(pid)
                    cands.append((pid, text, float(score)))
        results = cands[:k]
    elif mode == "hybrid_score":
        # Score fusion: normalize and combine
        dense_scores = {pid: sc for pid, _, sc in dense_hits}
        bm25_scores = {pid: sc for pid, _, sc in bm25_hits}
        d_norm = normalize_scores(dense_scores)
        b_norm = normalize_scores(bm25_scores)
        alpha = SCORE_FUSION_ALPHA
        combined = {}
        for pid in set(list(d_norm.keys()) + list(b_norm.keys())):
            combined[pid] = alpha * d_norm.get(pid, 0.0) + (1 - alpha) * b_norm.get(pid, 0.0)
        # sort by combined score
        sorted_pids = sorted(combined.keys(), key=lambda p: combined[p], reverse=True)
        results = []
        for pid in sorted_pids[:k]:
            text = pid2text.get(pid, next((p["text"] for p in passages_min if p["id"] == pid), ""))
            results.append((pid, text, float(combined[pid])))
    elif mode == "hybrid_rrf":
        sorted_pids, score_map = rrf_rank(dense_hits, bm25_hits, k_rrf=RRF_K)
        results = []
        for pid in sorted_pids[:k]:
            text = pid2text.get(pid, next((p["text"] for p in passages_min if p["id"] == pid), ""))
            results.append((pid, text, float(score_map.get(pid, 0.0))))
    else:
        raise ValueError(f"Unknown retrieve mode: {mode}")

    # Apply metadata filters if requested (filter by pid only)
    if any([min_date, max_date, allowed_sources, exclude_time_sensitive is not None]):
        filtered_ids = filter_by_metadata([pid for pid,_,_ in results], min_date, max_date, allowed_sources, exclude_time_sensitive)
        results = [(pid, pid2text.get(pid, ""), score) for pid,_,score in results if pid in filtered_ids]

    return results

# ---------- Quick sample test (safe) ----------
q = eval_queries[0]["query"] if 'eval_queries' in globals() and len(eval_queries)>0 else "⁄©ŸàŸà⁄à-19 ⁄©€å ÿπÿßŸÖ ÿπŸÑÿßŸÖÿßÿ™ ⁄©€åÿß €Å€å⁄∫ÿü"
print("Sample dense top-5 ids:", [r[0] for r in dense_retrieve(q, k=5)])
print("Sample bm25_new top-5 ids:", [r[0] for r in bm25_new_retrieve(q, k=5)])
print("Sample hybrid_score top-5 ids:", [r[0] for r in retrieve(q, k=5, mode='hybrid_score')])
print("Sample hybrid_rrf top-5 ids:", [r[0] for r in retrieve(q, k=5, mode='hybrid_rrf')])


In [None]:
# Cell 9b: Run after above cell 9. Cell 9 creates B2M5 + Dense hybrid and below
# cell evaluates its performance:
# Validation diagnostics ‚Äî Recall@1, Recall@5, MRR, Precision@k for retrievers
# - Works with any mode supported by your Cell 9 wrapper: 'bm25', 'dense', 'hybrid_interleave', 'hybrid_score', 'hybrid_rrf'
# - Calls retrieve(...) and computes retrieval metrics
# - Outputs summary metrics and a few examples of misses

import time, statistics
from tqdm.auto import tqdm

DEFAULT_K = 5
RECALL_KS = [1, 5]
PRECISION_KS = [1, 5]

def evaluate_retriever(eval_items, mode="hybrid_score", k=DEFAULT_K,
                       recall_ks=RECALL_KS, precision_ks=PRECISION_KS):
    per_query = []
    latencies = []
    rr_list = []
    recall_counts = {rk: 0 for rk in recall_ks}
    precision_sums = {pk: 0.0 for pk in precision_ks}
    total = 0

    for item in tqdm(eval_items, desc=f"Evaluating {mode} retriever"):
        total += 1
        q = item.get("query") or item.get("question") or item.get("q") or ""
        positive_ids = item.get("positive_ids") or item.get("positive_id") or []
        if isinstance(positive_ids, str):
            positive_ids = [positive_ids]
        positive_ids = [str(x) for x in positive_ids]

        t0 = time.time()
        try:
            hits = retrieve(q, k=k, mode=mode)   # (pid, text, score)
        except Exception as e:
            hits = []
            print(f"[eval] retrieve error for query {q[:60]}... -> {e}")
        latency = time.time() - t0
        latencies.append(latency)

        retrieved_ids = [r[0] for r in hits]

        # Reciprocal rank: first position among positives
        rr = 0.0
        for rank, pid in enumerate(retrieved_ids, start=1):
            if pid in positive_ids:
                rr = 1.0 / rank
                break
        rr_list.append(rr)

        # Recall@k and Precision@k
        for rk in recall_ks:
            recall_counts[rk] += 1 if any(pid in positive_ids for pid in retrieved_ids[:rk]) else 0
        for pk in precision_ks:
            num_pos_in_topk = sum(1 for pid in retrieved_ids[:pk] if pid in positive_ids)
            precision_sums[pk] += (num_pos_in_topk / pk)

        per_query.append({
            "query": q,
            "positive_ids": positive_ids,
            "retrieved_ids": retrieved_ids,
            "reciprocal_rank": rr,
            "latency": latency
        })

    n = total if total else 1
    mrr = sum(rr_list) / n
    recall_at = {rk: recall_counts[rk] / n for rk in recall_ks}
    precision_at = {pk: precision_sums[pk] / n for pk in precision_ks}
    latency_mean = statistics.mean(latencies) if latencies else 0.0
    latency_median = statistics.median(latencies) if latencies else 0.0

    summary = {
        "n_queries": n,
        "MRR": mrr,
        **{f"Recall@{rk}": recall_at[rk] for rk in recall_ks},
        **{f"Precision@{pk}": precision_at[pk] for pk in precision_ks},
        "latency_mean_s": latency_mean,
        "latency_median_s": latency_median
    }

    return summary, per_query

# ---------- Run evaluation ----------
# Use eval_queries_val if defined, else fall back to eval_queries
eval_items = eval_queries_val if 'eval_queries_val' in globals() else eval_queries

# Evaluate all retriever modes
modes = ["bm25", "dense", "hybrid_interleave", "hybrid_score", "hybrid_rrf"]
results = {}
for m in modes:
    summary, records = evaluate_retriever(eval_items, mode=m, k=DEFAULT_K)
    results[m] = summary
    print(f"\n{m} retriever evaluation summary:")
    for k,v in summary.items():
        print(f"  {k}: {v:.3f}" if isinstance(v,float) else f"  {k}: {v}")

    # Show a few misses
    misses = [r for r in records if r["reciprocal_rank"] == 0.0]
    print(f"  Total misses: {len(misses)} / {len(records)}. Showing up to 3 misses:")
    for r in misses[:3]:
        print("   Query:", r["query"][:80])
        print("    Positives:", r["positive_ids"])
        print("    Retrieved top ids:", r["retrieved_ids"][:8])
