In [36]:
# =========================
# 0) 环境准备
# =========================
# 建议环境：
# pip install -U torch transformers accelerate faiss-cpu pandas numpy scikit-learn tqdm nbformat

import os, json, math, time, pickle, random, re
from pathlib import Path
from dataclasses import dataclass
from typing import Dict, List, Tuple, Any, Optional

import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

def set_seed(seed: int = 2025):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(2025)
print("OK")


OK


In [37]:

# =========================
# 1) 配置区
# =========================
DEV_FILE = "./data/medmcqa/dev.json"         # <-- 改这里
KB_DIR   = "./rag_cache/medmcqa_train_50k"   # <-- 改这里

# 用于 A/B/C/D 选择打分的 LM
MODEL_NAME_OR_PATH = "./gpt2"               # <-- 需要时改
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# 用于检索 embedding 的 encoder（mean pooling）
EMBED_MODEL_NAME = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext"

# 默认参数（后面会做网格搜索）
DEFAULT_RAG_K = 3
DEFAULT_CTX_MAX_CHARS = 500
DEFAULT_EVID_MAX_SENTS = 4
DEFAULT_SIM_THRESHOLD = None   # None=不做阈值回退（建议先跑通，再用网格找阈值）

RUN_DIR = Path("./eval_out") / f"medmcqa_rag_notebook_{time.strftime('%Y%m%d_%H%M%S')}"
RUN_DIR.mkdir(parents=True, exist_ok=True)
print("RUN_DIR =", RUN_DIR.resolve())


RUN_DIR = /media/miaoen/ad4277ac-5cfe-47b0-a2cc-f9e50e0da444/LLM/eval_out/medmcqa_rag_notebook_20251219_131920


In [38]:
# =========================
# 2) 数据加载（兼容字段）
# =========================
def _get_first(d: dict, keys: List[str], default=None):
    for k in keys:
        if k in d and d[k] is not None:
            return d[k]
    return default

def normalize_gold(g):
    # 统一成 int 或 A/B/C/D
    if isinstance(g, str):
        gg = g.strip()
        if gg in ["A","B","C","D"]:
            return gg
        if gg.isdigit():
            g = int(gg)
        else:
            return None

    if isinstance(g, (int, np.integer)):
        g = int(g)

        # 自动判断：更像 1/2/3/4 还是 0/1/2/3
        if g in (1, 2, 3, 4):
            return "ABCD"[g - 1]   # 1-based
        if g in (0, 1, 2, 3):
            return "ABCD"[g]       # 0-based

        return None

    return None


def load_medmcqa(path: str) -> List[dict]:
    path = Path(path)
    text = path.read_text(encoding="utf-8", errors="ignore").strip()

    # 1) 先尝试标准 JSON（list/dict）
    data = None
    try:
        data = json.loads(text)
        if isinstance(data, dict) and "data" in data and isinstance(data["data"], list):
            data = data["data"]
        if isinstance(data, dict) and "questions" in data and isinstance(data["questions"], list):
            data = data["questions"]
        if not isinstance(data, list):
            raise ValueError("Top-level JSON is not a list")
    except Exception:
        # 2) 走 JSONL（每行一个 JSON 对象）
        data = []
        with path.open("r", encoding="utf-8", errors="ignore") as f:
            for line_no, line in enumerate(f, 1):
                line = line.strip()
                if not line:
                    continue
                try:
                    obj = json.loads(line)
                    data.append(obj)
                except json.JSONDecodeError as e:
                    raise ValueError(
                        f"JSONL parse failed at line {line_no}: {e}\n"
                        f"Line content (first 200 chars): {line[:200]}"
                    )

    # ====== 下面保持你原来的字段兼容逻辑 ======
    samples = []
    for i, ex in enumerate(data):
        q = _get_first(ex, ["question", "ques", "query", "prompt"])
        if q is None:
            continue

        opa = _get_first(ex, ["opa","option_a","A","a"])
        opb = _get_first(ex, ["opb","option_b","B","b"])
        opc = _get_first(ex, ["opc","option_c","C","c"])
        opd = _get_first(ex, ["opd","option_d","D","d"])

        if any(x is None for x in [opa,opb,opc,opd]) and "options" in ex:
            opts = ex["options"]
            if isinstance(opts, dict):
                opa = opa or opts.get("A") or opts.get("a")
                opb = opb or opts.get("B") or opts.get("b")
                opc = opc or opts.get("C") or opts.get("c")
                opd = opd or opts.get("D") or opts.get("d")
            elif isinstance(opts, list) and len(opts) >= 4:
                opa, opb, opc, opd = opa or opts[0], opb or opts[1], opc or opts[2], opd or opts[3]

        gold_raw = _get_first(ex, ["cop","gold","answer","label","correct_option","correct"])
        gold = normalize_gold(gold_raw)

        samples.append({
            "id": ex.get("id", i),
            "question": str(q).strip(),
            "options": {
                "A": "" if opa is None else str(opa).strip(),
                "B": "" if opb is None else str(opb).strip(),
                "C": "" if opc is None else str(opc).strip(),
                "D": "" if opd is None else str(opd).strip(),
            },
            "gold": gold,
            "raw": ex
        })
    return samples

