In [None]:
# Cell 1: Import Dependencies & Load Llama-2 (4-bit)
import json
import os
import re
import pandas as pd
import networkx as nx
import torch
import torch.nn.functional as F
from datetime import datetime
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

LLAMA2_PATH   = r"/LLM/llama2"
PRIMEKG_PATH  = r"kg.csv"
MEDMCQA_FILE  = r"/LLM/data/medmcqa/dev.json"
OUTPUT_DIR    = "eval_results_llama2"   

# RAG Parameters
MAX_RAG_ENTITIES = 3    
MAX_K_EDGES      = 5    
MAX_CTX_CHARS    = 2000 

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Config loaded. Running on {DEVICE}")

# ---------------- Model Loading ----------------
def load_llama2_4bit(model_path):
    print(f"Loading LLaMA2 4-bit from: {model_path}")
    
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.float16,
    )
    
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    # Llama2 has no pad token; usually set to eos or unk, here we set it to eos
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        quantization_config=bnb_config,
        device_map="auto", 
    )
    print("LLaMA2 4-bit loaded!")
    return tokenizer, model

llama2_tokenizer, llama2_model = load_llama2_4bit(LLAMA2_PATH)

Config loaded. Running on cuda
Loading LLaMA2 4-bit from: /LLM/llama2


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [None]:
# Cell 2: Knowledge Graph Loading and RAG Functions
def load_kg(path):
    print(f"Loading PrimeKG from {path} ...")
    if not os.path.exists(path):
        return nx.Graph()
    df = pd.read_csv(path, low_memory=False)
    G = nx.from_pandas_edgelist(df, "x_name", "y_name", edge_attr=True)
    print(f"KG Loaded: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")
    return G

G = load_kg(PRIMEKG_PATH)
node_index = {str(n).lower(): n for n in G.nodes()}

BLACKLIST = {
    "rest", "left", "right", "lateral", "medial", "distal", "proximal",
    "central", "group", "zone", "growth", "tube", "head", "neck", 
    "body", "hand", "foot", "wing", "type", "part", "high", "low",
    "blood", "cell", "tissue", "disease", "drug", "pathway", "human"
}

def extract_rag_entities(question, max_entities=MAX_RAG_ENTITIES):
    q_lower = question.lower()
    ents = []
    candidates = [n for n in node_index.keys() if len(n) >= 4]
    for cand in candidates:
        if cand in BLACKLIST: continue
        if re.search(r'\b' + re.escape(cand) + r'\b', q_lower):
            ents.append(node_index[cand])
            if len(ents) >= max_entities: break
    return ents

def get_knowledge_context_single(entity, max_edges=MAX_K_EDGES):
    node = node_index.get(entity.lower())
    if not node: return None
    try:
        edges = list(G.edges(node, data=True))
    except: return None
    if not edges: return None
    sorted_edges = sorted(edges, key=lambda x: x[2].get('relation', ''), reverse=True)
    lines = []
    for u, v, attr in sorted_edges[:max_edges]:
        rel = attr.get("display_relation", attr.get("relation", "related to"))
        nbr = v if node == u else u
        lines.append(f"{node} is {rel} {nbr}.")
    return " ".join(lines)

def get_knowledge_context_multi(question):
    entities = extract_rag_entities(question)
    if not entities: return [], ""
    all_facts = []
    current_len = 0
    for ent in entities:
        ctx = get_knowledge_context_single(ent)
        if not ctx: continue
        block = f"Fact about {ent}: {ctx} "
        if current_len + len(block) > MAX_CTX_CHARS: break
        all_facts.append(block)
        current_len += len(block)
    return entities, "".join(all_facts)

Loading PrimeKG from kg.csv ...
KG Loaded: 129262 nodes, 4049405 edges


In [None]:

# Cell 3: Load MedMCQA Example

