In [12]:
import os, re, json, pickle, random, math
from datetime import datetime
from typing import List, Dict, Any, Tuple

import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset

from datasets import load_dataset
from transformers import (
    AutoTokenizer, AutoModel, AutoModelForCausalLM,
    Trainer, TrainingArguments
)

try:
    import faiss
except Exception as e:
    raise RuntimeError("缺少 faiss：pip install faiss-cpu 或 faiss-gpu") from e

import torch

# 给所有 torch 优化器补上 train/eval（某些 transformers 版本会调用）
if not hasattr(torch.optim.Optimizer, "train"):
    torch.optim.Optimizer.train = lambda self: None
if not hasattr(torch.optim.Optimizer, "eval"):
    torch.optim.Optimizer.eval = lambda self: None


# ========= 路径 =========
BASE_DIR = "/media/miaoen/ad4277ac-5cfe-47b0-a2cc-f9e50e0da444/LLM"
OUT_DIR  = f"{BASE_DIR}/eval_out/medqa_train_letter_rag"
os.makedirs(OUT_DIR, exist_ok=True)

# 训练输出目录（新模型）
FT_OUT_DIR = f"{BASE_DIR}/out_gpt2_medqa_rag_letter"
os.makedirs(FT_OUT_DIR, exist_ok=True)

# ========= 两个模型（对比用） =========
MODEL_BASE = f"{BASE_DIR}/gpt2"
MODEL_OLD_FT = f"{BASE_DIR}/gpt2-medmcqa-raft-masked"  # 旧的（跨数据集会掉点正常）
# 你将训练的新模型：FT_OUT_DIR

# ========= Embedding 模型（建库/检索一致） =========
EMB_MODEL = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext"

# ========= KB（WikiDoc） =========
KB_NAME  = "medalpaca/medical_meadow_wikidoc"
MAX_DOCS = 200000  # 先 20w（你机器够再加）
SAVE_DIR = f"{BASE_DIR}/rag_cache/medqa_{KB_NAME.replace('/','_')}_{MAX_DOCS}"
os.makedirs(SAVE_DIR, exist_ok=True)
RAG_INDEX_PATH = f"{SAVE_DIR}/kb.index"
RAG_DOCS_PATH  = f"{SAVE_DIR}/docs.pkl"

# ========= 设备/随机种子 =========
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
if torch.cuda.is_available(): torch.cuda.manual_seed_all(SEED)

# ========= Prompt/Token 上限 =========
MAX_INPUT_TOKENS = 512   # 和你之前保持一致
EVID_MAX_CHARS   = 180   # 证据更短更稳
EVID_KEEP_SENTS  = 1

# ========= RAG 参数 =========
RAG_K = 2               # question-only: top2
RAG_TAU = 0.25          # 门控阈值（后续可以扫 0.2/0.25/0.3）
P_RAG_TRAIN = 0.5       # 训练样本中，有证据比例（关键！）

# ========= 训练参数（GPT2 很小，别训太猛） =========
TRAIN_MAX = None        # 想快速试跑：比如 20000
VAL_SIZE = 2000         # 从 train 切出验证集
EPOCHS = 1              # 建议先 1，再看趋势
LR = 5e-5               # 小一点更稳
TRAIN_BS = 8
EVAL_BS  = 16
GRAD_ACC = 4
WARMUP_RATIO = 0.03

CHOICE_LETTERS = ["A","B","C","D"]

print("DEVICE:", DEVICE)
print("OUT_DIR:", OUT_DIR)
print("FT_OUT_DIR:", FT_OUT_DIR)
print("KB:", KB_NAME, "MAX_DOCS:", MAX_DOCS)
print("RAG: K=", RAG_K, "tau=", RAG_TAU, "P_RAG_TRAIN=", P_RAG_TRAIN)


