# Phase 3 — MedGemma Inference Test (Kaggle GPU)

**Environment**: Kaggle, GPU T4×2 or P100 (≥15 GB VRAM)  
**Model**: `medgemma-4b-it-Q3_K_M.gguf` (~2.1 GB) via `llama-cpp-python`  
**Goal**: Validate that MedGemma correctly maps surgical transcriptions → structured JSON machine states

Steps:
1. Clone OR-SIM repo from GitHub
2. Install llama-cpp-python with CUDA
3. Copy GGUF model from Kaggle dataset
4. Run 10 transcription samples per surgery (30 total)
5. Measure inference latency
6. Validate JSON output correctness

In [None]:
# ── 0. Check GPU ──────────────────────────────────────────────────────────
import subprocess
result = subprocess.run(['nvidia-smi', '--query-gpu=name,memory.total', '--format=csv,noheader'],
                        capture_output=True, text=True)
print('GPU:', result.stdout.strip() or 'No GPU found')

In [None]:
# ── 1. Clone OR-SIM repository ────────────────────────────────────────────
# Replace <YOUR_GITHUB_USERNAME> with your actual username
!git clone https://github.com/<YOUR_GITHUB_USERNAME>/OR-SIM.git /kaggle/working/OR-SIM
%cd /kaggle/working/OR-SIM
!git checkout dev
!git log --oneline -5

In [None]:
# ── 2. Install llama-cpp-python with CUDA support ─────────────────────────
# This builds from source with cuBLAS — takes ~3-5 minutes on Kaggle
import os
os.environ['CMAKE_ARGS'] = '-DLLAMA_CUBLAS=on'
os.environ['FORCE_CMAKE'] = '1'

!pip install llama-cpp-python --no-cache-dir -q

# Verify GPU backend loaded
from llama_cpp import Llama
print('llama-cpp-python imported successfully')

In [None]:
# ── 3. Install OR-SIM Python dependencies ────────────────────────────────
!pip install -r requirements/base.txt -q
!pip install -r requirements/llm.txt -q
print('Dependencies installed')

In [None]:
# ── 4. Locate GGUF model ──────────────────────────────────────────────────
# Option A: Model is in a Kaggle dataset attached to this notebook
# Option B: Upload the GGUF file manually to /kaggle/input/
#
# Update MODEL_PATH to point to your GGUF file:
from pathlib import Path

MODEL_PATH = Path('/kaggle/input/medgemma/medgemma-4b-it-Q3_K_M.gguf')

# Fallback: search for any .gguf file
if not MODEL_PATH.exists():
    candidates = list(Path('/kaggle/input').rglob('*.gguf'))
    if candidates:
        MODEL_PATH = candidates[0]
        print(f'Found GGUF at: {MODEL_PATH}')
    else:
        raise FileNotFoundError(
            'No GGUF file found. Attach medgemma dataset to this notebook.'
        )

print(f'Model: {MODEL_PATH}  ({MODEL_PATH.stat().st_size / 1e9:.1f} GB)')

In [None]:
# ── 5. Add OR-SIM to Python path ─────────────────────────────────────────
import sys
sys.path.insert(0, '/kaggle/working/OR-SIM')

from backend.data.surgeries import SurgeryType, MACHINES
from backend.llm.prompt_builder import PromptBuilder
from backend.llm.output_parser  import parse_llm_output
from backend.llm.schemas        import LLMOutput
print('OR-SIM modules imported')

In [None]:
# ── 6. Load MedGemma model ────────────────────────────────────────────────
import time

t0 = time.time()
llm = Llama(
    model_path   = str(MODEL_PATH),
    n_gpu_layers = -1,       # All layers on GPU
    n_ctx        = 4096,
    n_threads    = 4,
    verbose      = False,
)
print(f'Model loaded in {time.time()-t0:.1f}s')

In [None]:
# ── 7. Test transcriptions per surgery ───────────────────────────────────
TEST_CASES = {
    SurgeryType.HEART_TRANSPLANT: [
        'turn on the ventilator',
        'activate bypass pump',
        'turn off the OR lights',
        'we need the defibrillator ready',
        'activate echocardiography',
        'turn off anesthesia machine',
        'turn on cell saver',
        'activate cardiac monitor',
        'turn everything off',
        'ventilator on, bypass pump on, OR lights on',
    ],
    SurgeryType.LIVER_RESECTION: [
        'turn on the laparoscopic tower',
        'activate the ultrasonic dissector',
        'turn off the harmonic scalpel',
        'we need the argon beam',
        'activate the fluoroscopy unit',
        'turn on the cell saver',
        'activate the bipolar electrosurgery unit',
        'turn off the OR lights',
        'activate patient warmer',
        'ultrasonic dissector and laparoscopic tower on',
    ],
    SurgeryType.KIDNEY_PCNL: [
        'turn on the fluoroscopy C-arm',
        'activate the nephroscope tower',
        'turn on the lithotripter',
        'we need the irrigation pump',
        'activate the ultrasound guidance',
        'turn off the electrosurgery unit',
        'turn on the suction unit',
        'activate OR lights',
        'everything off please',
        'fluoroscopy on, lithotripter on, nephroscope on',
    ],
}

