# LoopGuard ‚Äî 7-Model Comparison Evaluation

**Purpose:** Prove that MedGemma + task-specific fine-tuning is the correct approach for privacy-sensitive clinical NLP. Covers Criterion 1 (HAI-DEF Use, 20%) and Criterion 4 (Product Feasibility, 20%).

## Models (all ‚â§7B, all open-source, all locally deployable)

| # | Model | Size | Type | Narrative role |
|---|-------|------|------|----------------|
| 1 | `google/gemma-2-2b-it` | 2B | General | Minimum baseline |
| 2 | `meta-llama/Llama-3.2-3B-Instruct` | 3B | General | Best popular 3B |
| 3 | `BioMistral/BioMistral-7B` | 7B | Medical (PubMed) | Literature ‚â† clinical notes |
| 4 | `Qwen/Qwen2.5-7B-Instruct` | 7B | General | Best general 7B |
| 5 | `deepseek-ai/DeepSeek-R1-Distill-Qwen-7B` | 7B | Reasoning | Reasoning ‚â† domain knowledge |
| 6 | `google/medgemma-1.5-4b-it` | 4B | Medical | Right domain, needs fine-tuning |
| 7 | **LoopGuard v2** | 4B | Medical FT | ‚úÖ Right domain + right task |

## Chat template notes (from official docs)
- **Gemma 2**: `<start_of_turn>user/model` ‚Äî via `apply_chat_template()`
- **Llama 3.2**: `<|begin_of_text|><|start_header_id|>` ‚Äî via `apply_chat_template()`
- **BioMistral**: `[INST]...[/INST]` (Mistral v0.1 format) ‚Äî **no system message** supported
- **Qwen2.5**: ChatML `<|im_start|>` ‚Äî via `apply_chat_template()`
- **DeepSeek-R1 Distill**: ChatML (Qwen2.5 tokenizer) ‚Äî via `apply_chat_template()`, **no system prompt**, produces `<think>` tokens to strip
- **MedGemma / LoopGuard**: `<start_of_turn>user/model` ‚Äî same as Gemma 2

## Workflow
1. **Cell 1** ‚Äî Install + restart kernel
2. **Cell 2** ‚Äî Imports + shared config
3. **Cell 3** ‚Äî Test notes + scoring functions
4. **Cells 4‚Äì10** ‚Äî Run each model (load ‚Üí infer ‚Üí free VRAM)
5. **Cell 11** ‚Äî Score all results + print table
6. **Cell 12** ‚Äî Save bar chart PNG
7. **Cell 13** ‚Äî Print writeup-ready text

In [None]:
# ============================================================
# CELL 1: Install Dependencies ‚Äî restart kernel after
# ============================================================
!pip uninstall -y -q transformers peft bitsandbytes accelerate
!pip install -q transformers>=4.47.0
!pip install -q peft>=0.13.0
!pip install -q accelerate>=0.34.0
!pip install -q bitsandbytes>=0.46.1
!pip install -q matplotlib

print("‚úÖ Done. ‚ö†Ô∏è RESTART KERNEL, then run from Cell 2.")

In [None]:
# ============================================================
# CELL 2: Imports + Shared BnB Config
# ============================================================
import torch
import json
import re
import os
import gc
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel

