# üè• Triage Category Prediction Benchmark

**Head-to-Head: NurseSim-Triage vs Gemini 3 vs GPT-4o**

Testing which model most accurately predicts triage priority categories from patient presentations.

---

In [None]:
!pip install -q gradio_client google-generativeai openai pandas matplotlib

In [None]:
import json, re, time
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from dataclasses import dataclass
from typing import Dict
from gradio_client import Client
import google.generativeai as genai
import openai
from google.colab import userdata

openai.api_key = userdata.get('OPENAI_API_KEY')
genai.configure(api_key=userdata.get('GOOGLE_API_KEY'))
print("‚úÖ Setup complete")

In [None]:
# Test cases with expected triage categories
@dataclass
class TriageCase:
    id: str
    complaint: str
    hr: int
    bp: str
    spo2: int
    temp: float
    expected: int  # 1=Immediate, 2=Very Urgent, 3=Urgent, 4=Standard, 5=Non-Urgent
    reasoning: str

CASES = [
    # IMMEDIATE (1)
    TriageCase("IMM_01", "Crushing chest pain radiating to left arm, sweating, nausea", 110, "160/95", 94, 37.2, 1, "Classic ACS"),
    TriageCase("IMM_02", "Severe headache worst of life, sudden onset, neck stiffness", 88, "150/90", 99, 38.2, 1, "Possible SAH/meningitis"),
    TriageCase("IMM_03", "Unresponsive after seizure, still postictal", 120, "140/85", 92, 37.5, 1, "Post-ictal state"),
    
    # VERY URGENT (2)
    TriageCase("VU_01", "Confusion and productive cough, green sputum, weak", 102, "105/65", 92, 38.9, 2, "Possible sepsis/CAP"),
    TriageCase("VU_02", "Vague malaise 2 days, something is wrong, epigastric discomfort, 78F", 72, "138/84", 96, 36.8, 2, "Atypical MI elderly"),
    TriageCase("VU_03", "Difficulty breathing, worsening over 4 hours, previous asthma", 105, "130/80", 91, 37.0, 2, "Asthma exacerbation"),
    
    # URGENT (3)
    TriageCase("URG_01", "RLQ abdominal pain 12 hours, worsening, vomiting once", 98, "128/82", 98, 38.6, 3, "Possible appendicitis"),
    TriageCase("URG_02", "Non-healing foot wound 2 weeks, redness and discharge, diabetic", 92, "145/88", 97, 37.4, 3, "Diabetic foot infection"),
    TriageCase("URG_03", "Severe back pain sudden onset, radiating to flank", 95, "155/95", 98, 37.2, 3, "Possible renal colic"),
    
    # STANDARD (4)
    TriageCase("STD_01", "Twisted ankle playing football, swelling, can bear weight", 75, "125/80", 99, 36.8, 4, "Ankle sprain"),
    TriageCase("STD_02", "Cut on hand from kitchen knife, bleeding controlled", 78, "120/75", 99, 37.0, 4, "Minor laceration"),
    TriageCase("STD_03", "Earache for 2 days, mild fever, child otherwise well", 90, "100/65", 99, 38.0, 4, "Otitis media"),
    
    # NON-URGENT (5)
    TriageCase("NU_01", "Sore throat 3 days, mild difficulty swallowing", 78, "118/72", 99, 37.8, 5, "Viral pharyngitis"),
    TriageCase("NU_02", "Runny nose and mild cough for 5 days, no fever", 72, "115/70", 99, 36.9, 5, "Common cold"),
    TriageCase("NU_03", "Wants medication refill, no acute symptoms", 70, "120/78", 99, 36.8, 5, "Medication refill"),
]

print(f"‚úÖ {len(CASES)} test cases loaded")
for cat in [1,2,3,4,5]:
    n = sum(1 for c in CASES if c.expected == cat)
    names = {1:'IMMEDIATE', 2:'VERY URGENT', 3:'URGENT', 4:'STANDARD', 5:'NON-URGENT'}
    print(f"   {names[cat]}: {n} cases")

## ü§ñ Connect Models

In [None]:
# NurseSim-Triage
print("Connecting to NurseSim-Triage...")
try:
    nursesim = Client("NurseCitizenDeveloper/NurseSim-Triage-Demo")
    print("‚úÖ NurseSim connected")
