# LoopGuard ‚Äî 8-Model Comparison Evaluation (v3)

**Key changes from v2:**
- Aggressive VRAM management between every model (explicit del + gc + empty_cache + synchronize)
- Better response parsing per model family (fewer false negatives)
- Real validation examples from training data (not synthetic test notes)
- Meditron-7B added (EPFL, pretrained on clinical guidelines)
- env var `PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True` set globally

## Models (ordered by narrative arc)

| # | Model | Size | Training data | Task FT |
|---|-------|------|---------------|---------|
| 1 | `google/gemma-2-2b-it` | 2B | General web | ‚ùå |
| 2 | `meta-llama/Llama-3.2-3B-Instruct` | 3B | General web | ‚ùå |
| 3 | `BioMistral/BioMistral-7B` | 7B | PubMed papers | ‚ùå |
| 4 | `epfl-llm/meditron-7b` | 7B | PubMed + **clinical guidelines** | ‚ùå |
| 5 | `Qwen/Qwen2.5-7B-Instruct` | 7B | General web | ‚ùå |
| 6 | `deepseek-ai/DeepSeek-R1-Distill-Qwen-7B` | 7B | General + reasoning distill | ‚ùå |
| 7 | `google/medgemma-1.5-4b-it` | 4B | Medical multimodal | ‚ùå |
| 8 | **LoopGuard v2** | 4B | Medical + **421 clinical notes** | ‚úÖ |

## Test set
10 real examples held out from training data (last 2 from 5 different specialty batches + 1 nephrology).
Ground truth urgency labels come directly from the gold-standard output JSON.
Urgency distribution: 5 high, 4 low, 1 medium.

In [None]:
# ============================================================
# CELL 1: Install + set env vars
# RESTART KERNEL after this cell
# ============================================================
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

!pip uninstall -y -q transformers peft bitsandbytes accelerate 2>/dev/null
!pip install -q transformers>=4.47.0
!pip install -q peft>=0.13.0
!pip install -q accelerate>=0.34.0
!pip install -q bitsandbytes>=0.46.1
!pip install -q matplotlib

print("‚úÖ Done. ‚ö†Ô∏è  RESTART KERNEL, then run from Cell 2.")

In [None]:
# ============================================================
# CELL 2: Global setup ‚Äî run once after kernel restart
# ============================================================
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import torch, json, re, gc, warnings
warnings.filterwarnings('ignore')
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel

print(f"PyTorch: {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    props = torch.cuda.get_device_properties(0)
    print(f"GPU: {props.name} | Total VRAM: {props.total_memory/1e9:.1f} GB")

BNB_CONFIG = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

def vram_free():
    free = torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_allocated()
    return free / 1e9

def nuke_vram(*args):
    """Aggressively free all GPU memory."""
    for obj in args:
        try:
            del obj
        except Exception:
            pass
    # Clear any lingering global model/tokenizer references
    for name in ['model', 'base', 'tokenizer']:
        if name in globals():
            try:
                del globals()[name]
            except Exception:
                pass
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.synchronize()
    gc.collect()
    torch.cuda.empty_cache()
    print(f"   VRAM after clear: {vram_free():.2f} GB free")

ALL_RESULTS = {}
print("\n‚úÖ Global setup complete")
print(f"   VRAM available: {vram_free():.2f} GB")

In [None]:
# ============================================================
# CELL 3: Test notes + scoring functions
# 10 REAL examples from training batches (last N of each file)
# Ground truth urgency from gold-standard outputs
# ============================================================

TEST_NOTES = [
    {
        "id": "SEP01", "urgency_gt": "high",
        "primary_hypothesis_gt": "Ascending cholangitis (post-operative complication)",
        "note": "62-year-old male with recent hospitalization for cholecystectomy (7 days ago) presents to ED with fever (102.1¬∞F), right upper quadrant pain, and jaundice. Reports incision looks okay but has had worsening abdominal pain. Vitals: Temp 102.1¬∞F, HR 106, BP 118/72, RR 18. Exam: Jaundiced, RUQ tenderness, surgical site clean without drainage. Murphy's sign negative (post-op). Labs show elevated bilirubin and transaminases. Concern for biliary leak or retained stone with cholangitis."
    },
    {
        "id": "SEP02", "urgency_gt": "high",
        "primary_hypothesis_gt": "Infective endocarditis (likely tricuspid valve)",
        "note": "41-year-old female injection drug user presents with 3 days of fever (103.6¬∞F), rigors, and new heart murmur. Reports feeling weak and having night sweats. Denies recent hospitalizations but admits to IV heroin use. Vitals: Temp 103.6¬∞F, HR 118, BP 102/58, RR 18, SpO2 96%. Exam: III/VI systolic murmur at apex (new per patient), scattered petechiae on palms, splinter hemorrhages in nails. Assessment: Likely infective endocarditis. Plan: 3 sets blood cultures from different sites, echocardiogram, admit for IV antibiotics."
    },
    {
        "id": "NEU01", "urgency_gt": "high",
        "primary_hypothesis_gt": "Fat embolism syndrome",
        "note": "72-year-old male post-op day 3 from hip replacement surgery presents with sudden dyspnea, confusion, and petechial rash on chest. Nurse reports he was fine 2 hours ago. Vitals: BP 94/58, HR 128, RR 32, SpO2 86% on 4L NC, Temp 100.8¬∞F. Exam: Confused, tachypneic, petechiae on anterior chest and conjunctiva, decreased breath sounds bilaterally. ABG: pH 7.32, PaCO2 32, PaO2 58. Assessment: Fat embolism syndrome. Plan: Supportive care, mechanical ventilation likely needed, orthopedic and ICU consult, consider steroids."
    },
    {
        "id": "CAR01", "urgency_gt": "high",
        "primary_hypothesis_gt": "Critical aortic stenosis with exertional syncope",
        "note": "56-year-old male with known severe aortic stenosis presents with syncope while climbing stairs. Reports 3-month history of progressive exertional dyspnea and chest pressure. Vitals: BP 98/64, HR 62, RR 18, SpO2 96%. Exam: Delayed carotid upstroke, sustained PMI, harsh systolic crescendo-decrescendo murmur at right upper sternal border radiating to carotids, paradoxical S2 splitting. Echocardiogram 6 months ago: severe AS, valve area 0.7 cm¬≤, mean gradient 52 mmHg, EF 55%. Assessment: Critical aortic stenosis with syncope. Plan: Admit telemetry, cardiology and cardiac surgery consult for urgent valve replacement, avoid dehydration and vasodilators."
    },
    {
        "id": "CAR02", "urgency_gt": "high",
        "primary_hypothesis_gt": "Acute-on-chronic subdural hematoma with anticoagulation",
        "note": "81-year-old male on warfarin for atrial fibrillation fell down stairs 6 hours ago. Initially seemed fine but now increasingly confused and lethargic. Family reports worsening headache. PMH: Atrial fibrillation on warfarin (INR usually 2-3). Vitals: BP 158/88, HR 72. Exam: Lethargic, GCS 13, confused, left-sided weakness 3/5, pupils equal and reactive. CT head: Right-sided acute-on-chronic subdural hematoma, 1.5cm thickness, 8mm midline shift. INR: 3.8. Assessment: Acute subdural hematoma on chronic subdural with coagulopathy. Plan: Reverse anticoagulation (vitamin K, PCC or FFP), neurosurgery consult for evacuation, admit ICU."
    },
    {
        "id": "LOW01", "urgency_gt": "low",
        "primary_hypothesis_gt": "Vitamin D deficiency, corrected with supplementation",
        "note": "50-year-old male presents for vitamin D deficiency follow-up. Was started on vitamin D3 2000 IU daily 3 months ago. Repeat labs show 25-OH vitamin D level increased from 18 to 35 ng/mL (goal 30-50). Reports no bone pain, no muscle weakness. PMH: Vitamin D deficiency. Assessment: Vitamin D insufficiency, now replete. Plan: Continue vitamin D3 1000-2000 IU daily for maintenance, recheck level in 6-12 months, ensure adequate calcium intake, weight-bearing exercise for bone health."
    },
    {
        "id": "LOW02", "urgency_gt": "low",
        "primary_hypothesis_gt": "Polycystic ovary syndrome, stable on medical management",
        "note": "28-year-old female with polycystic ovary syndrome (PCOS) on metformin 1000mg twice daily and combined oral contraceptive presents for follow-up. Reports regular menstrual cycles on OCP, no hirsutism worsening, lost 8 lbs with diet and exercise. Glucose tolerance test normal. Assessment: PCOS, managed with lifestyle and medications. Plan: Continue metformin and OCP, continue weight loss efforts (goal BMI <25), screen for metabolic syndrome annually, follow-up in 6 months."
    },
    {
        "id": "SCR01", "urgency_gt": "low",
        "primary_hypothesis_gt": "Routine breast cancer screening with dense breast tissue",
        "note": "50-year-old female presents for annual mammogram. Last mammogram 1 year ago showed dense breast tissue, otherwise normal. No breast symptoms, no masses palpated. PMH: None. Family history: Maternal aunt with breast cancer at age 65. Vitals: Normal. Exam: Normal breast exam, no masses or lymphadenopathy. Assessment: Routine breast cancer screening, dense breasts. Plan: Screening mammogram, consider supplemental screening (ultrasound or MRI) if extremely dense (category D), breast self-awareness education, continue annual screening."
    },
    {
        "id": "SCR02", "urgency_gt": "low",
        "primary_hypothesis_gt": "Generalized anxiety disorder, mild",
        "note": "28-year-old male presents for anxiety screening. Reports feeling anxious at work, difficulty concentrating, occasional palpitations. No panic attacks. Sleep and appetite normal. PMH: None. Medications: None. Vitals: BP 128/82, HR 88. Exam: Normal. GAD-7 score: 9 (mild anxiety). Assessment: Mild anxiety symptoms, screening positive. Plan: Lifestyle counseling (exercise, sleep hygiene, caffeine reduction, stress management), cognitive behavioral therapy referral, relaxation techniques, consider SSRI if symptoms worsen or persist, follow-up in 1 month."
    },
    {
        "id": "NEP01", "urgency_gt": "medium",
        "primary_hypothesis_gt": "HIV-associated nephropathy (HIVAN)",
        "note": "38-year-old female with HIV (CD4 200, viral load 85,000, non-adherent to ART) presents with 4-week history of progressive lower extremity edema and foamy urine. Reports fatigue. Denies hematuria or dysuria. PMH: HIV diagnosed 3 years ago. Medications: None currently. Vitals: BP 148/92. Exam: Facial edema, significant lower extremity edema. Labs: Creatinine 2.6, albumin 2.0, CD4 200, urinalysis 4+ protein, urine protein 8.5 g/day. Kidney biopsy: Collapsing variant of FSGS. Assessment: HIV-associated nephropathy (HIVAN). Plan: Immediate initiation of antiretroviral therapy, ACE inhibitor, prednisone, nephrology and infectious disease consult."
    },
]

print(f"‚úÖ {len(TEST_NOTES)} real validation notes loaded")
print(f"   high={sum(1 for n in TEST_NOTES if n['urgency_gt']=='high')}, "
      f"medium={sum(1 for n in TEST_NOTES if n['urgency_gt']=='medium')}, "
      f"low={sum(1 for n in TEST_NOTES if n['urgency_gt']=='low')}")

# ---- Field definitions ----
# We score for 6 fields. The field names are flexible ‚Äî any reasonable variant counts.
FIELD_PATTERNS = [
    r'primary[\s_-]*hypothesis',
    r'differential[\s_-]*diagnos',
    r'key[\s_-]*(supporting|symptoms|evidence|findings)',
    r'urgency[\s_-]*(level|classification|rating)?',
    r'tests?[\s_-]*(ordered|recommended|planned|to[\s_-]*order)',
    r'(clinical[\s_-]*)?reasoning',
]
REQUIRED_FIELD_COUNT = len(FIELD_PATTERNS)

def count_fields(text):
    """Count how many of the 6 required fields appear in the output."""
    text_lower = text.lower()
    return sum(1 for pattern in FIELD_PATTERNS if re.search(pattern, text_lower))

def extract_urgency(text):
    """Extract urgency level from model output."""
    text_lower = text.lower()
    # Pattern: 'urgency level: high' or 'urgency: medium' etc.
    m = re.search(r'urgency[\s_-]*(level|classification)?[:\s]+([\w]+)', text_lower)
    if m:
        val = m.group(2).strip()
        if val in ('high', 'medium', 'low'):
            return val
    # Fallback: look for urgency word close to level word
    for u in ('high', 'medium', 'low'):
        if re.search(rf'urgency.*\b{u}\b', text_lower):
            return u
    return 'unknown'

def score_output(text, ground_truth_urgency):
    fields = count_fields(text)
    urgency = extract_urgency(text)
    return {
        'fields_present': fields,
        'completeness_pct': round(100 * fields / REQUIRED_FIELD_COUNT),
        'valid_structure': fields >= 5,
        'urgency_detected': urgency,
        'urgency_match': urgency == ground_truth_urgency.lower(),
    }

def run_eval(model, tokenizer, prompt_fn, display_name, decode_fn=None):
    """Run evaluation on all 10 test notes. prompt_fn(note_text) -> prompt string."""
    results = []
    for i, note in enumerate(TEST_NOTES):
        prompt = prompt_fn(note['note'])
        inputs = tokenizer(prompt, return_tensors='pt', truncation=True,
                           max_length=768).to(model.device)
        with torch.no_grad():
            out = model.generate(
                **inputs, max_new_tokens=450, do_sample=False,
                repetition_penalty=1.1,
                pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id
            )
        raw = tokenizer.decode(out[0], skip_special_tokens=True)
        if decode_fn:
            generated = decode_fn(raw, prompt)
        else:
            # Default: strip prompt prefix
            generated = raw[len(prompt):].strip() if raw.startswith(prompt) else raw.strip()
        scores = score_output(generated, note['urgency_gt'])
        results.append({'note_id': note['id'], 'urgency_gt': note['urgency_gt'], **scores})
        tick = '‚úÖ' if scores['urgency_match'] else '‚ùå'
        print(f"  [{i+1:02d}/10] {note['id']} "
              f"fields={scores['fields_present']}/6 "
              f"urgency={scores['urgency_detected']} {tick}")
    avg_comp = sum(r['completeness_pct'] for r in results) / len(results)
    pct_valid = 100 * sum(1 for r in results if r['valid_structure']) / len(results)
    pct_urg = 100 * sum(1 for r in results if r['urgency_match']) / len(results)
    summary = {
        'model': display_name,
        'avg_completeness': round(avg_comp, 1),
        'pct_valid_structure': round(pct_valid, 1),
        'pct_urgency_correct': round(pct_urg, 1),
        'per_note': results,
    }
    ALL_RESULTS[display_name] = summary
    print(f"\n  üìä {display_name.replace(chr(10),' ')}")
    print(f"     Completeness: {avg_comp:.1f}% | Valid structure: {pct_valid:.1f}% | Urgency acc: {pct_urg:.1f}%")
    return summary

# Standard instruction used for all models (except Meditron which is base-only)
STD_INSTRUCTION = (
    "Extract diagnostic information from this clinical note.\n\n"
    "Clinical Note:\n{note}\n\n"
    "Output ONLY these 6 fields:\n"
    "PRIMARY HYPOTHESIS: [main diagnosis]\n"
    "DIFFERENTIAL DIAGNOSES: [comma-separated alternatives]\n"
    "KEY SUPPORTING EVIDENCE: [comma-separated findings]\n"
    "URGENCY LEVEL: [high/medium/low]\n"
    "TESTS ORDERED: [comma-separated tests]\n"
    "CLINICAL REASONING: [brief explanation]"
)

print("\n‚úÖ Scoring functions and test notes ready")

In [None]:
# ============================================================
# CELL 4: MODEL 1 ‚Äî Gemma 2 2B-it
# Chat template: <start_of_turn>user/model
# ~10 min
# ============================================================
DN = "Gemma 2 2B\n(general, no medical)"
print(f"\nüî¨ [1/8] google/gemma-2-2b-it")

tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-2b-it")
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2-2b-it", quantization_config=BNB_CONFIG,
    device_map="auto", dtype=torch.bfloat16)
