In [None]:

# 0. IMPORTS & CONFIG

import json
import os
import re 
import pandas as pd
import networkx as nx
import torch
from datetime import datetime
from transformers import AutoTokenizer, AutoModelForCausalLM

# Automatically detect device
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# ---------- Please modify according to your actual paths ----------
# If you have a local GPT-2, fill in the local path; otherwise use "gpt2", "gpt2-medium", "gpt2-xl" directly
GPT2_PATH     = "/media/miaoen/ad4277ac-5cfe-47b0-a2cc-f9e50e0da444/LLM/gpt2-medmcqa-raft-masked"  
PRIMEKG_PATH  = "kg.csv" 
MEDMCQA_FILE  = r"/media/miaoen/ad4277ac-5cfe-47b0-a2cc-f9e50e0da444/LLM/data/medmcqa/dev.json"  # Note: Add 'r' for Windows paths

# ---------- RAG Parameters ----------
MAX_RAG_ENTITIES = 3       # Number of entities to retrieve
MAX_K_EDGES      = 3       # Number of edges retrieved per entity (reduce quantity to save Tokens)
MAX_CTX_CHARS    = 600     # Strictly limit context length to reserve generation space for GPT-2

print(f"Config loaded. Running on {DEVICE}")

Config loaded. Running on cuda


In [None]:

# Cell 2: Load GPT-2 Model

def load_gpt2(model_path):
    print(f"Loading GPT-2 from: {model_path}")
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        
    model = AutoModelForCausalLM.from_pretrained(model_path)
    model.to(DEVICE)
    model.eval()
    print("GPT-2 loaded.")
    return tokenizer, model

gpt2_tokenizer, gpt2_model = load_gpt2(GPT2_PATH)

def gpt2_generate(prompt, max_new_tokens=10):
    inputs = gpt2_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1000).to(DEVICE)
    
    with torch.no_grad():
        outputs = gpt2_model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=False,        # Greedy decoding
            pad_token_id=gpt2_tokenizer.eos_token_id,
            eos_token_id=gpt2_tokenizer.eos_token_id
        )
    
    full_text = gpt2_tokenizer.decode(outputs[0], skip_special_tokens=True)
    return full_text

Loading GPT-2 from: /media/miaoen/ad4277ac-5cfe-47b0-a2cc-f9e50e0da444/LLM/gpt2-medmcqa-raft-masked
GPT-2 loaded.


In [None]:

# Cell 3: Load Knowledge Graph (PrimeKG)

def load_kg(path):
    print(f"Loading PrimeKG from {path} ...")
    if not os.path.exists(path):
        print("Warning: KG file not found. Creating empty graph.")
        return nx.Graph()
        
    df = pd.read_csv(path, low_memory=False)
    G = nx.from_pandas_edgelist(df, "x_name", "y_name", edge_attr=True)
    print(f"KG Loaded: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")
    return G

G = load_kg(PRIMEKG_PATH)

# Create index
node_index = {str(n).lower(): n for n in G.nodes()}

def resolve_node(name):
    return node_index.get(name.lower()) if name else None

def get_knowledge_context_single(entity, max_edges=MAX_K_EDGES):
    node = resolve_node(entity)
    if node is None: return None

    try:
        edges = list(G.edges(node, data=True))
    except: return None
    
    if not edges: return None

    lines = []
    # Prioritize meaningful relations
    sorted_edges = sorted(edges, key=lambda x: x[2].get('relation', ''), reverse=True)
    
    for u, v, attr in sorted_edges[:max_edges]:
        rel = attr.get("display_relation", attr.get("relation", "related to"))
        nbr = v if node == u else u
        lines.append(f"{node} is {rel} {nbr}.")

    return " ".join(lines)

Loading PrimeKG from kg.csv ...
KG Loaded: 129262 nodes, 4049405 edges


In [None]:

# Cell 4: RAG Retrieval Logic (Includes optimized entity extraction + context construction)

import string
import re

# --- 1. Preprocessing: Build fast lookup index ---
# Store node names as a Set, lookup speed is O(1)
# This step is placed outside the function to run only once
valid_nodes_set = set(n.lower() for n in node_index.keys() if len(n) >= 4)

