In [4]:
# =========================
# MedMCQA RAG (E5 + FAISS + BM25 Hybrid) — FULL RUNNABLE NOTEBOOK CODE
# - KB: train -> (Q-only for retrieval, store Q/A)  ✅ no ABCD problem text in KB
# - Retrieval: Dense (E5) topK -> BM25 fuse -> near-dup rerank -> selective evidence (only add A when very similar)
# - Scoring: default = letter logits (A/B/C/D). Optional = option_text_logprob (stronger but slower)
# =========================

# ---- Cell 0: installs (run once) ----
# !pip install -q faiss-cpu rank-bm25 tqdm transformers

import os, re, json, time, pickle, math, random
from pathlib import Path
from dataclasses import dataclass
from typing import List, Dict, Tuple, Optional

import numpy as np
from tqdm import tqdm

import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM

import faiss
from rank_bm25 import BM25Okapi

In [3]:
%pip install -U rank-bm25

Collecting rank-bm25
  Using cached rank_bm25-0.2.2-py3-none-any.whl.metadata (3.2 kB)
Using cached rank_bm25-0.2.2-py3-none-any.whl (8.6 kB)
Installing collected packages: rank-bm25
Successfully installed rank-bm25-0.2.2
Note: you may need to restart the kernel to use updated packages.


In [5]:
# ---- Cell 1: CONFIG ----

# ========== Data paths ==========
MEDMCQA_TRAIN_FILE = "./data/medmcqa/train.json"  # <-- change
MEDMCQA_DEV_FILE   = "./data/medmcqa/dev.json"    # <-- change

# ========== Model (for answering MCQ) ==========
# Example: GPT-2
BASE_MODEL_PATH = "./gpt2"  # or "gpt2" if you use HF hub (requires net/cache)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32

# ========== Retrieval embedding model ==========
E5_MODEL_NAME = "intfloat/e5-base-v2"  # you already built this successfully

# ========== KB build outputs ==========
KB_JSONL = "./kb/medmcqa_q_only_kb_train.filtered.jsonl"
KB_DIR   = "./rag_cache/medmcqa_q_only_kb_train.filtered_e5"  # will contain kb.index, docs.pkl

# ========== Retrieval hyperparams ==========
K_VEC   = 400    # dense candidates
K_FINAL = 50     # after fusion, keep this many for near-dup rerank
ALPHA   = 0.35   # fusion: alpha*dense + (1-alpha)*bm25. smaller -> more BM25

MIN_SIM_KEEP    = 0.18  # near-dup jaccard threshold to keep
MIN_SIM_FOR_ANS = 0.22  # only include A in evidence if near-dup >= this

EVID_MAX_ITEMS  = 3
CTX_MAX_CHARS   = 600

# ========== Scoring ==========
# "letter" (fast): score A/B/C/D token logits
# "option_logprob" (stronger but slower): score logP(option_text | prompt)
SCORING = "option_logprob"  # change to "option_logprob" if needed

# ========== Output ==========
OUT_DIR = "./eval_out"
Path(OUT_DIR).mkdir(parents=True, exist_ok=True)


In [6]:
# ---- Cell 2: Helpers (dataset + normalization) ----

BAD_ANS = {
  "all of the above", "none of the above",
  "both a and b", "both b and c", "both a and c",
  "a and b", "b and c", "a and c",
  "a & b", "b & c", "a & c",
  "a, b and c", "a, b, c", "a,b and c",
}

def load_json_or_jsonl(path: str):
    p = Path(path)
    text = p.read_text(encoding="utf-8", errors="ignore").strip()
    try:
        obj = json.loads(text)
        return obj
    except json.JSONDecodeError:
        rows = []
        with p.open("r", encoding="utf-8", errors="ignore") as f:
            for line in f:
                line = line.strip()
                if line:
                    rows.append(json.loads(line))
        return rows

def normalize_gold_cop(cop: int) -> str:
    return "ABCD"[int(cop) - 1]

def is_bad_answer(a: str) -> bool:
    if a is None:
        return True
    x = " ".join(str(a).lower().split())
    if x in BAD_ANS:
        return True
    if len(x) <= 2:
        return True
    return False

def normalize_q_for_retrieval(q: str) -> str:
    """Make entities (drug/disease/material) more salient for retrieval."""
    q0 = " ".join(str(q).strip().split())
    ql = q0.lower().strip()

    # mechanism of action of X:
    m = re.match(r"^mechanism of action of\s+(.+?)[\:\?]?$", ql)
    if m:
        x = m.group(1).strip()
        return f"{x} mechanism of action"

    # features/findings of X except
    m = re.match(r"^all of the following are (?:features|findings) of\s+(.+?)\s+except.*$", ql)
    if m:
        x = m.group(1).strip()
        return f"{x} features except"

    return q0

def load_medmcqa_split(path: str) -> List[dict]:
    data = load_json_or_jsonl(path)
    if isinstance(data, dict) and "data" in data and isinstance(data["data"], list):
        data = data["data"]
    assert isinstance(data, list), f"Unexpected format in {path}"

    samples = []
    for i, ex in enumerate(data):
        q = ex.get("question") or ex.get("ques") or ex.get("query")
        if not q:
            continue
        opts = {"A": ex.get("opa"), "B": ex.get("opb"), "C": ex.get("opc"), "D": ex.get("opd")}
        cop = ex.get("cop")
        if cop is None:
            continue
        gold = normalize_gold_cop(cop)
        samples.append({
            "id": ex.get("id", i),
            "question": " ".join(str(q).strip().split()),
            "options": {k: " ".join(str(v).strip().split()) for k, v in opts.items()},
            "gold": gold,
            "raw": ex
        })
    return samples

