In [None]:
import os
import torch
import numpy as np
import pandas as pd
import pyarrow.parquet as pq
import evaluate
from datetime import datetime
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig



class Config:
    parquet_path = r"/media/miaoen/ad4277ac-5cfe-47b0-a2cc-f9e50e0da444/LLM/data/pubmedqa_hf/pqa_labeled_splits/test.parquet"
    model_path = r"/media/miaoen/ad4277ac-5cfe-47b0-a2cc-f9e50e0da444/LLM/llama2" 
    
    # Parameters
    limit = 100
    use_chat_template = True 
    max_ctx_chars = 4000
    seed = 2025

args = Config()

# seed
def set_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(args.seed)


# 1.build Prompt

def join_context(c, max_chars):
    try:
        ctxs = (c or {}).get("contexts", [])
        return " ".join(ctxs)[:max_chars] if ctxs else ""
    except:
        return ""

def build_prompt(tok, q, ctx, use_rag=False):
    """
    核心修改：通过 use_rag 参数控制是否提供上下文
    """
    # System Prompt
    system_msg = (
        "You are a biomedical expert. Answer the question directly and concisely. "
        "Do NOT start with 'Thank you'. Just provide the medical answer."
    )
    
    if use_rag:
        # With RAG
        user_content = (
            f"Retrieved Documents:\n{ctx}\n\n"
            f"Question: {q}\n\n"
            "Answer the question based strictly on the retrieved documents above."
        )
    else:
        # No RAG
        user_content = (
            f"Question: {q}\n\n"
            "Answer the question to the best of your knowledge. If you don't know, make an educated guess."
        )

    # Chat Template
    if args.use_chat_template and hasattr(tok, "apply_chat_template"):
        msgs = [
            {"role": "system", "content": system_msg},
            {"role": "user", "content": user_content}
        ]
        try:
            return tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
        except:
            pass
    
    # Fallback
    return f"[INST] <<SYS>>\n{system_msg}\n<</SYS>>\n\n{user_content} [/INST]"

def clean_output(full_text, prompt_len):
    """clean output"""
    gen_text = full_text[prompt_len:].strip()
    gen_text = gen_text.replace("[/INST]", "").strip()
    return gen_text if gen_text else "Error"


# 2. load data

print(f"[Info] Reading data: {args.parquet_path}")

if not os.path.exists(args.parquet_path):
    print("[Warning] Data file not found. Using mock data for demonstration.")
    questions = ["Does study X show that drug A is better than B?", "What is the primary outcome of this specific trial?"] * 25
    ctx_list = ["Study X shows Drug A reduced symptoms by 50% compared to B.", "The primary outcome was reduction in mortality."] * 25
    refs = ["Yes, drug A is better.", "Reduction in mortality."] * 25
    questions, ctx_list, refs = questions[:args.limit], ctx_list[:args.limit], refs[:args.limit]
else:
    try:
        tbl = pq.read_table(args.parquet_path)
        df = tbl.to_pandas().dropna().head(args.limit)
        
        questions = df["question"].tolist()
        ctx_list = df["context"].map(lambda c: join_context(c, args.max_ctx_chars)).tolist()
        
        target_col = "long_answer" if "long_answer" in df.columns else "final_decision"
        if target_col not in df.columns: target_col = df.columns[-1]
        refs = df[target_col].tolist()
        
        print(f"[Info] Data loaded successfully. Count: {len(questions)}")
    except Exception as e:
        print(f"[Error] Failed to read parquet file: {e}")
        exit()


# 3. 4-bit

print(f"[Info] Loading model: {args.model_path}")

try:
    tok = AutoTokenizer.from_pretrained(args.model_path, use_fast=True, local_files_only=True)
except:
    print("[Info] Local tokenizer not found, downloading...")
    tok = AutoTokenizer.from_pretrained(args.model_path, use_fast=True)

if tok.pad_token is None: tok.pad_token = tok.eos_token
tok.padding_side = "right"

# 4-bit
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4"
)

try:
    print("[Info] Enabling 4-bit quantization for speed optimization...")
    model = AutoModelForCausalLM.from_pretrained(
        args.model_path,
        quantization_config=bnb_config,
        device_map="auto",
        trust_remote_code=True
    )
except Exception as e:
    print(f"[Error] Model loading failed: {e}")
    print("Please ensure 'bitsandbytes' is installed: pip install bitsandbytes accelerate")
    exit()



time_str = datetime.now().strftime("%Y%m%d_%H%M%S")
run_dir = os.path.join("eval_out_notebook", f"run_{time_str}_fixed_comparison")
os.makedirs(run_dir, exist_ok=True)

preds_no = []
preds_rag = []

print(f"\n[Info] Starting comparison experiment (Limit={len(questions)})...")
print(" - No RAG: Closed Book (No context provided)")
print(" - With RAG: Open Book (Context provided)")