def extract_rag_entities(question, max_entities=MAX_RAG_ENTITIES):
    """
    Optimized version: Uses Set intersection for O(1) lookup, extremely fast.
    """
    q_lower = question.lower()
    
    # Simple tokenization (remove punctuation)
    translator = str.maketrans('', '', string.punctuation)
    clean_q = q_lower.translate(translator)
    tokens = clean_q.split()
    
    # Generate N-grams (supports matching phrases like "heart disease")
    candidates_in_q = set()
    
    # 1-gram (Single words)
    for token in tokens:
        candidates_in_q.add(token)
        
    # 2-gram (Bigrams)
    if len(tokens) >= 2:
        for i in range(len(tokens)-1):
            candidates_in_q.add(f"{tokens[i]} {tokens[i+1]}")
            
    # 3-gram (Trigrams)
    if len(tokens) >= 3:
        for i in range(len(tokens)-2):
            candidates_in_q.add(f"{tokens[i]} {tokens[i+1]} {tokens[i+2]}")

    # Set intersection operation (extremely fast)
    matched = list(candidates_in_q.intersection(valid_nodes_set))
    
    # Sort: Prioritize matching longer words
    matched.sort(key=lambda x: len(x), reverse=True)
    
    # Convert back to Node ID
    ents = []
    count = 0
    for m in matched:
        node_id = node_index.get(m)
        if node_id:
            ents.append(node_id)
            count += 1
            if count >= max_entities:
                break
    return ents

def get_knowledge_context_multi(question):
    """
    This is the function that was reported missing:
    Responsible for calling the extraction function, then converting extracted entities into text Facts.
    """
    # 1. Extract entities
    entities = extract_rag_entities(question)
    if not entities:
        return [], ""

    # 2. Find neighbors and concatenate text
    all_facts = []
    current_len = 0

    for ent in entities:
        # Call the single entity query function defined in Cell 3
        ctx = get_knowledge_context_single(ent)
        if not ctx: continue

        block = f"Fact: {ctx} "
        
        # Length truncation protection
        if current_len + len(block) > MAX_CTX_CHARS:
            break

        all_facts.append(block)
        current_len += len(block)

    return entities, "".join(all_facts)

print("Cell 4 (RAG Logic) updated successfully.")

Cell 4 (RAG Logic) updated successfully.


In [None]:

# Cell 5: Load Data & Answer Parsing Tools

# Replace the load_medmcqa_example function in Cell 5
def load_medmcqa_example(idx, file_path=MEDMCQA_FILE):
    if not os.path.exists(file_path):
        raise FileNotFoundError(f"File not found: {file_path}")
        
    with open(file_path, "r", encoding="utf-8") as f:
        for line_id, line in enumerate(f):
            if line_id < idx: continue
            line = line.strip()
            if not line: continue
            
            if line_id == idx:
                data = json.loads(line)
                q = data.get("question", "")
                
                # Extract independent options for scoring
                options_dict = {}
                mapping = {"opa":"A", "opb":"B", "opc":"C", "opd":"D"}
                options_text_list = []
                
                for key, lab in mapping.items():
                    if key in data and data[key]:
                        # Clean the option text
                        opt_val = str(data[key]).strip()
                        options_dict[lab] = opt_val
                        options_text_list.append(f"({lab}) {opt_val}")
                
                # Keep the original formatted_q for display (if still needed)
                formatted_q = f"{q}\nOptions:\n" + "\n".join(options_text_list)
                
                return {
                    "question_text": formatted_q,
                    "pure_question": q,
                    "options": options_dict,  # New: Return structured options
                    "answer": data.get("cop")
                }
            if line_id > idx: break
    raise IndexError("Index out of range")

# The parsing function `parse_answer` is no longer needed in scoring mode; it can be kept or deleted.

def parse_answer(text):
    """
    [Fix] Powerful regex parsing to extract A/B/C/D from messy generated text.
    """
    text = text.upper()
    # 1. Prioritize finding explicit formats like "Answer: A" or "Option A"
    match = re.search(r'(?:ANSWER|OPTION|CHOICE)[:\s\-]*([ABCD])', text)
    if match:
        return match.group(1)
    
    # 2. If not found, look for (A) enclosed in parentheses
    match = re.search(r'\(([ABCD])\)', text)
    if match:
        return match.group(1)
        
    # 3. Last resort: Only look at the last appearance of A/B/C/D
    # But be careful not to match letters inside words. Only match standalone A/B/C/D
    candidates = re.findall(r'\b([ABCD])\b', text)
    if candidates:
        return candidates[-1]
        
    return "Random" # Really cannot parse

In [None]:

# Cell 6: Inference Functions (Fixed Version - Only calculates Loss for the option part)

import torch
import torch.nn.functional as F