DEVICE: cuda
OUT_DIR: /media/miaoen/ad4277ac-5cfe-47b0-a2cc-f9e50e0da444/LLM/eval_out/medqa_train_letter_rag
FT_OUT_DIR: /media/miaoen/ad4277ac-5cfe-47b0-a2cc-f9e50e0da444/LLM/out_gpt2_medqa_rag_letter
KB: medalpaca/medical_meadow_wikidoc MAX_DOCS: 200000
RAG: K= 2 tau= 0.25 P_RAG_TRAIN= 0.5


In [13]:
ds = load_dataset("GBaker/MedQA-USMLE-4-options")
print("Splits:", ds.keys())
print("Train example keys:", ds["train"][0].keys())

def _label_from_any(x):
    if x is None: 
        return None
    if isinstance(x, (int, np.integer)):
        v = int(x)
        if 0 <= v <= 3: return v
        if 1 <= v <= 4: return v-1
        return None
    if isinstance(x, str):
        s = x.strip().upper()
        if s in CHOICE_LETTERS: return CHOICE_LETTERS.index(s)
        if s.isdigit():
            v = int(s)
            if 0 <= v <= 3: return v
            if 1 <= v <= 4: return v-1
    return None

def normalize_medqa_example(ex: Dict[str, Any], idx: int) -> Dict[str, Any]:
    q = ex["question"]
    od = ex["options"]
    A,B,C,D = od.get("A",""), od.get("B",""), od.get("C",""), od.get("D","")
    y = _label_from_any(ex.get("answer_idx", None))
    if y is None:
        y = _label_from_any(ex.get("answer", None))
    if y is None:
        raise KeyError(f"无法解析 label：keys={list(ex.keys())[:20]}")
    uid = ex.get("id", ex.get("qid", ex.get("question_id", f"idx-{idx}")))
    return {"id": str(uid), "question": q, "opa": A, "opb": B, "opc": C, "opd": D, "label": int(y)}

# 生成 train/val（MedQA 只有 train/test，这里从 train 切 val）
train_raw = ds["train"]
idxs = list(range(len(train_raw)))
random.shuffle(idxs)

val_idxs = idxs[:VAL_SIZE]
trn_idxs = idxs[VAL_SIZE:]

if TRAIN_MAX is not None:
    trn_idxs = trn_idxs[:TRAIN_MAX]

print("Train size:", len(trn_idxs), "Val size:", len(val_idxs), "Test size:", len(ds["test"]))


Splits: dict_keys(['train', 'test'])
Train example keys: dict_keys(['question', 'answer', 'options', 'meta_info', 'answer_idx', 'metamap_phrases'])
Train size: 8178 Val size: 2000 Test size: 1273


In [14]:
_SENT_SPLIT = re.compile(r'(?<=[\.\?\!])\s+|\n+')

def _normalize(text: str) -> str:
    return re.sub(r"\s+", " ", (text or "").strip())

def _tokenize(text: str):
    text = (text or "").lower()
    text = re.sub(r"[^a-z0-9\s\-]", " ", text)
    return [t for t in text.split() if len(t) >= 3]

def _jaccard(a, b):
    sa, sb = set(a), set(b)
    if not sa or not sb: return 0.0
    return len(sa & sb) / max(1, len(sa | sb))

def compress_context_by_overlap(question: str, options: List[str], context: str,
                                keep_sents: int=1, max_chars: int=180, min_sent_chars: int=30) -> str:
    ctx = _normalize(context)
    if not ctx: return ""

    query = _normalize(question) + " " + " ".join([_normalize(x) for x in options if x])
    q_toks = _tokenize(query)
    if not q_toks:
        return ctx[:max_chars] if len(ctx) > max_chars else ctx

    sents = [s.strip() for s in _SENT_SPLIT.split(ctx) if s and len(s.strip()) >= min_sent_chars]
    if not sents:
        return ctx[:max_chars] if len(ctx) > max_chars else ctx

    scored = []
    for s in sents:
        s_toks = _tokenize(s)
        if not s_toks:
            continue
        jac = _jaccard(q_toks, s_toks)
        overlap = len(set(q_toks) & set(s_toks))
        score = jac * 2.0 + overlap * 0.05
        scored.append((score, s))

    if not scored:
        return ctx[:max_chars] if len(ctx) > max_chars else ctx

    scored.sort(key=lambda x: x[0], reverse=True)
    picked = [s for _, s in scored[:keep_sents]]

    # 保持原顺序
    order = {s: i for i, s in enumerate(sents)}
    picked.sort(key=lambda s: order.get(s, 10**9))

    evidence = _normalize(" ".join(picked))
    return evidence[:max_chars] if len(evidence) > max_chars else evidence

