In [1]:
# =========================
# 0. INSTALL (run once)
# =========================
!pip install "transformers>=4.41.0" "datasets>=2.19.0" "accelerate>=0.30.0" \
              peft trl bitsandbytes pillow

# If you use Unsloth:
!pip install unsloth





In [2]:
from google.colab import drive
drive.mount('/content/drive')

# My Drive/medgemma_finetune/data
!cp -r "/content/drive/MyDrive/data" "/content/medgemma_finetune/"

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [3]:
%cd /content/medgemma_finetune
!python make_instructions.py


/content/medgemma_finetune
[INFO] Dataset: kidney CT
  - Class: Cyst
    -> Used 60 images
  - Class: Normal
    -> Used 60 images
  - Class: Stone
    -> Used 60 images
  - Class: Tumor
    -> Used 60 images
[INFO] Dataset: Breast MRI
  - Class: Malignant
    -> Used 60 images
  - Class: Benign
    -> Used 60 images
[INFO] Dataset: Brain Tumor MRI images
  - Class: Healthy
    -> Used 60 images
  - Class: Tumor
    -> Used 60 images
[INFO] Dataset: mammography
  - Class: Malignant
    -> Used 60 images
  - Class: Benign
    -> Used 60 images
[INFO] Dataset: Brain Tumor CT scan Images
  - Class: Healthy
    -> Used 60 images
  - Class: Tumor
    -> Used 60 images
[INFO] Dataset: lung cancer
  - Class: Bengin cases
    -> Used 60 images
  - Class: Normal cases
    -> Used 60 images
  - Class: Malignant cases
    -> Used 60 images
[INFO] Collected 900 examples total.
[INFO] Train: 765 | Val: 135
[OK] Wrote train_instructions.jsonl and val_instructions.jsonl in /content/medgemma_finetune


In [4]:
import os
from pathlib import Path
from typing import Any

import torch
from datasets import load_dataset
from PIL import Image

from transformers import (
    AutoProcessor,
    AutoModelForImageTextToText,
    BitsAndBytesConfig,
)
from peft import LoraConfig
from trl import SFTConfig, SFTTrainer



In [19]:
PROJECT_ROOT = Path(".")  # current folder
TRAIN_JSONL = "/content/medgemma_finetune/train_instructions.jsonl"
VAL_JSONL   = "/content/medgemma_finetune/val_instructions.jsonl"


MODEL_ID = "unsloth/medgemma-4b-it"   # <-- CHANGE THIS to your Unsloth model id if needed

# Tiny “safer than your GPA” training hyperparams
NUM_EPOCHS = 1
LEARNING_RATE = 1e-5       # you can go 5e-6 if you want to be extra safe
BATCH_SIZE = 4             # adjust by VRAM
GRAD_ACCUM = 2             # effective batch = BATCH_SIZE * GRAD_ACCUM




In [20]:
data = load_dataset(
    "json",
    data_files={
        "train": str(TRAIN_JSONL),
        "validation": str(VAL_JSONL),
    },
)
print(data)

DatasetDict({
    train: Dataset({
        features: ['image_path', 'dataset', 'class_name', 'prompt', 'target'],
        num_rows: 765
    })
    validation: Dataset({
        features: ['image_path', 'dataset', 'class_name', 'prompt', 'target'],
        num_rows: 135
    })
})


In [21]:
# =========================
# 3. FORMAT DATA: image + messages
# =========================

def format_data(example: dict[str, Any]) -> dict[str, Any]:
    """
    Expects each JSONL row to have:
      - image_path: path to image
      - prompt: user prompt text
      - target: JSON string with {label_key, confidence, summary}
    Adds:
      - image: PIL image (RGB)
      - messages: chat-style messages for MedGemma
    """
    img_path = Path(example["image_path"])
    if not img_path.is_absolute():
        img_path = PROJECT_ROOT / img_path

    image = Image.open(img_path).convert("RGB")
    example["image"] = image

    example["messages"] = [
        {
            "role": "user",
            "content": [
                {"type": "image"},
                {"type": "text", "text": example["prompt"]},
            ],
        },
        {
            "role": "assistant",
            "content": [
                {"type": "text", "text": example["target"]},
            ],
        },
    ]
    return example


data = data.map(format_data)
print(data["train"][0])