def get_lm_score(context, question, option_text):
    """
    Advanced Scoring: Use labels = -100 to mask the Question part,
    calculating Loss (Conditional Likelihood) only for the Option part.
    This effectively solves the "long question diluting option" issue.
    """
    # 1. Construct Prompt and Full Text
    # Note: Add a space after "Answer:" to match GPT-2 conventions
    if context:
        prompt = f"Context: {context}\nQuestion: {question}\nAnswer:"
    else:
        prompt = f"Question: {question}\nAnswer:"
        
    full_text = f"{prompt} {option_text}"
    
    # 2. Encode separately
    # We need to know the length of the prompt to mask it out
    prompt_enc = gpt2_tokenizer(prompt, return_tensors="pt", add_special_tokens=False)
    full_enc = gpt2_tokenizer(full_text, return_tensors="pt", truncation=True, max_length=1024)
    
    input_ids = full_enc.input_ids.to(DEVICE)
    attention_mask = full_enc.attention_mask.to(DEVICE)
    
    # 3. Construct Labels
    # In HuggingFace, if Labels are set to -100, that position is ignored during Loss calculation
    labels = input_ids.clone()
    
    # Get the length of the prompt
    # Note: Be careful if full_text is truncated, though usually the prompt won't be too long
    prompt_len = prompt_enc.input_ids.shape[1]
    
    # Set labels for the prompt part to -100 (excluded from Loss)
    # Perform masking as long as prompt_len is less than the total length
    if prompt_len < labels.shape[1]:
        labels[:, :prompt_len] = -100
    else:
        # Extremely rare case: prompt is longer than truncated length (almost never happens)
        labels[:, :] = -100 
        
    # 4. Calculate Loss
    with torch.no_grad():
        # The model automatically handles labels of -100
        outputs = gpt2_model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        
        # If everything is masked (e.g., option is too long and truncated), loss might be nan
        if torch.isnan(outputs.loss):
            return -float('inf')
            
        loss = outputs.loss.item()
    
    # Return negative Loss (the higher, the better)
    return -loss

# The solve_by_scoring function does not need changes; keep it as is
def solve_by_scoring(question, options_dict, context=""):
    scores = {}
    for label in ["A", "B", "C", "D"]:
        if label not in options_dict:
            scores[label] = -float('inf')
            continue
        opt_text = options_dict[label]
        score = get_lm_score(context, question, opt_text)
        scores[label] = score
    best_label = max(scores, key=scores.get)
    return best_label, scores

In [None]:
# Replace evaluate_gpt2_and_save in Cell 7
def evaluate_gpt2_and_save(start_idx=0, end_idx=20, output_dir="eval_results"):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    save_path = os.path.join(output_dir, f"gpt2_scoring_results_{timestamp}.jsonl")
    
    print(f"Results will be saved to: {save_path}")
    print(f"Starting SCORING evaluation {start_idx} -> {end_idx}...\n")

    total = 0
    correct_base = 0
    correct_rag = 0
    
    def format_gt(val):
        if val is None: return None
        s = str(val).strip().upper()
        # Handle cases where it might be 1/2/3/4
        mapping = {'1':'A', '2':'B', '3':'C', '4':'D'}
        return mapping.get(s, s) if s in mapping else (s if s in ['A','B','C','D'] else None)

    with open(save_path, "w", encoding="utf-8") as f:
        for i in range(start_idx, end_idx):
            try:
                ex = load_medmcqa_example(i)
            except Exception as e:
                print(f"Skipping {i}: {e}")
                continue
                
            gt = format_gt(ex['answer'])
            if not gt: continue
            
            q_text = ex['pure_question']
            opts = ex['options']
            
            # --- 1. Base Inference (Scoring Method) ---
            pred_base, scores_base = solve_by_scoring(q_text, opts, context="")
            
            # --- 2. RAG Inference (Scoring Method) ---
            # Retrieve first
            entities, ctx = get_knowledge_context_multi(q_text)
            pred_rag, scores_rag = solve_by_scoring(q_text, opts, context=ctx)
            
            # --- 3. Statistics ---
            total += 1
            if pred_base == gt: correct_base += 1
            if pred_rag == gt: correct_rag += 1
            
            # --- 4. Recording ---
            record = {
                "id": i,
                "question": q_text,
                "ground_truth": gt,
                "base_prediction": pred_base,
                "rag_prediction": pred_rag,
                "rag_context": ctx,
                "is_base_correct": (pred_base == gt),
                "is_rag_correct": (pred_rag == gt)
            }
            
            f.write(json.dumps(record, ensure_ascii=False) + "\n")
            
            # Print status
            status = ""
            if pred_rag == gt and pred_base != gt: status = "RAG Fix"
            elif pred_rag != gt and pred_base == gt: status = "RAG Hurt"
            
            print(f"[{i}] GT:{gt} | Base:{pred_base} | RAG:{pred_rag} | {status}")

    if total > 0:
        print(f"\n{'='*30}")
        print(f"Total Samples: {total}")
        print(f"Base Accuracy: {correct_base/total:.2%}")
        print(f"RAG  Accuracy: {correct_rag/total:.2%}")
        print(f"{'='*30}\n")

# Run evaluation
evaluate_gpt2_and_save(0, 4183)

Results will be saved to: eval_results/gpt2_scoring_results_20251215_154935.jsonl
Starting SCORING evaluation 0 -> 4183...