model.eval()
print(f"‚úÖ Loaded | VRAM used: {torch.cuda.memory_allocated()/1e9:.2f} GB")

def prompt_gemma(note):
    msgs = [{"role": "user", "content": STD_INSTRUCTION.format(note=note)}]
    return tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)

def decode_gemma(raw, prompt):
    if "<start_of_turn>model" in raw:
        text = raw.split("<start_of_turn>model")[-1]
        return text.split("<end_of_turn>")[0].strip()
    return raw[len(prompt):].strip()

run_eval(model, tokenizer, prompt_gemma, DN, decode_gemma)
nuke_vram(model, tokenizer)

In [None]:
# ============================================================
# CELL 5: MODEL 2 ‚Äî Llama 3.2 3B Instruct
# Requires HF token. If not approved yet, skip and come back.
# Chat template: <|begin_of_text|><|start_header_id|> format
# ~10 min
# ============================================================
DN = "Llama 3.2 3B\n(general, no medical)"
print(f"\nüî¨ [2/8] meta-llama/Llama-3.2-3B-Instruct")
print(f"   VRAM available: {vram_free():.2f} GB")

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct")
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.2-3B-Instruct", quantization_config=BNB_CONFIG,
    device_map="auto", dtype=torch.bfloat16)
model.eval()
print(f"‚úÖ Loaded | VRAM used: {torch.cuda.memory_allocated()/1e9:.2f} GB")