def extract_text_from_kb_example(ex: Dict[str, Any]) -> str:
    # 尽量从常见字段取正文
    for k in ["text","content","passage","snippet","abstract","output"]:
        if k in ex and isinstance(ex[k], str) and ex[k].strip():
            return ex[k].strip()
    parts = []
    for k in ["title","instruction","input","question"]:
        if k in ex and isinstance(ex[k], str) and ex[k].strip():
            parts.append(ex[k].strip())
    for k in ["output","answer","context"]:
        if k in ex and isinstance(ex[k], str) and ex[k].strip():
            parts.append(ex[k].strip())
    return _normalize(" ".join(parts))


In [15]:
def build_kb_if_needed():
    if os.path.exists(RAG_INDEX_PATH) and os.path.exists(RAG_DOCS_PATH):
        print("发现 KB 缓存，跳过建库。")
        return

    print("开始建 KB（首次会较久）...")
    emb_tok = AutoTokenizer.from_pretrained(EMB_MODEL, use_fast=True)
    emb_model = AutoModel.from_pretrained(EMB_MODEL).to(DEVICE)
    emb_model.eval()

    @torch.no_grad()
    def encode_texts(texts: List[str], batch_size: int=64) -> np.ndarray:
        all_vecs = []
        for i in range(0, len(texts), batch_size):
            batch = texts[i:i+batch_size]
            enc = emb_tok(batch, padding=True, truncation=True, max_length=256, return_tensors="pt")
            enc = {k: v.to(DEVICE) for k, v in enc.items()}
            out = emb_model(**enc).last_hidden_state
            mask = enc["attention_mask"].unsqueeze(-1).float()
            pooled = (out * mask).sum(dim=1) / torch.clamp(mask.sum(dim=1), min=1e-6)
            vec = pooled.detach().cpu().numpy().astype("float32")
            vec /= (np.linalg.norm(vec, axis=1, keepdims=True) + 1e-12)
            all_vecs.append(vec)
        return np.concatenate(all_vecs, axis=0)

    dim = 768
    index = faiss.IndexFlatIP(dim)
    docs_text = []
    buf = []
    BATCH = 64

    stream_ds = load_dataset(KB_NAME, split="train", streaming=True)

    cnt = 0
    for ex in stream_ds:
        text = extract_text_from_kb_example(ex)
        text = _normalize(text)
        if not text:
            continue
        docs_text.append(text)
        buf.append(text)
        cnt += 1

        if len(buf) >= BATCH:
            vecs = encode_texts(buf, batch_size=BATCH)
            index.add(vecs)
            buf = []

        if cnt >= MAX_DOCS:
            break

    if buf:
        vecs = encode_texts(buf, batch_size=BATCH)
        index.add(vecs)

    print("Built docs:", len(docs_text), "FAISS ntotal:", index.ntotal)

    faiss.write_index(index, RAG_INDEX_PATH)
    with open(RAG_DOCS_PATH, "wb") as f:
        pickle.dump(docs_text, f)

    print("Saved index:", RAG_INDEX_PATH)
    print("Saved docs :", RAG_DOCS_PATH)

build_kb_if_needed()