def load_medmcqa_example(idx, file_path=MEDMCQA_FILE):
    if not os.path.exists(file_path):
        raise FileNotFoundError(f"File not found: {file_path}")
        
    with open(file_path, "r", encoding="utf-8") as f:
        for line_id, line in enumerate(f):
            if line_id < idx: continue
            if line_id > idx: break 
            
            line = line.strip()
            if not line: continue
            
            data = json.loads(line)
            q = data.get("question", "")
            options_dict = {}
            mapping = {"opa":"A", "opb":"B", "opc":"C", "opd":"D"}
            for key, lab in mapping.items():
                if key in data and data[key]:
                    options_dict[lab] = str(data[key]).strip()
            
            return {
                "id": idx,
                "question": q,
                "options": options_dict,
                "answer": data.get("cop")
            }
    raise IndexError("Index out of range")

In [None]:
# Cell 4: Llama 2 PPL Scoring Function (with length normalization)
def get_llama2_score(context, question, option_text):
    """
    Calculate the plausibility score of (Context + Question + Option).
    Use labels=-100 to mask the Question, calculating Loss only for the Option.
    """
    # 1. Construct Prompt
    parts = []
    if context:
        parts.append(f"Context: {context}") # RAG knowledge as prerequisite
    parts.append(f"Question: {question}")
    parts.append("Answer:")
    
    prompt = "\n".join(parts)
    # Llama2 is sensitive to spaces; it is recommended to add a space after Answer:
    full_text = f"{prompt} {option_text}"
    
    # 2. Encoding
    prompt_enc = llama2_tokenizer(prompt, return_tensors="pt", add_special_tokens=False)
    full_enc = llama2_tokenizer(full_text, return_tensors="pt", truncation=True, max_length=2048)
    
    input_ids = full_enc.input_ids.to(DEVICE)
    attention_mask = full_enc.attention_mask.to(DEVICE)
    
    # 3. Masking (Mask the Prompt part, only calculate the score for the Option)
    labels = input_ids.clone()
    prompt_len = prompt_enc.input_ids.shape[1]
    
    if prompt_len < labels.shape[1]:
        labels[:, :prompt_len] = -100
    else:
        labels[:, :] = -100 # Exception safety
        
    # 4. Calculate Loss
    with torch.no_grad():
        outputs = llama2_model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        
        # HuggingFace returns Mean Loss by default (average loss per token)
        # This naturally has a length normalization effect (much better than Sum Loss)
        mean_loss = outputs.loss.item()
    
    if torch.isnan(outputs.loss):
        return -float('inf')
    # Return negative Loss (higher score is better)
    return -mean_loss

def solve_by_scoring(question, options_dict, context=""):
    """
    Iterate through all options and select the one with the highest score.
    """
    scores = {}
    for label in ["A", "B", "C", "D"]:
        if label not in options_dict:
            scores[label] = -float('inf')
            continue
            
        opt_text = options_dict[label]
        scores[label] = get_llama2_score(context, question, opt_text)
    # Select the maximum value
    best_label = max(scores, key=scores.get)
    return best_label, scores