print(f"‚úÖ Imports OK")
print(f"   PyTorch: {torch.__version__}")
print(f"   CUDA: {torch.cuda.is_available()} | GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'none'}")
print(f"   VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

# Shared 4-bit quantization config used for all models
BNB_CONFIG = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

def free_vram(model=None, tokenizer=None):
    """Delete model/tokenizer objects and free GPU memory."""
    if model is not None:
        del model
    if tokenizer is not None:
        del tokenizer
    gc.collect()
    torch.cuda.empty_cache()
    print(f"   VRAM freed ‚Üí {torch.cuda.memory_allocated()/1e9:.2f} GB in use")

print("\n‚úÖ Shared config ready")

In [None]:
# ============================================================
# CELL 3: Test Notes, Scoring Functions, Results Store
# ============================================================

TEST_NOTES = [
    {"id": "N01", "specialty": "Oncology",       "urgency_gt": "high",
     "note": "58-year-old female presents with 3-month history of intermittent abdominal bloating and early satiety. Reports feeling full after eating small amounts. Postmenopausal (LMP 7 years ago). Abdomen: mildly distended, positive fluid wave. CA-125 ordered. Assessment: Rule out ovarian malignancy. Plan: CA-125, transvaginal ultrasound, follow-up 6 weeks."},
    {"id": "N02", "specialty": "Cardiology",     "urgency_gt": "high",
     "note": "58-year-old male presents with sudden onset crushing substernal chest pressure, 8/10, radiating to left arm and jaw. Started 45 minutes ago while shoveling snow. Diaphoresis, nausea, dyspnea. PMH: Hypertension, hyperlipidemia, smoking 1 PPD x 30 years. BP 160/95, HR 102. EKG shows ST elevations 2-3mm in leads II, III, aVF. Assessment: Acute inferior STEMI. Plan: Activate cath lab, aspirin, heparin."},
    {"id": "N03", "specialty": "Pulmonology",    "urgency_gt": "high",
     "note": "42-year-old female with sudden onset dyspnea and right-sided pleuritic chest pain. 6-hour flight from Europe yesterday. On oral contraceptives. BP 108/68, HR 118, RR 28, SpO2 88% on RA. Right leg with 2+ edema and calf tenderness. EKG: S1Q3T3 pattern. Assessment: High probability pulmonary embolism. Plan: CTPA, D-dimer, anticoagulation."},
    {"id": "N04", "specialty": "Neurology",      "urgency_gt": "high",
     "note": "67-year-old male presents with sudden onset left-sided weakness and facial droop, onset 2 hours ago. PMH: Atrial fibrillation, hypertension, not on anticoagulation. BP 188/104, HR 78 irregular. Neuro: Left hemiparesis, left facial droop, dysarthria. Assessment: Acute ischemic stroke. Plan: Stat CT head, labs, neurology consult, tPA evaluation."},
    {"id": "N05", "specialty": "Sepsis",         "urgency_gt": "high",
     "note": "72-year-old female presents with fever 102.8F, confusion, and burning urination x 3 days. PMH: Type 2 diabetes, recurrent UTIs. BP 88/52, HR 118, RR 24, SpO2 94%. Exam: Altered mental status, CVA tenderness. Labs: WBC 18.4, lactate 3.2. Assessment: Urosepsis. Plan: Blood cultures x2, urine culture, IV antibiotics, 30cc/kg fluid bolus."},
    {"id": "N06", "specialty": "Endocrinology",  "urgency_gt": "medium",
     "note": "45-year-old female with 6-month history of fatigue, weight gain (15 lbs), cold intolerance, constipation, and dry skin. Hair thinning noted. No family history of thyroid disease. Vitals normal. Exam: Delayed DTR relaxation, mild facial puffiness, dry skin. Assessment: Hypothyroidism. Plan: TSH, free T4, CBC, CMP."},
    {"id": "N07", "specialty": "Hematology",     "urgency_gt": "medium",
     "note": "34-year-old female presents with 3-month history of fatigue, exertional dyspnea, and palpitations. Heavy menstrual periods. Vitals: BP 110/70, HR 96. Exam: Pale conjunctivae, tachycardia. Assessment: Iron deficiency anemia secondary to menorrhagia. Plan: CBC, iron studies, ferritin, reticulocyte count, gynecology referral."},
    {"id": "N08", "specialty": "Gastroenterology","urgency_gt": "medium",
     "note": "52-year-old male with 2-month history of right upper quadrant pain, jaundice, and unintentional weight loss of 18 lbs. Pruritus. PMH: Chronic alcohol use. Vitals: BP 118/76, HR 88. Exam: Jaundice, RUQ tenderness, palpable gallbladder. Assessment: Obstructive jaundice, rule out cholangiocarcinoma or pancreatic head mass. Plan: LFTs, bilirubin, CT abdomen/pelvis, MRCP."},
    {"id": "N09", "specialty": "Nephrology",     "urgency_gt": "medium",
     "note": "61-year-old male with type 2 diabetes and hypertension presenting for routine follow-up. Creatinine has risen from 1.2 to 2.1 over 6 months. Foamy urine noted. BP 158/96 despite lisinopril 20mg. Exam: 1+ bilateral lower extremity edema. Assessment: Progressive diabetic nephropathy, stage 3b CKD. Plan: Nephrology referral, urine albumin/creatinine, renal ultrasound."},
    {"id": "N10", "specialty": "Preventive",     "urgency_gt": "low",
     "note": "45-year-old male presents for annual wellness exam. No complaints. PMH: None. Nonsmoker, social drinker. BMI 27. BP 124/78, HR 68. Exam normal. Assessment: Healthy male due for age-appropriate screening. Plan: Lipid panel, fasting glucose, colorectal cancer screening discussion, HIV screen, tetanus booster."},
]

REQUIRED_FIELDS = [
    "PRIMARY HYPOTHESIS",
    "DIFFERENTIAL DIAGNOSES",
    "KEY SUPPORTING EVIDENCE",
    "URGENCY LEVEL",
    "TESTS ORDERED",
    "CLINICAL REASONING",
]

# Storage for all model results
ALL_RESULTS = {}

def extract_urgency(text):
    match = re.search(r'urgency level[:\s]+([\w]+)', text.lower())
    if match:
        val = match.group(1).strip()
        if val in ("high", "medium", "low"):
            return val
    # Fallback: look for standalone urgency words
    for u in ["high", "medium", "low"]:
        if re.search(rf'\burgency\b.*\b{u}\b', text.lower()):
            return u
    return "unknown"

def score_output(generated_text, ground_truth_urgency):
    text_upper = generated_text.upper()
    fields_found = sum(1 for f in REQUIRED_FIELDS if f in text_upper)
    urgency_detected = extract_urgency(generated_text)
    return {
        "fields_present": fields_found,
        "completeness_pct": round(100 * fields_found / len(REQUIRED_FIELDS)),
        "valid_structure": fields_found >= 5,
        "urgency_detected": urgency_detected,
        "urgency_match": urgency_detected == ground_truth_urgency.lower(),
    }

def summarize_results(results, model_display_name):
    avg_completeness = sum(r['completeness_pct'] for r in results) / len(results)
    pct_valid = 100 * sum(1 for r in results if r['valid_structure']) / len(results)
    pct_urgency = 100 * sum(1 for r in results if r['urgency_match']) / len(results)
    return {
        "model": model_display_name,
        "avg_completeness": round(avg_completeness, 1),
        "pct_valid_structure": round(pct_valid, 1),
        "pct_urgency_correct": round(pct_urgency, 1),
        "per_note": results,
    }

print(f"‚úÖ {len(TEST_NOTES)} test notes ready")
print(f"   Urgency breakdown: high={sum(1 for n in TEST_NOTES if n['urgency_gt']=='high')}, medium={sum(1 for n in TEST_NOTES if n['urgency_gt']=='medium')}, low={sum(1 for n in TEST_NOTES if n['urgency_gt']=='low')}")

In [None]:
# ============================================================
# CELL 4: MODEL 1 ‚Äî Gemma 2 2B-it
# Chat template: <start_of_turn>user/model  (via apply_chat_template)
# ~10 min
# ============================================================
MODEL_ID = "google/gemma-2-2b-it"
DISPLAY_NAME = "Gemma 2 2B\n(general, no medical)"
print(f"\nüî¨ [{1}/7] {MODEL_ID}")

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID, quantization_config=BNB_CONFIG, device_map="auto",
    trust_remote_code=True, dtype=torch.bfloat16)
