In [50]:
# =========================
# PubMedQA Evaluation (BASE vs CTX)
# - BASE: question only
# - CTX : question + its own context (open-book "RAG")
# - Scoring: logprob over {yes,no,maybe} (NO format errors)
# =========================

import re, json, time, pickle
from pathlib import Path
from typing import List, Dict, Optional, Any

import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM


In [51]:
# ---- CONFIG ----

TEST_PARQUET = "./data/pubmedqa_hf/pqa_labeled_splits/test.parquet"

# 推荐先用 gpt2 跑通（显存小），llama2 可用 CPU 或量化
GEN_MODEL_PATH = "./gpt2"    # or "./llama2"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# context 截断策略
CTX_MAX_CHARS = 700
CTX_N_SENTS = 3

# 评测数量
LIMIT = None  # None=全量（你test只有100），或 50/100

SEED = 2025
torch.manual_seed(SEED)

print("DEVICE:", DEVICE, "MODEL:", GEN_MODEL_PATH)


DEVICE: cuda MODEL: ./gpt2


In [52]:
# ---- Load data ----

test_df = pd.read_parquet(TEST_PARQUET)
print("len(test_df) =", len(test_df))
print("columns:", list(test_df.columns))
print(test_df[["question","final_decision"]].head(3))


len(test_df) = 100
columns: ['pubid', 'question', 'context', 'long_answer', 'final_decision']
                                            question final_decision
0  Malnutrition, a new inducer for arterial calci...            yes
1  Should temperature be monitorized during kidne...             no
2  Screening for gestational diabetes mellitus: a...            yes


In [53]:
# ---- Text helpers ----

def normalize_text(s: str) -> str:
    return " ".join(str(s).strip().split())

_SENT_SPLIT = re.compile(r'(?<=[\.\?\!])\s+')

def first_n_sents(text: str, n=3) -> str:
    text = normalize_text(text)
    if not text:
        return ""
    sents = _SENT_SPLIT.split(text)
    return " ".join(sents[:n])

def row_to_question(row: pd.Series) -> str:
    return normalize_text(row["question"])

def row_to_context(row: pd.Series) -> str:
    ctx = row["context"]
    # HF PubMedQA: context是dict {"contexts":[...]} 或者直接就是dict/str
    if isinstance(ctx, dict):
        contexts = ctx.get("contexts", [])
        if isinstance(contexts, list):
            return "\n".join([normalize_text(x) for x in contexts if x])
        return normalize_text(str(ctx))
    if isinstance(ctx, list):
        return "\n".join([normalize_text(x) for x in ctx if x])
    return normalize_text(str(ctx))

def row_to_gold(row: pd.Series) -> str:
    g = normalize_text(row["final_decision"]).lower()
    assert g in ["yes","no","maybe"], f"bad gold: {g}"
    return g


In [54]:
# ---- Safe load model (GPU fp16 if possible, else CPU) ----

def load_gen_model_safe(model_path: str, prefer_cuda: bool = True):
    tok = AutoTokenizer.from_pretrained(model_path, use_fast=True)
    if tok.pad_token is None:
        tok.pad_token = tok.eos_token

    if prefer_cuda and torch.cuda.is_available():
        try:
            mdl = AutoModelForCausalLM.from_pretrained(
                model_path,
                torch_dtype=torch.float16,
                low_cpu_mem_usage=True,
            ).to("cuda").eval()
            return tok, mdl, "cuda"
        except torch.cuda.OutOfMemoryError:
            torch.cuda.empty_cache()

    mdl = AutoModelForCausalLM.from_pretrained(
        model_path,
        torch_dtype=torch.float32,
        low_cpu_mem_usage=True,
    ).to("cpu").eval()
    return tok, mdl, "cpu"

gen_tok, gen_mdl, GEN_DEVICE = load_gen_model_safe(GEN_MODEL_PATH, prefer_cuda=(DEVICE=="cuda"))
print("Loaded on:", GEN_DEVICE)


Loaded on: cuda


In [40]:
# ---- Build KB docs from TRAIN ----
# 每条 doc = 一个训练样本的 context（可以加 question 作为索引增强）

kb_docs = []
kb_texts = []
for i in tqdm(range(len(train_df)), desc="build_kb"):
    r = train_df.iloc[i]
    q = row_to_question(r)
    ctx = row_to_context_str(r)
    gold = row_to_gold_decision(r)
    if not q or not ctx:
        continue

    # doc text 用于检索：question + context（更稳）
    doc_text = f"{q}\n{ctx}"
    kb_docs.append({"id": int(i), "q": q, "ctx": ctx, "gold": gold, "text": doc_text})
    kb_texts.append("passage: " + doc_text)

