In [None]:
# =========================
# 0) CONFIG (edit if needed)
# =========================
import os, math, random, textwrap
import pandas as pd
import torch
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel

# Base model and where you saved the LoRA adapter
BASE_MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
ADAPTER_DIR = r".\qwen25_1p5b_medqa_lora_cpu\adapter"   # <- from your logs
DATA_PATH = r"YOUR DIRECTORY to the Dataset\Medical_QA_Dataset.csv"

SYSTEM_PROMPT = (
    "You are a careful medical information assistant. Provide general educational information, "
    "not personal medical advice. Encourage consulting qualified clinicians for diagnosis and treatment. "
    "If symptoms suggest an emergency, advise seeking urgent care. If unsure, say you don't know."
)

# CPU performance: use all cores (optional)
try:
    torch.set_num_threads(os.cpu_count() or 8)
except Exception:
    pass

# Reproducibility for generation (optional)
SEED = 42
random.seed(SEED)
torch.manual_seed(SEED)


# =========================
# 1) LOAD Tokenizer + Model
# =========================
def load_model_cpu(model_name: str):
    """
    Transformers v5+ may prefer dtype=... vs torch_dtype=...
    We'll try dtype first, then fall back.
    """
    try:
        return AutoModelForCausalLM.from_pretrained(model_name, dtype=torch.float32, trust_remote_code=False)
    except Exception:
        return AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float32, trust_remote_code=False)

# Load tokenizer from adapter folder (so it matches what you saved)
tokenizer = AutoTokenizer.from_pretrained(ADAPTER_DIR, trust_remote_code=False)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

# Load base model + attach adapter
base_model = load_model_cpu(BASE_MODEL_NAME)
model_ft = PeftModel.from_pretrained(base_model, ADAPTER_DIR)
model_ft.eval()

print("[OK] Loaded fine-tuned model (base + LoRA adapter).")


# =========================
# 2) PROMPT + GENERATION
# =========================
def build_user_content(question: str, qtype: str | None = None) -> str:
    if qtype and str(qtype).strip() and str(qtype).lower() not in {"nan", "none"}:
        return f"Question type: {qtype}\n\nQuestion: {question}"
    return question

def build_chat_text(question: str, qtype: str | None = None) -> str:
    messages = [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": build_user_content(question, qtype=qtype)},
    ]
    # IMPORTANT: same pattern as training (add_generation_prompt=True)
    return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

@torch.inference_mode()
def generate_answer(
    model,
    question: str,
    qtype: str | None = None,
    max_new_tokens: int = 256,
    do_sample: bool = False,      # False = deterministic (greedy)
    temperature: float = 0.7,     # used only if do_sample=True
    top_p: float = 0.9,           # used only if do_sample=True
    repetition_penalty: float = 1.05,
) -> str:
    text = build_chat_text(question, qtype=qtype)
    inputs = tokenizer([text], return_tensors="pt")
    # CPU inference
    inputs = {k: v.to("cpu") for k, v in inputs.items()}

    gen_kwargs = dict(
        max_new_tokens=max_new_tokens,
        do_sample=do_sample,
        repetition_penalty=repetition_penalty,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
    )
    if do_sample:
        gen_kwargs.update(dict(temperature=temperature, top_p=top_p))

    out = model.generate(**inputs, **gen_kwargs)

    # decode only newly generated tokens
    new_tokens = out[0][inputs["input_ids"].shape[1]:]
    return tokenizer.decode(new_tokens, skip_special_tokens=True).strip()


# ============================================
# 3) QUICK MANUAL TESTS (YOUR MODEL ONLY)
# ============================================
test_questions = [
    ("general", "What is hypertension?"),
    ("symptoms", "What are common symptoms of influenza (flu)?"),
    ("treatment", "How is type 2 diabetes commonly treated?"),
    ("safety", "I have severe chest pain and shortness of breath. What should I do?"),
    ("out_of_domain", "Write a Python function to sort a list."),  # should still respond, but may be less good
]

print("\n======================")
print("Fine-tuned model samples")
print("======================")
for qt, q in test_questions:
    ans = generate_answer(model_ft, q, qtype=qt, max_new_tokens=200, do_sample=False)
    print(f"\nQTYPE: {qt}\nQ: {q}\nA: {ans}\n{'-'*80}")


# ==========================================================
# 4) BASE vs FINE-TUNED COMPARISON (NO EXTRA MEMORY)
#    We try to disable adapter temporarily for "base" output.
# ==========================================================
def can_disable_adapter(m) -> bool:
    return hasattr(m, "disable_adapter")

@torch.inference_mode()
def compare_base_vs_ft(
    question: str,
    qtype: str | None = None,
    max_new_tokens: int = 220,
):
    ft = generate_answer(model_ft, question, qtype=qtype, max_new_tokens=max_new_tokens, do_sample=False)

    base = None
    if can_disable_adapter(model_ft):
        try:
            with model_ft.disable_adapter():
                base = generate_answer(model_ft, question, qtype=qtype, max_new_tokens=max_new_tokens, do_sample=False)
        except Exception as e:
            base = f"[Could not disable adapter in this PEFT version: {e}]"
    else:
        base = "[This PEFT version does not support disable_adapter() easily.]"

    print("\n======================")
    print("BASE vs FINE-TUNED")
    print("======================")
    print(f"Q: {question}\n")
    print("---- BASE (adapter OFF) ----")
    print(base)
    print("\n---- FINE-TUNED (adapter ON) ----")
    print(ft)

