In [None]:
import json
import os
import re
import pandas as pd
import networkx as nx
import torch
from datetime import datetime
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

LLAMA2_PATH   = "C:\LLM\llama2"
PRIMEKG_PATH  = "kg.csv"
MEDMCQA_FILE  = "C:\LLM\data\medmcqa\dev.json"

#RAG parameters
MAX_RAG_ENTITIES = 3
MAX_K_EDGES      = 8
MAX_CTX_CHARS    = 3000

print("CONFIG OK. DEVICE =", DEVICE)

  from .autonotebook import tqdm as notebook_tqdm


CONFIG OK. DEVICE = cuda


In [None]:
# 4bit
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
)

def load_llama2_4bit(model_path: str = LLAMA2_PATH):
    print(f" Loading LLaMA2 4-bit from: {model_path}")
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        quantization_config=bnb_config,
        device_map="auto",
    )

    print("LLaMA2 4-bit loaded!")
    return tokenizer, model

llama2_tokenizer, llama2_model = load_llama2_4bit()



def llama2_generate(prompt: str,
                    max_new_tokens: int = 100,
                    do_sample: bool = False,
                    temperature: float = 0.0,
                    top_p: float = 1.0):
    inputs = llama2_tokenizer(prompt, return_tensors="pt").to(llama2_model.device)

    if inputs["input_ids"].shape[1] > 900:
        for k in inputs:
            inputs[k] = inputs[k][:, -900:]

    with torch.no_grad():
        outputs = llama2_model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=do_sample,
            temperature=temperature,
            top_p=top_p,
            no_repeat_ngram_size=3,
            pad_token_id=llama2_tokenizer.eos_token_id,
        )

    return llama2_tokenizer.decode(outputs[0], skip_special_tokens=True)


In [None]:

# 2. LOAD PRIMEKG GRAPH & OPTIMIZED RETRIEVAL

def load_kg(path):
    print(f"Loading PrimeKG from {path} ...")
    df = pd.read_csv(path, low_memory=False)
    G = nx.from_pandas_edgelist(df, "x_name", "y_name", edge_attr=True)
    print(f"KG Loaded: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")
    return G

G = load_kg(PRIMEKG_PATH)

# lower-case lookup
node_index = {str(n).lower(): n for n in G.nodes()}

def resolve_node(name):
    if not name: return None
    return node_index.get(name.lower())

def get_knowledge_context_single(entity, max_edges=MAX_K_EDGES):
    node = resolve_node(entity)
    if node is None: return None

    edges = list(G.edges(node, data=True))
    if not edges: return None


    priority_relations = {
        "indication": 0,
        "contraindication": 0,
        "disease_may_be_treated_by_drug": 0,
        "drug_may_treat_disease": 0,
        "causes": 1,
        "may_cause": 1,
        "synergistic_interaction": 2,
    }
    

    def get_sort_key(edge):
        attr = edge[2]
        rel = attr.get("display_relation", attr.get("relation", "related_to"))
        return priority_relations.get(rel, 99) 

    sorted_edges = sorted(edges, key=get_sort_key)

    lines = []
    for u, v, attr in sorted_edges[:max_edges]:
        rel = attr.get("display_relation", attr.get("relation", "related_to"))
        x_type = attr.get("x_type", "")
        y_type = attr.get("y_type", "")

        if node == u:
            nbr, t = v, y_type
        else:
            nbr, t = u, x_type

        if not t: t = "Entity"
        lines.append(f"{node} ({x_type}) is {rel} {nbr} ({t}).")

    return " ".join(lines)

In [None]:
import re

BLACKLIST = {
    "rest", "left", "right", "lateral", "medial", "distal", "proximal",
    "central", "group", "zone", "growth", "tube", "head", "neck", 
    "body", "hand", "foot", "wing", "type", "part", "high", "low",
    "blood", "cell", "tissue", "disease", "drug", "pathway", "membrane"
}

def extract_rag_entities(question, max_entities=MAX_RAG_ENTITIES):
    q_lower = question.lower()
    ents = []
    
    # Only consider entities with length >= 4
    candidates = [n for n in node_index.keys() if len(n) >= 4]
    
    for cand in candidates:
        # 1. Blacklist filtering
        if cand in BLACKLIST:
            continue

        # 2. Whole Word Matching
        if re.search(r'\b' + re.escape(cand) + r'\b', q_lower):
            ents.append(node_index[cand])
            if len(ents) >= max_entities:
                break
    return ents