print("kb size:", len(kb_docs))

build_kb: 100%|██████████| 800/800 [00:00<00:00, 3582.96it/s]

kb size: 800





In [55]:
# ---- Prompts ----
# 这里不需要强行让它“只输出一个词”，因为我们不生成；我们算 logprob 三分类。

def build_prompt_base(q: str) -> str:
    return (
        "Answer the question with one of: yes, no, maybe.\n"
        f"Question: {q}\n"
        "Answer:"
    )

def build_prompt_ctx(q: str, ctx: str) -> str:
    return (
        "Use the given context to answer with one of: yes, no, maybe.\n"
        f"Context: {ctx}\n"
        f"Question: {q}\n"
        "Answer:"
    )


In [56]:
# ---- Logprob scoring over choices (yes/no/maybe) ----
# 关键：对每个候选词，计算 logP(candidate_tokens | prompt)

@torch.no_grad()
def score_choice_logprob(prompt: str, choice: str) -> float:
    # 让choice带前置空格更稳定（GPT2/BPE）
    choice = choice if choice.startswith(" ") else (" " + choice)

    prompt_ids = gen_tok(prompt, return_tensors="pt", add_special_tokens=False).input_ids.to(GEN_DEVICE)
    choice_ids = gen_tok(choice, return_tensors="pt", add_special_tokens=False).input_ids.to(GEN_DEVICE)

    # 拼接后做 teacher forcing
    input_ids = torch.cat([prompt_ids, choice_ids], dim=1)
    out = gen_mdl(input_ids=input_ids)
    logits = out.logits  # [1, T, V]

    # choice部分每个token的logprob来自其前一位的logits
    # 例如 choice_ids 的第0个token，其概率来自 logits 在 prompt末位那个位置
    logp = F.log_softmax(logits, dim=-1)

    # 取出对应 token 的 logprob 并求和
    # positions: prompt_len-1 ... prompt_len+choice_len-2 对应预测 choice tokens
    prompt_len = prompt_ids.shape[1]
    choice_len = choice_ids.shape[1]

    total = 0.0
    for j in range(choice_len):
        pos = prompt_len + j - 1
        tok_id = int(choice_ids[0, j].item())
        total += float(logp[0, pos, tok_id].item())
    return total

def predict_decision_logprob(prompt: str) -> Dict[str, float]:
    scores = {
        "yes": score_choice_logprob(prompt, "yes"),
        "no": score_choice_logprob(prompt, "no"),
        "maybe": score_choice_logprob(prompt, "maybe"),
    }
    return scores

def argmax_key(d: Dict[str, float]) -> str:
    return max(d, key=d.get)


In [57]:
# ---- Eval ----

def eval_pubmedqa(df: pd.DataFrame, mode: str, limit: Optional[int] = None):
    rows = df.iloc[:limit] if limit else df
    records = []

    for i in tqdm(range(len(rows)), desc=f"eval:{mode}"):
        r = rows.iloc[i]
        q = row_to_question(r)
        gold = row_to_gold(r)

        if mode == "base":
            prompt = build_prompt_base(q)
            used_ctx = ""
        elif mode == "ctx":
            ctx_full = row_to_context(r)
            ctx = first_n_sents(ctx_full, n=CTX_N_SENTS)[:CTX_MAX_CHARS]
            used_ctx = ctx
            prompt = build_prompt_ctx(q, ctx)
        else:
            raise ValueError("mode must be 'base' or 'ctx'")

        scores = predict_decision_logprob(prompt)
        pred = argmax_key(scores)

        records.append({
            "i": int(i),
            "gold": gold,
            "pred": pred,
            "scores": scores,
            "ctx_used": used_ctx,
            "question": q,
        })
    return records

def decision_report(records: List[dict]) -> dict:
    n = len(records)
    correct = 0
    conf = {}
    for r in records:
        g, p = r["gold"], r["pred"]
        conf[f"{g}->{p}"] = conf.get(f"{g}->{p}", 0) + 1
        if g == p:
            correct += 1
    return {"acc": correct / n if n else 0.0, "support": n, "confusion": conf}

