# Phase 3 — MedGemma Inference Test (Kaggle GPU)

**Environment**: Kaggle, GPU T4×2 or P100 (≥15 GB VRAM)  
**Model**: `medgemma-4b-it-Q3_K_M.gguf` (~2.1 GB) via `llama-cpp-python` (CUDA 12.4 prebuilt wheels — no compilation)  
**Goal**: Validate that MedGemma correctly maps surgical transcriptions → structured JSON machine states

Steps:
1. Check GPU
2. Clone OR-SIM repo (`main` branch)
3. Install OR-SIM Python dependencies
4. Install `llama-cpp-python` from prebuilt CUDA 12.4 wheels *(fast — ~30s, no compilation)*
5. Download MedGemma GGUF from HuggingFace (`unsloth/medgemma-4b-it-GGUF`)
6. Verify GPU offload is active
7. Run 10 transcription samples per surgery (30 total)
8. Measure latency & validate JSON output

In [None]:
# ── 0. Check GPU ──────────────────────────────────────────────────────────
import subprocess
result = subprocess.run(
    ['nvidia-smi', '--query-gpu=name,memory.total,driver_version', '--format=csv,noheader'],
    capture_output=True, text=True
)
print('GPU:', result.stdout.strip() or 'No GPU found')

In [None]:
# ── 1. Clone OR-SIM repository (main branch) ─────────────────────────────
# Replace <YOUR_GITHUB_USERNAME> with your actual GitHub username before running
!git clone https://github.com/<YOUR_GITHUB_USERNAME>/OR-SIM.git /kaggle/working/OR-SIM
%cd /kaggle/working/OR-SIM
!git log --oneline -5

In [None]:
# ── 2. Install llama-cpp-python — CUDA 12.4 prebuilt wheels (no compilation)
# Uses Kaggle's T4 CUDA 12.4 driver — takes ~30s vs ~5 minutes if building from source
!pip install llama-cpp-python \
    --extra-index-url https://abetlen.github.io/llama-cpp-python/whl/cu124

# Verify GPU offload is compiled in — must be True
import llama_cpp
lib = getattr(llama_cpp, 'llama_cpp', None)
gpu_ok = lib.llama_supports_gpu_offload() if lib else False
print(f'llama_supports_gpu_offload() = {gpu_ok}')
if not gpu_ok:
    raise RuntimeError('GPU offload not available — check Kaggle CUDA version matches cu124 wheel')

from llama_cpp import Llama
print('llama-cpp-python imported successfully')

In [None]:
# ── 3. Install OR-SIM Python dependencies ────────────────────────────────
!pip install -q -r /kaggle/working/OR-SIM/requirements/base.txt
!pip install -q -r /kaggle/working/OR-SIM/requirements/llm.txt
print('OR-SIM dependencies installed')

In [None]:
# ── 4. Download MedGemma GGUF from HuggingFace ───────────────────────────
!pip install -q huggingface_hub

from huggingface_hub import hf_hub_download
from pathlib import Path
import os

os.makedirs('/kaggle/working/OR-SIM/models/medgemma', exist_ok=True)

MODEL_PATH = hf_hub_download(
    repo_id   = 'unsloth/medgemma-4b-it-GGUF',
    filename  = 'medgemma-4b-it-Q3_K_M.gguf',
    local_dir = '/kaggle/working/OR-SIM/models/medgemma',
)
MODEL_PATH = Path(MODEL_PATH)
print(f'Model downloaded to: {MODEL_PATH}')
print(f'File size          : {MODEL_PATH.stat().st_size / 1e9:.2f} GB')

In [None]:
# ── 5. Add OR-SIM to Python path ─────────────────────────────────────────
import sys
sys.path.insert(0, '/kaggle/working/OR-SIM')

from backend.data.surgeries import SurgeryType, MACHINES
from backend.llm.prompt_builder import PromptBuilder
from backend.llm.output_parser  import parse_llm_output
from backend.llm.schemas        import LLMOutput
print('OR-SIM modules imported')

In [None]:
# ── 6. Load MedGemma model ────────────────────────────────────────────────
import time

t0 = time.time()
llm = Llama(
    model_path   = str(MODEL_PATH),
    n_gpu_layers = -1,       # All layers on GPU
    n_ctx        = 4096,
    n_threads    = 4,
    verbose      = False,
)
print(f'Model loaded in {time.time()-t0:.1f}s')

In [None]:
# ── 7. Test transcriptions per surgery ───────────────────────────────────
TEST_CASES = {
    SurgeryType.HEART_TRANSPLANT: [
        'turn on the ventilator',
        'activate bypass pump',
        'turn off the OR lights',
        'we need the defibrillator ready',
        'activate the cardiac monitor',
        'turn off anesthesia machine',
        'turn on the perfusion pump',
        'activate the warming blanket',
        'turn everything off',
        'ventilator on, bypass pump on, OR lights on',
    ],
    SurgeryType.LIVER_RESECTION: [
        'turn on the laparoscopic tower',
        'activate the ultrasonic dissector',
        'turn off the harmonic scalpel',
        'we need the argon beam',
        'activate the fluoroscopy unit',
        'turn on the cell saver',
        'activate the bipolar electrosurgery unit',
        'turn off the OR lights',
        'activate patient warmer',
        'ultrasonic dissector and laparoscopic tower on',
    ],
    SurgeryType.KIDNEY_PCNL: [
        'turn on the fluoroscopy C-arm',
        'activate the nephroscope tower',
        'turn on the lithotripter',
        'we need the irrigation pump',
        'activate the ultrasound guidance',
        'turn off the electrosurgery unit',
        'turn on the suction unit',
        'activate OR lights',
        'everything off please',
        'fluoroscopy on, lithotripter on, nephroscope on',
    ],
}