# Example comparison
compare_base_vs_ft("What causes migraine headaches, and what are common treatments?", qtype="treatment")


# ==========================================================
# 5) QUICK OVERFITTING SANITY CHECK (Train vs Eval ROUGE-L)
#    - small sample size (CPU-friendly)
#    - ROUGE-L is not perfect, but helps detect big train>eval gap
# ==========================================================
def rouge_l_f1(pred: str, ref: str) -> float:
    # Simple whitespace tokenization
    a = pred.lower().split()
    b = ref.lower().split()
    if not a or not b:
        return 0.0

    # LCS DP (O(n*m)) - OK for small lengths
    n, m = len(a), len(b)
    dp = [[0]*(m+1) for _ in range(n+1)]
    for i in range(n):
        ai = a[i]
        for j in range(m):
            if ai == b[j]:
                dp[i+1][j+1] = dp[i][j] + 1
            else:
                dp[i+1][j+1] = max(dp[i][j+1], dp[i+1][j])
    lcs = dp[n][m]
    prec = lcs / max(n, 1)
    rec  = lcs / max(m, 1)
    if prec + rec == 0:
        return 0.0
    return (2 * prec * rec) / (prec + rec)

def load_split_dataset(data_path: str, test_size: float = 0.05, seed: int = 42) -> tuple[Dataset, Dataset]:
    df = pd.read_csv(data_path)[["qtype", "Question", "Answer"]].copy()
    df["qtype"] = df["qtype"].fillna("").astype(str)
    df["Question"] = df["Question"].fillna("").astype(str)
    df["Answer"] = df["Answer"].fillna("").astype(str)
    df = df[(df["Question"].str.strip() != "") & (df["Answer"].str.strip() != "")].reset_index(drop=True)
    ds = Dataset.from_pandas(df, preserve_index=False)
    split = ds.train_test_split(test_size=test_size, seed=seed)
    return split["train"], split["test"]

def quick_rouge_eval(ds: Dataset, n: int = 20, seed: int = 0, max_new_tokens: int = 180):
    rng = random.Random(seed)
    idxs = list(range(len(ds)))
    rng.shuffle(idxs)
    idxs = idxs[:min(n, len(ds))]

    scores = []
    rows = []
    for i in idxs:
        ex = ds[i]
        qtype = ex.get("qtype", "")
        q = ex.get("Question", "")
        ref = ex.get("Answer", "")

        pred = generate_answer(
            model_ft,
            q,
            qtype=qtype,
            max_new_tokens=max_new_tokens,
            do_sample=False,   # deterministic for eval
        )
        s = rouge_l_f1(pred, ref)
        scores.append(s)
        rows.append((s, qtype, q, pred, ref))

    scores_sorted = sorted(rows, key=lambda x: x[0])
    mean_score = sum(scores) / max(len(scores), 1)

    return mean_score, scores_sorted

train_ds, eval_ds = load_split_dataset(DATA_PATH, test_size=0.05, seed=42)

N = 15  # keep small on CPU; increase if you want
train_mean, train_rows = quick_rouge_eval(train_ds, n=N, seed=1)
eval_mean, eval_rows = quick_rouge_eval(eval_ds, n=N, seed=2)

print("\n======================")
print("Quick Overfitting Check (ROUGE-L F1)")
print("======================")
print(f"Train sample (n={N}) mean ROUGE-L F1: {train_mean:.3f}")
print(f"Eval  sample (n={N}) mean ROUGE-L F1: {eval_mean:.3f}")
print("Rule of thumb: if train >> eval by a big margin, that suggests overfitting.\n")

print("Worst 3 EVAL examples (lowest ROUGE-L):")
for s, qt, q, pred, ref in eval_rows[:3]:
    print("\n" + "-"*80)
    print(f"ROUGE-L F1: {s:.3f} | qtype: {qt}")
    print("Q:", q)
    print("PRED:", textwrap.shorten(pred, width=300, placeholder=" ..."))
    print("REF :", textwrap.shorten(ref,  width=300, placeholder=" ..."))


# ============================================
# 6) OPTIONAL: interactive Q&A loop in Jupyter
# ============================================
def qa():
    print("\nInteractive Q&A (type 'exit' to stop)\n")
    while True:
        q = input("You: ").strip()
        if q.lower() in {"exit", "quit"}:
            break
        a = generate_answer(model_ft, q, qtype=None, max_new_tokens=256, do_sample=True, temperature=0.7, top_p=0.9)
        print("\nModel:", a, "\n")

# To start interactive mode, run:
# qa()




Loading weights:   0%|          | 0/338 [00:00<?, ?it/s]

The following generation flags are not valid and may be ignored: ['temperature', 'top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


[OK] Loaded fine-tuned model (base + LoRA adapter).

Fine-tuned model samples

QTYPE: general
Q: What is hypertension?
A: Hypertension, also called high blood pressure, is a common condition in which the long-term force of the blood against your artery walls is higher than normal. Over time, high blood pressure can damage the walls of your arteries and make it harder for your heart to pump blood. It's the most common cause of heart disease and stroke, which are leading causes of death in the United States.    Hypertension often has no signs or symptoms. Many people who have it don't know they have it. That's why it's important to get your blood pressure checked regularly by a health care provider.    The main risk factor for developing hypertension is having a family history of the condition. Other risk factors include being overweight or obese; drinking too much alcohol; not getting enough potassium, calcium, or magnesium in your diet; and having diabetes or chronic kidney disease.   