model.eval()
print(f"‚úÖ Loaded | VRAM: {torch.cuda.memory_allocated()/1e9:.2f} GB")

results = []
for i, note in enumerate(TEST_NOTES):
    messages = [{"role": "user", "content": (
        "Extract diagnostic information from this clinical note.\n\n"
        f"Clinical Note:\n{note['note']}\n\n"
        "Output ONLY these 6 fields:\n"
        "PRIMARY HYPOTHESIS: [diagnosis]\n"
        "DIFFERENTIAL DIAGNOSES: [comma-separated list]\n"
        "KEY SUPPORTING EVIDENCE: [comma-separated list]\n"
        "URGENCY LEVEL: [high/medium/low]\n"
        "TESTS ORDERED: [comma-separated list]\n"
        "CLINICAL REASONING: [brief reasoning]"
    )}]
    prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=768).to(model.device)
    with torch.no_grad():
        out = model.generate(**inputs, max_new_tokens=400, do_sample=False,
                             repetition_penalty=1.1, pad_token_id=tokenizer.eos_token_id)
    response = tokenizer.decode(out[0], skip_special_tokens=True)
    # Extract model turn only
    if "<start_of_turn>model" in response:
        generated = response.split("<start_of_turn>model")[-1].split("<end_of_turn>")[0].strip()
    else:
        generated = response[len(prompt):].strip()
    scores = score_output(generated, note["urgency_gt"])
    results.append({"note_id": note["id"], **scores})
    print(f"  [{i+1}/10] {note['id']} fields={scores['fields_present']}/6 urgency_match={scores['urgency_match']}")

ALL_RESULTS[DISPLAY_NAME] = summarize_results(results, DISPLAY_NAME)
free_vram(model, tokenizer)