def prompt_llama(note):
    msgs = [
        {"role": "system", "content": "You are a clinical documentation assistant. Output ONLY the structured format requested."},
        {"role": "user", "content": STD_INSTRUCTION.format(note=note)}
    ]
    return tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)

def decode_llama(raw, prompt):
    # Strip eot tokens, extract assistant content
    raw = raw.replace("<|eot_id|>", "").strip()
    # Split on assistant header
    parts = re.split(r'<\|start_header_id\|>assistant<\|end_header_id\|>', raw, flags=re.IGNORECASE)
    if len(parts) > 1:
        return parts[-1].strip()
    # Fallback: strip prompt
    return raw[len(prompt):].strip()

run_eval(model, tokenizer, prompt_llama, DN, decode_llama)
nuke_vram(model, tokenizer)

In [None]:
# ============================================================
# CELL 6: MODEL 3 ‚Äî BioMistral 7B
# Mistral [INST]...[/INST] format, NO system message
# Note: loads pytorch_model.bin (not safetensors), auto-conversion
# error in background thread is harmless ‚Äî model loads fine
# ~12 min
# ============================================================
DN = "BioMistral 7B\n(medical, PubMed pretrain)"
print(f"\nüî¨ [3/8] BioMistral/BioMistral-7B")
print(f"   VRAM available: {vram_free():.2f} GB")

