In [None]:
import json
import os
import gc
import torch
import pandas as pd
import networkx as nx
from transformers import GPT2LMHeadModel, GPT2Tokenizer

#Configuration Section (Please modify according to your actual paths)
MODEL_PATH = "C:\\LLM\\gpt2"
PRIMEKG_PATH = "kg.csv" 
PUBMEDQA_FILE = "C:\\LLM/data/pubmedqa_hf/pqa_labeled_splits/test.parquet"

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# RAG Parameter Optimization
MAX_RAG_ENTITIES = 3    # Keep count low to avoid interference
MAX_K_EDGES      = 3    # Take only the top 3 most important edges per entity to reduce Token usage
MAX_TOTAL_TOKENS = 950  # GPT-2 limit is approx 1024, reserve some for generation/computation

# Blacklist: Filter out meaningless generic words from PrimeKG caused by tokenization or matching
ENTITY_BLACKLIST = {
    "stable", "group", "left", "right", "study", "human", "blood", 
    "comp", "case", "control", "severe", "disease", "level", "high", "low",
    "frequent", "risk", "measure", "ratio", "rate", "oxygen", "normal"
}

print(f"DEVICE: {DEVICE}")

DEVICE: cuda


In [None]:
# 1. model lodading
def load_gpt2(model_path: str = MODEL_PATH):
    print(f"Loading GPT-2 from {model_path} ...")
    tokenizer = GPT2Tokenizer.from_pretrained(model_path)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    model = GPT2LMHeadModel.from_pretrained(model_path).to(DEVICE)
    model.eval()
    print("  GPT-2 loaded.")
    return tokenizer, model

tokenizer, model = load_gpt2()

In [None]:
# 2. kg loading and processing
def load_kg(path: str = PRIMEKG_PATH) -> nx.Graph:
    print(f" Loading KG from {path} ...")
    df = pd.read_csv(path, low_memory=False)
    G = nx.from_pandas_edgelist(
        df, source="x_name", target="y_name", edge_attr=True
    )
    print(f"KG loaded: {G.number_of_nodes()} nodes.")
    return G

G = load_kg()
node_index = {str(n).lower(): n for n in G.nodes()}

def get_knowledge_context_single(entity_name: str, max_edges: int = MAX_K_EDGES) -> str | None:
    node = node_index.get(entity_name.lower())
    if not node: return None

    edges = list(G.edges(node, data=True))
    if not edges: return None
    

    selected_edges = edges[:max_edges]
    
    lines = []
    for u, v, attr in selected_edges:
        rel = attr.get("display_relation", attr.get("relation", "related to"))
        # neighbor node
        neighbor = v if u == node else u
        lines.append(f"{node} is {rel} {neighbor}.")
        
    return " ".join(lines)

def extract_rag_entities(text: str, max_entities: int = MAX_RAG_ENTITIES):

    t = text.lower()
    ents = []
    candidates = []
    words = set(t.split())
    
    for word in words:
        if len(word) < 4: continue
        if word in ENTITY_BLACKLIST: continue
        
        # check if in node index
        if word in node_index:
            ents.append(node_index[word])
            if len(ents) >= max_entities:
                break
                
    return ents

def get_knowledge_context_multi(text: str):
    entities = extract_rag_entities(text)
    if not entities:
        return [], ""
    
    facts_all = []
    for ent in entities:
        ctx = get_knowledge_context_single(ent)
        if ctx:
            facts_all.append(ctx)
            
    return entities, " ".join(facts_all)

In [None]:
# 3. data loading
def load_pubmedqa_example(idx: int, file_path: str = PUBMEDQA_FILE):
    df = pd.read_parquet(file_path)
    row = df.iloc[idx]
    
    return {
        "pmid": str(row.get("pubid")),
        "question": str(row.get("question") or ""),
        "context": str(row.get("long_answer") or ""), # use long_answer as context
        "answer": str(row.get("final_decision")).lower().strip()
    }

In [None]:
# 4. middle core: scoring function and prompt construction (with truncation fix)
def score_label(prompt: str, label: str) -> float:
    # Prompt + Label
    full_text = prompt + " " + label
    enc = tokenizer(full_text, return_tensors="pt").to(DEVICE)
    input_ids = enc["input_ids"]

    if input_ids.shape[1] > 1024:
        input_ids = input_ids[:, -1024:] 
    
    with torch.no_grad():
        outputs = model(input_ids, labels=input_ids)
        loss = outputs.loss.item()
    return -loss

def build_prompt(question: str, context: str, facts: str = "") -> str:
    instruction = (
        "Question Answering task.\n"
        "Fact: {facts}\n"
        "Abstract: {context}\n"
        "Question: {question}\n"
        "Answer (yes/no/maybe):"
    )
    
    # Estimated length (character-level rough estimate, 1 token ≈ 4 chars)
    safe_char_limit = 3000 
    
    current_len = len(question) + len(facts) + 100 # 100 is template buffer
    available_for_context = safe_char_limit - current_len
    
    if available_for_context < 100:
        available_for_context = 100 
        
    context_truncated = context[:available_for_context]
    # if no facts, do not show "Fact: ..." line
    facts_str = facts if facts else "None"
    
    return instruction.format(
        facts=facts_str,
        context=context_truncated,
        question=question
    )

In [None]:
def evaluate_pubmedqa_acc(start_idx=0, end_idx=50, save_path="pubmedqa_results.csv"):
    correct_no = 0    # No RAG
    correct_rag = 0   # RAG
    rag_triggered = 0 # search found knowledge
    improvements = 0  # RAG fixed errors
    hurts = 0         # RAG failed errors
    
    logs = [] # used to store each detailed result

    print(f"Starting Evaluation [{start_idx} to {end_idx}]...")

    for idx in range(start_idx, end_idx):
        try:
            ex = load_pubmedqa_example(idx)
        except Exception as e:
            print(f"Skipping {idx}: {e}")
            continue
            
        q, c, gt = ex["question"], ex["context"], ex["answer"]
        if gt not in ["yes", "no", "maybe"]: continue

        # 1. Baseline (no RAG)
        prompt_no = build_prompt(q, c, facts="")
        scores_no = {lab: score_label(prompt_no, lab) for lab in ["yes", "no", "maybe"]}
        pred_no = max(scores_no, key=scores_no.get)
        

        # 2. With RAG (RAG)
        ents, facts = get_knowledge_context_multi(q + " " + c)
        
        if ents and facts:
            rag_triggered += 1
            prompt_rag = build_prompt(q, c, facts=facts)
            scores_rag = {lab: score_label(prompt_rag, lab) for lab in ["yes", "no", "maybe"]}
            pred_rag = max(scores_rag, key=scores_rag.get)
        else:
            pred_rag = pred_no
            
        # 3. Statistics and Comparisons
        if pred_no == gt: correct_no += 1
        if pred_rag == gt: correct_rag += 1
        
        status = "SAME"
        if pred_no != gt and pred_rag == gt:
            improvements += 1
            status = "  FIXED"
        elif pred_no == gt and pred_rag != gt:
            hurts += 1
            status = "  BROKE"
            
        print(f"[{idx}] GT={gt} | No={pred_no} | RAG={pred_rag} | {status} | Ents={ents}")
    
        logs.append({
            "id": idx,
            "ground_truth": gt,
            "pred_no_rag": pred_no,
            "pred_rag": pred_rag,
            "status": status,
            "entities": str(ents),
            "rag_facts": facts[:200] if facts else ""
        })


    # 4. Accuracy
    total = len(logs)
    if total == 0:
        print("No valid samples processed.")
        return

    acc_no = correct_no / total
    acc_rag = correct_rag / total

    summary_lines = [
        "="*40,
        f"RESULTS (Total Samples: {total})",
        f"No RAG Accuracy   : {acc_no:.4f}  ({correct_no}/{total})",
        f"With RAG Accuracy : {acc_rag:.4f}  ({correct_rag}/{total})",
        "-" * 20,
        f"RAG Triggered     : {rag_triggered}",
        f"Improvements      : {improvements}",
        f"Hurts             : {hurts}",
        "="*40
    ]
    
    summary_text = "\n".join(summary_lines)

    print("\n" + summary_text)

    # save summary to TXT
    summary_path = "summary_stats.txt"
    with open(summary_path, "w", encoding="utf-8") as f:
        f.write(summary_text)

    # save detailed logs to CSV
    df = pd.DataFrame(logs)
    df.to_csv(save_path, index=False)
    
    print(f"\nDetailed logs saved to: {save_path}")
    print(f"Summary stats saved to: {summary_path}")

In [None]:
# 运行
if __name__ == "__main__":
    evaluate_pubmedqa_acc(0, 50, "my_experiment_results.csv")

In [65]:
ex0 = load_pubmedqa_example(0)
ex0

{'pmid': '23506394',
 'question': 'Malnutrition, a new inducer for arterial calcification in hemodialysis patients?',
 'context': 'Malnutrition is prevalent in hemodialysis patients and is associated with arterial calcification and the expressions of BMP2 and MGP in calcified radial arteries. Malnutrition may be a new inducer candidate for arterial calcification in hemodialysis patients.',
 'answer': 'yes'}

In [None]:
res = evaluate_pubmedqa_acc(start_idx=0, end_idx=50)
print(res)

In [None]:
import os
import torch
import numpy as np
import pandas as pd
import pyarrow.parquet as pq
import evaluate
from datetime import datetime
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM


# 0. Configuration Parameters
class Config:
    # Note: For Windows paths, it is recommended to add 'r' before the quotes to prevent escape character errors
    parquet_path = r"C:\LLM\data\pubmedqa_hf\pqa_labeled_splits\test.parquet"
    
    # Model path (Ensure model files exist in this path, or fill in a huggingface hub id like "gpt2")
    model_path = r"C:\LLM\gpt2" 
    
    # Runtime parameters
    limit = 50              # Recommended to set small for testing (e.g., 10-50), set to 200 or None after successful run
    use_chat_template = False
    max_ctx_chars = 4000
    seed = 2025

args = Config()

# Set random seed
def set_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(args.seed)


