# MedGemma 1.5 on Our CT Dataset (NIfTI)

This notebook mirrors the official `high_dimensional_ct_hugging_face.ipynb`, but loads CT volumes from our dataset (NIfTI) and samples slices for MedGemma input.

## Setup
Install required dependencies (if not already installed).

In [None]:
# If needed
# !pip install nibabel pydicom transformers pillow tqdm

## Configure paths and settings

In [None]:
from pathlib import Path

# Path to our dataset JSONL (QA)
DATASET_JSONL = Path("3D_VLM_Spatial/spatial_qa_processed.jsonl")

# Root of CT NIfTI volumes (valid_fixed)
NIFTI_ROOT = Path("3D_VLM_Spatial/dataset/data_volumes/dataset/valid_fixed")

# MedGemma model ID
MODEL_ID = "google/medgemma-1.5-4b-it"

# Number of slices to sample uniformly
NUM_SLICES = 85

# Optional resize
RESIZE_LONGEST = 384  # set 0 to disable
PAD_SQUARE = True

## Helpers: load volume, sample slices, windowing
(Directly adapted from the official notebook)

In [None]:
import json
import numpy as np
import nibabel as nib
from PIL import Image
import io, base64

def derive_nifti_path(nifti_root: Path, case_id: str) -> Path:
    stem = case_id.replace(".nii.gz", "")
    subdir = stem.rsplit("_", 1)[0]  # valid_1_a
    base = subdir.rsplit("_", 1)[0] if "_" in subdir else subdir
    return nifti_root / base / subdir / case_id

def sample_indices(n_slices: int, num_samples: int):
    if n_slices <= num_samples:
        return list(range(n_slices))
    idxs = np.linspace(0, n_slices - 1, num_samples)
    return [int(round(i)) for i in idxs]

# Windowing as in the official notebook
WINDOW_CLIPS = [(-1024, 1024), (-135, 215), (0, 80)]

def norm(ct_vol: np.ndarray, lo: float, hi: float) -> np.ndarray:
    ct_vol = np.clip(ct_vol, lo, hi).astype(np.float32)
    ct_vol -= lo
    ct_vol /= (hi - lo)
    ct_vol *= 255.0
    return ct_vol

def window(ct_slice: np.ndarray) -> np.ndarray:
    return np.stack([norm(ct_slice, lo, hi) for (lo, hi) in WINDOW_CLIPS], axis=-1)

def resize_rgb(rgb: np.ndarray, longest: int, pad_square: bool):
    if longest <= 0:
        return rgb
    h, w = rgb.shape[:2]
    scale = longest / float(max(h, w))
    new_w = max(1, int(round(w * scale)))
    new_h = max(1, int(round(h * scale)))
    img = Image.fromarray(rgb)
    img = img.resize((new_w, new_h), resample=Image.BILINEAR)
    if not pad_square:
        return np.asarray(img)
    canvas = Image.new("RGB", (longest, longest), (0, 0, 0))
    x0 = (longest - new_w) // 2
    y0 = (longest - new_h) // 2
    canvas.paste(img, (x0, y0))
    return np.asarray(canvas)

# Load one volume (example)
with DATASET_JSONL.open() as f:
    example = json.loads(next(f))

case_id = example["case_id"]
question = example["question"]

nifti_path = derive_nifti_path(NIFTI_ROOT, case_id)
print("Example NIfTI:", nifti_path)

vol = np.asarray(nib.load(str(nifti_path)).get_fdata())
# Nibabel loads as (x, y, z) -> transpose to (z, y, x)
vol = np.transpose(vol, (2,1,0))

idxs = sample_indices(vol.shape[0], NUM_SLICES)

slices = []
for i in idxs:
    rgb = window(vol[i])
    rgb = np.round(rgb, 0).astype(np.uint8)
    rgb = resize_rgb(rgb, RESIZE_LONGEST, PAD_SQUARE)
    slices.append(rgb)

print("Slices:", len(slices), slices[0].shape)

## Build MedGemma prompt (chat template)

In [None]:
from transformers import AutoProcessor

processor = AutoProcessor.from_pretrained(MODEL_ID, use_fast=True)

instruction = (
    "You are an instructor teaching medical students. You are "
    "analyzing a contiguous block of CT slices. Please review the slices provided below carefully."
)
query_suffix = (
    "

Based on the visual evidence in the slices provided above, "
    "answer the question below. Provide concise reasoning and conclude with a final answer."
)

def _encode(rgb: np.ndarray, fmt: str = "jpeg") -> str:
    with io.BytesIO() as buf:
        Image.fromarray(rgb).save(buf, format=fmt)
        buf.seek(0)
        encoded = base64.b64encode(buf.getbuffer()).decode("utf-8")
    return f"data:image/{fmt};base64,{encoded}"

content = [{"type": "text", "text": instruction}]
for i, rgb in enumerate(slices, 1):
    content.append({"type": "image", "image": _encode(rgb)})
    content.append({"type": "text", "text": f"SLICE {i}"})
content.append({"type": "text", "text": f"{query_suffix}

Question: {question}"})

messages = [{"role": "user", "content": content}]

inputs = processor.apply_chat_template(
    messages,
    add_generation_prompt=True,
    continue_final_message=False,
    return_tensors="pt",
    tokenize=True,
    return_dict=True,
)

print("Input tokens:", inputs["input_ids"].shape)

## Load MedGemma and run inference

In [None]:
import torch
from transformers import AutoModelForImageTextToText

model_kwargs = dict(
    dtype=torch.bfloat16,
    device_map="auto",
    offload_buffers=True,
)

model = AutoModelForImageTextToText.from_pretrained(MODEL_ID, **model_kwargs)

# Move tensors to device
inputs = {k: (v.to(model.device, dtype=torch.bfloat16) if torch.is_floating_point(v) else v.to(model.device))
          for k, v in inputs.items()}

with torch.inference_mode():
    generated = model.generate(**inputs, do_sample=False, max_new_tokens=512)

raw_output = processor.post_process_image_text_to_text(generated, skip_special_tokens=True)[0]
input_text = processor.post_process_image_text_to_text(inputs["input_ids"], skip_special_tokens=True)[0]

# Remove the input prompt from the output if present
out = raw_output
idx = out.find(input_text)
if 0 <= idx <= 2:
    out = out[idx + len(input_text):].strip()

print(out)