In [None]:
# Layer 1: Basic Imports & Configuration
import json
import os
import pickle
import numpy as np
import torch
import faiss
import pandas as pd
from datetime import datetime
from sentence_transformers import SentenceTransformer
from transformers import GPT2LMHeadModel, GPT2Tokenizer
# [New] PEFT Library Import
from peft import PeftModel, PeftConfig


# [Configuration Area] Please modify variables below according to your actual paths
# Original GPT-2 Base Path
GPT2_BASE_PATH      = "/media/miaoen/ad4277ac-5cfe-47b0-a2cc-f9e50e0da444/LLM/gpt2"

# [New] Adapter Path trained via Prompt Tuning
ADAPTER_PATH        = "/media/miaoen/ad4277ac-5cfe-47b0-a2cc-f9e50e0da444/LLM/out_gpt2_official_prompt_tuning_e1"

MEDMCQA_FILE        = "/media/miaoen/ad4277ac-5cfe-47b0-a2cc-f9e50e0da444/LLM/data/medmcqa/dev.json"

# Knowledge Base File Paths
FAISS_INDEX_PATH    = "pubmed_qa.index"
DOCS_PKL_PATH       = "pubmed_documents.pkl"

# Embedding Model
EMBED_MODEL_NAME    = "all-MiniLM-L6-v2" 

DEVICE              = "cuda" if torch.cuda.is_available() else "cpu"

# RAG Parameters
TOP_K_DOCS          = 2     # Retrieve the top 2 most relevant abstracts
MAX_CTX_CHARS       = 2000  # Maximum context characters

print(f"Config OK. DEVICE = {DEVICE}")


# Layer 2: Model Layer (GPT-2 Loading & Fixed Generation Function)
def load_peft_model(base_path: str, adapter_path: str):
    print(f"Loading Tokenizer & Base Model from {base_path} ...")
    try:
        tokenizer = GPT2Tokenizer.from_pretrained(base_path)
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
        
        # 1. Load original base model
        base_model = GPT2LMHeadModel.from_pretrained(base_path).to(DEVICE)
        
        # 2. Load Prompt Tuning Adapter
        print(f"Loading Prompt Tuning Adapter from {adapter_path} ...")
        model = PeftModel.from_pretrained(base_model, adapter_path).to(DEVICE)
        
        # [Critical] Set model to eval mode, important for Dropout/BN and stable generation
        model.eval()
        
        print(f"PEFT Model loaded successfully.")
        return tokenizer, model
    except Exception as e:
        print(f"Error loading models: {e}")
        return None, None

tokenizer, model = load_peft_model(GPT2_BASE_PATH, ADAPTER_PATH)

def gpt2_generate(prompt: str, use_adapter: bool = True, max_new_tokens: int = 10, do_sample: bool = False):
    """
    General generation function (Fixed version)
    """
    inputs = tokenizer.encode(prompt, return_tensors="pt").to(DEVICE)
    
    # Input length truncation protection
    if inputs.shape[1] > 900:
        inputs = inputs[:, -900:]
        
    attention_mask = torch.ones_like(inputs)

    gen_kwargs = {
        "max_new_tokens": max_new_tokens,
        "do_sample": do_sample,
        "pad_token_id": tokenizer.eos_token_id,
        "attention_mask": attention_mask,
        # "use_cache": True # Explicitly enabling cache usually speeds up, but causes shape issues in rare versions. Default is usually True
    }

    with torch.no_grad():
        if use_adapter:
            #Prompt Tuning Mode
            # Use model (PeftModel) directly for generation, it automatically appends soft prompts
            outputs = model.generate(inputs, **gen_kwargs)
        else:
            #Base Raw Mode
            # [Critical Fix]: Use disable_adapter() context and call base_model.generate directly
            # This avoids interference from the PeftModel wrapper while in disabled state
            with model.disable_adapter():
                outputs = model.base_model.generate(inputs, **gen_kwargs)
    
    return tokenizer.decode(outputs[0], skip_special_tokens=True)


# Layer 3: Knowledge Base Layer (Load FAISS & Documents)
def load_retrieval_system():
    if not os.path.exists(FAISS_INDEX_PATH) or not os.path.exists(DOCS_PKL_PATH):
        print(f"Error: Knowledge base files not found. Please check if {FAISS_INDEX_PATH} and {DOCS_PKL_PATH} exist.")
        return None, None, None

    print(f"Loading Embedding Model: {EMBED_MODEL_NAME} ...")
    embed_model = SentenceTransformer(EMBED_MODEL_NAME)
    
    print(f"Loading FAISS Index ...")
    index = faiss.read_index(FAISS_INDEX_PATH)
    
    print(f"Loading Documents ...")
    with open(DOCS_PKL_PATH, "rb") as f:
        documents = pickle.load(f)
        
    print(f"Knowledge Base Loaded! Index size: {index.ntotal}, Docs count: {len(documents)}")
    return embed_model, index, documents