dev_samples = load_medmcqa_split(MEDMCQA_DEV_FILE)
print("Loaded dev samples:", len(dev_samples))
print("Example:", dev_samples[0]["question"])


Loaded dev samples: 4183
Example: Which of the following is not true for myelinated nerve fibers:


In [7]:
# ---- Cell 3: Build KB JSONL (train -> filtered Q-only retrieval docs) ----
# Output format per line:
# {"id":..., "q":..., "a":..., "text":...}  where text is used for retrieval (Q normalized), evidence uses q/a.

def build_kb_jsonl(train_file: str, out_jsonl: str):
    data = load_json_or_jsonl(train_file)
    if isinstance(data, dict) and "data" in data and isinstance(data["data"], list):
        data = data["data"]
    assert isinstance(data, list), f"Unexpected format in {train_file}"

    outp = Path(out_jsonl)
    outp.parent.mkdir(parents=True, exist_ok=True)

    kept = 0
    skipped_bad = 0
    skipped_missing = 0

    with outp.open("w", encoding="utf-8") as f:
        for i, ex in enumerate(data):
            q = ex.get("question") or ex.get("ques") or ex.get("query")
            if not q:
                skipped_missing += 1
                continue

            options = {"A": ex.get("opa"), "B": ex.get("opb"), "C": ex.get("opc"), "D": ex.get("opd")}
            cop = ex.get("cop")
            if cop is None:
                skipped_missing += 1
                continue

            gold = normalize_gold_cop(cop)
            a = options.get(gold)

            if not a or is_bad_answer(a):
                skipped_bad += 1
                continue

            q_clean = " ".join(str(q).strip().split())
            a_clean = " ".join(str(a).strip().split())

            doc = {
                "id": ex.get("id", i),
                "q": q_clean,
                "a": a_clean,
                "text": normalize_q_for_retrieval(q_clean),  # retrieval text only
            }
            f.write(json.dumps(doc, ensure_ascii=False) + "\n")
            kept += 1

    print("KB JSONL wrote:", kept, "->", outp)
    print("Skipped bad answers:", skipped_bad, "| skipped missing:", skipped_missing)

# Build if missing
if not Path(KB_JSONL).exists():
    build_kb_jsonl(MEDMCQA_TRAIN_FILE, KB_JSONL)
else:
    print("KB JSONL exists:", KB_JSONL)


KB JSONL exists: ./kb/medmcqa_q_only_kb_train.filtered.jsonl


In [8]:
# ---- Cell 4: Build E5 FAISS index (passage: text) + docs.pkl ----

def mean_pool(last_hidden, attn_mask):
    mask = attn_mask.unsqueeze(-1).to(last_hidden.dtype)
    summed = (last_hidden * mask).sum(dim=1)
    denom = mask.sum(dim=1).clamp(min=1e-6)
    return summed / denom

@torch.no_grad()
def encode_e5_texts(texts: List[str], tok, mdl, device, max_len=256, bs=64) -> np.ndarray:
    vecs = []
    for i in tqdm(range(0, len(texts), bs), desc="embed(e5)"):
        batch = texts[i:i+bs]
        t = tok(batch, padding=True, truncation=True, max_length=max_len, return_tensors="pt")
        t = {k: v.to(device) for k, v in t.items()}
        out = mdl(**t)
        emb = mean_pool(out.last_hidden_state, t["attention_mask"])
        emb = F.normalize(emb, p=2, dim=1)
        vecs.append(emb.float().cpu().numpy())
    return np.vstack(vecs)