# --- 加载 KB ---
index = faiss.read_index(RAG_INDEX_PATH)
with open(RAG_DOCS_PATH, "rb") as f:
    docs_text = pickle.load(f)
print("KB loaded. docs:", len(docs_text), "dim:", index.d)

# --- 加载 embedding 模型（用于检索 query）---
emb_tok = AutoTokenizer.from_pretrained(EMB_MODEL, use_fast=True)
emb_model = AutoModel.from_pretrained(EMB_MODEL).to(DEVICE)
emb_model.eval()

@torch.no_grad()
def encode_query(text: str) -> np.ndarray:
    enc = emb_tok([text], padding=True, truncation=True, max_length=256, return_tensors="pt")
    enc = {k: v.to(DEVICE) for k, v in enc.items()}
    out = emb_model(**enc).last_hidden_state
    mask = enc["attention_mask"].unsqueeze(-1).float()
    pooled = (out * mask).sum(dim=1) / torch.clamp(mask.sum(dim=1), min=1e-6)
    vec = pooled.detach().cpu().numpy().astype("float32")
    vec /= (np.linalg.norm(vec, axis=1, keepdims=True) + 1e-12)
    return vec

def retrieve_topk_with_scores(query: str, k: int=2):
    qv = encode_query(query)
    D, I = index.search(qv, k)
    out = []
    for s, j in zip(D[0].tolist(), I[0].tolist()):
        if 0 <= j < len(docs_text):
            out.append((float(s), docs_text[j]))
    return out


发现 KB 缓存，跳过建库。
KB loaded. docs: 10000 dim: 768


In [16]:
# 简单缓存：训练集很大，避免同一题多次 encode_query（尤其你调参时）
_evidence_cache = {}

def get_rag_evidence(ex: Dict[str, Any]) -> str:
    q = _normalize(ex["question"])
    if q in _evidence_cache:
        return _evidence_cache[q]

    opts = [ex["opa"], ex["opb"], ex["opc"], ex["opd"]]
    got = retrieve_topk_with_scores(q, k=RAG_K)

    lines = []
    for score, doc in got:
        if score < RAG_TAU:
            continue
        raw = doc[:1200]
        short = compress_context_by_overlap(q, opts, raw, keep_sents=EVID_KEEP_SENTS, max_chars=EVID_MAX_CHARS)
        if short:
            lines.append(short)

    # 最多拼 1~2 条（这里建议 1 条更稳）
    ev = _normalize(" ".join(lines[:1]))
    _evidence_cache[q] = ev
    return ev

def build_prompt_base_letter(ex: Dict[str, Any]) -> str:
    q = ex["question"]
    A,B,C,D = ex["opa"], ex["opb"], ex["opc"], ex["opd"]
    return (
        "You are a medical exam solver. Choose the single best option.\n"
        "Reply with ONLY one letter: A, B, C, or D.\n\n"
        f"Question:\n{q}\n\n"
        f"Options:\nA) {A}\nB) {B}\nC) {C}\nD) {D}\n\n"
        "Answer (A, B, C, or D):"
    )

def build_prompt_rag_letter(ex: Dict[str, Any], tok) -> Tuple[str, str]:
    base = build_prompt_base_letter(ex)
    cut = "Answer (A, B, C, or D):"
    assert base.endswith(cut)
    prefix = base[:-len(cut)]

    ev = get_rag_evidence(ex)
    evidence = ""
    if ev:
        # 控制 evidence token budget：保证不会挤爆 MAX_INPUT_TOKENS
        ev_prefix = "Evidence:\n"
        base_len = len(tok(base, add_special_tokens=False).input_ids)
        overhead = len(tok(ev_prefix + "\n", add_special_tokens=False).input_ids)
        budget = max(0, MAX_INPUT_TOKENS - base_len - overhead - 8)
        if budget > 0:
            evid_ids = tok(ev, add_special_tokens=False).input_ids[:budget]
            evidence = tok.decode(evid_ids)

    ev_block = (f"Evidence:\n{evidence}\n\n" if evidence else "")
    prompt = prefix + ev_block + cut
    return prompt, evidence