embed_model, faiss_index, doc_store = load_retrieval_system()


# Layer 4: Vector Retrieval RAG Logic
def get_pubmed_context(question_text: str, top_k: int = TOP_K_DOCS) -> str:
    if faiss_index is None: return ""
    
    q_emb = embed_model.encode([question_text], convert_to_numpy=True)
    distances, indices = faiss_index.search(q_emb, top_k)
    
    retrieved_texts = []
    current_chars = 0
    
    for idx_in_store in indices[0]:
        if idx_in_store == -1 or idx_in_store >= len(doc_store): continue
        doc_content = doc_store[idx_in_store].replace("\n", " ").strip()
        if not doc_content: continue

        if current_chars + len(doc_content) > MAX_CTX_CHARS:
            remaining = MAX_CTX_CHARS - current_chars
            retrieved_texts.append(f"Abstract: {doc_content[:remaining]}...")
            break
        
        retrieved_texts.append(f"Abstract: {doc_content}")
        current_chars += len(doc_content)
    
    if not retrieved_texts: return ""
    return "\n\n".join(retrieved_texts)


# Layer 5: MedMCQA Data Loading
def load_medmcqa_example(idx: int = 0, file_path: str = MEDMCQA_FILE):
    with open(file_path, "r", encoding="utf-8") as f:
        line_idx = 0
        for line in f:
            line = line.strip()
            if not line: continue
            if line_idx == idx:
                data = json.loads(line)
                q = data.get("question") or data.get("Question") or ""
                options_lines = []
                option_map = {"opa": "A", "opb": "B", "opc": "C", "opd": "D"}
                found = False
                for k, lab in option_map.items():
                    if k in data:
                        found = True
                        options_lines.append(f"{lab}) {data[k]}")
                if not found and "options" in data:
                    for i, opt in enumerate(data["options"]):
                        options_lines.append(f"{chr(ord('A')+i)}) {opt}")
                
                question_text = q
                if options_lines:
                    question_text += "\nOptions:\n" + "\n".join(options_lines)
                
                answer = data.get("cop") or data.get("answer") or data.get("label")
                return {"raw": data, "question_text": question_text, "answer": answer}
            line_idx += 1
    raise IndexError(f"Index {idx} out of range")


# Layer 6: Comparison QA Interface
# 1. Base Model (Raw, No RAG, Adapter Disabled)
# 2. Prompt Tuning Model (Adapter Enabled) + RAG

def extract_answer(text: str) -> str:
    """Extract the last valid option from generated text"""
    tail = text.strip()
    for ch in reversed(tail):
        if ch in ["A", "B", "C", "D"]: return ch
    return tail

def qa_baseline_raw(question: str) -> str:
    """
    [Baseline] Pure Base Model, No Adapter, No RAG
    """
    prompt = (
        "You are a medical exam solver.\n"
        "Choose the single best option and reply with ONLY one capital letter: A, B, C, or D.\n\n"
        f"{question}\n\n"
        "Answer (A, B, C, or D):"
    )
    # use_adapter=False -> Disable Prompt Tuning weights
    full_text = gpt2_generate(prompt, use_adapter=False, max_new_tokens=5, do_sample=False)
    return extract_answer(full_text)

def qa_pt_rag(question: str):
    """
    [Experimental] Prompt Tuning Model + RAG Context
    """
    # 1. Retrieve
    context = get_pubmed_context(question, top_k=TOP_K_DOCS)
    
    # 2. Construct Prompt (RAG Augmented)
    prompt = (
        "You are a medical exam solver.\n"
        "Below are relevant research abstracts retrieved from PubMed.\n"
        f"Context:\n{context}\n\n"
        "Using the context above, answer the question below with ONLY one capital letter: A, B, C, or D.\n\n"
        f"Question:\n{question}\n\n"
        "Answer (A, B, C, or D):"
    )
    
    # use_adapter=True -> Enable Prompt Tuning weights
    full_text = gpt2_generate(prompt, use_adapter=True, max_new_tokens=5, do_sample=False)
    
    return context, extract_answer(full_text)


