# MedGemma Quickstart Inference

This notebook demonstrates how to run inference with MedGemma 4B for chest X-ray triage.

**Requirements:**
- GPU with 16GB+ VRAM (or use quantization for smaller GPUs)
- `transformers>=4.50.0`
- `torch>=2.1.0`

**Time to complete:** ~10 minutes

## 1. Setup

In [None]:
# Install dependencies
# !pip install -U transformers accelerate torch Pillow requests

In [None]:
import torch
from transformers import pipeline, AutoProcessor, AutoModelForImageTextToText
from PIL import Image
import requests
from io import BytesIO
import warnings
warnings.filterwarnings('ignore')

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

## 2. Load MedGemma Model

We use `google/medgemma-4b-it` - the instruction-tuned 4B parameter multimodal model.

In [None]:
MODEL_ID = "google/medgemma-4b-it"

# Method 1: Using pipeline (simpler)
print("Loading MedGemma via pipeline...")

pipe = pipeline(
    "image-text-to-text",
    model=MODEL_ID,
    torch_dtype=torch.bfloat16,
    device="cuda" if torch.cuda.is_available() else "cpu",
)

print("Model loaded successfully!")

## 3. Load Sample Images

We use CC0-licensed chest X-ray images for demonstration.

In [None]:
def load_image_from_url(url: str) -> Image.Image:
    """Load an image from URL."""
    response = requests.get(url, headers={"User-Agent": "MedGemmaDemo"})
    response.raise_for_status()
    return Image.open(BytesIO(response.content)).convert("RGB")


# Sample CXR images (CC0 licensed from Wikimedia Commons)
SAMPLE_IMAGES = [
    {
        "name": "Normal CXR",
        "url": "https://upload.wikimedia.org/wikipedia/commons/c/c8/Chest_Xray_PA_3-8-2010.png",
        "expected": "Non-Urgent",
    },
]

# Load images
images = []
for sample in SAMPLE_IMAGES:
    try:
        img = load_image_from_url(sample["url"])
        images.append({"name": sample["name"], "image": img, "expected": sample["expected"]})
        print(f"✓ Loaded: {sample['name']} ({img.size})")
    except Exception as e:
        print(f"✗ Failed to load {sample['name']}: {e}")

print(f"\nLoaded {len(images)} images")

## 4. Define Triage Prompts

We create structured prompts for:
1. **Urgency Classification** - Urgent vs Non-Urgent
2. **Brief Explanation** - One-line rationale
3. **Key Findings** - Detailed observations

In [None]:
SYSTEM_PROMPT = """You are an expert radiologist assistant. Your role is to help triage chest X-rays by:
1. Classifying urgency (Urgent or Non-Urgent)
2. Providing a brief explanation
3. Highlighting key findings

⚠️ IMPORTANT: This is for clinical decision support only. All findings require verification by a qualified radiologist."""

TRIAGE_PROMPT = """Analyze this chest X-ray and provide:

1. URGENCY: [Urgent/Non-Urgent]
2. EXPLANATION: [One-line explanation for the urgency classification]
3. KEY FINDINGS: [List 2-3 key observations]

Be concise and focus on clinically significant findings."""

def create_triage_messages(image: Image.Image) -> list:
    """Create chat messages for triage inference."""
    return [
        {
            "role": "system",
            "content": [{"type": "text", "text": SYSTEM_PROMPT}]
        },
        {
            "role": "user",
            "content": [
                {"type": "text", "text": TRIAGE_PROMPT},
                {"type": "image", "image": image}
            ]
        }
    ]

## 5. Run Inference on Sample Images

In [None]:
import time

def run_triage_inference(pipe, image: Image.Image, max_tokens: int = 300) -> dict:
    """Run triage inference on a single image."""
    messages = create_triage_messages(image)
    
    start_time = time.time()
    output = pipe(text=messages, max_new_tokens=max_tokens)
    inference_time = time.time() - start_time
    
    response = output[0]["generated_text"][-1]["content"]
    
    return {
        "response": response,
        "inference_time_ms": inference_time * 1000
    }


# Run inference on all samples
print("=" * 60)
print("MedGemma CXR Triage Inference")
print("=" * 60)

results = []
for i, sample in enumerate(images):
    print(f"\n[Image {i+1}/{len(images)}] {sample['name']}")
    print("-" * 40)
    
    result = run_triage_inference(pipe, sample["image"])
    results.append({
        "name": sample["name"],
        "expected": sample["expected"],
        **result
    })
    
    print(f"Inference time: {result['inference_time_ms']:.0f}ms")
    print(f"\nResponse:\n{result['response']}")
    print()

print("=" * 60)
print("Inference Complete")
print("=" * 60)

## 6. Parse and Structure Results

Extract structured fields from the model output.

In [None]:
import re

def parse_triage_response(response: str) -> dict:
    """Parse structured fields from triage response."""
    result = {
        "urgency": None,
        "explanation": None,
        "key_findings": []
    }
    
    # Extract urgency
    urgency_match = re.search(r'URGENCY:\s*\[?(Urgent|Non-Urgent)\]?', response, re.IGNORECASE)
    if urgency_match:
        result["urgency"] = urgency_match.group(1).capitalize()
    
    # Extract explanation
    explanation_match = re.search(r'EXPLANATION:\s*\[?([^\]\n]+)\]?', response, re.IGNORECASE)
    if explanation_match:
        result["explanation"] = explanation_match.group(1).strip()
    
    # Extract key findings
    findings_section = re.search(r'KEY FINDINGS:\s*(.+)', response, re.IGNORECASE | re.DOTALL)
    if findings_section:
        findings_text = findings_section.group(1)
        # Split by bullet points or numbers
        findings = re.split(r'[\n•\-\d\.]+', findings_text)
        result["key_findings"] = [f.strip() for f in findings if f.strip()]
    
    return result


# Parse all results
for result in results:
    parsed = parse_triage_response(result["response"])
    result.update(parsed)

# Display structured results
print("Structured Results:")
print("-" * 40)
for result in results:
    print(f"Image: {result['name']}")
    print(f"  Urgency: {result['urgency']}")
    print(f"  Explanation: {result['explanation']}")
    print(f"  Key Findings: {result['key_findings'][:3]}")
    print()

## 7. Summary & Next Steps

### What we demonstrated:
- Loading MedGemma 4B IT model
- Running multimodal inference (image + text)
- Parsing structured triage outputs

### Performance:
- Inference time varies by GPU (typically 1-5 seconds)
- Memory usage: ~8GB VRAM in bfloat16

### Next Steps:
1. **Fine-tuning**: See `02_fine_tune_medgemma.ipynb` for task-specific training
2. **Evaluation**: See `03_evaluation_and_metrics.ipynb` for comprehensive metrics
3. **Deployment**: See `demo_app/` for production inference

### ⚠️ Disclaimer
This is for **clinical decision support only**. All AI outputs require verification by qualified healthcare professionals.

In [None]:
# Save results for evaluation
import json
from pathlib import Path

output_dir = Path("../eval")
output_dir.mkdir(exist_ok=True)

# Save inference results
output_file = output_dir / "quickstart_results.json"
with open(output_file, "w") as f:
    # Remove PIL images before serializing
    serializable_results = [
        {k: v for k, v in r.items() if k != "image"} 
        for r in results
    ]
    json.dump(serializable_results, f, indent=2)

print(f"Results saved to: {output_file}")