# üìä IJHPM Manuscript Validation: 50-Scenario Benchmark

**Purpose:** Run NurseSim-Triage evaluation on 50 standardized clinical scenarios for the IJHPM manuscript.

**Output:** Accuracy metrics, category-level performance, and manuscript-ready tables.

---

In [None]:
# Install dependencies
!pip install -q gradio_client pandas matplotlib

In [None]:
import json
import re
import time
import pandas as pd
import matplotlib.pyplot as plt
from gradio_client import Client
from datetime import datetime

print("‚úÖ Libraries loaded")

## 1. Load Validation Dataset

In [None]:
# Download val.jsonl from GitHub
!wget -q https://raw.githubusercontent.com/ClinyQAi/NurseSim-RL/main/data/val.jsonl -O val.jsonl

# Load scenarios
scenarios = []
with open('val.jsonl', 'r') as f:
    for line in f:
        if line.strip():
            scenarios.append(json.loads(line))

# Use first 50 for validation
scenarios = scenarios[:50]
print(f"‚úÖ Loaded {len(scenarios)} scenarios")

# Show category distribution
cat_counts = {}
for s in scenarios:
    cat = s.get('category', 'Unknown')
    cat_counts[cat] = cat_counts.get(cat, 0) + 1

cat_names = {1:'Immediate', 2:'Very Urgent', 3:'Urgent', 4:'Standard', 5:'Non-Urgent'}
print("\nCategory Distribution:")
for cat in sorted(cat_counts.keys()):
    print(f"  Category {cat} ({cat_names.get(cat, 'Unknown')}): {cat_counts[cat]} cases")

## 2. Connect to NurseSim-Triage Model

In [None]:
# Connect to Hugging Face Space
print("Connecting to NurseSim-Triage...")
try:
    client = Client("NurseCitizenDeveloper/NurseSim-Triage-Demo")
    print("‚úÖ Connected to NurseSim-Triage")
except Exception as e:
    print(f"‚ùå Connection failed: {e}")
    print("\nTroubleshooting:")
    print("1. Check if the Space is running: https://huggingface.co/spaces/NurseCitizenDeveloper/NurseSim-Triage-Demo")
    print("2. The Space may need to 'wake up' - try refreshing the page first")
    client = None

In [None]:
def parse_scenario(scenario):
    """Extract vitals from scenario input text"""
    input_text = scenario['input']
    
    # Extract chief complaint
    complaint_match = re.search(r'Chief Complaint: "(.+?)"', input_text)
    complaint = complaint_match.group(1) if complaint_match else input_text[:100]
    
    # Extract vitals
    hr_match = re.search(r'HR: (\d+)', input_text)
    bp_match = re.search(r'BP: ([\d/]+)', input_text)
    spo2_match = re.search(r'SpO2: (\d+)', input_text)
    temp_match = re.search(r'Temp: ([\d.]+)', input_text)
    
    return {
        'complaint': complaint,
        'hr': int(hr_match.group(1)) if hr_match else 80,
        'bp': bp_match.group(1) if bp_match else '120/80',
        'spo2': int(spo2_match.group(1)) if spo2_match else 98,
        'temp': float(temp_match.group(1)) if temp_match else 37.0,
        'expected': scenario.get('category', -1)
    }

def extract_category(response_text):
    """Extract triage category 1-5 from model response"""
    text = str(response_text).lower()
    
    # Check for category words
    if 'category: 1' in text or 'immediate' in text and 'red' in text:
        return 1
    if 'category: 2' in text or 'very urgent' in text:
        return 2
    if 'category: 3' in text or ('urgent' in text and 'very' not in text and 'non' not in text):
        return 3
    if 'category: 4' in text or 'standard' in text:
        return 4
    if 'category: 5' in text or 'non-urgent' in text or 'non urgent' in text:
        return 5
    
    # Look for number pattern
    match = re.search(r'category[:\s]*([1-5])', text)
    if match:
        return int(match.group(1))
    
    return -1

def query_model(parsed):
    """Query NurseSim-Triage model"""
    if client is None:
        return -1, "No client"
    
    try:
        result = client.predict(
            complaint=parsed['complaint'],
            hr=float(parsed['hr']),
            bp=parsed['bp'],
            spo2=float(parsed['spo2']),
            temp=float(parsed['temp']),
            api_name="/gradio_predict"
        )
        return extract_category(str(result)), str(result)[:200]
    except Exception as e:
        return -1, str(e)[:100]