# Layer 8: Batch Evaluation (Base vs Prompt Tuning+RAG)
def evaluate_pt_rag_comparison(start_idx=0, end_idx=100, output_file=None):
    
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    if output_file is None:
        output_file = f"pt_rag_eval_{timestamp}.csv"
    
    output_dir = "results"
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    full_output_path = os.path.join(output_dir, output_file)
    
    print(f"Start Evaluation: Base Model vs. Prompt Tuning + RAG")
    print(f"Range: {start_idx} -> {end_idx}")
    print(f"Results saved to: {full_output_path}")
    
    results = [] 
    total = 0
    correct_base = 0
    correct_pt_rag = 0
    
    # Counters: Improved and Worsened cases for PT+RAG compared to Base
    improved = 0
    worsened = 0
    
    for idx in range(start_idx, end_idx):
        try:
            ex = load_medmcqa_example(idx)
        except: 
            continue
        
        q = ex["question_text"]
        raw_ans = ex["answer"]
        
        # Parse GT
        gt = None
        if raw_ans and str(raw_ans).strip() in ["A","B","C","D"]: 
            gt = str(raw_ans).strip()
        elif raw_ans and str(raw_ans).isdigit(): 
            gt = chr(ord("A") + int(raw_ans) - 1)
        
        if not gt: continue 

        # Core Prediction Comparison
        # 1. Base Model (Raw)
        pred_base = qa_baseline_raw(q)
        
        # 2. Prompt Tuning + RAG
        ctx_rag, pred_pt_rag = qa_pt_rag(q)
        
        # Statistics
        is_correct_base = (pred_base == gt)
        is_correct_pt_rag = (pred_pt_rag == gt)
        
        if is_correct_base: correct_base += 1
        if is_correct_pt_rag: correct_pt_rag += 1
        total += 1
        
        status = "Same"
        if not is_correct_base and is_correct_pt_rag:
            status = "Improved"
            improved += 1
        elif is_correct_base and not is_correct_pt_rag:
            status = "Worsened"
            worsened += 1
        
        results.append({
            "Index": idx,
            "Question": q,
            "Ground_Truth": gt,
            "Pred_Base_Raw": pred_base,
            "Correct_Base_Raw": is_correct_base,
            "Pred_PT_RAG": pred_pt_rag,
            "Correct_PT_RAG": is_correct_pt_rag,
            "Status": status,
            "Retrieved_Context": ctx_rag
        })

        print(f"[{idx}] GT:{gt} | Base:{pred_base} {'O' if is_correct_base else 'X'} | PT+RAG:{pred_pt_rag} {'O' if is_correct_pt_rag else 'X'} | {status}")

    #Calculate Final Metrics
    acc_base = correct_base / total if total > 0 else 0
    acc_pt_rag = correct_pt_rag / total if total > 0 else 0
    
    print(f"\n=== Final Results ({total} questions) ===")
    print(f"Base GPT-2 (Raw) Accuracy   : {acc_base:.4f}")
    print(f"Prompt Tuning + RAG Accuracy: {acc_pt_rag:.4f}")
    print(f"Improved Cases: {improved}")
    print(f"Worsened Cases: {worsened}")
    
    # === Save File ===
    if results:
        df = pd.DataFrame(results)
        df["Global_Acc_Base"] = f"{acc_base:.2%}"
        df["Global_Acc_PT_RAG"] = f"{acc_pt_rag:.2%}"
        
        df.to_csv(full_output_path, index=False, encoding="utf-8-sig")
        print(f"Detailed results saved: {full_output_path}")
        
        # Save Summary
        base, ext = os.path.splitext(output_file)
        summary_filename = f"{base}_summary{ext}"
        full_summary_path = os.path.join(output_dir, summary_filename)
        
        summary_data = [{
            "Timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
            "Experiment": "Base_vs_PromptTuningRAG",
            "Range": f"{start_idx}-{end_idx}",
            "Total": total,
            "Acc_Base": acc_base,
            "Acc_PT_RAG": acc_pt_rag,
            "Improved": improved,
            "Worsened": worsened
        }]
        pd.DataFrame(summary_data).to_csv(full_summary_path, index=False, encoding="utf-8-sig")
        print(f"Summary statistics saved: {full_summary_path}")
    else:
        print("No results generated, skipping save.")

# Execution Entry Point
if __name__ == "__main__":
    # Run comparison test for 0 to 100 items
    evaluate_pt_rag_comparison(0, 100)