base_recs = eval_pubmedqa(test_df, "base", limit=LIMIT)
ctx_recs  = eval_pubmedqa(test_df, "ctx",  limit=LIMIT)

print("BASE:", decision_report(base_recs))
print("CTX :", decision_report(ctx_recs))

eval:base: 100%|██████████| 100/100 [00:02<00:00, 44.53it/s]
eval:ctx: 100%|██████████| 100/100 [00:02<00:00, 43.56it/s]

BASE: {'acc': 0.6, 'support': 100, 'confusion': {'yes->yes': 59, 'no->yes': 29, 'yes->no': 3, 'maybe->yes': 8, 'no->no': 1}}
CTX : {'acc': 0.61, 'support': 100, 'confusion': {'yes->yes': 57, 'no->yes': 26, 'yes->no': 5, 'maybe->yes': 7, 'no->no': 4, 'maybe->no': 1}}





In [None]:
# ---- Save ----

out_dir = Path("./eval_out")
out_dir.mkdir(parents=True, exist_ok=True)
ts = time.strftime("%Y%m%d_%H%M%S")

base_path = out_dir / f"pubmedqa_base_logprob_{ts}.jsonl"
ctx_path  = out_dir / f"pubmedqa_ctx_logprob_{ts}.jsonl"

with base_path.open("w", encoding="utf-8") as f:
    for r in base_recs:
        f.write(json.dumps(r, ensure_ascii=False) + "\n")

with ctx_path.open("w", encoding="utf-8") as f:
    for r in ctx_recs:
        f.write(json.dumps(r, ensure_ascii=False) + "\n")

print("Saved:", base_path)
print("Saved:", ctx_path)


Saved: eval_out/pubmedqa_base_logprob_20251222_150027.jsonl
Saved: eval_out/pubmedqa_ctx_logprob_20251222_150027.jsonl


: 

: 

In [45]:
# ---- Generate ----

@torch.no_grad()
def generate_text(prompt: str, max_new_tokens=128) -> str:
    t = gen_tok(prompt, return_tensors="pt", truncation=True, max_length=2048)
    t = {k:v.to(DEVICE) for k,v in t.items()}
    out = gen_mdl.generate(
        **t,
        max_new_tokens=max_new_tokens,
        do_sample=False,
        num_beams=1,
        pad_token_id=gen_tok.eos_token_id,
        eos_token_id=gen_tok.eos_token_id,
    )
    text = gen_tok.decode(out[0], skip_special_tokens=True)
    # 截掉 prompt 前缀，保留新增部分（更干净）
    if text.startswith(prompt):
        text = text[len(prompt):]
    return text.strip()

In [46]:
# ---- Eval loop: base vs rag ----

def eval_pubmedqa(df: pd.DataFrame, mode: str, limit: Optional[int] = None):
    rows = df.iloc[:limit] if limit else df
    records = []
    for i in tqdm(range(len(rows)), desc=f"eval:{mode}"):
        r = rows.iloc[i]
        q = row_to_question(r)
        gold_dec = row_to_gold_decision(r)
        gold_text = row_to_gold_text(r)

        if mode == "base":
            prompt = build_prompt_base(q)
            evid = ""
        elif mode == "rag":
            hits = retrieve(q, k=RAG_K)
            evid = build_evidence_from_hits(hits, max_chars=CTX_MAX_CHARS)
            prompt = build_prompt_rag(q, evid)
        else:
            raise ValueError("mode must be base or rag")

        pred_text = generate_text(prompt, max_new_tokens=MAX_NEW_TOKENS)
        pred_dec = parse_decision(pred_text)

        records.append({
            "i": int(i),
            "question": q,
            "gold_decision": gold_dec,
            "pred_decision": pred_dec,
            "gold_text": gold_text,
            "pred_text": pred_text,
            "rag_context": evid,
        })
    return records

base_recs = eval_pubmedqa(test_df, "base", limit=LIMIT_TEST)
rag_recs  = eval_pubmedqa(test_df, "rag",  limit=LIMIT_TEST)
len(base_recs), len(rag_recs)

eval:base: 100%|██████████| 100/100 [00:02<00:00, 45.49it/s]
eval:rag: 100%|██████████| 100/100 [00:02<00:00, 33.47it/s]


(100, 100)

In [48]:
# ---- Decision ACC ----