In [None]:
# Cell 5: Run Evaluation
def evaluate_and_save(start_idx=0, end_idx=50):
    if not os.path.exists(OUTPUT_DIR):
        os.makedirs(OUTPUT_DIR)
        
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    save_path = os.path.join(OUTPUT_DIR, f"llama2_scoring_{timestamp}.jsonl")
    print(f"Results will be saved to: {save_path}")
    print(f"Starting SCORING evaluation {start_idx} -> {end_idx}...\n")
    
    total = 0
    correct_base = 0
    correct_rag = 0
    
    # Answer formatting tool
    def format_gt(val):
        s = str(val).strip().upper()
        mapping = {'1':'A', '2':'B', '3':'C', '4':'D'}
        return mapping.get(s, s) if s in mapping else (s if s in ['A','B','C','D'] else None)

    with open(save_path, "w", encoding="utf-8") as f:
        for i in range(start_idx, end_idx):
            try:
                # Load
                ex = load_medmcqa_example(i)
                gt = format_gt(ex['answer'])
                if not gt: continue
                
                q_text = ex['question']
                opts = ex['options']
                
                # 1. Base Inference
                pred_base, scores_base = solve_by_scoring(q_text, opts, context="")
                
                # 2. RAG Inference
                ents, ctx = get_knowledge_context_multi(q_text)
                pred_rag, scores_rag = solve_by_scoring(q_text, opts, context=ctx)
                
                # 3. Statistics
                total += 1
                if pred_base == gt: correct_base += 1
                if pred_rag == gt: correct_rag += 1
                
                # 4. Write
                record = {
                    "id": i,
                    "question": q_text,
                    "ground_truth": gt,
                    "base_prediction": pred_base,
                    "rag_prediction": pred_rag,
                    "rag_context": ctx,
                    "is_base_correct": (pred_base == gt),
                    "is_rag_correct": (pred_rag == gt)
                }
                f.write(json.dumps(record, ensure_ascii=False) + "\n")
                
                # Print summary
                status = ""
                if pred_rag == gt and pred_base != gt: status = "O RAG FIX"
                elif pred_rag != gt and pred_base == gt: status = "X RAG HURT"
                print(f"[{i}] GT:{gt} | Base:{pred_base} | RAG:{pred_rag} | {status}")
                
            except Exception as e:
                print(f"Skipping {i}: {e}")
                continue
            
    if total > 0:
        print(f"Total: {total}")
        print(f"Base Acc: {correct_base/total:.2%}")
        print(f"RAG  Acc: {correct_rag/total:.2%}")

# Start execution
evaluate_and_save(0, 4183)

Results will be saved to: eval_results_llama2/llama2_scoring_20251222_233932.jsonl
Starting SCORING evaluation 0 -> 4183...

[0] GT:A | Base:A | RAG:A | 
[1] GT:A | Base:B | RAG:B | 
[2] GT:C | Base:B | RAG:B | 
[3] GT:C | Base:B | RAG:B | 
[4] GT:A | Base:C | RAG:C | 
[5] GT:A | Base:B | RAG:A | ✅ RAG FIX
[6] GT:A | Base:A | RAG:A | 
[7] GT:B | Base:A | RAG:A | 
[8] GT:B | Base:C | RAG:C | 
[9] GT:B | Base:D | RAG:C | 
[10] GT:B | Base:B | RAG:B | 
[11] GT:A | Base:A | RAG:A | 
[12] GT:B | Base:A | RAG:A | 
[13] GT:A | Base:B | RAG:B | 
[14] GT:A | Base:D | RAG:D | 
[15] GT:B | Base:A | RAG:A | 
[16] GT:A | Base:C | RAG:C | 
[17] GT:D | Base:D | RAG:D | 
[18] GT:A | Base:A | RAG:A | 
[19] GT:A | Base:C | RAG:D | 
[20] GT:B | Base:A | RAG:A | 
[21] GT:C | Base:A | RAG:A | 
[22] GT:A | Base:D | RAG:D | 
[23] GT:A | Base:A | RAG:A | 
[24] GT:D | Base:C | RAG:C | 
[25] GT:B | Base:B | RAG:C | ❌ RAG HURT
[26] GT:A | Base:C | RAG:C | 
[27] GT:D | Base:D | RAG:D | 
[28] GT:A | Base:D | RAG:B

In [None]:
import json

# Configuration Section
# Replace the path below with your actual generated jsonl file path
file_path = r"eval_results_llama2/llama2_scoring_20251212_130733.jsonl"

