<a href="https://colab.research.google.com/github/ahmedsaalman/low-resource-rag-comparison/blob/main/NLP_Project_RAG.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Cell 1: Install required libraries (run this cell first and one by one all required libraries will be installed)
# - transformers: model + generation
# - sentence-transformers: dense embeddings / fine-tuning helpers
# - faiss-cpu (or faiss-gpu if GPU available)
# - rank_bm25: BM25 baseline
# - datasets: convenient JSONL loading
# - evaluate / sacrebleu: BLEU/chrF metrics
# - tqdm: progress bars
# - accelerate (optional) for distributed/faster training
!pip install -q transformers sentence-transformers faiss-cpu rank_bm25 datasets evaluate sacrebleu tqdm accelerate


[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m51.8/51.8 kB[0m [31m3.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.7/23.7 MB[0m [31m104.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m7.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m104.1/104.1 kB[0m [31m7.6 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
!pip list # Optional to run this cell: To check which of the libraries/packages have been installed

In [2]:
# Cell 2: Imports and GPU check: Run this cell after the first cell
import os, json, time, math
from pathlib import Path
from tqdm.auto import tqdm
import numpy as np
import pandas as pd

# Transformers / sentence-transformers
import transformers
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModel
from sentence_transformers import SentenceTransformer, InputExample, losses, evaluation
import sentence_transformers # Import the package itself to access __version__

# FAISS and BM25
import faiss
from rank_bm25 import BM25Okapi

# Datasets and metrics
from datasets import load_dataset, Dataset
import evaluate
import sacrebleu

# Print versions and GPU info
print("transformers:", transformers.__version__)
print("sentence-transformers:", sentence_transformers.__version__)
try:
    import torch
    print("torch:", torch.__version__, "cuda:", torch.cuda.is_available())
except Exception as e:
    print("torch not available:", e)


transformers: 4.57.3
sentence-transformers: 5.1.2
torch: 2.9.0+cu126 cuda: True


In [3]:
# Cell 3: Load JSONL/TSV files into Python structures
# There will be a content folder on left side bar, files panel. This is our root
# folder. Inside it create a data folder, if not already present. Upload all files
# there and then run this cell.

DATA_DIR = Path("drive/MyDrive/data")  # change if files are elsewhere

# Create the data directory if it doesn't exist
import os
os.makedirs(DATA_DIR, exist_ok=True)

def load_jsonl(path):
    items = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            if line.strip():
                items.append(json.loads(line))
    return items

corpus_clean = load_jsonl(DATA_DIR / "urdu_covid_corpus_clean.jsonl")
passages_min = load_jsonl(DATA_DIR / "urdu_covid_passages_min.jsonl")
# TSV -> list of dicts
passages_tsv = []
with open(DATA_DIR / "urdu_covid_passages.tsv", "r", encoding="utf-8") as f:
    for line in f:
        # Use split(None, 1) to split on the first occurrence of any whitespace
        # This handles cases where the delimiter might be spaces instead of a tab.
        if line.strip(): # Ensure line is not empty after stripping whitespace
            parts = line.rstrip("\n").split(None, 1)
            if len(parts) == 2:
                pid, text = parts
                passages_tsv.append({"id": pid, "text": text})
            else:
                print(f"Skipping malformed line in urdu_covid_passages.tsv: {line.strip()}")

eval_queries = load_jsonl(DATA_DIR / "eval_queries.jsonl")
synthetic_pairs = load_jsonl(DATA_DIR / "synthetic_qa_pairs.jsonl")
hard_negatives = load_jsonl(DATA_DIR / "hard_negatives.jsonl")

print("Loaded:", len(corpus_clean), "corpus_clean; ", len(passages_min), "passages_min; ", len(eval_queries), "eval queries")


Loaded: 60 corpus_clean;  60 passages_min;  100 eval queries


In [4]:
# Cell 4: Validate IDs referenced in eval/synthetic/hard_negatives exist in corpus
# Run this after Cell 3.
passage_ids = {p["id"] for p in passages_min}
missing = []
for q in eval_queries:
    for pid in q.get("positive_ids", []):
        if pid not in passage_ids:
            missing.append(("eval", q["query_id"], pid))
for s in synthetic_pairs:
    if s["positive_id"] not in passage_ids:
        missing.append(("synthetic", s["synthetic_id"], s["positive_id"]))
for h in hard_negatives:
    for pid in h["hard_negatives"]:
        if pid not in passage_ids:
            missing.append(("hardneg", h["query_id"], pid))
print("Missing references (should be zero):", len(missing))
if missing:
    print(missing[:10])


Missing references (should be zero): 0


In [5]:
# Cell 5 (Run after Cell 4): BM25 baseline index (tokenize with simple whitespace; for Urdu this is OK as baseline)
# We'll store tokenized corpus and BM25 object for retrieval.
from nltk.tokenize import word_tokenize
# If nltk not installed, use simple split
try:
    import nltk
    nltk.download('punkt')
    nltk.download('punkt_tab') # Added to resolve LookupError for 'punkt_tab'
    tokenizer = lambda s: word_tokenize(s)
except Exception:
    tokenizer = lambda s: s.split()

corpus_texts = [p["text"] for p in passages_min]
corpus_ids = [p["id"] for p in passages_min]
tokenized_corpus = [tokenizer(t) for t in corpus_texts]
bm25 = BM25Okapi(tokenized_corpus)

# Example retrieval function
def bm25_retrieve(query, k=5):
    q_tokens = tokenizer(query)
    scores = bm25.get_scores(q_tokens)
    topk = np.argsort(scores)[::-1][:k]
    return [(corpus_ids[i], corpus_texts[i], float(scores[i])) for i in topk]

# Quick test
print("BM25 top-3 for sample:", bm25_retrieve(eval_queries[0]["query"], k=3))


[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.
[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt_tab.zip.


BM25 top-3 for sample: [('p0001', 'کورونا وائرس مرض 2019 (COVID-19) ایک متعدی بیماری ہے جس کی عام علامات میں بخار، کھانسی اور سانس لینے میں دشواری شامل ہیں۔', 5.810063974702894), ('p0024', 'بچوں میں کووڈ-19 عام طور پر ہلکا ہوتا ہے مگر بعض نادر معاملات میں شدید علامات سامنے آ سکتی ہیں؛ بچوں کے لیے مخصوص رہنمائی مختلف ہو سکتی ہے۔', 5.103496839739362), ('p0002', 'کووڈ-19 کی تشخیص کے لیے rRT-PCR سویب ٹیسٹ عام طور پر استعمال ہوتے ہیں اور یہ وائرس کی موجودگی کی تصدیق کرتے ہیں۔', 4.589270107579207)]


In [6]:
# Cell 5b: BM25-only retriever evaluation tool (run after Cell 5)
# Purpose: standalone evaluation harness for the independent BM25 retriever (bm25_retrieve)
# Metrics included (applicable to a retriever-only evaluation):
#   - Recall@1, Recall@5
#   - MRR (Mean Reciprocal Rank)
#   - Precision@k (k=1,5)
#   - Average / median retrieval latency
#   - Optional: match by gold_passage_id or by substring match of gold_answer
# Output:
#   - Per-query JSONL saved to bm25_eval_results.jsonl
#   - Printed summary with all metrics
#
# Requirements (must be available in the session):
#   - bm25_retrieve(query, k) -> list of (passage_id, passage_text, score)
#   - eval_queries: list of dicts with at least a query field and optionally:
#       * "question" or "query" or "q"  (the query text)
#       * "gold_passage_id" (optional) OR "answer"/"gold" (gold text to match)
#
# Usage:
#   - Run this cell after you build the BM25 index (Cell 5).
#   - Optionally pass a different eval list or k values to evaluate subsets.

# Use this evaluator if your eval_queries items contain "positive_ids" and "gold_answer"
import json, time, re, statistics
from typing import List, Dict

OUT_JSONL = "bm25_eval_results.jsonl"
DEFAULT_K = 5
RECALL_KS = [1, 5]
PRECISION_KS = [1, 5]

def normalize_text(s: str) -> str:
    if s is None: return ""
    s = str(s).strip()
    return re.sub(r"\s+", " ", s)

def get_query_text(item: Dict) -> str:
    return item.get("query") or item.get("question") or item.get("q") or ""

def evaluate_bm25_with_positive_ids(eval_items: List[Dict],
                                    out_jsonl: str = OUT_JSONL,
                                    k: int = DEFAULT_K,
                                    recall_ks = RECALL_KS,
                                    precision_ks = PRECISION_KS):
    per_query = []
    latencies = []
    rr_list = []
    recall_counts = {rk: 0 for rk in recall_ks}
    precision_sums = {pk: 0.0 for pk in precision_ks}
    total = 0

    for item in eval_items:
        total += 1
        q = get_query_text(item)
        positive_ids = item.get("positive_ids") or item.get("positive_id") or []
        # normalize to list of strings
        if isinstance(positive_ids, str):
            positive_ids = [positive_ids]
        positive_ids = [str(x) for x in positive_ids]

        gold_text = normalize_text(item.get("gold_answer") or item.get("answer") or "")

        t0 = time.time()
        try:
            hits = bm25_retrieve(q, k=k)   # (id, text, score)
        except Exception as e:
            hits = []
            print(f"[eval] bm25_retrieve error for query {q[:60]}... -> {e}")
        latency = time.time() - t0
        latencies.append(latency)

        retrieved_ids = [h[0] for h in hits]
        retrieved_texts = [h[1] for h in hits]

        # Reciprocal rank: first position among positives
        rr = 0.0
        for rank, pid in enumerate(retrieved_ids, start=1):
            if pid in positive_ids:
                rr = 1.0 / rank
                break
        rr_list.append(rr)

        # Recall@k and Precision@k (multiple positives supported)
        for rk in recall_ks:
            recall_counts[rk] += 1 if any(pid in positive_ids for pid in retrieved_ids[:rk]) else 0
        for pk in precision_ks:
            # precision@k = (# positives in top-k) / k
            num_pos_in_topk = sum(1 for pid in retrieved_ids[:pk] if pid in positive_ids)
            precision_sums[pk] += (num_pos_in_topk / pk)

        per_query.append({
            "query_id": item.get("query_id"),
            "query": q,
            "positive_ids": positive_ids,
            "gold_text": gold_text,
            "retrieved_ids": retrieved_ids,
            "retrieved_texts_preview": [t[:300] for t in retrieved_texts],
            "reciprocal_rank": rr,
            "latency": latency
        })

    n = total if total else 1
    mrr = sum(rr_list) / n
    recall_at = {rk: recall_counts[rk] / n for rk in recall_ks}
    precision_at = {pk: precision_sums[pk] / n for pk in precision_ks}
    latency_mean = statistics.mean(latencies) if latencies else 0.0
    latency_median = statistics.median(latencies) if latencies else 0.0

    summary = {
        "n_queries": n,
        "MRR": mrr,
        **{f"Recall@{rk}": recall_at[rk] for rk in recall_ks},
        **{f"Precision@{pk}": precision_at[pk] for pk in precision_ks},
        "latency_mean_s": latency_mean,
        "latency_median_s": latency_median
    }

    with open(out_jsonl, "w", encoding="utf-8") as f:
        for r in per_query:
            f.write(json.dumps(r, ensure_ascii=False) + "\n")

    return summary, per_query

# Run it
if 'eval_queries' not in globals():
    # try to load from file if not in memory
    eval_queries = []
    with open("eval_queries.jsonl","r",encoding="utf-8") as f:
        for line in f:
            eval_queries.append(json.loads(line))

summary, records = evaluate_bm25_with_positive_ids(eval_queries, out_jsonl=OUT_JSONL, k=DEFAULT_K)
print("BM25 retrieval evaluation summary:")
for k,v in summary.items():
    print(f"  {k}: {v}")

# show a few examples where retrieval missed positives
misses = [r for r in records if r["reciprocal_rank"] == 0.0]
print(f"\nTotal misses: {len(misses)} / {len(records)}. Showing up to 5 misses:")
for r in misses[:5]:
    print("Query id:", r.get("query_id"), "Query:", r["query"][:80])
    print(" Positives:", r["positive_ids"])
    print(" Retrieved top ids:", r["retrieved_ids"][:8])
    print()


BM25 retrieval evaluation summary:
  n_queries: 100
  MRR: 0.8853333333333333
  Recall@1: 0.84
  Recall@5: 0.95
  Precision@1: 0.84
  Precision@5: 0.21599999999999964
  latency_mean_s: 0.00045699119567871094
  latency_median_s: 0.0004508495330810547

Total misses: 5 / 100. Showing up to 5 misses:
Query id: q007 Query: کووڈ-19 ویکسین کا بنیادی مقصد کیا ہے؟
 Positives: ['p0007']
 Retrieved top ids: ['p0028', 'p0050', 'p0051', 'p0027', 'p0039']

Query id: q019 Query: وینٹیلیشن وبا کے دوران کیوں اہم ہے؟
 Positives: ['p0020']
 Retrieved top ids: ['p0017', 'p0060', 'p0031', 'p0048', 'p0027']

Query id: q038 Query: ویکسین سائیڈ ایفیکٹس کی نگرانی کیسے کی جاتی ہے؟
 Positives: ['p0039']
 Retrieved top ids: ['p0058', 'p0040', 'p0032', 'p0051', 'p0011']

Query id: q065 Query: ویکسین کی سائیڈ ایفیکٹس کی رپورٹنگ کیسے ہوتی ہے؟
 Positives: ['p0039']
 Retrieved top ids: ['p0058', 'p0047', 'p0032', 'p0022', 'p0025']

Query id: q095 Query: وبا کے دوران معاشی بحالی کے لیے کون سے اقدامات کیے جا سکتے ہیں؟
 

In [7]:
# Cell 6: Dense embeddings with a multilingual model (use a compact model for Colab)
# We use a multilingual SBERT model that supports Urdu reasonably (e.g., 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2')
embed_model_name = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
embedder = SentenceTransformer(embed_model_name)

# Compute embeddings for passages_min (batching)
passage_embeddings = embedder.encode(corpus_texts, show_progress_bar=True, convert_to_numpy=True)

# Build FAISS index (cosine similarity via normalized vectors)
d = passage_embeddings.shape[1]
index = faiss.IndexFlatIP(d)  # inner product
# normalize embeddings for cosine
faiss.normalize_L2(passage_embeddings)
index.add(passage_embeddings)

# Map index positions to ids
# retrieval function
def dense_retrieve(query, k=5):
    q_emb = embedder.encode([query], convert_to_numpy=True)
    faiss.normalize_L2(q_emb)
    D, I = index.search(q_emb, k)
    results = []
    for idx, score in zip(I[0], D[0]):
        results.append((corpus_ids[idx], corpus_texts[idx], float(score)))
    return results

# Quick test
print("Dense top-3:", dense_retrieve(eval_queries[0]["query"], k=3))


modules.json:   0%|          | 0.00/229 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/122 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/645 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/471M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/480 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.08M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Batches:   0%|          | 0/2 [00:00<?, ?it/s]

Dense top-3: [('p0024', 'بچوں میں کووڈ-19 عام طور پر ہلکا ہوتا ہے مگر بعض نادر معاملات میں شدید علامات سامنے آ سکتی ہیں؛ بچوں کے لیے مخصوص رہنمائی مختلف ہو سکتی ہے۔', 0.7648560404777527), ('p0021', 'کووڈ-19 کے بعد بعض افراد میں طویل مدتی علامات (Long COVID) جیسے تھکن، سانس کی تکلیف اور دماغی دھند برقرار رہ سکتی ہیں؛ ریہیب پروگرامز مدد دیتے ہیں۔', 0.6735703349113464), ('p0036', 'کووڈ-19 کے مریضوں میں خون جمنے کے مسائل اور دیگر پیچیدگیاں بعض اوقات سامنے آئیں، اس لیے طبی نگرانی اور مناسب علاج ضروری ہے۔', 0.647298276424408)]


In [8]:
# Cell 6b: Evaluation of dense retriever (run after Cell 6)
# Purpose: measure Recall@1, Recall@5, MRR, Precision@k, latency for dense_retrieve
# Uses eval_queries with "positive_ids" and "gold_answer" fields

import json, time, re, statistics

OUT_JSONL_DENSE = "dense_eval_results.jsonl"
DEFAULT_K = 5
RECALL_KS = [1, 5]
PRECISION_KS = [1, 5]

def normalize_text(s):
    if s is None: return ""
    return re.sub(r"\s+", " ", str(s).strip())

def get_query_text(item):
    return item.get("query") or item.get("question") or item.get("q") or ""

def evaluate_dense(eval_items, out_jsonl=OUT_JSONL_DENSE, k=DEFAULT_K,
                   recall_ks=RECALL_KS, precision_ks=PRECISION_KS):
    per_query = []
    latencies, rr_list = [], []
    recall_counts = {rk: 0 for rk in recall_ks}
    precision_sums = {pk: 0.0 for pk in precision_ks}
    total = 0

    for item in eval_items:
        total += 1
        q = get_query_text(item)
        pos_ids = item.get("positive_ids") or []
        if isinstance(pos_ids, str): pos_ids = [pos_ids]
        pos_ids = [str(x) for x in pos_ids]

        gold_text = normalize_text(item.get("gold_answer") or "")

        t0 = time.time()
        hits = dense_retrieve(q, k=k)  # (id, text, score)
        latency = time.time() - t0
        latencies.append(latency)

        retrieved_ids = [h[0] for h in hits]
        retrieved_texts = [h[1] for h in hits]

        # Reciprocal rank
        rr = 0.0
        for rank, pid in enumerate(retrieved_ids, start=1):
            if pid in pos_ids:
                rr = 1.0 / rank
                break
        rr_list.append(rr)

        # Recall@k and Precision@k
        for rk in recall_ks:
            recall_counts[rk] += 1 if any(pid in pos_ids for pid in retrieved_ids[:rk]) else 0
        for pk in precision_ks:
            num_pos_in_topk = sum(1 for pid in retrieved_ids[:pk] if pid in pos_ids)
            precision_sums[pk] += num_pos_in_topk / pk

        per_query.append({
            "query_id": item.get("query_id"),
            "query": q,
            "positive_ids": pos_ids,
            "gold_text": gold_text,
            "retrieved_ids": retrieved_ids,
            "retrieved_texts_preview": [txt[:300] for txt in retrieved_texts],
            "reciprocal_rank": rr,
            "latency": latency
        })

    n = total if total else 1
    summary = {
        "n_queries": n,
        "MRR": sum(rr_list)/n,
        **{f"Recall@{rk}": recall_counts[rk]/n for rk in recall_ks},
        **{f"Precision@{pk}": precision_sums[pk]/n for pk in precision_ks},
        "latency_mean_s": statistics.mean(latencies) if latencies else 0.0,
        "latency_median_s": statistics.median(latencies) if latencies else 0.0
    }

    with open(out_jsonl, "w", encoding="utf-8") as f:
        for r in per_query:
            f.write(json.dumps(r, ensure_ascii=False) + "\n")

    return summary, per_query

# Run evaluation
print("[dense_eval] Running dense retriever evaluation...")
summary_dense, records_dense = evaluate_dense(eval_queries, out_jsonl=OUT_JSONL_DENSE, k=DEFAULT_K)
print("\nDense retriever evaluation summary:")
for k,v in summary_dense.items():
    print(f"  {k}: {v}")

# Show a few examples
print("\nExamples (first 5):")
for r in records_dense[:5]:
    print(" - Query:", r["query"][:80])
    print("   Retrieved ids:", r["retrieved_ids"][:6])
    print("   Reciprocal rank:", r["reciprocal_rank"], "Latency(s):", round(r["latency"], 4))
    print()


[dense_eval] Running dense retriever evaluation...

Dense retriever evaluation summary:
  n_queries: 100
  MRR: 0.7956666666666666
  Recall@1: 0.7
  Recall@5: 0.92
  Precision@1: 0.7
  Precision@5: 0.20199999999999968
  latency_mean_s: 0.02900353670120239
  latency_median_s: 0.022706270217895508

Examples (first 5):
 - Query: کووڈ-19 کی عام علامات کیا ہیں؟
   Retrieved ids: ['p0024', 'p0021', 'p0036', 'p0001', 'p0019']
   Reciprocal rank: 0.25 Latency(s): 0.0575

 - Query: کووڈ-19 کی تشخیص کے لیے کون سا ٹیسٹ عام طور پر استعمال ہوتا ہے؟
   Retrieved ids: ['p0018', 'p0002', 'p0036', 'p0055', 'p0024']
   Reciprocal rank: 0.5 Latency(s): 0.0394

 - Query: ہاتھوں کی صفائی وبا کے دوران کیوں ضروری ہے؟
   Retrieved ids: ['p0030', 'p0003', 'p0042', 'p0046', 'p0041']
   Reciprocal rank: 0.5 Latency(s): 0.0494

 - Query: ماسک پہننے کے کیا فوائد ہیں؟
   Retrieved ids: ['p0004', 'p0029', 'p0042', 'p0022', 'p0050']
   Reciprocal rank: 1.0 Latency(s): 0.0749

 - Query: سماجی فاصلہ رکھنے کی اہمیت کیا 

In [9]:
# Cell 7: Prepare InputExamples for sentence-transformers fine-tuning
# Now with an 80/20 train/validation split

from sentence_transformers import InputExample
import random

pid2text = {p["id"]: p["text"] for p in passages_min}

examples = []
for s in synthetic_pairs:
    q = s["query"]
    pos = pid2text.get(s["positive_id"])
    neg = None
    # Find hard negatives if available
    hn = next((h for h in hard_negatives if h["query_id"] == s.get("synthetic_id", s.get("query_id"))), None)
    if hn:
        for nid in hn["hard_negatives"]:
            if nid != s["positive_id"]:
                neg = pid2text.get(nid)
                break
    if neg is None:
        # fallback: random negative
        neg_id = random.choice([pid for pid in corpus_ids if pid != s["positive_id"]])
        neg = pid2text[neg_id]
    if pos and neg:
        examples.append(InputExample(texts=[q, pos, neg]))

print("Prepared", len(examples), "triplet examples.")

# --- Split into train/validation (80/20) ---
random.shuffle(examples)
split_idx = int(0.8 * len(examples))
train_examples = examples[:split_idx]
val_examples = examples[split_idx:]

print("Train examples:", len(train_examples))
print("Validation examples:", len(val_examples))


Prepared 500 triplet examples.
Train examples: 400
Validation examples: 100


In [10]:
# Cell 8 (use in-memory model; do NOT reload): Fine-tune SBERT with triplet loss and IR validation on passages_min

from torch.utils.data import DataLoader
from sentence_transformers import SentenceTransformer, losses, evaluation
import faiss

# Sanity checks
assert isinstance(train_examples, list) and len(train_examples) > 0, "train_examples must be a non-empty list"
assert 'passages_min' in globals(), "passages_min must be loaded"
assert 'eval_queries' in globals(), "eval_queries must be loaded"

# Build validation split against real corpus & labels
eval_val = eval_queries_val if 'eval_queries_val' in globals() else eval_queries[int(0.8*len(eval_queries)):]
val_queries_dict = {it["query_id"]: it["query"] for it in eval_val}
val_relevant_dict = {it["query_id"]: set(it["positive_ids"]) for it in eval_val}
val_corpus_dict = {p["id"]: p["text"] for p in passages_min}

# Warn if labels reference missing ids
missing = []
for qid, rels in val_relevant_dict.items():
    for pid in rels:
        if pid not in val_corpus_dict:
            missing.append((qid, pid))
if missing:
    print(f"Warning: {len(missing)} relevant ids not found in corpus. Example:", missing[:3])

# Construct evaluator (defaults to cosine similarity)
retrieval_evaluator = evaluation.InformationRetrievalEvaluator(
    queries=val_queries_dict,
    corpus=val_corpus_dict,
    relevant_docs=val_relevant_dict,
    name="val_ir_passages"
)

# Start from baseline multilingual MiniLM
embedder = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2").to("cuda")

# Triplet loss with conservative settings
train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=16)
train_loss = losses.TripletLoss(
    model=embedder,
    distance_metric=losses.TripletDistanceMetric.COSINE,
    triplet_margin=0.3
)

num_epochs = 2
warmup_steps = int(len(train_dataloader) * num_epochs * 0.1)
optimizer_params = {'lr': 2e-5}

# Train with IR evaluator
embedder.fit(
    train_objectives=[(train_dataloader, train_loss)],
    evaluator=retrieval_evaluator,
    epochs=num_epochs,
    warmup_steps=warmup_steps,
    optimizer_params=optimizer_params,
    show_progress_bar=True,
    output_path="fine_tuned_sbert_urdu_passages"
)

print("✅ Fine-tuning complete. Using in-memory fine-tuned 'embedder' (no reload).")


Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

  | |_| | '_ \/ _` / _` |  _/ -_)
[34m[1mwandb[0m: (1) Create a W&B account
[34m[1mwandb[0m: (2) Use an existing W&B account
[34m[1mwandb[0m: (3) Don't visualize my results
[34m[1mwandb[0m: Enter your choice:

 


[34m[1mwandb[0m: Enter your choice:

 2


[34m[1mwandb[0m: You chose 'Use an existing W&B account'
[34m[1mwandb[0m: Logging into https://api.wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: Find your API key here: https://wandb.ai/authorize?ref=models
[34m[1mwandb[0m: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mmwaqarsaleem1[0m ([33mmwaqarsaleem1-national-university-of-computing-and-emerg[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Step,Training Loss,Validation Loss,Val Ir Passages Cosine Accuracy@1,Val Ir Passages Cosine Accuracy@3,Val Ir Passages Cosine Accuracy@5,Val Ir Passages Cosine Accuracy@10,Val Ir Passages Cosine Precision@1,Val Ir Passages Cosine Precision@3,Val Ir Passages Cosine Precision@5,Val Ir Passages Cosine Precision@10,Val Ir Passages Cosine Recall@1,Val Ir Passages Cosine Recall@3,Val Ir Passages Cosine Recall@5,Val Ir Passages Cosine Recall@10,Val Ir Passages Cosine Ndcg@10,Val Ir Passages Cosine Mrr@10,Val Ir Passages Cosine Map@100
25,No log,No log,0.8,0.9,0.95,0.95,0.8,0.4,0.26,0.14,0.575,0.775,0.825,0.875,0.816226,0.8625,0.767717
50,No log,No log,0.85,0.9,0.9,1.0,0.85,0.4,0.25,0.14,0.6,0.775,0.8,0.875,0.821244,0.88,0.774146


✅ Fine-tuning complete. Using in-memory fine-tuned 'embedder' (no reload).


In [11]:
# Cell 8b: Rebuild FAISS with fine-tuned in-memory embedder

import faiss

# Use the current in-memory fine-tuned 'embedder'
corpus_texts = [p["text"] for p in passages_min]
passage_embeddings = embedder.encode(corpus_texts, show_progress_bar=True, convert_to_numpy=True)

faiss.normalize_L2(passage_embeddings)
index = faiss.IndexFlatIP(passage_embeddings.shape[1])
index.add(passage_embeddings)

print("✅ FAISS rebuilt with fine-tuned embeddings (in-memory model).")


Batches:   0%|          | 0/2 [00:00<?, ?it/s]

✅ FAISS rebuilt with fine-tuned embeddings (in-memory model).


In [12]:
# Cell 9 (final): Retriever wrapper with true fusion modes (non-destructive)
# - Creates bm25_new only if not present
# - Supports modes: 'bm25', 'dense', 'hybrid_interleave' (legacy), 'hybrid_score', 'hybrid_rrf'
# - Returns list of (pid, text, score) tuples
# - Does NOT rebuild or overwrite dense/FAISS objects

from datetime import datetime
import numpy as np
import re

# ---------- Config ----------
# Tune these later on a small validation set
DEFAULT_RETRIEVE_POOL = 50
SCORE_FUSION_ALPHA = 0.6   # alpha in [0,1] for score fusion: alpha * dense + (1-alpha) * bm25
RRF_K = 60                 # reciprocal rank fusion constant

# ---------- Sanity checks for canonical corpus ----------
assert 'passages_min' in globals() and isinstance(passages_min, list) and len(passages_min) > 0, "passages_min must be loaded"
assert 'pid2text' in globals() and isinstance(pid2text, dict) and len(pid2text) > 0, "pid2text must be available"
assert 'dense_retrieve' in globals(), "dense_retrieve wrapper must be defined (fine-tuned dense retriever)"

# ---------- Build or reuse BM25 index (non-destructive) ----------
try:
    # If bm25_new already exists from a previous run, reuse it
    bm25_new  # noqa: F821
except Exception:
    try:
        from rank_bm25 import BM25Okapi
    except Exception:
        import sys, subprocess
        subprocess.run([sys.executable, "-m", "pip", "install", "rank_bm25"], check=True)
        from rank_bm25 import BM25Okapi

    # Build tokenized corpus from passages_min (light normalization)
    def _normalize_for_bm25(s: str) -> str:
        if s is None:
            return ""
        s = s.replace("\u200c", " ")  # zero-width non-joiner
        s = re.sub(r"\s+", " ", s).strip()
        return s.lower()

    bm25_new_ids = [p["id"] for p in passages_min]
    bm25_new_texts = [p["text"] for p in passages_min]
    bm25_new_tokenized = [_normalize_for_bm25(t).split() for t in bm25_new_texts]
    bm25_new = BM25Okapi(bm25_new_tokenized)

# Safe wrapper for BM25 that returns (pid, text, score)
def bm25_new_retrieve(query: str, k: int = 5):
    q_tok = _normalize_for_bm25(query).split()
    scores = bm25_new.get_scores(q_tok)
    top_idx = np.argsort(scores)[::-1][:k]
    results = []
    for i in top_idx:
        i = int(i)
        pid = bm25_new_ids[i]
        text = bm25_new_texts[i]
        score = float(scores[i])
        results.append((pid, text, score))
    return results

# ---------- Fusion utilities ----------
def normalize_scores(score_map):
    """Min-max normalize a dict of scores to [0,1]."""
    if not score_map:
        return {}
    vals = list(score_map.values())
    lo, hi = min(vals), max(vals)
    if hi == lo:
        return {k: 1.0 for k in score_map}
    return {k: (v - lo) / (hi - lo) for k, v in score_map.items()}

def rrf_rank(dense_list, bm25_list, k_rrf=RRF_K):
    """Reciprocal Rank Fusion: returns sorted list of pids by RRF score."""
    score = {}
    for rank, (pid, _, _) in enumerate(dense_list, start=1):
        score[pid] = score.get(pid, 0.0) + 1.0 / (k_rrf + rank)
    for rank, (pid, _, _) in enumerate(bm25_list, start=1):
        score[pid] = score.get(pid, 0.0) + 1.0 / (k_rrf + rank)
    sorted_pids = sorted(score.keys(), key=lambda p: score[p], reverse=True)
    return sorted_pids, score

# ---------- Metadata filter helper (unchanged semantics) ----------
# If you have meta_map from corpus_clean, it will be used; otherwise fallback to passages_min metadata
if 'corpus_clean' in globals():
    meta_map = {p["id"]: p for p in corpus_clean}
else:
    meta_map = {p["id"]: p for p in passages_min}

def filter_by_metadata(candidate_ids, min_date=None, max_date=None, allowed_sources=None, exclude_time_sensitive=None):
    out = []
    for pid in candidate_ids:
        m = meta_map.get(pid, {})
        ok = True
        if min_date or max_date:
            dt = None
            if "retrieved_at" in m:
                try:
                    dt = datetime.fromisoformat(m["retrieved_at"])
                except Exception:
                    dt = None
            if dt:
                if min_date and dt < min_date: ok = False
                if max_date and dt > max_date: ok = False
        if allowed_sources and m.get("source") not in allowed_sources:
            ok = False
        if exclude_time_sensitive is not None and m.get("time_sensitive") == exclude_time_sensitive:
            ok = False
        if ok:
            out.append(pid)
    return out

# ---------- Main retrieve wrapper with fusion modes ----------
def retrieve(query: str, k: int = 5, mode: str = "hybrid_score", min_date=None, max_date=None, allowed_sources=None, exclude_time_sensitive=None):
    """
    retrieve(query, k, mode)
    Modes:
      - 'bm25' : BM25-only (bm25_new_retrieve)
      - 'dense' : dense-only (dense_retrieve)
      - 'hybrid_interleave' : legacy interleave (dense first, then bm25)
      - 'hybrid_score' : score fusion (normalized dense + bm25)
      - 'hybrid_rrf' : reciprocal rank fusion (RRF)
    Returns: list of (pid, text, score)
    """
    # Get candidate pools (pool size configurable)
    pool = max(DEFAULT_RETRIEVE_POOL, k)
    dense_hits = dense_retrieve(query, k=pool)   # expected (pid, text, score)
    bm25_hits = bm25_new_retrieve(query, k=pool) # (pid, text, score)

    # Mode-specific behavior
    if mode == "bm25":
        results = bm25_hits[:k]
    elif mode == "dense":
        results = dense_hits[:k]
    elif mode == "hybrid_interleave":
        # preserve dense-first interleaving (legacy behavior)
        seen = set()
        cands = []
        for lst in (dense_hits, bm25_hits):
            for pid, text, score in lst:
                if pid not in seen:
                    seen.add(pid)
                    cands.append((pid, text, float(score)))
        results = cands[:k]
    elif mode == "hybrid_score":
        # Score fusion: normalize and combine
        dense_scores = {pid: sc for pid, _, sc in dense_hits}
        bm25_scores = {pid: sc for pid, _, sc in bm25_hits}
        d_norm = normalize_scores(dense_scores)
        b_norm = normalize_scores(bm25_scores)
        alpha = SCORE_FUSION_ALPHA
        combined = {}
        for pid in set(list(d_norm.keys()) + list(b_norm.keys())):
            combined[pid] = alpha * d_norm.get(pid, 0.0) + (1 - alpha) * b_norm.get(pid, 0.0)
        # sort by combined score
        sorted_pids = sorted(combined.keys(), key=lambda p: combined[p], reverse=True)
        results = []
        for pid in sorted_pids[:k]:
            text = pid2text.get(pid, next((p["text"] for p in passages_min if p["id"] == pid), ""))
            results.append((pid, text, float(combined[pid])))
    elif mode == "hybrid_rrf":
        sorted_pids, score_map = rrf_rank(dense_hits, bm25_hits, k_rrf=RRF_K)
        results = []
        for pid in sorted_pids[:k]:
            text = pid2text.get(pid, next((p["text"] for p in passages_min if p["id"] == pid), ""))
            results.append((pid, text, float(score_map.get(pid, 0.0))))
    else:
        raise ValueError(f"Unknown retrieve mode: {mode}")

    # Apply metadata filters if requested (filter by pid only)
    if any([min_date, max_date, allowed_sources, exclude_time_sensitive is not None]):
        filtered_ids = filter_by_metadata([pid for pid,_,_ in results], min_date, max_date, allowed_sources, exclude_time_sensitive)
        results = [(pid, pid2text.get(pid, ""), score) for pid,_,score in results if pid in filtered_ids]

    return results

# ---------- Quick sample test (safe) ----------
q = eval_queries[0]["query"] if 'eval_queries' in globals() and len(eval_queries)>0 else "کووڈ-19 کی عام علامات کیا ہیں؟"
print("Sample dense top-5 ids:", [r[0] for r in dense_retrieve(q, k=5)])
print("Sample bm25_new top-5 ids:", [r[0] for r in bm25_new_retrieve(q, k=5)])
print("Sample hybrid_score top-5 ids:", [r[0] for r in retrieve(q, k=5, mode='hybrid_score')])
print("Sample hybrid_rrf top-5 ids:", [r[0] for r in retrieve(q, k=5, mode='hybrid_rrf')])


Sample dense top-5 ids: ['p0024', 'p0001', 'p0021', 'p0044', 'p0008']
Sample bm25_new top-5 ids: ['p0001', 'p0024', 'p0002', 'p0044', 'p0021']
Sample hybrid_score top-5 ids: ['p0001', 'p0024', 'p0044', 'p0021', 'p0002']
Sample hybrid_rrf top-5 ids: ['p0024', 'p0001', 'p0021', 'p0044', 'p0002']


In [13]:
# Cell 9b: Run after above cell 9. Cell 9 creates B2M5 + Dense hybrid and below
# cell evaluates its performance:
# Validation diagnostics — Recall@1, Recall@5, MRR, Precision@k for retrievers
# - Works with any mode supported by your Cell 9 wrapper: 'bm25', 'dense', 'hybrid_interleave', 'hybrid_score', 'hybrid_rrf'
# - Calls retrieve(...) and computes retrieval metrics
# - Outputs summary metrics and a few examples of misses

import time, statistics
from tqdm.auto import tqdm

DEFAULT_K = 5
RECALL_KS = [1, 5]
PRECISION_KS = [1, 5]

def evaluate_retriever(eval_items, mode="hybrid_score", k=DEFAULT_K,
                       recall_ks=RECALL_KS, precision_ks=PRECISION_KS):
    per_query = []
    latencies = []
    rr_list = []
    recall_counts = {rk: 0 for rk in recall_ks}
    precision_sums = {pk: 0.0 for pk in precision_ks}
    total = 0

    for item in tqdm(eval_items, desc=f"Evaluating {mode} retriever"):
        total += 1
        q = item.get("query") or item.get("question") or item.get("q") or ""
        positive_ids = item.get("positive_ids") or item.get("positive_id") or []
        if isinstance(positive_ids, str):
            positive_ids = [positive_ids]
        positive_ids = [str(x) for x in positive_ids]

        t0 = time.time()
        try:
            hits = retrieve(q, k=k, mode=mode)   # (pid, text, score)
        except Exception as e:
            hits = []
            print(f"[eval] retrieve error for query {q[:60]}... -> {e}")
        latency = time.time() - t0
        latencies.append(latency)

        retrieved_ids = [r[0] for r in hits]

        # Reciprocal rank: first position among positives
        rr = 0.0
        for rank, pid in enumerate(retrieved_ids, start=1):
            if pid in positive_ids:
                rr = 1.0 / rank
                break
        rr_list.append(rr)

        # Recall@k and Precision@k
        for rk in recall_ks:
            recall_counts[rk] += 1 if any(pid in positive_ids for pid in retrieved_ids[:rk]) else 0
        for pk in precision_ks:
            num_pos_in_topk = sum(1 for pid in retrieved_ids[:pk] if pid in positive_ids)
            precision_sums[pk] += (num_pos_in_topk / pk)

        per_query.append({
            "query": q,
            "positive_ids": positive_ids,
            "retrieved_ids": retrieved_ids,
            "reciprocal_rank": rr,
            "latency": latency
        })

    n = total if total else 1
    mrr = sum(rr_list) / n
    recall_at = {rk: recall_counts[rk] / n for rk in recall_ks}
    precision_at = {pk: precision_sums[pk] / n for pk in precision_ks}
    latency_mean = statistics.mean(latencies) if latencies else 0.0
    latency_median = statistics.median(latencies) if latencies else 0.0

    summary = {
        "n_queries": n,
        "MRR": mrr,
        **{f"Recall@{rk}": recall_at[rk] for rk in recall_ks},
        **{f"Precision@{pk}": precision_at[pk] for pk in precision_ks},
        "latency_mean_s": latency_mean,
        "latency_median_s": latency_median
    }

    return summary, per_query

# ---------- Run evaluation ----------
# Use eval_queries_val if defined, else fall back to eval_queries
eval_items = eval_queries_val if 'eval_queries_val' in globals() else eval_queries

# Evaluate all retriever modes
modes = ["bm25", "dense", "hybrid_interleave", "hybrid_score", "hybrid_rrf"]
results = {}
for m in modes:
    summary, records = evaluate_retriever(eval_items, mode=m, k=DEFAULT_K)
    results[m] = summary
    print(f"\n{m} retriever evaluation summary:")
    for k,v in summary.items():
        print(f"  {k}: {v:.3f}" if isinstance(v,float) else f"  {k}: {v}")

    # Show a few misses
    misses = [r for r in records if r["reciprocal_rank"] == 0.0]
    print(f"  Total misses: {len(misses)} / {len(records)}. Showing up to 3 misses:")
    for r in misses[:3]:
        print("   Query:", r["query"][:80])
        print("    Positives:", r["positive_ids"])
        print("    Retrieved top ids:", r["retrieved_ids"][:8])


Evaluating bm25 retriever:   0%|          | 0/100 [00:00<?, ?it/s]


bm25 retriever evaluation summary:
  n_queries: 100
  MRR: 0.880
  Recall@1: 0.830
  Recall@5: 0.950
  Precision@1: 0.830
  Precision@5: 0.216
  latency_mean_s: 0.033
  latency_median_s: 0.033
  Total misses: 5 / 100. Showing up to 3 misses:
   Query: کووڈ-19 ویکسین کا بنیادی مقصد کیا ہے؟
    Positives: ['p0007']
    Retrieved top ids: ['p0028', 'p0050', 'p0051', 'p0027', 'p0039']
   Query: وینٹیلیشن وبا کے دوران کیوں اہم ہے؟
    Positives: ['p0020']
    Retrieved top ids: ['p0017', 'p0060', 'p0031', 'p0048', 'p0027']
   Query: ویکسین سائیڈ ایفیکٹس کی نگرانی کیسے کی جاتی ہے؟
    Positives: ['p0039']
    Retrieved top ids: ['p0058', 'p0040', 'p0032', 'p0051', 'p0011']


Evaluating dense retriever:   0%|          | 0/100 [00:00<?, ?it/s]


dense retriever evaluation summary:
  n_queries: 100
  MRR: 0.857
  Recall@1: 0.780
  Recall@5: 0.940
  Precision@1: 0.780
  Precision@5: 0.214
  latency_mean_s: 0.029
  latency_median_s: 0.030
  Total misses: 6 / 100. Showing up to 3 misses:
   Query: وینٹیلیشن وبا کے دوران کیوں اہم ہے؟
    Positives: ['p0020']
    Retrieved top ids: ['p0005', 'p0035', 'p0043', 'p0057', 'p0022']
   Query: ویکسین کی افادیت وقت کے ساتھ کیوں کم ہو سکتی ہے؟
    Positives: ['p0032']
    Retrieved top ids: ['p0008', 'p0003', 'p0025', 'p0045', 'p0052']
   Query: ویکسین کے خلاف جھجک کم کرنے کے عملی طریقے کیا ہیں؟
    Positives: ['p0037', 'p0056']
    Retrieved top ids: ['p0007', 'p0008', 'p0052', 'p0060', 'p0011']


Evaluating hybrid_interleave retriever:   0%|          | 0/100 [00:00<?, ?it/s]


hybrid_interleave retriever evaluation summary:
  n_queries: 100
  MRR: 0.857
  Recall@1: 0.780
  Recall@5: 0.940
  Precision@1: 0.780
  Precision@5: 0.214
  latency_mean_s: 0.011
  latency_median_s: 0.010
  Total misses: 6 / 100. Showing up to 3 misses:
   Query: وینٹیلیشن وبا کے دوران کیوں اہم ہے؟
    Positives: ['p0020']
    Retrieved top ids: ['p0005', 'p0035', 'p0043', 'p0057', 'p0022']
   Query: ویکسین کی افادیت وقت کے ساتھ کیوں کم ہو سکتی ہے؟
    Positives: ['p0032']
    Retrieved top ids: ['p0008', 'p0003', 'p0025', 'p0045', 'p0052']
   Query: ویکسین کے خلاف جھجک کم کرنے کے عملی طریقے کیا ہیں؟
    Positives: ['p0037', 'p0056']
    Retrieved top ids: ['p0007', 'p0008', 'p0052', 'p0060', 'p0011']


Evaluating hybrid_score retriever:   0%|          | 0/100 [00:00<?, ?it/s]


hybrid_score retriever evaluation summary:
  n_queries: 100
  MRR: 0.928
  Recall@1: 0.860
  Recall@5: 1.000
  Precision@1: 0.860
  Precision@5: 0.230
  latency_mean_s: 0.011
  latency_median_s: 0.010
  Total misses: 0 / 100. Showing up to 3 misses:


Evaluating hybrid_rrf retriever:   0%|          | 0/100 [00:00<?, ?it/s]


hybrid_rrf retriever evaluation summary:
  n_queries: 100
  MRR: 0.894
  Recall@1: 0.820
  Recall@5: 0.990
  Precision@1: 0.820
  Precision@5: 0.228
  latency_mean_s: 0.011
  latency_median_s: 0.010
  Total misses: 1 / 100. Showing up to 3 misses:
   Query: کووڈ-19 ویکسین کا بنیادی مقصد کیا ہے؟
    Positives: ['p0007']
    Retrieved top ids: ['p0051', 'p0033', 'p0039', 'p0028', 'p0040']


In [None]:
# Cell 6c: Report top-3 from current fine-tuned dense retriever (no re-init)
print("Dense top-3 (fine-tuned):", dense_retrieve(eval_queries[0]["query"], k=3))


In [14]:
# Cell 9b (fixed): Validation diagnostics — Recall@1 and Recall@5 on validation examples for Dense Model only
import numpy as np
from tqdm.auto import tqdm

# Map passage text back to IDs
text2pid = {p["text"]: p["id"] for p in passages_min}

# Build validation pairs (query, gold passage ID)
val_pairs = []
for ex in val_examples[:200]:  # cap to 200 for speed; remove cap for full set
    q = ex.texts[0]
    pos_text = ex.texts[1]
    gold_pid = text2pid.get(pos_text)
    if gold_pid:
        val_pairs.append((q, gold_pid))

def recall_at_k_pairs(pairs, k=1, mode="dense"):
    hits = 0
    for q, gold_pid in tqdm(pairs):
        # retrieve(...) returns (pid, text, score) tuples; take the first element as pid
        retrieved = [r[0] for r in retrieve(q, k=k, mode=mode)]
        if gold_pid in retrieved:
            hits += 1
    return hits / len(pairs) if pairs else 0.0

r1 = recall_at_k_pairs(val_pairs, k=1, mode="dense")
r5 = recall_at_k_pairs(val_pairs, k=5, mode="dense")
print(f"Validation Recall@1 (dense): {r1:.3f}")
print(f"Validation Recall@5 (dense): {r5:.3f}")


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Validation Recall@1 (dense): 0.650
Validation Recall@5 (dense): 0.970


In [None]:
# Optional Cell:
from sentence_transformers import SentenceTransformer
embedder = SentenceTransformer("fine_tuned_sbert_urdu").to("cuda")


In [None]:
# Optional cell - only need to be run if generator model not present as unzipped in drive

# 1. Mount Drive (if not already mounted)
from google.colab import drive
drive.mount('/content/drive')

# 2. Paths — update to your actual Drive path
DRIVE_ZIP = "/content/drive/MyDrive/mbart_urdu_covid.zip"   # path to the zip in Drive
UNZIP_TARGET = "/content/drive/MyDrive/fine_tuned_mbart_urdu"  # folder to create in Drive

# 3. Make sure target folder does not already exist (optional safety)
import os, shutil
if os.path.exists(UNZIP_TARGET):
    print("Warning: target folder already exists:", UNZIP_TARGET)
else:
    # 4. Copy zip into runtime (optional) and unzip directly into Drive
    !unzip -q "{DRIVE_ZIP}" -d "{UNZIP_TARGET}"
    print("Unzipped to:", UNZIP_TARGET)

# 5. List files to confirm
for root, dirs, files in os.walk(UNZIP_TARGET):
    print(root)
    print("  dirs:", dirs)
    print("  files:", files[:10])
    break


In [15]:
# Cell 10: Load fine-tuned MBART generator from Drive (or local path)
# - Paste this cell into your RAG notebook before the RAG inference cell.
# - Assumes you have mounted Google Drive earlier in the notebook:
#     from google.colab import drive
#     drive.mount('/content/drive')
# - Replace `MBART_ARCHIVE_PATH` with the folder path that contains the saved model
#   (the folder should contain config.json, model.safetensors, tokenizer files, etc.)

import os
import torch
from transformers import MBartForConditionalGeneration, MBart50Tokenizer # Changed to MBart50Tokenizer

# ---------- Configuration ----------
# Path to the unzipped fine-tuned model folder (update to your Drive path)
# Corrected path to the nested directory where the model files are actually located
MBART_ARCHIVE_PATH = "/content/drive/MyDrive/fine_tuned_mbart_urdu/fine_tuned_mbart_urdu"  # <- Corrected path

# Device
device = "cuda" if torch.cuda.is_available() else "cpu"

# ---------- Load tokenizer and model ----------
print(f"Loading MBART generator from: {MBART_ARCHIVE_PATH}")
tokenizer = MBart50Tokenizer.from_pretrained(MBART_ARCHIVE_PATH) # Changed to MBart50Tokenizer
# Ensure tokenizer uses Urdu language code for both source and target
tokenizer.src_lang = "ur_PK"
tokenizer.tgt_lang = "ur_PK"

gen_model = MBartForConditionalGeneration.from_pretrained(MBART_ARCHIVE_PATH).to(device)
gen_model.eval()

# Optional: set forced BOS token to ensure Urdu generation if not already set
if "ur_PK" in tokenizer.lang_code_to_id:
    gen_model.config.forced_bos_token_id = tokenizer.lang_code_to_id["ur_PK"]

print(f"✅ MBART loaded on {device}. Tokenizer and model ready for generation.")

Loading MBART generator from: /content/drive/MyDrive/fine_tuned_mbart_urdu/fine_tuned_mbart_urdu
✅ MBART loaded on cuda. Tokenizer and model ready for generation.


In [16]:
# Cell 10b (memory-optimized): Fine-tune MBART generator for RAG-style inputs
# - Uses eval_queries.jsonl as training data
# - Saves fine-tuned model to OUTPUT_DIR on Drive
# - Adjusted for Colab GPU memory limits

import os, json, torch, gc
from pathlib import Path
from datasets import Dataset
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments, DataCollatorForSeq2Seq

# Clear GPU cache before training
gc.collect()
torch.cuda.empty_cache()

# ---------- Config ----------
DATA_DIR = Path("/content/drive/MyDrive/data")
TRAIN_JSONL = DATA_DIR / "eval_queries.jsonl"
OUTPUT_DIR = Path("/content/drive/MyDrive/models/mbart_rag_finetuned")

# Smaller lengths to reduce memory
MAX_INPUT_LENGTH = 256
MAX_TARGET_LENGTH = 64
BATCH_SIZE = 1              # keep tiny batch size
EPOCHS = 2                  # fewer epochs to fit memory
LEARNING_RATE = 2e-5
GRAD_ACCUM_STEPS = 8        # simulate larger batch via accumulation

# ---------- Preconditions ----------
assert 'tokenizer' in globals() and 'gen_model' in globals(), "tokenizer and gen_model must be loaded (Cell 10)"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Training device:", device)

# Enable gradient checkpointing to save memory
gen_model.gradient_checkpointing_enable()

# ---------- Load training examples ----------
train_examples = []
with open(TRAIN_JSONL, "r", encoding="utf-8") as f:
    for line in f:
        line = line.strip()
        if not line:
            continue
        try:
            train_examples.append(json.loads(line))
        except Exception:
            continue

assert len(train_examples) > 0, "No training examples found in eval_queries.jsonl"

# ---------- Helper to build RAG-style input text ----------
def build_rag_input(example, top_k=3, mode="hybrid_score"):
    q = example.get("query") or example.get("question") or example.get("q")
    hits = retrieve(q, k=top_k, mode=mode)
    passages = [f"[حوالہ] {ptext}" for _, ptext, _ in hits]
    context = "\n\n".join(passages)
    instruction = "حوالہ شدہ معلومات کی بنیاد پر مختصر اور درست جواب لکھیں۔"
    return f"سوال: {q}\n\nحوالہ شدہ معلومات:\n{context}\n\n{instruction}"

# ---------- Build dataset ----------
inputs, targets = [], []
for ex in train_examples:
    inp = build_rag_input(ex, top_k=3, mode="hybrid_score")
    tgt = ex.get("gold_answer") or ""
    if not tgt:
        continue
    inputs.append(inp)
    targets.append(tgt)

assert len(inputs) > 0, "No valid (input,target) pairs constructed"

def tokenize_batch(batch):
    model_inputs = tokenizer(batch["input"], max_length=MAX_INPUT_LENGTH, truncation=True, padding="max_length")
    labels = tokenizer(batch["target"], max_length=MAX_TARGET_LENGTH, truncation=True, padding="max_length", text_target=batch["target"])
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

hf_ds = Dataset.from_dict({"input": inputs, "target": targets})
hf_ds = hf_ds.map(tokenize_batch, batched=True, remove_columns=["input", "target"])

# ---------- Training arguments ----------
training_args = Seq2SeqTrainingArguments(
    output_dir=str(OUTPUT_DIR),
    per_device_train_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRAD_ACCUM_STEPS,
    num_train_epochs=EPOCHS,
    learning_rate=LEARNING_RATE,
    fp16=torch.cuda.is_available(),
    save_total_limit=1,
    save_strategy="epoch",
    logging_strategy="steps",
    logging_steps=50,
    predict_with_generate=False,
    remove_unused_columns=True,
    push_to_hub=False,
)

data_collator = DataCollatorForSeq2Seq(tokenizer, model=gen_model, label_pad_token_id=tokenizer.pad_token_id)

# ---------- Trainer ----------
trainer = Seq2SeqTrainer(
    model=gen_model,
    args=training_args,
    train_dataset=hf_ds,
    tokenizer=tokenizer,
    data_collator=data_collator,
)

print("Starting fine-tuning (memory-optimized).")
trainer.train()

# ---------- Save fine-tuned model ----------
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
trainer.save_model(str(OUTPUT_DIR))
tokenizer.save_pretrained(str(OUTPUT_DIR))
print(f"Fine-tuned generator saved to {OUTPUT_DIR}")


Training device: cuda


Map:   0%|          | 0/100 [00:00<?, ? examples/s]

  trainer = Seq2SeqTrainer(


Starting fine-tuning (memory-optimized).


Step,Training Loss




Fine-tuned generator saved to /content/drive/MyDrive/models/mbart_rag_finetuned


In [17]:
# Cell 10c: Evaluate fine-tuned generator alone (no retrieval context)
# - Loads eval queries and gold answers
# - Generates answers directly from the fine-tuned MBART model
# - Computes BLEU, chrF, token-level F1, latency
# - Provides summary metrics for comparison with RAG and retrievers

import time, statistics
import numpy as np
import sacrebleu

# ---------- Preconditions ----------
assert 'tokenizer' in globals() and 'gen_model' in globals(), "Fine-tuned generator must be loaded (Cell 10b)"
assert 'eval_queries' in globals(), "eval_queries must be loaded"

# ---------- Helper: token-level F1 ----------
def token_f1(pred, refs):
    def toks(s):
        return [t for t in s.strip().split() if t]
    pred_t = set(toks(pred))
    if not refs:
        return 0.0
    best = 0.0
    for r in refs:
        ref_t = set(toks(r))
        if not ref_t:
            continue
        tp = len(pred_t & ref_t)
        prec = tp / max(1, len(pred_t))
        rec = tp / max(1, len(ref_t))
        f1 = 0.0 if (prec + rec) == 0 else (2 * prec * rec) / (prec + rec)
        best = max(best, f1)
    return best

# ---------- Build references ----------
def get_references_for_item(item):
    refs = []
    if "gold_answer" in item and item["gold_answer"]:
        refs = [item["gold_answer"]]
    elif "answers" in item and item["answers"]:
        val = item["answers"]
        if isinstance(val, str):
            refs = [val]
        elif isinstance(val, list):
            refs = [x for x in val if isinstance(x, str) and x.strip()]
    return [r.strip() for r in refs if r and r.strip()]

# ---------- Evaluation loop ----------
preds = []
refs_all = []
f1s = []
latencies = []

for it in eval_queries:
    q = it["query"]
    refs = get_references_for_item(it)

    # Build simple prompt: question only (no retrieval context)
    input_text = f"سوال: {q}\n\nبراہ کرم مختصر اور درست جواب دیں۔"
    inputs = tokenizer(input_text, return_tensors="pt", max_length=256, truncation=True).to(gen_model.device)

    t0 = time.time()
    outputs = gen_model.generate(**inputs, max_length=64, num_beams=4)
    latency = time.time() - t0
    latencies.append(latency)

    pred = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
    preds.append(pred)
    refs_all.append(refs if refs else [""])
    f1s.append(token_f1(pred, refs))

# ---------- BLEU / chrF ----------
max_refs = max(len(r) for r in refs_all) if refs_all else 1
ref_sets = []
for i in range(max_refs):
    ref_sets.append([ (refs[i] if i < len(refs) else "") for refs in refs_all ])

bleu = sacrebleu.corpus_bleu(preds, ref_sets)
chrf = sacrebleu.corpus_chrf(preds, ref_sets)

# ---------- Summary ----------
summary_gen = {
    "BLEU": float(bleu.score),
    "chrF": float(chrf.score),
    "F1_mean": float(np.mean(f1s)) if f1s else 0.0,
    "F1_median": float(statistics.median(f1s)) if f1s else 0.0,
    "latency_mean_s": float(np.mean(latencies)) if latencies else 0.0,
    "latency_median_s": float(statistics.median(latencies)) if latencies else 0.0,
    "n": len(eval_queries)
}

print("\n=== Generator-only Evaluation (no retrieval) ===")
for k,v in summary_gen.items():
    print(f"- {k}: {v}")


`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`...
Caching is incompatible with gradient checkpointing in MBartDecoderLayer. Setting `past_key_values=None`.



=== Generator-only Evaluation (no retrieval) ===
- BLEU: 0.0238927947466256
- chrF: 3.3246001371286993
- F1_mean: 0.07335472491045246
- F1_median: 0.10818713450292397
- latency_mean_s: 2.434983456134796
- latency_median_s: 2.2022024393081665
- n: 100


In [None]:
# Optional to clear GPU cache:
import torch, gc
gc.collect()
torch.cuda.empty_cache()


In [None]:
# Optional cell to just check if our imported generative model above performs fine independantly.
# Expectation: Just generate a fluent urdu answer since this is just a smoke test.

prompt = "کووڈ-19 کی عام علامات کیا ہیں؟"
inputs = tokenizer(prompt, return_tensors="pt").to(device)
with torch.no_grad():
    output_ids = gen_model.generate(
        **inputs,
        max_length=64,
        num_beams=4,
        length_penalty=1.0
    )
print(tokenizer.decode(output_ids[0], skip_special_tokens=True))


In [18]:
# Cell 11 (final, fixed): RAG inference — hybrid_score fusion + fine-tuned generator
# - Loads tokenizer from base MBART ("facebook/mbart-large-50") for robust config
# - Loads fine-tuned generator weights from Drive path
# - Sets Urdu language codes (src_lang, forced_bos_token_id)
# - Uses retrieve(..., mode="hybrid_score") by default
# - Strict grounding, intent-aware expansion, corpus rescue, extractive QA fallback

import time
import json
import re
import os
import torch
from typing import List, Tuple
from pathlib import Path

# ---------- Preconditions ----------
assert 'retrieve' in globals(), "retrieve(...) must be defined (Cell 9)"
assert 'pid2text' in globals(), "pid2text must be loaded"
assert 'passages_min' in globals() and len(passages_min) > 0, "passages_min must be loaded"

# ---------- Device ----------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ---------- Load tokenizer (base) and fine-tuned generator (weights) ----------
from transformers import MBart50TokenizerFast, MBartForConditionalGeneration

GEN_PATH = "/content/drive/MyDrive/models/mbart_rag_finetuned"
BASE_TOKENIZER = "facebook/mbart-large-50"
print(f"Loading tokenizer from base: {BASE_TOKENIZER}")
tokenizer = MBart50TokenizerFast.from_pretrained(BASE_TOKENIZER)

print(f"Loading fine-tuned generator weights from: {GEN_PATH}")
gen_model = MBartForConditionalGeneration.from_pretrained(GEN_PATH).to(device)
gen_model.eval()

# Set Urdu language codes for MBART
# Use 'ur_PK' (Urdu, Pakistan) if available; fallback to 'ur_IN' if needed
src_lang = "ur_PK" if "ur_PK" in tokenizer.lang_code_to_id else "ur_IN"
tokenizer.src_lang = src_lang
if hasattr(tokenizer, "lang_code_to_id"):
    forced_bos_id = tokenizer.lang_code_to_id.get(src_lang)
    if forced_bos_id is not None:
        # Configure BOS token for generation
        if hasattr(gen_model, "generation_config"):
            gen_model.generation_config.forced_bos_token_id = forced_bos_id
        else:
            gen_model.config.forced_bos_token_id = forced_bos_id

print(f"Fine-tuned generator loaded. src_lang={src_lang}, forced_bos_token_id={getattr(gen_model.config, 'forced_bos_token_id', None)}")

# ---------- Config ----------
DATA_DIR = Path("/content/drive/MyDrive/data")
CORPUS_JSONL_PATH = DATA_DIR / "urdu_covid_corpus_clean.jsonl"
DEFAULT_RETRIEVE_K = 50      # larger candidate pool for rescue/rerank
FINAL_CONTEXT_K = 5          # passages kept for generation
DEFAULT_RETRIEVE_MODE = "hybrid_score"  # use true fusion by default

SYMPTOM_EXPANSION = " علامات بخار کھانسی سانس ذائقہ بو تھکن"
SYMPTOM_TOKENS = {t.lower() for t in ["بخار","کھانسی","سانس","علامات","ذائقہ","بو","تھکن","سانس لینے","سانس پھولنا","گلے"]}

# ---------- Load corpus (keyword rescue; read-only) ----------
def load_corpus_jsonl(path: str):
    docs = []
    try:
        with open(path, "r", encoding="utf-8") as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue
                try:
                    docs.append(json.loads(line))
                except Exception:
                    continue
    except FileNotFoundError:
        docs = []
    return docs

_CORPUS_DOCS = load_corpus_jsonl(str(CORPUS_JSONL_PATH))
print(f"Loaded _CORPUS_DOCS with {len(_CORPUS_DOCS)} items.")

# ---------- Intent detection ----------
INTENT_KEYWORDS = {
    "symptoms": ["علامات", "بخار", "کھانسی", "سانس", "تھکن", "ذائقہ", "بو"],
    "diagnosis": ["PCR", "RT-PCR", "اینٹیجن", "سویب", "ٹیسٹ", "تشخیص"],
    "prevention": ["ماسک", "ہاتھ", "فاصلے", "سینیٹائزر", "صفائی"],
    "vaccination": ["ویکسین", "ٹیکہ", "بوسٹر", "ڈوز"]
}

def infer_intent(query: str) -> str:
    q = query.lower()
    if any(k.lower() in q for k in INTENT_KEYWORDS["diagnosis"]): return "diagnosis"
    if "علامات" in q or any(k.lower() in q for k in INTENT_KEYWORDS["symptoms"]): return "symptoms"
    if any(k.lower() in q for k in INTENT_KEYWORDS["vaccination"]): return "vaccination"
    if any(k.lower() in q for k in INTENT_KEYWORDS["prevention"]): return "prevention"
    return "general"

def expanded_query_for_intent(query: str, intent: str) -> str:
    if intent == "symptoms":
        return query + " " + SYMPTOM_EXPANSION
    return query

# ---------- Retrieval helpers ----------
def dedupe_candidates(cands: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
    """
    Accepts iterable items:
      - (pid, text)
      - (pid, text, score)
      - any sequence with pid at [0] and text at [1]
    Returns deduplicated list of (pid, text) preserving first-seen order.
    """
    seen = set()
    out = []
    for item in cands:
        try:
            pid = item[0]
            txt = item[1]
        except Exception:
            continue
        if pid in seen:
            continue
        seen.add(pid)
        out.append((pid, txt))
    return out

def filter_retrieved_by_intent(retrieved_candidates: List[Tuple[str, str]], intent: str, keep: int) -> List[Tuple[str, str]]:
    if intent == "general":
        return dedupe_candidates(retrieved_candidates)[:keep]
    intent_keywords = INTENT_KEYWORDS.get(intent, [])
    prioritized, other = [], []
    for pid, text in retrieved_candidates:
        (prioritized if any(keyword.lower() in text.lower() for keyword in intent_keywords) else other).append((pid, text))
    return dedupe_candidates(prioritized + other)[:keep]

# ---------- Keyword rescue ----------
def corpus_keyword_rescue(docs, tokens, limit=10):
    hits = []
    for d in docs:
        txt = d.get("text", "").lower()
        if any(tok in txt for tok in tokens):
            hits.append((d.get("id"), d.get("text")))
            if len(hits) >= limit:
                break
    return hits

# ---------- Grounding utilities ----------
INSTRUCTION_TOKENS = ["براہ کرم", "جواب", "ہدایات", "instruction", "Answer:", "Response:", "حوالہ شدہ معلومات"]

def strip_instruction_echoes(text: str) -> str:
    t = text.strip()
    for tok in INSTRUCTION_TOKENS:
        t = t.replace(tok, "")
    t = re.sub(r"\s+", " ", t).strip()
    return t

def content_words(s: str):
    stop = {"ہے", "ہیں", "میں", "کی", "کے", "اور", "سے", "پر", "کہ", "ہی", "بھی"}
    toks = [w for w in re.split(r"\W+", s) if len(w) >= 3 and w not in stop]
    return set(toks)

def has_overlap_with_context(ans: str, context_texts):
    ans_words = content_words(ans.lower())
    ctx_words = set()
    for txt in context_texts:
        ctx_words |= content_words(txt.lower())
    return len(ans_words & ctx_words) >= 2

def looks_like_echo(query: str, ans: str) -> bool:
    q = re.sub(r"\s+", " ", query).strip()
    a = re.sub(r"\s+", " ", ans).strip()
    shared = os.path.commonprefix([q, a])
    long_overlap = len(shared) >= max(8, int(0.3 * len(q)))
    starts_like_q = a.startswith(q[: max(12, len(q)//2)])
    return long_overlap or starts_like_q

def is_vague(ans: str) -> bool:
    a = ans.strip()
    if len(a) < 12:
        return True
    hedges = ["منحصر", "ممکن", "عام طور", "ضروری", "اہم", "کامیابی"]
    return any(h in a for h in hedges) and len(a.split()) < 16

# ---------- Generation helper (fine-tuned MBART) ----------
def generate_answer_with_mbart(query, retrieved_passages, max_length=160, num_beams=6, intent="general"):
    def short(p, limit=350):
        p = " ".join(p.split())
        return (p[:limit] + "…") if len(p) > limit else p

    context = "\n\n".join([f"[حوالہ] {short(p)}" for _, p in retrieved_passages])

    if intent == "symptoms":
        instruction = (
            "صرف انہی حوالہ شدہ معلومات کی بنیاد پر علامات کی فہرست لکھیں۔ "
            "اضافی مشورے یا عمومی صحت کی معلومات شامل نہ کریں۔ سوال کو دوبارہ نہ لکھیں۔"
        )
    else:
        instruction = (
            "صرف انہی حوالہ شدہ معلومات کی بنیاد پر مختصر اور درست جواب لکھیں۔ "
            "اضافی معلومات شامل نہ کریں اور سوال کو دوبارہ نہ لکھیں۔"
        )

    prompt = f"سوال: {query}\n\nحوالہ شدہ معلومات:\n{context}\n\n{instruction}"

    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512).to(device)
    with torch.no_grad():
        outputs = gen_model.generate(
            **inputs,
            max_length=max_length,
            min_length=8,
            num_beams=num_beams,
            length_penalty=1.0,
            repetition_penalty=1.5,
            no_repeat_ngram_size=2,
            do_sample=False,
            early_stopping=True
        )
    ans = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
    ans = strip_instruction_echoes(ans)
    if ans and not ans.endswith("۔"):
        ans += "۔"
    return ans, prompt

# ---------- Extractive QA fallback (multi-passage) ----------
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
qa_model_name = "deepset/xlm-roberta-base-squad2"
print(f"Loading extractive QA model: {qa_model_name}")
qa_tokenizer = AutoTokenizer.from_pretrained(qa_model_name)
qa_model = AutoModelForQuestionAnswering.from_pretrained(qa_model_name).to(device)
qa_model.eval()
print("Extractive QA model loaded successfully.")

def extractive_answer_multi(query, passages, top_k=3, min_span_len=3, intent="general"):
    best_ans, best_score = "", float("-inf")
    for passage in passages[:top_k]:
        inputs = qa_tokenizer(query, passage, return_tensors="pt", truncation=True, max_length=512).to(device)
        with torch.no_grad():
            outputs = qa_model(**inputs)
        start_logits = outputs.start_logits[0]
        end_logits = outputs.end_logits[0]
        score = float(torch.max(start_logits).item() + torch.max(end_logits).item())
        start = int(torch.argmax(start_logits).item())
        end = int(torch.argmax(end_logits).item()) + 1
        tokens = inputs["input_ids"][0][start:end]
        ans = qa_tokenizer.decode(tokens, skip_special_tokens=True).strip()

        if not ans or len(ans.split()) < min_span_len:
            continue
        if intent == "symptoms":
            lower = ans.lower()
            if not any(tok in lower for tok in SYMPTOM_TOKENS):
                continue

        if score > best_score:
            best_ans, best_score = ans, score
    return best_ans.strip()

# ---------- RAG wrapper ----------
def rag_answer(query, k=5, mode=DEFAULT_RETRIEVE_MODE, metadata_filters=None, max_length=160, num_beams=6, debug=True):
    t0 = time.time()
    intent = infer_intent(query)
    q_for_retrieval = expanded_query_for_intent(query, intent)

    # Start with candidates; prepend p0001 for symptoms if available
    initial_candidates = []
    if intent == "symptoms" and ('pid2text' in globals()) and ('p0001' in pid2text):
        p0001_text = pid2text['p0001']
        initial_candidates.append(('p0001', p0001_text))
        if debug:
            print("DEBUG: Explicitly pre-pending p0001 for symptom query.")

    # Retrieve via fusion retriever
    retrieve_k = DEFAULT_RETRIEVE_K if intent == "symptoms" else max(k, DEFAULT_RETRIEVE_K // 5)
    retrieved = retrieve(q_for_retrieval, k=retrieve_k, mode=mode, **(metadata_filters or {}))
    initial_candidates.extend(retrieved)

    # Deduplicate and intent-filter
    retrieved_raw = dedupe_candidates(initial_candidates)
    retrieved_filtered = filter_retrieved_by_intent(retrieved_raw, intent, keep=FINAL_CONTEXT_K)

    # Corpus rescue (symptoms only) if filtered lacks strong symptom signals
    rescued_hits = []
    if intent == "symptoms":
        has_signal = any(any(tok in txt.lower() for tok in SYMPTOM_TOKENS) for _, txt in retrieved_filtered)
        if not has_signal:
            pool_hits = [(pid, txt) for pid, txt in retrieved_raw if any(tok in txt.lower() for tok in SYMPTOM_TOKENS)]
            if pool_hits:
                retrieved_filtered = dedupe_candidates(pool_hits + retrieved_filtered)[:FINAL_CONTEXT_K]
            else:
                rescued_hits = corpus_keyword_rescue(_CORPUS_DOCS, SYMPTOM_TOKENS, limit=10)
                if rescued_hits:
                    retrieved_filtered = dedupe_candidates(rescued_hits + retrieved_filtered)[:FINAL_CONTEXT_K]

    # Empty context → safe return; no hallucinations
    if not retrieved_filtered:
        latency = time.time() - t0
        return {
            "query": query,
            "answer": "سیاق میں متعلقہ معلومات دستیاب نہیں۔",
            "provenance": [],
            "latency": latency,
            "used_fallback": True,
            "generator_failed": True,
            "intent": intent,
            "debug": {
                "retrieved_top": [{"id": pid, "text": txt[:300]} for pid, txt in retrieved_raw[:k]],
                "retrieved_filtered": [],
                "rescued_hits": [{"id": pid, "text": txt[:300]} for pid, txt in rescued_hits],
                "prompt_context_preview": ""
            }
        }

    # Try generator
    gen_ans = ""
    generator_failed = False
    used_fallback = False
    try:
        gen_ans, prompt_used = generate_answer_with_mbart(
            query, retrieved_filtered, max_length=max_length, num_beams=num_beams, intent=intent
        )
    except Exception:
        gen_ans = ""
        generator_failed = True
        prompt_used = ""

    # Grounding heuristics
    context_texts = [p for _, p in retrieved_filtered]
    if (not gen_ans) or looks_like_echo(query, gen_ans) or is_vague(gen_ans) or not has_overlap_with_context(gen_ans, context_texts):
        generator_failed = True

    # Fallback to extractive QA
    if generator_failed:
        used_fallback = True
        def rephrase_from_question(q, a): return a
        span = extractive_answer_multi(query, context_texts, top_k=min(3, len(context_texts)), intent=intent)
        final_ans = rephrase_from_question(query, span) if span else "کوئی جواب نہیں ملا۔"
    else:
        final_ans = gen_ans

    latency = time.time() - t0

    # Provenance
    if 'corpus_clean' in globals():
        meta_map = {p["id"]: p for p in corpus_clean}
    else:
        meta_map = {p["id"]: p for p in passages_min}
    provenance = []
    for pid, text in retrieved_filtered:
        meta = meta_map.get(pid, {})
        provenance.append({"id": pid, "source": meta.get("source"), "retrieved_at": meta.get("retrieved_at")})

    # Debug
    debug_block = None
    if debug:
        debug_block = {
            "retrieved_top": [{"id": pid, "text": txt[:300]} for pid, txt in retrieved_raw[:min(10, len(retrieved_raw))]],
            "retrieved_filtered": [{"id": pid, "text": txt[:300]} for pid, txt in retrieved_filtered],
            "rescued_hits": [{"id": pid, "text": txt[:300]} for pid, txt in rescued_hits],
            "prompt_context_preview": context_texts[0][:400] if context_texts else "",
            "prompt_used": prompt_used
        }

    return {
        "query": query,
        "answer": final_ans,
        "provenance": provenance,
        "latency": latency,
        "used_fallback": used_fallback,
        "generator_failed": generator_failed,
        "intent": intent,
        "debug": debug_block
    }

# ---------- Quick smoke test ----------
res = rag_answer("کووڈ-19 کی عام علامات کیا ہیں؟", k=5, mode=DEFAULT_RETRIEVE_MODE, debug=True)
print("Answer:", res["answer"])
print("Used fallback:", res["used_fallback"], "Generator failed:", res["generator_failed"], "Intent:", res["intent"])
if res["debug"]:
    print("\n--- Debug: Top retrieved (raw) ---")
    for i, it in enumerate(res["debug"]["retrieved_top"][:5], 1):
        print(f"{i}. [{it['id']}] {it['text']}")
    print("\n--- Debug: Filtered context used for generation ---")
    for i, it in enumerate(res["debug"]["retrieved_filtered"][:5], 1):
        print(f"{i}. [{it['id']}] {it['text']}")
    print("\n--- Debug: Rescued hits (corpus scan) ---")
    for i, it in enumerate(res["debug"]["rescued_hits"][:5], 1):
        print(f"{i}. [{it['id']}] {it['text']}")


Loading tokenizer from base: facebook/mbart-large-50


tokenizer_config.json:   0%|          | 0.00/531 [00:00<?, ?B/s]

sentencepiece.bpe.model:   0%|          | 0.00/5.07M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/649 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

Loading fine-tuned generator weights from: /content/drive/MyDrive/models/mbart_rag_finetuned
Fine-tuned generator loaded. src_lang=ur_PK, forced_bos_token_id=None
Loaded _CORPUS_DOCS with 60 items.
Loading extractive QA model: deepset/xlm-roberta-base-squad2


tokenizer_config.json:   0%|          | 0.00/79.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/605 [00:00<?, ?B/s]

sentencepiece.bpe.model:   0%|          | 0.00/5.07M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/150 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.11G [00:00<?, ?B/s]

Extractive QA model loaded successfully.
DEBUG: Explicitly pre-pending p0001 for symptom query.
Answer: عام علامات میں بخار، کھانسی اور سانس لینے میں مشکل شامل ہیں؛۔
Used fallback: False Generator failed: False Intent: symptoms

--- Debug: Top retrieved (raw) ---
1. [p0001] کورونا وائرس مرض 2019 (COVID-19) ایک متعدی بیماری ہے جس کی عام علامات میں بخار، کھانسی اور سانس لینے میں دشواری شامل ہیں۔
2. [p0024] بچوں میں کووڈ-19 عام طور پر ہلکا ہوتا ہے مگر بعض نادر معاملات میں شدید علامات سامنے آ سکتی ہیں؛ بچوں کے لیے مخصوص رہنمائی مختلف ہو سکتی ہے۔
3. [p0008] ویکسین کے عام مضر اثرات میں انجیکشن سائٹ پر درد، ہلکا بخار اور تھکن شامل ہو سکتے ہیں جو عموماً چند دنوں میں ختم ہو جاتے ہیں؛ سنگین ضمنی اثرات نایاب ہیں۔
4. [p0044] کووڈ-19 کے مریضوں کے لیے ریہیب پروگرامز طویل علامات کے انتظام میں مدد دیتے ہیں، جیسے سانس کی ورزشیں اور توانائی کی بحالی کے پروگرام۔
5. [p0021] کووڈ-19 کے بعد بعض افراد میں طویل مدتی علامات (Long COVID) جیسے تھکن، سانس کی تکلیف اور دماغی دھند برقرار رہ سکتی ہیں؛ ریہیب پروگرامز

In [None]:
# Dummy cell so that notebook doesn`t reset
print("Notebook don`t reset!!")

In [None]:
# Quick smoke-run: 10 eval items to sanity-check Cell 12
sample_eval = eval_items[:10]  # eval_items is defined inside Cell 12; if not, use eval_queries[:10]
print("Running quick RAG eval on 10 items (hybrid) to estimate runtime...")
_ = rag_generation_metrics(sample_eval, mode="hybrid", k=5, max_length=120, num_beams=4)
print("Quick run complete.")


In [19]:
# Cell 12: Comprehensive evaluation — retrieval, RAG generation, ablations, human sampling
# - Evaluates bm25, dense, hybrid_interleave, hybrid_score, hybrid_rrf
# - Computes Retrieval: Recall@1, Recall@5, MRR
# - Computes Generation (RAG): BLEU, chrF, token F1, latency, fallback rates
# - Ablation: disable extractive QA fallback; compare retrieval modes
# - Exports summary JSON and human-eval samples

import time
import json
import numpy as np
import statistics
from pathlib import Path

# ---------- Preconditions ----------
assert 'retrieve' in globals(), "retrieve(...) must be defined (Cell 9)"
assert 'rag_answer' in globals(), "rag_answer(...) must be defined (Cell 11)"
assert 'eval_queries' in globals(), "eval_queries must be loaded"
assert isinstance(eval_queries, list) and len(eval_queries) > 0, "eval_queries must be a non-empty list"

# ---------- Config ----------
EVAL_ITEMS = eval_queries_val if 'eval_queries_val' in globals() else eval_queries
MODES = ["bm25", "dense", "hybrid_interleave", "hybrid_score", "hybrid_rrf"]
K_RETRIEVAL = 5
RAG_K = 5
RAG_MAX_LEN = 160
RAG_BEAMS = 6

# Output paths
OUT_DIR = Path("/content/drive/MyDrive/eval_outputs")
OUT_DIR.mkdir(parents=True, exist_ok=True)
SUMMARY_JSON = OUT_DIR / "rag_eval_summary.json"
PER_QUERY_JSON = OUT_DIR / "rag_eval_per_query.jsonl"
HUMAN_SAMPLES_JSON = OUT_DIR / "rag_human_samples.jsonl"

# ---------- Helpers ----------
def get_references_for_item(item):
    refs = []
    if "gold_answer" in item and item["gold_answer"]:
        refs = [item["gold_answer"]]
    elif "answers" in item and item["answers"]:
        val = item["answers"]
        if isinstance(val, str):
            refs = [val]
        elif isinstance(val, list):
            refs = [x for x in val if isinstance(x, str) and x.strip()]
    return [r.strip() for r in refs if r and r.strip()]

def token_f1(pred, refs):
    def toks(s):
        return [t for t in s.strip().split() if t]
    pred_t = set(toks(pred))
    if not refs:
        return 0.0
    best = 0.0
    for r in refs:
        ref_t = set(toks(r))
        if not ref_t:
            continue
        tp = len(pred_t & ref_t)
        prec = tp / max(1, len(pred_t))
        rec = tp / max(1, len(ref_t))
        f1 = 0.0 if (prec + rec) == 0 else (2 * prec * rec) / (prec + rec)
        best = max(best, f1)
    return best

# ---------- Retrieval evaluation ----------
def retrieval_metrics(eval_items, mode="hybrid_score", k=K_RETRIEVAL):
    rr_list = []
    r1 = 0
    r5 = 0
    latencies = []
    for it in eval_items:
        q = it.get("query") or it.get("question") or it.get("q") or ""
        pos_ids = set(map(str, it.get("positive_ids", [])))
        t0 = time.time()
        hits = retrieve(q, k=k, mode=mode)
        latencies.append(time.time() - t0)
        retrieved_ids = [r[0] for r in hits]
        # MRR
        rr = 0.0
        for rank, pid in enumerate(retrieved_ids, start=1):
            if pid in pos_ids:
                rr = 1.0 / rank
                break
        rr_list.append(rr)
        # Recall@1 / Recall@5
        r1 += 1 if (len(retrieved_ids) >= 1 and retrieved_ids[0] in pos_ids) else 0
        r5 += 1 if any(pid in pos_ids for pid in retrieved_ids[:5]) else 0
    n = len(eval_items)
    return {
        "MRR": float(np.mean(rr_list)) if rr_list else 0.0,
        "Recall@1": r1 / n if n else 0.0,
        "Recall@5": r5 / n if n else 0.0,
        "latency_mean_s": float(np.mean(latencies)) if latencies else 0.0,
        "latency_median_s": float(statistics.median(latencies)) if latencies else 0.0,
        "n": n
    }

# ---------- SacreBLEU setup ----------
try:
    import sacrebleu
except ImportError:
    import sys, subprocess
    subprocess.run([sys.executable, "-m", "pip", "install", "sacrebleu"], check=True)
    import sacrebleu

# ---------- RAG generation evaluation ----------
def rag_generation_metrics(eval_items, mode="hybrid_score", k=RAG_K, max_length=RAG_MAX_LEN, num_beams=RAG_BEAMS, allow_fallback=True):
    preds = []
    refs_all = []
    f1s = []
    latencies = []
    fallback_count = 0
    gen_fail_count = 0
    per_query = []

    for it in eval_items:
        q = it["query"]
        refs = get_references_for_item(it)
        res = rag_answer(q, k=k, mode=mode, max_length=max_length, num_beams=num_beams, debug=False)
        # ablation: disable fallback (use generator answer only; if generator_failed, treat as empty)
        if not allow_fallback and res.get("generator_failed"):
            pred = ""
            used_fallback = False
            generator_failed = True
        else:
            pred = res.get("answer", "").strip()
            used_fallback = res.get("used_fallback", False)
            generator_failed = res.get("generator_failed", False)

        preds.append(pred)
        refs_all.append(refs if refs else [""])
        f1s.append(token_f1(pred, refs))
        latencies.append(res.get("latency", 0.0))
        fallback_count += 1 if used_fallback else 0
        gen_fail_count += 1 if generator_failed else 0

        per_query.append({
            "query": q,
            "intent": res.get("intent"),
            "pred": pred,
            "refs": refs,
            "used_fallback": used_fallback,
            "generator_failed": generator_failed,
            "latency_s": res.get("latency", 0.0),
            "provenance": res.get("provenance", [])
        })

    # SacreBLEU / chrF formatting: transpose references to list-of-reference-sets
    max_refs = max(len(r) for r in refs_all) if refs_all else 1
    ref_sets = []
    for i in range(max_refs):
        ref_sets.append([ (refs[i] if i < len(refs) else "") for refs in refs_all ])

    bleu = sacrebleu.corpus_bleu(preds, ref_sets)
    chrf = sacrebleu.corpus_chrf(preds, ref_sets)

    summary = {
        "BLEU": float(bleu.score),
        "chrF": float(chrf.score),
        "F1_mean": float(np.mean(f1s)) if f1s else 0.0,
        "F1_median": float(statistics.median(f1s)) if f1s else 0.0,
        "latency_mean_s": float(np.mean(latencies)) if latencies else 0.0,
        "latency_median_s": float(statistics.median(latencies)) if latencies else 0.0,
        "fallback_rate": fallback_count / len(eval_items) if eval_items else 0.0,
        "generator_fail_rate": gen_fail_count / len(eval_items) if eval_items else 0.0,
        "n": len(eval_items)
    }
    return summary, per_query

# ---------- Human factuality sampling export ----------
def export_human_samples(per_query_records, sample_size=25, path=HUMAN_SAMPLES_JSON):
    # Select diverse samples: prioritize no-fallback, then fallback, mix intents
    # Simple strategy: take first N no-fallback, then fill with fallback cases
    no_fb = [r for r in per_query_records if not r["used_fallback"]]
    fb = [r for r in per_query_records if r["used_fallback"]]
    sample = (no_fb[:sample_size//2]) + (fb[:sample_size - len(no_fb[:sample_size//2])])
    with open(path, "w", encoding="utf-8") as f:
        for r in sample:
            payload = {
                "query": r["query"],
                "intent": r.get("intent"),
                "prediction": r["pred"],
                "references": r["refs"],
                "provenance": r.get("provenance", []),
                "human_judgment": {
                    "factual_consistency": "TBD_true/partial/false",
                    "helpfulness": "TBD_1-5",
                    "notes": ""
                }
            }
            f.write(json.dumps(payload, ensure_ascii=False) + "\n")
    return len(sample)

# ---------- Run retrieval evaluation across modes ----------
retrieval_results = {}
print("=== Retrieval Quality (Recall@1 / Recall@5 / MRR) ===")
for m in MODES:
    metrics = retrieval_metrics(EVAL_ITEMS, mode=m, k=K_RETRIEVAL)
    retrieval_results[m] = metrics
    print(f"- {m}: Recall@1={metrics['Recall@1']:.3f} | Recall@5={metrics['Recall@5']:.3f} | MRR={metrics['MRR']:.3f}")

# ---------- RAG generation & end-to-end evaluation (by mode) ----------
rag_results_by_mode = {}
per_query_all = {}

print("\n=== RAG Generation & End-to-End (with fallback) ===")
for m in MODES:
    summary, per_query = rag_generation_metrics(EVAL_ITEMS, mode=m, k=RAG_K, max_length=RAG_MAX_LEN, num_beams=RAG_BEAMS, allow_fallback=True)
    rag_results_by_mode[m] = summary
    per_query_all[f"{m}_with_fb"] = per_query
    print(f"- {m}: BLEU={summary['BLEU']:.2f}, chrF={summary['chrF']:.2f}, F1_mean={summary['F1_mean']:.3f}, latency_mean_s={summary['latency_mean_s']:.3f}, fallback_rate={summary['fallback_rate']:.2f}")

print("\n=== Ablation: RAG without extractive QA fallback (generator-only on retrieved context) ===")
rag_results_no_fb = {}
for m in MODES:
    summary, per_query = rag_generation_metrics(EVAL_ITEMS, mode=m, k=RAG_K, max_length=RAG_MAX_LEN, num_beams=RAG_BEAMS, allow_fallback=False)
    rag_results_no_fb[m] = summary
    per_query_all[f"{m}_no_fb"] = per_query
    print(f"- {m}: BLEU={summary['BLEU']:.2f}, chrF={summary['chrF']:.2f}, F1_mean={summary['F1_mean']:.3f}, latency_mean_s={summary['latency_mean_s']:.3f}, generator_fail_rate={summary['generator_fail_rate']:.2f}")

# ---------- Compact summary and export ----------
summary = {
    "retrieval": retrieval_results,
    "rag_with_fallback": rag_results_by_mode,
    "rag_without_fallback": rag_results_no_fb,
    "config": {
        "modes": MODES,
        "retrieval_k": K_RETRIEVAL,
        "rag_k": RAG_K,
        "rag_max_length": RAG_MAX_LEN,
        "rag_num_beams": RAG_BEAMS,
        "n_eval": len(EVAL_ITEMS)
    }
}

with open(SUMMARY_JSON, "w", encoding="utf-8") as f:
    json.dump(summary, f, ensure_ascii=False, indent=2)

with open(PER_QUERY_JSON, "w", encoding="utf-8") as f:
    for key, records in per_query_all.items():
        for r in records:
            r_out = dict(r)
            r_out["mode"] = key
            f.write(json.dumps(r_out, ensure_ascii=False) + "\n")

# ---------- Human factuality sampling ----------
sample_count = export_human_samples(per_query_all.get("hybrid_score_with_fb", []), sample_size=25, path=HUMAN_SAMPLES_JSON)

print("\n=== Summary (saved) ===")
print(f"- Summary JSON: {SUMMARY_JSON}")
print(f"- Per-query logs: {PER_QUERY_JSON}")
print(f"- Human samples for factuality (n={sample_count}): {HUMAN_SAMPLES_JSON}")


=== Retrieval Quality (Recall@1 / Recall@5 / MRR) ===
- bm25: Recall@1=0.830 | Recall@5=0.950 | MRR=0.880
- dense: Recall@1=0.780 | Recall@5=0.940 | MRR=0.857
- hybrid_interleave: Recall@1=0.780 | Recall@5=0.940 | MRR=0.857
- hybrid_score: Recall@1=0.860 | Recall@5=1.000 | MRR=0.928
- hybrid_rrf: Recall@1=0.820 | Recall@5=0.990 | MRR=0.894

=== RAG Generation & End-to-End (with fallback) ===
- bm25: BLEU=12.48, chrF=28.41, F1_mean=0.278, latency_mean_s=0.907, fallback_rate=0.29
- dense: BLEU=14.48, chrF=30.60, F1_mean=0.313, latency_mean_s=0.626, fallback_rate=0.32
- hybrid_interleave: BLEU=14.48, chrF=30.60, F1_mean=0.313, latency_mean_s=0.616, fallback_rate=0.32
- hybrid_score: BLEU=14.75, chrF=31.82, F1_mean=0.328, latency_mean_s=0.629, fallback_rate=0.27
- hybrid_rrf: BLEU=14.59, chrF=31.26, F1_mean=0.320, latency_mean_s=0.623, fallback_rate=0.29

=== Ablation: RAG without extractive QA fallback (generator-only on retrieved context) ===
- bm25: BLEU=11.26, chrF=26.12, F1_mean=0.249

In [None]:
# Check that bm25 index object and corpus mapping exist
print("bm25_retrieve exists:", 'bm25_retrieve' in globals())
if 'bm25_index' in globals():
    try:
        print("bm25_index type:", type(bm25_index))
        # If your BM25 implementation exposes doc count:
        print("bm25_index doc count (if available):", getattr(bm25_index, 'doc_count', 'unknown'))
    except Exception as e:
        print("bm25_index introspect error:", e)
else:
    print("bm25_index not found in globals.")


In [None]:
import inspect
if 'bm25_retrieve' in globals():
    print(inspect.getsource(bm25_retrieve))
else:
    print("bm25_retrieve not defined in this session.")


In [None]:
sample_qs = [eval_queries[0]["query"], eval_queries[1]["query"], "کووڈ-19 کی عام علامات کیا ہیں؟"]
for q in sample_qs:
    print("\nQuery:", q)
    try:
        hits = bm25_retrieve(q, k=10)  # adjust if your wrapper signature differs
        for i, (pid, score, txt) in enumerate(hits[:10], 1):
            print(f"{i}. {pid} score={score} text_preview={txt[:120]}")
    except Exception as e:
        print("bm25_retrieve call error:", e)


In [None]:
# If your BM25 corpus mapping is named bm25_corpus or similar, inspect keys
if 'bm25_corpus' in globals():
    sample_keys = list(bm25_corpus.keys())[:10]
    print("bm25_corpus sample keys:", sample_keys)
else:
    print("No bm25_corpus variable found; check your BM25 build step.")
# Check overlap with pid2text
if 'pid2text' in globals() and 'bm25_corpus' in globals():
    overlap = set(pid2text.keys()) & set(bm25_corpus.keys())
    print("Overlap count between pid2text and bm25_corpus:", len(overlap))


In [None]:
q = eval_queries[0]["query"]
print("Dense top-5:", [pid for pid,_,_ in dense_retrieve(q, k=5)])
print("BM25 top-5:", [pid for pid,_,_ in bm25_retrieve(q, k=5)])


In [None]:
# Cell 13: Full evaluation run (may take long). Use max_examples to limit.
results = evaluate_rag(eval_queries, k=5, mode="hybrid", max_examples=100)  # set None to run all
print("Results:", results)


In [None]:
# Cell 14: Ablation experiments: compare retrieval modes and cross-lingual fallback
modes = ["bm25", "dense", "hybrid"]
ablation_results = {}
for mode in modes:
    print("Evaluating mode:", mode)
    ablation_results[mode] = evaluate_rag(eval_queries, k=5, mode=mode, max_examples=100)
print("Ablation summary:", ablation_results)