[0] GT:A | Base:A | RAG:A | 
[1] GT:A | Base:B | RAG:B | 
[2] GT:C | Base:C | RAG:A | RAG Hurt
[3] GT:C | Base:C | RAG:C | 
[4] GT:A | Base:C | RAG:C | 
[5] GT:A | Base:A | RAG:A | 
[6] GT:A | Base:A | RAG:A | 
[7] GT:B | Base:A | RAG:A | 
[8] GT:B | Base:D | RAG:D | 
[9] GT:B | Base:D | RAG:D | 
[10] GT:B | Base:A | RAG:A | 
[11] GT:A | Base:A | RAG:A | 
[12] GT:B | Base:A | RAG:A | 
[13] GT:A | Base:D | RAG:D | 
[14] GT:A | Base:C | RAG:C | 
[15] GT:B | Base:A | RAG:A | 
[16] GT:A | Base:A | RAG:A | 
[17] GT:D | Base:C | RAG:C | 
[18] GT:A | Base:B | RAG:B | 
[19] GT:A | Base:D | RAG:D | 
[20] GT:B | Base:B | RAG:B | 
[21] GT:C | Base:A | RAG:A | 
[22] GT:A | Base:D | RAG:D | 
[23] GT:A | Base:D | RAG:D | 
[24] GT:D | Base:C | RAG:C | 
[25] GT:B | Base:C | RAG:C | 
[26] GT:A | Base:C | RAG:C | 
[27] GT:D | Base:D | RAG:D | 
[28] GT:A | Base:D | RAG:D | 
[29] GT:

Overall Accuracy

In [None]:
import json
import os
import pandas as pd
from collections import defaultdict

# -------------------------
# 1. Set Correct File Paths
# -------------------------
# Please ensure this is the new filename generated by the "evaluation code" just now!
path = r"eval_results/gpt2_scoring_results_20251215_154935.jsonl"  # <--- Please manually modify the timestamp here

# Automatically find the latest file (If you don't want to change the filename manually, use this code to find the latest one)
def get_latest_file(directory, prefix="gpt2_scoring_results"):
    if not os.path.exists(directory): return None
    files = [os.path.join(directory, f) for f in os.listdir(directory) if f.startswith(prefix) and f.endswith(".jsonl")]
    if not files: return None
    return max(files, key=os.path.getmtime)

latest_path = get_latest_file("eval_results")
if latest_path:
    print(f"Auto-detected latest file: {latest_path}")
    path = latest_path
else:
    print(f"Using manual path: {path}")

# -------------------------
# 2. Read and Check Data
# -------------------------
rows = []
if os.path.exists(path):
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            if line.strip():
                rows.append(json.loads(line))
else:
    print(f"  Error: File not found at {path}")
    rows = []

if not rows:
    print("  No data found.")
else:
    # --- Debug: Print all Keys of the first row ---
    print(f"  Loaded {len(rows)} rows.")
    first_row = rows[0]
    print("Keys in the first row:", list(first_row.keys()))
    
    # Check if key fields exist
    if "is_base_correct" not in first_row:
        print("\n   WARNING: 'is_base_correct' key is MISSING!")
        print("Did you run the old evaluation code? The new scoring code should produce this key.")
        # Attempt auto-repair (if ground_truth and base_prediction exist)
        if "ground_truth" in first_row and "base_prediction" in first_row:
            print("  Attempting to recalculate correctness on the fly...")
            for r in rows:
                r["is_base_correct"] = (r["base_prediction"] == r["ground_truth"])
                r["is_rag_correct"] = (r["rag_prediction"] == r["ground_truth"])
        else:
            print("  Cannot calculate accuracy. Missing prediction data.")

# -------------------------
# 3. Statistics and Analysis
# -------------------------
if rows and "is_base_correct" in rows[0]:
    total = len(rows)
    base_correct = sum(1 for r in rows if r.get("is_base_correct", False))
    rag_correct = sum(1 for r in rows if r.get("is_rag_correct", False))

    print("\n========== Overall Accuracy====")
    print(f"Total samples : {total}")
    print(f"Base Accuracy : {base_correct/total:.2%}")
    print(f"RAG  Accuracy : {rag_correct/total:.2%}")
    print(f"Net Gain      : {rag_correct - base_correct} samples")

    # RAG Effect
    rag_fix = sum(1 for r in rows if not r.get("is_base_correct") and r.get("is_rag_correct"))
    rag_hurt = sum(1 for r in rows if r.get("is_base_correct") and not r.get("is_rag_correct"))

    print("\n========== RAG Effect====")
    print(f"RAG Fix  (Base Wrong -> RAG Right): {rag_fix}")
    print(f"RAG Hurt (Base Right -> RAG Wrong): {rag_hurt}")