In [None]:
# ============================================================
# CELL 5: MODEL 2 ‚Äî Llama 3.2 3B Instruct
# Chat template: <|begin_of_text|><|start_header_id|>  (via apply_chat_template)
# Note: Llama requires HF token ‚Äî make sure internet is enabled and you're logged in
# ~10 min
# ============================================================
MODEL_ID = "meta-llama/Llama-3.2-3B-Instruct"
DISPLAY_NAME = "Llama 3.2 3B\n(general, no medical)"
print(f"\nüî¨ [{2}/7] {MODEL_ID}")
print("   Note: Requires HF token. If it fails, run: from huggingface_hub import notebook_login; notebook_login()")

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID, quantization_config=BNB_CONFIG, device_map="auto",
    trust_remote_code=True, dtype=torch.bfloat16)
model.eval()
print(f"‚úÖ Loaded | VRAM: {torch.cuda.memory_allocated()/1e9:.2f} GB")

results = []
for i, note in enumerate(TEST_NOTES):
    # Llama 3.2 supports system messages via apply_chat_template
    messages = [
        {"role": "system", "content": "You are a clinical documentation assistant. Output ONLY the structured format requested. No additional text."},
        {"role": "user", "content": (
            "Extract diagnostic information from this clinical note.\n\n"
            f"Clinical Note:\n{note['note']}\n\n"
            "Output ONLY these 6 fields:\n"
            "PRIMARY HYPOTHESIS: [diagnosis]\n"
            "DIFFERENTIAL DIAGNOSES: [comma-separated list]\n"
            "KEY SUPPORTING EVIDENCE: [comma-separated list]\n"
            "URGENCY LEVEL: [high/medium/low]\n"
            "TESTS ORDERED: [comma-separated list]\n"
            "CLINICAL REASONING: [brief reasoning]"
        )}
    ]
    prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=768).to(model.device)
    with torch.no_grad():
        out = model.generate(**inputs, max_new_tokens=400, do_sample=False,
                             repetition_penalty=1.1, pad_token_id=tokenizer.eos_token_id)
    response = tokenizer.decode(out[0], skip_special_tokens=True)
    # Extract assistant turn: appears after <|start_header_id|>assistant<|end_header_id|>
    if "assistant" in response.lower():
        parts = re.split(r'<\|start_header_id\|>assistant<\|end_header_id\|>', response, flags=re.IGNORECASE)
        generated = parts[-1].strip() if len(parts) > 1 else response[len(prompt):].strip()
    else:
        generated = response[len(prompt):].strip()
    # Strip any trailing eot tokens
    generated = generated.replace("<|eot_id|>", "").strip()
    scores = score_output(generated, note["urgency_gt"])
    results.append({"note_id": note["id"], **scores})
    print(f"  [{i+1}/10] {note['id']} fields={scores['fields_present']}/6 urgency_match={scores['urgency_match']}")

ALL_RESULTS[DISPLAY_NAME] = summarize_results(results, DISPLAY_NAME)
free_vram(model, tokenizer)

In [None]:
# ============================================================
# CELL 6: MODEL 3 ‚Äî BioMistral 7B
# Chat template: [INST]...[/INST]  (Mistral v0.1 format)
# IMPORTANT: BioMistral does NOT support system messages.
# Per official docs: user content only, no system role.
# ~12 min
# ============================================================
MODEL_ID = "BioMistral/BioMistral-7B"
DISPLAY_NAME = "BioMistral 7B\n(medical, PubMed pretrain)"
print(f"\nüî¨ [{3}/7] {MODEL_ID}")
print("   Medical model pretrained on PubMed Central ‚Äî but no clinical note fine-tuning")

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID, quantization_config=BNB_CONFIG, device_map="auto",
    trust_remote_code=True, dtype=torch.bfloat16)
model.eval()
print(f"‚úÖ Loaded | VRAM: {torch.cuda.memory_allocated()/1e9:.2f} GB")

results = []
for i, note in enumerate(TEST_NOTES):
    # BioMistral uses Mistral [INST] format, NO system message supported
    # We use apply_chat_template which handles this correctly
    messages = [{"role": "user", "content": (
        "Extract diagnostic information from this clinical note. "
        "Output ONLY the 6 structured fields below, no other text.\n\n"
        f"Clinical Note:\n{note['note']}\n\n"
        "PRIMARY HYPOTHESIS: [diagnosis]\n"
        "DIFFERENTIAL DIAGNOSES: [comma-separated list]\n"
        "KEY SUPPORTING EVIDENCE: [comma-separated list]\n"
        "URGENCY LEVEL: [high/medium/low]\n"
        "TESTS ORDERED: [comma-separated list]\n"
        "CLINICAL REASONING: [brief reasoning]"
    )}]
    try:
        prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    except Exception:
        # Fallback to manual Mistral format if apply_chat_template fails
        prompt = f"<s>[INST] {messages[0]['content']} [/INST]"
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=768).to(model.device)
    with torch.no_grad():
        out = model.generate(**inputs, max_new_tokens=400, do_sample=False,
                             repetition_penalty=1.1, pad_token_id=tokenizer.eos_token_id)
    response = tokenizer.decode(out[0], skip_special_tokens=True)
    # Extract content after [/INST]
    if "[/INST]" in response:
        generated = response.split("[/INST]")[-1].strip()
    else:
        generated = response[len(prompt):].strip()
    scores = score_output(generated, note["urgency_gt"])
    results.append({"note_id": note["id"], **scores})
    print(f"  [{i+1}/10] {note['id']} fields={scores['fields_present']}/6 urgency_match={scores['urgency_match']}")