tokenizer = AutoTokenizer.from_pretrained("BioMistral/BioMistral-7B")
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
    "BioMistral/BioMistral-7B", quantization_config=BNB_CONFIG,
    device_map="auto", dtype=torch.bfloat16)
model.eval()
print(f"‚úÖ Loaded | VRAM used: {torch.cuda.memory_allocated()/1e9:.2f} GB")

def prompt_biomistral(note):
    # No system message for BioMistral
    content = STD_INSTRUCTION.format(note=note)
    msgs = [{"role": "user", "content": content}]
    try:
        return tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
    except Exception:
        return f"<s>[INST] {content} [/INST]"

def decode_biomistral(raw, prompt):
    if "[/INST]" in raw:
        return raw.split("[/INST]")[-1].strip()
    return raw[len(prompt):].strip()

run_eval(model, tokenizer, prompt_biomistral, DN, decode_biomistral)
nuke_vram(model, tokenizer)

In [None]:
# ============================================================
# CELL 7: MODEL 4 ‚Äî Meditron 7B (EPFL)
# BASE model ‚Äî no instruction tuning. Raw completion format.
# Pretrained on PubMed + 46K clinical guidelines (closest to
# clinical data without task fine-tuning)
# ~12 min
# ============================================================
DN = "Meditron 7B\n(medical, guidelines pretrain)"
print(f"\nüî¨ [4/8] epfl-llm/meditron-7b")
print(f"   VRAM available: {vram_free():.2f} GB")
print("   Base model only ‚Äî using raw completion format (no chat template)")

tokenizer = AutoTokenizer.from_pretrained("epfl-llm/meditron-7b")
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
    "epfl-llm/meditron-7b", quantization_config=BNB_CONFIG,
    device_map="auto", dtype=torch.bfloat16)
model.eval()
print(f"‚úÖ Loaded | VRAM used: {torch.cuda.memory_allocated()/1e9:.2f} GB")

# Meditron is a base model ‚Äî prompt as raw completion ending on first field label
# so the model continues the structured format
MEDITRON_PROMPT = (
    "### Clinical Note:\n{note}\n\n"
    "### Diagnostic Extraction:\n"
    "PRIMARY HYPOTHESIS:"
)

def prompt_meditron(note):
    return MEDITRON_PROMPT.format(note=note)

def decode_meditron(raw, prompt):
    # Strip prompt, then prepend the field label we used as prompt suffix
    generated = raw[len(prompt):].strip() if raw.startswith(prompt) else raw.strip()
    return "PRIMARY HYPOTHESIS:" + generated

run_eval(model, tokenizer, prompt_meditron, DN, decode_meditron)
nuke_vram(model, tokenizer)

In [None]:
# ============================================================
# CELL 8: MODEL 5 ‚Äî Qwen2.5 7B Instruct
# ChatML <|im_start|> format via apply_chat_template
# ~12 min
# ============================================================
DN = "Qwen2.5 7B\n(general, best <10B)"
print(f"\nüî¨ [5/8] Qwen/Qwen2.5-7B-Instruct")
print(f"   VRAM available: {vram_free():.2f} GB")

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct")
model = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen2.5-7B-Instruct", quantization_config=BNB_CONFIG,
    device_map="auto", dtype=torch.bfloat16)
model.eval()
print(f"‚úÖ Loaded | VRAM used: {torch.cuda.memory_allocated()/1e9:.2f} GB")

def prompt_qwen(note):
    msgs = [
        {"role": "system", "content": "You are a clinical documentation assistant. Output ONLY the structured format requested. No preamble."},
        {"role": "user", "content": STD_INSTRUCTION.format(note=note)}
    ]
    return tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)

def decode_qwen(raw, prompt):
    # Qwen outputs only new tokens cleanly with input_ids slicing
    # but since we decode full sequence, strip prompt-equivalent
    if "<|im_start|>assistant" in raw:
        text = raw.split("<|im_start|>assistant")[-1]
        return text.replace("<|im_end|>", "").strip()
    return raw[len(prompt):].strip()