Auto-detected latest file: eval_results/gpt2_scoring_results_20251215_154935.jsonl
  Loaded 4183 rows.
Keys in the first row: ['id', 'question', 'ground_truth', 'base_prediction', 'rag_prediction', 'rag_context', 'is_base_correct', 'is_rag_correct']

Total samples : 4183
Base Accuracy : 25.72%
RAG  Accuracy : 25.84%
Net Gain      : 5 samples

RAG Fix  (Base Wrong -> RAG Right): 70
RAG Hurt (Base Right -> RAG Wrong): 65


In [None]:
import json
import pandas as pd
import os
from collections import Counter

# 1. Automatically find the latest result file
def get_latest_file(directory, prefix="gpt2_scoring_results"):
    if not os.path.exists(directory): return None
    files = [os.path.join(directory, f) for f in os.listdir(directory) if f.startswith(prefix) and f.endswith(".jsonl")]
    if not files: return None
    return max(files, key=os.path.getmtime)

path = get_latest_file("eval_results")
print(f"Analyzing file: {path}")

# 2. Read data
data = []
with open(path, 'r', encoding='utf-8') as f:
    for line in f:
        data.append(json.loads(line))

df = pd.DataFrame(data)

# 3. Core Verification: Check A/B/C/D distribution
print("\n========== Prediction Distribution (Check if 'only selecting A/B' issue is resolved)====")
base_counts = df['base_prediction'].value_counts().sort_index()
rag_counts = df['rag_prediction'].value_counts().sort_index()
gt_counts = df['ground_truth'].value_counts().sort_index()

dist_df = pd.DataFrame({
    'Ground Truth': gt_counts,
    'Base Preds': base_counts,
    'RAG Preds': rag_counts
}).fillna(0).astype(int)

print(dist_df)

# 4. Did it actually select C/D?
c_d_selected = df[df['rag_prediction'].isin(['C', 'D'])].shape[0]
print(f"\nCount of C or D selected in RAG mode: {c_d_selected} / {len(df)}")
if c_d_selected > 0:
    print("  Success! The model has started selecting C and D.")
else:
    print("  Warning: The model still has not selected C or D.")

# 5. Error Analysis: Look at incorrectly predicted examples
print("\n========== Error Case Analysis (Examples where RAG was wrong)====")
wrong_samples = df[~df['is_rag_correct']].head(3)

for idx, row in wrong_samples.iterrows():
    print(f"ID: {row['id']}")
    print(f"Question: {row['question'][:100]}...")
    print(f"Ground Truth: {row['ground_truth']} | Model Prediction: {row['rag_prediction']}")
    # Print score comparison (if available)
    if 'scores_rag' in row:
        scores = row['scores_rag']
        sorted_scores = sorted(scores.items(), key=lambda x: x[1], reverse=True)
        print(f"Score Ranking: {sorted_scores}")
    print("-" * 50)

In [None]:
#
# Layer 1: Basic Imports & Configuration
#
import json
import os
import pickle
import numpy as np
import torch
import faiss
from sentence_transformers import SentenceTransformer
from transformers import GPT2LMHeadModel, GPT2Tokenizer

# ---------------------------------------------------------
# [Configuration Area] Please modify variables below according to your actual paths
# ---------------------------------------------------------
GPT2_PATH           = "/media/miaoen/ad4277ac-5cfe-47b0-a2cc-f9e50e0da444/LLM/gpt2"
MEDMCQA_FILE        = "/media/miaoen/ad4277ac-5cfe-47b0-a2cc-f9e50e0da444/LLM/data/medmcqa/dev.json"

# Knowledge Base File Paths (Ensure these files are in the current directory, or use absolute paths)
FAISS_INDEX_PATH    = "pubmed_qa.index"
DOCS_PKL_PATH       = "pubmed_documents.pkl"

# Embedding Model (Must match the model used when building the index)
EMBED_MODEL_NAME    = "all-MiniLM-L6-v2" 

DEVICE              = "cuda" if torch.cuda.is_available() else "cpu"

# RAG Parameters
TOP_K_DOCS          = 2     # Retrieve the top 2 most relevant abstracts
MAX_CTX_CHARS       = 2000  # Maximum context characters (Prevent GPT-2 memory overflow/context limit issues)

print(f"Config OK. DEVICE = {DEVICE}")

#
# Layer 2: Model Layer (GPT-2 Loading & Fixed Generation Function)
#
def load_gpt2(model_path: str = GPT2_PATH):
    print(f"Loading GPT-2 from {model_path} ...")
    try:
        tokenizer = GPT2Tokenizer.from_pretrained(model_path)
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
        model = GPT2LMHeadModel.from_pretrained(model_path).to(DEVICE)
        print(f"GPT-2 loaded.")
        return tokenizer, model
    except Exception as e:
        print(f"Error loading GPT-2: {e}")
        return None, None

