In [None]:
# Install required packages if needed
# !pip install transformers medmnist pillow torchvision huggingface_hub

In [None]:
import os
import torch
from transformers import AutoProcessor, AutoModelForImageTextToText
from PIL import Image
from torchvision import transforms
from medmnist import PneumoniaMNIST
import matplotlib.pyplot as plt

In [None]:
from huggingface_hub import login

# Replace with your actual HuggingFace token
# login("your_token_here")
# Or use: login()  # will prompt interactively
login()

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

In [None]:
# Create the data directory if it doesn't exist
data_root = './data'
os.makedirs(data_root, exist_ok=True)

In [None]:
# Load the datasets
train_dataset = PneumoniaMNIST(root=data_root, split='train', download=True)
val_dataset   = PneumoniaMNIST(root=data_root, split='val',   download=True)
test_dataset  = PneumoniaMNIST(root=data_root, split='test',  download=True)

print(f"Training samples:   {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Test samples:       {len(test_dataset)}")

In [None]:
# Load MedGemma model and processor
model_id = "google/medgemma-4b-it"

processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForImageTextToText.from_pretrained(model_id)
model = model.to(device)
model.eval()
print("Model loaded successfully.")

In [None]:
def preprocess_image(img):
    """
    Convert PneumoniaMNIST image to RGB 224x224 PIL image.
    """
    import numpy as np
    if not isinstance(img, Image.Image):
        arr = img.squeeze()
        if hasattr(arr, 'numpy'):
            arr = arr.numpy()
        arr = np.array(arr, dtype=np.uint8)
        img = Image.fromarray(arr)
    img = img.convert("RGB")
    img = transforms.Resize((224, 224))(img)
    return img

In [None]:
# Select a sample image
sample_index = 0
img, label = train_dataset[sample_index]
img = preprocess_image(img)

ground_truth = "Pneumonia" if label == 1 else "Normal"
print(f"Ground Truth: {ground_truth}")

In [None]:
# Define the prompt
prompt_text = (
    "You are an expert radiologist. "
    "Generate a chest X-ray report following RSNA pneumonia guidelines. "
    "Include only findings consistent with the standard criteria "
    "(consolidation, effusion, pneumothorax, lung opacity, heart size, "
    "mediastinum, pleura, bones) and provide a structured summary with "
    "Findings and Impression."
)

In [None]:
# Build the chat messages
messages = [
    {
        "role": "system",
        "content": [{"type": "text", "text": "You are an expert radiologist."}]
    },
    {
        "role": "user",
        "content": [
            {"type": "text",  "text":  prompt_text},
            {"type": "image", "image": img}
        ]
    }
]

In [None]:
# Prepare inputs
inputs = processor.apply_chat_template(
    messages,
    add_generation_prompt=True,
    tokenize=True,
    return_dict=True,
    return_tensors="pt"
).to(device)

input_len = inputs["input_ids"].shape[-1]
print(f"Input token length: {input_len}")

In [None]:
# Generate prediction
print("Generating prediction... (CPU may take 1-3 minutes)")

with torch.inference_mode():
    output_ids = model.generate(
        **inputs,
        max_new_tokens=300,
        do_sample=False
    )

generated_ids = output_ids[0][input_len:]
prediction = processor.decode(generated_ids, skip_special_tokens=True).strip()
print("Generation complete.")

In [None]:
# Display image and results
plt.figure(figsize=(4, 4))
plt.imshow(img)
plt.title(f"Ground Truth: {ground_truth}", fontsize=13)
plt.axis("off")
plt.tight_layout()
plt.show()

print("\n" + "="*60)
print("MODEL PREDICTION:")
print("="*60)
print(prediction)

In [None]:
# --- Optional: Batch evaluation on multiple test samples ---

def run_inference(img_pil, prompt_text, model, processor, device, max_new_tokens=300):
    """Run inference on a single PIL image."""
    messages = [
        {"role": "system",
         "content": [{"type": "text", "text": "You are an expert radiologist."}]},
        {"role": "user",
         "content": [
             {"type": "text",  "text":  prompt_text},
             {"type": "image", "image": img_pil}
         ]}
    ]
    inputs = processor.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=True,
        return_dict=True,
        return_tensors="pt"
    ).to(device)
    input_len = inputs["input_ids"].shape[-1]
    with torch.inference_mode():
        output_ids = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
    generated_ids = output_ids[0][input_len:]
    return processor.decode(generated_ids, skip_special_tokens=True).strip()


# Evaluate on first 3 test samples
num_samples = 3
print(f"Running inference on {num_samples} test samples...\n")

for i in range(num_samples):
    raw_img, lbl = test_dataset[i]
    pil_img = preprocess_image(raw_img)
    gt = "Pneumonia" if lbl == 1 else "Normal"
    pred = run_inference(pil_img, prompt_text, model, processor, device, max_new_tokens=200)

    print(f"--- Sample {i} | Ground Truth: {gt} ---")
    print(pred)
    print()