ALL_RESULTS[DISPLAY_NAME] = summarize_results(results, DISPLAY_NAME)
free_vram(model, tokenizer)

In [None]:
# ============================================================
# CELL 7: MODEL 4 ‚Äî Qwen2.5 7B Instruct
# Chat template: ChatML <|im_start|>  (via apply_chat_template)
# Arguably the strongest general <10B model currently
# ~12 min
# ============================================================
MODEL_ID = "Qwen/Qwen2.5-7B-Instruct"
DISPLAY_NAME = "Qwen2.5 7B\n(general, best <10B)"
print(f"\nüî¨ [{4}/7] {MODEL_ID}")
print("   Strongest general open-source model under 10B ‚Äî still lacks medical domain knowledge")

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID, quantization_config=BNB_CONFIG, device_map="auto",
    trust_remote_code=True, dtype=torch.bfloat16)
model.eval()
print(f"‚úÖ Loaded | VRAM: {torch.cuda.memory_allocated()/1e9:.2f} GB")

results = []
for i, note in enumerate(TEST_NOTES):
    messages = [
        {"role": "system", "content": "You are a clinical documentation assistant. Output ONLY the structured format requested. No preamble, no explanation."},
        {"role": "user", "content": (
            "Extract diagnostic information from this clinical note.\n\n"
            f"Clinical Note:\n{note['note']}\n\n"
            "Output ONLY these 6 fields:\n"
            "PRIMARY HYPOTHESIS: [diagnosis]\n"
            "DIFFERENTIAL DIAGNOSES: [comma-separated list]\n"
            "KEY SUPPORTING EVIDENCE: [comma-separated list]\n"
            "URGENCY LEVEL: [high/medium/low]\n"
            "TESTS ORDERED: [comma-separated list]\n"
            "CLINICAL REASONING: [brief reasoning]"
        )}
    ]
    text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer([text], return_tensors="pt", truncation=True, max_length=768).to(model.device)
    with torch.no_grad():
        out = model.generate(**inputs, max_new_tokens=400, do_sample=False,
                             repetition_penalty=1.1)
    generated = tokenizer.decode(out[0][inputs.input_ids.shape[1]:], skip_special_tokens=True).strip()
    scores = score_output(generated, note["urgency_gt"])
    results.append({"note_id": note["id"], **scores})
    print(f"  [{i+1}/10] {note['id']} fields={scores['fields_present']}/6 urgency_match={scores['urgency_match']}")

ALL_RESULTS[DISPLAY_NAME] = summarize_results(results, DISPLAY_NAME)
free_vram(model, tokenizer)

In [None]:
# ============================================================
# CELL 8: MODEL 5 ‚Äî DeepSeek-R1 Distill Qwen 7B
# Chat template: ChatML (Qwen2.5 tokenizer) via apply_chat_template
# IMPORTANT per DeepSeek docs:
#   - No system prompt ‚Äî all instructions in user turn
#   - Produces <think>...</think> reasoning tokens ‚Äî must strip before scoring
#   - Recommended temperature 0.6 (we use greedy for fairness)
# ~15 min (generates think tokens before output)
# ============================================================
MODEL_ID = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"
DISPLAY_NAME = "DeepSeek-R1 7B\n(reasoning, no medical)"
print(f"\nüî¨ [{5}/7] {MODEL_ID}")
print("   Reasoning model ‚Äî will produce <think> tokens. We strip these before scoring.")
print("   Per DeepSeek docs: no system prompt, instructions in user turn only.")

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID, quantization_config=BNB_CONFIG, device_map="auto",
    trust_remote_code=True, dtype=torch.bfloat16)
model.eval()
print(f"‚úÖ Loaded | VRAM: {torch.cuda.memory_allocated()/1e9:.2f} GB")

def strip_think_tokens(text):
    """Remove DeepSeek-R1 chain-of-thought <think>...</think> blocks."""
    # Remove think blocks
    text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL)
    return text.strip()