tokenizer, model = load_gpt2()

def gpt2_generate(prompt: str, max_new_tokens: int = 120, do_sample: bool = True, temperature: float = 0.7, top_p: float = 0.9):
    """
    General generation function, UserWarning issues fixed.
    """
    inputs = tokenizer.encode(prompt, return_tensors="pt").to(DEVICE)
    
    # Input length truncation protection (GPT-2 context is usually 1024)
    if inputs.shape[1] > 900:
        inputs = inputs[:, -900:]
        
    attention_mask = torch.ones_like(inputs)

    # Dynamically build parameters to avoid errors when passing temperature with do_sample=False
    gen_kwargs = {
        "max_new_tokens": max_new_tokens,
        "do_sample": do_sample,
        "top_p": top_p,
        "no_repeat_ngram_size": 3,
        "pad_token_id": tokenizer.eos_token_id,
        "attention_mask": attention_mask,
    }

    # Only pass temperature when sampling is enabled
    if do_sample:
        gen_kwargs["temperature"] = temperature

    with torch.no_grad():
        outputs = model.generate(inputs, **gen_kwargs)
    
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

#
# Layer 3: Knowledge Base Layer (Load FAISS & Documents)
#
def load_retrieval_system():
    if not os.path.exists(FAISS_INDEX_PATH) or not os.path.exists(DOCS_PKL_PATH):
        print(f"Error: Knowledge base files not found. Please check if {FAISS_INDEX_PATH} and {DOCS_PKL_PATH} exist.")
        return None, None, None

    print(f"Loading Embedding Model: {EMBED_MODEL_NAME} ...")
    embed_model = SentenceTransformer(EMBED_MODEL_NAME)
    
    print(f"Loading FAISS Index ...")
    index = faiss.read_index(FAISS_INDEX_PATH)
    
    print(f"Loading Documents ...")
    with open(DOCS_PKL_PATH, "rb") as f:
        documents = pickle.load(f)
        
    print(f"Knowledge Base Loaded! Index size: {index.ntotal}, Docs count: {len(documents)}")
    return embed_model, index, documents

# Initialize global variables
embed_model, faiss_index, doc_store = load_retrieval_system()

#
# Layer 4: Vector Retrieval RAG -- Semantic Search Core Logic
#
def get_pubmed_context(question_text: str, top_k: int = TOP_K_DOCS) -> str:
    """
    1. Question to Vector
    2. FAISS Search for Similar Document IDs
    3. Extract Text and Concatenate
    """
    if faiss_index is None:
        return ""

    # 1. Encode
    q_emb = embed_model.encode([question_text], convert_to_numpy=True)
    
    # 2. Search
    distances, indices = faiss_index.search(q_emb, top_k)
    
    # 3. Fetch Text
    retrieved_texts = []
    current_chars = 0
    
    for idx_in_store in indices[0]:
        if idx_in_store == -1: continue # Placeholder when FAISS finds nothing
        
        if idx_in_store >= len(doc_store): continue # Prevent index out of bounds

        doc_content = doc_store[idx_in_store]
        
        # Simple cleaning
        clean_content = doc_content.replace("\n", " ").strip()
        if not clean_content: continue

        # Length check
        if current_chars + len(clean_content) > MAX_CTX_CHARS:
            remaining = MAX_CTX_CHARS - current_chars
            retrieved_texts.append(f"Abstract: {clean_content[:remaining]}...")
            break
        
        retrieved_texts.append(f"Abstract: {clean_content}")
        current_chars += len(clean_content)
    
    if not retrieved_texts:
        return ""
    
    return "\n\n".join(retrieved_texts)

#
# Layer 5: MedMCQA Data Loading
#
def load_medmcqa_example(idx: int = 0, file_path: str = MEDMCQA_FILE):
    """
    Read the idx-th sample from the MedMCQA dataset
    """
    with open(file_path, "r", encoding="utf-8") as f:
        line_idx = 0
        for line in f:
            line = line.strip()
            if not line: continue
            if line_idx == idx:
                data = json.loads(line)
                
                # Construct question stem
                q = data.get("question") or data.get("Question") or ""
                
                # Construct options
                options_lines = []
                option_map = {"opa": "A", "opb": "B", "opc": "C", "opd": "D"}
                found = False
                for k, lab in option_map.items():
                    if k in data:
                        found = True
                        options_lines.append(f"{lab}) {data[k]}")
                if not found and "options" in data:
                    for i, opt in enumerate(data["options"]):
                        options_lines.append(f"{chr(ord('A')+i)}) {opt}")
                
                question_text = q
                if options_lines:
                    question_text += "\nOptions:\n" + "\n".join(options_lines)
                
                # Get answer
                answer = data.get("cop") or data.get("answer") or data.get("label")
                return {"raw": data, "question_text": question_text, "answer": answer}
            line_idx += 1
    raise IndexError(f"Index {idx} out of range")