samples = load_medmcqa(DEV_FILE)
print("Loaded samples:", len(samples))
print("Example question:", samples[0]["question"][:120])
print("Options:", samples[0]["options"])
print("Gold:", samples[0]["gold"])


Loaded samples: 4183
Example question: Which of the following is not true for myelinated nerve fibers:
Options: {'A': 'Impulse through myelinated fibers is slower than non-myelinated fibers', 'B': 'Membrane currents are generated at nodes of Ranvier', 'C': 'Saltatory conduction of impulses is seen', 'D': 'Local anesthesia is effective only when the nerve is not covered by myelin sheath'}
Gold: A


In [39]:
from collections import Counter
print("raw cop top:", Counter([s["raw"].get("cop") for s in samples]).most_common(10))
print("normalized gold top:", Counter([s["gold"] for s in samples]).most_common(10))

raw cop top: [(1, 1348), (2, 1085), (3, 925), (4, 825)]
normalized gold top: [('A', 1348), ('B', 1085), ('C', 925), ('D', 825)]


In [40]:
# =========================
# 3) KB 加载（自动发现 index & docs）
# =========================
def find_first_file(d: Path, exts: Tuple[str, ...]) -> Optional[Path]:
    for p in sorted(d.rglob("*")):
        if p.is_file() and p.suffix.lower() in exts:
            return p
    return None

def load_docs_any(path: Path) -> List[str]:
    if path.suffix.lower() in [".pkl", ".pickle"]:
        obj = pickle.loads(path.read_bytes())
        if isinstance(obj, list):
            return [str(x) for x in obj]
        if isinstance(obj, dict):
            for k in ["documents", "docs", "texts", "corpus"]:
                if k in obj and isinstance(obj[k], list):
                    return [str(x) for x in obj[k]]
            if all(isinstance(k,(int,np.integer,str)) for k in obj.keys()):
                items = sorted(obj.items(), key=lambda kv: int(kv[0]) if str(kv[0]).isdigit() else str(kv[0]))
                return [str(v) for _, v in items]
        raise ValueError(f"Unknown pkl format: {type(obj)}")

    if path.suffix.lower() == ".jsonl":
        docs = []
        with path.open("r", encoding="utf-8") as f:
            for line in f:
                line=line.strip()
                if not line:
                    continue
                j = json.loads(line)
                txt = j.get("text") or j.get("content") or j.get("document") or j.get("doc") or j.get("passage")
                if txt is None:
                    txt = json.dumps(j, ensure_ascii=False)
                docs.append(str(txt))
        return docs

    if path.suffix.lower() == ".json":
        j = json.loads(path.read_text(encoding="utf-8"))
        if isinstance(j, list):
            if all(isinstance(x, str) for x in j):
                return j
            docs=[]
            for x in j:
                if isinstance(x, dict):
                    txt = x.get("text") or x.get("content") or x.get("document") or x.get("doc") or x.get("passage")
                    docs.append(str(txt) if txt is not None else json.dumps(x, ensure_ascii=False))
                else:
                    docs.append(str(x))
            return docs
        if isinstance(j, dict):
            for k in ["documents","docs","texts","corpus"]:
                if k in j and isinstance(j[k], list):
                    return [str(x) for x in j[k]]
        raise ValueError("Unknown json format")

    raise ValueError("Unsupported docs file type: " + str(path))