print("‚úÖ Functions ready")

## 3. Run Evaluation

In [None]:
print("üî¨ Running 50-Scenario Evaluation...\n")
print("="*60)

results = []
for i, scenario in enumerate(scenarios):
    parsed = parse_scenario(scenario)
    predicted, response = query_model(parsed)
    expected = parsed['expected']
    
    match = "‚úì" if predicted == expected else "‚úó"
    print(f"[{i+1:2d}/50] Expected: {expected} | Predicted: {predicted} {match}")
    
    results.append({
        'scenario_id': i + 1,
        'complaint': parsed['complaint'][:50],
        'expected': expected,
        'predicted': predicted,
        'exact_match': predicted == expected,
        'within_1': abs(predicted - expected) <= 1 if predicted > 0 else False,
        'under_triage': predicted > expected if predicted > 0 else False,
        'over_triage': predicted < expected if predicted > 0 else False
    })
    
    time.sleep(1.5)  # Rate limiting

df = pd.DataFrame(results)
print("\n" + "="*60)
print("‚úÖ Evaluation Complete!")

## 4. Calculate Results

In [None]:
# Filter valid responses
valid = df[df['predicted'] > 0]
n_valid = len(valid)
n_total = len(df)

print("\n" + "="*60)
print("üìä NURSESIM-TRIAGE VALIDATION RESULTS")
print("="*60)
print(f"\nValid Responses: {n_valid}/{n_total} ({n_valid/n_total*100:.0f}%)\n")

# Overall Metrics
exact_accuracy = valid['exact_match'].mean() * 100
within_1_accuracy = valid['within_1'].mean() * 100
under_triage_rate = valid['under_triage'].mean() * 100
over_triage_rate = valid['over_triage'].mean() * 100

print("OVERALL PERFORMANCE:")
print(f"  Exact Match Accuracy:  {valid['exact_match'].sum()}/{n_valid} ({exact_accuracy:.1f}%)")
print(f"  Within ¬±1 Category:    {valid['within_1'].sum()}/{n_valid} ({within_1_accuracy:.1f}%)")
print(f"  Under-triage Rate:     {valid['under_triage'].sum()}/{n_valid} ({under_triage_rate:.1f}%)")
print(f"  Over-triage Rate:      {valid['over_triage'].sum()}/{n_valid} ({over_triage_rate:.1f}%)")

In [None]:
# Performance by Category (for manuscript Table 1)
print("\n" + "-"*60)
print("PERFORMANCE BY MTS CATEGORY:")
print("-"*60)

cat_names = {
    1: 'Immediate (Red)',
    2: 'Very Urgent (Orange)', 
    3: 'Urgent (Yellow)',
    4: 'Standard (Green)',
    5: 'Non-Urgent (Blue)'
}

cat_results = []
for cat in [1, 2, 3, 4, 5]:
    subset = valid[valid['expected'] == cat]
    if len(subset) > 0:
        accuracy = subset['exact_match'].mean() * 100
        n = len(subset)
        correct = subset['exact_match'].sum()
        cat_results.append({
            'Category': cat,
            'Name': cat_names.get(cat, 'Unknown'),
            'N': n,
            'Correct': correct,
            'Accuracy': accuracy
        })
        print(f"  Category {cat} ({cat_names.get(cat, 'Unknown')}): {correct}/{n} ({accuracy:.0f}%)")

cat_df = pd.DataFrame(cat_results)
print("\n‚úÖ Category breakdown complete")

In [None]:
# Safety Analysis (Critical for manuscript)
print("\n" + "-"*60)
print("SAFETY ANALYSIS (Critical Category Detection):")
print("-"*60)

# Category 1 (Immediate) - most critical
cat1 = valid[valid['expected'] == 1]
cat1_correct = cat1['exact_match'].sum() if len(cat1) > 0 else 0
cat1_total = len(cat1)
cat1_sensitivity = (cat1_correct / cat1_total * 100) if cat1_total > 0 else 0

# Critical under-triage (predicting Cat 3-5 when actual is Cat 1-2)
critical_cases = valid[valid['expected'].isin([1, 2])]
severe_undertriage = critical_cases[critical_cases['predicted'].isin([4, 5])]
undertriage_rate = (len(severe_undertriage) / len(critical_cases) * 100) if len(critical_cases) > 0 else 0