print('Test cases defined:', sum(len(v) for v in TEST_CASES.values()), 'total')

In [None]:
# ── 8. Run inference on all test cases ───────────────────────────────────
import json

results = []

for surgery, transcriptions in TEST_CASES.items():
    builder = PromptBuilder(surgery)
    print(f'\n{'='*60}')
    print(f'Surgery: {surgery.value}')
    print('='*60)

    for transcription in transcriptions:
        messages = builder.build_messages(transcription, snapshot=None)

        t0 = time.time()
        response = llm.create_chat_completion(
            messages    = messages,
            max_tokens  = 256,
            temperature = 0.1,
            top_p       = 0.9,
        )
        latency_ms = (time.time() - t0) * 1000

        raw_text = response['choices'][0]['message']['content']
        parsed   = parse_llm_output(raw_text, surgery)

        results.append({
            'surgery':      surgery.value,
            'transcription': transcription,
            'raw_output':   raw_text,
            'parsed_on':    parsed.machine_states['1'],
            'parsed_off':   parsed.machine_states['0'],
            'reasoning':    parsed.reasoning,
            'latency_ms':   latency_ms,
        })

        print(f'\nTranscription : {transcription}')
        print(f'ON            : {parsed.machine_states["1"]}')
        print(f'OFF           : {parsed.machine_states["0"]}')
        print(f'Reasoning     : {parsed.reasoning[:80]}...' if len(parsed.reasoning) > 80 else f'Reasoning     : {parsed.reasoning}')
        print(f'Latency       : {latency_ms:.0f} ms')

print(f'\nTotal inference calls: {len(results)}')

In [None]:
# ── 9. Latency summary ───────────────────────────────────────────────────
import statistics

latencies = [r['latency_ms'] for r in results]
print('Latency Statistics (ms):')
print(f'  Min    : {min(latencies):.0f}')
print(f'  Max    : {max(latencies):.0f}')
print(f'  Mean   : {statistics.mean(latencies):.0f}')
print(f'  Median : {statistics.median(latencies):.0f}')
print(f'  P95    : {sorted(latencies)[int(len(latencies)*0.95)]:.0f}')

target_ms = 3000
under_target = sum(1 for l in latencies if l < target_ms)
print(f'\nUnder {target_ms}ms target: {under_target}/{len(latencies)} ({100*under_target/len(latencies):.0f}%)')

In [None]:
# ── 10. Correctness check ────────────────────────────────────────────────
# Expected: at least one machine correctly ON for each transcription
# (Heuristic — not exact as MedGemma may interpret commands differently)

keyword_to_machines = {
    'ventilator':    ['Ventilator'],
    'bypass pump':   ['Bypass Pump', 'Cardiopulmonary Bypass Machine'],
    'defibrillator': ['Defibrillator'],
    'OR lights':     ['OR Lights'],
    'laparoscopic':  ['Laparoscopic Tower'],
    'lithotripter':  ['Lithotripter'],
    'flouroscop':    ['Fluoroscopy C-Arm'],
    'fluoroscop':    ['Fluoroscopy C-Arm'],
    'C-arm':         ['Fluoroscopy C-Arm'],
}

correct   = 0
incorrect = 0

for r in results:
    tr = r['transcription'].lower()
    on = [m.lower() for m in r['parsed_on']]

    matched = False
    for keyword, expected_machines in keyword_to_machines.items():
        if keyword in tr:
            for em in expected_machines:
                if any(em.lower() in o for o in on):
                    matched = True
                    break
    
    # If no keyword match in our heuristic dict, skip the check
    if any(kw in tr for kw in keyword_to_machines):
        if matched:
            correct += 1
        else:
            incorrect += 1
            print(f'POTENTIAL MISS: "{r["transcription"]}" → ON={r["parsed_on"]}')

checked = correct + incorrect
if checked > 0:
    print(f'\nHeuristic correctness: {correct}/{checked} ({100*correct/checked:.0f}%)')
else:
    print('No keyword-matched heuristics to check')

In [None]:
# ── 11. Save results to JSON ─────────────────────────────────────────────
output_path = Path('/kaggle/working/phase3_results.json')
with open(output_path, 'w') as f:
    json.dump(results, f, indent=2)
print(f'Results saved to {output_path}')

## Summary

Expected results for Phase 3 sign-off:
- ✅ MedGemma loads successfully with `n_gpu_layers=-1`
- ✅ Median inference latency < 3000 ms on T4 GPU
- ✅ JSON output parses correctly for all 30 test cases
- ✅ Heuristic keyword correctness > 80%

If all pass → Phase 3 complete → proceed to Phase 4 (E2E Pipeline).