for q, ctx in tqdm(zip(questions, ctx_list), total=len(questions), desc="Inference"):
    
    gen_kwargs = {
        "max_new_tokens": 128,
        "pad_token_id": tok.pad_token_id,
        "eos_token_id": tok.eos_token_id,
        "do_sample": False,
        "repetition_penalty": 1.1
    }

    #A. No RAG
    prompt_no = build_prompt(tok, q, ctx, use_rag=False) 
    inputs_no = tok(prompt_no, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        out_no = model.generate(**inputs_no, **gen_kwargs)
        
    input_len_no = inputs_no["input_ids"].shape[1]
    decoded_no = tok.decode(out_no[0], skip_special_tokens=True)
    preds_no.append(clean_output(decoded_no, len(tok.decode(out_no[0][:input_len_no], skip_special_tokens=True))))
    
    #B. With RAG
    prompt_rag = build_prompt(tok, q, ctx, use_rag=True) 
    inputs_rag = tok(prompt_rag, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        out_rag = model.generate(**inputs_rag, **gen_kwargs)
        
    input_len_rag = inputs_rag["input_ids"].shape[1]
    decoded_rag = tok.decode(out_rag[0], skip_special_tokens=True)
    preds_rag.append(clean_output(decoded_rag, len(tok.decode(out_rag[0][:input_len_rag], skip_special_tokens=True))))



print("\n[Info] Calculating ROUGE scores...")
try:
    rouge = evaluate.load("rouge")
    scores_no = rouge.compute(predictions=preds_no, references=refs, use_aggregator=False)["rougeLsum"]
    scores_rag = rouge.compute(predictions=preds_rag, references=refs, use_aggregator=False)["rougeLsum"]

    avg_no = np.mean(scores_no)
    avg_rag = np.mean(scores_rag)

    print("\n" + "="*40)
    print("FINAL RESULTS SUMMARY")
    print("="*40)
    print(f"No RAG (Closed Book): {avg_no:.4f}")
    print(f"With RAG (Open Book): {avg_rag:.4f}")
    print(f"Net Gain            : +{(avg_rag - avg_no):.4f}")
    print("="*40 + "\n")
except Exception as e:
    print(f"[Error] ROUGE calculation failed: {e}")
    scores_no, scores_rag = [0]*len(questions), [0]*len(questions)

# save to csv
df_res = pd.DataFrame({
    "question": questions,
    "ref": refs,
    "pred_no_rag": preds_no,
    "pred_with_rag": preds_rag,
    "score_no": scores_no,
    "score_rag": scores_rag,
    "diff": np.array(scores_rag) - np.array(scores_no)
})

csv_path = os.path.join(run_dir, "results_fixed.csv")
df_res.to_csv(csv_path, index=False)
print(f"[Info] Detailed results saved to: {csv_path}")

# print sample
if len(df_res) > 0:
    print("\n[Info] Top 3 Most Improved Cases:")
    top_improved = df_res.sort_values(by="diff", ascending=False).head(3)
    for idx, row in top_improved.iterrows():
        print(f"\n[Case {idx}] Score Diff: +{row['diff']:.4f}")
        print(f"Question: {row['question']}")
        print(f"No RAG Answer: {str(row['pred_no_rag'])[:100]}...")
        print(f"With RAG Answer: {str(row['pred_with_rag'])[:100]}...")
        print("-" * 50)

[Info] Reading data: /media/miaoen/ad4277ac-5cfe-47b0-a2cc-f9e50e0da444/LLM/data/pubmedqa_hf/pqa_labeled_splits/test.parquet
[Info] Data loaded successfully. Count: 100
[Info] Loading model: /media/miaoen/ad4277ac-5cfe-47b0-a2cc-f9e50e0da444/LLM/llama2
[Info] Enabling 4-bit quantization for speed optimization...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]


[Info] Starting comparison experiment (Limit=100)...
 - No RAG: Closed Book (No context provided)
 - With RAG: Open Book (Context provided)


Inference: 100%|██████████| 100/100 [24:46<00:00, 14.86s/it]



[Info] Calculating ROUGE scores...


Using the latest cached version of the module from /home/miaoen/.cache/huggingface/modules/evaluate_modules/metrics/evaluate-metric--rouge/6e5315f72865c2eaa764c8361360bb938740b9c120a2cf3a7ad218aa0ce452ed (last modified on Sat Nov 15 19:23:49 2025) since it couldn't be found locally at evaluate-metric--rouge, or remotely on the Hugging Face Hub.


[Error] ROUGE calculation failed: No module named 'absl'
[Info] Detailed results saved to: eval_out_notebook/run_20251222_164019_fixed_comparison/results_fixed.csv

[Info] Top 3 Most Improved Cases:

[Case 99] Score Diff: +0.0000
Question: Does the treatment of amblyopia normalise subfoveal choroidal thickness in amblyopic children?
No RAG Answer: No, the treatment of amblyopia does not necessarily normalize subfoveal choroidal thickness in ambly...
With RAG Answer: According to the retrieved documents, the answer is no. The study found that amblyopia treatment did...
--------------------------------------------------

[Case 0] Score Diff: +0.0000
Question: Malnutrition, a new inducer for arterial calcification in hemodialysis patients?
No RAG Answer: Malnutrition has been identified as a potential inducer of arterial calcification in hemodialysis pa...
With RAG Answer: According to the retrieved documents, there is evidence to suggest that malnutrition may be a new in...
-------------

LLAMA2 BERTScore

In [None]:
import os
import pandas as pd
import numpy as np
import evaluate
import torch


# 1. Configure Path
#   Using the absolute path you provided directly
csv_path = r"C:\LLM\PrimeKG\eval_out_notebook\run_20251210_133244_fixed_comparison\results_fixed.csv"


# 2. Load Data
print(f"[Info] Reading file: {csv_path}")

if not os.path.exists(csv_path):
    print(f"[Error] File not found. Please check the path: {csv_path}")
    exit()

try:
    df = pd.read_csv(csv_path)
    print(f"[Success] Data loaded successfully. Total samples: {len(df)}")
except Exception as e:
    print(f"[Error] Failed to read CSV: {e}")
    exit()

# Extract key columns (prevent null values, convert to string)
preds_no = df["pred_no_rag"].fillna("").astype(str).tolist()
preds_rag = df["pred_with_rag"].fillna("").astype(str).tolist()
refs = df["ref"].fillna("").astype(str).tolist()


# 3. Load BERTScore Model
print("\n[Info] Loading BERTScore model (First time run will download roberta-large, ~1.4GB)...")
try:
    bertscore = evaluate.load("bertscore")
except Exception as e:
    print(f"[Error] Loading failed: {e}")
    print("Please check network connection, or run: pip install bert_score")
    exit()

# Use GPU acceleration
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"[Info] Compute device: {device}")


# 4. Start Calculation
# batch_size=32 fits RTX 3060 (12G), adjust to 16 or 8 if VRAM overflows
print("\n[Status] Calculating BERTScore for No RAG (1/2)...")
res_no = bertscore.compute(predictions=preds_no, references=refs, lang="en", device=device, batch_size=32)
f1_no = np.array(res_no['f1'])

print("[Status] Calculating BERTScore for With RAG (2/2)...")
res_rag = bertscore.compute(predictions=preds_rag, references=refs, lang="en", device=device, batch_size=32)
f1_rag = np.array(res_rag['f1'])


# 5. Results Summary
avg_no = f1_no.mean()
avg_rag = f1_rag.mean()

print(f"\n{'='*50}")
print(f"BERTScore Evaluation Results (F1 Score)")
print(f"{'='*50}")
print(f"   (1.0 = Perfect semantic match, 0.0 = Completely irrelevant)")
print(f"{'-'*50}")
print(f"No RAG (Closed Book) : {avg_no:.5f}")
print(f"With RAG (Open Book) : {avg_rag:.5f}")
print(f"Net Semantic Gain    : +{(avg_rag - avg_no):.5f}")
print(f"{'='*50}\n")


# 6. Case Analysis and Saving
# Merge scores back into DataFrame
df["bert_score_no"] = f1_no
df["bert_score_rag"] = f1_rag
df["bert_diff"] = f1_rag - f1_no

# Find Top 3 cases with highest semantic improvement
top_improved = df.sort_values(by="bert_diff", ascending=False).head(3)

print("Top 3 Cases with Highest Semantic Improvement:\n")
for idx, row in top_improved.iterrows():
    print(f"[Index: {idx}] Semantic Score Gain: +{row['bert_diff']:.4f}")
    print(f"Question: {row['question']}")
    print(f"Reference Answer: {row['ref']}")
    print(f"No RAG Prediction: {row['pred_no_rag'][:100]}...")
    print(f"With RAG Prediction: {row['pred_with_rag'][:100]}...")
    print("-" * 60)

# Save new results
output_path = csv_path.replace(".csv", "_bertscore_evaluated.csv")
df.to_csv(output_path, index=False)
print(f"\n[Info] Results saved to: {output_path}")

[Info] Reading file: C:\LLM\PrimeKG\eval_out_notebook\run_20251209_203948_gpt2\results.csv
[Success] Data loaded successfully. Total samples: 50

[Info] Loading BERTScore model (First time run will download roberta-large, ~1.4GB)...
[Info] Compute device: cuda

[Status] Calculating BERTScore for No RAG (1/2)...


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


[Status] Calculating BERTScore for With RAG (2/2)...

BERTScore Evaluation Results (F1 Score)
   (1.0 = Perfect semantic match, 0.0 = Completely irrelevant)
--------------------------------------------------
No RAG (Closed Book) : 0.81608
With RAG (Open Book) : 0.81667
Net Semantic Gain    : +0.00059

Top 3 Cases with Highest Semantic Improvement:

[Index: 47] Semantic Score Gain: +0.0259
Question: Human papillomavirus and pterygium. Is the virus a risk factor?
Reference Answer: The low presence of HPV DNA in pterygia does not support the hypothesis that HPV is involved in the development of pterygia in Denmark.
No RAG Prediction: Human Papiloma Virus (HPV) is an infectious disease of human origin, which has been identified as be...
With RAG Prediction: The human Papilioid Virus (HPV) is an infectious viral disease that causes severe, chronic inflammat...
------------------------------------------------------------
[Index: 11] Semantic Score Gain: +0.0096
Question: Does laparoscopic su

In [None]:
import os
import json
import random
import torch
import numpy as np
import pandas as pd
from datetime import datetime
from sklearn.feature_extraction.text import CountVectorizer
from transformers import (
    AutoTokenizer, 
    AutoModelForCausalLM, 
    BitsAndBytesConfig,
    logging as hf_logging
)
import evaluate

# Global Configuration Class
class Config:
    # Local data path
    parquet_path = r"C:\LLM\data\pubmedqa_hf\pqa_labeled\train-00000-of-00001.parquet"
    
    # Use the permission-free Llama-2-7b-chat mirror/model
    model_name = "NousResearch/Llama-2-7b-chat-hf"
    
    # Enable Chat Template
    use_chat_template = True 
    
    # Number of test samples (None to run the full dataset)
    limit = 10 
    
    # Generation parameters
    max_new_tokens = 128
    max_ctx_chars = 4000
    batch_size = 2
    
    output_dir = "eval_results_llama2"
    seed = 42

cfg = Config()

def set_seed(seed: int):
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(cfg.seed)
hf_logging.set_verbosity_error()