# 1. RAG Retrieval Interface (Mock)
def retrieve_knowledge(question: str) -> str:
    """
    Mock retrieval logic. In real scenarios, this should call VectorDB or Search API.
    """
    q_lower = question.lower()
    if "mutation" in q_lower:
        return "Fact: Specific gene mutations (e.g., BRCA1) are strong predictors for this condition."
    if "treatment" in q_lower:
        return "Fact: Standard treatment involves a combination of chemotherapy and radiation therapy."
    if "risk" in q_lower:
        return "Fact: High BMI and smoking are significant risk factors."
    return "" 


# 2. Helper Functions Definition
def join_context(c, max_chars):
    try:
        # Handle potential None or dictionary structures
        ctxs = (c or {}).get("contexts", [])
        return " ".join(ctxs)[:max_chars] if ctxs else ""
    except:
        return ""

def build_prompt(tok, q, ctx, facts=""):
    system_msg = "You are a helpful biomedical QA assistant."
    
    if facts:
        #With RAG
        user_content = (
            f"Retrieval Knowledge: {facts}\n"
            f"Question: {q}\nContext: {ctx}\n"
            "Answer the question in detail."
        )
        text_prompt = (
            f"Retrieval Knowledge: {facts}\n"
            f"Question: {q}\nContext: {ctx}\n\n"
            "Answer the question in detail using the context and retrieved knowledge.\nAnswer: "
        )
    else:
        #No RAG
        user_content = (
            f"Question: {q}\nContext: {ctx}\n"
            "Answer the question in detail."
        )
        text_prompt = (
            f"Question: {q}\nContext: {ctx}\n\n"
            "Answer the question in detail using the context above.\nAnswer: "
        )

    # Prioritize using Chat Template (if model supports)
    if args.use_chat_template and hasattr(tok, "apply_chat_template"):
        msgs = [{"role": "system", "content": system_msg}, {"role": "user", "content": user_content}]
        try:
            return tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
        except Exception:
            return text_prompt # Fallback to simple concatenation
    
    return text_prompt

def clean_output(full_text, prompt_len):
    """Clean model generated output, remove Prompt part and redundant stop tokens"""
    gen_text = full_text[prompt_len:].strip()
    stop_tokens = ["\nQuestion:", "\nContext:", "Answer:", "Reference:"]
    for stop_t in stop_tokens:
        if stop_t in gen_text:
            gen_text = gen_text.split(stop_t)[0].strip()
    return gen_text if gen_text else "Error: No answer generated."


# 3. Data Loading
print(f" Reading data: {args.parquet_path}")