def analyze_negative_transfer(file_path):
    total_samples = 0
    negative_cases = [] # Base Correct, RAG Incorrect
    positive_cases = [] # Base Incorrect, RAG Correct
    neutral_cases = 0   # Results are the same
    
    print(f"Analyzing file: {file_path} ...\n")

    try:
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                try:
                    data = json.loads(line)
                    total_samples += 1
                    
                    # Assuming jsonl contains the following fields (based on typical eval output assumptions)
                    # You may need to fine-tune based on actual json keys, e.g., 'label' or 'answer_idx'
                    idx = data.get('id', total_samples)
                    question = data.get('question', '')
                    label = data.get('label', '') # Correct answer index (e.g., 0, 1, 2, 3 or A, B, C, D)
                    
                    # Get prediction results
                    base_pred = data.get('base_pred', '') 
                    rag_pred = data.get('rag_pred', '')
                    
                    # Get retrieved context (Key point)
                    context = data.get('context', 'No Context Content')
                    
                    # Simple standardized comparison (convert to string comparison)
                    is_base_correct = str(base_pred) == str(label)
                    is_rag_correct = str(rag_pred) == str(label)

                    if is_base_correct and not is_rag_correct:
                        # Record negative case: Base correct, RAG incorrect
                        negative_cases.append({
                            'id': idx,
                            'question': question,
                            'label': label,
                            'base_pred': base_pred,
                            'rag_pred': rag_pred,
                            'context': context
                        })
                    elif not is_base_correct and is_rag_correct:
                        positive_cases.append(idx)
                    else:
                        neutral_cases += 1
                        
                except json.JSONDecodeError:
                    continue

    except FileNotFoundError:
        print(f"Error: File not found {file_path}")
        return

    #Output Statistical Results
    print(f"=== Analysis Results (Total Samples: {total_samples}) ===")
    print(f"RAG Effective Fix (Base Wrong -> RAG Correct): {len(positive_cases)} cases")
    print(f"RAG Introduced Noise (Base Correct -> RAG Wrong): {len(negative_cases)} cases (Key Focus)")
    print(f"No Change or Both Wrong: {neutral_cases} cases")
    
    rag_drop_rate = (len(negative_cases) / total_samples) * 100 if total_samples > 0 else 0
    print(f"Negative Transfer Rate: {rag_drop_rate:.2f}%")
    
    print("\n" + "="*50)
    print("Detailed Check: Why did RAG cause errors? (Showing first 5 cases)")
    print("="*50)

    for i, case in enumerate(negative_cases[:5]): # Print only the first 5 bad cases
        print(f"\n[Case ID: {case['id']}]")
        print(f"Question: {case['question']}")
        print(f"Correct Answer: {case['label']}")
        print(f"Base Prediction: {case['base_pred']} (Correct)")
        print(f"RAG  Prediction: {case['rag_pred']} (Incorrect)")
        print(f"Retrieved Context (Possible Noise):")
        print(f"   > {case['context'][:300]}...") # Show only first 300 characters
        print("-" * 50)

if __name__ == "__main__":
    analyze_negative_transfer(file_path)

In [None]:
import os
import json
import re
import pickle
import numpy as np
import pandas as pd
from datetime import datetime
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

import faiss
import pyarrow.parquet as pq
from sentence_transformers import SentenceTransformer

In [None]:
LLAMA2_PATH = r"llama2"

In [None]:
# MedMCQA dev.json (Change this to your actual path)
MEDMCQA_FILE = r"medmcqa/dev.json"

# PubMedQA train.parquet (Used to build the paper knowledge base)
PARQUET_PATH = r"pubmedqa_hf/pqa_labeled_splits/train.parquet"

# Your specified new output path (Knowledge Base files)
OUT_INDEX = r"PrimeKG/pubmed_qa.index"
OUT_DOCS  = r"PrimeKG/pubmed_documents.pkl"

# Evaluation output directory (Saves llama2_scoring_xxx.jsonl)
OUTPUT_DIR = r"PrimeKG/eval_results_llama2_paperRAG"

In [None]:
# RAG Parameters (Paper Repository)
TOP_K_DOCS = 1          # Strongly recommended to start with 1
MAX_CTX_CHARS = 800     # Strongly recommended to start with 600~900; do not make it too long
# Embedding Model (Must match the one used during index construction)
# If the model isn't cached in your offline environment, the download will fail; 
# in that case, replace this with the local path to your existing SentenceTransformer model.
EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("DEVICE:", DEVICE)