except Exception as e:
    print(f"‚ö†Ô∏è NurseSim: {e}")
    nursesim = None

In [None]:
# Gemini
GEMINI_MODELS = ['gemini-3-pro', 'gemini-3', 'gemini-2.0-flash-exp']
gemini = None
gemini_name = "Gemini"
for m in GEMINI_MODELS:
    try:
        gemini = genai.GenerativeModel(m)
        gemini.generate_content("test")
        gemini_name = m
        print(f"‚úÖ {m}")
        break
    except: continue

# GPT
gpt = openai.OpenAI()
print("‚úÖ GPT-4o")

In [None]:
# Query functions
TRIAGE_PROMPT = """You are a triage nurse. Given this patient, assign a triage category:
1 = IMMEDIATE (life-threatening)
2 = VERY URGENT (severe condition)
3 = URGENT (needs prompt care)
4 = STANDARD (can wait)
5 = NON-URGENT (minor issue)

Patient: {complaint}
HR: {hr}, BP: {bp}, SpO2: {spo2}%, Temp: {temp}¬∞C

Reply with ONLY the number (1-5)."""

def extract_category(text):
    """Extract triage category 1-5 from text"""
    text = str(text).lower()
    # Check for category words first
    if 'immediate' in text or 'resuscitation' in text: return 1
    if 'very urgent' in text: return 2
    if 'urgent' in text and 'non' not in text: return 3
    if 'standard' in text: return 4
    if 'non-urgent' in text or 'non urgent' in text or 'minor' in text: return 5
    # Look for number
    match = re.search(r'\b([1-5])\b', text)
    return int(match.group(1)) if match else -1

def query_nursesim(c):
    if not nursesim: return -1
    try:
        result = nursesim.predict(
            complaint=c.complaint,
            hr=float(c.hr),
            bp=c.bp,
            spo2=float(c.spo2),
            temp=float(c.temp),
            api_name="/gradio_predict"
        )
        return extract_category(str(result))
    except Exception as e:
        print(f"   NurseSim error: {str(e)[:50]}")
        return -1

def query_gemini(c):
    if not gemini: return -1
    try:
        prompt = TRIAGE_PROMPT.format(complaint=c.complaint, hr=c.hr, bp=c.bp, spo2=c.spo2, temp=c.temp)
        result = gemini.generate_content(prompt)
        return extract_category(result.text)
    except: return -1

def query_gpt(c):
    try:
        prompt = TRIAGE_PROMPT.format(complaint=c.complaint, hr=c.hr, bp=c.bp, spo2=c.spo2, temp=c.temp)
        resp = gpt.chat.completions.create(
            model="gpt-4o",
            messages=[{"role": "user", "content": prompt}],
            max_tokens=50
        )
        return extract_category(resp.choices[0].message.content)
    except: return -1

print("‚úÖ Query functions ready")

In [None]:
# Test NurseSim connection
print("Testing NurseSim...")
test = query_nursesim(CASES[0])
print(f"NurseSim returned: {test} (expected: {CASES[0].expected})")

## üî¨ Run Benchmark

In [None]:
print("üî¨ Running Triage Category Benchmark...\n")
results = []

for c in CASES:
    print(f"{c.id}: Expected={c.expected}", end=" ")
    
    ns = query_nursesim(c)
    gm = query_gemini(c)
    gp = query_gpt(c)
    
    print(f"| NS={ns} | Gem={gm} | GPT={gp}")
    
    results.append({
        'case': c.id,
        'complaint': c.complaint[:40],
        'expected': c.expected,
        'nursesim': ns,
        'gemini': gm,
        'gpt': gp,
        'ns_exact': ns == c.expected,
        'gm_exact': gm == c.expected,
        'gp_exact': gp == c.expected,
        'ns_within1': abs(ns - c.expected) <= 1 if ns > 0 else False,
        'gm_within1': abs(gm - c.expected) <= 1 if gm > 0 else False,
        'gp_within1': abs(gp - c.expected) <= 1 if gp > 0 else False,
    })
    time.sleep(0.5)

df = pd.DataFrame(results)
print("\n‚úÖ Complete!")

In [None]:
# Results
print("\nüìä TRIAGE CATEGORY PREDICTION ACCURACY")
print("=" * 55)

