# AEGIS â€” MedGemma Inference Server (Kaggle)

This notebook loads **MedGemma 1.5 4B IT** on Kaggle's free T4 GPU and serves it
as an OpenAI-compatible API via **ngrok**.

## Setup
1. **Accelerator**: GPU T4 x2 (Settings â†’ Accelerator)
2. **Internet**: ON (Settings â†’ Internet â†’ Internet connected)
3. **Secrets**: Add `HF_TOKEN` and `NGROK_TOKEN` in Add-ons â†’ Secrets
4. Run all cells, copy the ngrok URL into your `.env.local`

## Model
- `google/medgemma-1.5-4b-it` â€” 4-bit quantized, fits in ~5 GB VRAM
- Updated Jan 13, 2026: improved EHR, CT/MRI 3D imaging, longitudinal CXR, lab report extraction
- Part of Google's HAI-DEF collection (required for competition)


In [None]:
# Cell 1: Install dependencies
%pip install -q -U transformers accelerate bitsandbytes
%pip install -q fastapi uvicorn pyngrok nest_asyncio
print("âœ… Dependencies installed")

In [None]:
# Cell 2: Load MedGemma 1.5 4B IT (4-bit quantized)
import os
import torch
from transformers import AutoProcessor, AutoModelForImageTextToText, BitsAndBytesConfig

# Kaggle Secrets (only available when running on Kaggle)
try:
    from kaggle_secrets import UserSecretsClient
    _secrets = UserSecretsClient()
    HF_TOKEN = _secrets.get_secret("HF_TOKEN")
    NGROK_TOKEN = _secrets.get_secret("NGROK_TOKEN")
except ImportError:
    # Fallback: read from environment variables (e.g. local testing)
    HF_TOKEN = os.environ.get("HF_TOKEN", "")
    NGROK_TOKEN = os.environ.get("NGROK_TOKEN", "")

MODEL_ID = "google/medgemma-1.5-4b-it"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,  # bfloat16 required â€” float16 causes NaN logits
    bnb_4bit_use_double_quant=True,
)

print(f"Loading {MODEL_ID} ...")
processor = AutoProcessor.from_pretrained(MODEL_ID, token=HF_TOKEN)
model = AutoModelForImageTextToText.from_pretrained(
    MODEL_ID,
    token=HF_TOKEN,
    quantization_config=bnb_config,
    device_map="auto",
    torch_dtype=torch.bfloat16,
)
model.eval()
print(f"âœ… {MODEL_ID} loaded â€” {model.get_memory_footprint()/1e9:.1f} GB VRAM")


In [None]:
# Cell 3: Inference helper
import json, re

def generate(messages: list[dict], max_tokens: int = 1024, temperature: float = 0.0) -> str:
    """
    Chat-completion style generation using MedGemma.

    Uses processor.tokenizer.apply_chat_template which natively supports the
    'system' role with plain string content â€” unlike processor.apply_chat_template
    which is multimodal-only and does NOT support system role.

    Reference: official RL notebook (google-health/medgemma) uses this exact pattern.
    """
    # Normalise content: typed-dict lists -> plain strings
    plain_messages = []
    for msg in messages:
        content = msg["content"]
        if isinstance(content, list):
            content = " ".join(
                part.get("text", "")
                for part in content
                if isinstance(part, dict) and part.get("type") == "text"
            )
        plain_messages.append({"role": msg["role"], "content": str(content)})

    # Tokenise using the *tokenizer's* apply_chat_template (supports system role)
    inputs = processor.tokenizer.apply_chat_template(
        plain_messages,
        add_generation_prompt=True,
        tokenize=True,
        return_dict=True,
        return_tensors="pt",
    ).to(model.device)

    with torch.inference_mode():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_tokens,
            do_sample=temperature > 0,
            temperature=temperature if temperature > 0 else None,
            top_p=0.95 if temperature > 0 else None,
            pad_token_id=processor.tokenizer.eos_token_id,
        )

    input_len = inputs["input_ids"].shape[-1]
    response = processor.tokenizer.decode(outputs[0][input_len:], skip_special_tokens=True)
    return response.strip()


# Quick smoke test â€” also verify system role is handled correctly
test = generate([
    {"role": "system", "content": "You are a helpful medical assistant."},
    {"role": "user", "content": "What are 3 warning signs of a stroke?"},
], max_tokens=200)
print("Smoke test:", test[:400])