print(f"  Category 1 Sensitivity: {cat1_correct}/{cat1_total} ({cat1_sensitivity:.0f}%)")
print(f"  Severe Under-triage (Cat 1-2 ‚Üí Cat 4-5): {len(severe_undertriage)}/{len(critical_cases)} ({undertriage_rate:.1f}%)")

if undertriage_rate == 0:
    print("  ‚úÖ NO severe under-triage events detected")
else:
    print("  ‚ö†Ô∏è Severe under-triage events require review")

## 5. Generate Manuscript-Ready Output

In [None]:
# Create visualization
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# Chart 1: Accuracy by Category
ax1 = axes[0]
colors = ['#dc2626', '#f97316', '#eab308', '#22c55e', '#3b82f6']
cats = [c['Category'] for c in cat_results]
accs = [c['Accuracy'] for c in cat_results]
bars = ax1.bar(cats, accs, color=colors[:len(cats)])
ax1.set_xlabel('MTS Category')
ax1.set_ylabel('Accuracy (%)')
ax1.set_title('Triage Accuracy by MTS Category')
ax1.set_ylim(0, 100)
ax1.set_xticks([1, 2, 3, 4, 5])
for bar, val in zip(bars, accs):
    ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 2, 
             f'{val:.0f}%', ha='center', fontweight='bold')

# Chart 2: Overall Metrics
ax2 = axes[1]
metrics = ['Exact Match', 'Within ¬±1', 'Under-triage', 'Over-triage']
values = [exact_accuracy, within_1_accuracy, under_triage_rate, over_triage_rate]
colors2 = ['#22c55e', '#3b82f6', '#ef4444', '#f97316']
bars2 = ax2.bar(metrics, values, color=colors2)
ax2.set_ylabel('Percentage (%)')
ax2.set_title('Overall Performance Metrics')
ax2.set_ylim(0, 100)
for bar, val in zip(bars2, values):
    ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 2, 
             f'{val:.1f}%', ha='center', fontweight='bold')

plt.tight_layout()
plt.savefig('ijhpm_validation_results.png', dpi=150, bbox_inches='tight')
plt.show()
print("\n‚úÖ Saved: ijhpm_validation_results.png")

In [None]:
# Generate Markdown Report for Manuscript
report = f"""# NurseSim-Triage Validation Results
**Generated:** {datetime.now().strftime('%Y-%m-%d %H:%M')}
**Dataset:** 50 standardized clinical scenarios from val.jsonl

## Summary

| Metric | Result |
|--------|--------|
| Sample Size | {n_valid} scenarios |
| Exact Match Accuracy | {exact_accuracy:.1f}% |
| Within ¬±1 Category | {within_1_accuracy:.1f}% |
| Under-triage Rate | {under_triage_rate:.1f}% |
| Over-triage Rate | {over_triage_rate:.1f}% |

## Table 1: Performance by MTS Category

| Category | Description | n | Correct | Accuracy |
|----------|-------------|---|---------|----------|
"""

for c in cat_results:
    report += f"| {c['Category']} | {c['Name']} | {c['N']} | {c['Correct']} | {c['Accuracy']:.0f}% |\n"

report += f"""
## Safety Analysis

| Metric | Result |
|--------|--------|
| Category 1 Sensitivity | {cat1_sensitivity:.0f}% ({cat1_correct}/{cat1_total}) |
| Severe Under-triage (Cat 1-2 ‚Üí Cat 4-5) | {undertriage_rate:.1f}% ({len(severe_undertriage)}/{len(critical_cases)}) |

## Notes for Manuscript

- **Methodology:** Evaluated on {n_valid} standardized clinical scenarios from a held-out validation set.
- **Ground Truth:** Each scenario was assigned an expected MTS category based on clinical guidelines.
- **Safety Focus:** Under-triage of critical patients (Category 1-2) is penalized more heavily than over-triage.

---
*NurseSim-Triage | IJHPM Manuscript Validation*
"""

print(report)

with open('ijhpm_validation_report.md', 'w') as f:
    f.write(report)
print("\n‚úÖ Saved: ijhpm_validation_report.md")

In [None]:
# Save raw results
df.to_csv('ijhpm_validation_raw.csv', index=False)
print("‚úÖ Saved: ijhpm_validation_raw.csv")

# Download files
print("\nüì• Download these files for your manuscript:")
print("   1. ijhpm_validation_report.md - Results summary")
print("   2. ijhpm_validation_results.png - Charts")
print("   3. ijhpm_validation_raw.csv - Raw data")