# KVEC Triage Clinical Evaluation

This notebook evaluates the triage models against a curated dataset of 45 symptom→specialty mappings.

**Models:**
- **SetFit (ONNX)**: Fast classification (~100ms) for specialty and condition routing
- **MedGemma (GGUF)**: Full LLM inference for rich clinical assessment

**Metrics:** Accuracy, Precision, Recall, F1 per specialty

In [None]:
# Install dependencies
!pip install -q llama-cpp-python onnxruntime transformers huggingface_hub pyyaml pandas scikit-learn

In [None]:
import os
import json
import time
import yaml
import numpy as np
import pandas as pd
from pathlib import Path
from huggingface_hub import hf_hub_download
from sklearn.metrics import classification_report, confusion_matrix
import warnings
warnings.filterwarnings('ignore')

## 1. Load Test Cases

In [None]:
# Load test cases from YAML
test_cases_yaml = """
test_cases:
  # Urology
  - symptom: "burning pain when I pee, have to urinate frequently"
    expected_specialty: "Urology"
  - symptom: "blood in my urine and lower back pain"
    expected_specialty: "Urology"
  - symptom: "difficulty starting to urinate, weak stream, getting up multiple times at night"
    expected_specialty: "Urology"

  # Cardiology
  - symptom: "chest pain that gets worse with exertion, shortness of breath"
    expected_specialty: "Cardiology"
  - symptom: "heart racing for no reason, palpitations especially at night"
    expected_specialty: "Cardiology"
  - symptom: "swollen ankles and legs, difficulty breathing when lying down"
    expected_specialty: "Cardiology"

  # Gastroenterology
  - symptom: "chest pain after eating, burning in my throat, acid taste"
    expected_specialty: "Gastroenterology"
  - symptom: "trouble swallowing, food getting stuck in my throat"
    expected_specialty: "Gastroenterology"
  - symptom: "severe abdominal cramping, blood in stool, diarrhea for weeks"
    expected_specialty: "Gastroenterology"

  # Behavioral Health
  - symptom: "feeling anxious and can't sleep, constant worry"
    expected_specialty: "Behavioral Health"
  - symptom: "lost interest in activities, feeling hopeless, no energy"
    expected_specialty: "Behavioral Health"
  - symptom: "panic attacks, racing heart, feeling like I'm going to die"
    expected_specialty: "Behavioral Health"

  # Neurology
  - symptom: "severe headaches with visual aura, sensitivity to light"
    expected_specialty: "Neurology"
  - symptom: "numbness and tingling in hands and feet, weakness"
    expected_specialty: "Neurology"
  - symptom: "sudden confusion, difficulty speaking, one side weakness"
    expected_specialty: "Neurology"

  # Dermatology
  - symptom: "itchy red rash on arms that spreads, dry flaky skin"
    expected_specialty: "Dermatology"
  - symptom: "changing mole, irregular borders, getting darker"
    expected_specialty: "Dermatology"
  - symptom: "severe acne, painful cysts on face and back"
    expected_specialty: "Dermatology"

  # Orthopedic Surgery
  - symptom: "lower back pain radiating to leg, numbness in foot"
    expected_specialty: "Orthopedic Surgery"
  - symptom: "knee swelling after injury, can't bear weight"
    expected_specialty: "Orthopedic Surgery"
  - symptom: "shoulder pain, can't raise arm above head, grinding noise"
    expected_specialty: "Orthopedic Surgery"

  # Pulmonology
  - symptom: "chronic cough, wheezing, difficulty breathing especially at night"
    expected_specialty: "Pulmonology"
  - symptom: "shortness of breath getting worse, persistent cough with mucus"
    expected_specialty: "Pulmonology"
  - symptom: "coughing up blood, unexplained weight loss, chest pain"
    expected_specialty: "Pulmonology"

  # Rheumatology
  - symptom: "joint pain in multiple joints, morning stiffness lasting hours"
    expected_specialty: "Rheumatology"
  - symptom: "fatigue, joint pain, butterfly rash on face"
    expected_specialty: "Rheumatology"
  - symptom: "red swollen big toe, sudden severe pain, can't touch it"
    expected_specialty: "Rheumatology"

  # Women's Health
  - symptom: "irregular periods, severe cramping, heavy bleeding"
    expected_specialty: "Women's Health"
  - symptom: "pelvic pain, pain during intercourse, difficulty getting pregnant"
    expected_specialty: "Women's Health"
  - symptom: "hot flashes, night sweats, mood changes, missed periods"
    expected_specialty: "Women's Health"

  # Primary Care
  - symptom: "fatigue, always tired, no energy even with sleep"
    expected_specialty: "Primary Care"
  - symptom: "always thirsty, urinating frequently, unexplained weight loss"
    expected_specialty: "Primary Care"
  - symptom: "sore throat, runny nose, mild cough for a few days"
    expected_specialty: "Primary Care"

  # Pain Management
  - symptom: "chronic pain everywhere, tender points, fibromyalgia diagnosed"
    expected_specialty: "Pain Management"
  - symptom: "severe back pain for years, tried everything, need pain relief"
    expected_specialty: "Pain Management"

  # Oncology
  - symptom: "unexplained weight loss, night sweats, swollen lymph nodes"
    expected_specialty: "Oncology"
  - symptom: "lump in breast, nipple discharge, family history of cancer"
    expected_specialty: "Oncology"

  # Sports Medicine
  - symptom: "runner's knee pain, pain going down stairs, swelling"
    expected_specialty: "Sports Medicine"
  - symptom: "tennis elbow, pain on outside of elbow, worse with gripping"
    expected_specialty: "Sports Medicine"

  # Vascular Medicine
  - symptom: "leg pain when walking, relieved by rest, cold feet"
    expected_specialty: "Vascular Medicine"
  - symptom: "swollen leg, red and warm, pain in calf"
    expected_specialty: "Vascular Medicine"

  # Edge cases
  - symptom: "seeing floaters and flashing lights in vision"
    expected_specialty: "Neurology"
  - symptom: "ear ringing constant, hearing loss, dizziness"
    expected_specialty: "Primary Care"
  - symptom: "hair falling out in patches, brittle nails"
    expected_specialty: "Dermatology"
  - symptom: "difficulty concentrating, memory problems, brain fog"
    expected_specialty: "Neurology"
"""