def decision_acc(records):
    ok = 0
    n = 0
    fmt_err = 0
    conf = {}
    for r in records:
        g = r["gold_decision"]
        p = r["pred_decision"]
        if p is None:
            fmt_err += 1
            continue
        n += 1
        conf[f"{g}->{p}"] = conf.get(f"{g}->{p}", 0) + 1
        if g == p:
            ok += 1
    acc = ok / n if n else 0.0
    return {"acc": acc, "support": n, "format_errors": fmt_err, "confusion": conf}

LIMIT_TEST = 100
base_recs = eval_pubmedqa(test_df, "base", limit=LIMIT_TEST)
rag_recs  = eval_pubmedqa(test_df, "rag",  limit=LIMIT_TEST)

print("BASE decision:", decision_acc(base_recs))
print("RAG  decision:", decision_acc(rag_recs))


eval:base: 100%|██████████| 100/100 [00:02<00:00, 45.28it/s]
eval:rag: 100%|██████████| 100/100 [00:02<00:00, 33.83it/s]

BASE decision: {'acc': 0.5714285714285714, 'support': 98, 'format_errors': 2, 'confusion': {'yes->yes': 55, 'no->yes': 28, 'yes->no': 6, 'maybe->yes': 7, 'no->no': 1, 'maybe->no': 1}}
RAG  decision: {'acc': 0.4186046511627907, 'support': 86, 'format_errors': 14, 'confusion': {'no->yes': 9, 'yes->no': 33, 'no->no': 17, 'yes->yes': 19, 'maybe->yes': 2, 'maybe->no': 6}}





In [28]:
print("Columns:", list(test_df.columns))

for col in ["answer", "final_decision", "decision", "label", "gold", "raw"]:
    if col in test_df.columns:
        print(col, "example:", test_df[col].head(3).tolist())

# 看前5条 gold_decision 实际解析结果
for i in range(5):
    r = test_df.iloc[i]
    print(i, "q:", row_to_question(r)[:80])
    print("gold_decision(parsed):", repr(row_to_gold_decision(r)))

Columns: ['pubid', 'question', 'context', 'long_answer', 'final_decision']
final_decision example: ['yes', 'no', 'yes']
0 q: Malnutrition, a new inducer for arterial calcification in hemodialysis patients?
gold_decision(parsed): ''
1 q: Should temperature be monitorized during kidney allograft preservation?
gold_decision(parsed): ''
2 q: Screening for gestational diabetes mellitus: are the criteria proposed by the in
gold_decision(parsed): ''
3 q: Is resected stomach volume related to weight loss after laparoscopic sleeve gast
gold_decision(parsed): ''
4 q: Body perception: do parents, their children, and their children's physicians per
gold_decision(parsed): ''


In [26]:
# ---- BERTScore (optional but matches你之前实验) ----
# 如果没装： pip install bert-score
from bert_score import score as bertscore

def bertscore_mean(records):
    refs = [r["gold_text"] for r in records]
    hyps = [r["pred_text"] for r in records]
    P, R, F1 = bertscore(hyps, refs, lang="en", rescale_with_baseline=True)
    return {"precision": float(P.mean()), "recall": float(R.mean()), "f1": float(F1.mean())}

print("BASE BERTScore:", bertscore_mean(base_recs))
print("RAG  BERTScore:", bertscore_mean(rag_recs))

tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/482 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.42G [00:00<?, ?B/s]

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


BASE BERTScore: {'precision': -0.12728960812091827, 'recall': 0.0066066887229681015, 'f1': -0.060143403708934784}


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


RAG  BERTScore: {'precision': -0.17148424685001373, 'recall': 0.026392295956611633, 'f1': -0.07353765517473221}


In [27]:
# ---- Save results ----
ts = time.strftime("%Y%m%d_%H%M%S")
out_base = Path("./eval_out") / f"pubmedqa_base_{ts}.jsonl"
out_rag  = Path("./eval_out") / f"pubmedqa_rag_{ts}.jsonl"
out_base.parent.mkdir(parents=True, exist_ok=True)

with out_base.open("w", encoding="utf-8") as f:
    for r in base_recs:
        f.write(json.dumps(r, ensure_ascii=False) + "\n")
with out_rag.open("w", encoding="utf-8") as f:
    for r in rag_recs:
        f.write(json.dumps(r, ensure_ascii=False) + "\n")

print("Saved:", out_base)
print("Saved:", out_rag)

Saved: eval_out/pubmedqa_base_20251222_143546.jsonl
Saved: eval_out/pubmedqa_rag_20251222_143546.jsonl