results = []
for i, note in enumerate(TEST_NOTES):
    # DeepSeek-R1: NO system prompt per official recommendation
    messages = [{"role": "user", "content": (
        "Extract diagnostic information from this clinical note. "
        "Output ONLY the 6 structured fields below. Do not include any other text or reasoning.\n\n"
        f"Clinical Note:\n{note['note']}\n\n"
        "PRIMARY HYPOTHESIS: [diagnosis]\n"
        "DIFFERENTIAL DIAGNOSES: [comma-separated list]\n"
        "KEY SUPPORTING EVIDENCE: [comma-separated list]\n"
        "URGENCY LEVEL: [high/medium/low]\n"
        "TESTS ORDERED: [comma-separated list]\n"
        "CLINICAL REASONING: [brief reasoning]"
    )}]
    text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer([text], return_tensors="pt", truncation=True, max_length=768).to(model.device)
    with torch.no_grad():
        out = model.generate(**inputs, max_new_tokens=600, do_sample=False,
                             repetition_penalty=1.1)
    raw = tokenizer.decode(out[0][inputs.input_ids.shape[1]:], skip_special_tokens=True).strip()
    # Strip <think> reasoning blocks before scoring
    generated = strip_think_tokens(raw)
    scores = score_output(generated, note["urgency_gt"])
    results.append({"note_id": note["id"], **scores, "had_think_tokens": "<think>" in raw})
    think_note = " [had <think> tokens]" if "<think>" in raw else ""
    print(f"  [{i+1}/10] {note['id']} fields={scores['fields_present']}/6 urgency_match={scores['urgency_match']}{think_note}")

ALL_RESULTS[DISPLAY_NAME] = summarize_results(results, DISPLAY_NAME)
free_vram(model, tokenizer)

In [None]:
# ============================================================
# CELL 9: MODEL 6 ‚Äî Base MedGemma 1.5 4B-it (zero-shot)
# Chat template: <start_of_turn>user/model  (via apply_chat_template)
# Same model as LoopGuard base ‚Äî proves fine-tuning adds value
# ~10 min
# ============================================================
MODEL_ID = "google/medgemma-1.5-4b-it"
DISPLAY_NAME = "Base MedGemma 4B\n(medical, zero-shot)"
print(f"\nüî¨ [{6}/7] {MODEL_ID}")
print("   Same base as LoopGuard ‚Äî proves task fine-tuning adds value on top of domain pretraining")

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID, quantization_config=BNB_CONFIG, device_map="auto",
    trust_remote_code=True, dtype=torch.bfloat16)
model.eval()
print(f"‚úÖ Loaded | VRAM: {torch.cuda.memory_allocated()/1e9:.2f} GB")

results = []
for i, note in enumerate(TEST_NOTES):
    messages = [{"role": "user", "content": (
        "Extract diagnostic information from this clinical note.\n\n"
        f"Clinical Note:\n{note['note']}\n\n"
        "Output ONLY these 6 fields:\n"
        "PRIMARY HYPOTHESIS: [diagnosis]\n"
        "DIFFERENTIAL DIAGNOSES: [comma-separated list]\n"
        "KEY SUPPORTING EVIDENCE: [comma-separated list]\n"
        "URGENCY LEVEL: [high/medium/low]\n"
        "TESTS ORDERED: [comma-separated list]\n"
        "CLINICAL REASONING: [brief reasoning]"
    )}]
    prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=768).to(model.device)
    with torch.no_grad():
        out = model.generate(**inputs, max_new_tokens=400, do_sample=False,
                             repetition_penalty=1.1, pad_token_id=tokenizer.eos_token_id)
    response = tokenizer.decode(out[0], skip_special_tokens=True)
    if "<start_of_turn>model" in response:
        generated = response.split("<start_of_turn>model")[-1].split("<end_of_turn>")[0].strip()
    else:
        generated = response[len(prompt):].strip()
    scores = score_output(generated, note["urgency_gt"])
    results.append({"note_id": note["id"], **scores})
    print(f"  [{i+1}/10] {note['id']} fields={scores['fields_present']}/6 urgency_match={scores['urgency_match']}")

ALL_RESULTS[DISPLAY_NAME] = summarize_results(results, DISPLAY_NAME)
free_vram(model, tokenizer)

In [None]:
# ============================================================
# CELL 10: MODEL 7 ‚Äî LoopGuard v2 (fine-tuned MedGemma)
# Load base MedGemma + v2 LoRA adapter
# ~10 min
# ============================================================
BASE_MODEL_ID = "google/medgemma-1.5-4b-it"
ADAPTER_DIR = "/kaggle/working/medgemma-hypothesis-extraction-v2"
DISPLAY_NAME = "LoopGuard v2\n(fine-tuned MedGemma)"
print(f"\nüî¨ [{7}/7] LoopGuard v2")
assert os.path.exists(ADAPTER_DIR), f"‚ùå Adapter not found at {ADAPTER_DIR}. Run fine-tuning notebook first."

tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID)
tokenizer.pad_token = tokenizer.eos_token
base = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL_ID, quantization_config=BNB_CONFIG, device_map="auto",
    trust_remote_code=True, dtype=torch.bfloat16)
model = PeftModel.from_pretrained(base, ADAPTER_DIR)
model.eval()
print(f"‚úÖ Loaded with adapter | VRAM: {torch.cuda.memory_allocated()/1e9:.2f} GB")

# LoopGuard uses the exact prompt template from fine-tuning
LOOPGUARD_TEMPLATE = (
    "<start_of_turn>user\n"
    "Extract diagnostic information from this clinical note.\n\n"
    "Clinical Note:\n{note}<end_of_turn>\n"
    "<start_of_turn>model\n"
)

results = []
for i, note in enumerate(TEST_NOTES):
    prompt = LOOPGUARD_TEMPLATE.format(note=note["note"])
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=768).to(model.device)
    with torch.no_grad():
        out = model.generate(**inputs, max_new_tokens=400, do_sample=False,
                             repetition_penalty=1.1, pad_token_id=tokenizer.eos_token_id)
    response = tokenizer.decode(out[0], skip_special_tokens=True)
    if "<start_of_turn>model" in response:
        generated = response.split("<start_of_turn>model")[-1].split("<end_of_turn>")[0].strip()
    else:
        generated = response[len(prompt):].strip()
    scores = score_output(generated, note["urgency_gt"])
    results.append({"note_id": note["id"], **scores})
    print(f"  [{i+1}/10] {note['id']} fields={scores['fields_present']}/6 urgency_match={scores['urgency_match']}")

ALL_RESULTS[DISPLAY_NAME] = summarize_results(results, DISPLAY_NAME)
free_vram(model, tokenizer)

print(f"\n‚úÖ All 7 models complete. {len(ALL_RESULTS)} result sets stored.")

In [None]:
# ============================================================
# CELL 11: Score Table
# ============================================================
ORDERED_MODELS = [
    "Gemma 2 2B\n(general, no medical)",
    "Llama 3.2 3B\n(general, no medical)",
    "BioMistral 7B\n(medical, PubMed pretrain)",
    "Qwen2.5 7B\n(general, best <10B)",
    "DeepSeek-R1 7B\n(reasoning, no medical)",
    "Base MedGemma 4B\n(medical, zero-shot)",
    "LoopGuard v2\n(fine-tuned MedGemma)",
]

print("\n" + "=" * 80)
print(f"{'MODEL':<38} {'COMPLETENESS':>13} {'VALID STRUCT':>13} {'URGENCY ACC':>13}")
print("=" * 80)
for model_name in ORDERED_MODELS:
    if model_name not in ALL_RESULTS:
        print(f"{model_name.replace(chr(10),' '):<38} {'NOT RUN':>13}")
        continue
    s = ALL_RESULTS[model_name]
    name = model_name.replace('\n', ' ')
    marker = " ‚úÖ" if "LoopGuard" in model_name else ""
    print(f"{name:<38} {s['avg_completeness']:>12.1f}% {s['pct_valid_structure']:>12.1f}% {s['pct_urgency_correct']:>12.1f}%{marker}")
print("=" * 80)

# Save to JSON
with open('/kaggle/working/eval_comparison_7models.json', 'w') as f:
    json.dump({k: {"summary": {"model": v["model"], "avg_completeness": v["avg_completeness"],
                               "pct_valid_structure": v["pct_valid_structure"],
                               "pct_urgency_correct": v["pct_urgency_correct"]},
                   "per_note": v["per_note"]} for k, v in ALL_RESULTS.items()}, f, indent=2)
print("\n‚úÖ Results saved to /kaggle/working/eval_comparison_7models.json")

In [None]:
# ============================================================
# CELL 12: Bar Chart
# Saves eval_comparison.png to /kaggle/working/
# ============================================================
models_in_order = [m for m in ORDERED_MODELS if m in ALL_RESULTS]
summaries = [ALL_RESULTS[m] for m in models_in_order]
labels = [m.replace('\n', '\n') for m in models_in_order]

# Color scheme: grays for general, blue for medical non-FT, green for LoopGuard
colors = []
for m in models_in_order:
    if "LoopGuard" in m:
        colors.append('#28a745')   # green ‚Äî winner
    elif "MedGemma" in m or "BioMistral" in m:
        colors.append('#4a90d9')   # blue ‚Äî medical
    else:
        colors.append('#9e9e9e')   # gray ‚Äî general