data = yaml.safe_load(test_cases_yaml)
test_cases = data['test_cases']
print(f"Loaded {len(test_cases)} test cases")
print(f"Specialties: {sorted(set(tc['expected_specialty'] for tc in test_cases))}")

## 2. Load SetFit ONNX Models

In [None]:
import onnxruntime as ort
from transformers import AutoTokenizer

# Download SetFit models from HuggingFace
print("Downloading SetFit models...")

# Specialty model
specialty_body_path = hf_hub_download("ekim1394/setfit-specialty-onnx", "body/model.onnx")
specialty_head_path = hf_hub_download("ekim1394/setfit-specialty-onnx", "model_head.onnx")
specialty_labels_path = hf_hub_download("ekim1394/setfit-specialty-onnx", "label_mapping.json")

# Condition model
condition_body_path = hf_hub_download("ekim1394/setfit-condition-onnx", "body/model.onnx")
condition_head_path = hf_hub_download("ekim1394/setfit-condition-onnx", "model_head.onnx")
condition_labels_path = hf_hub_download("ekim1394/setfit-condition-onnx", "label_mapping.json")

print("Loading ONNX sessions...")
specialty_body = ort.InferenceSession(specialty_body_path)
specialty_head = ort.InferenceSession(specialty_head_path)
condition_body = ort.InferenceSession(condition_body_path)
condition_head = ort.InferenceSession(condition_head_path)

# Load label mappings
with open(specialty_labels_path) as f:
    specialty_labels = json.load(f)
with open(condition_labels_path) as f:
    condition_labels = json.load(f)

# Load tokenizer (MiniLM-based)
tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")

print(f"Specialty labels: {len(specialty_labels)}")
print(f"Condition labels: {len(condition_labels)}")
print("SetFit models loaded!")

In [None]:
def mean_pooling(hidden_states, attention_mask):
    """Mean pooling with attention mask"""
    input_mask_expanded = np.expand_dims(attention_mask, -1).astype(np.float32)
    sum_embeddings = np.sum(hidden_states * input_mask_expanded, axis=1)
    sum_mask = np.clip(input_mask_expanded.sum(axis=1), a_min=1e-9, a_max=None)
    return sum_embeddings / sum_mask