DEVICE: cuda


In [None]:
# Cell 2: load Llama-2 4bit
def load_llama2_4bit(model_path):
    print(f"Loading LLaMA2 4-bit from: {model_path}")

    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.float16,
    )

    tokenizer = AutoTokenizer.from_pretrained(model_path)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        quantization_config=bnb_config,
        device_map="auto",
    )

    print("LLaMA2 4-bit loaded!")
    return tokenizer, model

llama2_tokenizer, llama2_model = load_llama2_4bit(LLAMA2_PATH)

Loading LLaMA2 4-bit from: /LLM/llama2


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

LLaMA2 4-bit loaded!


In [None]:
# Cell 3: load pubmed_qa.index + pubmed_documents.pkl
def build_paper_kb(parquet_path, out_index, out_docs, embed_model_name):
    print("[KB] Building paper KB from:", parquet_path)
    df = pq.read_table(parquet_path).to_pandas()

    documents = []
    for raw_ctx in df["context"]:
        try:
            if hasattr(raw_ctx, "get"):  # dict: {"contexts":[...]}
                text_list = raw_ctx.get("contexts", [])
            else:                        # list
                text_list = raw_ctx
            full_text = " ".join([str(x) for x in text_list if x])
            documents.append(full_text)
        except:
            documents.append("")

    print(f"[KB] Extracted docs: {len(documents)}")

    print("[KB] Embedding docs...")
    embed_model = SentenceTransformer(embed_model_name)
    emb = embed_model.encode(documents, convert_to_numpy=True, show_progress_bar=True)

    # cosine：normalize + IP
    faiss.normalize_L2(emb)
    index = faiss.IndexFlatIP(emb.shape[1])
    index.add(emb)

    os.makedirs(os.path.dirname(out_index), exist_ok=True)
    faiss.write_index(index, out_index)
    with open(out_docs, "wb") as f:
        pickle.dump(documents, f)

    print("[KB] Saved index:", out_index)
    print("[KB] Saved docs :", out_docs)


[KB] Loading index/docs...
[KB] Loaded. docs: 800


In [8]:
def load_paper_kb(out_index, out_docs, embed_model_name):
    print("[KB] Loading index/docs...")
    if not os.path.exists(out_index) or not os.path.exists(out_docs):
        return None

    index = faiss.read_index(out_index)
    with open(out_docs, "rb") as f:
        documents = pickle.load(f)
    embed_model = SentenceTransformer(embed_model_name)

    print("[KB] Loaded. docs:", len(documents))
    return index, documents, embed_model

kb = load_paper_kb(OUT_INDEX, OUT_DOCS, EMBED_MODEL_NAME)
if kb is None:
    build_paper_kb(PARQUET_PATH, OUT_INDEX, OUT_DOCS, EMBED_MODEL_NAME)
    kb = load_paper_kb(OUT_INDEX, OUT_DOCS, EMBED_MODEL_NAME)

paper_index, paper_docs, paper_embed = kb

[KB] Loading index/docs...
[KB] Loaded. docs: 800


In [None]:

# Cell 4: RAG Functions for Paper Retrieval

def retrieve_papers(query, top_k=TOP_K_DOCS):
    q_emb = paper_embed.encode([query], convert_to_numpy=True)
    faiss.normalize_L2(q_emb)
    scores, idxs = paper_index.search(q_emb, top_k)
    idxs = idxs[0].tolist()
    return [paper_docs[i] for i in idxs if 0 <= i < len(paper_docs)]