print('Test cases defined:', sum(len(v) for v in TEST_CASES.values()), 'total')

In [None]:
# ── 8. Run inference on all test cases ───────────────────────────────────
import json

results = []
sep = '=' * 60

for surgery, transcriptions in TEST_CASES.items():
    builder = PromptBuilder(surgery)
    print(f'\n{sep}')
    print(f'Surgery: {surgery.value}')
    print(sep)

    for transcription in transcriptions:
        messages = builder.build_messages(transcription, snapshot=None)

        t0 = time.time()
        response = llm.create_chat_completion(
            messages    = messages,
            max_tokens  = 256,
            temperature = 0.1,
            top_p       = 0.9,
        )
        latency_ms = (time.time() - t0) * 1000

        raw_text = response['choices'][0]['message']['content']
        parsed   = parse_llm_output(raw_text, surgery)

        results.append({
            'surgery':       surgery.value,
            'transcription': transcription,
            'raw_output':    raw_text,
            'parsed_on':     parsed.machine_states['1'],
            'parsed_off':    parsed.machine_states['0'],
            'reasoning':     parsed.reasoning,
            'latency_ms':    latency_ms,
        })

        reasoning_preview = (parsed.reasoning[:80] + '...') if len(parsed.reasoning) > 80 else parsed.reasoning
        print(f'\nTranscription : {transcription}')
        print(f'ON            : {parsed.machine_states["1"]}')
        print(f'OFF           : {parsed.machine_states["0"]}')
        print(f'Reasoning     : {reasoning_preview}')
        print(f'Latency       : {latency_ms:.0f} ms')

print(f'\nTotal inference calls: {len(results)}')

In [None]:
# ── 9. Latency summary ───────────────────────────────────────────────────
import statistics

latencies = [r['latency_ms'] for r in results]
print('Latency Statistics (ms):')
print(f'  Min    : {min(latencies):.0f}')
print(f'  Max    : {max(latencies):.0f}')
print(f'  Mean   : {statistics.mean(latencies):.0f}')
print(f'  Median : {statistics.median(latencies):.0f}')
print(f'  P95    : {sorted(latencies)[int(len(latencies)*0.95)]:.0f}')

target_ms = 3000
under_target = sum(1 for lat in latencies if lat < target_ms)
print(f'\nUnder {target_ms}ms target: {under_target}/{len(latencies)} ({100*under_target/len(latencies):.0f}%)')

In [None]:
# ── 10. Correctness check ────────────────────────────────────────────────
# Heuristic: if keyword in transcription → expect that canonical machine to be ON
# (skips "turn off X" sentences and any transcription with no matching keyword)
keyword_to_canonical = {
    'ventilator':       'Ventilator',
    'bypass pump':      'Cardiopulmonary Bypass Machine',
    'defibrillator':    'Defibrillator',
    'or lights':        'Surgical Lights',
    'cardiac monitor':  'Patient Monitor',
    'perfusion pump':   'Perfusion Pump',
    'warming blanket':  'Warming Blanket',
    'laparoscopic':     'Laparoscopic Tower',
    'ultrasonic':       'Ultrasonic Dissector',
    'lithotripter':     'Lithotripter',
    'c-arm':            'Fluoroscopy C-Arm',
    'fluoroscop':       'Fluoroscopy C-Arm',
    'nephroscope':      'Nephroscope Tower',
    'irrigation pump':  'Irrigation Pump',
    'suction':          'Suction Unit',
}

correct = incorrect = skipped = 0

for r in results:
    tr  = r['transcription'].lower()
    on  = [m.lower() for m in r['parsed_on']]

    hit_kw = False
    for kw, expected in keyword_to_canonical.items():
        if kw in tr:
            hit_kw = True
            # Skip if this keyword appears after a "turn off" / "off" command
            idx = tr.find(kw)
            prefix = tr[max(0, idx-15):idx]
            if 'off' in prefix:
                skipped += 1
            elif any(expected.lower() in o for o in on):
                correct += 1
            else:
                incorrect += 1
                print(f'MISS: "{r["transcription"]}"')
                print(f'      expected ON: {expected!r}, got: {r["parsed_on"]}')
            break

    if not hit_kw:
        skipped += 1

checked = correct + incorrect
if checked > 0:
    print(f'\nHeuristic correctness : {correct}/{checked} ({100*correct/checked:.0f}%)')
    print(f'Skipped (no heuristic): {skipped}')
else:
    print('No keyword-matched heuristics to evaluate')

In [None]:
# ── 11. Save results to JSON ─────────────────────────────────────────────
from pathlib import Path

output_path = Path('/kaggle/working/phase3_results.json')
with open(output_path, 'w') as f:
    json.dump(results, f, indent=2)
print(f'Results saved to {output_path}')
print(f'Total calls  : {len(results)}')
print(f'Avg latency  : {statistics.mean(latencies):.0f} ms')
print(f'Median       : {statistics.median(latencies):.0f} ms')

## Phase 3 Sign-off Criteria

| Check | Target | How verified |
|-------|--------|-------------|
| MedGemma loads without error | ✅ no exception | Cell 6 |
| `llama_supports_gpu_offload()` returns `True` | ✅ must be True | Cell 2 |
| All 30 JSON responses parse successfully | 30 / 30 non-empty `machine_states` | Cell 8 |
| Median inference latency | < 3 000 ms | Cell 9 |
| Heuristic keyword correctness | > 80 % | Cell 10 |

If all pass → Phase 3 complete → report results → **proceed to Phase 4: E2E Pipeline**  
(`LiveTranscriber → MedGemmaModel → StateManager → machine_states.json`)