In [17]:
def load_reader_tokenizer(model_path: str):
    tok = AutoTokenizer.from_pretrained(model_path, use_fast=True)
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token
    tok.truncation_side = "left"
    tok.padding_side = "left"
    return tok

tok_train = load_reader_tokenizer(MODEL_BASE)

def pack_one_example(ex: Dict[str, Any], use_rag: bool, tok) -> Dict[str, Any]:
    """
    生成 (input_ids, attention_mask, labels)：
    - prompt 部分 labels = -100（不算 loss）
    - answer 字母部分 labels = token ids（算 loss）
    """
    if not use_rag:
        prompt = build_prompt_base_letter(ex)
        evidence = ""
    else:
        prompt, evidence = build_prompt_rag_letter(ex, tok)

    # 目标：只输出字母（训练时给 " A" 这种形式更稳）
    y = ex["label"]
    target = " " + CHOICE_LETTERS[y]

    prompt_ids = tok(prompt, add_special_tokens=False).input_ids
    target_ids = tok(target, add_special_tokens=False).input_ids

    # 保证答案不被截掉：只截 prompt（左截）
    max_prompt_len = MAX_INPUT_TOKENS - len(target_ids)
    if max_prompt_len < 1:
        # 极端情况：目标都放不下
        prompt_ids = prompt_ids[-1:]
        max_prompt_len = MAX_INPUT_TOKENS - len(target_ids)

    if len(prompt_ids) > max_prompt_len:
        prompt_ids = prompt_ids[-max_prompt_len:]

    input_ids = prompt_ids + target_ids
    labels = [-100]*len(prompt_ids) + target_ids
    attention_mask = [1]*len(input_ids)

    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels,
        "use_rag": int(use_rag),
        "evidence_preview": evidence[:120] if evidence else ""
    }

class MedQARagTrainDataset(Dataset):
    def __init__(self, raw_ds, indices: List[int], tok, p_rag: float):
        self.raw_ds = raw_ds
        self.indices = indices
        self.tok = tok
        self.p_rag = p_rag

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, i):
        idx = self.indices[i]
        ex = normalize_medqa_example(self.raw_ds[idx], idx=idx)
        use_rag = (random.random() < self.p_rag)
        item = pack_one_example(ex, use_rag=use_rag, tok=self.tok)
        return item

def collate_fn(batch, pad_token_id: int):
    # 左 padding（与 tok.padding_side 一致）
    max_len = max(len(x["input_ids"]) for x in batch)
    input_ids, attn, labels = [], [], []
    use_rag = []
    for x in batch:
        L = len(x["input_ids"])
        pad = max_len - L
        input_ids.append([pad_token_id]*pad + x["input_ids"])
        attn.append([0]*pad + x["attention_mask"])
        labels.append([-100]*pad + x["labels"])
        use_rag.append(x.get("use_rag", 0))
    return {
        "input_ids": torch.tensor(input_ids, dtype=torch.long),
        "attention_mask": torch.tensor(attn, dtype=torch.long),
        "labels": torch.tensor(labels, dtype=torch.long),
        "use_rag": torch.tensor(use_rag, dtype=torch.long),
    }

train_ds = MedQARagTrainDataset(train_raw, trn_idxs, tok_train, p_rag=P_RAG_TRAIN)
val_ds   = MedQARagTrainDataset(train_raw, val_idxs, tok_train, p_rag=P_RAG_TRAIN)

print("train_ds:", len(train_ds), "val_ds:", len(val_ds))


train_ds: 8178 val_ds: 2000


In [19]:
model = AutoModelForCausalLM.from_pretrained(
    MODEL_BASE,
    torch_dtype=torch.float32,
    device_map=None
)
if torch.cuda.is_available():
    model.to("cuda")