def get_paper_context(question, options_dict=None, top_k=TOP_K_DOCS, max_chars=MAX_CTX_CHARS):
    queries = [question]
    if options_dict:
        for lab, txt in options_dict.items():
            if txt:
                queries.append(f"{question} {txt}")

    seen = set()
    blocks = []
    cur = 0

    for q in queries:
        for d in retrieve_papers(q, top_k=top_k):
            d = (d or "").strip()
            if not d:
                continue
            key = d[:200]
            if key in seen:
                continue
            seen.add(key)

            block = f"Paper: {d}\n"
            if cur + len(block) > max_chars:
                return [], "".join(blocks)
            blocks.append(block)
            cur += len(block)

        if cur >= max_chars:
            break

    return [], "".join(blocks)


In [None]:

# Cell 5: load MedMCQA (return options dict)
def load_medmcqa_example(idx, file_path=MEDMCQA_FILE):
    if not os.path.exists(file_path):
        raise FileNotFoundError(f"File not found: {file_path}")

    with open(file_path, "r", encoding="utf-8") as f:
        for line_id, line in enumerate(f):
            if line_id < idx:
                continue
            if line_id > idx:
                break

            line = line.strip()
            if not line:
                continue

            data = json.loads(line)
            q = data.get("question", "")

            options_dict = {}
            mapping = {"opa": "A", "opb": "B", "opc": "C", "opd": "D"}
            for key, lab in mapping.items():
                if key in data and data[key]:
                    options_dict[lab] = str(data[key]).strip()

            return {
                "id": idx,
                "question": q,
                "options": options_dict,
                "answer": data.get("cop")
            }

    raise IndexError("Index out of range")


In [None]:
# Cell 6:  Llama 2 PPL Scoring Function (with length normalization)
def get_llama2_score(context, question, option_text):
    sys_msg = "You are a medical expert. Choose the correct option."

    if context:
        user_content = f"Context: {context}\n\nQuestion: {question}"
    else:
        user_content = f"Question: {question}"

    prompt_prefix = (
        f"[INST] <<SYS>>\n{sys_msg}\n<</SYS>>\n\n{user_content} [/INST] "
        f"The correct answer is"
    )
    full_text = f"{prompt_prefix} {option_text}"

    full_enc = llama2_tokenizer(full_text, return_tensors="pt", truncation=True, max_length=2048)
    prompt_enc = llama2_tokenizer(prompt_prefix, return_tensors="pt", add_special_tokens=True)

    input_ids = full_enc.input_ids.to(DEVICE)
    attention_mask = full_enc.attention_mask.to(DEVICE)

    labels = input_ids.clone()
    prompt_len = prompt_enc.input_ids.shape[1]

    if prompt_len >= input_ids.shape[1]:
        return -float("inf")

    labels[:, :prompt_len] = -100

    with torch.no_grad():
        outputs = llama2_model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss

    if torch.isnan(loss):
        return -float("inf")

    return -loss.item()

def solve_by_scoring(question, options_dict, context=""):
    scores = {}
    for label in ["A", "B", "C", "D"]:
        if label not in options_dict:
            scores[label] = -float("inf")
            continue
        scores[label] = get_llama2_score(context, question, options_dict[label])

    best_label = max(scores, key=scores.get)
    return best_label, scores


In [None]:

# Cell 7: # Run Evaluation with Paper-RAG