def classify_with_setfit(symptom: str) -> dict:
    """Run SetFit classification on a symptom"""
    start_time = time.time()
    
    # Tokenize
    inputs = tokenizer(symptom, padding="max_length", truncation=True, max_length=128, return_tensors="np")
    
    # Run body model
    body_outputs = specialty_body.run(None, {
        "input_ids": inputs["input_ids"].astype(np.int64),
        "attention_mask": inputs["attention_mask"].astype(np.int64),
        "token_type_ids": np.zeros_like(inputs["input_ids"]).astype(np.int64)
    })
    
    # Mean pooling
    embeddings = mean_pooling(body_outputs[0], inputs["attention_mask"])
    
    # Run head model
    head_outputs = specialty_head.run(None, {"embedding": embeddings.astype(np.float32)})
    logits = head_outputs[0][0]
    
    # Softmax and get prediction
    probs = np.exp(logits) / np.sum(np.exp(logits))
    predicted_idx = np.argmax(logits)
    confidence = float(probs[predicted_idx])
    
    # Map to label
    specialty = specialty_labels.get(str(predicted_idx), "Primary Care")
    
    inference_time = (time.time() - start_time) * 1000
    
    return {
        "specialty": specialty,
        "confidence": confidence,
        "inference_time_ms": inference_time
    }

# Test it
test_result = classify_with_setfit("burning pain when I pee")
print(f"Test: {test_result}")

## 3. Load MedGemma LLM (GGUF)

In [None]:
from llama_cpp import Llama

print("Downloading MedGemma GGUF model... (this may take a few minutes)")
model_path = hf_hub_download(
    "ekim1394/medgemma-4b-iq2_xxs-gguf",
    "medgemma-4b-it-iq2_xxs.gguf"
)

print(f"Loading model from {model_path}...")
llm = Llama(
    model_path=model_path,
    n_ctx=256,
    n_batch=256,
    n_threads=8,
    n_gpu_layers=-1,  # Use all GPU layers if available
    verbose=False
)
print("MedGemma loaded!")

In [None]:
SPECIALTIES = [
    'Behavioral Health', 'Cardiology', 'Dermatology', 'Gastroenterology',
    'Neurology', 'Oncology', 'Orthopedic Surgery', 'Pain Management',
    'Primary Care', 'Pulmonology', 'Rheumatology', 'Sports Medicine',
    'Urology', 'Vascular Medicine', "Women's Health"
]

def find_closest_specialty(text: str) -> str:
    """Match LLM output to valid specialty"""
    text_lower = text.lower()
    for specialty in SPECIALTIES:
        if specialty.lower() in text_lower:
            return specialty
    # Aliases
    aliases = {
        'mental health': 'Behavioral Health', 'psychiatry': 'Behavioral Health',
        'heart': 'Cardiology', 'cardiac': 'Cardiology',
        'skin': 'Dermatology', 'gi': 'Gastroenterology',
        'neuro': 'Neurology', 'brain': 'Neurology',
        'cancer': 'Oncology', 'ortho': 'Orthopedic Surgery',
        'lung': 'Pulmonology', 'breathing': 'Pulmonology',
        'bladder': 'Urology', 'kidney': 'Urology',
        'gynecology': "Women's Health", 'ob/gyn': "Women's Health"
    }
    for alias, specialty in aliases.items():
        if alias in text_lower:
            return specialty
    return 'Primary Care'

def classify_with_llm(symptom: str) -> dict:
    """Run MedGemma triage on a symptom"""
    start_time = time.time()
    
    prompt = f"""<bos><start_of_turn>user
Triage: "{symptom}"
Format: SPECIALTY|CONFIDENCE|URGENCY
<end_of_turn>
<start_of_turn>model
"""
    
    response = llm(
        prompt,
        max_tokens=60,
        temperature=0.1,
        stop=["</s>", "\n\n"]
    )
    
    text = response["choices"][0]["text"].strip()
    specialty = find_closest_specialty(text)
    
    # Parse confidence if present
    confidence = 0.8  # Default
    if "high" in text.lower():
        confidence = 0.9
    elif "low" in text.lower():
        confidence = 0.6
    
    inference_time = (time.time() - start_time) * 1000
    
    return {
        "specialty": specialty,
        "confidence": confidence,
        "inference_time_ms": inference_time,
        "raw_output": text
    }

# Test it
test_result = classify_with_llm("burning pain when I pee")
print(f"Test: {test_result}")

## 4. Run Evaluation

In [None]:
from tqdm import tqdm

# Run SetFit evaluation
print("=" * 50)
print("Running SetFit Evaluation...")
print("=" * 50)