# Data Loading
print(f"Reading data from: {cfg.parquet_path}")
if not os.path.exists(cfg.parquet_path):
    raise FileNotFoundError(f"File not found: {cfg.parquet_path}")

df = pd.read_parquet(cfg.parquet_path).dropna(subset=['question', 'long_answer'])
if cfg.limit:
    df = df.head(cfg.limit)

def extract_context_text(c):
    try:
        if isinstance(c, dict) and 'contexts' in c: return " ".join(list(c['contexts']))
        if isinstance(c, list): return " ".join(c)
        return str(c)
    except: return ""

print("Processing contexts...")
ctx_list_rag = df["context"].apply(extract_context_text).apply(lambda x: x[:cfg.max_ctx_chars]).tolist()
ctx_list_base = [""] * len(df)
questions = df["question"].tolist()
refs = df["long_answer"].tolist()
print(f"Loaded {len(df)} samples.")

# Prompt Construction Function (Adapted for Llama 2)
def build_prompt(tok, q: str, ctx: str) -> str:
    has_ctx = ctx and len(ctx.strip()) > 0
    sys_msg = "You are a helpful biomedical expert. Answer the question precisely based on the provided context if available."
    
    if has_ctx:
        user_msg = (
            f"Context:\n{ctx}\n\n"
            f"Question: {q}\n"
            f"Task: Provide a concise medical reasoning (1-2 sentences) followed by a final decision: 'Yes', 'No', or 'Maybe'.\n"
            f"Answer:"
        )
    else:
        user_msg = (
            f"Question: {q}\n"
            f"Task: Provide a concise medical reasoning (1-2 sentences) followed by a final decision: 'Yes', 'No', or 'Maybe'.\n"
            f"Answer:"
        )

    messages = [
        {"role": "system", "content": sys_msg},
        {"role": "user", "content": user_msg}
    ]
    
    try:
        return tok.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    except:
        return f"[INST] <<SYS>>\n{sys_msg}\n<</SYS>>\n\n{user_msg} [/INST]"

# Keyword Recall Calculation
def compute_keyword_recall(preds, refs):
    scores = []
    vectorizer = CountVectorizer(stop_words='english')
    for p, r in zip(preds, refs):
        try:
            vectorizer.fit([r])
            ref_keywords = set(vectorizer.get_feature_names_out())
            if not ref_keywords: scores.append(0.0); continue
            p_lower = p.lower()
            hit = sum(1 for kw in ref_keywords if kw in p_lower)
            scores.append(hit / len(ref_keywords))
        except: scores.append(0.0)
    return scores

# Inference Function
def run_inference(model, tok, qs, ctxs, cfg, label):
    print(f"Inference: {label} ...")
    preds = []
    if tok.pad_token is None: tok.pad_token = tok.eos_token
    tok.padding_side = "left"
    
    prompts = [build_prompt(tok, q, c) for q, c in zip(qs, ctxs)]
    
    for i in range(0, len(prompts), cfg.batch_size):
        batch = prompts[i : i + cfg.batch_size]
        inputs = tok(batch, return_tensors="pt", padding=True, truncation=True, max_length=cfg.max_ctx_chars+512).to("cuda")
        
        with torch.no_grad():
            out = model.generate(
                **inputs, 
                max_new_tokens=cfg.max_new_tokens,
                pad_token_id=tok.pad_token_id,
                do_sample=False,
                temperature=0.0
            )
        
        input_len = inputs.input_ids.shape[1]
        decoded = tok.batch_decode(out[:, input_len:], skip_special_tokens=True)
        preds.extend([d.strip() for d in decoded])
        
        if (i // cfg.batch_size) % 2 == 0: print(f"Batch {i//cfg.batch_size + 1} done.")
            
    return preds

# Model Loading (4-bit Quantization)
print(f"Loading Model: {cfg.model_name} with 4-bit quantization...")

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4"
)

try:
    tok = AutoTokenizer.from_pretrained(cfg.model_name, use_fast=True)
    tok.pad_token = tok.eos_token
    
    model = AutoModelForCausalLM.from_pretrained(
        cfg.model_name,
        quantization_config=bnb_config,
        device_map="auto",
        trust_remote_code=True
    )
    model.eval()
    print("Model loaded successfully on GPU!")
except Exception as e:
    print(f"Load Error: {e}")
    raise e

# Execute Inference
preds_base = run_inference(model, tok, questions, ctx_list_base, cfg, label="Base (No Context)")
preds_rag = run_inference(model, tok, questions, ctx_list_rag, cfg, label="RAG (With Context)")

# Evaluation and Saving
print("\nCalculating Metrics...")
rouge = evaluate.load("rouge")
bertscore = evaluate.load("bertscore")

r_base = rouge.compute(predictions=preds_base, references=refs)
r_rag = rouge.compute(predictions=preds_rag, references=refs)

print("Calculating BERTScore...")
bs_base = bertscore.compute(predictions=preds_base, references=refs, lang="en", model_type="distilbert-base-uncased")
bs_rag = bertscore.compute(predictions=preds_rag, references=refs, lang="en", model_type="distilbert-base-uncased")

kw_base = compute_keyword_recall(preds_base, refs)
kw_rag = compute_keyword_recall(preds_rag, refs)

summary = {
    "Model": cfg.model_name,
    "ROUGE-L": {"Base": r_base['rougeL'], "RAG": r_rag['rougeL']},
    "BERTScore-F1": {"Base": np.mean(bs_base['f1']), "RAG": np.mean(bs_rag['f1'])},
    "Keyword-Recall": {"Base": np.mean(kw_base), "RAG": np.mean(kw_rag)}
}

print("\nEVALUATION SUMMARY (Llama-2)")
print(json.dumps(summary, indent=2))

os.makedirs(cfg.output_dir, exist_ok=True)
save_path = os.path.join(cfg.output_dir, f"llama2_results_{datetime.now().strftime('%H%M%S')}.jsonl")

with open(save_path, "w", encoding="utf-8") as f:
    for i in range(len(preds_base)):
        row = {
            "id": i,
            "question": questions[i],
            "reference": refs[i],
            "base_pred": preds_base[i],
            "rag_pred": preds_rag[i],
            "base_kw": kw_base[i],
            "rag_kw": kw_rag[i]
        }
        f.write(json.dumps(row, ensure_ascii=False) + "\n")

print(f"Saved to {save_path}")

In [None]:
import os
import pandas as pd

csv_path = r"/media/miaoen/ad4277ac-5cfe-47b0-a2cc-f9e50e0da444/LLM/PrimeKG/eval_out_notebook/run_20251209_203948_gpt2/results_bertscore_evaluated.csv"

assert os.path.exists(csv_path), f"File not found: {csv_path}"

df = pd.read_csv(csv_path)
df.columns = [c.strip() for c in df.columns]  # Strip whitespace from column names to prevent errors
print("[Info] Loaded:", csv_path, "| rows:", len(df))
print("[Info] Columns:", list(df.columns))

# Mandatory Column Check
need = ["question", "ref", "pred_no_rag", "pred_with_rag", "bert_score_no", "bert_score_rag", "bert_diff"]
missing = [c for c in need if c not in df.columns]
if missing:
    raise KeyError(f"Missing columns: {missing}\nAvailable: {list(df.columns)}")

# Convert to Numeric
for c in ["bert_score_no", "bert_score_rag", "bert_diff"]:
    df[c] = pd.to_numeric(df[c], errors="coerce")

df = df.dropna(subset=["bert_score_rag", "bert_score_no", "bert_diff"]).reset_index(drop=True)

# Top/Bottom Export
top3_rag = df.sort_values("bert_score_rag", ascending=False).head(3)
bot3_rag = df.sort_values("bert_score_rag", ascending=True).head(3)

top3_gain = df.sort_values("bert_diff", ascending=False).head(3)
bot3_gain = df.sort_values("bert_diff", ascending=True).head(3)

show_cols = [
    "question", "ref", "pred_no_rag", "pred_with_rag",
    "bert_score_no", "bert_score_rag", "bert_diff"
]

