# üè• Triage Category Benchmark (Local Unsloth Model)

**Verify your trained model immediately in Colab.**

This notebook loads your locally trained adapter (`nursesim_lora_llama3_robust`) from **Google Drive** and runs it against the 15 gold-standard test cases.

### ‚úÖ UPDATED: Now uses exact training prompts AND History format.

### Prerequisites
- You must have run the training notebook first.
- The folder `nursesim_lora_llama3_robust` must exist in your Google Drive.

In [None]:
%%capture
# Install dependencies (same as training)
!pip install --upgrade "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
!pip install --no-deps trl peft accelerate bitsandbytes xformers
!pip install pandas matplotlib

## 1. Load Your Trained Model from Drive

In [None]:
from unsloth import FastLanguageModel
import torch
import re
import os
from google.colab import drive

# 1. Mount Drive
drive.mount('/content/drive')

# 2. Path to your saved model in Drive
adapter_path = "/content/drive/MyDrive/nursesim_lora_llama3_robust"

print(f"üîÑ Loading adapter from: {adapter_path}...")

if not os.path.exists(adapter_path):
    print(f"‚ö†Ô∏è Warning: Path not found: {adapter_path}")
    print("Trying local folder 'nursesim_lora_llama3_robust' instead...")
    adapter_path = "nursesim_lora_llama3_robust"

try:
    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name = adapter_path,
        max_seq_length = 2048,
        dtype = None,
        load_in_4bit = True,
    )
    FastLanguageModel.for_inference(model)
    print("‚úÖ Model loaded successfully!")
except Exception as e:
    print(f"‚ùå Error loading model: {e}")
    print("Did you run the training notebook? Check if the folder exists.")

In [None]:
from dataclasses import dataclass
import pandas as pd
import matplotlib.pyplot as plt

@dataclass
class TriageCase:
    id: str
    complaint: str
    hr: int
    bp_sys: int
    bp_dia: int
    spo2: int
    rr: int
    temp: float
    avpu: str
    history: str
    expected: int

# The 15 Gold-Standard Benchmark Cases (Complete Clinical Data)
CASES = [
    # IMMEDIATE (1)
    TriageCase("IMM_01", "Crushing chest pain radiating to left arm, sweating, nausea", 110, 160, 95, 94, 24, 37.2, "A", "HTN, T2DM, MI 2019", 1),
    TriageCase("IMM_02", "Severe headache worst of life, sudden onset, neck stiffness", 88, 150, 90, 99, 16, 38.2, "A", "Migraine History", 1),
    TriageCase("IMM_03", "Unresponsive after seizure, still postictal", 120, 140, 85, 92, 12, 37.5, "P", "Epilepsy", 1),
    
    # VERY URGENT (2)
    TriageCase("VU_01", "Confusion and productive cough, green sputum, weak", 102, 105, 65, 92, 22, 38.9, "V", "COPD", 2),
    TriageCase("VU_02", "Vague malaise 2 days, something is wrong, epigastric discomfort", 72, 138, 84, 96, 18, 36.8, "A", "HTN", 2),
    TriageCase("VU_03", "Difficulty breathing, worsening over 4 hours", 105, 130, 80, 91, 26, 37.0, "A", "Asthma", 2),
    
    # URGENT (3)
    TriageCase("URG_01", "RLQ abdominal pain 12 hours, worsening, vomiting once", 98, 128, 82, 98, 18, 38.6, "A", "Appendectomy 22yo (No)", 3),
    TriageCase("URG_02", "Non-healing foot wound 2 weeks, redness and discharge", 92, 145, 88, 97, 16, 37.4, "A", "T2DM", 3),
    TriageCase("URG_03", "Severe back pain sudden onset, radiating to flank", 95, 155, 95, 98, 20, 37.2, "A", "None", 3),
    
    # STANDARD (4)
    TriageCase("STD_01", "Twisted ankle playing football, swelling, can bear weight", 75, 125, 80, 99, 14, 36.8, "A", "None", 4),
    TriageCase("STD_02", "Cut on hand from kitchen knife, bleeding controlled", 78, 120, 75, 99, 12, 37.0, "A", "None", 4),
    TriageCase("STD_03", "Earache for 2 days, mild fever, child otherwise well", 90, 100, 65, 99, 18, 38.0, "A", "None", 4),
    
    # NON-URGENT (5)
    TriageCase("NU_01", "Sore throat 3 days, mild difficulty swallowing", 78, 118, 72, 99, 14, 37.8, "A", "None", 5),
    TriageCase("NU_02", "Runny nose and mild cough for 5 days, no fever", 72, 115, 70, 99, 12, 36.9, "A", "None", 5),
    TriageCase("NU_03", "Wants medication refill, no acute symptoms", 70, 120, 78, 99, 12, 36.8, "A", "T2DM", 5),
]
print(f"Loaded {len(CASES)} test cases.")