setfit_results = []
for tc in tqdm(test_cases):
    result = classify_with_setfit(tc['symptom'])
    setfit_results.append({
        'symptom': tc['symptom'],
        'expected': tc['expected_specialty'],
        'predicted': result['specialty'],
        'confidence': result['confidence'],
        'inference_ms': result['inference_time_ms'],
        'correct': result['specialty'] == tc['expected_specialty']
    })

setfit_df = pd.DataFrame(setfit_results)
setfit_accuracy = setfit_df['correct'].mean()
setfit_avg_time = setfit_df['inference_ms'].mean()

print(f"\nSetFit Results:")
print(f"  Accuracy: {setfit_accuracy:.1%}")
print(f"  Avg Inference: {setfit_avg_time:.1f}ms")

In [None]:
# Run LLM evaluation
print("=" * 50)
print("Running MedGemma LLM Evaluation...")
print("=" * 50)

llm_results = []
for tc in tqdm(test_cases):
    result = classify_with_llm(tc['symptom'])
    llm_results.append({
        'symptom': tc['symptom'],
        'expected': tc['expected_specialty'],
        'predicted': result['specialty'],
        'confidence': result['confidence'],
        'inference_ms': result['inference_time_ms'],
        'correct': result['specialty'] == tc['expected_specialty'],
        'raw_output': result['raw_output']
    })

llm_df = pd.DataFrame(llm_results)
llm_accuracy = llm_df['correct'].mean()
llm_avg_time = llm_df['inference_ms'].mean()

print(f"\nMedGemma LLM Results:")
print(f"  Accuracy: {llm_accuracy:.1%}")
print(f"  Avg Inference: {llm_avg_time:.1f}ms")

## 5. Generate Report

In [None]:
# Comparison summary
print("\n" + "=" * 60)
print("CLINICAL EVALUATION SUMMARY")
print("=" * 60)

comparison = pd.DataFrame({
    'Model': ['Fallback (Keyword)', 'SetFit (ONNX)', 'MedGemma (LLM)'],
    'Accuracy': ['47%', f'{setfit_accuracy:.1%}', f'{llm_accuracy:.1%}'],
    'Avg Inference': ['0.02ms', f'{setfit_avg_time:.1f}ms', f'{llm_avg_time:.1f}ms'],
    'Use Case': ['Offline backup', 'Fast routing', 'Rich assessment']
})
print(comparison.to_string(index=False))

In [None]:
# Per-specialty breakdown for SetFit
print("\n" + "=" * 60)
print("SetFit Per-Specialty Performance")
print("=" * 60)

print(classification_report(
    setfit_df['expected'], 
    setfit_df['predicted'],
    zero_division=0
))

In [None]:
# Per-specialty breakdown for LLM
print("\n" + "=" * 60)
print("MedGemma Per-Specialty Performance")
print("=" * 60)

print(classification_report(
    llm_df['expected'], 
    llm_df['predicted'],
    zero_division=0
))

In [None]:
# Show misclassifications
print("\n" + "=" * 60)
print("SetFit Misclassifications")
print("=" * 60)

misses = setfit_df[~setfit_df['correct']][['symptom', 'expected', 'predicted']]
for _, row in misses.iterrows():
    print(f"\n• '{row['symptom'][:50]}...'")
    print(f"  Expected: {row['expected']} → Got: {row['predicted']}")

In [None]:
# Save results to JSON
results = {
    "evaluation_date": time.strftime("%Y-%m-%d %H:%M:%S"),
    "test_cases": len(test_cases),
    "setfit": {
        "accuracy": float(setfit_accuracy),
        "avg_inference_ms": float(setfit_avg_time),
        "results": setfit_results
    },
    "llm": {
        "accuracy": float(llm_accuracy),
        "avg_inference_ms": float(llm_avg_time),
        "results": [{k: v for k, v in r.items() if k != 'raw_output'} for r in llm_results]
    }
}

with open('evaluation_results_full.json', 'w') as f:
    json.dump(results, f, indent=2)
    
print("Results saved to evaluation_results_full.json")

## 6. Key Findings for Competition

| Model | Accuracy | Inference Time | Best For |
|-------|----------|----------------|----------|
| Fallback (Keywords) | 47% | <1ms | Offline backup |
| SetFit (ONNX) | __% | ~Xms | Fast mobile routing |
| MedGemma (LLM) | __% | ~Xms | Rich clinical assessment |

**Tiered Architecture Value:**
- SetFit provides fast sub-100ms classification for immediate routing
- MedGemma enriches with clinical context, urgency, and red flags
- Combination gives best UX: instant feedback + rich detail