print("\n[Top3] Highest RAG BERTScore (bert_score_rag):")
print(top3_rag[["bert_score_rag","bert_diff","question"]].to_string(index=False))

print("\n[Bottom3] Lowest RAG BERTScore (bert_score_rag):")
print(bot3_rag[["bert_score_rag","bert_diff","question"]].to_string(index=False))

print("\n[Top3] Largest Gain (bert_diff):")
print(top3_gain[["bert_diff","bert_score_no","bert_score_rag","question"]].to_string(index=False))

print("\n[Bottom3] Worst Gain (bert_diff):")
print(bot3_gain[["bert_diff","bert_score_no","bert_score_rag","question"]].to_string(index=False))

# Save to the same directory
out_dir = os.path.dirname(csv_path) or "."
base = os.path.splitext(os.path.basename(csv_path))[0]

paths = {
    "top3_rag":  os.path.join(out_dir, f"{base}_top3_bert_rag.csv"),
    "bot3_rag":  os.path.join(out_dir, f"{base}_bottom3_bert_rag.csv"),
    "top3_gain": os.path.join(out_dir, f"{base}_top3_bert_gain.csv"),
    "bot3_gain": os.path.join(out_dir, f"{base}_bottom3_bert_gain.csv"),
}

top3_rag[show_cols].to_csv(paths["top3_rag"], index=False, encoding="utf-8-sig")
bot3_rag[show_cols].to_csv(paths["bot3_rag"], index=False, encoding="utf-8-sig")
top3_gain[show_cols].to_csv(paths["top3_gain"], index=False, encoding="utf-8-sig")
bot3_gain[show_cols].to_csv(paths["bot3_gain"], index=False, encoding="utf-8-sig")

print("\n[Saved]")
for k, p in paths.items():
    print(f" - {k}: {p}")

[Info] Loaded: /media/miaoen/ad4277ac-5cfe-47b0-a2cc-f9e50e0da444/LLM/PrimeKG/eval_out_notebook/run_20251209_203948_gpt2/results_bertscore_evaluated.csv | rows: 50
[Info] Columns: ['question', 'retrieved_facts', 'ref', 'pred_no_rag', 'pred_with_rag', 'score_no', 'score_rag', 'diff', 'bert_score_no', 'bert_score_rag', 'bert_diff']

[Top3] Highest RAG BERTScore (bert_score_rag):
 bert_score_rag  bert_diff                                                                                              question
       0.852124   0.000000 Preoperative locoregional staging of gastric cancer: is there a place for magnetic resonance imaging?
       0.839119   0.009647             Does laparoscopic surgery decrease the risk of atrial fibrillation after foregut surgery?
       0.839060   0.000000                 Is there any evidence of a "July effect" in patients undergoing major cancer surgery?

[Bottom3] Lowest RAG BERTScore (bert_score_rag):
 bert_score_rag  bert_diff                            

: 

llama2_pubmedqa_primekg

In [None]:
import os
import re
import random
from dataclasses import dataclass
from datetime import datetime
from typing import List, Optional
from collections import Counter

import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from sklearn.feature_extraction.text import TfidfVectorizer


# 1) Config (Only keep RAG evaluation)

@dataclass
class Config:
    parquet_path: str = "/media/miaoen/ad4277ac-5cfe-47b0-a2cc-f9e50e0da444/LLM/data/pubmedqa_hf/pqa_labeled_splits/test.parquet"
    kg_csv_path: str = "/media/miaoen/ad4277ac-5cfe-47b0-a2cc-f9e50e0da444/LLM/PrimeKG/kg.csv"
    model_name_or_path: str = "/media/miaoen/ad4277ac-5cfe-47b0-a2cc-f9e50e0da444/LLM/llama2"

    # Inference config
    batch_size: int = 4
    max_new_tokens: int = 256
    fix_max_new_tokens: int = 16   # Second pass fix only needs a decision line, keep it short
    temperature: float = 0.0       # Default to deterministic output (do_sample=False)

    # Retrieval config
    top_k_nodes: int = 8
    top_facts_per_node: int = 6
    max_ctx_chars: int = 1500

    # KG Construction (PrimeKG is large, read in streams using chunksize)
    kg_chunksize: int = 2_000_000
    max_facts_per_node_cache: int = 50
    max_node_doc_chars: int = 2000
    use_display_relation: bool = True

    # Evaluation settings
    limit: Optional[int] = 100
    seed: int = 42
    out_dir: str = "eval_out_rag_only"

cfg = Config()


# 2) Utils & Data Loading

def set_seed(seed: int):
    """Set random seed to ensure reproducibility."""
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(cfg.seed)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

def ensure_dir(d: str):
    """Ensure output directory exists."""
    os.makedirs(d, exist_ok=True)

def load_pubmedqa(parquet_path: str, limit: Optional[int]) -> pd.DataFrame:
    """Load PubMedQA parquet and ensure question/final_decision are not empty."""
    print(f"[Data] Loading {parquet_path}...")
    df = pd.read_parquet(parquet_path)
    df = df.dropna(subset=["question", "final_decision"])
    if limit is not None:
        df = df.head(limit)
    return df


# 3) KG & Retrieval (Use PrimeKG for retrieval context)

def make_fact(x, rel, y):
    """Assemble a triple into a 'fact' sentence."""
    if not x or not y:
        return ""
    return f"Fact: {x} {rel} {y}."

def build_kg_cache(csv_path):
    """
    Stream build from kg.csv:
    - node_info: node_id -> node_name
    - adj: node_id -> [fact1, fact2, ...] (Cache max max_facts_per_node_cache facts per node)
    """
    print(f"[KG] Building from {csv_path}...")
    node_info, adj = {}, {}
    iterator = pd.read_csv(csv_path, chunksize=cfg.kg_chunksize)

    for chunk in tqdm(iterator, desc="KG Chunks"):
        chunk = chunk.fillna("")

        xs_nm = chunk["x_name"].astype(str).values
        ys_nm = chunk["y_name"].astype(str).values
        xs_id = chunk["x_id"].astype(str).values
        ys_id = chunk["y_id"].astype(str).values
        rels  = chunk["display_relation"].astype(str).values if cfg.use_display_relation else chunk["relation"].astype(str).values

        for i in range(len(chunk)):
            xid, xnm, yid, ynm, r = xs_id[i], xs_nm[i], ys_id[i], ys_nm[i], rels[i]

            # Record node names
            if xid not in node_info:
                node_info[xid] = xnm
            if yid not in node_info:
                node_info[yid] = ynm

            # Generate fact
            f = make_fact(xnm, r, ynm)
            if not f:
                continue

            # Attach fact to node x
            if xid not in adj:
                adj[xid] = []
            if len(adj[xid]) < cfg.max_facts_per_node_cache:
                adj[xid].append(f)

            # Attach to node y as well (undirected diffusion, facilitates retrieval coverage)
            if yid not in adj:
                adj[yid] = []
            if len(adj[yid]) < cfg.max_facts_per_node_cache:
                adj[yid].append(f)

    node_ids = list(node_info.keys())
    return node_ids, node_info, adj

def retrieve_kg(questions, vectorizer, X, node_ids, adj):
    """
    TF-IDF Retrieval: Map question -> top_k_nodes, aggregate facts from these nodes as context.
    """
    ctx_list = []
    Q = vectorizer.transform(questions)

    for i in tqdm(range(Q.shape[0]), desc="Retrieving"):
        scores = (Q[i] @ X.T).toarray().ravel()
        top_idx = np.argsort(-scores)[:cfg.top_k_nodes]

        facts = []
        for idx in top_idx:
            nid = node_ids[idx]
            facts.extend(adj.get(nid, [])[:cfg.top_facts_per_node])

        ctx = "\n".join(facts)
        ctx_list.append(ctx[:cfg.max_ctx_chars])

    return ctx_list


# 4) Metrics & Parsing (Only RAG evaluation)

DEC_LABELS = ("yes", "no", "maybe")

def norm_dec(text: str) -> str:
    """
    Extract decision from model output:
    1) Prioritize finding 'Final Decision: yes/no/maybe'
    2) If not found, fallback to matching any yes/no/maybe word
    """
    s = (text or "").lower()
    m = re.findall(r"final decision\s*:\s*(yes|no|maybe)", s)
    if m:
        return m[-1]
    m2 = re.findall(r"\b(yes|no|maybe)\b", s)
    return m2[-1] if m2 else "unknown"