def load_faiss_index_any(path: Path):
    import faiss
    return faiss.read_index(str(path))

KB_DIR_PATH = Path(KB_DIR)
assert KB_DIR_PATH.exists(), f"KB_DIR not found: {KB_DIR_PATH}"

idx_file = find_first_file(KB_DIR_PATH, (".faiss",".index"))
docs_file = find_first_file(KB_DIR_PATH, (".pkl",".pickle",".jsonl",".json"))

print("Found index:", idx_file)
print("Found docs :", docs_file)
assert idx_file is not None, "未找到 FAISS index 文件（.faiss/.index）"
assert docs_file is not None, "未找到 documents 文件（.pkl/.jsonl/.json）"

docs = load_docs_any(docs_file)
print("Docs:", len(docs), "| Example:", docs[0][:120].replace("\n"," "))

index = load_faiss_index_any(idx_file)
print("FAISS index ntotal:", index.ntotal)
assert index.ntotal == len(docs), "index.ntotal 与 docs 数量不一致：请检查 KB 对齐"


Found index: rag_cache/medmcqa_train_50k/index.faiss
Found docs : rag_cache/medmcqa_train_50k/docs.jsonl
Docs: 50000 | Example: Q: Chronic urethral obstruction due to benign prismatic hyperplasia can lead to the following change in kidney parenchym
FAISS index ntotal: 50000


In [41]:
# =========================
# 4) Embedding 模型
# =========================
@dataclass
class EmbedderConfig:
    model_name: str
    device: str = DEVICE
    max_length: int = 256
    batch_size: int = 32
    fp16: bool = True

class MeanPoolEmbedder:
    def __init__(self, cfg: EmbedderConfig):
        self.cfg = cfg
        self.tokenizer = AutoTokenizer.from_pretrained(cfg.model_name, use_fast=True)
        self.model = AutoModel.from_pretrained(cfg.model_name)
        self.model.eval()
        self.model.to(cfg.device)
        if cfg.fp16 and cfg.device.startswith("cuda"):
            self.model.half()

    @torch.no_grad()
    def encode(self, texts: List[str]) -> np.ndarray:
        all_vecs = []
        bs = self.cfg.batch_size
        for i in range(0, len(texts), bs):
            batch = texts[i:i+bs]
            tok = self.tokenizer(batch, padding=True, truncation=True,
                                 max_length=self.cfg.max_length, return_tensors="pt")
            tok = {k: v.to(self.cfg.device) for k,v in tok.items()}
            out = self.model(**tok)
            last = out.last_hidden_state  # [B,T,H]
            mask = tok["attention_mask"].unsqueeze(-1).to(last.dtype)
            summed = (last * mask).sum(dim=1)
            denom = mask.sum(dim=1).clamp(min=1e-6)
            mean = summed / denom
            mean = torch.nn.functional.normalize(mean, p=2, dim=1)
            all_vecs.append(mean.float().cpu().numpy())
        return np.vstack(all_vecs)

embedder = MeanPoolEmbedder(EmbedderConfig(model_name=EMBED_MODEL_NAME))
print("Embedder ready:", EMBED_MODEL_NAME)


Embedder ready: microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext


In [42]:
# =========================
# 5) Retrieval & Evidence extraction
# =========================
_SENT_SPLIT = re.compile(r"(?<=[\.!\?])\s+|(?<=[。！？])")

def split_sents(text: str) -> List[str]:
    text = re.sub(r"\s+", " ", text.strip())
    if not text:
        return []
    parts = _SENT_SPLIT.split(text)
    return [p.strip() for p in parts if p.strip()]

def retrieve_topk(query: str, k: int = 3) -> Tuple[List[int], List[float]]:
    qv = embedder.encode([query]).astype(np.float32)
    scores, ids = index.search(qv, k)
    return ids[0].tolist(), scores[0].tolist()