model.train()

args = TrainingArguments(
    output_dir=FT_OUT_DIR,
    overwrite_output_dir=True,
    num_train_epochs=EPOCHS,
    learning_rate=LR,
    warmup_ratio=WARMUP_RATIO,
    per_device_train_batch_size=TRAIN_BS,
    per_device_eval_batch_size=EVAL_BS,
    gradient_accumulation_steps=GRAD_ACC,

    eval_strategy="steps",
    eval_steps=1000,
    save_strategy="steps",
    save_steps=1000,
    save_total_limit=2,

    logging_steps=50,
    report_to=[],
    dataloader_num_workers=0,

    fp16=False,
    bf16=False,

    max_grad_norm=1.0,
    weight_decay=0.01,
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    data_collator=lambda b: collate_fn(b, pad_token_id=tok_train.pad_token_id),
)

trainer.train()
trainer.save_model(FT_OUT_DIR)
tok_train.save_pretrained(FT_OUT_DIR)
print("Saved to:", FT_OUT_DIR)


  self.scaler = torch.cuda.amp.GradScaler(**kwargs)


Step,Training Loss,Validation Loss


Saved to: /media/miaoen/ad4277ac-5cfe-47b0-a2cc-f9e50e0da444/LLM/out_gpt2_medqa_rag_letter


In [20]:
def load_reader(model_path: str):
    tok = AutoTokenizer.from_pretrained(model_path, use_fast=True)
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token
    tok.truncation_side = "left"
    tok.padding_side = "left"
    m = AutoModelForCausalLM.from_pretrained(
        model_path,
        torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
        device_map="auto" if torch.cuda.is_available() else None
    )
    m.eval()
    return tok, m

@torch.no_grad()
def score_continuation(tok, model, prompt_ids, cont_ids):
    input_ids = torch.cat([prompt_ids, cont_ids], dim=1)
    out = model(input_ids=input_ids)
    logits = out.logits
    prompt_len = prompt_ids.size(1)
    lp = 0.0
    for i in range(cont_ids.size(1)):
        pos = prompt_len + i - 1
        logp = F.log_softmax(logits[0, pos], dim=-1)
        tid = cont_ids[0, i].item()
        lp += float(logp[tid].item())
    return lp

@torch.no_grad()
def predict_by_letter_logits(tok, model, prompt: str):
    enc = tok(prompt, return_tensors="pt", truncation=True, max_length=MAX_INPUT_TOKENS)
    prompt_ids = enc["input_ids"].to(model.device)

    scores = []
    for L in ["A","B","C","D"]:
        c1 = tok.encode(" " + L, add_special_tokens=False)
        c2 = tok.encode(L, add_special_tokens=False)
        t1 = torch.tensor([c1], device=model.device, dtype=torch.long) if c1 else None
        t2 = torch.tensor([c2], device=model.device, dtype=torch.long) if c2 else None
        s1 = score_continuation(tok, model, prompt_ids, t1) if t1 is not None else -1e9
        s2 = score_continuation(tok, model, prompt_ids, t2) if t2 is not None else -1e9
        scores.append(max(s1, s2))
    pred = int(np.argmax(scores))
    return pred, scores