def calc_acc(gts, preds):
    """Decision Accuracy: Only count samples where gt belongs to yes/no/maybe."""
    correct, total = 0, 0
    for g, p in zip(gts, preds):
        gt = str(g).lower().strip()
        pred = norm_dec(p)
        if gt not in DEC_LABELS:
            continue
        total += 1
        if gt == pred:
            correct += 1
    return correct / total if total > 0 else 0


# 5) Prompts (Only keep RAG prompt + fix prompt)

SYSTEM_PROMPT = (
    "You are a medical expert.\n"
    "Task: answer PubMed-style yes/no/maybe questions.\n"
    "Output format MUST be exactly:\n"
    "Reasoning: <one paragraph>\n"
    "Final Decision: yes|no|maybe\n"
    "If evidence is insufficient, choose maybe.\n"
)

FIX_SYSTEM_PROMPT = (
    "You are a strict formatter.\n"
    "Return ONLY one line in exactly this format:\n"
    "Final Decision: yes|no|maybe\n"
)

def format_prompt(tokenizer, question, context):
    """
    RAG prompt: Must answer using Context and strictly follow output format.
    """
    user_content = (
        f"Context:\n{context}\n\n"
        f"Question: {question}\n"
        "Answer using the Context.\n"
        "Follow the required output format."
    )

    # If tokenizer supports apply_chat_template (Different Llama2 tokenizers might have/not have it)
    if hasattr(tokenizer, "apply_chat_template"):
        msgs = [{"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user", "content": user_content}]
        return tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)

    # Fallback: Pure text concatenation
    return f"{SYSTEM_PROMPT}\nUser: {user_content}\nAssistant:"

def format_fix_prompt(tokenizer, question, answer_text):
    """
    Fix prompt: When decision cannot be extracted (unknown) from model output, generate again requesting only the decision line.
    """
    user_content = (
        f"Question: {question}\n\n"
        f"Model Answer:\n{answer_text}\n\n"
        "Extract the decision."
    )
    if hasattr(tokenizer, "apply_chat_template"):
        msgs = [{"role": "system", "content": FIX_SYSTEM_PROMPT},
                {"role": "user", "content": user_content}]
        return tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)

    return f"{FIX_SYSTEM_PROMPT}\nUser: {user_content}\nAssistant:"


# 6) Inference (Keep your original left-padding slicing fix logic)

@torch.inference_mode()
def generate_batch(model, tokenizer, prompts: List[str], max_new_tokens: int) -> List[str]:
    """
    Key points:
    - With left padding, generate appends new tokens after the unified length of the batch
    - So slicing must use inputs.input_ids.shape[1] (do not use attention_mask.sum)
    """
    inputs = tokenizer(
        prompts,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=2048
    ).to(DEVICE)

    gen_ids = model.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        do_sample=False,  # You fixed this to deterministic output
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
    )

    cut = inputs.input_ids.shape[1]
    outs = []
    for j in range(gen_ids.size(0)):
        new_ids = gen_ids[j, cut:]
        outs.append(tokenizer.decode(new_ids, skip_special_tokens=True).strip())
    return outs

@torch.inference_mode()
def run_eval_with_fix(model, tokenizer, questions, contexts):
    """
    Run RAG only:
    1) first pass generates complete answer
    2) second pass: perform "format fix" for unknown outputs, forcing completion of Final Decision
    """
    # 1) first pass
    outputs = []
    for i in tqdm(range(0, len(questions), cfg.batch_size), desc="Infer RAG"):
        batch_q = questions[i:i+cfg.batch_size]
        batch_c = contexts[i:i+cfg.batch_size]
        prompts = [format_prompt(tokenizer, q, c) for q, c in zip(batch_q, batch_c)]
        outputs.extend(generate_batch(model, tokenizer, prompts, cfg.max_new_tokens))

    # 2) second pass fix for unknown
    unknown_idx = [i for i, t in enumerate(outputs) if norm_dec(t) == "unknown"]
    if unknown_idx:
        print(f"[Fix] RAG: repairing unknown={len(unknown_idx)}/{len(outputs)}")
        for s in tqdm(range(0, len(unknown_idx), cfg.batch_size), desc="Fix RAG"):
            idxs = unknown_idx[s:s+cfg.batch_size]
            fix_prompts = [format_fix_prompt(tokenizer, questions[k], outputs[k]) for k in idxs]
            fixed = generate_batch(model, tokenizer, fix_prompts, cfg.fix_max_new_tokens)

            # Write the fixed decision back to the end of the original output to ensure norm_dec can extract it
            for k, fx in zip(idxs, fixed):
                dec = norm_dec(fx)
                if dec in DEC_LABELS:
                    outputs[k] = outputs[k].rstrip() + "\nFinal Decision: " + dec
                else:
                    outputs[k] = outputs[k].rstrip() + "\nFinal Decision: maybe"

    return outputs


# 7) Main (Only keep RAG path)

def main():
    ensure_dir(cfg.out_dir)

    # Read PubMedQA
    df = load_pubmedqa(cfg.parquet_path, cfg.limit)
    questions = df["question"].astype(str).tolist()
    gts = df["final_decision"].astype(str).tolist()

    # --- Build KG Cache (node -> facts) ---
    node_ids, node_info, adj = build_kg_cache(cfg.kg_csv_path)

    # --- Build a "node document" for each node, used for TF-IDF ---
    # Note: Here we concatenate node_name + several facts into a document (truncated to max_node_doc_chars)
    node_docs = [
        f"{node_info[nid]} {' '.join(adj.get(nid, []))}"[:cfg.max_node_doc_chars]
        for nid in node_ids
    ]

    # --- TF-IDF Modeling and Retrieval ---
    tfidf = TfidfVectorizer(stop_words="english", max_features=50000)
    X = tfidf.fit_transform(node_docs)
    kg_contexts = retrieve_kg(questions, tfidf, X, node_ids, adj)

    # --- Load 4-bit Model ---
    print(f"[Model] Loading 4-bit Llama from {cfg.model_name_or_path}...")
    bnb_cfg = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4"
    )

    tokenizer = AutoTokenizer.from_pretrained(cfg.model_name_or_path, use_fast=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "left"

    # Single card: device_map={"":0}; Multi-card can be changed to "auto"
    model = AutoModelForCausalLM.from_pretrained(
        cfg.model_name_or_path,
        quantization_config=bnb_cfg,
        device_map={"": 0}
    )
    model.eval()

    # --- Run RAG Inference Only (with automatic unknown fix) ---
    preds_rag = run_eval_with_fix(model, tokenizer, questions, kg_contexts)

    # --- Metrics ---
    acc_rag = calc_acc(gts, preds_rag)

    print("\n=== Distribution Check ===")
    print("GT dist     :", Counter([str(x).lower().strip() for x in gts]))
    print("RAG dec dist:", Counter([norm_dec(p) for p in preds_rag]))

    print(f"\nSummary:\nRAG ACC: {acc_rag:.4f}")

    # --- Save Results ---
    timestamp = datetime.now().strftime("%Y%m%d_%H%M")
    out_csv = f"{cfg.out_dir}/rag_only_{timestamp}.csv"
    res_df = pd.DataFrame({
        "question": questions,
        "gt": gts,
        "pred_rag": preds_rag,
        "dec_rag": [norm_dec(p) for p in preds_rag],
        "ctx_rag": kg_contexts,
    })
    res_df.to_csv(out_csv, index=False)
    print("[Saved]", out_csv)

    # --- Quick Sample Check ---
    for k in range(min(3, len(questions))):
        print("\n--- sample", k, "---")
        print("Q :", questions[k])
        print("GT:", gts[k])
        print("RAG(dec):", norm_dec(preds_rag[k]))
        print("RAG out :", preds_rag[k][:400])
        print("CTX     :", (kg_contexts[k] or "")[:300])

if __name__ == "__main__":
    main()

[Data] Loading /media/miaoen/ad4277ac-5cfe-47b0-a2cc-f9e50e0da444/LLM/data/pubmedqa_hf/pqa_labeled_splits/test.parquet...
[KG] Building from /media/miaoen/ad4277ac-5cfe-47b0-a2cc-f9e50e0da444/LLM/PrimeKG/kg.csv...


KG Chunks: 0it [00:00, ?it/s]

  for obj in iterable:
  for obj in iterable:
  for obj in iterable:
  for obj in iterable:
KG Chunks: 5it [00:24,  4.80s/it]
Retrieving: 100%|██████████| 100/100 [00:02<00:00, 43.50it/s]


[Model] Loading 4-bit Llama from /media/miaoen/ad4277ac-5cfe-47b0-a2cc-f9e50e0da444/LLM/llama2...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Infer base: 100%|██████████| 25/25 [07:12<00:00, 17.29s/it]


[Fix] base: repairing unknown=15/100


Fix base: 100%|██████████| 4/4 [00:03<00:00,  1.29it/s]
Infer rag: 100%|██████████| 25/25 [07:24<00:00, 17.78s/it]


[Fix] rag: repairing unknown=33/100


Fix rag: 100%|██████████| 9/9 [00:06<00:00,  1.29it/s]


=== Distribution Check ===
GT dist      : Counter({'yes': 62, 'no': 30, 'maybe': 8})
Base dec dist: Counter({'maybe': 60, 'yes': 32, 'no': 8})
RAG  dec dist: Counter({'maybe': 55, 'yes': 37, 'no': 8})

Summary:
Base ACC: 0.2900
RAG ACC: 0.2800
[Saved] eval_out_closed_book_vs_rag/result_20251231_2250.csv

--- sample 0 ---
Q : Malnutrition, a new inducer for arterial calcification in hemodialysis patients?
GT: yes
BASE(dec): yes
BASE out : Reasoning:
Malnutrition has been increasingly recognized as a significant risk factor for cardiovascular disease in hemodialysis patients. Recent studies have suggested that malnutrition may also contribute to the development of arterial calcification in this population. For example, a study published in the Journal of the American Society of Nephrology found that malnourished hemodialysis patient
RAG (dec): yes
RAG  out : Reasoning:
Arterial calcification is a known complication in hemodialysis patients, and it has been established that malnutrition 




llama2_pumbmedqakg

In [None]:
# RAG ONLY Evaluation Script for PubMedQA Test Set + KB(pubmed_documents.pkl)
# Model: Llama-2-7b-chat (4-bit quantization loading)
# Metrics: Decision ACC + ROUGE-L + BERTScore


import os
import re
import json
import pickle
import random
import numpy as np
import pandas as pd
import torch
import pyarrow.parquet as pq
from datetime import datetime
from tqdm import tqdm

from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    BitsAndBytesConfig  # Used for 4-bit quantization loading
)
from sklearn.feature_extraction.text import TfidfVectorizer