def extract_evidence(query: str, doc_ids: List[int], max_sents: int = 4) -> str:
    sents = []
    for did in doc_ids:
        for s in split_sents(docs[did]):
            sents.append(s)
    if not sents:
        return ""
    qv = embedder.encode([query]).astype(np.float32)[0]
    sv = embedder.encode(sents).astype(np.float32)
    sims = (sv @ qv.reshape(-1,1)).reshape(-1)  # cosine if normalized
    top_idx = np.argsort(-sims)[:max_sents]
    chosen = [sents[i] for i in top_idx]
    return " ".join(chosen)

# sanity check
ids, sc = retrieve_topk(samples[0]["question"], k=3)
print("Top ids:", ids)
print("Top scores:", [round(x,4) for x in sc])
print("Evidence preview:", extract_evidence(samples[0]["question"], ids, max_sents=4)[:200])


Top ids: [27773, 45434, 40482]
Top scores: [0.9015, 0.9006, 0.8966]
Evidence preview: Which of the following positions serve as most comfoable for a patient facing difficulty in bre A: All positions are same in respect to comfo. Q: Which of the following statements is true regarding ka


In [43]:
# =========================
# 6) LM scorer (next-token logits for A/B/C/D)
# =========================
@dataclass
class LMConfig:
    model_name_or_path: str
    device: str = DEVICE
    max_prompt_tokens: int = 512
    fp16: bool = True

class ChoiceScorer:
    def __init__(self, cfg: LMConfig):
        self.cfg = cfg
        self.tokenizer = AutoTokenizer.from_pretrained(cfg.model_name_or_path, use_fast=True)
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        self.model = AutoModelForCausalLM.from_pretrained(cfg.model_name_or_path)
        self.model.eval()
        self.model.to(cfg.device)
        if cfg.fp16 and cfg.device.startswith("cuda"):
            self.model.half()

        self.choice_token_id = {}
        for ch in ["A","B","C","D"]:
            ids = self.tokenizer.encode(ch, add_special_tokens=False)
            self.choice_token_id[ch] = ids[0]
        print("Choice token id:", self.choice_token_id)

    @torch.no_grad()
    def score_choices_next_token(self, prompt: str) -> Dict[str, float]:
        tok = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=self.cfg.max_prompt_tokens)
        tok = {k: v.to(self.cfg.device) for k,v in tok.items()}
        out = self.model(**tok)
        logits = out.logits[:, -1, :]  # [1,vocab]
        scores = {}
        for ch, tid in self.choice_token_id.items():
            scores[ch] = float(logits[0, tid].detach().cpu())
        return scores

scorer = ChoiceScorer(LMConfig(model_name_or_path=MODEL_NAME_OR_PATH))
print("LM ready:", MODEL_NAME_OR_PATH)


Choice token id: {'A': 32, 'B': 33, 'C': 34, 'D': 35}
LM ready: ./gpt2


In [44]:
# =========================
# 7) Prompt builders (no triple-quote nesting)
# =========================
def build_prompt_base(q: str, options: Dict[str,str]) -> str:
    return (
        "You are answering a multiple-choice medical question.\n"
        f"Question: {q}\n"
        f"A) {options['A']}\n"
        f"B) {options['B']}\n"
        f"C) {options['C']}\n"
        f"D) {options['D']}\n"
        "Answer: "
    )

def build_prompt_rag_question(q: str, options: Dict[str,str], evidence: str) -> str:
    ev = evidence.strip()
    ev_block = (f"Evidence: {ev}\n" if ev else "")
    return (
        "You are answering a multiple-choice medical question using evidence.\n"
        + ev_block +
        f"Question: {q}\n"
        f"A) {options['A']}\n"
        f"B) {options['B']}\n"
        f"C) {options['C']}\n"
        f"D) {options['D']}\n"
        "Answer: "
    )

def build_prompt_rag_option(q: str, options: Dict[str,str], ev_by_opt: Dict[str,str]) -> str:
    lines = [
        "You are answering a multiple-choice medical question using per-option evidence.",
        f"Question: {q}"
    ]
    for ch in ["A","B","C","D"]:
        lines.append(f"{ch}) {options[ch]}")
        ev = (ev_by_opt.get(ch, "") or "").strip()
        if ev:
            lines.append(f"Evidence for {ch}: {ev}")
    lines.append("Answer: ")
    return "\n".join(lines)