#
# Layer 6: GPT-2 QA Interface (No RAG vs Vector RAG)
#
def qa_no_rag(question: str) -> str:
    """
    Baseline: No database query, ask GPT-2 directly
    """
    prompt = (
        "You are a medical exam solver.\n"
        "You will be given one multiple-choice question with options A, B, C, and D.\n"
        "Choose the single best option and reply with ONLY one capital letter: A, B, C, or D.\n"
        "Do not output anything else.\n\n"
        f"{question}\n\n"
        "Answer (A, B, C, or D):"
    )
    # do_sample=False indicates greedy search (deterministic results)
    full_text = gpt2_generate(prompt, max_new_tokens=8, do_sample=False)
    
    # Extract the last letter
    tail = full_text.strip()
    for ch in reversed(tail):
        if ch in ["A", "B", "C", "D"]: return ch
    return tail

def qa_with_rag_vector(question: str):
    """
    Vector RAG: Query Database -> Concatenate Context -> Ask GPT-2
    """
    # 1. Get context
    context = get_pubmed_context(question, top_k=TOP_K_DOCS)
    
    if not context:
        return "", qa_no_rag(question)

    # 2. Construct Prompt
    prompt = (
        "You are a medical exam solver.\n"
        "Below are some relevant research abstracts retrieved from PubMed.\n"
        "Use this context to help answer the question.\n"
        "You will be given one multiple-choice question with options A, B, C, and D.\n"
        "Choose the single best option and reply with ONLY one capital letter: A, B, C, or D.\n"
        "Do not output anything else.\n\n"
        f"Context:\n{context}\n\n"
        f"Question:\n{question}\n\n"
        "Answer (A, B, C, or D):"
    )
    
    full_text = gpt2_generate(prompt, max_new_tokens=8, do_sample=False)
    
    tail = full_text.strip()
    for ch in reversed(tail):
        if ch in ["A", "B", "C", "D"]:
            return context, ch

    return context, tail

#
# Layer 7: Comparison Test Function (Single Item)
#
def compare_rag_medmcqa_vector(idx: int = 0, max_print_chars: int = 600, save_dir: str = "rag_logs_vector"):
    """
    Run a single test and print a detailed report
    """
    try:
        example = load_medmcqa_example(idx)
    except IndexError:
        print(f"Index {idx} out of range.")
        return

    question = example["question_text"]
    gt_raw = example["answer"]

    # Unify answer format to A/B/C/D
    def _to_letter(x):
        if x is None: return None
        s = str(x).strip()
        if s in ["A", "B", "C", "D"]: return s
        if s.isdigit() and 1 <= int(s) <= 4: return chr(ord("A") + int(s) - 1)
        return s

    gt = _to_letter(gt_raw)

    # 1) No RAG
    ans_no = qa_no_rag(question)
    
    # 2) Vector RAG
    retrieved_ctx, ans_rag = qa_with_rag_vector(question)

    correct_no  = (str(ans_no) == str(gt))
    correct_rag = (str(ans_rag) == str(gt))

    # Print report
    os.makedirs(save_dir, exist_ok=True)
    out_path = os.path.join(save_dir, f"medmcqa_idx_{idx}.txt")
    
    def _short(s): return s if len(s) <= max_print_chars else s[:max_print_chars] + "..."

    print("=" * 60)
    print(f"MedMCQA Sample idx = {idx}")
    print("Question:")
    print(_short(question))
    print(f"\nCorrect Answer (GT): {gt}")
    print("-" * 60)
    print(f"[No RAG] Prediction: {ans_no} | Correct? {correct_no}")
    print("-" * 60)
    print(f"[Vector RAG] Prediction: {ans_rag} | Correct? {correct_rag}")
    print("\n[Retrieved Context (Top 2)]:")
    if retrieved_ctx:
        print(_short(retrieved_ctx))
    else:
        print("(No relevant documents)")
    print("=" * 60)
    
    # Save to file
    with open(out_path, "w", encoding="utf-8") as f:
        f.write(f"Q: {question}\nGT: {gt}\n\nCTX:\n{retrieved_ctx}\n\nPred_No: {ans_no}\nPred_RAG: {ans_rag}")

Config OK. DEVICE = cuda
Loading GPT-2 from /media/miaoen/ad4277ac-5cfe-47b0-a2cc-f9e50e0da444/LLM/gpt2 ...
GPT-2 loaded.
Loading Embedding Model: all-MiniLM-L6-v2 ...
Loading FAISS Index ...
Loading Documents ...
Knowledge Base Loaded! Index size: 800, Docs count: 800


In [None]:
import pandas as pd
import os
from datetime import datetime # Import datetime module