models = [
    ('NurseSim-Triage', 'ns_exact', 'ns_within1', 'nursesim'),
    (gemini_name, 'gm_exact', 'gm_within1', 'gemini'),
    ('GPT-4o', 'gp_exact', 'gp_within1', 'gpt')
]

summary = {}
for name, exact_col, within1_col, pred_col in models:
    valid = df[df[pred_col] > 0]
    if len(valid) > 0:
        exact_acc = valid[exact_col].mean() * 100
        within1_acc = valid[within1_col].mean() * 100
        summary[name] = {'exact': exact_acc, 'within1': within1_acc, 'n': len(valid)}
        print(f"\n{name}:")
        print(f"  Exact Match: {valid[exact_col].sum()}/{len(valid)} ({exact_acc:.1f}%)")
        print(f"  Within ¬±1: {valid[within1_col].sum()}/{len(valid)} ({within1_acc:.1f}%)")

In [None]:
# Visualization
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

names = list(summary.keys())
colors = ['#ef4444', '#10b981', '#3b82f6']

# Exact accuracy
ax1 = axes[0]
exact_accs = [summary[n]['exact'] for n in names]
bars1 = ax1.bar(names, exact_accs, color=colors[:len(names)])
ax1.set_ylabel('Accuracy %')
ax1.set_title('Triage Category - Exact Match')
ax1.set_ylim(0, 100)
for bar, val in zip(bars1, exact_accs):
    ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 2, f'{val:.0f}%', ha='center', fontweight='bold')

# Within ¬±1
ax2 = axes[1]
within1_accs = [summary[n]['within1'] for n in names]
bars2 = ax2.bar(names, within1_accs, color=colors[:len(names)])
ax2.set_ylabel('Accuracy %')
ax2.set_title('Triage Category - Within ¬±1 Category')
ax2.set_ylim(0, 100)
for bar, val in zip(bars2, within1_accs):
    ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 2, f'{val:.0f}%', ha='center', fontweight='bold')

plt.tight_layout()
plt.savefig('triage_benchmark.png', dpi=150)
plt.show()

In [None]:
# Detailed table
print("\nüìã Detailed Results")
print(df[['case', 'expected', 'nursesim', 'gemini', 'gpt']].to_string(index=False))

In [None]:
# Generate Report
from datetime import datetime

winner = max(summary.keys(), key=lambda x: summary[x]['exact']) if summary else "N/A"

report = f"""# Triage Category Prediction Benchmark
**Generated**: {datetime.now().strftime('%Y-%m-%d %H:%M')}

## Summary

| Model | Exact Match | Within ¬±1 |
|-------|-------------|------------|
"""
for name in summary:
    star = "‚≠ê" if name == winner else ""
    report += f"| {name} {star} | {summary[name]['exact']:.1f}% | {summary[name]['within1']:.1f}% |\n"

report += f"""
**Winner**: {winner} with {summary.get(winner, {}).get('exact', 0):.1f}% exact match accuracy

## Category Breakdown

| Category | NurseSim | Gemini | GPT |
|----------|----------|--------|-----|
"""

cat_names = {1:'IMMEDIATE', 2:'VERY URGENT', 3:'URGENT', 4:'STANDARD', 5:'NON-URGENT'}
for cat in [1,2,3,4,5]:
    subset = df[df['expected'] == cat]
    if len(subset) > 0:
        ns_acc = subset['ns_exact'].mean()*100 if 'ns_exact' in subset else 0
        gm_acc = subset['gm_exact'].mean()*100 if 'gm_exact' in subset else 0
        gp_acc = subset['gp_exact'].mean()*100 if 'gp_exact' in subset else 0
        report += f"| {cat_names[cat]} | {ns_acc:.0f}% | {gm_acc:.0f}% | {gp_acc:.0f}% |\n"

report += "\n---\n*NurseSim-Triage Benchmark | practicedev.cloud*"

print(report)
with open('triage_benchmark_report.md', 'w') as f:
    f.write(report)
print("\n‚úÖ Saved: triage_benchmark_report.md")

In [None]:
df.to_csv('triage_benchmark_results.csv', index=False)
print("‚úÖ Saved: triage_benchmark_results.csv")