In [None]:
# EXACT Training Prompt Format
TRAINING_INSTRUCTION = "You are an expert A&E Triage Nurse using the Manchester Triage System. Assess the following patient and provide your triage decision with clinical reasoning."

def format_input(c):
    # Mimic the training data's dictionary-style history
    history_dict = {
        'relevant_PMH': c.history, 
        'note': 'History structured as dict to match training data format'
    }
    return f"""PATIENT PRESENTING TO A&E TRIAGE

Chief Complaint: "{c.complaint}"

Vitals:
- HR: {c.hr} bpm
- BP: {c.bp_sys}/{c.bp_dia} mmHg
- SpO2: {c.spo2}%
- RR: {c.rr} /min
- Temp: {c.temp}C
- AVPU: {c.avpu}

History: {history_dict}

WAITING ROOM: 12 patients | AVAILABLE BEDS: 4

What is your triage decision?"""

def predict(case):
    # Create prompts using Alpaca format
    alpaca_prompt = f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction:
{TRAINING_INSTRUCTION}

### Input:
{format_input(case)}

### Response:
"""
    
    inputs = tokenizer(
        [alpaca_prompt],
        return_tensors="pt",
    ).to("cuda")
    
    outputs = model.generate(**inputs, max_new_tokens=256, use_cache=True)
    response = tokenizer.batch_decode(outputs)[0]
    
    # Extract result
    response_clean = response.split("### Response:")[-1].strip()
    
    # Extract Category Number
    try:
        # Look for "Category: X"
        match = re.search(r"Category:\s*(\d)", response_clean)
        if match:
            return int(match.group(1)), response_clean
        
        # Fallback to looking for numbers
        match = re.search(r'\b([1-5])\b', response_clean)
        return (int(match.group(1)), response_clean) if match else (-1, response_clean)
    except:
        return -1, response_clean

In [None]:
print("üî¨ Running Benchmark on Local Model (Correct Trigger Prompt + History Dict)...\n")
results = []

for c in CASES:
    pred, full_resp = predict(c)
    is_correct = (pred == c.expected)
    icon = "‚úÖ" if is_correct else "‚ùå"
    print(f"{icon} {c.id}: Pred={pred} | Exp={c.expected} | {c.complaint[:40]}...")
    
    results.append({
        'id': c.id,
        'category': c.expected,
        'prediction': pred,
        'correct': is_correct,
        'within_1': abs(pred - c.expected) <= 1 if pred > 0 else False
    })

# Stats
df = pd.DataFrame(results)
acc = df['correct'].mean() * 100
acc_1 = df['within_1'].mean() * 100

print("\n" + "="*40)
print(f"üèÜ FINAL ACCURACY: {acc:.1f}%")
print(f"üéØ WITHIN ¬±1 CAT: {acc_1:.1f}%")
print("="*40)

# Breakdown by Category
for cat in range(1, 6):
    sub = df[df['category'] == cat]
    if len(sub) > 0:
        cat_acc = sub['correct'].mean() * 100
        print(f"Category {cat}: {cat_acc:.0f}%")