# 0) Configuration Area

class Config:
    # ---- PubMedQA Test Set (HF parquet) ----
    parquet_path = r"/media/miaoen/ad4277ac-5cfe-47b0-a2cc-f9e50e0da444/LLM/data/pubmedqa_hf/pqa_labeled_splits/test.parquet"

    # ---- KB Documents (Your local "Knowledge Base") ----
    kb_docs_path = r"/media/miaoen/ad4277ac-5cfe-47b0-a2cc-f9e50e0da444/LLM/PrimeKG/pubmed_documents.pkl"

    # ---- (Optional) FAISS Vector Index ----
    kb_index_path = r"/media/miaoen/ad4277ac-5cfe-47b0-a2cc-f9e50e0da444/LLM/PrimeKG/pubmed_qa.index"
    embed_model_name = "sentence-transformers/all-MiniLM-L6-v2"

    # ---- Generation Model (Llama-2) ----
    model_name_or_path = "/media/miaoen/ad4277ac-5cfe-47b0-a2cc-f9e50e0da444/LLM/llama2"

    # Llama2 consumes a lot of VRAM, small batch_size (1 or 2) is recommended
    batch_size = 2
    max_new_tokens = 512

    # ---- RAG Retrieval Parameters ----
    top_k_docs = 2          # Retrieve top_k KB documents for each question
    max_ctx_chars = 1200    # Max characters for concatenated Context (prevents prompt from being too long)

    # ---- TF-IDF Fallback Parameters (When FAISS is missing or fails to init) ----
    tfidf_max_docs = None
    tfidf_max_features = 200000

    # ---- Evaluation Parameters ----
    limit = 100   # None means run all; integer means run only first N samples
    seed = 42

    # ---- Output Directory ----
    output_dir = "eval_results_pubmedqa_kbRAG_only_llama2"

cfg = Config()


# 1) Utility Functions

def set_seed(seed: int):
    """Set fixed random seed for reproducibility."""
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(cfg.seed)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("[Info] DEVICE:", DEVICE)

def accuracy(y_true, y_pred):
    """Simple accuracy."""
    if not y_true:
        return None
    correct = sum(1 for a, b in zip(y_true, y_pred) if a == b)
    return correct / len(y_true)

def confusion_table(y_true, y_pred, labels=("yes", "no", "maybe", "unknown")):
    """Confusion matrix (using pandas crosstab)."""
    return pd.crosstab(
        pd.Series(y_true, name="GT"),
        pd.Series(y_pred, name="Pred"),
        rownames=["GT"], colnames=["Pred"],
        dropna=False
    ).reindex(index=list(labels), columns=list(labels), fill_value=0)

# ---- Pure Python implementation of ROUGE-L (F1) ----
def _lcs_len(a_tokens, b_tokens):
    """Calculate LCS length of two token sequences."""
    n, m = len(a_tokens), len(b_tokens)
    dp = [0] * (m + 1)
    for i in range(1, n + 1):
        prev = 0
        for j in range(1, m + 1):
            tmp = dp[j]
            if a_tokens[i - 1] == b_tokens[j - 1]:
                dp[j] = prev + 1
            else:
                dp[j] = max(dp[j], dp[j - 1])
            prev = tmp
    return dp[m]

def rouge_l_f1(pred: str, ref: str) -> float:
    """ROUGE-L F1 (LCS-based Precision/Recall/F1)."""
    pred = (pred or "").strip()
    ref = (ref or "").strip()
    if not pred or not ref:
        return 0.0
    pred_tokens = re.findall(r"\w+", pred.lower())
    ref_tokens  = re.findall(r"\w+", ref.lower())
    if not pred_tokens or not ref_tokens:
        return 0.0

    lcs = _lcs_len(pred_tokens, ref_tokens)
    prec = lcs / len(pred_tokens)
    rec  = lcs / len(ref_tokens)
    if prec + rec == 0:
        return 0.0
    return 2 * prec * rec / (prec + rec)


# 2) Read PubMedQA TEST parquet

print(f"[Info] Reading test data: {cfg.parquet_path}")
if not os.path.exists(cfg.parquet_path):
    raise FileNotFoundError(f"Parquet not found: {cfg.parquet_path}")

tbl = pq.read_table(cfg.parquet_path)
df = tbl.to_pandas()

# Filter out samples with empty questions
df = df.dropna(subset=["question"])

# Run only first N samples (Optional)
if cfg.limit:
    df = df.head(cfg.limit)

questions = df["question"].astype(str).tolist()

# ref_col: Find a column that can serve as "Reference Answer Text"
# Note: In PubMedQA, long_answer might be long text; final_decision is yes/no/maybe; answer might also exist
ref_col = None
for c in ["long_answer", "final_decision", "answer"]:
    if c in df.columns:
        ref_col = c
        break
if ref_col is None:
    # Fallback to the last column if not found
    ref_col = df.columns[-1]

refs = df[ref_col].fillna("").astype(str).tolist()

# If final_decision exists, use it for Decision ACC
gt_decisions = None
if "final_decision" in df.columns:
    gt_decisions = df["final_decision"].fillna("").astype(str).str.lower().tolist()

print(f"[Info] Loaded samples: {len(questions)} | ref_col={ref_col} | has_final_decision={gt_decisions is not None}")