{'image_path': '/content/medgemma_finetune/data/Brain Tumor MRI images/Tumor/glioma (431).jpg', 'dataset': 'Brain Tumor MRI images', 'class_name': 'Tumor', 'prompt': 'You are an assistant radiologist. Modality: MRI brain. Task: Determine whether the MRI shows a brain tumor and summarize key findings. Analyze the given medical image and respond ONLY with valid JSON.\n\nThe JSON must contain the keys: "brain_mri_finding", "confidence", "summary".\n', 'target': '{"brain_mri_finding": "Tumor", "confidence": 0.93, "summary": "Abnormal mass lesion with altered signal intensity, consistent with a brain tumor. Recommend correlation with clinical history and further evaluation as needed."}', 'image': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=512x512 at 0x7CBD21BC36B0>, 'messages': [{'content': [{'text': None, 'type': 'image'}, {'text': 'You are an assistant radiologist. Modality: MRI brain. Task: Determine whether the MRI shows a brain tumor and summarize key findings. Analyze the gi

In [13]:
from huggingface_hub import login
login()


VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [None]:
# =========================
# 4. LOAD 4-BIT MEDGEMMA
# =========================

# GPU sanity check (bfloat16 support – required by official MedGemma)
if not torch.cuda.is_available():
    raise RuntimeError("CUDA GPU not found. You need a GPU for MedGemma fine-tuning.")

if torch.cuda.get_device_capability()[0] < 8:
    print("[WARN] GPU may not support bfloat16 well; official notebook requires compute 8.0+ (A100, H100, etc).")

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,   # <--- changed
    bnb_4bit_quant_storage=torch.float16,   # <--- changed
)

model = AutoModelForImageTextToText.from_pretrained(
    MODEL_ID,
    quantization_config=bnb_config,
    torch_dtype=torch.float16,              # <--- changed
    device_map="auto",
)

processor = AutoProcessor.from_pretrained(MODEL_ID)
processor.tokenizer.padding_side = "right"

[WARN] GPU may not support bfloat16 well; official notebook requires compute 8.0+ (A100, H100, etc).


config.json: 0.00B [00:00, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/3.64G [00:00<?, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.96G [00:00<?, ?B/s]

In [None]:
# =========================
# 5. LoRA CONFIG (SAFE MODE)
# =========================

peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.05,
    r=8,                  # smaller rank = lighter LoRA
    bias="none",
    target_modules="all-linear",
    task_type="CAUSAL_LM",
    modules_to_save=[
        "lm_head",
        "embed_tokens",
    ],
)

In [None]:
# =========================
# 6. COLLATE FN (FROM GOOGLE NOTEBOOK)
# =========================

def collate_fn(examples: list[dict[str, Any]]):
    texts = []
    images = []

    for example in examples:
        images.append([example["image"].convert("RGB")])
        texts.append(
            processor.apply_chat_template(
                example["messages"],
                add_generation_prompt=False,
                tokenize=False,
            ).strip()
        )

    # Tokenize text + process images
    batch = processor(text=texts, images=images, return_tensors="pt", padding=True)

    # Labels are input_ids with some tokens masked out
    labels = batch["input_ids"].clone()

    # Mask image tokens
    boi_id = processor.tokenizer.convert_tokens_to_ids(
        processor.tokenizer.special_tokens_map["boi_token"]
    )
    image_token_id = [boi_id]

    labels[labels == processor.tokenizer.pad_token_id] = -100
    labels[labels == image_token_id] = -100
    labels[labels == 262144] = -100  # special large image token id used in notebook

    batch["labels"] = labels
    return batch

In [None]:
from trl import SFTConfig

args = SFTConfig(
    output_dir="medgemma-4b-it-lora-medmulti",
    num_train_epochs=1,                   # 1 epoch as you planned
    per_device_train_batch_size=1,        # <<< was 4, drop to 1
    per_device_eval_batch_size=1,         # <<< drop eval batch too
    gradient_accumulation_steps=4,        # effective batch = 1 * 4 = 4
    gradient_checkpointing=True,
    optim="adamw_torch_fused",
    logging_steps=10,

    # EVAL: turn off during training to save VRAM
    save_strategy="epoch",
    eval_strategy="no",                   # <<< no eval during train

    learning_rate=1e-5,
    max_grad_norm=0.3,
    warmup_ratio=0.03,
    lr_scheduler_type="linear",

    bf16=False,                           # we're using float16 on T4
    fp16=True,                            # enable fp16

    push_to_hub=False,
    report_to="none",
    gradient_checkpointing_kwargs={"use_reentrant": False},
    dataset_kwargs={"skip_prepare_dataset": True},
    remove_unused_columns=False,
    label_names=["labels"],

    # cap seq length to avoid super long prompts blowing up memory
    max_seq_length=512,                   # default is bigger; 512 is safer
)


In [None]:
# =========================
# 8. TRAINER & TRAIN
# =========================

# You can optionally subsample validation to speed up:
eval_dataset = data["validation"]  # .shuffle().select(range(200))

trainer = SFTTrainer(
    model=model,
    args=args,
    train_dataset=data["train"],
    eval_dataset=eval_dataset,
    peft_config=peft_config,
    processing_class=processor,
    data_collator=collate_fn,
)

print("[INFO] Starting training...")
trainer.train()

In [None]:
# Save final adapter
trainer.save_model()  # saves LoRA weights into output_dir
print("[OK] Training complete. LoRA adapter saved in:", args.output_dir)