metrics = [
    ("avg_completeness",    "Field Completeness (%)",  "Out of 6 required fields"),
    ("pct_valid_structure", "Valid Structure (%)",      "‚â•5/6 fields present"),
    ("pct_urgency_correct", "Urgency Accuracy (%)",     "Matches ground truth label"),
]

fig, axes = plt.subplots(1, 3, figsize=(18, 7))
fig.patch.set_facecolor('#f8f9fa')

for ax, (metric_key, metric_title, metric_desc) in zip(axes, metrics):
    values = [s[metric_key] for s in summaries]
    x = np.arange(len(labels))
    bars = ax.bar(x, values, color=colors, width=0.65, zorder=3, edgecolor='white', linewidth=0.5)
    ax.set_xticks(x)
    ax.set_xticklabels(labels, fontsize=7.5)
    ax.set_ylim(0, 115)
    ax.set_ylabel("Score (%)", fontsize=10)
    ax.set_title(metric_title, fontsize=12, fontweight='bold', pad=10)
    ax.set_facecolor('#ffffff')
    ax.grid(axis='y', alpha=0.35, zorder=0)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    for bar, val in zip(bars, values):
        ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1.5,
                f'{val:.0f}%', ha='center', va='bottom', fontsize=9, fontweight='bold')
    ax.text(0.5, -0.18, metric_desc, transform=ax.transAxes,
            ha='center', fontsize=8, color='#666', style='italic')

legend_items = [
    mpatches.Patch(color='#9e9e9e', label='General model (no medical pretraining)'),
    mpatches.Patch(color='#4a90d9', label='Medical model (literature pretrain, no task FT)'),
    mpatches.Patch(color='#28a745', label='LoopGuard v2 (MedGemma + clinical note fine-tuning)'),
]
fig.legend(handles=legend_items, loc='upper center', bbox_to_anchor=(0.5, 1.04),
           ncol=3, fontsize=9, frameon=True)

fig.suptitle(
    'LoopGuard v2 vs 6 Comparable Open-Source Models (all ‚â§7B, all locally deployable)\n'
    '10 Clinical Notes ¬∑ 3 Metrics ¬∑ Zero-shot except LoopGuard v2',
    fontsize=12, fontweight='bold', y=1.10)

plt.tight_layout()
plt.savefig('/kaggle/working/eval_comparison.png', dpi=150, bbox_inches='tight',
            facecolor=fig.get_facecolor())
print("‚úÖ Chart saved ‚Üí /kaggle/working/eval_comparison.png")
print("   Download this file for the competition writeup and video.")
plt.show()

In [None]:
# ============================================================
# CELL 13: Writeup-Ready Text
# Copy directly into competition writeup (Technical Details section)
# ============================================================
print("\nüìù WRITEUP-READY TABLE (paste into Technical Details section)")
print("=" * 70)

rows = []
for model_name in ORDERED_MODELS:
    if model_name not in ALL_RESULTS:
        continue
    s = ALL_RESULTS[model_name]
    name = model_name.replace('\n', ' ')
    rows.append(f"| {name:<38} | {s['avg_completeness']:>11.0f}% | {s['pct_valid_structure']:>14.0f}% | {s['pct_urgency_correct']:>14.0f}% |")

header = "| Model                                  | Completeness | Valid Structure | Urgency Accuracy |"
seperator = "|----------------------------------------|-------------|-----------------|------------------|"
print(header)
print(seperrator := "|" + "-"*40 + "|" + "-"*13 + "|" + "-"*17 + "|" + "-"*18 + "|")
for row in rows:
    print(row)

if "LoopGuard v2\n(fine-tuned MedGemma)" in ALL_RESULTS:
    ft = ALL_RESULTS["LoopGuard v2\n(fine-tuned MedGemma)"]
    competitors = [v for k, v in ALL_RESULTS.items() if "LoopGuard" not in k]
    best_competitor_completeness = max(c['avg_completeness'] for c in competitors)
    print(f"""
\nüìä KEY NUMBERS FOR WRITEUP:
   LoopGuard v2 completeness: {ft['avg_completeness']:.0f}%
   Best competitor completeness: {best_competitor_completeness:.0f}%
   Gap: +{ft['avg_completeness'] - best_competitor_completeness:.0f} percentage points
   LoopGuard valid structure rate: {ft['pct_valid_structure']:.0f}%
   LoopGuard urgency accuracy: {ft['pct_urgency_correct']:.0f}%
""")

print("=" * 70)
print("‚úÖ Download eval_comparison.png + eval_comparison_7models.json from /kaggle/working/")