if not os.path.exists(args.parquet_path):
    print(f" Warning: File not found {args.parquet_path}.")
    print(">>> Creating mock data for demonstration...")
    questions = ["Is mutation X related to cancer?", "What is the treatment for flu?"] * (args.limit // 2 + 1)
    ctx_list = ["Context about mutation X.", "Context about flu."] * (args.limit // 2 + 1)
    refs = ["Yes, it is related.", "Rest and fluids."] * (args.limit // 2 + 1)
    
    # Truncate to limit
    questions = questions[:args.limit]
    ctx_list = ctx_list[:args.limit]
    refs = refs[:args.limit]
else:
    try:
        tbl = pq.read_table(args.parquet_path)
        df = tbl.to_pandas().dropna().head(args.limit)
        
        # Define key variables
        questions = df["question"].tolist()
        # Adjust column name reading logic based on dataset structure
        ctx_list = df["context"].map(lambda c: join_context(c, args.max_ctx_chars)).tolist()
        
        target_col = "long_answer" if "long_answer" in df.columns else "final_decision"
        if target_col not in df.columns:
            # Final fallback to prevent missing columns
            target_col = df.columns[-1]
            
        refs = df[target_col].tolist()
        print(f"  Data loading completed, total {len(questions)} samples.")
    except Exception as e:
        print(f"  Failed to read Parquet: {e}")
        exit()


# 4. Model Loading
print(f" Loading model: {args.model_path}")

try:
    # Prioritize local loading
    tok = AutoTokenizer.from_pretrained(args.model_path, use_fast=True, local_files_only=True)
except:
    print("Local load failed, trying to load tokenizer online...")
    tok = AutoTokenizer.from_pretrained(args.model_path, use_fast=True)

if tok.pad_token is None:
    tok.pad_token = tok.eos_token

# Automatically select precision
if torch.cuda.is_available():
    dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
    print(f" Using GPU: {torch.cuda.get_device_name(0)} ({dtype})")
else:
    dtype = torch.float32
    print(" Using CPU")

try:
    model = AutoModelForCausalLM.from_pretrained(
        args.model_path,
        torch_dtype=dtype,
        device_map="auto",
        trust_remote_code=True
    )
except Exception as e:
    print(f"  Model load failed: {e}")
    # If local gpt2 is missing, provide an automatic download fallback here
    if args.model_path == r"C:\LLM\gpt2": 
        print("Attempting to download gpt2 from HuggingFace...")
        model = AutoModelForCausalLM.from_pretrained("gpt2", torch_dtype=dtype, device_map="auto")
        tok = AutoTokenizer.from_pretrained("gpt2")
        tok.pad_token = tok.eos_token
    else:
        exit()

model.eval()


# 5. Inference Loop
# Create experiment directory
time_str = datetime.now().strftime("%Y%m%d_%H%M%S")
model_tag = os.path.basename(args.model_path.rstrip("/\\")) # Compatible with Windows/Linux path separators
run_dir = os.path.join("eval_out_notebook", f"run_{time_str}_{model_tag}")
os.makedirs(run_dir, exist_ok=True)

preds_no = []
preds_rag = []
facts_log = []

print(f"\n Starting Inference (Limit={len(questions)})...")

for q, ctx in tqdm(zip(questions, ctx_list), total=len(questions), desc="Inference"):
    
    #Generation Parameters
    gen_kwargs = {
        "max_new_tokens": 128,
        "pad_token_id": tok.pad_token_id,
        "eos_token_id": tok.eos_token_id,
        "do_sample": False,        # Greedy decoding
        "repetition_penalty": 1.2
    }

    #A. No RAG
    prompt_no = build_prompt(tok, q, ctx, facts="")
    inputs_no = tok(prompt_no, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        out_no = model.generate(**inputs_no, **gen_kwargs)
    
    input_len_no = inputs_no["input_ids"].shape[1]
    decoded_out_no = tok.decode(out_no[0], skip_special_tokens=True)
    decoded_in_no = tok.decode(out_no[0][:input_len_no], skip_special_tokens=True)
    
    pred_no_clean = clean_output(decoded_out_no, len(decoded_in_no))
    preds_no.append(pred_no_clean)
    
    #B. With RAG
    retrieved_facts = retrieve_knowledge(q)
    facts_log.append(retrieved_facts)
    
    prompt_rag = build_prompt(tok, q, ctx, facts=retrieved_facts)
    inputs_rag = tok(prompt_rag, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        out_rag = model.generate(**inputs_rag, **gen_kwargs)
        
    input_len_rag = inputs_rag["input_ids"].shape[1]
    decoded_out_rag = tok.decode(out_rag[0], skip_special_tokens=True)
    decoded_in_rag = tok.decode(out_rag[0][:input_len_rag], skip_special_tokens=True)
    
    pred_rag_clean = clean_output(decoded_out_rag, len(decoded_in_rag))
    preds_rag.append(pred_rag_clean)


# 6. Metrics Calculation and Saving
print("\n Calculating ROUGE scores...")
try:
    rouge = evaluate.load("rouge")
    
    # Calculate ROUGE
    scores_no = rouge.compute(predictions=preds_no, references=refs, use_aggregator=False)["rougeLsum"]
    scores_rag = rouge.compute(predictions=preds_rag, references=refs, use_aggregator=False)["rougeLsum"]

    avg_no = np.mean(scores_no)
    avg_rag = np.mean(scores_rag)

    print(f"\n{'='*40}")
    print(f" Results Summary (n={len(questions)})")
    print(f"{'='*40}")
    print(f" No RAG Avg ROUGE-Lsum : {avg_no:.4f}")
    print(f" w/ RAG Avg ROUGE-Lsum : {avg_rag:.4f}")
    print(f" Improvement           : {(avg_rag - avg_no):.4f}")
    print(f"{'='*40}\n")
except Exception as e:
    print(f" ROUGE calculation failed (possibly due to network issues downloading metric): {e}")
    scores_no = [0.0] * len(questions)
    scores_rag = [0.0] * len(questions)

# Save results
df_res = pd.DataFrame({
    "question": questions,
    "retrieved_facts": facts_log,
    "ref": refs,
    "pred_no_rag": preds_no,
    "pred_with_rag": preds_rag,
    "score_no": scores_no,
    "score_rag": scores_rag,
    "diff": np.array(scores_rag) - np.array(scores_no)
})

csv_path = os.path.join(run_dir, "results.csv")
df_res.to_csv(csv_path, index=False)
print(f"Detailed results saved: {csv_path}")

# Visualize Top 3
if len(df_res) > 0:
    print("\n Top 3 Most Improved Cases by RAG:")
    top_improved = df_res.sort_values(by="diff", ascending=False).head(3)

    for idx, row in top_improved.iterrows():
        print(f"\n Index: {idx} | Diff: +{row['diff']:.4f}")
        print(f" Q: {row['question']}")
        print(f" Facts: {row['retrieved_facts']}")
        print(f" No RAG: {str(row['pred_no_rag'])[:100]}...")
        print(f" w/ RAG: {str(row['pred_with_rag'])[:100]}...")
        print("-" * 50)

In [None]:
import os
import torch
import numpy as np
import pandas as pd
import pyarrow.parquet as pq
import evaluate
import faiss
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
from sentence_transformers import SentenceTransformer


# 0. Configuration Parameters
class Config:
    parquet_path = r"C:\LLM\data\pubmedqa_hf\pqa_labeled_splits\test.parquet"
    model_path = r"C:\LLM\gpt2" 
    embedding_model = "all-MiniLM-L6-v2" 
    limit = 50
    seed = 2025

args = Config()
torch.manual_seed(args.seed)

# 1. Vector Retrieval Module
class VectorDB:
    def __init__(self, model_name):
        print(f" Loading embedding model: {model_name}...")
        # Replaced random numbers in Kaggle code with real AI semantic understanding
        self.model = SentenceTransformer(model_name)
        self.index = None
        self.docs = []
    
    def create_index(self, documents):
        """Step 1: Index documents with FAISS"""
        print("Building vector index (FAISS)...")
        self.docs = documents
        # Generate real semantic vectors
        embeddings = self.model.encode(documents, convert_to_numpy=True, show_progress_bar=True)
        
        # Initialize FAISS
        dimension = embeddings.shape[1]
        self.index = faiss.IndexFlatL2(dimension)
        self.index.add(embeddings)
        print(f" Index build completed, total {len(documents)} documents.")

    def search(self, query, top_k=1):
        """Step 2: Retrieve relevant documents"""
        query_embedding = self.model.encode([query], convert_to_numpy=True)
        distances, indices = self.index.search(query_embedding, top_k)
        
        # Return the most relevant document found
        results = [self.docs[i] for i in indices[0]]
        return results[0] # Taking only Top 1 for simplicity


# 2. Data Processing and Helper Functions
def join_context(c):
    try:
        ctxs = (c or {}).get("contexts", [])
        return " ".join(ctxs) if ctxs else ""
    except:
        return ""

def clean_output(text):
    if "Answer:" in text:
        return text.split("Answer:")[-1].strip()
    return text.strip()


# 3. Load Data
print(f" Reading data: {args.parquet_path}")
if not os.path.exists(args.parquet_path):
    print(" File not found")
    exit()

tbl = pq.read_table(args.parquet_path)
df = tbl.to_pandas().dropna().head(args.limit)

# Prepare "Knowledge Base": Treat all Contexts in the test set as a huge database
# In real scenarios, this could be millions of papers
database_contexts = df["context"].map(join_context).tolist()
questions = df["question"].tolist()
refs = df["long_answer"].tolist() if "long_answer" in df.columns else df[df.columns[-1]].tolist()


# 4. Initialize FAISS Vector Database
vector_db = VectorDB(args.embedding_model)
vector_db.create_index(database_contexts)


# 5. Load GPT-2
print(f" Loading generation model: {args.model_path}")
tokenizer = AutoTokenizer.from_pretrained(args.model_path, use_fast=True, local_files_only=True)
if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(args.model_path, device_map="auto")
model.eval()


# 6. Run Real RAG Pipeline
preds = []
print("\n Starting Real RAG Inference (Search -> Generate)...")

for q in tqdm(questions, desc="RAG Pipeline"):
    
    # Step A: Retrieve
    # This time we don't provide the correct answer, let the model search in the library!
    retrieved_ctx = vector_db.search(q, top_k=1)
    
    # Step B: Augment
    # Use Prompt format from your Kaggle code
    # Note truncation length to prevent GPT-2 OOM (1024 token limit)
    truncated_ctx = retrieved_ctx[:800] 
    
    prompt = (
        f"Use the following context to answer the query:\n\n"
        f"Context: {truncated_ctx}\n\n"
        f"Query: {q}\n\n"
        f"Answer:"
    )
    
    # Step C: Generate
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    
    # Protection mechanism: If Prompt is too long, force truncate
    if inputs["input_ids"].shape[1] > 1024:
        inputs["input_ids"] = inputs["input_ids"][:, -1024:]
        
    with torch.no_grad():
        # max_new_tokens controls generation length
        outputs = model.generate(**inputs, max_new_tokens=50, pad_token_id=tokenizer.pad_token_id, do_sample=False)
        
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    # Clean result, keep only the part after Answer
    final_ans = clean_output(response[len(prompt):]) # Take only the newly generated part
    preds.append(final_ans)


# 7. Evaluation (BERTScore)
print("\n Calculating BERTScore...")
bertscore = evaluate.load("bertscore")
res = bertscore.compute(predictions=preds, references=refs, lang="en", device="cuda" if torch.cuda.is_available() else "cpu")
avg_score = np.mean(res['f1'])

print(f"\n{'='*40}")
print(f" GPT-2 + FAISS Real RAG Results")
print(f"{'='*40}")
print(f" Average BERTScore F1: {avg_score:.4f}")
print(f"{'='*40}\n")

# Save results
df_res = pd.DataFrame({"question": questions, "retrieved_context": [vector_db.search(q) for q in questions], "pred": preds, "ref": refs})
df_res.to_csv("results_gpt2_faiss.csv", index=False)
print(" Results saved to results_gpt2_faiss.csv")

# Show what was retrieved
print("\n Retrieval Effect Demonstration (First Entry):")
print(f"Q: {questions[0]}")
print(f" Retrieved: {df_res.iloc[0]['retrieved_context'][:100]}...")
print(f" GPT-2 Answer: {preds[0]}")

In [None]:
import os
import torch
import numpy as np
import pandas as pd
import pyarrow.parquet as pq
import evaluate
import faiss
from tqdm import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
from sentence_transformers import SentenceTransformer


# 0. Configuration Parameters
class Config:
    parquet_path = r"C:\LLM\data\pubmedqa_hf\pqa_labeled_splits\test.parquet"
    model_path = r"C:\LLM\gpt2" 
    embedding_model = "all-MiniLM-L6-v2"
    limit = 50
    seed = 2025

args = Config()
torch.manual_seed(args.seed)


# 1. Vector Retrieval Module
class VectorDB:
    def __init__(self, model_name):
        print(f" Loading embedding model: {model_name}...")
        self.model = SentenceTransformer(model_name)
        self.index = None
        self.docs = []
    
    def create_index(self, documents):
        print(" Building FAISS index...")
        self.docs = documents
        embeddings = self.model.encode(documents, convert_to_numpy=True, show_progress_bar=True)
        self.index = faiss.IndexFlatL2(embeddings.shape[1])
        self.index.add(embeddings)

    def search(self, query, top_k=1):
        query_embedding = self.model.encode([query], convert_to_numpy=True)
        _, indices = self.index.search(query_embedding, top_k)
        return self.docs[indices[0][0]]


# 2. Prompt Construction (GPT-2 Specialized)
def build_prompt_native(q):
    """Native mode: Ask directly without providing materials"""
    return f"Question: {q}\nAnswer:"

def build_prompt_rag(q, ctx):
    """RAG mode: Provide retrieved materials + question"""
    # Simplify Prompt slightly, try to reduce GPT-2's repetitive behavior
    return (
        f"Context: {ctx[:800]}\n\n" # Truncate to prevent OOM
        f"Question: {q}\n"
        f"Answer:"
    )

def clean_output(full_text, prompt_len):
    gen = full_text[prompt_len:].strip()
    # Truncate common stop words
    for stop in ["\nQuestion:", "\nContext:", "Question:"]:
        if stop in gen:
            gen = gen.split(stop)[0].strip()
    return gen if gen else "Error"


# 3. Preparation
# Load data
if not os.path.exists(args.parquet_path):
    print(" Data file not found")
    exit()
tbl = pq.read_table(args.parquet_path)
df = tbl.to_pandas().dropna().head(args.limit)
database_contexts = df["context"].map(lambda c: " ".join((c or {}).get("contexts", []))).tolist()
questions = df["question"].tolist()
refs = df["long_answer"].tolist() if "long_answer" in df.columns else df[df.columns[-1]].tolist()

# Initialize FAISS
vector_db = VectorDB(args.embedding_model)
vector_db.create_index(database_contexts)

# Load GPT-2
print(f" Loading GPT-2: {args.model_path}")
tokenizer = AutoTokenizer.from_pretrained(args.model_path, use_fast=True, local_files_only=True)
if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(args.model_path, device_map="auto")
model.eval()


# 4. Comparative Inference Loop
print(f"\n Starting comparative experiment (Limit={len(questions)})...")
preds_native = []
preds_rag = []
retrieved_docs = []

for q in tqdm(questions, desc="Inference"):
    
    gen_kwargs = {
        "max_new_tokens": 64,
        "pad_token_id": tokenizer.pad_token_id,
        "eos_token_id": tokenizer.eos_token_id,
        "do_sample": False # Greedy decoding
    }

    #A. Native GPT-2
    prompt_native = build_prompt_native(q)
    inputs_native = tokenizer(prompt_native, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        out_native = model.generate(**inputs_native, **gen_kwargs)
        
    res_native = clean_output(tokenizer.decode(out_native[0], skip_special_tokens=True), inputs_native["input_ids"].shape[1])
    preds_native.append(res_native)

    #B. RAG GPT-2
    # 1. Retrieve
    ctx = vector_db.search(q)
    retrieved_docs.append(ctx)
    
    # 2. Generate
    prompt_rag = build_prompt_rag(q, ctx)
    inputs_rag = tokenizer(prompt_rag, return_tensors="pt").to(model.device)
    
    # Length protection
    if inputs_rag["input_ids"].shape[1] > 1000:
        inputs_rag["input_ids"] = inputs_rag["input_ids"][:, -1000:]

    with torch.no_grad():
        out_rag = model.generate(**inputs_rag, **gen_kwargs)

    res_rag = clean_output(tokenizer.decode(out_rag[0], skip_special_tokens=True), inputs_rag["input_ids"].shape[1])
    preds_rag.append(res_rag)


# 5. BERTScore Showdown
print("\n Calculating BERTScore...")
bertscore = evaluate.load("bertscore")
device = "cuda" if torch.cuda.is_available() else "cpu"

print(" Calculating Native scores...")
f1_native = np.array(bertscore.compute(predictions=preds_native, references=refs, lang="en", device=device, batch_size=32)['f1'])

print(" Calculating RAG scores...")
f1_rag = np.array(bertscore.compute(predictions=preds_rag, references=refs, lang="en", device=device, batch_size=32)['f1'])

avg_native = f1_native.mean()
avg_rag = f1_rag.mean()

print(f"\n{'='*50}")
print(f" GPT-2 Native vs. RAG (Real World)")
print(f"{'='*50}")
print(f" Native Score : {avg_native:.4f}")
print(f" RAG Score    : {avg_rag:.4f}")
print(f" Net Impact       : {(avg_rag - avg_native):.4f}")
print(f"{'='*50}")

if avg_rag < avg_native:
    print(" Conclusion: RAG caused performance degradation! GPT-2 got confused by retrieved documents.")
else:
    print(" Conclusion: RAG improved performance.")

# Save comparison results
df_res = pd.DataFrame({
    "question": questions,
    "pred_native": preds_native,
    "pred_rag": preds_rag,
    "score_native": f1_native,
    "score_rag": f1_rag,
    "diff": f1_rag - f1_native
})
df_res.to_csv("results_gpt2_comparison.csv", index=False)
print("\n Detailed comparison saved to: results_gpt2_comparison.csv")

In [None]:
import os
import re
import pandas as pd
import networkx as nx
import torch
import spacy
from transformers import AutoTokenizer, AutoModelForCausalLM, GPT2LMHeadModel, GPT2Tokenizer
from tqdm import tqdm
from fuzzywuzzy import fuzz
import pyarrow.parquet as pq

# 0. Global Configuration
class Config:
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    
    # Path Configuration
    GPT2_PATH = r"C:\LLM\gpt2"       # Path to GPT-2
    LLAMA2_PATH = r"C:\LLM\llama2"   # Path to LLaMA-2
    
    # Note: If you don't have kg.csv, the code will automatically create a tiny mock graph to ensure it runs.
    PRIMEKG_PATH = "kg.csv"          
    
    PUBMEDQA_PATH = r"C:\LLM\data\pubmedqa_hf\pqa_labeled_splits\test.parquet"
    
    MAX_CTX_EDGES = 5   # Limit context length (GPT-2 context window is small, too many will cause overflow)
    EVAL_LIMIT = 20     # Number of test samples
    
# Load NLP Tools
print("Loading Spacy NER model...")
try:
    NER = spacy.load("en_core_web_sm")
except:
    print("Spacy model not found, downloading...")
    os.system("python -m spacy download en_core_web_sm")
    NER = spacy.load("en_core_web_sm")


# 1. Model Loader
def load_model(model_type="gpt2"):
    print(f"Loading model: {model_type}...")
    if model_type.lower() == "gpt2":
        tokenizer = AutoTokenizer.from_pretrained(Config.GPT2_PATH, use_fast=True, local_files_only=True)
        if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token
        model = AutoModelForCausalLM.from_pretrained(Config.GPT2_PATH).to(Config.DEVICE).eval()
    elif model_type.lower() == "llama2":
        tokenizer = AutoTokenizer.from_pretrained(Config.LLAMA2_PATH, use_fast=True)
        model = AutoModelForCausalLM.from_pretrained(Config.LLAMA2_PATH, torch_dtype=torch.float16, device_map="auto").eval()
    else:
        raise ValueError("Unknown model type")
    return tokenizer, model

# 2. PrimeKG Graph Database Management
def load_primekg(path=Config.PRIMEKG_PATH):
    # If file does not exist, create a mock Knowledge Graph (Mock KG) for demonstration
    if not os.path.exists(path):
        print(f"   {path} not found. Creating Mock Knowledge Graph (Mock Mode)...")
        G = nx.Graph()
        # Add some medical common sense for testing
        G.add_edge("warfarin", "bleeding", relation="causes", x_type="Drug", y_type="Effect")
        G.add_edge("aspirin", "fever", relation="treats", x_type="Drug", y_type="Symptom")
        G.add_edge("smoking", "lung cancer", relation="causes", x_type="Behavior", y_type="Disease")
        G.add_edge("diabetes", "insulin", relation="treated_by", x_type="Disease", y_type="Drug")
        G.add_edge("hypertension", "stroke", relation="risk_factor_for", x_type="Disease", y_type="Disease")
        print("   Mock graph created.")
        return G
    
    print(f"Loading real PrimeKG: {path} ...")
    try:
        df = pd.read_csv(path, low_memory=False)
        # Assume CSV has x_name, y_name, relation columns
        G = nx.from_pandas_edgelist(df, source="x_name", target="y_name", edge_attr=["relation", "x_type", "y_type"])
        print(f"   PrimeKG Loaded: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges.")
        return G
    except Exception as e:
        print(f"   Failed to load CSV: {e}")
        return nx.Graph()

# Node matching index (speed up lookup)
_node_cache = {}

def resolve_node(G, name):
    """
    Find the closest node name in the graph
    """
    name = name.lower().strip()
    if not _node_cache:
        # Build cache
        for n in G.nodes():
            _node_cache[str(n).lower()] = n
            
    # 1. Exact match
    if name in _node_cache:
        return _node_cache[name]
    
    # 2. Fuzzy match (Slow for large graphs, recommended only for Mock Mode or small graphs)
    if G.number_of_nodes() < 10000:
        choices = list(G.nodes())
        best_match, score = None, 0
        for node in choices:
            s = fuzz.ratio(name, str(node).lower())
            if s > score:
                score = s
                best_match = node
        
        if score > 80:
            return best_match
            
    return None


# 3. Core RAG Retrieval Logic (Key Fix)
def extract_keywords(question):
    """
    Improvement: Extract key nouns from the sentence, not just named entities
    """
    doc = NER(question)
    # Prioritize entities
    keywords = [ent.text for ent in doc.ents]
    # If no entities, take noun chunks
    if not keywords:
        keywords = [chunk.text for chunk in doc.noun_chunks]
    # If still nothing, take all words except stop words
    if not keywords:
        keywords = [token.text for token in doc if not token.is_stop and not token.is_punct]
        
    return keywords

def get_knowledge_context(G, question, max_edges=Config.MAX_CTX_EDGES):
    """
    Retrieve relevant triples from the Knowledge Graph
    """
    keywords = extract_keywords(question)
    found_facts = []
    
    for kw in keywords:
        node = resolve_node(G, kw)
        if node:
            # Get neighbor nodes
            edges = list(G.edges(node, data=True))
            # Limit quantity to prevent Prompt from being too long
            for u, v, attr in edges[:3]: # Find at most 3 edges per keyword
                rel = attr.get('relation', 'related_to')
                target = v if u == node else u
                fact = f"{node} --[{rel}]--> {target}"
                found_facts.append(fact)
    
    # Deduplicate and truncate
    unique_facts = list(set(found_facts))[:max_edges]
    
    if not unique_facts:
        return "No specific knowledge found in graph."
    
    return "; ".join(unique_facts)


# 4. Generation and Prompt (Optimized for GPT-2)
def generate_answer(tokenizer, model, model_type, context, question):
    # Strong Few-Shot Prompt to make GPT-2 behave
    if model_type == "gpt2":
        prompt = (
            "Determine the relationship based on medical knowledge.\n\n"
            "Context: Aspirin --[treats]--> fever\n"
            "Question: Is aspirin effective for fever?\n"
            "Answer: Yes\n\n"
            "Context: Smoking --[causes]--> lung cancer\n"
            "Question: Does smoking improve lung capacity?\n"
            "Answer: No\n\n"
            f"Context: {context}\n"
            f"Question: {question}\n"
            "Answer:"
        )
    else:
        # Llama 2 format
        prompt = f"[INST] Context: {context}\nQuestion: {question}\nAnswer with Yes, No, or Maybe. [/INST]"

    inputs = tokenizer(prompt, return_tensors="pt").to(Config.DEVICE)
    
    # Truncate to prevent overflow (GPT-2 limit 1024)
    if inputs.input_ids.shape[1] > 1000:
        inputs.input_ids = inputs.input_ids[:, -1000:]

    with torch.no_grad():
        outputs = model.generate(
            inputs.input_ids, 
            max_new_tokens=10, 
            pad_token_id=tokenizer.pad_token_id,
            do_sample=False
        )
    
    full_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    # Extract answer part
    # GPT-2 might repeat, we need to capture content after "Answer:"
    if "Answer:" in full_text:
        # Take content after the last "Answer" label
        generated = full_text.split("Answer:")[-1].strip().lower()
    else:
        generated = full_text.lower()

    # Keyword matching
    if "yes" in generated: return "yes"
    if "no" in generated: return "no"
    if "maybe" in generated: return "maybe"
    
    return "maybe" # Fallback strategy


# 5. Evaluation Main Loop
def evaluate_on_pubmedqa(G, tokenizer, model, model_type, limit=Config.EVAL_LIMIT, use_rag=True):
    print(f"\n>>> Start Evaluation: {model_type.upper()} | RAG Mode: {use_rag}")
    
    # Load data
    if not os.path.exists(Config.PUBMEDQA_PATH):
        print(f"   Data file not found: {Config.PUBMEDQA_PATH}")
        return

    df = pq.read_table(Config.PUBMEDQA_PATH).to_pandas().head(limit)
    questions = df["question"].tolist()
    
    # Extract ground truth labels (yes/no/maybe)
    golds = []
    for x in df["final_decision"]:
        x = str(x).lower().strip()
        if x.startswith("y"): golds.append("yes")
        elif x.startswith("n"): golds.append("no")
        else: golds.append("maybe")

    preds = []
    correct_count = 0
    
    # Progress bar
    pbar = tqdm(total=limit)
    
    for i, q in enumerate(questions):
        # 1. Retrieve
        if use_rag:
            context = get_knowledge_context(G, q)
        else:
            context = "N/A"
            
        # 2. Generate
        pred = generate_answer(tokenizer, model, model_type, context, q)
        preds.append(pred)
        
        # 3. Real-time statistics
        if pred == golds[i]:
            correct_count += 1
            
        pbar.set_description(f"Acc: {correct_count/(i+1):.2%}")
        pbar.update(1)
        
        # Print the first few examples to check performance
        if i < 3:
            print(f"\n[Case {i}]")
            print(f"Q: {q}")
            print(f"KG Context: {context}")
            print(f"Pred: {pred} | Gold: {golds[i]}")
            
    pbar.close()
    
    acc = correct_count / limit
    print(f"\nFinal Accuracy ({model_type}, RAG={use_rag}): {acc:.2%}")

# 6. Program Entry Point
if __name__ == "__main__":
    # 1. Load KG
    G = load_primekg()
    
    # 2. Load Model
    tokenizer, model = load_model("gpt2")
    
    # 3. Run Comparative Evaluation
    # Test Native first (No RAG)
    evaluate_on_pubmedqa(G, tokenizer, model, "gpt2", use_rag=False)
    
    # Test KG-RAG next (With Knowledge Graph)
    evaluate_on_pubmedqa(G, tokenizer, model, "gpt2", use_rag=True)

bnuild_pubmed_index

In [None]:
import pandas as pd
import pyarrow.parquet as pq
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np


# 1. Prepare "Paper Repository" (Knowledge Base)

parquet_path = r"C:\LLM\data\pubmedqa_hf\pqa_labeled_splits\test.parquet"

print("Loading dataset to build Knowledge Base...")
df = pq.read_table(parquet_path).to_pandas()
# Treat the 'context' field of each row as a "document"
# 'context' in parquet is usually a dictionary or list, needs to be joined into a string
documents = []
for raw_ctx in df["context"]:
    # Depending on your parquet structure, it could be a list or a dict
    # Assuming standard format here, join sentences in the list
    try:
        if hasattr(raw_ctx, 'get'): # If it is a dict
            text_list = raw_ctx.get('contexts', [])
        else: # If it is a list
            text_list = raw_ctx
        
        full_abstract = " ".join(text_list)
        documents.append(full_abstract)
    except:
        documents.append("")

print(f"  Extracted {len(documents)} paper abstracts.")


# 2. Build FAISS Index (Text Index)
print("Converting abstracts to vectors (Embeddings)...")
embed_model = SentenceTransformer("all-MiniLM-L6-v2")
embeddings = embed_model.encode(documents, convert_to_numpy=True, show_progress_bar=True)
# Build index
index = faiss.IndexFlatL2(embeddings.shape[1])
index.add(embeddings)
print("  FAISS index construction completed!")


# 3. Retrieval Function

def retrieve_paper(question, top_k=1):
    q_emb = embed_model.encode([question], convert_to_numpy=True)
    distances, indices = index.search(q_emb, top_k)
    
    # Return the most relevant abstract found
    doc_id = indices[0][0]
    return documents[doc_id]


# 4. Test it out
q = "Does histologic chorioamnionitis correspond to clinical chorioamnionitis?"
retrieved_text = retrieve_paper(q)

print(f"\nQuestion: {q}")
print(f"Retrieved paper abstract:\n{retrieved_text[:200]}...") # Only print the first 200 characters

In [None]:
import os
import re
import json
import ast
import random
import torch
import numpy as np
import pandas as pd
from datetime import datetime
from sklearn.feature_extraction.text import CountVectorizer
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers.utils import logging as hf_logging
import evaluate


#  Global Configuration (Config)
class Config:
    #  Your local absolute path (prefix r prevents escape characters)
    parquet_path = r"C:\LLM\data\pubmedqa_hf\pqa_labeled\train-00000-of-00001.parquet"
    
    # Model name (use "gpt2" for testing; recommend switching to Llama-3-8B or Mistral for formal evaluation)
    model_name = "gpt2" 
    
    # Limit number of test samples (set to None to run the full 1000 entries)
    limit = 20  
    
    batch_size = 4
    max_new_tokens = 64
    max_ctx_chars = 4000
    
    # Output directory
    output_dir = "eval_results"
    seed = 42

cfg = Config()

# Set random seed
def set_seed(seed: int):
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(cfg.seed)
hf_logging.set_verbosity_error()
print(f" Config loaded. Device: {'cuda' if torch.cuda.is_available() else 'cpu'}")

  Config loaded. Device: cuda


In [None]:
# Data Loading
print(f" Reading data from: {cfg.parquet_path}")

if not os.path.exists(cfg.parquet_path):
    raise FileNotFoundError(f"  File not found: {cfg.parquet_path}")

# 1. Read Parquet file
df = pd.read_parquet(cfg.parquet_path)

# 2. Filter out invalid data
df = df.dropna(subset=['question', 'long_answer'])
if cfg.limit:
    df = df.head(cfg.limit)

# 3. Define Context extraction logic
def extract_context_text(c):
    """
    Specifically handles the PubMedQA context structure.
    The Context column is usually a dictionary: {'contexts': ['A...', 'B...'], 'labels': [...]}
    """
    try:
        # Case A: It is a dictionary (common case)
        if isinstance(c, dict) and 'contexts' in c:
            # 'contexts' is an array or list
            return " ".join(list(c['contexts']))
        
        # Case B: It is a Numpy array (sometimes pyarrow reads it as an array)
        if isinstance(c, np.ndarray): 
            # Assuming the array contains structs, try to get the 'contexts' field.
            # This is complex, and Case A usually covers it.
            # Simple handling here: if the array contains only strings, join directly.
            if c.dtype.kind in {'U', 'S'}: 
                return " ".join(c)
        
        # Case C: It is already a list
        if isinstance(c, list):
            return " ".join(c)

        return str(c)
    except Exception:
        return ""

print(" Processing contexts...")

# Prepare Context for RAG mode (extracted from dataset)
ctx_list_rag = df["context"].apply(extract_context_text).apply(lambda x: x[:cfg.max_ctx_chars]).tolist()

# Prepare Context for Base mode (set all to empty)
ctx_list_base = [""] * len(df)

questions = df["question"].tolist()
refs = df["long_answer"].tolist()
decisions = df["final_decision"].tolist() if "final_decision" in df else []

print(f" Data loaded successfully!")
print(f"- Total Samples: {len(df)}")
print(f"- First Question: {questions[0]}")
print(f"- First Ref Answer: {refs[0][:50]}...")

In [None]:
def build_prompt(tok, q: str, ctx: str) -> str:
    """
    Automatically switches between Open-Book / Closed-Book templates based on whether ctx is empty.
    """
    has_ctx = ctx and len(ctx.strip()) > 0
    
    if has_ctx:
        # RAG Mode
        return (
            f"Question: {q}\n"
            f"Context: {ctx}\n\n"
            f"Answer the question using the context. Provide a reasoning then a final Yes/No/Maybe.\n"
            f"Answer: "
        )
    else:
        # Base Mode (No Context)
        return (
            f"Question: {q}\n\n"
            f"Answer the question based on your knowledge. Provide a reasoning then a final Yes/No/Maybe.\n"
            f"Answer: "
        )

def compute_keyword_recall(preds, refs):
    """Calculates how many keywords from the reference answer are included in the generated answer."""
    scores = []
    vectorizer = CountVectorizer(stop_words='english')
    for p, r in zip(preds, refs):
        try:
            vectorizer.fit([r])
            ref_keywords = set(vectorizer.get_feature_names_out())
            if not ref_keywords:
                scores.append(0.0); continue
            
            p_lower = p.lower()
            hit = sum(1 for kw in ref_keywords if kw in p_lower)
            scores.append(hit / len(ref_keywords))
        except:
            scores.append(0.0)
    return scores

def extract_decision(text):
    t = text.lower()
    if 'yes' in t: return 'yes'
    if 'no' in t: return 'no'
    if 'maybe' in t: return 'maybe'
    return 'unknown'

def run_inference(model, tok, qs, ctxs, cfg, label):
    print(f" Running Inference: {label} ...")
    preds = []
    
    # Ensure pad_token exists
    if tok.pad_token is None: tok.pad_token = tok.eos_token
    tok.padding_side = "left"
    
    prompts = [build_prompt(tok, q, c) for q, c in zip(qs, ctxs)]
    
    total = len(prompts)
    for i in range(0, total, cfg.batch_size):
        batch = prompts[i : i + cfg.batch_size]
        inputs = tok(batch, return_tensors="pt", padding=True, truncation=True, max_length=cfg.max_ctx_chars+256).to(model.device)
        
        with torch.no_grad():
            out = model.generate(
                **inputs, 
                max_new_tokens=cfg.max_new_tokens, 
                pad_token_id=tok.pad_token_id, 
                do_sample=False # For reproducible results
            )
        
        # Decode, taking only the generated part
        input_len = inputs.input_ids.shape[1]
        decoded = tok.batch_decode(out[:, input_len:], skip_special_tokens=True)
        preds.extend([d.strip() for d in decoded])
        
        if (i // cfg.batch_size) % 5 == 0:
            print(f"   Batch {i // cfg.batch_size + 1} done.")
            
    return preds

In [None]:
# 1. Load Model
print(f" Loading Model: {cfg.model_name}...")
try:
    tok = AutoTokenizer.from_pretrained(cfg.model_name, use_fast=True)
    model = AutoModelForCausalLM.from_pretrained(
        cfg.model_name,
        torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
        device_map="auto"
    )
    model.eval()
except Exception as e:
    print(f" Error loading model: {e}")
    raise e

# 2. Execute Two Rounds of Inference
# Round 1: Base (No Context)
preds_base = run_inference(model, tok, questions, ctx_list_base, cfg, label="Base (No Context)")

# Round 2: RAG (With Context)
preds_rag = run_inference(model, tok, questions, ctx_list_rag, cfg, label="RAG (With Context)")
print("\n Inference complete!")

In [None]:
print("\nCalculating Metrics (This may take a moment)...")

# 1. ROUGE (N-gram Overlap)
rouge = evaluate.load("rouge")
r_base = rouge.compute(predictions=preds_base, references=refs)
r_rag = rouge.compute(predictions=preds_rag, references=refs)

# 2. BERTScore (Semantic Similarity) - Will download a small model on first run
bertscore = evaluate.load("bertscore")
bs_base = bertscore.compute(predictions=preds_base, references=refs, lang="en", model_type="distilbert-base-uncased")
bs_rag = bertscore.compute(predictions=preds_rag, references=refs, lang="en", model_type="distilbert-base-uncased")

# 3. Keyword Recall
kw_base = compute_keyword_recall(preds_base, refs)
kw_rag = compute_keyword_recall(preds_rag, refs)

# 4. Summary of Results
summary = {
    "ROUGE-L": {
        "Base": round(r_base['rougeL'], 4), 
        "RAG": round(r_rag['rougeL'], 4)
    },
    "BERTScore-F1": {
        "Base": round(np.mean(bs_base['f1']), 4), 
        "RAG": round(np.mean(bs_rag['f1']), 4)
    },
    "Keyword-Recall": {
        "Base": round(np.mean(kw_base), 4), 
        "RAG": round(np.mean(kw_rag), 4)
    }
}

print("\n" + "="*40)
print(" EVALUATION SUMMARY")
print("="*40)
print(json.dumps(summary, indent=2))

# 5. Save detailed results to file
time_str = datetime.now().strftime("%Y%m%d_%H%M%S")
os.makedirs(cfg.output_dir, exist_ok=True)
save_path = os.path.join(cfg.output_dir, f"result_comparison_{time_str}.jsonl")

with open(save_path, "w", encoding="utf-8") as f:
    for i in range(len(preds_base)):
        row = {
            "id": i,
            "question": questions[i],
            "reference": refs[i],
            "base_pred": preds_base[i],
            "rag_pred": preds_rag[i],
            "base_kw_recall": kw_base[i],
            "rag_kw_recall": kw_rag[i],
            "rag_context_preview": ctx_list_rag[i][:100] # Save a bit of context for inspection
        }
        f.write(json.dumps(row, ensure_ascii=False) + "\n")

print(f"\n Detailed results saved to: {save_path}")


Calculating Metrics (This may take a moment)...


Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`



 EVALUATION SUMMARY
{
  "ROUGE-L": {
    "Base": 0.1033,
    "RAG": 0.135
  },
  "BERTScore-F1": {
    "Base": 0.6773,
    "RAG": 0.6994
  },
  "Keyword-Recall": {
    "Base": 0.0746,
    "RAG": 0.1068
  }
}

 Detailed results saved to: eval_results\result_comparison_20251210_165022.jsonl


In [None]:
# Full Script: PubMedQA test + KB(pubmed_documents.pkl) RAG
# Metrics: Decision ACC + ROUGE-L (pure python) + BERTScore (optional)

import os
import re
import json
import pickle
import random
import numpy as np
import pandas as pd
import torch
import pyarrow.parquet as pq
from datetime import datetime
from tqdm import tqdm

from transformers import AutoTokenizer, AutoModelForCausalLM
from sklearn.feature_extraction.text import TfidfVectorizer



# 0) Config

class Config:
    # ---- Eval data (TEST) ----
    parquet_path = r"/LLM/data/pubmedqa_hf/pqa_labeled_splits/test.parquet"

    # ---- KB docs (your "kg") ----
    kb_docs_path = r"/LLM/PrimeKG/pubmed_documents.pkl"

    # (Optional) FAISS index if you have it. If missing -> TF-IDF fallback
    kb_index_path = r"/LLM/PrimeKG/pubmed_qa.index"
    embed_model_name = "sentence-transformers/all-MiniLM-L6-v2"  # only used if FAISS exists

    # ---- Generator model ----
    model_name_or_path = "gpt2"
    batch_size = 4
    max_new_tokens = 64

    # ---- RAG params ----
    top_k_docs = 2
    max_ctx_chars = 1200

    # ---- TF-IDF fallback params ----
    tfidf_max_docs = None 
    tfidf_max_features = 200000

    # ---- Eval params ----
    limit = 100  # None -> run all
    seed = 42

    # ---- Output ----
    output_dir = "eval_results_pubmedqa_kbRAG_test"

cfg = Config()



# 1) Utilities

def set_seed(seed: int):
    os.environ["PYTHONHASHSEED"] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(cfg.seed)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("[Info] DEVICE:", DEVICE)


def extract_decision(text: str) -> str:
    """Extract final decision label from generated text (robust regex)."""
    t = (text or "").lower()
    if re.search(r"\byes\b", t): return "yes"
    if re.search(r"\bno\b", t): return "no"
    if re.search(r"\bmaybe\b", t): return "maybe"
    return "unknown"


def accuracy(y_true, y_pred):
    if not y_true:
        return None
    correct = sum(int(a == b) for a, b in zip(y_true, y_pred))
    return correct / len(y_true)


def confusion_table(y_true, y_pred, labels=("yes", "no", "maybe", "unknown")):
    """Return a pandas crosstab confusion table."""
    return pd.crosstab(
        pd.Series(y_true, name="GT"),
        pd.Series(y_pred, name="Pred"),
        rownames=["GT"], colnames=["Pred"],
        dropna=False
    ).reindex(index=list(labels), columns=list(labels), fill_value=0)


# ---- Pure Python ROUGE-L F1 ----
def _lcs_len(a_tokens, b_tokens):
    n, m = len(a_tokens), len(b_tokens)
    dp = [0] * (m + 1)
    for i in range(1, n + 1):
        prev = 0
        for j in range(1, m + 1):
            tmp = dp[j]
            if a_tokens[i - 1] == b_tokens[j - 1]:
                dp[j] = prev + 1
            else:
                dp[j] = max(dp[j], dp[j - 1])
            prev = tmp
    return dp[m]

def rouge_l_f1(pred: str, ref: str) -> float:
    pred = (pred or "").strip()
    ref = (ref or "").strip()
    if not pred or not ref:
        return 0.0
    pred_tokens = re.findall(r"\w+", pred.lower())
    ref_tokens  = re.findall(r"\w+", ref.lower())
    if not pred_tokens or not ref_tokens:
        return 0.0

    lcs = _lcs_len(pred_tokens, ref_tokens)
    prec = lcs / len(pred_tokens)
    rec  = lcs / len(ref_tokens)
    if prec + rec == 0:
        return 0.0
    return 2 * prec * rec / (prec + rec)



# 2) Load TEST parquet
print(f"[Info] Reading test data: {cfg.parquet_path}")
if not os.path.exists(cfg.parquet_path):
    raise FileNotFoundError(f"Parquet not found: {cfg.parquet_path}")

tbl = pq.read_table(cfg.parquet_path)
df = tbl.to_pandas()

df = df.dropna(subset=["question"])
if cfg.limit:
    df = df.head(cfg.limit)

questions = df["question"].astype(str).tolist()

# Pick reference column (for ROUGE/BERTScore)
ref_col = None
for c in ["long_answer", "final_decision", "answer"]:
    if c in df.columns:
        ref_col = c
        break
if ref_col is None:
    ref_col = df.columns[-1]

refs = df[ref_col].fillna("").astype(str).tolist()

# Decision GT for ACC (best choice is final_decision)
gt_decisions = None
if "final_decision" in df.columns:
    gt_decisions = df["final_decision"].fillna("").astype(str).str.lower().tolist()

print(f"[Info] Loaded samples: {len(questions)} | ref_col={ref_col} | has_final_decision={gt_decisions is not None}")



# 3) Load KB docs (pubmed_documents.pkl)
print(f"[Info] Loading KB docs: {cfg.kb_docs_path}")
if not os.path.exists(cfg.kb_docs_path):
    raise FileNotFoundError(f"KB docs not found: {cfg.kb_docs_path}")

with open(cfg.kb_docs_path, "rb") as f:
    kb_docs = pickle.load(f)

if cfg.tfidf_max_docs is not None:
    kb_docs = kb_docs[:cfg.tfidf_max_docs]

kb_docs = [("" if d is None else str(d)) for d in kb_docs]
print(f"[Info] KB docs loaded: {len(kb_docs)}")



# 4) Build Retriever (FAISS preferred else TF-IDF)
use_faiss = False
faiss_index = None
embed_model = None

tfidf_vectorizer = None
tfidf_X = None

def build_tfidf_retriever(docs):
    print("[Info] Building TF-IDF retriever (fallback)...")
    vectorizer = TfidfVectorizer(
        stop_words="english",
        max_features=cfg.tfidf_max_features,
        ngram_range=(1, 2),
    )
    X = vectorizer.fit_transform(docs)
    print("[Info] TF-IDF ready.")
    return vectorizer, X

try:
    if os.path.exists(cfg.kb_index_path):
        import faiss
        from sentence_transformers import SentenceTransformer

        print(f"[Info] Found FAISS index: {cfg.kb_index_path}")
        faiss_index = faiss.read_index(cfg.kb_index_path)
        embed_model = SentenceTransformer(cfg.embed_model_name)
        use_faiss = True
        print("[Info] Using FAISS retriever.")
    else:
        print("[Info] FAISS index not found -> fallback TF-IDF.")
except Exception as e:
    print(f"[Warn] FAISS init failed -> fallback TF-IDF. Error: {e}")

if not use_faiss:
    tfidf_vectorizer, tfidf_X = build_tfidf_retriever(kb_docs)

def retrieve_docs(query: str, top_k: int):
    query = (query or "").strip()
    if not query:
        return []

    if use_faiss:
        import faiss
        q_emb = embed_model.encode([query], convert_to_numpy=True)
        faiss.normalize_L2(q_emb)
        scores, idxs = faiss_index.search(q_emb, top_k)
        idxs = idxs[0].tolist()
        return [kb_docs[i] for i in idxs if 0 <= i < len(kb_docs)]
    else:
        qv = tfidf_vectorizer.transform([query])
        sims = (tfidf_X @ qv.T).toarray().ravel()
        if top_k >= len(sims):
            top_idx = np.argsort(-sims)
        else:
            top_idx = np.argpartition(-sims, top_k)[:top_k]
            top_idx = top_idx[np.argsort(-sims[top_idx])]
        return [kb_docs[i] for i in top_idx.tolist()]

def build_rag_context(question: str):
    docs = retrieve_docs(question, cfg.top_k_docs)
    ctx = "\n\n".join([d.strip() for d in docs if d and d.strip()])
    return ctx[:cfg.max_ctx_chars] if ctx else ""

ctx_list_rag = [build_rag_context(q) for q in questions]
ctx_list_base = [""] * len(questions)

print("[Info] ctx example head:", (ctx_list_rag[0][:200] + "...") if ctx_list_rag and ctx_list_rag[0] else "<EMPTY>")



# 5) Load Generator model
print(f"[Info] Loading generator: {cfg.model_name_or_path}")
tok = AutoTokenizer.from_pretrained(cfg.model_name_or_path, use_fast=True)
if tok.pad_token is None:
    tok.pad_token = tok.eos_token
tok.padding_side = "left"

model = AutoModelForCausalLM.from_pretrained(
    cfg.model_name_or_path,
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    device_map="auto"
)
model.eval()
print("[Info] model.device:", model.device)



# 6) Prompt + Inference
def build_prompt(q: str, ctx: str) -> str:
    has_ctx = bool(ctx and ctx.strip())
    if has_ctx:
        return (
            f"Question: {q}\n\n"
            f"Retrieved Documents:\n{ctx}\n\n"
            "Answer the question using ONLY the retrieved documents. "
            "Provide brief reasoning then a final Yes/No/Maybe.\n"
            "Answer: "
        )
    else:
        return (
            f"Question: {q}\n\n"
            "Answer the question based on your knowledge. "
            "Provide brief reasoning then a final Yes/No/Maybe.\n"
            "Answer: "
        )

def run_inference(qs, ctxs, label: str):
    print(f"[Info] Running inference: {label}")
    preds = []
    prompts = [build_prompt(q, c) for q, c in zip(qs, ctxs)]

    for i in range(0, len(prompts), cfg.batch_size):
        batch_prompts = prompts[i:i + cfg.batch_size]
        inputs = tok(
            batch_prompts,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=cfg.max_ctx_chars + 256
        ).to(model.device)

        with torch.no_grad():
            out = model.generate(
                **inputs,
                max_new_tokens=cfg.max_new_tokens,
                pad_token_id=tok.pad_token_id,
                do_sample=False
            )

        input_len = inputs["input_ids"].shape[1]
        decoded = tok.batch_decode(out[:, input_len:], skip_special_tokens=True)
        preds.extend([d.strip() for d in decoded])

        if (i // cfg.batch_size) % 10 == 0:
            print(f"  batch {i // cfg.batch_size + 1} done")

    return preds

preds_base = run_inference(questions, ctx_list_base, "Base (No Context)")
preds_rag  = run_inference(questions, ctx_list_rag,  "RAG (KB Retrieved Docs)")



# Decision by logprob (robust)
@torch.no_grad()
def decision_by_logprob(prompt: str, tok, model, device: str):
    """
    Return best label in {yes,no,maybe} by scoring prompt + candidate.
    Uses negative NLL on candidate tokens only.
    """
    cand_map = {"yes": " yes", "no": " no", "maybe": " maybe"}

    prompt_ids = tok(prompt, add_special_tokens=False).input_ids
    prompt_len = len(prompt_ids)

    scores = {}
    for lab, suffix in cand_map.items():
        ids = tok(prompt + suffix, return_tensors="pt", add_special_tokens=False).input_ids.to(device)
        labels = ids.clone()
        labels[:, :prompt_len] = -100  # mask prompt tokens, only score suffix tokens

        out = model(input_ids=ids, labels=labels, use_cache=False)
        cand_len = ids.shape[1] - prompt_len
        # out.loss is mean NLL over scored tokens -> convert to total logprob
        logp = -float(out.loss.item()) * float(cand_len)
        scores[lab] = logp

    best = max(scores, key=scores.get)
    return best, scores


def extract_decision_last(text: str) -> str:
    """Extract last occurrence of decision label from text using regex."""
    matches = re.findall(r"\b(yes|no|maybe)\b", (text or "").lower())
    return matches[-1] if matches else "unknown"


def get_decision(prompt: str, gen_text: str, tok, model, device: str) -> str:
    """
    Get final decision from generated text; if unknown, use logprob scoring.
    """
    d = extract_decision_last(gen_text)
    if d != "unknown":
        return d
    best, _ = decision_by_logprob(prompt, tok, model, device)
    return best



# 7) Metrics: ACC + ROUGE-L + (optional) BERTScore
#Decision ACC
acc_base = acc_rag = None
conf_rag = None
prompts_base = [build_prompt(q, c) for q, c in zip(questions, ctx_list_base)]
prompts_rag  = [build_prompt(q, c) for q, c in zip(questions, ctx_list_rag)]

pred_dec_base = [
    get_decision(prompts_base[i], preds_base[i], tok, model, model.device)
    for i in range(len(preds_base))
]
pred_dec_rag = [
    get_decision(prompts_rag[i], preds_rag[i], tok, model, model.device)
    for i in range(len(preds_rag))
]


if gt_decisions is not None:
    # 只保留 GT ∈ {yes,no,maybe}
    valid_idx = [i for i, g in enumerate(gt_decisions) if g in ("yes", "no", "maybe")]
    y_true = [gt_decisions[i] for i in valid_idx]
    yb = [pred_dec_base[i] for i in valid_idx]
    yr = [pred_dec_rag[i] for i in valid_idx]

    acc_base = accuracy(y_true, yb)
    acc_rag  = accuracy(y_true, yr)
    conf_rag = confusion_table(y_true, yr, labels=("yes", "no", "maybe", "unknown"))

print("\n" + "=" * 60)
print("Decision ACC (Yes/No/Maybe)")
if acc_base is None:
    print("[Warn] final_decision not found in parquet -> ACC skipped.")
else:
    print(f"Base ACC: {acc_base:.4f}")
    print(f"RAG  ACC: {acc_rag:.4f}")
    print(f"Gain   : {acc_rag - acc_base:+.4f}")
    print("\nConfusion Matrix (RAG):")
    print(conf_rag)
print("=" * 60)

# ROUGE-L (pure python)
rougeL_base_list = [rouge_l_f1(p, r) for p, r in zip(preds_base, refs)]
rougeL_rag_list  = [rouge_l_f1(p, r) for p, r in zip(preds_rag, refs)]
rougeL_base = float(np.mean(rougeL_base_list))
rougeL_rag  = float(np.mean(rougeL_rag_list))

print("\nROUGE-L (pure python):")
print(f"Base: {rougeL_base:.4f} | RAG: {rougeL_rag:.4f} | Gain: {rougeL_rag - rougeL_base:+.4f}")

# BERTScore (optional; may require downloads/extra deps)
bs_base = bs_rag = None
try:
    import evaluate
    bertscore = evaluate.load("bertscore")
    bs_res_base = bertscore.compute(predictions=preds_base, references=refs, lang="en", model_type="distilbert-base-uncased")
    bs_res_rag  = bertscore.compute(predictions=preds_rag,  references=refs, lang="en", model_type="distilbert-base-uncased")
    bs_base = float(np.mean(bs_res_base["f1"]))
    bs_rag  = float(np.mean(bs_res_rag["f1"]))
    print("\nBERTScore-F1:")
    print(f"Base: {bs_base:.4f} | RAG: {bs_rag:.4f} | Gain: {bs_rag - bs_base:+.4f}")
except Exception as e:
    print("\n[Warn] BERTScore skipped (dependency/download issue). Error:", str(e)[:200])



# 8) Save results

summary = {
    "meta": {
        "parquet_path": cfg.parquet_path,
        "kb_docs_path": cfg.kb_docs_path,
        "kb_index_path": cfg.kb_index_path if os.path.exists(cfg.kb_index_path) else None,
        "retriever": "faiss" if use_faiss else "tfidf",
        "model": cfg.model_name_or_path,
        "n": len(questions),
        "ref_col": ref_col,
    },
    "decision_acc": {
        "base": None if acc_base is None else round(float(acc_base), 4),
        "rag":  None if acc_rag  is None else round(float(acc_rag), 4),
    },
    "rougeL": {
        "base": round(rougeL_base, 4),
        "rag":  round(rougeL_rag, 4),
    },
    "bertscore_f1": {
        "base": None if bs_base is None else round(bs_base, 4),
        "rag":  None if bs_rag  is None else round(bs_rag, 4),
    }
}

os.makedirs(cfg.output_dir, exist_ok=True)
time_str = datetime.now().strftime("%Y%m%d_%H%M%S")
jsonl_path = os.path.join(cfg.output_dir, f"kbRAG_test_{time_str}.jsonl")
csv_path   = os.path.join(cfg.output_dir, f"kbRAG_test_{time_str}.csv")
summary_path = os.path.join(cfg.output_dir, f"kbRAG_test_{time_str}_summary.json")

# Save per-sample
rows = []
with open(jsonl_path, "w", encoding="utf-8") as f:
    for i in range(len(questions)):
        row = {
            "id": i,
            "question": questions[i],
            "reference": refs[i],
            "gt_decision": None if gt_decisions is None else gt_decisions[i],
            "base_pred": preds_base[i],
            "rag_pred": preds_rag[i],
            "base_decision": pred_dec_base[i],
            "rag_decision": pred_dec_rag[i],
            "rougeL_base": rougeL_base_list[i],
            "rougeL_rag":  rougeL_rag_list[i],
            "rag_context_preview": (ctx_list_rag[i][:300] if ctx_list_rag[i] else ""),
        }
        f.write(json.dumps(row, ensure_ascii=False) + "\n")
        rows.append(row)

# Save CSV
pd.DataFrame(rows).to_csv(csv_path, index=False)

# Save summary JSON
with open(summary_path, "w", encoding="utf-8") as f:
    json.dump(summary, f, ensure_ascii=False, indent=2)

print("\n" + "=" * 60)
print("[Done] Saved:")
print("  JSONL  :", jsonl_path)
print("  CSV    :", csv_path)
print("  SUMMARY:", summary_path)
print("=" * 60)
print(json.dumps(summary, indent=2, ensure_ascii=False))


[Info] DEVICE: cuda
[Info] Reading test data: /LLM/data/pubmedqa_hf/pqa_labeled_splits/test.parquet
[Info] Loaded samples: 100 | ref_col=long_answer | has_final_decision=True
[Info] Loading KB docs: /LLM/PrimeKG/pubmed_documents.pkl
[Info] KB docs loaded: 800
[Info] Found FAISS index: /LLM/PrimeKG/pubmed_qa.index
[Info] Using FAISS retriever.
[Info] ctx example head: The etiology of hemodialysis (HD)-induced hypotension and hypertension remains speculative. There is mounting evidence that endothelin-1 (ET-1) may play a vital role in these hemodynamic changes. We e...
[Info] Loading generator: gpt2
[Info] model.device: cuda:0
[Info] Running inference: Base (No Context)
  batch 1 done
  batch 11 done
  batch 21 done
[Info] Running inference: RAG (KB Retrieved Docs)
  batch 1 done
  batch 11 done
  batch 21 done

Decision ACC (Yes/No/Maybe)
Base ACC: 0.1000
RAG  ACC: 0.2800
Gain   : +0.1800

Confusion Matrix (RAG):
Pred     yes  no  maybe  unknown
GT                              
yes     

Using the latest cached version of the module from /home/miaoen/.cache/huggingface/modules/evaluate_modules/metrics/evaluate-metric--bertscore/cf4907b18f8f741f202232c0f8009a3bd49ff98802c245abcb6ea51a37a8c05b (last modified on Wed Dec 17 16:39:27 2025) since it couldn't be found locally at evaluate-metric--bertscore, or remotely on the Hugging Face Hub.



BERTScore-F1:
Base: 0.7041 | RAG: 0.6744 | Gain: -0.0297

[Done] Saved:
  JSONL  : eval_results_pubmedqa_kbRAG_test/kbRAG_test_20251222_225401.jsonl
  CSV    : eval_results_pubmedqa_kbRAG_test/kbRAG_test_20251222_225401.csv
  SUMMARY: eval_results_pubmedqa_kbRAG_test/kbRAG_test_20251222_225401_summary.json
{
  "meta": {
    "parquet_path": "/LLM/data/pubmedqa_hf/pqa_labeled_splits/test.parquet",
    "kb_docs_path": "/LLM/PrimeKG/pubmed_documents.pkl",
    "kb_index_path": "/LLM/PrimeKG/pubmed_qa.index",
    "retriever": "faiss",
    "model": "gpt2",
    "n": 100,
    "ref_col": "long_answer"
  },
  "decision_acc": {
    "base": 0.1,
    "rag": 0.28
  },
  "rougeL": {
    "base": 0.126,
    "rag": 0.1108
  },
  "bertscore_f1": {
    "base": 0.7041,
    "rag": 0.6744
  }
}


In [None]:
import numpy as np
import torch

LABELS = ["yes", "no", "maybe"]

def _score_completion(tok, model, prompt_text: str, completion_text: str, max_len=1024) -> float:
    """
    Compute log-probability score of completion_text given prompt_text.
    """
    full = prompt_text + " " + completion_text
    enc_full = tok(full, return_tensors="pt", truncation=True, max_length=max_len).to(model.device)
    enc_prompt = tok(prompt_text, return_tensors="pt", truncation=True, max_length=max_len).to(model.device)

    input_ids = enc_full["input_ids"]
    attn = enc_full["attention_mask"]

    # labels: mask prompt tokens only score completion tokens
    labels = input_ids.clone()
    prompt_len = enc_prompt["input_ids"].shape[1]
    if prompt_len >= input_ids.shape[1]:
        return -1e9
    labels[:, :prompt_len] = -100

    with torch.no_grad():
        out = model(input_ids=input_ids, attention_mask=attn, labels=labels)
        loss = out.loss
        if torch.isnan(loss):
            return -1e9
        return float(-loss.item())

def build_decision_prompt(question: str, ctx: str = "") -> str:
    """
    Build prompt for decision scoring.
    """
    if ctx and ctx.strip():
        return (
            f"Question: {question}\n\n"
            f"Retrieved Documents:\n{ctx}\n\n"
            "Decide the final answer. Respond with ONE WORD only: yes / no / maybe.\n"
            "Final answer:"
        )
    else:
        return (
            f"Question: {question}\n\n"
            "Decide the final answer. Respond with ONE WORD only: yes / no / maybe.\n"
            "Final answer:"
        )

def predict_decision_by_scoring(tok, model, question: str, ctx: str = ""):
    prompt = build_decision_prompt(question, ctx)
    scores = {lab: _score_completion(tok, model, prompt, lab) for lab in LABELS}
    pred = max(scores, key=scores.get)
    return pred, scores


In [None]:
# Decision prediction by scoring (Base / RAG) 
pred_dec_base = []
pred_dec_rag = []
scores_dec_base = []
scores_dec_rag = []

for q, ctx in zip(questions, ctx_list_rag):
    p_b, s_b = predict_decision_by_scoring(tok, model, q, ctx="")
    p_r, s_r = predict_decision_by_scoring(tok, model, q, ctx=ctx)
    pred_dec_base.append(p_b); scores_dec_base.append(s_b)
    pred_dec_rag.append(p_r);  scores_dec_rag.append(s_r)

# ACC
valid_idx = [i for i,g in enumerate(gt_decisions) if g in ("yes","no","maybe")]
y_true = [gt_decisions[i] for i in valid_idx]
y_base = [pred_dec_base[i] for i in valid_idx]
y_rag  = [pred_dec_rag[i]  for i in valid_idx]

acc_base = sum(a==b for a,b in zip(y_true,y_base)) / len(y_true)
acc_rag  = sum(a==b for a,b in zip(y_true,y_rag )) / len(y_true)

print("Decision ACC (scoring):")
print("Base:", acc_base)
print("RAG :", acc_rag)


In [None]:
print("[Check] len(kb_docs) =", len(kb_docs))
try:
    import faiss
    print("[Check] faiss_index.ntotal =", faiss_index.ntotal)
    if faiss_index.ntotal != len(kb_docs):
        print("[FATAL] FAISS index and docs length mismatch! 。")
except Exception as e:
    print("[Check] faiss not available:", e)


[Check] len(kb_docs) = 800
[Check] faiss_index.ntotal = 800


In [None]:
import os
import numpy as np
import pandas as pd
import torch

# 0) Your input CSV (currently missing bertscore columns)
csv_path = r"eval_results_pubmedqa_kbRAG_test/kbRAG_test_20251222_175710.csv"
assert os.path.exists(csv_path), f"CSV not found: {csv_path}"

df = pd.read_csv(csv_path)
print("[Info] Loaded:", csv_path, "| rows:", len(df))
print("[Info] Columns:", list(df.columns))

# 1) Prepare prediction/reference text
# Prevent null values
preds_base = df["base_pred"].fillna("").astype(str).tolist()
preds_rag  = df["rag_pred"].fillna("").astype(str).tolist()
refs       = df["reference"].fillna("").astype(str).tolist()

# 2) Load BERTScore (evaluate)
# You previously reported rouge needed absl, bertscore might also depend on some packages here
# If ModuleNotFoundError: absl occurs again, install it: pip install absl-py
import evaluate

device = "cuda" if torch.cuda.is_available() else "cpu"
print("[Info] BERTScore device:", device)

bertscore = evaluate.load("bertscore")

# Note:
# - roberta-large is the most accurate but large; distilbert-base-uncased is lighter and faster
# - Your previous logs used the default (likely roberta-large), here I'm giving you the distilbert version, stable and saves VRAM
model_type = "distilbert-base-uncased"

print("[Info] Computing BERTScore for BASE ...")
res_base = bertscore.compute(
    predictions=preds_base,
    references=refs,
    lang="en",
    model_type=model_type,
    device=device,
    batch_size=16,   # Change to 8 if VRAM is insufficient
)
f1_base = np.array(res_base["f1"], dtype=float)

print("[Info] Computing BERTScore for RAG ...")
res_rag = bertscore.compute(
    predictions=preds_rag,
    references=refs,
    lang="en",
    model_type=model_type,
    device=device,
    batch_size=16,
)
f1_rag = np.array(res_rag["f1"], dtype=float)

# Write back to DF
df["bert_f1_base"] = f1_base
df["bert_f1_rag"]  = f1_rag
df["bert_gain"]    = df["bert_f1_rag"] - df["bert_f1_base"]

print("\n[Info] BERTScore summary:")
print("  base mean:", float(df["bert_f1_base"].mean()))
print("  rag  mean:", float(df["bert_f1_rag"].mean()))
print("  gain mean:", float(df["bert_gain"].mean()))

# 3) Export Top3 / Bottom3 (By RAG BERTScore) ======
top3_rag = df.sort_values("bert_f1_rag", ascending=False).head(3)
bot3_rag = df.sort_values("bert_f1_rag", ascending=True).head(3)

# Columns you want to save (can add more yourself)
keep_cols = [
    "id","question","gt_decision",
    "bert_f1_base","bert_f1_rag","bert_gain",
    "rougeL_base","rougeL_rag",
    "base_pred","rag_pred",
    "rag_context_preview"
]
keep_cols = [c for c in keep_cols if c in df.columns]

print("\n[Top3] Highest RAG BERTScore:")
print(top3_rag[["id","bert_f1_rag","bert_gain","question"]].to_string(index=False))

print("\n[Bottom3] Lowest RAG BERTScore:")
print(bot3_rag[["id","bert_f1_rag","bert_gain","question"]].to_string(index=False))

# 4) Optional: Export gain Top3 / Bottom3 (RAG-Base) ======
top3_gain = df.sort_values("bert_gain", ascending=False).head(3)
bot3_gain = df.sort_values("bert_gain", ascending=True).head(3)

print("\n[Top3] Highest BERT Gain (RAG-Base):")
print(top3_gain[["id","bert_gain","bert_f1_base","bert_f1_rag","question"]].to_string(index=False))

print("\n[Bottom3] Lowest BERT Gain (RAG-Base):")
print(bot3_gain[["id","bert_gain","bert_f1_base","bert_f1_rag","question"]].to_string(index=False))

# 5) Save files
out_dir = os.path.dirname(csv_path) or "."
base_name = os.path.splitext(os.path.basename(csv_path))[0]

csv_with_bert = os.path.join(out_dir, f"{base_name}_with_bertscore.csv")
top3_path     = os.path.join(out_dir, f"{base_name}_top3_bertscore_rag.csv")
bot3_path     = os.path.join(out_dir, f"{base_name}_bottom3_bertscore_rag.csv")
top3g_path    = os.path.join(out_dir, f"{base_name}_top3_bertscore_gain.csv")
bot3g_path    = os.path.join(out_dir, f"{base_name}_bottom3_bertscore_gain.csv")

df.to_csv(csv_with_bert, index=False, encoding="utf-8-sig")
top3_rag[keep_cols].to_csv(top3_path, index=False, encoding="utf-8-sig")
bot3_rag[keep_cols].to_csv(bot3_path, index=False, encoding="utf-8-sig")
top3_gain[keep_cols].to_csv(top3g_path, index=False, encoding="utf-8-sig")
bot3_gain[keep_cols].to_csv(bot3g_path, index=False, encoding="utf-8-sig")

print("\n[Saved]")
print("  Full (with bertscore):", csv_with_bert)
print("  Top3 rag:", top3_path)
print("  Bot3 rag:", bot3_path)
print("  Top3 gain:", top3g_path)
print("  Bot3 gain:", bot3g_path)

[Info] Loaded: eval_results_pubmedqa_kbRAG_test/kbRAG_test_20251222_175710.csv | rows: 20
[Info] Columns: ['id', 'question', 'reference', 'gt_decision', 'base_pred', 'rag_pred', 'base_decision', 'rag_decision', 'rougeL_base', 'rougeL_rag', 'rag_context_preview']
[Info] BERTScore device: cuda




[Info] Computing BERTScore for BASE ...
[Info] Computing BERTScore for RAG ...

[Info] BERTScore summary:
  base mean: 0.7157986968755722
  rag  mean: 0.6796941459178925
  gain mean: -0.03610455095767975

[Top3] Highest RAG BERTScore:
 id  bert_f1_rag  bert_gain                                                                                                      question
  4     0.763997   0.064787 Body perception: do parents, their children, and their children's physicians perceive body image differently?
 11     0.755544   0.009216                     Does laparoscopic surgery decrease the risk of atrial fibrillation after foregut surgery?
 14     0.748675   0.035051                                                       Did Chile's traffic law reform push police enforcement?

[Bottom3] Lowest RAG BERTScore:
 id  bert_f1_rag  bert_gain                                                                                                                                                         

: 