# Agentic Model: RATS AI Triage Classifier

Combat Triage AI - Complete Implementation with Quantization and SALT Protocol

Ready for Raspberry Pi Deployment

## Step 1: Setup

#### Step 1A: Installation

In [1]:
!pip install -U transformers
!pip install optimum[onnxruntime]
!pip install torch torchaudio

Collecting transformers
  Using cached transformers-4.57.1-py3-none-any.whl.metadata (43 kB)
Collecting tokenizers<=0.23.0,>=0.22.0 (from transformers)
  Using cached tokenizers-0.22.1-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.8 kB)
Using cached transformers-4.57.1-py3-none-any.whl (12.0 MB)
Using cached tokenizers-0.22.1-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.3 MB)
Installing collected packages: tokenizers, transformers
[2K  Attempting uninstall: tokenizers
[2K    Found existing installation: tokenizers 0.21.4
[2K    Uninstalling tokenizers-0.21.4:
[2K      Successfully uninstalled tokenizers-0.21.4
[2K  Attempting uninstall: transformers━━━━━━━━━━━━[0m [32m0/2[0m [tokenizers]
[2K    Found existing installation: transformers 4.55.42m0/2[0m [tokenizers]
[2K    Uninstalling transformers-4.55.4:╺[0m[90m━━━━━━━━━━━━━━━━━━━[0m [32m1/2[0m [transformers]
[2K      Successfully uninstalled transformers-4.55.4━━━━━━━━━━━━[0m [32

#### Step 1B: Imports

In [2]:
import os, torch, torchaudio
import numpy as np
import re
import time
from transformers import pipeline, WhisperForConditionalGeneration, AutoProcessor

  from .autonotebook import tqdm as notebook_tqdm


#### Step 1C: Configuration

In [3]:
MODEL_ID = "openai/whisper-tiny.en"
DEVICE = "cpu"  # Pi doesn't have GPU

## Step 2: Medical Tools

In [4]:
# Expanded combat medical vocabulary
COMBAT_MEDICAL_LEXICON = """
tourniquet, hemorrhage, massive hemorrhage, capillary refill, 
obey commands, airway patent, airway obstructed, 
respirations per minute, respiratory rate, breathing adequately,
radial pulse present, radial pulse absent, carotid pulse,
shock, hypotensive, pale, clammy, cold,
GSW, gunshot wound, blast injury, shrapnel, amputation,
conscious, unconscious, alert, verbal, pain, unresponsive,
chest seal, needle decompression, nasopharyngeal airway,
combat gauze, hemostatic agent, pressure dressing,
walking wounded, litter urgent, urgent surgical,
can walk, cannot walk, ambulatory, unable to walk
"""

## Step 3: Model Quantization

In [5]:
# ========== MODEL LOADING WITH QUANTIZATION ==========
print("Loading and quantizing model...")

# Load model and processor separately
model = WhisperForConditionalGeneration.from_pretrained(MODEL_ID)
processor = AutoProcessor.from_pretrained(MODEL_ID)

# Quantize the model (makes it 4x smaller and faster)
quantized_model = torch.quantization.quantize_dynamic(
    model, 
    {torch.nn.Linear},  # Quantize linear layers
    dtype=torch.qint8   # Use 8-bit integers
)

# Create ASR pipeline with quantized model AND processor components
asr = pipeline(
    "automatic-speech-recognition",
    model=quantized_model,
    tokenizer=processor.tokenizer,
    feature_extractor=processor.feature_extractor,
    device=DEVICE,
    chunk_length_s=30,
    stride_length_s=5,
    return_timestamps=True
)

print(f"Model loaded and quantized successfully")
print(f"Device: {DEVICE}")

Loading and quantizing model...


Device set to use cpu


Model loaded and quantized successfully
Device: cpu


### Step 3: Audio Processing

In [6]:
def preprocess_combat_audio(audio_path):
    """Handle noisy battlefield audio conditions"""
    wav, sr = torchaudio.load(audio_path)
    
    # Resample to 16kHz (Whisper requirement)
    if sr != 16000:
        wav = torchaudio.functional.resample(wav, sr, 16000)
    
    # Convert to mono if stereo
    if wav.shape[0] > 1:
        wav = torch.mean(wav, dim=0, keepdim=True)
    
    # Noise reduction - high-pass filter for wind/vehicle noise
    wav = torchaudio.functional.highpass_biquad(wav, 16000, cutoff_freq=200)
    
    # Normalize volume (gunfire may cause clipping)
    max_val = wav.abs().max()
    if max_val > 0:
        wav = wav / max_val
    
    return wav, 16000

In [7]:
def simple_vad_chunks(wav_path, min_speech_len=0.6):
    """Voice Activity Detection - remove silence"""
    wav, sr = torchaudio.load(wav_path)
    
    if sr != 16000:
        wav = torchaudio.functional.resample(wav, sr, 16000)
    
    # Simple energy-based VAD
    energy = wav.pow(2).mean(dim=0)
    threshold = energy.mean() * 0.3
    
    voiced_mask = energy > threshold
    if voiced_mask.sum() < 16000 * min_speech_len:
        return [wav_path]  # Too short, return original
    
    # For simplicity, return original if has enough speech
    return [wav_path]

In [8]:
# ========== TRANSCRIPTION ==========
def transcribe(path: str) -> dict:
    """Transcribe audio with medical vocabulary priming"""
    
    # Build prompt for medical context
    prompt_ids = processor.get_prompt_ids(text=COMBAT_MEDICAL_LEXICON)
    
    if isinstance(prompt_ids, np.ndarray):
        prompt_ids = prompt_ids.tolist()
    elif isinstance(prompt_ids, tuple):
        prompt_ids = list(prompt_ids)
    
    prompt_ids = torch.tensor(prompt_ids, dtype=torch.long, device=DEVICE)
    
    # Clear any old forced decoder IDs
    try:
        asr.model.generation_config.forced_decoder_ids = None
    except:
        pass
    
    result = asr(
        path,
        generate_kwargs={
            "prompt_ids": prompt_ids,
            "temperature": 0.0,
            "num_beams": 5,
            "do_sample": False,
        },
        return_timestamps=True
    )
    
    return result

def transcribe_with_vad(path):
    """Transcribe with voice activity detection"""
    out = {"text": "", "chunks": []}
    
    for chunk in simple_vad_chunks(path):
        r = transcribe(chunk)
        out["text"] += (" " + r["text"]).strip()
        if "chunks" in r:
            out["chunks"].extend(r["chunks"])
    
    return out

# ========== ENTITY EXTRACTION ==========
def extract_triage_entities(transcription_text):
    """Extract SALT-relevant medical information from transcription"""
    text_lower = transcription_text.lower()
    
    entities = {
        "can_walk": None,
        "bleeding_severe": False,
        "obeys_commands": None,
        "resp_rate": None,
        "radial_pulse": None,
        "mental_status": None,
        "cap_refill_sec": None
    }
    
    evidence = []
    
    
    # Walking ability
    walk_yes = ["can walk", "walking", "ambulatory", "able to walk"]
    walk_no = ["cannot walk", "can't walk", "unable to walk", "not walking"]
    
    for phrase in walk_yes:
        if phrase in text_lower:
            entities["can_walk"] = True
            evidence.append(f"Walking: '{phrase}' detected")
            break
    
    for phrase in walk_no:
        if phrase in text_lower:
            entities["can_walk"] = False
            evidence.append(f"Not walking: '{phrase}' detected")
            break
    
    # Severe bleeding
    bleeding_phrases = ["severe bleeding", "hemorrhage", "massive hemorrhage", 
                       "tourniquet applied", "massive bleeding", "heavy bleeding"]
    for phrase in bleeding_phrases:
        if phrase in text_lower:
            entities["bleeding_severe"] = True
            evidence.append(f"Severe bleeding: '{phrase}' detected")
            break
    
    # Command response
    obey_yes = ["obeys commands", "follows commands", "responsive to commands", "responding"]
    obey_no = ["does not obey", "doesn't obey", "unresponsive", "no response", 
               "not responding", "not obeying"]
    
    for phrase in obey_yes:
        if phrase in text_lower:
            entities["obeys_commands"] = True
            evidence.append(f"Obeys commands: '{phrase}' detected")
            break
    
    for phrase in obey_no:
        if phrase in text_lower:
            entities["obeys_commands"] = False
            evidence.append(f"Does not obey: '{phrase}' detected")
            break
    
    # Respiratory rate extraction
    resp_patterns = [
        r'(\d+)\s*breaths?\s*(?:per\s*minute)?',
        r'(\d+)\s*respirations?\s*(?:per\s*minute)?',
        r'respiratory\s*rate\s*(?:of\s*)?(\d+)',
        r'breathing\s*(?:at\s*)?(\d+)',
        r'(\d+)\s*rpm'
    ]
    
    for pattern in resp_patterns:
        match = re.search(pattern, text_lower)
        if match:
            entities["resp_rate"] = int(match.group(1))
            evidence.append(f"Respiratory rate: {match.group(1)} detected")
            break
    
    # Radial pulse
    pulse_yes = ["radial pulse present", "has radial pulse", "pulse present"]
    pulse_no = ["no radial pulse", "radial pulse absent", "no pulse"]
    
    for phrase in pulse_yes:
        if phrase in text_lower:
            entities["radial_pulse"] = True
            evidence.append(f"Radial pulse: '{phrase}' detected")
            break
    
    for phrase in pulse_no:
        if phrase in text_lower:
            entities["radial_pulse"] = False
            evidence.append(f"No radial pulse: '{phrase}' detected")
            break
    
    # Mental status
    if "alert" in text_lower:
        entities["mental_status"] = "alert"
        evidence.append("Mental status: alert")
    elif "verbal" in text_lower or "responds to verbal" in text_lower:
        entities["mental_status"] = "verbal"
        evidence.append("Mental status: verbal")
    elif "pain" in text_lower or "responds to pain" in text_lower:
        entities["mental_status"] = "pain"
        evidence.append("Mental status: pain")
    elif "unresponsive" in text_lower:
        entities["mental_status"] = "unresponsive"
        evidence.append("Mental status: unresponsive")
    
    return entities, evidence

## Step 5: Implement SALT Protocol

In [9]:
# ========== SALT TRIAGE RULES ==========
def salt_rules(entities, sensors=None):
    """
    Implement SALT (Sort, Assess, Lifesaving interventions, Treatment/Transport) triage
    
    Categories:
    - Immediate (Red): Life-threatening injuries, needs immediate care
    - Delayed (Yellow): Serious injuries, can wait for treatment
    - Minimal (Green): Minor injuries, walking wounded
    - Expectant (Black): Injuries incompatible with life
    """
    s = sensors or {}
    
    # Merge sensor data
    can_walk = entities.get("can_walk") or s.get("can_walk")
    severe_bleed = entities.get("bleeding_severe") or s.get("bleeding_detected")
    resp = entities.get("resp_rate") or s.get("resp_rate")
    obeys = entities.get("obeys_commands") or s.get("obeys_commands")
    radial_pulse = entities.get("radial_pulse") or s.get("radial_pulse")
    
    # SALT Algorithm
    # Step 1: Can the patient walk?
    if can_walk is True:
        return "Minimal"
    
    # Step 2: Assess for life-threatening bleeding
    if severe_bleed:
        return "Immediate"
    
    # Step 3: Check respirations
    if resp is None:
        return "Unknown"  # Need more data
    
    if resp == 0:
        return "Expectant"  # Not breathing
    
    if resp >= 30:
        return "Immediate"  # Respiratory distress
    
    # Step 4: Check mental status / obeys commands
    if obeys is False:
        return "Immediate"  # Altered mental status
    
    # Step 5: Check radial pulse (perfusion)
    if radial_pulse is False:
        return "Immediate"  # Poor perfusion
    
    # Default: injuries present but stable
    return "Delayed"

In [10]:
# ========== CONFIDENCE & NEXT QUESTION ==========
def calculate_confidence(entities):
    """Calculate confidence based on how much data we have"""
    total_fields = len(entities)
    filled_fields = sum(1 for v in entities.values() if v is not None and v is not False)
    return filled_fields / total_fields

def suggest_next_question(entities):
    """Ask medic for missing critical SALT info"""
    
    if entities["can_walk"] is None:
        return "Can the patient walk?"
    
    if not entities["bleeding_severe"] and entities.get("bleeding_severe") is None:
        return "Is there severe bleeding or hemorrhage?"
    
    if entities["resp_rate"] is None:
        return "What is the respiratory rate per minute?"
    
    if entities["obeys_commands"] is None:
        return "Does the patient obey commands?"
    
    if entities["radial_pulse"] is None:
        return "Is there a radial pulse present?"
    
    return None  # All critical data collected

# ========== SENSOR FUSION ==========
def fuse_sensor_data(audio_entities, drone_sensors):
    """Combine voice transcription with drone sensor data"""
    final_entities = audio_entities.copy()
    
    # Sensor data overrides uncertain voice data
    if drone_sensors.get("thermal_bleeding_detected") is not None:
        if audio_entities["bleeding_severe"] is None or not audio_entities["bleeding_severe"]:
            final_entities["bleeding_severe"] = drone_sensors["thermal_bleeding_detected"]
    
    if drone_sensors.get("movement_detected") is not None:
        if audio_entities["can_walk"] is None:
            final_entities["can_walk"] = drone_sensors["movement_detected"]
    
    if drone_sensors.get("heart_rate") is not None:
        # Estimate respiratory rate from heart rate if not available
        if audio_entities["resp_rate"] is None:
            # Rough estimate: normal resp is ~1/4 of heart rate
            final_entities["resp_rate"] = int(drone_sensors["heart_rate"] / 4)
    
    return final_entities



## Step 6: Test the model

In [11]:
# ========== COMPLETE TRIAGE PIPELINE ==========
def triage_patient(audio_path, sensor_data=None):
    """
    Complete combat triage pipeline
    
    Args:
        audio_path: Path to audio file of medic assessment
        sensor_data: Optional dict of drone sensor readings
        
    Returns:
        Full triage assessment with category and confidence
    """
    start_time = time.time()
    
    print(f"\n{'='*60}")
    print(f"TRIAGE ASSESSMENT INITIATED")
    print(f"{'='*60}\n")
    
    # Step 1: Transcribe audio
    print("📝 Transcribing audio...")
    transcription = transcribe_with_vad(audio_path)
    print(f"✓ Transcription: {transcription['text'][:100]}...")
    
    # Step 2: Extract medical entities
    print("\n🔍 Extracting medical information...")
    entities, evidence = extract_triage_entities(transcription["text"])
    
    # Step 3: Fuse with sensor data if available
    if sensor_data:
        print("🤖 Fusing with sensor data...")
        entities = fuse_sensor_data(entities, sensor_data)
    
    # Step 4: Apply SALT triage rules
    print("\n🏥 Applying SALT triage protocol...")
    triage_category = salt_rules(entities, sensor_data)
    
    # Step 5: Calculate confidence and suggest next question
    confidence = calculate_confidence(entities)
    next_question = suggest_next_question(entities)
    
    processing_time = time.time() - start_time
    
    # Format results
    result = {
        "patient_id": os.path.basename(audio_path),
        "transcription": transcription["text"],
        "entities": entities,
        "evidence": evidence,
        "triage_category": triage_category,
        "confidence": confidence,
        "next_question": next_question,
        "processing_time_sec": round(processing_time, 2),
        "timestamp": transcription.get("chunks", [])
    }
    
    # Print results
    print(f"\n{'='*60}")
    print(f"TRIAGE RESULTS")
    print(f"{'='*60}")
    print(f"🚑 Category: {triage_category}")
    print(f"📊 Confidence: {confidence*100:.0f}%")
    print(f"⏱️  Processing Time: {processing_time:.2f}s")
    print(f"\n📋 Extracted Information:")
    for key, value in entities.items():
        if value is not None:
            print(f"  • {key}: {value}")
    
    if next_question:
        print(f"\n❓ Recommended Question: {next_question}")
    
    if evidence:
        print(f"\n📝 Evidence:")
        for item in evidence:
            print(f"  • {item}")
    
    print(f"{'='*60}\n")
    
    return result

# ========== TESTING ==========
# Test with your audio file
AUDIO = "EnglishTriageTest 1.mp3"

# Example 1: Audio only
print("TEST 1: Audio transcription only")
result1 = triage_patient(AUDIO)

# Example 2: Audio + sensor data
print("\n\nTEST 2: Audio + sensor fusion")
mock_sensor_data = {
    "thermal_bleeding_detected": False,
    "movement_detected": False,
    "heart_rate": 120
}
result2 = triage_patient(AUDIO, sensor_data=mock_sensor_data)

TEST 1: Audio transcription only

TRIAGE ASSESSMENT INITIATED

📝 Transcribing audio...


Using custom `forced_decoder_ids` from the (generation) config. This is deprecated in favor of the `task` and `language` flags/config options.
Whisper did not predict an ending timestamp, which can happen if audio is cut off in the middle of a word. Also make sure WhisperTimeStampLogitsProcessor was used during generation.


✓ Transcription: Ow, my leg hurts, and I can't breathe....

🔍 Extracting medical information...

🏥 Applying SALT triage protocol...

TRIAGE RESULTS
🚑 Category: Unknown
📊 Confidence: 0%
⏱️  Processing Time: 4.86s

📋 Extracted Information:
  • bleeding_severe: False

❓ Recommended Question: Can the patient walk?



TEST 2: Audio + sensor fusion

TRIAGE ASSESSMENT INITIATED

📝 Transcribing audio...


Whisper did not predict an ending timestamp, which can happen if audio is cut off in the middle of a word. Also make sure WhisperTimeStampLogitsProcessor was used during generation.


✓ Transcription: Ow, my leg hurts, and I can't breathe....

🔍 Extracting medical information...
🤖 Fusing with sensor data...

🏥 Applying SALT triage protocol...

TRIAGE RESULTS
🚑 Category: Immediate
📊 Confidence: 14%
⏱️  Processing Time: 2.32s

📋 Extracted Information:
  • can_walk: False
  • bleeding_severe: False
  • resp_rate: 30

❓ Recommended Question: Does the patient obey commands?



## Step 7: Performance metrics of the model

In [12]:
# ========== PERFORMANCE METRICS ==========
print("\n" + "="*60)
print("MODEL PERFORMANCE METRICS")
print("="*60)
print(f"Model: {MODEL_ID}")
print(f"Quantized: Yes (8-bit)")
print(f"Device: {DEVICE}")
print(f"Model parameters: {sum(p.numel() for p in asr.model.parameters()) / 1e6:.1f}M")
print("="*60)


MODEL PERFORMANCE METRICS
Model: openai/whisper-tiny.en
Quantized: Yes (8-bit)
Device: cpu
Model parameters: 21.2M