def evaluate_and_save(start_idx=0, end_idx=50):
    os.makedirs(OUTPUT_DIR, exist_ok=True)

    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    save_path = os.path.join(OUTPUT_DIR, f"llama2_scoring_{timestamp}.jsonl")
    print(f"Results will be saved to: {save_path}")
    print(f"Starting SCORING evaluation {start_idx} -> {end_idx}...\n")

    total = 0
    correct_base = 0
    correct_rag = 0

    def format_gt(val):
        s = str(val).strip().upper()
        mapping = {"1": "A", "2": "B", "3": "C", "4": "D"}
        if s in mapping:
            return mapping[s]
        if s in ["A", "B", "C", "D"]:
            return s
        return None

    with open(save_path, "w", encoding="utf-8") as f:
        for i in range(start_idx, end_idx):
            try:
                ex = load_medmcqa_example(i)
                gt = format_gt(ex["answer"])
                if not gt:
                    continue

                q_text = ex["question"]
                opts = ex["options"]

                # Base
                pred_base, scores_base = solve_by_scoring(q_text, opts, context="")

                # Paper-RAG
                ents, ctx = get_paper_context(q_text, opts, top_k=TOP_K_DOCS, max_chars=MAX_CTX_CHARS)
                pred_rag, scores_rag = solve_by_scoring(q_text, opts, context=ctx)

                total += 1
                if pred_base == gt:
                    correct_base += 1
                if pred_rag == gt:
                    correct_rag += 1

                status = ""
                if pred_rag == gt and pred_base != gt:
                    status = "O RAG FIX"
                elif pred_rag != gt and pred_base == gt:
                    status = "X RAG HURT"

                if i < start_idx + 3:
                    print(f"\n[DEBUG ID {i}]")
                    print("Q:", q_text[:120], "...")
                    print("GT:", gt, "| Base:", pred_base, "| RAG:", pred_rag, status)
                    print("Base scores:", {k: round(v, 4) for k, v in scores_base.items()})
                    print("RAG  scores:", {k: round(v, 4) for k, v in scores_rag.items()})
                    print("CTX(head):", (ctx[:300] + "...") if ctx else "<EMPTY>")

                record = {
                    "id": i,
                    "question": q_text,
                    "ground_truth": gt,
                    "base_prediction": pred_base,
                    "rag_prediction": pred_rag,
                    "rag_context": ctx,
                    "scores_base": scores_base,
                    "scores_rag": scores_rag,
                    "is_base_correct": (pred_base == gt),
                    "is_rag_correct": (pred_rag == gt),
                }
                f.write(json.dumps(record, ensure_ascii=False) + "\n")

                if i >= start_idx + 3:
                    print(f"[{i}] GT:{gt} | Base:{pred_base} | RAG:{pred_rag} | {status}")

            except Exception as e:
                print(f"Skipping {i}: {e}")
                continue

    if total > 0:
        print("\n" + "=" * 30)
        print(f"Total    : {total}")
        print(f"Base Acc : {correct_base/total:.2%}")
        print(f"RAG  Acc : {correct_rag/total:.2%}")
        print("=" * 30 + "\n")

    return save_path


In [None]:

# Cell 8: Start Execution
save_path = evaluate_and_save(0, 1500)
print("Saved:", save_path)


Results will be saved to: /LLM/PrimeKG/eval_results_llama2_paperRAG/llama2_scoring_20251222_172524.jsonl
Starting SCORING evaluation 0 -> 1500...


[DEBUG ID 0]
Q: Which of the following is not true for myelinated nerve fibers: ...
GT: A | Base: A | RAG: A 
Base scores: {'A': -2.7494, 'B': -4.2843, 'C': -3.8867, 'D': -3.4827}
RAG  scores: {'A': -2.7494, 'B': -4.2843, 'C': -3.8867, 'D': -3.4827}
CTX(head): <EMPTY>

[DEBUG ID 1]
Q: Which of the following is not true about glomerular capillaries') ...
GT: A | Base: D | RAG: D 
Base scores: {'A': -3.988, 'B': -3.7157, 'C': -4.1374, 'D': -3.5972}
RAG  scores: {'A': -3.988, 'B': -3.7157, 'C': -4.1374, 'D': -3.5972}
CTX(head): <EMPTY>

[DEBUG ID 2]
Q: A 29 yrs old woman with a pregnancy of 17 week has a 10 years old boy with down syndrome. She does not want another down ...
GT: C | Base: A | RAG: A 
Base scores: {'A': -4.3582, 'B': -4.4772, 'C': -5.0548, 'D': -6.5388}
RAG  scores: {'A': -4.3582, 'B': -4.4772, 'C': -5.0548, 'D': -6.5388}
CTX(h