In [3]:
from google.colab import userdata
import os

# Get token from Colab secrets
os.environ["HF_TOKEN"] = userdata.get('HF_TOKEN')

!pip install transformers accelerate pillow --quiet

In [4]:
from transformers import pipeline
import torch

pipe = pipeline(
    "image-text-to-text",
    model="google/medgemma-4b-it",
    torch_dtype=torch.bfloat16,
    device="cuda",
)
print("Model loaded!")

config.json:   0%|          | 0.00/2.47k [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors.index.json:   0%|          | 0.00/90.6k [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.96G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/3.64G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/156 [00:00<?, ?B/s]

processor_config.json:   0%|          | 0.00/70.0 [00:00<?, ?B/s]

chat_template.jinja:   0%|          | 0.00/1.53k [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


tokenizer_config.json:   0%|          | 0.00/1.16M [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.69M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/35.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/662 [00:00<?, ?B/s]

Device set to use cuda


Model loaded!


In [5]:
from PIL import Image
import requests

# Public domain chest X-ray
image_url = "https://upload.wikimedia.org/wikipedia/commons/c/c8/Chest_Xray_PA_3-8-2010.png"
image = Image.open(requests.get(image_url, headers={"User-Agent": "test"}, stream=True).raw)

messages = [
    {"role": "system", "content": [{"type": "text", "text": "You are an expert radiologist. Think step by step. First describe what you observe, then assess clinical significance, then provide your conclusion with confidence level."}]},
    {"role": "user", "content": [
        {"type": "text", "text": "Analyze this chest X-ray. Are there any abnormal findings?"},
        {"type": "image", "image": image}
    ]}
]

output = pipe(text=messages, max_new_tokens=512)
print(output[0]["generated_text"][-1]["content"])

Okay, let's analyze this chest X-ray.

**1. Observation:**

*   **Bones:** The ribs, clavicles, and spine appear intact and without obvious fractures or significant degenerative changes.
*   **Lungs:** The lung fields are clear bilaterally. There is no evidence of consolidation, pleural effusion, or pneumothorax. The heart size appears within normal limits.
*   **Mediastinum:** The mediastinum is unremarkable, with the trachea midline.
*   **Soft Tissues:** No obvious masses or enlarged lymph nodes are seen.

**2. Clinical Significance:**

Based on the above observations, there are no obvious acute findings that would suggest a serious cardiopulmonary pathology. The lungs are clear, and the bones appear intact.

**3. Conclusion:**

The chest X-ray is unremarkable.

**Confidence Level:** 95%



In [None]:
# Install medgemma-mcp directly from GitHub (no clone needed)
!pip install git+https://github.com/Tom-R-Main/medgemma-mcp.git --quiet

from medgemma_mcp.prompts.templates import Modality, get_system_prompt, build_image_prompt

# Build CoT prompts for chest X-ray
system = get_system_prompt(Modality.CHEST_XRAY)
user_prompt = build_image_prompt(Modality.CHEST_XRAY, "Are there signs of pneumonia?")
print("=== SYSTEM PROMPT ===")
print(system)
print("\n=== USER PROMPT (first 500 chars) ===")
print(user_prompt[:500], "...")

In [None]:
# Run with CoT prompts — compare to vanilla output above
messages_cot = [
    {"role": "system", "content": [{"type": "text", "text": system}]},
    {"role": "user", "content": [
        {"type": "text", "text": user_prompt},
        {"type": "image", "image": image}
    ]}
]

output_cot = pipe(text=messages_cot, max_new_tokens=1024)
cot_text = output_cot[0]["generated_text"][-1]["content"]

print("=== CoT STRUCTURED OUTPUT ===")
print(cot_text)
print(f"\n--- Vanilla: {len(output[0]['generated_text'][-1]['content'])} chars | CoT: {len(cot_text)} chars ---")

## Pathological Image Test

Test with a CXR showing actual pathology (pneumonia) to verify the model detects findings and the CoT template structures the analysis properly.

In [None]:
# Pathological CXR images — URLs verified via Wikipedia API + curl + PIL
from io import BytesIO

# All URLs verified: HTTP 200, valid JPEG, open with PIL
VERIFIED_IMAGES = {
    "pneumonia": "https://upload.wikimedia.org/wikipedia/commons/0/01/Pneumonia_x_ray.jpg",        # 700x545 RGB
    "sars": "https://upload.wikimedia.org/wikipedia/commons/d/d2/SARS_xray.jpg",                   # 600x543 RGB
    "lobar_pneumonia": "https://upload.wikimedia.org/wikipedia/commons/5/51/X-ray_of_lobar_pneumonia.jpg",  # 3027x2407 RGB
}

# Validate all load correctly
for name, url in VERIFIED_IMAGES.items():
    try:
        resp = requests.get(url, headers={"User-Agent": "MedGemmaMCP/1.0"}, timeout=15)
        resp.raise_for_status()
        img = Image.open(BytesIO(resp.content))
        print(f"{name}: OK — {img.size} {img.mode} ({len(resp.content)//1024}KB)")
    except Exception as e:
        print(f"{name}: FAILED — {e}")

# Use the lobar pneumonia (high-res, clear pathology) for the pipeline test
resp = requests.get(VERIFIED_IMAGES["lobar_pneumonia"], headers={"User-Agent": "MedGemmaMCP/1.0"})
pneumonia_img = Image.open(BytesIO(resp.content))
pneumonia_img

## Full Tool Pipeline Test — Pathological Image

Run the pneumonia CXR through the **exact same code path** the MCP tool uses:
`run_image_inference()` → `extract_confidence()` → `ImageAnalysisResult` (Pydantic) → JSON

This tests everything except the MCP transport layer (stdio), which we test locally with MCP Inspector.

In [None]:
# Extract model + processor from the pipeline (reuse already-loaded weights)
model = pipe.model
processor = pipe.processor

# Import the actual tool pipeline functions
from medgemma_mcp.model.inference import run_image_inference
from medgemma_mcp.safety.confidence import extract_confidence
from medgemma_mcp.safety.disclaimers import MEDICAL_DISCLAIMER
from medgemma_mcp.tools.analyze_image import ImageAnalysisResult

# Run through the SAME code path the MCP tool uses
# (minus the async wrapper and MCP context)
findings = run_image_inference(
    model=model,
    processor=processor,
    image=pneumonia_img,
    prompt=build_image_prompt(Modality.CHEST_XRAY, "Evaluate for pneumonia or other acute cardiopulmonary findings."),
    system_prompt=get_system_prompt(Modality.CHEST_XRAY),
    max_new_tokens=1024,
)

# Confidence extraction — same as the tool does
confidence = extract_confidence(findings)
requires_review = confidence < 0.7

# Build the exact Pydantic result the MCP tool returns
result = ImageAnalysisResult(
    findings=findings,
    confidence=confidence,
    requires_review=requires_review,
    modality="chest_xray",
    disclaimer=MEDICAL_DISCLAIMER,
)

print("=== PNEUMONIA CXR — FULL TOOL PIPELINE ===")
print(f"Confidence:      {result.confidence}")
print(f"Requires review: {result.requires_review}")
print(f"Modality:        {result.modality}")
print(f"Findings length: {len(result.findings)} chars")
print(f"\n--- FINDINGS ---")
print(result.findings)
print(f"\n--- JSON (what Claude Desktop receives) ---")
import json
json_out = result.model_dump_json(indent=2)
print(json_out[:3000])
if len(json_out) > 3000:
    print("...")

## FHIR Parsing + Text Reasoning Pipeline

Test the non-image tools: FHIR Bundle → Python text summary → MedGemma reasoning.
MedGemma was NOT trained on FHIR (67.6% vs base Gemma's 70.9%), so all parsing happens in Python.

In [None]:
from medgemma_mcp.preprocessing.fhir import fhir_bundle_to_summary
from medgemma_mcp.prompts.templates import build_fhir_summary_prompt, build_text_prompt
from medgemma_mcp.model.inference import run_text_inference

# Sample FHIR R4 Bundle
fhir_bundle = {
    "resourceType": "Bundle",
    "type": "searchset",
    "entry": [
        {"resource": {"resourceType": "Patient", "gender": "female",
            "birthDate": "1958-03-15",
            "name": [{"given": ["Maria"], "family": "Garcia"}]}},
        {"resource": {"resourceType": "Condition",
            "code": {"coding": [{"display": "Type 2 diabetes mellitus"}]},
            "clinicalStatus": {"coding": [{"code": "active"}]}}},
        {"resource": {"resourceType": "Condition",
            "code": {"coding": [{"display": "Essential hypertension"}]},
            "clinicalStatus": {"coding": [{"code": "active"}]}}},
        {"resource": {"resourceType": "Condition",
            "code": {"coding": [{"display": "Chronic kidney disease, stage 3"}]},
            "clinicalStatus": {"coding": [{"code": "active"}]}}},
        {"resource": {"resourceType": "MedicationRequest",
            "medicationCodeableConcept": {"text": "Metformin 1000mg"},
            "dosageInstruction": [{"text": "twice daily"}]}},
        {"resource": {"resourceType": "MedicationRequest",
            "medicationCodeableConcept": {"text": "Lisinopril 20mg"},
            "dosageInstruction": [{"text": "once daily"}]}},
        {"resource": {"resourceType": "Observation",
            "code": {"coding": [{"code": "4548-4", "display": "HbA1c"}]},
            "valueQuantity": {"value": 8.2, "unit": "%"}}},
        {"resource": {"resourceType": "Observation",
            "code": {"coding": [{"code": "2160-0", "display": "Creatinine"}]},
            "valueQuantity": {"value": 1.8, "unit": "mg/dL"}}},
    ]
}

# Step 1: FHIR → text (Python, no model needed)
clinical_summary = fhir_bundle_to_summary(fhir_bundle)
print("=== FHIR → TEXT SUMMARY (Python) ===")
print(clinical_summary)

# Step 2: Build CoT prompt
fhir_prompt = build_fhir_summary_prompt(clinical_summary, "What are the key clinical concerns for this patient?")

# Step 3: Run through MedGemma text inference
fhir_analysis = run_text_inference(
    model=model,
    processor=processor,
    prompt=fhir_prompt,
    system_prompt="You are an expert physician providing evidence-based clinical reasoning.",
    max_new_tokens=1024,
)

fhir_confidence = extract_confidence(fhir_analysis)

print(f"\n=== MEDGEMMA FHIR ANALYSIS (confidence: {fhir_confidence}) ===")
print(fhir_analysis)

## What This Notebook Tests vs What Needs Local Testing

| Layer | Tested Here | How |
|-------|------------|-----|
| MedGemma inference | Yes | `run_image_inference()`, `run_text_inference()` |
| CoT prompt templates | Yes | `build_image_prompt()`, `build_fhir_summary_prompt()` |
| Confidence extraction | Yes | `extract_confidence()` parses model output |
| Structured Pydantic responses | Yes | `ImageAnalysisResult` with JSON serialization |
| FHIR R4 → text preprocessing | Yes | `fhir_bundle_to_summary()` |
| Medical disclaimers | Yes | Included in result objects |
| **MCP stdio transport** | **No** | Test locally with `npx @modelcontextprotocol/inspector medgemma-mcp` |
| **Claude Desktop integration** | **No** | Test locally with Claude Desktop config |
| **DICOM preprocessing** | **No** | Requires .dcm files, test with `pytest tests/` |
| **Base64 image loading** | **No** | Covered by unit tests (`pytest tests/test_preprocessing.py`) |

## MCP Server over SSE (Remote Access)

Run the actual MCP server on Colab's GPU, expose via ngrok, and connect Claude Desktop to it.
This tests the **full MCP protocol** — tool discovery, parameter validation, lifespan, stdio → SSE transport — with the real model on a real GPU.

### Steps:
1. Install ngrok and authenticate
2. Start the MedGemma MCP server with SSE transport
3. Create ngrok tunnel
4. Configure Claude Desktop with the public URL

**Prerequisites:** Add your ngrok auth token to Colab Secrets as `NGROK_AUTH_TOKEN`. Get one free at https://ngrok.com

In [None]:
# Step 1: Install ngrok and update medgemma-mcp (needs --transport flag)
!pip install pyngrok --quiet
!pip install git+https://github.com/Tom-R-Main/medgemma-mcp.git --force-reinstall --quiet

In [None]:
# Step 2: Start MCP SSE server + ngrok tunnel
import subprocess
import time
from pyngrok import ngrok

# Authenticate ngrok (add NGROK_AUTH_TOKEN to Colab Secrets)
ngrok_token = userdata.get('NGROK_AUTH_TOKEN')
ngrok.set_auth_token(ngrok_token)

# Set HF token for the server process
os.environ["HF_TOKEN"] = userdata.get('HF_TOKEN')

# Start the MCP server with SSE transport in background
server_proc = subprocess.Popen(
    ["medgemma-mcp", "--transport", "sse", "--host", "0.0.0.0", "--port", "8000"],
    stdout=subprocess.PIPE,
    stderr=subprocess.PIPE,
)
print(f"MCP server started (PID {server_proc.pid})")
print("Waiting for model to load...")

# Wait for the server to be ready (model loading takes ~30-60s on T4)
import urllib.request
for i in range(120):  # up to 2 minutes
    time.sleep(5)
    try:
        urllib.request.urlopen("http://localhost:8000/sse", timeout=2)
    except urllib.error.URLError:
        if i % 6 == 0:
            print(f"  ...loading ({(i+1)*5}s)")
        continue
    except Exception:
        # SSE endpoint exists but returns error without proper handshake — server is ready
        break
else:
    print("WARNING: Server may not be ready yet. Check stderr:")
    print(server_proc.stderr.read(2000).decode())

# Create ngrok tunnel
tunnel = ngrok.connect(8000)
public_url = tunnel.public_url

print(f"\n{'='*60}")
print(f"MCP SERVER RUNNING")
print(f"{'='*60}")
print(f"Local:  http://localhost:8000/sse")
print(f"Public: {public_url}/sse")
print(f"\nAdd this to Claude Desktop config:")
print(f"""
{{
  "mcpServers": {{
    "medgemma": {{
      "url": "{public_url}/sse"
    }}
  }}
}}
""")
print("Then restart Claude Desktop. The 4 MedGemma tools will appear.")
print(f"{'='*60}")

In [None]:
# Step 3: Monitor server logs (run this to see activity)
# The server will keep running as long as the Colab session is active.
# Check stderr for logs:
import select
print("Server logs (Ctrl+C to stop monitoring):\n")
try:
    while True:
        line = server_proc.stderr.readline()
        if line:
            print(line.decode().strip())
        else:
            time.sleep(1)
except KeyboardInterrupt:
    print("\nStopped monitoring. Server still running.")