#
# Layer 8: Batch Evaluation (Auto Timestamp + Directory Management)
#
def evaluate_medmcqa_acc(start_idx=0, end_idx=100, output_file=None):
    """
    Batch test and save results.
    Arguments:
        output_file: (Optional) Specify filename. If not provided, automatically generated based on current time.
    """
    
    # --- 1. Automatically build filename with timestamp ---
    if output_file is None:
        # Get current time, format: YYYYMMDD_HHMMSS
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        output_file = f"rag_eval_{timestamp}.csv"
    
    # --- 2. Automatically create results folder (Optional, for tidiness) ---
    output_dir = "results"
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
        print(f"Output directory created: {output_dir}")
    
    # Combine full path: results/rag_eval_2025xxxx.csv
    full_output_path = os.path.join(output_dir, output_file)
    
    print(f"Start batch evaluation: Index {start_idx} -> {end_idx}")
    print(f"Results will be saved to: {full_output_path}")
    
    results = [] 
    total = 0
    correct_no = 0
    correct_rag = 0
    improved = 0
    worsened = 0
    
    for idx in range(start_idx, end_idx):
        try:
            ex = load_medmcqa_example(idx)
        except: 
            continue
        
        q = ex["question_text"]
        raw_ans = ex["answer"]
        
        # Parse Ground Truth (GT)
        gt = None
        if raw_ans and str(raw_ans).strip() in ["A","B","C","D"]: 
            gt = str(raw_ans).strip()
        elif raw_ans and str(raw_ans).isdigit(): 
            gt = chr(ord("A") + int(raw_ans) - 1)
        
        if not gt: continue 

        # === Core Prediction ===
        pred_no = qa_no_rag(q)
        ctx_rag, pred_rag = qa_with_rag_vector(q)
        
        # === Statistics ===
        is_correct_no = (pred_no == gt)
        is_correct_rag = (pred_rag == gt)
        
        if is_correct_no: correct_no += 1
        if is_correct_rag: correct_rag += 1
        total += 1
        
        status = "Same"
        if not is_correct_no and is_correct_rag:
            status = "Improved"
            improved += 1
        elif is_correct_no and not is_correct_rag:
            status = "Worsened"
            worsened += 1
        
        results.append({
            "Index": idx,
            "Question": q,
            "Ground_Truth": gt,
            "Pred_No_RAG": pred_no,
            "Correct_No_RAG": is_correct_no,
            "Pred_Vector_RAG": pred_rag,
            "Correct_Vector_RAG": is_correct_rag,
            "Status": status,
            "Retrieved_Context": ctx_rag
        })

        print(f"[{idx}] GT:{gt} | NoRAG:{pred_no} {'O' if is_correct_no else 'X'} | RAG:{pred_rag} {'O' if is_correct_rag else 'X'} | {status}")

    # === Calculate Final Metrics ===
    acc_no = correct_no / total if total > 0 else 0
    acc_rag = correct_rag / total if total > 0 else 0
    
    print(f"\nFinal Results ({total} questions):")
    print(f"Pure GPT-2 Accuracy : {acc_no:.4f}")
    print(f"RAG (Vector) Accuracy: {acc_rag:.4f}")
    
    # === Save File ===
    if results:
        df = pd.DataFrame(results)
        df["Global_Acc_No_RAG"] = f"{acc_no:.2%}"
        df["Global_Acc_Vector_RAG"] = f"{acc_rag:.2%}"
        
        # 1. Save detailed CSV
        df.to_csv(full_output_path, index=False, encoding="utf-8-sig")
        print(f"Detailed results saved: {full_output_path}")
        
        # 2. Save Summary CSV (Filename automatically adds _summary)
        base, ext = os.path.splitext(output_file) # Separate filename and extension
        summary_filename = f"{base}_summary{ext}"
        full_summary_path = os.path.join(output_dir, summary_filename)
        
        summary_data = [{
            "Timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), # Record exact time
            "Range": f"{start_idx}-{end_idx}",
            "Total_Questions": total,
            "Acc_No_RAG": acc_no,
            "Acc_Vector_RAG": acc_rag,
            "Improved_Count": improved,
            "Worsened_Count": worsened
        }]
        pd.DataFrame(summary_data).to_csv(full_summary_path, index=False, encoding="utf-8-sig")
        print(f"Summary statistics saved: {full_summary_path}")
        
    else:
        print("No results generated, skipping save.")

#
# Run Example
#
if __name__ == "__main__":
    # Method 1: No arguments passed, automatically generate filename with timestamp
    evaluate_medmcqa_acc(0, 4183)
    
    # Method 2: If you want to specify a name, you can pass it (will also be saved in results folder)
    # evaluate_medmcqa_acc(0, 50, output_file="my_custom_experiment.csv")