def run_eval_letter(model_path: str, name: str, use_rag: bool, limit=None):
    tok, model = load_reader(model_path)
    test_ds = ds["test"]
    n = len(test_ds) if limit is None else min(limit, len(test_ds))

    total=0; correct=0; trunc=0; lens=[]; ev_nonempty=0

    ts = datetime.now().strftime("%Y%m%d_%H%M%S")
    out_jsonl = os.path.join(
        OUT_DIR,
        f"{name}_test_{'rag' if use_rag else 'no_rag'}_letter_{ts}.jsonl"
    )

    with open(out_jsonl, "w", encoding="utf-8") as wf:
        for i in range(n):
            ex = normalize_medqa_example(test_ds[i], idx=i)

            if not use_rag:
                prompt = build_prompt_base_letter(ex)
                evidence = ""
            else:
                prompt, evidence = build_prompt_rag_letter(ex, tok)

            L = len(tok(prompt, add_special_tokens=False).input_ids)
            lens.append(L)
            if L > MAX_INPUT_TOKENS: trunc += 1
            if (evidence or "").strip(): ev_nonempty += 1

            pred, scores = predict_by_letter_logits(tok, model, prompt)
            ok = (pred == ex["label"])

            total += 1
            correct += int(ok)

            wf.write(json.dumps({
                "idx": i,
                "gt": ex["label"],
                "pred": pred,
                "is_correct": bool(ok),
                "use_rag": bool(use_rag),
                "prompt_len": L,
                "evidence_preview": (evidence[:200] if evidence else ""),
                "scores": scores,
            }, ensure_ascii=False) + "\n")

    acc = correct/total
    trunc_rate = trunc/total
    arr = np.array(lens)
    stats = {"mean": float(arr.mean()), "p50": float(np.percentile(arr,50)), "p90": float(np.percentile(arr,90)), "p99": float(np.percentile(arr,99)), "max": int(arr.max())}

    print(f"[{name} | test | {'RAG' if use_rag else 'NO-RAG'} | letter_logits]")
    print(f"total={total} acc={acc:.4f} trunc_rate={trunc_rate:.4f} ev_nonempty={ev_nonempty}/{total}")
    print("prompt_len_stats:", stats)
    print(" ->", out_jsonl)
    return acc

# ========== 4 组对比 ==========
acc_base_no = run_eval_letter(MODEL_BASE, "gpt2_base", use_rag=False)
acc_base_rg = run_eval_letter(MODEL_BASE, "gpt2_base", use_rag=True)

acc_new_no  = run_eval_letter(FT_OUT_DIR, "gpt2_medqa_ft", use_rag=False)
acc_new_rg  = run_eval_letter(FT_OUT_DIR, "gpt2_medqa_ft", use_rag=True)

print("\n=== SUMMARY ===")
print("base  no-rag:", acc_base_no, "rag:", acc_base_rg, "gain:", acc_base_rg-acc_base_no)
print("newFT no-rag:", acc_new_no,  "rag:", acc_new_rg,  "gain:", acc_new_rg-acc_new_no)


[gpt2_base | test | NO-RAG | letter_logits]
total=1273 acc=0.2773 trunc_rate=0.0086 ev_nonempty=0/1273
prompt_len_stats: {'mean': 261.9858601728201, 'p50': 250.0, 'p90': 363.0, 'p99': 499.51999999999975, 'max': 945}
 -> /media/miaoen/ad4277ac-5cfe-47b0-a2cc-f9e50e0da444/LLM/eval_out/medqa_train_letter_rag/gpt2_base_test_no_rag_letter_20251216_175526.jsonl
[gpt2_base | test | RAG | letter_logits]
total=1273 acc=0.2773 trunc_rate=0.0086 ev_nonempty=1260/1273
prompt_len_stats: {'mean': 296.53731343283584, 'p50': 284.0, 'p90': 399.0, 'p99': 506.0, 'max': 945}
 -> /media/miaoen/ad4277ac-5cfe-47b0-a2cc-f9e50e0da444/LLM/eval_out/medqa_train_letter_rag/gpt2_base_test_rag_letter_20251216_175637.jsonl
[gpt2_medqa_ft | test | NO-RAG | letter_logits]
total=1273 acc=0.2584 trunc_rate=0.0086 ev_nonempty=0/1273
prompt_len_stats: {'mean': 261.9858601728201, 'p50': 250.0, 'p90': 363.0, 'p99': 499.51999999999975, 'max': 945}
 -> /media/miaoen/ad4277ac-5cfe-47b0-a2cc-f9e50e0da444/LLM/eval_out/medqa_train