def build_faiss_e5(kb_jsonl: str, out_dir: str, e5_name: str):
    out_dir = Path(out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    docs = []
    passages = []
    with Path(kb_jsonl).open("r", encoding="utf-8") as f:
        for line in f:
            j = json.loads(line)
            docs.append(j)  # dict with q/a/text
            passages.append("passage: " + j["text"])

    dev = "cuda" if torch.cuda.is_available() else "cpu"
    tok = AutoTokenizer.from_pretrained(e5_name, use_fast=True)
    mdl = AutoModel.from_pretrained(e5_name).to(dev).eval()
    if dev == "cuda":
        mdl.half()

    X = encode_e5_texts(passages, tok, mdl, dev, max_len=256, bs=64).astype(np.float32)
    dim = X.shape[1]
    index = faiss.IndexFlatIP(dim)
    index.add(X)

    faiss.write_index(index, str(out_dir / "kb.index"))
    (out_dir / "docs.pkl").write_bytes(pickle.dumps(docs))

    print("Saved:", out_dir / "kb.index")
    print("Saved:", out_dir / "docs.pkl")
    print("ntotal:", index.ntotal)

# Build if missing
if not (Path(KB_DIR) / "kb.index").exists() or not (Path(KB_DIR) / "docs.pkl").exists():
    build_faiss_e5(KB_JSONL, KB_DIR, E5_MODEL_NAME)
else:
    print("FAISS index exists:", KB_DIR)


FAISS index exists: ./rag_cache/medmcqa_q_only_kb_train.filtered_e5


In [9]:
# ---- Cell 5: Load KB (FAISS + docs) + Build BM25 ----

def load_kb(kb_dir: str):
    kb_dir = Path(kb_dir)
    index = faiss.read_index(str(kb_dir / "kb.index"))
    docs = pickle.loads((kb_dir / "docs.pkl").read_bytes())
    assert isinstance(docs, list) and isinstance(docs[0], dict), "docs.pkl must be list[dict]"
    return index, docs

index, kb_docs = load_kb(KB_DIR)
print("Loaded KB:", len(kb_docs), "docs")

def bm25_tokenize(s: str):
    s = s.lower()
    s = re.sub(r"[^a-z0-9\s]+", " ", s)
    return [t for t in s.split() if len(t) >= 2]

bm25 = BM25Okapi([bm25_tokenize(d["q"]) for d in kb_docs])
print("BM25 ready")


Loaded KB: 176438 docs
BM25 ready


In [10]:
# ---- Cell 6: Load E5 encoder for queries (must match index) ----

embed_device = "cuda" if torch.cuda.is_available() else "cpu"
e5_tok = AutoTokenizer.from_pretrained(E5_MODEL_NAME, use_fast=True)
e5_model = AutoModel.from_pretrained(E5_MODEL_NAME).to(embed_device).eval()
if embed_device == "cuda":
    e5_model.half()

@torch.no_grad()
def e5_encode_query(text: str) -> np.ndarray:
    t = "query: " + text
    tok = e5_tok(t, return_tensors="pt", truncation=True, max_length=256)
    tok = {k: v.to(embed_device) for k, v in tok.items()}
    out = e5_model(**tok)
    emb = mean_pool(out.last_hidden_state, tok["attention_mask"])
    emb = F.normalize(emb, p=2, dim=1)
    return emb[0].float().cpu().numpy().astype(np.float32)


In [11]:
# ---- Cell 7: Hybrid retrieval + near-dup rerank + selective evidence ----

def toks_for_jaccard(s: str):
    s = s.lower()
    s = re.sub(r"[^a-z0-9\s]+", " ", s)
    return [t for t in s.split() if len(t) >= 3]

def jaccard(a: List[str], b: List[str]) -> float:
    A, B = set(a), set(b)
    if not A or not B:
        return 0.0
    return len(A & B) / len(A | B)

def near_dup_score(q: str, hit_q: str) -> float:
    return jaccard(toks_for_jaccard(q), toks_for_jaccard(hit_q))

def faiss_search(qvec: np.ndarray, k: int) -> Tuple[List[int], List[float]]:
    D, I = index.search(qvec.reshape(1, -1), k)
    return I[0].tolist(), D[0].tolist()

def retrieve_topk_hybrid(question: str, k_vec=K_VEC, k_final=K_FINAL, alpha=ALPHA) -> List[int]:
    q_norm = normalize_q_for_retrieval(question)
    qvec = e5_encode_query(q_norm)
    cand_ids, dense_scores = faiss_search(qvec, k_vec)

    # BM25 scores over all, slice to candidates
    qtok = bm25_tokenize(question)
    bm25_scores_all = bm25.get_scores(qtok)
    bm25_scores = np.array([bm25_scores_all[i] for i in cand_ids], dtype=np.float32)
    dense_scores = np.array(dense_scores, dtype=np.float32)

    def norm(x):
        if float(x.max() - x.min()) < 1e-6:
            return np.zeros_like(x)
        return (x - x.min()) / (x.max() - x.min())

    fused = alpha * norm(dense_scores) + (1 - alpha) * norm(bm25_scores)
    order = np.argsort(-fused)[:k_final]
    return [cand_ids[i] for i in order]

def rerank_and_filter(question: str, cand_ids: List[int], min_sim=MIN_SIM_KEEP) -> List[int]:
    scored = []
    for did in cand_ids:
        s = near_dup_score(question, kb_docs[did]["q"])
        scored.append((s, did))
    scored.sort(reverse=True)
    kept = [did for s, did in scored if s >= min_sim]
    return kept if kept else [scored[0][1]]  # never empty

def build_evidence(question: str, hit_ids: List[int],
                   max_items=EVID_MAX_ITEMS, max_chars=CTX_MAX_CHARS,
                   min_sim_for_A=MIN_SIM_FOR_ANS) -> str:
    blocks = []
    for did in hit_ids[:max_items]:
        item = kb_docs[did]
        sim = near_dup_score(question, item["q"])
        if sim >= min_sim_for_A:
            blocks.append(f'Q: {item["q"]}\nA: {item["a"]}')
        else:
            blocks.append(f'Related Q: {item["q"]}')
    ev = "\n\n".join(blocks)
    return ev[:max_chars]

def sanity_show_hits(samples: List[dict], n=10):
    for s in random.sample(samples, n):
        q = s["question"]
        cand = retrieve_topk_hybrid(q)
        hit = rerank_and_filter(q, cand)
        top = kb_docs[hit[0]]
        print("="*120)
        print("Q:", q)
        print("Gold:", s["gold"])
        print("Hit Q:", top["q"][:200])
        print("Hit A:", top["a"][:200])

sanity_show_hits(dev_samples, n=8)


Q: Retraction of mandible is achieved by:
Gold: B
Hit Q: Maximum amount of incisor retraction achieved is:
Hit A: 7 mm
Q: Technique of root coverage called as:
Gold: B
Hit Q: Given technique is called as:
Hit A: Crown down technique
Q: MTA barrier in open apex is made up to?
Gold: B
Hit Q: Barrier method
Hit A: Condom
Q: %lost radio-resistant cells in retina
Gold: B
Hit Q: Cells are most radio-resistant in
Hit A: S phase
Q: In symphyseal fracture with lag screw fixation?
Gold: B
Hit Q: Long bone fracture fixation done with -
Hit A: Intramedullary nail
Q: A lady delivered a normal vaginal delivery and was discharged. On third day she came back with fever, tachycardia and seizures. Fundus showed papilledema with no focal deficits. What is the most likely diagnosis?
Gold: A
Hit Q: A child from West Bengal presents with fever & unconsciousness for 1 day and pallor with no focal neurodeficit. What is the most probable diagnosis?
Hit A: Cerebral malaria
Q: Euphemism pudding paste is used for

In [12]:
# ---- Cell 8: Load MCQ Answering Model (base) ----

def load_mcq_model(model_path: str, device: str, dtype: torch.dtype):
    tok = AutoTokenizer.from_pretrained(model_path, use_fast=True)
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token
    mdl = AutoModelForCausalLM.from_pretrained(model_path)
    mdl.to(device)
    mdl.eval()
    if device == "cuda":
        mdl = mdl.half()
    return tok, mdl

mcq_tok, mcq_model = load_mcq_model(BASE_MODEL_PATH, DEVICE, DTYPE)
print("Loaded MCQ model:", BASE_MODEL_PATH)


Loaded MCQ model: ./gpt2


In [13]:
# ---- Cell 9: Prompt + Scoring ----

def build_prompt_base(question: str, opts: Dict[str, str]) -> str:
    # Keep prompt minimal (too long hurts small models)
    return (
        "Choose the correct option (A, B, C, or D).\n\n"
        f"Question: {question}\n"
        f"A) {opts['A']}\n"
        f"B) {opts['B']}\n"
        f"C) {opts['C']}\n"
        f"D) {opts['D']}\n"
        "Answer:"
    )

def build_prompt_rag(question: str, opts: Dict[str, str], evidence: str) -> str:
    return (
        "Use the evidence to answer. Choose A, B, C, or D.\n\n"
        f"Evidence:\n{evidence}\n\n"
        f"Question: {question}\n"
        f"A) {opts['A']}\n"
        f"B) {opts['B']}\n"
        f"C) {opts['C']}\n"
        f"D) {opts['D']}\n"
        "Answer:"
    )

def _choice_token_ids(tok, ch: str) -> List[int]:
    # Score multiple possible tokenizations
    candidates = [ch, " " + ch, "\n" + ch]
    ids = []
    for s in candidates:
        t = tok.encode(s, add_special_tokens=False)
        if len(t) == 1:
            ids.append(t[0])
    return list(dict.fromkeys(ids))  # unique

CHOICE_TOKEN_IDS = {c: _choice_token_ids(mcq_tok, c) for c in ["A","B","C","D"]}
print("Choice token ids:", CHOICE_TOKEN_IDS)

@torch.no_grad()
def score_letter_logits(prompt: str) -> Dict[str, float]:
    t = mcq_tok(prompt, return_tensors="pt", truncation=True, max_length=512)
    t = {k: v.to(DEVICE) for k, v in t.items()}
    out = mcq_model(**t)
    logits = out.logits[0, -1]  # next-token logits
    scores = {}
    for c in ["A","B","C","D"]:
        ids = CHOICE_TOKEN_IDS[c]
        if not ids:
            scores[c] = -1e9
        else:
            scores[c] = float(torch.max(logits[ids]).item())
    return scores

@torch.no_grad()
def score_option_logprob(prompt: str, option_text: str, max_len: int = 512) -> float:
    # logP(option_text | prompt)
    full = prompt + " " + option_text
    tok_full = mcq_tok(full, return_tensors="pt", truncation=True, max_length=max_len)
    tok_full = {k:v.to(DEVICE) for k,v in tok_full.items()}
    input_ids = tok_full["input_ids"]
    attn = tok_full["attention_mask"]

    tok_p = mcq_tok(prompt, return_tensors="pt", truncation=True, max_length=max_len)
    prompt_len = tok_p["input_ids"].shape[1]

    out = mcq_model(input_ids=input_ids, attention_mask=attn)
    logits = out.logits  # [1,T,V]
    targets = input_ids[:, prompt_len:]
    if targets.numel() == 0:
        return -1e9

    logp = 0.0
    for j in range(targets.shape[1]):
        tpos = prompt_len + j
        if tpos == 0:
            continue
        lp = F.log_softmax(logits[0, tpos-1], dim=-1)[targets[0, j]].item()
        logp += lp
    return float(logp)

def predict_one(sample: dict, mode: str, ctx_max_chars: int = CTX_MAX_CHARS) -> dict:
    q = sample["question"]
    opts = sample["options"]
    gold = sample["gold"]

    if mode == "base":
        prompt = build_prompt_base(q, opts)
        evidence = ""
    elif mode == "rag_q":
        cand = retrieve_topk_hybrid(q)
        hit = rerank_and_filter(q, cand)
        evidence = build_evidence(q, hit, max_items=EVID_MAX_ITEMS, max_chars=ctx_max_chars)
        prompt = build_prompt_rag(q, opts, evidence)
    else:
        raise ValueError("mode must be 'base' or 'rag_q'")

    if SCORING == "letter":
        scores = score_letter_logits(prompt)
        pred = max(scores, key=scores.get)
    elif SCORING == "option_logprob":
        scores = {c: score_option_logprob(prompt, opts[c]) for c in ["A","B","C","D"]}
        pred = max(scores, key=scores.get)
    else:
        raise ValueError("SCORING must be 'letter' or 'option_logprob'")

    return {
        "id": sample["id"],
        "gold": gold,
        "pred": pred,
        "scores": scores,
        "rag_context": evidence,
    }


Choice token ids: {'A': [32, 317], 'B': [33, 347], 'C': [34, 327], 'D': [35, 360]}


In [14]:
# ---- Cell 10: Metrics ----

def confusion_matrix(records: List[dict]) -> np.ndarray:
    idx = {c:i for i,c in enumerate(["A","B","C","D"])}
    cm = np.zeros((4,4), dtype=int)
    for r in records:
        cm[idx[r["gold"]], idx[r["pred"]]] += 1
    return cm

def precision_recall_f1(cm: np.ndarray):
    # per-class
    eps = 1e-9
    tp = np.diag(cm).astype(np.float32)
    pred_sum = cm.sum(axis=0).astype(np.float32)
    gold_sum = cm.sum(axis=1).astype(np.float32)

    prec = tp / (pred_sum + eps)
    rec  = tp / (gold_sum + eps)
    f1   = 2*prec*rec / (prec+rec+eps)
    return prec, rec, f1, gold_sum

def summarize(records: List[dict]) -> dict:
    cm = confusion_matrix(records)
    prec, rec, f1, support = precision_recall_f1(cm)

    acc = float(np.trace(cm) / max(1, cm.sum()))
    macro_p = float(np.mean(prec))
    macro_r = float(np.mean(rec))
    macro_f = float(np.mean(f1))
    weighted_f = float(np.sum(f1 * support) / max(1, np.sum(support)))
    weighted_p = float(np.sum(prec * support) / max(1, np.sum(support)))
    weighted_r = float(np.sum(rec * support) / max(1, np.sum(support)))

    return {
        "acc": acc,
        "macro_p": macro_p, "macro_r": macro_r, "macro_f1": macro_f,
        "weighted_p": weighted_p, "weighted_r": weighted_r, "weighted_f1": weighted_f,
        "cm": cm,
        "per_class": {
            "A": {"p": float(prec[0]), "r": float(rec[0]), "f1": float(f1[0]), "support": int(support[0])},
            "B": {"p": float(prec[1]), "r": float(rec[1]), "f1": float(f1[1]), "support": int(support[1])},
            "C": {"p": float(prec[2]), "r": float(rec[2]), "f1": float(f1[2]), "support": int(support[2])},
            "D": {"p": float(prec[3]), "r": float(rec[3]), "f1": float(f1[3]), "support": int(support[3])},
        }
    }

def fix_hurt(base_records: List[dict], rag_records: List[dict]) -> dict:
    base_map = {r["id"]: r for r in base_records}
    fix = hurt = same_correct = same_wrong = 0
    for r in rag_records:
        b = base_map[r["id"]]
        base_ok = (b["pred"] == b["gold"])
        rag_ok = (r["pred"] == r["gold"])
        if (not base_ok) and rag_ok:
            fix += 1
        elif base_ok and (not rag_ok):
            hurt += 1
        elif base_ok and rag_ok:
            same_correct += 1
        else:
            same_wrong += 1
    return {"fix": fix, "hurt": hurt, "same_correct": same_correct, "same_wrong": same_wrong, "n": len(rag_records)}


In [15]:
# ---- Cell 11: Evaluate (Base vs RAG-Q) ----

def eval_mode(samples: List[dict], mode: str, ctx_max_chars: int = CTX_MAX_CHARS, limit: Optional[int] = None) -> List[dict]:
    out = []
    it = samples[:limit] if limit else samples
    for s in tqdm(it, desc=f"eval:{mode}"):
        out.append(predict_one(s, mode=mode, ctx_max_chars=ctx_max_chars))
    return out

# Quick smoke test (10 samples)
_ = eval_mode(dev_samples, "base", limit=10)
_ = eval_mode(dev_samples, "rag_q", limit=10)
print("Smoke test ok.")


eval:base: 100%|██████████| 10/10 [00:00<00:00, 22.77it/s]
eval:rag_q: 100%|██████████| 10/10 [00:04<00:00,  2.11it/s]

Smoke test ok.





In [16]:
# ---- Cell 12: Full run + Save ----

ts = time.strftime("%Y%m%d_%H%M%S")
run_name = f"medmcqa_rag_e5_hybrid_{SCORING}_{ts}"
out_path_base = Path(OUT_DIR) / f"{run_name}.base.jsonl"
out_path_rag  = Path(OUT_DIR) / f"{run_name}.ragq.jsonl"

base_records = eval_mode(dev_samples, "base", ctx_max_chars=CTX_MAX_CHARS)
ragq_records = eval_mode(dev_samples, "rag_q", ctx_max_chars=CTX_MAX_CHARS)

base_sum = summarize(base_records)
rag_sum  = summarize(ragq_records)
fh = fix_hurt(base_records, ragq_records)

print("\n==== RESULTS ====")
print("Base ACC:", base_sum["acc"], "MacroF1:", base_sum["macro_f1"], "WeightedF1:", base_sum["weighted_f1"])
print("RAGQ ACC:", rag_sum["acc"],  "MacroF1:", rag_sum["macro_f1"],  "WeightedF1:", rag_sum["weighted_f1"])
print("Fix/Hurt:", fh)

# Save JSONL
with out_path_base.open("w", encoding="utf-8") as f:
    for r in base_records:
        f.write(json.dumps(r, ensure_ascii=False) + "\n")

with out_path_rag.open("w", encoding="utf-8") as f:
    for r in ragq_records:
        f.write(json.dumps(r, ensure_ascii=False) + "\n")

print("Saved:")
print(" -", out_path_base)
print(" -", out_path_rag)

print("\nConfusion Matrix (rows=Gold, cols=Pred) RAGQ:")
print(rag_sum["cm"])


eval:base:   1%|          | 33/4183 [00:01<02:08, 32.32it/s]


KeyboardInterrupt: 

In [17]:
SCORING = "option_logprob"
CTX_MAX_CHARS = 600

In [None]:
base_records = eval_mode(dev_samples, "base", ctx_max_chars=600)
ragq_records = eval_mode(dev_samples, "rag_q", ctx_max_chars=600)

print("Base ACC:", summarize(base_records)["acc"])
print("RAGQ ACC:", summarize(ragq_records)["acc"])
print("Fix/Hurt:", fix_hurt(base_records, ragq_records))

eval:base: 100%|██████████| 4183/4183 [02:06<00:00, 33.13it/s]
eval:rag_q:  38%|███▊      | 1604/4183 [16:46<20:03,  2.14it/s]  

In [18]:
MIN_SIM_FOR_ANS = 0.35   # 或 0.40

In [19]:
ragq_035 = eval_mode(dev_samples, "rag_q", ctx_max_chars=600)
print("RAGQ ACC:", summarize(ragq_035)["acc"])
print("Fix/Hurt:", fix_hurt(base_records, ragq_035))

eval:rag_q: 100%|██████████| 4183/4183 [43:14<00:00,  1.61it/s]  

RAGQ ACC: 0.31173798709060485
Fix/Hurt: {'fix': 326, 'hurt': 225, 'same_correct': 978, 'same_wrong': 2654, 'n': 4183}





In [16]:
def build_evidence(question: str, hit_ids,
                   max_items=3, max_chars=600,
                   min_sim_for_A=0.22):
    blocks = []
    for did in hit_ids[:max_items]:
        item = kb_docs[did]
        sim = near_dup_score(question, item["q"])
        if sim >= min_sim_for_A:
            blocks.append(f'Q: {item["q"]}\nA: {item["a"]}')
        else:
            blocks.append(f'Related Q: {item["q"]}')
    ev = "\n\n".join(blocks)
    return ev[:max_chars]


In [17]:
# ---- Cell 9: Prompt + Scoring ----

def build_prompt_base(question: str, opts: Dict[str, str]) -> str:
    # Keep prompt minimal (too long hurts small models)
    return (
        "Choose the correct option (A, B, C, or D).\n\n"
        f"Question: {question}\n"
        f"A) {opts['A']}\n"
        f"B) {opts['B']}\n"
        f"C) {opts['C']}\n"
        f"D) {opts['D']}\n"
        "Answer:"
    )

def build_prompt_rag(question: str, opts: Dict[str, str], evidence: str) -> str:
    return (
        "Use the evidence to answer. Choose A, B, C, or D.\n\n"
        f"Evidence:\n{evidence}\n\n"
        f"Question: {question}\n"
        f"A) {opts['A']}\n"
        f"B) {opts['B']}\n"
        f"C) {opts['C']}\n"
        f"D) {opts['D']}\n"
        "Answer:"
    )

def _choice_token_ids(tok, ch: str) -> List[int]:
    # Score multiple possible tokenizations
    candidates = [ch, " " + ch, "\n" + ch]
    ids = []
    for s in candidates:
        t = tok.encode(s, add_special_tokens=False)
        if len(t) == 1:
            ids.append(t[0])
    return list(dict.fromkeys(ids))  # unique

CHOICE_TOKEN_IDS = {c: _choice_token_ids(mcq_tok, c) for c in ["A","B","C","D"]}
print("Choice token ids:", CHOICE_TOKEN_IDS)

@torch.no_grad()
def score_letter_logits(prompt: str) -> Dict[str, float]:
    t = mcq_tok(prompt, return_tensors="pt", truncation=True, max_length=512)
    t = {k: v.to(DEVICE) for k, v in t.items()}
    out = mcq_model(**t)
    logits = out.logits[0, -1]  # next-token logits
    scores = {}
    for c in ["A","B","C","D"]:
        ids = CHOICE_TOKEN_IDS[c]
        if not ids:
            scores[c] = -1e9
        else:
            scores[c] = float(torch.max(logits[ids]).item())
    return scores

@torch.no_grad()
def score_option_logprob(prompt: str, option_text: str, max_len: int = 512) -> float:
    # logP(option_text | prompt)
    full = prompt + " " + option_text
    tok_full = mcq_tok(full, return_tensors="pt", truncation=True, max_length=max_len)
    tok_full = {k:v.to(DEVICE) for k,v in tok_full.items()}
    input_ids = tok_full["input_ids"]
    attn = tok_full["attention_mask"]

    tok_p = mcq_tok(prompt, return_tensors="pt", truncation=True, max_length=max_len)
    prompt_len = tok_p["input_ids"].shape[1]

    out = mcq_model(input_ids=input_ids, attention_mask=attn)
    logits = out.logits  # [1,T,V]
    targets = input_ids[:, prompt_len:]
    if targets.numel() == 0:
        return -1e9

    logp = 0.0
    for j in range(targets.shape[1]):
        tpos = prompt_len + j
        if tpos == 0:
            continue
        lp = F.log_softmax(logits[0, tpos-1], dim=-1)[targets[0, j]].item()
        logp += lp
    return float(logp)

def predict_one(sample: dict, mode: str, ctx_max_chars: int = CTX_MAX_CHARS) -> dict:
    q = sample["question"]
    opts = sample["options"]
    gold = sample["gold"]

    if mode == "base":
        prompt = build_prompt_base(q, opts)
        evidence = ""
    elif mode == "rag_q":
        cand = retrieve_topk_hybrid(q)
        hit = rerank_and_filter(q, cand)
        evidence = build_evidence(q, hit, max_items=EVID_MAX_ITEMS, max_chars=ctx_max_chars, min_sim_for_A=MIN_SIM_FOR_ANS)
        prompt = build_prompt_rag(q, opts, evidence)
    else:
        raise ValueError("mode must be 'base' or 'rag_q'")

    if SCORING == "letter":
        scores = score_letter_logits(prompt)
        pred = max(scores, key=scores.get)
    elif SCORING == "option_logprob":
        scores = {c: score_option_logprob(prompt, opts[c]) for c in ["A","B","C","D"]}
        pred = max(scores, key=scores.get)
    else:
        raise ValueError("SCORING must be 'letter' or 'option_logprob'")

    return {
        "id": sample["id"],
        "gold": gold,
        "pred": pred,
        "scores": scores,
        "rag_context": evidence,
    }


Choice token ids: {'A': [32, 317], 'B': [33, 347], 'C': [34, 327], 'D': [35, 360]}


In [16]:
def eval_ragq_with_params(samples, base_records_subset, min_sim_for_ans, ctx_max_chars=600, limit=500):
    # 临时改全局阈值（predict_one 会读它）
    global MIN_SIM_FOR_ANS
    MIN_SIM_FOR_ANS = min_sim_for_ans

    rag = eval_mode(samples, "rag_q", ctx_max_chars=ctx_max_chars, limit=limit)
    acc = summarize(rag)["acc"]
    fh = fix_hurt(base_records_subset, rag)
    return acc, fh

limit = 500
base_500 = eval_mode(dev_samples, "base", ctx_max_chars=600, limit=limit)

cands = [0.22, 0.28, 0.32, 0.35, 0.38, 0.42]
results = []
for th in cands:
    acc, fh = eval_ragq_with_params(dev_samples, base_500, th, ctx_max_chars=600, limit=limit)
    results.append((th, acc, fh["fix"], fh["hurt"]))
    print(f"th={th:.2f}  acc={acc:.4f}  fix={fh['fix']}  hurt={fh['hurt']}")

results

eval:base: 100%|██████████| 500/500 [00:14<00:00, 33.72it/s]
eval:rag_q:   0%|          | 2/500 [00:01<06:35,  1.26it/s]


KeyboardInterrupt: 

In [None]:
MIN_SIM_FOR_ANS = 0.28

ragq_records_028 = eval_mode(dev_samples, "rag_q", ctx_max_chars=600)
print("RAGQ(0.28) ACC:", summarize(ragq_records_028)["acc"])
print("Fix/Hurt:", fix_hurt(base_records, ragq_records_028))

eval:rag_q: 100%|██████████| 4183/4183 [42:38<00:00,  1.63it/s]  

RAGQ(0.28) ACC: 0.31197704996414055
Fix/Hurt: {'fix': 310, 'hurt': 208, 'same_correct': 995, 'same_wrong': 2670, 'n': 4183}





In [18]:
MIN_SIM_FOR_ANS = 0.28

# 0) 如果你的 eval_mode 里是直接用全局变量 MIN_SIM_FOR_ANS，这句能保证它读到最新值
globals()["MIN_SIM_FOR_ANS"] = MIN_SIM_FOR_ANS

import os
import json
import inspect
from datetime import datetime

import numpy as np
import pandas as pd


def _call_eval_mode(samples, mode: str, ctx_max_chars: int, min_sim: float):
    """
    兼容不同 eval_mode 签名：
    - 有的写 eval_mode(samples, mode, ctx_max_chars=..., min_sim_for_ans=...)
    - 有的写 eval_mode(samples, mode, ctx_max_chars=..., MIN_SIM_FOR_ANS=...)
    - 有的根本不收阈值参数，只能靠全局 MIN_SIM_FOR_ANS
    """
    sig = inspect.signature(eval_mode)
    kw = {}
    if "ctx_max_chars" in sig.parameters:
        kw["ctx_max_chars"] = ctx_max_chars

    # 尝试把阈值作为参数传进去（优先常见名字）
    for name in ["min_sim_for_ans", "min_sim", "sim_threshold", "MIN_SIM_FOR_ANS"]:
        if name in sig.parameters:
            kw[name] = min_sim
            break

    return eval_mode(samples, mode, **kw)


def _jsonable(x):
    """把 numpy/torch 等对象转成 json 可序列化类型。"""
    # numpy 标量
    if isinstance(x, (np.integer, np.int64, np.int32)):
        return int(x)
    if isinstance(x, (np.floating, np.float32, np.float64)):
        return float(x)
    if isinstance(x, (np.bool_,)):
        return bool(x)
    # numpy 数组
    if isinstance(x, np.ndarray):
        return x.tolist()

    # torch tensor（如果你 record 里塞了）
    try:
        import torch
        if isinstance(x, torch.Tensor):
            return x.detach().cpu().tolist()
    except Exception:
        pass

    # dict / list 递归
    if isinstance(x, dict):
        return {str(k): _jsonable(v) for k, v in x.items()}
    if isinstance(x, (list, tuple)):
        return [_jsonable(v) for v in x]

    # 其它：保持原样（json dump 失败时会再抛）
    return x


def _fallback_save(records, out_dir: str, prefix: str, extra_meta: dict | None = None):
    os.makedirs(out_dir, exist_ok=True)
    ts = datetime.now().strftime("%Y%m%d_%H%M%S")
    base = f"{prefix}_{ts}"

    jsonl_path = os.path.join(out_dir, base + ".jsonl")
    csv_path   = os.path.join(out_dir, base + ".csv")
    sum_path   = os.path.join(out_dir, base + "_summary.json")

    # jsonl
    with open(jsonl_path, "w", encoding="utf-8") as f:
        for r in records:
            f.write(json.dumps(_jsonable(r), ensure_ascii=False) + "\n")

    # csv（尽量展开）
    try:
        df = pd.DataFrame([_jsonable(r) for r in records])
        df.to_csv(csv_path, index=False)
    except Exception:
        csv_path = None

    # summary
    summ = summarize(records) if callable(summarize) else {}
    payload = {"summary": _jsonable(summ), "meta": _jsonable(extra_meta or {})}
    with open(sum_path, "w", encoding="utf-8") as f:
        json.dump(payload, f, ensure_ascii=False, indent=2)

    return {"jsonl": jsonl_path, "csv": csv_path, "summary": sum_path}


def safe_save_real_records(records, out_dir: str, prefix: str, extra_meta: dict | None = None):
    """
    先用你原来的 save_real_records；
    如果炸（最常见就是 ndarray 不可序列化），就自动 fallback 保存。
    """
    try:
        return save_real_records(records, out_dir=out_dir, prefix=prefix, extra_meta=extra_meta)
    except Exception as e:
        print(f"[Warn] save_real_records failed -> fallback save. Error: {type(e).__name__}: {str(e)[:200]}")
        return _fallback_save(records, out_dir=out_dir, prefix=prefix, extra_meta=extra_meta)


# 1) 先算 base（闭卷）
base_records = _call_eval_mode(dev_samples, "base", ctx_max_chars=600, min_sim=MIN_SIM_FOR_ANS)

# 2) 再算 rag_q（检索 + 阈值）
ragq_records_028 = _call_eval_mode(dev_samples, "rag_q", ctx_max_chars=600, min_sim=MIN_SIM_FOR_ANS)

# 3) 打印指标
base_acc = summarize(base_records)["acc"]
rag_acc  = summarize(ragq_records_028)["acc"]
fix_hurt_stats = fix_hurt(base_records, ragq_records_028)

print("Base ACC:", base_acc)
print(f"RAGQ({MIN_SIM_FOR_ANS}) ACC:", rag_acc)
print("Fix/Hurt:", fix_hurt_stats)

# 4) 保存结果（同时保存 base + ragq）
meta = {"MIN_SIM_FOR_ANS": float(MIN_SIM_FOR_ANS), "mode": "rag_q", "ctx_max_chars": 600}

save_base = safe_save_real_records(
    base_records,
    out_dir="eval_out_notebook/ragq_threshold_sweeps",
    prefix="base_closedbook",
    extra_meta={**meta, "mode": "base"}
)

save_ragq = safe_save_real_records(
    ragq_records_028,
    out_dir="eval_out_notebook/ragq_threshold_sweeps",
    prefix=f"ragq_sim{str(MIN_SIM_FOR_ANS).replace('.','')}",
    extra_meta=meta
)

print("\n[Saved] base:", save_base)
print("[Saved] ragq:", save_ragq)
print(f"\nGain: {rag_acc - base_acc:+.4f}")


eval:base:   0%|          | 0/4183 [00:00<?, ?it/s]

eval:base: 100%|██████████| 4183/4183 [02:22<00:00, 29.34it/s]
eval:rag_q: 100%|██████████| 4183/4183 [36:05<00:00,  1.93it/s] 


Base ACC: 0.2878316997370308
RAGQ(0.28) ACC: 0.31125986134353334
Fix/Hurt: {'fix': 305, 'hurt': 207, 'same_correct': 997, 'same_wrong': 2674, 'n': 4183}
[Warn] save_real_records failed -> fallback save. Error: NameError: name 'save_real_records' is not defined
[Warn] save_real_records failed -> fallback save. Error: NameError: name 'save_real_records' is not defined

[Saved] base: {'jsonl': 'eval_out_notebook/ragq_threshold_sweeps/base_closedbook_20251223_135115.jsonl', 'csv': 'eval_out_notebook/ragq_threshold_sweeps/base_closedbook_20251223_135115.csv', 'summary': 'eval_out_notebook/ragq_threshold_sweeps/base_closedbook_20251223_135115_summary.json'}
[Saved] ragq: {'jsonl': 'eval_out_notebook/ragq_threshold_sweeps/ragq_sim028_20251223_135115.jsonl', 'csv': 'eval_out_notebook/ragq_threshold_sweeps/ragq_sim028_20251223_135115.csv', 'summary': 'eval_out_notebook/ragq_threshold_sweeps/ragq_sim028_20251223_135115_summary.json'}

Gain: +0.0234


In [16]:
save_eval_records(
    ragq_records_028,
    out_dir="eval_out_notebook/ragq_threshold_sweeps",
    prefix="ragq_sim028",
    extra_meta={"MIN_SIM_FOR_ANS": MIN_SIM_FOR_ANS, "mode": "rag_q", "ctx_max_chars": 600}
)

TypeError: Object of type ndarray is not JSON serializable