def get_knowledge_context_multi(question):
    entities = extract_rag_entities(question)
    if not entities:
        return [], ""

    all_facts = []
    total_chars = 0

    for ent in entities:
        ctx = get_knowledge_context_single(ent)
        if not ctx: continue

        block = f"Fact about {ent}: {ctx}"
        if total_chars + len(block) > MAX_CTX_CHARS:
            break

        all_facts.append(block)
        total_chars += len(block)

    return entities, " ".join(all_facts)

In [None]:

# 4. LOAD MedMCQA
def load_medmcqa_example(idx, file_path=MEDMCQA_FILE):
    with open(file_path, "r", encoding="utf-8") as f:
        line_id = 0
        for line in f:
            line = line.strip()
            if not line:
                continue

            if line_id == idx:
                data = json.loads(line)

                q = data.get("question", "")
                options_text = []
                mapping = {"opa":"A", "opb":"B", "opc":"C", "opd":"D"}

                for key, lab in mapping.items():
                    if key in data:
                        options_text.append(f"{lab}) {data[key]}")

                question_text = q + "\nOptions:\n" + "\n".join(options_text)
                return {
                    "raw": data,
                    "question_text": question_text,
                    "answer": data.get("cop")
                }

            line_id += 1

    raise IndexError("Index out of range")

In [None]:

# 5. LLaMA2 QA: no RAG 
def qa_no_rag_llama2(question):
    # Llama 2 Chat
    sys_msg = (
        "You are a medical exam expert. \n"
        "You will receive a multiple-choice question. \n"
        "Think step by step and choose the single best answer. \n"
        "Output ONLY the option letter (A, B, C, or D) at the very end."
    )
    
    user_msg = (
        f"Question:\n{question}\n\n"
        "Choose the correct option. Reply with just the letter."
    )
    
    prompt = f"[INST] <<SYS>>\n{sys_msg}\n<</SYS>>\n\n{user_msg} [/INST]\nAnswer:"

    # Slightly increasing tokens allows the model to think briefly.
    out = llama2_generate(prompt, max_new_tokens=20) 

    # Analysis of Logical Optimization: Find the last option that appears
    matches = re.findall(r'\b([A-D])\b', out.upper())
    if matches:
        return matches[-1]
    
    return out.strip()[-1:] if out else "C" # Fallback

In [None]:

# 6. LLaMA2 QA: RAG
def qa_with_rag_all_llama2(question):
    entities, ctx = get_knowledge_context_multi(question)
    
    # If no entities or context found, fallback to no-RAG QA
    if not entities or not ctx:
        return [], "", qa_no_rag_llama2(question)

    sys_msg = (
        "You are a medical exam expert.\n"
        "You have access to the following retrieved biomedical knowledge facts.\n"
        "Use these facts to help answer the question if they are relevant.\n"
        "Think step by step.\n"
        "Output ONLY the option letter (A, B, C, or D) at the very end."
    )
    
    user_msg = (
        f"Retrieved Facts:\n{ctx}\n\n"
        f"Question:\n{question}\n\n"
        "Based on the facts and your knowledge, choose the correct option. Reply with just the letter."
    )
    
    prompt = f"[INST] <<SYS>>\n{sys_msg}\n<</SYS>>\n\n{user_msg} [/INST]\nAnswer:"

    out = llama2_generate(prompt, max_new_tokens=20)

    matches = re.findall(r'\b([A-D])\b', out.upper())
    if matches:
        return entities, ctx, matches[-1]

    return entities, ctx, out.strip()[-1:] if out else "C"

In [None]:

# 7. LLaMA2 ACC 
def evaluate_medmcqa_acc_llama2_save(start_idx=0, end_idx=50, output_dir="eval_results_llama2"):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    save_path = os.path.join(output_dir, f"llama2_results_{timestamp}.jsonl")
    
    print(f"Results will be saved to: {save_path}")
    print(f"Starting Llama 2 evaluation {start_idx} -> {end_idx}...\n")

    total = 0
    correct_no = 0
    correct_rag = 0
    improvements = 0
    hurts = 0


    def to_letter(x):
        if x is None: return None
        x = str(x).strip().upper()
        if x in ["A","B","C","D"]:
            return x
        if x.isdigit():
            i = int(x)
            if 1 <= i <= 4:
                return chr(ord("A")+i-1)
        return None


    with open(save_path, "w", encoding="utf-8") as f:
        for i in range(start_idx, end_idx):
            try:
                # load data
                ex = load_medmcqa_example(i)
            except Exception as e:
                print(f"Skipping index {i}: {e}")
                continue

            q = ex["question_text"]
            gt = to_letter(ex["answer"])
            
            if gt is None:
                continue

            pred_no = qa_no_rag_llama2(q)
            ents, ctx, pred_rag = qa_with_rag_all_llama2(q)


            total += 1
            is_correct_no = (pred_no == gt)
            is_correct_rag = (pred_rag == gt)

            if is_correct_no: correct_no += 1
            if is_correct_rag: correct_rag += 1

            if not is_correct_no and is_correct_rag:
                improvements += 1
            if is_correct_no and not is_correct_rag:
                hurts += 1

            status = ""
            if not is_correct_no and is_correct_rag: status = "Fixed"
            elif is_correct_no and not is_correct_rag: status = "Hurt"
            
            print(f"[{i}] GT={gt} | NO={pred_no} | RAG={pred_rag} | Ents={ents} | {status}")

            # save to JSONL
            record = {
                "id": i,
                "question": q,
                "ground_truth": gt,
                "pred_no_rag": pred_no,
                "pred_rag": pred_rag,
                "is_correct_no_rag": is_correct_no,
                "is_correct_rag": is_correct_rag,
                "retrieved_entities": ents,
                "rag_context": ctx
            }
            f.write(json.dumps(record, ensure_ascii=False) + "\n")

    if total == 0:
        print("no valid samples")
        return None

    #print final summary
    summary_text = (
        f"\n{'='*30}\n"
        f"Finished! Saved to: {save_path}\n"
        f"LLaMA2 Total: {total}\n"
        f"NO-RAG ACC  : {correct_no/total:.4f}\n"
        f"RAG ACC     : {correct_rag/total:.4f}\n"
        f"Fix by RAG  : {improvements}\n"
        f"Hurt by RAG : {hurts}\n"
        f"{'='*30}\n"
    )
    print(summary_text)

    return {
        "total": total,
        "acc_no": correct_no/total,
        "acc_rag": correct_rag/total,
        "improve": improvements,
        "hurt": hurts,
        "save_path": save_path
    }

In [15]:
evaluate_medmcqa_acc_llama2_save(0, 100)

Results will be saved to: eval_results_llama2\llama2_results_20251211_152250.jsonl
Starting Llama 2 evaluation 0 -> 100...

[0] GT=A | NO=B | RAG=D | Ents=['myelin sheath', 'myelin', 'nerve'] | 
[1] GT=A | NO=D | RAG=D | Ents=[] | 
[2] GT=C | NO=B | RAG=B | Ents=['Down syndrome', 'amniotic fluid'] | 
[3] GT=C | NO=B | RAG=B | Ents=['axonal transport', 'transport'] | 
[4] GT=A | NO=C | RAG=C | Ents=['Glucagon', 'Gluconeogenesis', 'Glycogen synthesis'] | 
[5] GT=A | NO=C | RAG=B | Ents=['Tropicamide'] | 
[6] GT=A | NO=C | RAG=B | Ents=['Oseltamivir', 'influenza', 'throat'] | 
[7] GT=B | NO=C | RAG=C | Ents=[] | 
[8] GT=B | NO=D | RAG=B | Ents=['Electrical alternans', 'P pulmonale'] | Fixed
[9] GT=B | NO=D | RAG=C | Ents=['Tetralogy of Fallot', 'Cyanosis', 'transposition'] | 
[10] GT=B | NO=C | RAG=B | Ents=['dental caries', 'enamel'] | Fixed
[11] GT=A | NO=B | RAG=B | Ents=['Heparin'] | 
[12] GT=B | NO=B | RAG=B | Ents=['pre-Botzinger complex'] | 
[13] GT=A | NO=B | RAG=B | Ents=['apraxi

{'total': 100,
 'acc_no': 0.27,
 'acc_rag': 0.3,
 'improve': 11,
 'hurt': 8,
 'save_path': 'eval_results_llama2\\llama2_results_20251211_152250.jsonl'}