run_eval(model, tokenizer, prompt_qwen, DN, decode_qwen)
nuke_vram(model, tokenizer)

In [None]:
# ============================================================
# CELL 9: MODEL 6 ‚Äî DeepSeek-R1 Distill Qwen 7B
# ChatML (Qwen2.5 tokenizer) via apply_chat_template
# Per DeepSeek docs: NO system prompt, all in user turn
# Produces <think>...</think> reasoning blocks ‚Äî strip before scoring
# ~15 min
# ============================================================
DN = "DeepSeek-R1 7B\n(reasoning, no medical)"
print(f"\nüî¨ [6/8] deepseek-ai/DeepSeek-R1-Distill-Qwen-7B")
print(f"   VRAM available: {vram_free():.2f} GB")

tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-R1-Distill-Qwen-7B")
model = AutoModelForCausalLM.from_pretrained(
    "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B", quantization_config=BNB_CONFIG,
    device_map="auto", dtype=torch.bfloat16)
model.eval()
print(f"‚úÖ Loaded | VRAM used: {torch.cuda.memory_allocated()/1e9:.2f} GB")

def prompt_deepseek(note):
    # No system message per DeepSeek-R1 official docs
    msgs = [{"role": "user", "content": STD_INSTRUCTION.format(note=note)}]
    return tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)

def decode_deepseek(raw, prompt):
    if "<|im_start|>assistant" in raw:
        text = raw.split("<|im_start|>assistant")[-1]
        text = text.replace("<|im_end|>", "").strip()
    else:
        text = raw[len(prompt):].strip()
    # Strip <think>...</think> reasoning blocks
    text = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL).strip()
    return text

run_eval(model, tokenizer, prompt_deepseek, DN, decode_deepseek)
nuke_vram(model, tokenizer)

In [None]:
# ============================================================
# CELL 10: MODEL 7 ‚Äî Base MedGemma 1.5 4B-it (zero-shot)
# Same base model as LoopGuard ‚Äî proves fine-tuning adds value
# ~10 min
# ============================================================
DN = "Base MedGemma 4B\n(medical, zero-shot)"
print(f"\nüî¨ [7/8] google/medgemma-1.5-4b-it (zero-shot)")
print(f"   VRAM available: {vram_free():.2f} GB")

tokenizer = AutoTokenizer.from_pretrained("google/medgemma-1.5-4b-it")
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
    "google/medgemma-1.5-4b-it", quantization_config=BNB_CONFIG,
    device_map="auto", dtype=torch.bfloat16)
model.eval()
print(f"‚úÖ Loaded | VRAM used: {torch.cuda.memory_allocated()/1e9:.2f} GB")

def prompt_medgemma(note):
    msgs = [{"role": "user", "content": STD_INSTRUCTION.format(note=note)}]
    return tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)

def decode_gemma(raw, prompt):
    if "<start_of_turn>model" in raw:
        text = raw.split("<start_of_turn>model")[-1]
        return text.split("<end_of_turn>")[0].strip()
    return raw[len(prompt):].strip()

run_eval(model, tokenizer, prompt_medgemma, DN, decode_gemma)
nuke_vram(model, tokenizer)

In [None]:
# ============================================================
# CELL 11: MODEL 8 ‚Äî LoopGuard v2 (fine-tuned MedGemma)
# Load base + LoRA adapter from Kaggle model or /kaggle/working
# ADAPTER_DIR: update path if you uploaded adapter as Kaggle model
# ~10 min
# ============================================================
BASE_ID = "google/medgemma-1.5-4b-it"
# Update this path if you uploaded adapter as a Kaggle model:
ADAPTER_DIR = "/kaggle/working/medgemma-hypothesis-extraction-v2"
# Fallback: check Kaggle model mount
if not os.path.exists(ADAPTER_DIR):
    import glob
    candidates = glob.glob("/kaggle/input/**/adapter_config.json", recursive=True)
    if candidates:
        ADAPTER_DIR = os.path.dirname(candidates[0])
        print(f"   Found adapter at: {ADAPTER_DIR}")
    else:
        raise FileNotFoundError(f"Adapter not found. Upload medgemma-loopguard-v2 as Kaggle model and add to notebook.")

DN = "LoopGuard v2\n(fine-tuned MedGemma)"
print(f"\nüî¨ [8/8] LoopGuard v2")
print(f"   Base: {BASE_ID}")
print(f"   Adapter: {ADAPTER_DIR}")
print(f"   VRAM available: {vram_free():.2f} GB")

tokenizer = AutoTokenizer.from_pretrained(BASE_ID)
tokenizer.pad_token = tokenizer.eos_token
base = AutoModelForCausalLM.from_pretrained(
    BASE_ID, quantization_config=BNB_CONFIG,
    device_map="auto", dtype=torch.bfloat16)
