In [1]:
import os
# Fixes potential memory fragmentation issues
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# Optional: Add this to potentially reduce overall memory usage if needed
# os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

In [2]:
# =========================
# 0. INSTALL (run once)
# =========================
!pip install "transformers>=4.41.0" "datasets>=2.19.0" "accelerate>=0.30.0" \
              peft trl bitsandbytes pillow

# If you use Unsloth:
!pip install unsloth



Collecting trl
  Downloading trl-0.25.1-py3-none-any.whl.metadata (11 kB)
Collecting bitsandbytes
  Downloading bitsandbytes-0.48.2-py3-none-manylinux_2_24_x86_64.whl.metadata (10 kB)
Downloading trl-0.25.1-py3-none-any.whl (465 kB)
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m465.5/465.5 kB[0m [31m22.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading bitsandbytes-0.48.2-py3-none-manylinux_2_24_x86_64.whl (59.4 MB)
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m59.4/59.4 MB[0m [31m13.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: bitsandbytes, trl
Successfully installed bitsandbytes-0.48.2 trl-0.25.1
Collecting unsloth
  Downloading unsloth-2025.12.1-py3-none-any.whl.metadata (65 kB)
[2K     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚

In [4]:
from google.colab import drive
drive.mount('/content/drive')

# My Drive/medgemma_finetune/data
!cp -r "/content/drive/MyDrive/data" "/content/medgemma_finetune/"

Mounted at /content/drive


In [5]:
%cd /content/medgemma_finetune
!python make_instructions.py


/content/medgemma_finetune
[INFO] Dataset: Brain Tumor CT scan Images
  - Class: Healthy
    -> Used 60 images
  - Class: Tumor
    -> Used 60 images
[INFO] Dataset: mammography
  - Class: Malignant
    -> Used 60 images
  - Class: Benign
    -> Used 60 images
[INFO] Dataset: kidney CT
  - Class: Stone
    -> Used 60 images
  - Class: Normal
    -> Used 60 images
  - Class: Cyst
    -> Used 60 images
  - Class: Tumor
    -> Used 60 images
[INFO] Dataset: lung cancer
  - Class: Normal cases
    -> Used 60 images
  - Class: Bengin cases
    -> Used 60 images
  - Class: Malignant cases
    -> Used 60 images
[INFO] Dataset: Brain Tumor MRI images
  - Class: Healthy
    -> Used 60 images
  - Class: Tumor
    -> Used 60 images
[INFO] Dataset: Breast MRI
  - Class: Malignant
    -> Used 60 images
  - Class: Benign
    -> Used 60 images
[INFO] Collected 900 examples total.
[INFO] Train: 765 | Val: 135
[OK] Wrote train_instructions.jsonl and val_instructions.jsonl in /content/medgemma_finetune


In [6]:
import os
from pathlib import Path
from typing import Any

import torch
from datasets import load_dataset
from PIL import Image

from transformers import (
    AutoProcessor,
    AutoModelForImageTextToText,
    BitsAndBytesConfig,
)
from peft import LoraConfig
from trl import SFTConfig, SFTTrainer



In [7]:
PROJECT_ROOT = Path(".")  # current folder
TRAIN_JSONL = "/content/medgemma_finetune/train_instructions.jsonl"
VAL_JSONL   = "/content/medgemma_finetune/val_instructions.jsonl"


MODEL_ID = "unsloth/medgemma-4b-it"   # <-- CHANGE THIS to your Unsloth model id if needed

# Tiny ‚Äúsafer than your GPA‚Äù training hyperparams
NUM_EPOCHS = 1
LEARNING_RATE = 1e-5       # you can go 5e-6 if you want to be extra safe
BATCH_SIZE = 4             # adjust by VRAM
GRAD_ACCUM = 2             # effective batch = BATCH_SIZE * GRAD_ACCUM




In [8]:
data = load_dataset(
    "json",
    data_files={
        "train": str(TRAIN_JSONL),
        "validation": str(VAL_JSONL),
    },
)
print(data)

Generating train split: 0 examples [00:00, ? examples/s]

Generating validation split: 0 examples [00:00, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['image_path', 'dataset', 'class_name', 'prompt', 'target'],
        num_rows: 765
    })
    validation: Dataset({
        features: ['image_path', 'dataset', 'class_name', 'prompt', 'target'],
        num_rows: 135
    })
})


In [9]:
# =========================
# 3. FORMAT DATA: image + messages
# =========================

def format_data(example: dict[str, Any]) -> dict[str, Any]:
    """
    Expects each JSONL row to have:
      - image_path: path to image
      - prompt: user prompt text
      - target: JSON string with {label_key, confidence, summary}
    Adds:
      - image: PIL image (RGB)
      - messages: chat-style messages for MedGemma
    """
    img_path = Path(example["image_path"])
    if not img_path.is_absolute():
        img_path = PROJECT_ROOT / img_path

    image = Image.open(img_path).convert("RGB")
    example["image"] = image

    example["messages"] = [
        {
            "role": "user",
            "content": [
                {"type": "image"},
                {"type": "text", "text": example["prompt"]},
            ],
        },
        {
            "role": "assistant",
            "content": [
                {"type": "text", "text": example["target"]},
            ],
        },
    ]
    return example


data = data.map(format_data)
print(data["train"][0])

Map:   0%|          | 0/765 [00:00<?, ? examples/s]

Map:   0%|          | 0/135 [00:00<?, ? examples/s]

{'image_path': '/content/medgemma_finetune/data/mammography/Benign/36_200779059_png.rf.3bc2f5bf2ef78d6f98d6ad7b0993bedd.jpg', 'dataset': 'mammography', 'class_name': 'Benign', 'prompt': 'You are an assistant radiologist. Modality: Mammogram. Task: Assess for malignant or benign breast changes on mammography. Analyze the given medical image and respond ONLY with valid JSON.\n\nThe JSON must contain the keys: "mammogram_finding", "confidence", "summary".\n', 'target': '{"mammogram_finding": "Benign", "confidence": 0.72, "summary": "No highly suspicious mammographic abnormality; features favor benign etiology. These findings are not a definitive diagnosis and should be confirmed by a doctor."}', 'image': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=640x640 at 0x7AB744C0D310>, 'messages': [{'content': [{'text': None, 'type': 'image'}, {'text': 'You are an assistant radiologist. Modality: Mammogram. Task: Assess for malignant or benign breast changes on mammography. Analyze the give

In [10]:
from huggingface_hub import login
login()


VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv‚Ä¶

In [11]:
from google.colab import userdata
access_token = userdata.get('HF_TOKEN')

In [12]:
# =======================================================
# STEP 1: INSPECT MODEL STRUCTURE
# =======================================================
import torch.nn as nn

print("\n--- Model Submodule Names ---")
# List all top-level module names to find the Vision Encoder
for name, module in model.named_children():
    print(f"Top-level Module: {name} (Type: {type(module).__name__})")

print("\n--- Vision Encoder Check ---")
# If 'vision_tower' exists, show its structure
if hasattr(model, 'vision_tower'):
    print("\n'vision_tower' structure:")
    for name, param in model.vision_tower.named_parameters():
        print(f"  {name}: {param.requires_grad}")
        # Only print the first 5 params to avoid huge output
        if name.split('.')[0] == name.split('.')[-1] and len(list(model.vision_tower.named_parameters())) > 5:
            break


--- Model Submodule Names ---


NameError: name 'model' is not defined

In [13]:
import torch
import gc
# ‚ö†Ô∏è Ensure this is one of the first imports in your entire notebook!
from unsloth import FastLanguageModel
from peft import LoraConfig

# =======================================================
# 1. MODEL LOADING
# =======================================================
model_id = "google/medgemma-4b-it"
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
)
model = AutoModelForImageTextToText.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    device_map="auto"
)

processor = AutoProcessor.from_pretrained(model_id)
processor.tokenizer.padding_side = "right"


Please restructure your imports with 'import unsloth' at the top of your file.
  from unsloth import FastLanguageModel


ü¶• Unsloth: Will patch your computer to enable 2x faster free finetuning.
ü¶• Unsloth Zoo will now patch everything to make training faster!


config.json:   0%|          | 0.00/2.47k [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/90.6k [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.96G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/3.64G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/156 [00:00<?, ?B/s]

processor_config.json:   0%|          | 0.00/70.0 [00:00<?, ?B/s]

chat_template.jinja:   0%|          | 0.00/1.53k [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.16M [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.69M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/35.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/662 [00:00<?, ?B/s]

In [15]:
from peft import get_peft_model

In [16]:
# === Step 5: LoRA configuration ===
peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.05,
    r=8,
    bias="none",
    target_modules="all-linear",
    task_type="CAUSAL_LM",
    modules_to_save=["lm_head", "embed_tokens"]
)
peft_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
    lora_dropout=0.05,
    task_type="CAUSAL_LM",
)

In [17]:
model = get_peft_model(model, peft_config)

In [18]:
model.print_trainable_parameters()

trainable params: 16,394,240 || all params: 4,316,473,712 || trainable%: 0.3798


In [19]:
# =========================
# 6. COLLATE FN (WITH MAX LENGTH FIX)
# =========================

def collate_fn(examples: list[dict[str, Any]]):
    texts = []
    images = []

    for example in examples:
        images.append([example["image"].convert("RGB")])
        texts.append(
            processor.apply_chat_template(
                example["messages"],
                add_generation_prompt=False,
                tokenize=False,
            ).strip()
        )

    # CRITICAL VRAM FIX: Add truncation and max_length here
    batch = processor(
        text=texts,
        images=images,
        return_tensors="pt",
        padding=True,
        # Max sequence length limits VRAM allocation for text

    )

    # Labels are input_ids with some tokens masked out
    labels = batch["input_ids"].clone()

    # Mask image tokens
    boi_id = processor.tokenizer.convert_tokens_to_ids("<_boi>")
    eoi_id = processor.tokenizer.convert_tokens_to_ids("<_eoi>")

    # The labels of image tokens need to be set to -100
    # The boi_token is the first token of the input, and can be at input_ids[:, 0]
    labels[:, 0] = -100

    # The eoi_token is the second token of the input, and can be at input_ids[:, 1]
    labels[:, 1] = -100

    # For non-image/non-label tokens, set labels to -100
    labels[labels == processor.tokenizer.pad_token_id] = -100

    batch["labels"] = labels
    return batch

In [20]:
output_dir = "medgemma-qlora-finetune"
os.makedirs(output_dir, exist_ok=True)

In [23]:
from trl import SFTConfig

args = SFTConfig(
    output_dir=output_dir,
    num_train_epochs=1,
    per_device_train_batch_size=1,        # optimized for low VRAM
    gradient_accumulation_steps=8,
    gradient_checkpointing=True,
    optim="adamw_torch_fused",
    learning_rate=2e-4,
    bf16=False,
    fp16=True,
    max_grad_norm=0.3,
    warmup_ratio=0.03,
    lr_scheduler_type="linear",
    save_strategy="epoch",
    push_to_hub=True,
    logging_steps=0.1,
    eval_strategy="no",  # train-only dataset
    report_to="none",
    gradient_checkpointing_kwargs={"use_reentrant": False},
    dataset_kwargs={"skip_prepare_dataset": True},
    remove_unused_columns=False,
    label_names=["labels"],
    dataloader_num_workers=8,
    dataloader_pin_memory=True
)

In [None]:
# =========================
# 8. TRAINER & TRAIN
# =========================

# You can optionally subsample validation to speed up:
eval_dataset = data["validation"]  # .shuffle().select(range(200))

trainer = SFTTrainer(
    model=model,
    args=args,
    train_dataset=data["train"],
    eval_dataset=eval_dataset,
    processing_class=processor,
    data_collator=collate_fn,
)

print("[INFO] Starting training...")
trainer.train()

[INFO] Starting training...


Step,Training Loss
10,
20,
30,
40,
50,


In [None]:
print("[INFO] Starting training...")
trainer.train()

In [None]:
# Save final adapter
trainer.save_model()  # saves LoRA weights into output_dir
print("[OK] Training complete. LoRA adapter saved in:", args.output_dir)

In [None]:
save_path = "medgemma-4b-finetuned123"

In [None]:
trainer.save_model(save_path, safe_serialization=True)
print(f"[OK] Complete model saved to: {save_path}")