In [None]:
# Cell 4: FastAPI server (OpenAI-compatible /v1/chat/completions)
import time, uuid
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse

app = FastAPI(title="AEGIS MedGemma Server")
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["*"],
    allow_headers=["*"],
)


def _content_str(content) -> str:
    """Coerce message content to a plain string for token counting."""
    if isinstance(content, list):
        return " ".join(
            p.get("text", "") for p in content
            if isinstance(p, dict) and p.get("type") == "text"
        )
    return str(content) if content else ""


@app.get("/health")
async def health():
    return {"status": "ok", "model": MODEL_ID}


@app.post("/v1/chat/completions")
async def chat_completions(request: Request):
    body = await request.json()
    messages = body.get("messages", [])
    max_tokens = body.get("max_tokens", 1024)
    temperature = body.get("temperature", 0.0)

    start = time.time()
    content = generate(messages, max_tokens=max_tokens, temperature=temperature)
    elapsed = time.time() - start

    prompt_tokens = sum(
        len(processor.tokenizer.encode(_content_str(m.get("content", ""))))
        for m in messages
    )
    completion_tokens = len(processor.tokenizer.encode(content))

    return {
        "id": f"chatcmpl-{uuid.uuid4().hex[:8]}",
        "object": "chat.completion",
        "created": int(time.time()),
        "model": MODEL_ID,
        "choices": [{
            "index": 0,
            "message": {"role": "assistant", "content": content},
            "finish_reason": "stop",
        }],
        "usage": {
            "prompt_tokens": prompt_tokens,
            "completion_tokens": completion_tokens,
            "total_tokens": prompt_tokens + completion_tokens,
        },
        "latency_seconds": round(elapsed, 2),
    }


# AEGIS-specific endpoint (simpler interface)
@app.post("/analyze")
async def analyze(request: Request):
    body = await request.json()
    symptoms = body.get("symptoms", "")
    system_prompt = body.get(
        "systemPrompt",
        "You are AEGIS, a clinical triage AI. Analyze the patient presentation "
        "and return JSON with: symptoms, severity (low/medium/high), summary, "
        "differential (list of {condition, probability, recommendation}), "
        "recommendations, reasoning, confidence (0-1).",
    )

    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": f"Patient presentation: {symptoms}"},
    ]

    content = generate(messages, max_tokens=1024, temperature=0.0)
    return {"response": content, "model": MODEL_ID}


print("âœ… FastAPI app defined")


In [None]:
# Cell 5: Launch server with ngrok tunnel
import nest_asyncio
import threading
import uvicorn

try:
    from pyngrok import ngrok
except ImportError:
    raise ImportError("pyngrok not found â€” run Cell 1 first to install it.")

nest_asyncio.apply()
ngrok.set_auth_token(NGROK_TOKEN)

PORT = 8000

# Start uvicorn in background thread
thread = threading.Thread(
    target=uvicorn.run,
    args=(app,),
    kwargs={"host": "0.0.0.0", "port": PORT, "log_level": "info"},
    daemon=True,
)
thread.start()

# Open ngrok tunnel
public_url = ngrok.connect(PORT, "http").public_url
print("=" * 60)
print("ðŸš€ AEGIS MedGemma Server is LIVE!")
print(f"   Public URL: {public_url}")
print(f"   Health:     {public_url}/health")
print(f"   Analyze:    {public_url}/analyze")
print(f"   OpenAI API: {public_url}/v1/chat/completions")
print("=" * 60)
print()
print("ðŸ“‹ Add this to your .env.local:")
print(f"   KAGGLE_MEDGEMMA_URL={public_url}")
print()
print("Keep this notebook running while using AEGIS.")


In [None]:
# Cell 6: (Optional) Test the endpoint
import requests

# Test health
r = requests.get(f"{public_url}/health")
print("Health:", r.json())

# Test analyze
r = requests.post(f"{public_url}/analyze", json={
    "symptoms": "55 year old male with crushing chest pain radiating to left arm, diaphoresis, nausea"
})
print("\nAnalysis result:")
print(r.json()["response"][:500])

In [None]:
# Cell 7: Keep alive â€” run this cell to prevent notebook timeout
import time
print("Keeping server alive... (Interrupt kernel to stop)")
while True:
    time.sleep(60)
    print(f"[{time.strftime('%H:%M:%S')}] Server running at {public_url}")