# 3) Load KB Documents (pubmed_documents.pkl)

print(f"[Info] Loading KB docs: {cfg.kb_docs_path}")
if not os.path.exists(cfg.kb_docs_path):
    raise FileNotFoundError(f"KB docs not found: {cfg.kb_docs_path}")

with open(cfg.kb_docs_path, "rb") as f:
    kb_docs = pickle.load(f)

# If you want to use only the first N documents of the KB (for debugging/speedup)
if cfg.tfidf_max_docs is not None:
    kb_docs = kb_docs[:cfg.tfidf_max_docs]

# Ensure every doc is a string
kb_docs = [("" if d is None else str(d)) for d in kb_docs]
print(f"[Info] KB docs loaded: {len(kb_docs)}")


# 4) Build Retriever (Prioritize FAISS, otherwise TF-IDF)

use_faiss = False
faiss_index = None
embed_model = None
tfidf_vectorizer = None
tfidf_X = None

def build_tfidf_retriever(docs):
    """Build TF-IDF vectorizer and matrix (fallback retrieval)."""
    print("[Info] Building TF-IDF retriever (fallback)...")
    vectorizer = TfidfVectorizer(
        stop_words="english",
        max_features=cfg.tfidf_max_features,
        ngram_range=(1, 2),
    )
    X = vectorizer.fit_transform(docs)
    print("[Info] TF-IDF ready.")
    return vectorizer, X

# Attempt to initialize FAISS
try:
    if os.path.exists(cfg.kb_index_path):
        import faiss
        from sentence_transformers import SentenceTransformer
        print(f"[Info] Found FAISS index: {cfg.kb_index_path}")
        faiss_index = faiss.read_index(cfg.kb_index_path)
        embed_model = SentenceTransformer(cfg.embed_model_name)
        use_faiss = True
        print("[Info] Using FAISS retriever.")
    else:
        print("[Info] FAISS index not found -> fallback TF-IDF.")
except Exception as e:
    print(f"[Warn] FAISS init failed -> fallback TF-IDF. Error: {e}")

# If not using FAISS, build TF-IDF
if not use_faiss:
    tfidf_vectorizer, tfidf_X = build_tfidf_retriever(kb_docs)

def retrieve_docs(query: str, top_k: int):
    """Retrieve top_k documents for a query, return list of doc texts."""
    query = (query or "").strip()
    if not query:
        return []

    if use_faiss:
        # FAISS: First generate vector with SentenceTransformer, then search with faiss
        import faiss
        q_emb = embed_model.encode([query], convert_to_numpy=True)
        faiss.normalize_L2(q_emb)
        scores, idxs = faiss_index.search(q_emb, top_k)
        idxs = idxs[0].tolist()
        return [kb_docs[i] for i in idxs if 0 <= i < len(kb_docs)]
    else:
        # TF-IDF: Calculate similarity and take top_k
        qv = tfidf_vectorizer.transform([query])
        sims = (tfidf_X @ qv.T).toarray().ravel()

        if top_k >= len(sims):
            top_idx = np.argsort(-sims)
        else:
            top_idx = np.argpartition(-sims, top_k)[:top_k]
            top_idx = top_idx[np.argsort(-sims[top_idx])]

        return [kb_docs[i] for i in top_idx.tolist()]

def build_rag_context(question: str):
    """Retrieve top_k_docs based on question and concatenate into a Context string."""
    docs = retrieve_docs(question, cfg.top_k_docs)
    ctx = "\n\n".join([d.strip() for d in docs if d and d.strip()])
    # Truncate to prevent prompt from being too long
    return ctx[:cfg.max_ctx_chars] if ctx else ""

print("[Info] Retrieving contexts for all questions...")
ctx_list_rag = [build_rag_context(q) for q in tqdm(questions)]


# 5) Load Generator Model (Llama-2 4-bit)

print(f"[Info] Loading generator: {cfg.model_name_or_path} with 4-bit Quantization...")

# 4-bit quantization config (NF4 + double quant), significantly saves VRAM
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4"
)

tok = AutoTokenizer.from_pretrained(cfg.model_name_or_path, use_fast=True)

# Llama2 usually doesn't have a pad_token, needs manual setting
if tok.pad_token is None:
    tok.pad_token = tok.eos_token

# Left padding is generally more suitable for generation tasks
tok.padding_side = "left"

model = AutoModelForCausalLM.from_pretrained(
    cfg.model_name_or_path,
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True
)
model.eval()
print("[Info] Model loaded.")


# 6) Prompt + Inference (RAG Only)

def build_prompt(q: str, ctx: str) -> str:
    """
    Manually construct Llama-2 Chat format prompt:
    [INST] <<SYS>> ... <</SYS>> user_msg [/INST]
    """
    sys_msg = (
        "You are a helpful biomedical expert. "
        "Your task is to answer the question with a brief medical reasoning "
        "followed by a final decision: 'Yes', 'No', or 'Maybe'."
    )

    # Only keep RAG here: Always provide Context
    user_msg = (
        f"Context:\n{ctx}\n\n"
        f"Question: {q}\n\n"
        "Based on the context above, provide reasoning and then conclude with "
        "'Answer: Yes', 'Answer: No', or 'Answer: Maybe'."
    )

    return f"[INST] <<SYS>>\n{sys_msg}\n<</SYS>>\n\n{user_msg} [/INST]"