In [45]:
# =========================
# 8) Predict & evaluate
# =========================
LABELS = ["A","B","C","D"]

def predict_one(sample: dict, mode: str,
                rag_k: int = DEFAULT_RAG_K,
                ctx_max_chars: int = DEFAULT_CTX_MAX_CHARS,
                evid_max_sents: int = DEFAULT_EVID_MAX_SENTS,
                sim_threshold: Optional[float] = DEFAULT_SIM_THRESHOLD) -> dict:
    q = sample["question"]
    opts = sample["options"]

    record = {
        "id": sample["id"],
        "gold": sample["gold"],
        "pred": None,
        "mode": mode,
        "rag_k": rag_k,
        "ctx_max_chars": ctx_max_chars,
        "evid_max_sents": evid_max_sents,
        "sim_threshold": sim_threshold,
        "scores": None,
        "rag_context": None,
        "question": q,
        "options": opts,
    }

    if mode == "base":
        prompt = build_prompt_base(q, opts)
        scores = scorer.score_choices_next_token(prompt)
        pred = max(scores, key=scores.get)
        record.update({"scores": scores, "pred": pred, "rag_context": ""})
        return record

    if mode == "rag_q":
        ids, scores0 = retrieve_topk(q, k=rag_k)
        if sim_threshold is not None and scores0 and scores0[0] < sim_threshold:
            evidence = ""
        else:
            evidence = extract_evidence(q, ids, max_sents=evid_max_sents)[:ctx_max_chars]
        prompt = build_prompt_rag_question(q, opts, evidence)
        scores = scorer.score_choices_next_token(prompt)
        pred = max(scores, key=scores.get)
        record.update({"scores": scores, "pred": pred, "rag_context": evidence})
        return record

    if mode == "rag_opt":
        ev_by_opt = {}
        per_opt_chars = max(50, ctx_max_chars // 2)
        for ch in ["A","B","C","D"]:
            query = q + " " + opts[ch]
            ids, scores0 = retrieve_topk(query, k=rag_k)
            if sim_threshold is not None and scores0 and scores0[0] < sim_threshold:
                ev = ""
            else:
                ev = extract_evidence(query, ids, max_sents=evid_max_sents)[:per_opt_chars]
            ev_by_opt[ch] = ev

        prompt = build_prompt_rag_option(q, opts, ev_by_opt)
        scores = scorer.score_choices_next_token(prompt)
        pred = max(scores, key=scores.get)
        record.update({"scores": scores, "pred": pred, "rag_context": ev_by_opt})
        return record

    raise ValueError("Unknown mode: " + mode)

def eval_mode(samples: List[dict], mode: str, **kwargs) -> List[dict]:
    out = []
    for s in tqdm(samples, desc=f"eval:{mode}"):
        out.append(predict_one(s, mode=mode, **kwargs))
    return out

def save_jsonl(records: List[dict], path: Path):
    with path.open("w", encoding="utf-8") as f:
        for r in records:
            f.write(json.dumps(r, ensure_ascii=False) + "\n")


In [46]:
# =========================
# 9) Metrics
# =========================
from sklearn.metrics import confusion_matrix, precision_recall_fscore_support, accuracy_score

def summarize(records: List[dict]) -> dict:
    gold = [r["gold"] for r in records]
    pred = [r["pred"] for r in records]
    keep = [i for i,(g,p) in enumerate(zip(gold,pred)) if g in LABELS and p in LABELS]
    gold = [gold[i] for i in keep]
    pred = [pred[i] for i in keep]

    acc = accuracy_score(gold, pred)
    cm = confusion_matrix(gold, pred, labels=LABELS)

    prec, rec, f1, sup = precision_recall_fscore_support(gold, pred, labels=LABELS, zero_division=0)
    per_class = []
    for i, lab in enumerate(LABELS):
        per_class.append({
            "class": lab,
            "precision": float(prec[i]),
            "recall": float(rec[i]),
            "f1": float(f1[i]),
            "support": int(sup[i]),
            "correct": int(cm[i,i]),
        })

    macro = {"precision": float(np.mean(prec)), "recall": float(np.mean(rec)), "f1": float(np.mean(f1))}
    weighted = {
        "precision": float(np.average(prec, weights=sup)),
        "recall": float(np.average(rec, weights=sup)),
        "f1": float(np.average(f1, weights=sup)),
    }
    dist_pred = dict(pd.Series(pred).value_counts().reindex(LABELS, fill_value=0))
    dist_gold = dict(pd.Series(gold).value_counts().reindex(LABELS, fill_value=0))

    return {
        "evaluated": len(gold),
        "acc": float(acc),
        "dist_pred": {k:int(v) for k,v in dist_pred.items()},
        "dist_gold": {k:int(v) for k,v in dist_gold.items()},
        "per_class": per_class,
        "macro": macro,
        "weighted": weighted,
        "confusion_matrix": cm.tolist(),
    }

def fix_hurt(base_records: List[dict], rag_records: List[dict]) -> dict:
    bmap = {r["id"]: r for r in base_records}
    rmap = {r["id"]: r for r in rag_records}
    ids = sorted(set(bmap.keys()) & set(rmap.keys()), key=lambda x: str(x))

    fix = hurt = same_correct = same_wrong = 0
    for i in ids:
        b = bmap[i]; r = rmap[i]
        g = b["gold"]
        if g not in LABELS:
            continue
        bc = (b["pred"] == g)
        rc = (r["pred"] == g)
        if (not bc) and rc:
            fix += 1
        elif bc and (not rc):
            hurt += 1
        elif bc and rc:
            same_correct += 1
        else:
            same_wrong += 1
    return {"fix": fix, "hurt": hurt, "same_correct": same_correct, "same_wrong": same_wrong, "overlap": len(ids)}


In [47]:
# =========================
# 10) Run one comparison
# =========================
base_records = eval_mode(samples, "base")
ragq_records = eval_mode(samples, "rag_q",
                         rag_k=DEFAULT_RAG_K,
                         ctx_max_chars=DEFAULT_CTX_MAX_CHARS,
                         evid_max_sents=DEFAULT_EVID_MAX_SENTS,
                         sim_threshold=DEFAULT_SIM_THRESHOLD)
rago_records = eval_mode(samples, "rag_opt",
                         rag_k=DEFAULT_RAG_K,
                         ctx_max_chars=DEFAULT_CTX_MAX_CHARS,
                         evid_max_sents=DEFAULT_EVID_MAX_SENTS,
                         sim_threshold=DEFAULT_SIM_THRESHOLD)

base_sum = summarize(base_records)
ragq_sum = summarize(ragq_records)
rago_sum = summarize(rago_records)

print("Base ACC :", base_sum["acc"])
print("RAG-Q ACC:", ragq_sum["acc"], "Fix/Hurt:", fix_hurt(base_records, ragq_records))
print("RAG-O ACC:", rago_sum["acc"], "Fix/Hurt:", fix_hurt(base_records, rago_records))

save_jsonl(base_records, RUN_DIR / "base_records.jsonl")
save_jsonl(ragq_records, RUN_DIR / "ragq_records.jsonl")
save_jsonl(rago_records, RUN_DIR / "rago_records.jsonl")

(RUN_DIR / "summary.json").write_text(json.dumps({
    "config": {
        "DEV_FILE": DEV_FILE, "KB_DIR": KB_DIR, "MODEL": MODEL_NAME_OR_PATH,
        "EMBED_MODEL": EMBED_MODEL_NAME,
        "DEFAULT_RAG_K": DEFAULT_RAG_K,
        "DEFAULT_CTX_MAX_CHARS": DEFAULT_CTX_MAX_CHARS,
        "DEFAULT_EVID_MAX_SENTS": DEFAULT_EVID_MAX_SENTS,
        "DEFAULT_SIM_THRESHOLD": DEFAULT_SIM_THRESHOLD,
    },
    "base": base_sum,
    "rag_q": ragq_sum,
    "rag_opt": rago_sum,
    "fixhurt_rag_q": fix_hurt(base_records, ragq_records),
    "fixhurt_rag_opt": fix_hurt(base_records, rago_records),
}, ensure_ascii=False, indent=2), encoding="utf-8")

print("Saved to:", RUN_DIR)


eval:base:  44%|████▎     | 1823/4183 [00:13<00:16, 139.70it/s]


KeyboardInterrupt: 

In [None]:
# =========================
# 11) Grid search
# =========================
def run_grid(modes=("rag_q","rag_opt"),
             rag_k_list=(1,3,5),
             ctx_list=(300,600),
             evid_sents_list=(4,),
             sim_threshold_list=(None,)) -> pd.DataFrame:
    rows = []
    base = base_records  # reuse

    for mode in modes:
        for rag_k in rag_k_list:
            for ctx in ctx_list:
                for es in evid_sents_list:
                    for th in sim_threshold_list:
                        recs = eval_mode(samples, mode,
                                         rag_k=rag_k,
                                         ctx_max_chars=ctx,
                                         evid_max_sents=es,
                                         sim_threshold=th)
                        summ = summarize(recs)
                        fh = fix_hurt(base, recs)
                        rows.append({
                            "mode": mode,
                            "rag_k": rag_k,
                            "ctx_max_chars": ctx,
                            "evid_max_sents": es,
                            "sim_threshold": th,
                            "acc": summ["acc"],
                            "macro_f1": summ["macro"]["f1"],
                            "weighted_f1": summ["weighted"]["f1"],
                            "fix": fh["fix"],
                            "hurt": fh["hurt"],
                            "evaluated": summ["evaluated"],
                        })
                        key = f"{mode}_k{rag_k}_ctx{ctx}_es{es}_th{th}"
                        (RUN_DIR / f"{key}.summary.json").write_text(json.dumps({
                            "summary": summ,
                            "fixhurt": fh,
                            "params": {"mode": mode, "rag_k": rag_k, "ctx_max_chars": ctx, "evid_max_sents": es, "sim_threshold": th}
                        }, ensure_ascii=False, indent=2), encoding="utf-8")

    df = pd.DataFrame(rows).sort_values(["acc","macro_f1"], ascending=False).reset_index(drop=True)
    df.to_csv(RUN_DIR / "grid_results.csv", index=False)
    return df

grid_df = run_grid(
    modes=("rag_q","rag_opt"),
    rag_k_list=(1,3,5),
    ctx_list=(300,600),
    evid_sents_list=(4,),
    sim_threshold_list=(None,)
)
grid_df.head(20)


eval:rag_q: 100%|██████████| 4183/4183 [05:00<00:00, 13.90it/s]
eval:rag_q: 100%|██████████| 4183/4183 [05:02<00:00, 13.83it/s]
eval:rag_q: 100%|██████████| 4183/4183 [05:32<00:00, 12.57it/s]
eval:rag_q: 100%|██████████| 4183/4183 [05:32<00:00, 12.58it/s]
eval:rag_q: 100%|██████████| 4183/4183 [05:50<00:00, 11.94it/s]
eval:rag_q: 100%|██████████| 4183/4183 [05:49<00:00, 11.98it/s]
eval:rag_opt: 100%|██████████| 4183/4183 [18:22<00:00,  3.79it/s]
eval:rag_opt: 100%|██████████| 4183/4183 [18:26<00:00,  3.78it/s]
eval:rag_opt: 100%|██████████| 4183/4183 [20:21<00:00,  3.42it/s]
eval:rag_opt: 100%|██████████| 4183/4183 [20:24<00:00,  3.42it/s]
eval:rag_opt: 100%|██████████| 4183/4183 [21:31<00:00,  3.24it/s]
eval:rag_opt: 100%|██████████| 4183/4183 [21:30<00:00,  3.24it/s]


Unnamed: 0,mode,rag_k,ctx_max_chars,evid_max_sents,sim_threshold,acc,macro_f1,weighted_f1,fix,hurt,evaluated
0,rag_opt,3,600,4,,0.322496,0.123023,0.158059,4,2,4183
1,rag_q,1,300,4,,0.322496,0.122442,0.157613,2,0,4183
2,rag_q,1,600,4,,0.322496,0.122442,0.157613,2,0,4183
3,rag_opt,5,600,4,,0.322257,0.122483,0.157582,3,2,4183
4,rag_q,3,300,4,,0.322257,0.121903,0.157136,2,1,4183
5,rag_q,3,600,4,,0.322257,0.121881,0.157107,2,1,4183
6,rag_q,5,300,4,,0.322257,0.121881,0.157107,2,1,4183
7,rag_q,5,600,4,,0.322257,0.121881,0.157107,2,1,4183
8,rag_opt,1,300,4,,0.322257,0.121881,0.157107,2,1,4183
9,rag_opt,3,300,4,,0.322257,0.121881,0.157107,2,1,4183


In [None]:
# =========================
# 12) Inspect retrieval score distribution
# =========================
def sample_top1_scores(n=200, mode="rag_q"):
    scores = []
    pick = random.sample(samples, min(n, len(samples)))
    for s in pick:
        if mode == "rag_q":
            q = s["question"]
        else:
            ch = random.choice(["A","B","C","D"])
            q = s["question"] + " " + s["options"][ch]
        _, sc = retrieve_topk(q, k=1)
        if sc:
            scores.append(sc[0])
    return np.array(scores, dtype=np.float32)

arr_q = sample_top1_scores(n=300, mode="rag_q")
arr_o = sample_top1_scores(n=300, mode="rag_opt")

print("RAG-Q top1 score: min/mean/max =", float(arr_q.min()), float(arr_q.mean()), float(arr_q.max()))
print("RAG-O top1 score: min/mean/max =", float(arr_o.min()), float(arr_o.mean()), float(arr_o.max()))

for name, arr in [("rag_q", arr_q), ("rag_opt", arr_o)]:
    qs = np.quantile(arr, [0.05,0.1,0.25,0.5,0.75,0.9,0.95])
    print(name, "quantiles:", {str(k): float(v) for k,v in zip([0.05,0.1,0.25,0.5,0.75,0.9,0.95], qs)})


RAG-Q top1 score: min/mean/max = 0.8763331770896912 0.908949613571167 0.9389159679412842
RAG-O top1 score: min/mean/max = 0.8507470488548279 0.9064990282058716 0.9410426020622253
rag_q quantiles: {'0.05': 0.888480469584465, '0.1': 0.8951842665672303, '0.25': 0.9004756808280945, '0.5': 0.9088247716426849, '0.75': 0.9172671139240265, '0.9': 0.9245349586009979, '0.95': 0.9296152234077454}
rag_opt quantiles: {'0.05': 0.8872941255569458, '0.1': 0.8918971240520477, '0.25': 0.8988986760377884, '0.5': 0.9067673981189728, '0.75': 0.9144319444894791, '0.9': 0.9226700484752655, '0.95': 0.9257474839687347}


In [34]:
import random

def show_ragq_examples(n=5, rag_k=3):
    for s in random.sample(samples, n):
        q = s["question"]
        ids, sc = retrieve_topk(q, k=rag_k)
        ev = extract_evidence(q, ids, max_sents=4)
        print("="*120)
        print("Q:", q)
        print("Gold:", s["gold"])
        print("Top scores:", [round(x,4) for x in sc])
        print("Top doc[0] snippet:", docs[ids[0]][:300].replace("\n"," "))
        print("Evidence:", ev[:400])

show_ragq_examples(5, rag_k=3)

Q: A patient shows one or more of the following: advanced bone loss, grade II and III furcation involvements, tooth mobility, inaccessible areas, systemic/environmental factors represents:
Gold: A
Top scores: [0.8943, 0.8906, 0.8901]
Top doc[0] snippet: Q: Children with apathy, general weakness, loosening of the skin, marasmic features also has X3B Xerophthalmia features. Eye finding will be A: Corneal ulcer with full thickness
Evidence: Q: In osteoporosis there is -a) Decrease in absolute amount of bone massb) More common in malec) Radiographs show normal bone densityd) Hormonal replacement therapy A: ad Q: An 18 year old female patient complains of prominent upper front teeth. Extra-oral examination reveals an acute nasolabial angle and lip strain. Q: Children with apathy, general weakness, loosening of the skin, marasmic featur
Q: Acquired cause of pure red cell aplasia are all except:
Gold: C
Top scores: [0.9157, 0.9152, 0.9151]
Top doc[0] snippet: Q: A patient has long standing se