model = PeftModel.from_pretrained(base, ADAPTER_DIR)
model.eval()
print(f"‚úÖ Loaded with adapter | VRAM used: {torch.cuda.memory_allocated()/1e9:.2f} GB")

# Use the exact prompt template from fine-tuning
LOOPGUARD_TMPL = (
    "<start_of_turn>user\n"
    "Extract diagnostic information from this clinical note.\n\n"
    "Clinical Note:\n{note}<end_of_turn>\n"
    "<start_of_turn>model\n"
)

def prompt_loopguard(note):
    return LOOPGUARD_TMPL.format(note=note)

def decode_loopguard(raw, prompt):
    if "<start_of_turn>model" in raw:
        text = raw.split("<start_of_turn>model")[-1]
        return text.split("<end_of_turn>")[0].strip()
    return raw[len(prompt):].strip()

run_eval(model, tokenizer, prompt_loopguard, DN, decode_loopguard)
nuke_vram(model, base, tokenizer)

In [None]:
# ============================================================
# CELL 12: Score Table
# ============================================================
ORDERED_MODELS = [
    "Gemma 2 2B\n(general, no medical)",
    "Llama 3.2 3B\n(general, no medical)",
    "BioMistral 7B\n(medical, PubMed pretrain)",
    "Meditron 7B\n(medical, guidelines pretrain)",
    "Qwen2.5 7B\n(general, best <10B)",
    "DeepSeek-R1 7B\n(reasoning, no medical)",
    "Base MedGemma 4B\n(medical, zero-shot)",
    "LoopGuard v2\n(fine-tuned MedGemma)",
]

print("\n" + "=" * 82)
print(f"{'MODEL':<40} {'COMPLETENESS':>13} {'VALID STRUCT':>13} {'URGENCY ACC':>13}")
print("=" * 82)
for m in ORDERED_MODELS:
    if m not in ALL_RESULTS:
        print(f"{m.replace(chr(10),' '):<40} {'NOT RUN':>41}")
        continue
    s = ALL_RESULTS[m]
    name = m.replace('\n', ' ')
    mark = " ‚úÖ" if "LoopGuard" in m else ""
    print(f"{name:<40} {s['avg_completeness']:>12.1f}% {s['pct_valid_structure']:>12.1f}% {s['pct_urgency_correct']:>12.1f}%{mark}")
print("=" * 82)

if "LoopGuard v2\n(fine-tuned MedGemma)" in ALL_RESULTS:
    ft = ALL_RESULTS["LoopGuard v2\n(fine-tuned MedGemma)"]
    others = [v for k, v in ALL_RESULTS.items() if "LoopGuard" not in k]
    if others:
        best_comp = max(v['avg_completeness'] for v in others)
        best_urg  = max(v['pct_urgency_correct'] for v in others)
        print(f"\nLoopGuard vs best competitor:")
        print(f"  Completeness: {ft['avg_completeness']:.1f}% vs {best_comp:.1f}% (+{ft['avg_completeness']-best_comp:.1f}pp)")
        print(f"  Urgency acc:  {ft['pct_urgency_correct']:.1f}% vs {best_urg:.1f}% ({ft['pct_urgency_correct']-best_urg:+.1f}pp)")

# Save
with open('/kaggle/working/eval_comparison_8models.json', 'w') as f:
    json.dump({k: {"summary": {"model": v["model"],
                               "avg_completeness": v["avg_completeness"],
                               "pct_valid_structure": v["pct_valid_structure"],
                               "pct_urgency_correct": v["pct_urgency_correct"]},
                   "per_note": v["per_note"]} for k, v in ALL_RESULTS.items()}, f, indent=2)
print("\n‚úÖ Saved ‚Üí /kaggle/working/eval_comparison_8models.json")

In [None]:
# ============================================================
# CELL 13: Bar Chart ‚Üí eval_comparison.png
# ============================================================
models_present = [m for m in ORDERED_MODELS if m in ALL_RESULTS]
if not models_present:
    print("No results to plot yet.")