def run_inference_rag(qs, ctxs):
    """Batch Inference (RAG only)."""
    print("[Info] Running inference: RAG (KB Retrieved Docs)")
    preds = []
    prompts = [build_prompt(q, c) for q, c in zip(qs, ctxs)]

    for i in range(0, len(prompts), cfg.batch_size):
        batch_prompts = prompts[i:i + cfg.batch_size]

        # tokenize + padding/truncation
        inputs = tok(
            batch_prompts,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=4096  # Llama-2 usually supports 4k
        ).to("cuda")  # Note: Defaulting to cuda here

        # Generate
        with torch.no_grad():
            out = model.generate(
                **inputs,
                max_new_tokens=cfg.max_new_tokens,
                pad_token_id=tok.pad_token_id,
                do_sample=False,     # Greedy decoding to ensure determinism
                temperature=0.0
            )

        # Only take the newly generated part
        input_len = inputs["input_ids"].shape[1]
        decoded = tok.batch_decode(out[:, input_len:], skip_special_tokens=True)
        preds.extend([d.strip() for d in decoded])

        if (i // cfg.batch_size) % 5 == 0:
            print(f"  batch {i // cfg.batch_size + 1} done")

    return preds

preds_rag = run_inference_rag(questions, ctx_list_rag)


# 7) Extract yes/no/maybe decision from output

def extract_decision_last(text: str) -> str:
    """
    Prioritize matching the last occurrence of 'Answer: Yes/No/Maybe';
    Otherwise degrade to matching the last occurrence of any yes/no/maybe word.
    """
    matches = re.findall(r"Answer:\s*(Yes|No|Maybe)", text, re.IGNORECASE)
    if matches:
        return matches[-1].lower()

    matches = re.findall(r"\b(yes|no|maybe)\b", (text or "").lower())
    return matches[-1] if matches else "unknown"

pred_dec_rag = [extract_decision_last(p) for p in preds_rag]


# 8) Metric Calculation

acc_rag = None
conf_rag = None

# Decision ACC: Requires final_decision column
if gt_decisions is not None:
    valid_idx = [i for i, g in enumerate(gt_decisions) if g in ("yes", "no", "maybe")]
    if valid_idx:
        y_true = [gt_decisions[i] for i in valid_idx]
        yr = [pred_dec_rag[i] for i in valid_idx]
        acc_rag  = accuracy(y_true, yr)
        conf_rag = confusion_table(y_true, yr, labels=("yes", "no", "maybe", "unknown"))

print("\n" + "=" * 60)
print("Decision ACC (Yes/No/Maybe) - RAG ONLY")
if acc_rag is None:
    print("[Warn] final_decision not found/valid -> ACC skipped.")
else:
    print(f"RAG ACC: {acc_rag:.4f}")
    print("\nConfusion (RAG):")
    print(conf_rag)
print("=" * 60)

# ROUGE-L: Compare generated text vs refs
rougeL_rag_list = [rouge_l_f1(p, r) for p, r in zip(preds_rag, refs)]
rougeL_rag = float(np.mean(rougeL_rag_list))
print("\nROUGE-L (RAG):")
print(f"RAG: {rougeL_rag:.4f}")

# BERTScore: May require downloading extra model/dependencies; skip if failed
bs_rag = None
try:
    import evaluate
    bertscore = evaluate.load("bertscore")
    bs_res_rag = bertscore.compute(
        predictions=preds_rag,
        references=refs,
        lang="en",
        model_type="distilbert-base-uncased"
    )
    bs_rag = float(np.mean(bs_res_rag["f1"]))
    print("\nBERTScore-F1 (RAG):")
    print(f"RAG: {bs_rag:.4f}")
except Exception as e:
    print("\n[Warn] BERTScore skipped:", str(e)[:100])


# 9) Save Results
summary = {
    "model": cfg.model_name_or_path,
    "decision_acc": {"rag": acc_rag},
    "rougeL": {"rag": rougeL_rag},
    "bertscore": {"rag": bs_rag}
}

os.makedirs(cfg.output_dir, exist_ok=True)
time_str = datetime.now().strftime("%Y%m%d_%H%M%S")
jsonl_path = os.path.join(cfg.output_dir, f"llama2_rag_only_{time_str}.jsonl")
csv_path   = os.path.join(cfg.output_dir, f"llama2_rag_only_{time_str}.csv")

rows = []
with open(jsonl_path, "w", encoding="utf-8") as f:
    for i in range(len(questions)):
        row = {
            "id": i,
            "question": questions[i],
            "reference": refs[i],
            "gt_decision": None if gt_decisions is None else gt_decisions[i],
            "rag_pred": preds_rag[i],
            "rag_decision": pred_dec_rag[i],
            "rag_context": ctx_list_rag[i],  # Save the fully retrieved context
        }
        f.write(json.dumps(row, ensure_ascii=False) + "\n")
        rows.append(row)

pd.DataFrame(rows).to_csv(csv_path, index=False)

print("\n" + "=" * 60)
print(f"[Done] Saved to {jsonl_path}")
print(f"Summary: {json.dumps(summary, indent=2, ensure_ascii=False)}")

[Info] DEVICE: cuda
[Info] Reading test data: /media/miaoen/ad4277ac-5cfe-47b0-a2cc-f9e50e0da444/LLM/data/pubmedqa_hf/pqa_labeled_splits/test.parquet
[Info] Loaded samples: 100 | ref_col=long_answer | has_final_decision=True
[Info] Loading KB docs: /media/miaoen/ad4277ac-5cfe-47b0-a2cc-f9e50e0da444/LLM/PrimeKG/pubmed_documents.pkl
[Info] KB docs loaded: 800
[Info] Found FAISS index: /media/miaoen/ad4277ac-5cfe-47b0-a2cc-f9e50e0da444/LLM/PrimeKG/pubmed_qa.index
[Info] Using FAISS retriever.
[Info] Retrieving contexts for all questions...


100%|██████████| 100/100 [00:00<00:00, 195.95it/s]


[Info] Loading generator: /media/miaoen/ad4277ac-5cfe-47b0-a2cc-f9e50e0da444/LLM/llama2 with 4-bit Quantization...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

[Info] Model loaded.
[Info] Running inference: Base (No Context)
  batch 1 done
  batch 6 done
  batch 11 done
  batch 16 done
  batch 21 done
  batch 26 done
  batch 31 done
  batch 36 done
  batch 41 done
  batch 46 done
[Info] Running inference: RAG (KB Retrieved Docs)
  batch 1 done
  batch 6 done
  batch 11 done
  batch 16 done
  batch 21 done
  batch 26 done
  batch 31 done
  batch 36 done
  batch 41 done
  batch 46 done

 =
Decision ACC (Yes/No/Maybe)
Base ACC: 0.4800
RAG  ACC: 0.4700
Gain   : -0.0100
 =

ROUGE-L:
Base: 0.1064 | RAG: 0.1123

BERTScore-F1:
Base: 0.7582 | RAG: 0.7567

 =
[Done] Saved to eval_results_pubmedqa_kbRAG_test_llama2/llama2_results_20251231_191704.jsonl
Summary: {
  "model": "/media/miaoen/ad4277ac-5cfe-47b0-a2cc-f9e50e0da444/LLM/llama2",
  "decision_acc": {
    "base": 0.48,
    "rag": 0.47
  },
  "rougeL": {
    "base": 0.10636585174230512,
    "rag": 0.11232739917531231
  },
  "bertscore": {
    "base": 0.7582157564163208,
    "rag": 0.7566591346263886

In [None]:
# 7) Metrics

# --- Decision ACC ---
acc_rag = None
conf_rag = None

if gt_decisions is not None:
    valid_idx = [i for i, g in enumerate(gt_decisions) if g in ("yes", "no", "maybe")]
    if valid_idx:
        y_true = [gt_decisions[i] for i in valid_idx]
        yr = [pred_dec_rag[i] for i in valid_idx]

        acc_rag  = accuracy(y_true, yr)
        conf_rag = confusion_table(y_true, yr, labels=("yes", "no", "maybe", "unknown"))

print("\n" + "=" * 60)
print("Decision ACC (Yes/No/Maybe)")
if acc_rag is None:
    print("[Warn] final_decision not found/valid -> ACC skipped.")
else:
    print(f"RAG  ACC: {acc_rag:.4f}")
print("=" * 60)

# --- ROUGE-L (Per sample list) ---
rougeL_rag_list  = [rouge_l_f1(p, r) for p, r in zip(preds_rag, refs)]
rougeL_rag  = float(np.mean(rougeL_rag_list))

print("\nROUGE-L:")
print(f"RAG: {rougeL_rag:.4f}")

# --- BERTScore (Per sample list) ---
bs_rag = None
# Initialize per-item score lists to prevent errors if BERTScore fails
bs_list_rag  = [None] * len(questions)

try:
    import evaluate
    bertscore = evaluate.load("bertscore")
    
    # RAG
    bs_res_rag  = bertscore.compute(predictions=preds_rag,  references=refs, lang="en", model_type="distilbert-base-uncased")
    bs_list_rag = bs_res_rag["f1"]   # Get F1 for each sample
    bs_rag  = float(np.mean(bs_list_rag))
    
    print("\nBERTScore-F1:")
    print(f"RAG: {bs_rag:.4f}")
except Exception as e:
    print("\n[Warn] BERTScore skipped:", str(e)[:100])



# 8) Save results (Include Per-Sample Scores)
summary = {
    "model": cfg.model_name_or_path,
    "decision_acc": {"rag": acc_rag},
    "rougeL": {"rag": rougeL_rag},
    "bertscore": {"rag": bs_rag}
}

os.makedirs(cfg.output_dir, exist_ok=True)
time_str = datetime.now().strftime("%Y%m%d_%H%M%S")
jsonl_path = os.path.join(cfg.output_dir, f"llama2_results_{time_str}.jsonl")
csv_path   = os.path.join(cfg.output_dir, f"llama2_results_{time_str}.csv")

rows = []
with open(jsonl_path, "w", encoding="utf-8") as f:
    for i in range(len(questions)):
        row = {
            "id": i,
            "question": questions[i],
            "reference": refs[i],
            "gt_decision": None if gt_decisions is None else gt_decisions[i],
            
            # --- Predictions ---
            "rag_pred": preds_rag[i],
            "rag_decision": pred_dec_rag[i],
            
            # --- Scores (New additions) ---
            "score_rougeL_rag":  rougeL_rag_list[i],
            "score_bertscore_rag":  bs_list_rag[i],
            
            #Context
            "rag_context": ctx_list_rag[i], 
        }
        f.write(json.dumps(row, ensure_ascii=False) + "\n")
        rows.append(row)

pd.DataFrame(rows).to_csv(csv_path, index=False)

print("\n" + "=" * 60)
print(f"[Done] Saved to {jsonl_path}")
print(f"Summary: {json.dumps(summary, indent=2)}")