else:
    summaries = [ALL_RESULTS[m] for m in models_present]
    labels = [m.replace('\n', '\n') for m in models_present]

    def get_color(m):
        if "LoopGuard" in m: return '#28a745'
        if any(x in m for x in ["MedGemma", "BioMistral", "Meditron"]): return '#4a90d9'
        return '#9e9e9e'
    colors = [get_color(m) for m in models_present]

    metrics = [
        ("avg_completeness",    "Field Completeness (%)",  "All 6 fields present"),
        ("pct_valid_structure", "Valid Structure (%)",      "‚â•5/6 fields present"),
        ("pct_urgency_correct", "Urgency Accuracy (%)",     "Matches ground truth"),
    ]

    fig, axes = plt.subplots(1, 3, figsize=(20, 7))
    fig.patch.set_facecolor('#f8f9fa')

    for ax, (key, title, desc) in zip(axes, metrics):
        values = [s[key] for s in summaries]
        x = np.arange(len(labels))
        bars = ax.bar(x, values, color=colors, width=0.65, zorder=3,
                      edgecolor='white', linewidth=0.5)
        ax.set_xticks(x)
        ax.set_xticklabels(labels, fontsize=7)
        ax.set_ylim(0, 118)
        ax.set_ylabel("Score (%)", fontsize=10)
        ax.set_title(title, fontsize=12, fontweight='bold', pad=10)
        ax.set_facecolor('#ffffff')
        ax.grid(axis='y', alpha=0.35, zorder=0)
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        for bar, val in zip(bars, values):
            ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1.5,
                    f'{val:.0f}%', ha='center', va='bottom', fontsize=8.5, fontweight='bold')
        ax.text(0.5, -0.20, desc, transform=ax.transAxes,
                ha='center', fontsize=8, color='#666', style='italic')

    legend_items = [
        mpatches.Patch(color='#9e9e9e', label='General model (no medical pretraining)'),
        mpatches.Patch(color='#4a90d9', label='Medical model (no task fine-tuning)'),
        mpatches.Patch(color='#28a745', label='LoopGuard v2 ‚Äî MedGemma + clinical note fine-tuning'),
    ]
    fig.legend(handles=legend_items, loc='upper center',
               bbox_to_anchor=(0.5, 1.04), ncol=3, fontsize=9, frameon=True)
    fig.suptitle(
        'LoopGuard v2 vs 7 Comparable Open-Source Models (all ‚â§7B, locally deployable)\n'
        '10 Real Clinical Notes ¬∑ 3 Metrics ¬∑ Zero-shot except LoopGuard v2',
        fontsize=11, fontweight='bold', y=1.10)

    plt.tight_layout()
    plt.savefig('/kaggle/working/eval_comparison.png', dpi=150,
                bbox_inches='tight', facecolor=fig.get_facecolor())
    print("‚úÖ Chart saved ‚Üí /kaggle/working/eval_comparison.png")
    plt.show()

In [None]:
# ============================================================
# CELL 14: Writeup-ready text
# ============================================================
print("\nüìù COPY THIS INTO WRITEUP ‚Äî Technical Details section\n")

print("| Model | Size | Completeness | Valid Structure | Urgency Accuracy |")
print("|-------|------|-------------|-----------------|------------------|")
for m in ORDERED_MODELS:
    if m not in ALL_RESULTS:
        continue
    s = ALL_RESULTS[m]
    name = m.replace('\n', ' ')
    # Extract size from display name
    size = "4B" if "MedGemma" in m or "LoopGuard" in m else ("2B" if "2B" in m else ("3B" if "3B" in m else "7B"))
    mark = " ‚úÖ" if "LoopGuard" in m else ""
    print(f"| {name}{mark} | {size} | {s['avg_completeness']:.0f}% | {s['pct_valid_structure']:.0f}% | {s['pct_urgency_correct']:.0f}% |")

print("\nüìä KEY CLAIMS FOR WRITEUP:")
if "LoopGuard v2\n(fine-tuned MedGemma)" in ALL_RESULTS:
    ft = ALL_RESULTS["LoopGuard v2\n(fine-tuned MedGemma)"]
    others = {k: v for k, v in ALL_RESULTS.items() if "LoopGuard" not in k}
    if others:
        best_k = max(others, key=lambda k: others[k]['avg_completeness'])
        print(f"  - LoopGuard completeness: {ft['avg_completeness']:.0f}% vs best competitor ({best_k.replace(chr(10),' ')}): {others[best_k]['avg_completeness']:.0f}%")
        print(f"  - LoopGuard urgency accuracy: {ft['pct_urgency_correct']:.0f}%")
        print(f"  - LoopGuard valid structure rate: {ft['pct_valid_structure']:.0f}%")
        med_models = {k: v for k, v in others.items() if any(x in k for x in ['MedGemma', 'BioMistral', 'Meditron'])}
        if med_models:
            best_med = max(med_models, key=lambda k: med_models[k]['avg_completeness'])
            print(f"  - vs best medical non-FT ({best_med.replace(chr(10),' ')}): completeness {med_models[best_med]['avg_completeness']:.0f}% ‚Üí FT adds +{ft['avg_completeness']-med_models[best_med]['avg_completeness']:.0f}pp")

print("\n‚úÖ Download: eval_comparison.png + eval_comparison_8models.json")