In [2]:
import os
# Fixes potential memory fragmentation issues
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# Optional: Add this to potentially reduce overall memory usage if needed
# os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

In [3]:
# =========================
# 0. INSTALL (run once)
# =========================
!pip install "transformers>=4.41.0" "datasets>=2.19.0" "accelerate>=0.30.0" \
              peft trl bitsandbytes pillow

# If you use Unsloth:
!pip install unsloth



Collecting trl
  Downloading trl-0.25.1-py3-none-any.whl.metadata (11 kB)
Collecting bitsandbytes
  Downloading bitsandbytes-0.48.2-py3-none-manylinux_2_24_x86_64.whl.metadata (10 kB)
Downloading trl-0.25.1-py3-none-any.whl (465 kB)
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m465.5/465.5 kB[0m [31m19.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading bitsandbytes-0.48.2-py3-none-manylinux_2_24_x86_64.whl (59.4 MB)
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m59.4/59.4 MB[0m [31m14.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: bitsandbytes, trl
Successfully installed bitsandbytes-0.48.2 trl-0.25.1
Collecting unsloth
  Downloading unsloth-2025.11.6-py3-none-any.whl.metadata (64 kB)
[2K     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚

In [5]:
from google.colab import drive
drive.mount('/content/drive')

# My Drive/medgemma_finetune/data
!cp -r "/content/drive/MyDrive/data" "/content/medgemma_finetune/"

Mounted at /content/drive


In [6]:
%cd /content/medgemma_finetune
!python make_instructions.py


/content/medgemma_finetune
[INFO] Dataset: kidney CT
  - Class: Cyst
    -> Used 100 images
  - Class: Normal
    -> Used 100 images
  - Class: Stone
    -> Used 100 images
  - Class: Tumor
    -> Used 100 images
[INFO] Dataset: Breast MRI
  - Class: Malignant
    -> Used 100 images
  - Class: Benign
    -> Used 100 images
[INFO] Dataset: Brain Tumor MRI images
  - Class: Healthy
    -> Used 100 images
  - Class: Tumor
    -> Used 100 images
[INFO] Dataset: mammography
  - Class: Malignant
    -> Used 100 images
  - Class: Benign
    -> Used 100 images
[INFO] Dataset: Brain Tumor CT scan Images
  - Class: Healthy
    -> Used 100 images
  - Class: Tumor
    -> Used 100 images
[INFO] Dataset: lung cancer
  - Class: Bengin cases
    -> Used 100 images
  - Class: Normal cases
    -> Used 100 images
  - Class: Malignant cases
    -> Used 100 images
[INFO] Collected 1500 examples total.
[INFO] Train: 1275 | Val: 225
[OK] Wrote train_instructions.jsonl and val_instructions.jsonl in /content/m

In [7]:
import os
from pathlib import Path
from typing import Any

import torch
from datasets import load_dataset
from PIL import Image

from transformers import (
    AutoProcessor,
    AutoModelForImageTextToText,
    BitsAndBytesConfig,
)
from peft import LoraConfig
from trl import SFTConfig, SFTTrainer

In [8]:
PROJECT_ROOT = Path(".")  # current folder
TRAIN_JSONL = "/content/medgemma_finetune/train_instructions.jsonl"
VAL_JSONL   = "/content/medgemma_finetune/val_instructions.jsonl"


MODEL_ID = "unsloth/medgemma-4b-it"   # <-- CHANGE THIS to your Unsloth model id if needed

# Tiny ‚Äúsafer than your GPA‚Äù training hyperparams
NUM_EPOCHS = 1
LEARNING_RATE = 1e-5       # you can go 5e-6 if you want to be extra safe
BATCH_SIZE = 4             # adjust by VRAM
GRAD_ACCUM = 2             # effective batch = BATCH_SIZE * GRAD_ACCUM




In [9]:
data = load_dataset(
    "json",
    data_files={
        "train": str(TRAIN_JSONL),
        "validation": str(VAL_JSONL),
    },
)
print(data)

Generating train split: 0 examples [00:00, ? examples/s]

Generating validation split: 0 examples [00:00, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['image_path', 'dataset', 'class_name', 'prompt', 'target'],
        num_rows: 1275
    })
    validation: Dataset({
        features: ['image_path', 'dataset', 'class_name', 'prompt', 'target'],
        num_rows: 225
    })
})


In [10]:
# =========================
# 3. FORMAT DATA: image + messages
# =========================

def format_data(example: dict[str, Any]) -> dict[str, Any]:
    """
    Expects each JSONL row to have:
      - image_path: path to image
      - prompt: user prompt text
      - target: JSON string with {label_key, confidence, summary}
    Adds:
      - image: PIL image (RGB)
      - messages: chat-style messages for MedGemma
    """
    img_path = Path(example["image_path"])
    if not img_path.is_absolute():
        img_path = PROJECT_ROOT / img_path

    image = Image.open(img_path).convert("RGB")
    example["image"] = image

    example["messages"] = [
        {
            "role": "user",
            "content": [
                {"type": "image"},
                {"type": "text", "text": example["prompt"]},
            ],
        },
        {
            "role": "assistant",
            "content": [
                {"type": "text", "text": example["target"]},
            ],
        },
    ]
    return example


data = data.map(format_data)
print(data["train"][0])

Map:   0%|          | 0/1275 [00:00<?, ? examples/s]

Map:   0%|          | 0/225 [00:00<?, ? examples/s]

{'image_path': '/content/medgemma_finetune/data/Breast MRI/Malignant/BREASTDX-01-0068_12666.jpg', 'dataset': 'Breast MRI', 'class_name': 'Malignant', 'prompt': 'You are an assistant radiologist. Modality: MRI breast. Task: Assess for breast malignancy and summarize key findings. Analyze the given medical image and respond ONLY with valid JSON.\n\nThe JSON must contain the keys: "breast_mri_finding", "confidence", "summary".\n', 'target': '{"breast_mri_finding": "Malignant", "confidence": 0.94, "summary": "Suspicious enhancing lesion highly suggestive of malignant breast tumor. These findings are not a definitive diagnosis and should be confirmed by a doctor."}', 'image': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=224x224 at 0x7EB7CF9BC8C0>, 'messages': [{'content': [{'text': None, 'type': 'image'}, {'text': 'You are an assistant radiologist. Modality: MRI breast. Task: Assess for breast malignancy and summarize key findings. Analyze the given medical image and respond ONLY wi

In [14]:
from huggingface_hub import login
login()


VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv‚Ä¶

In [15]:
from google.colab import userdata
access_token = userdata.get('HF_TOKEN')

In [16]:
# =======================================================
# STEP 1: INSPECT MODEL STRUCTURE
# =======================================================
import torch.nn as nn

print("\n--- Model Submodule Names ---")
# List all top-level module names to find the Vision Encoder
for name, module in model.named_children():
    print(f"Top-level Module: {name} (Type: {type(module).__name__})")

print("\n--- Vision Encoder Check ---")
# If 'vision_tower' exists, show its structure
if hasattr(model, 'vision_tower'):
    print("\n'vision_tower' structure:")
    for name, param in model.vision_tower.named_parameters():
        print(f"  {name}: {param.requires_grad}")
        # Only print the first 5 params to avoid huge output
        if name.split('.')[0] == name.split('.')[-1] and len(list(model.vision_tower.named_parameters())) > 5:
            break


--- Model Submodule Names ---


NameError: name 'model' is not defined

In [17]:
import torch
import gc
# ‚ö†Ô∏è Ensure this is one of the first imports in your entire notebook!
from unsloth import FastLanguageModel
from peft import LoraConfig

# =======================================================
# 1. MODEL LOADING
# =======================================================
MODEL_ID = "unsloth/medgemma-4b-it"
# access_token is assumed to be defined previously

model, processor = FastLanguageModel.from_pretrained(
    model_name=MODEL_ID,
    max_seq_length=512,
    dtype=None,
    load_in_4bit=True,
    token=access_token,
)

model.resize_token_embeddings(len(processor.tokenizer))
torch.cuda.empty_cache()
gc.collect()

# =======================================================
# 2. MANUAL FREEZE (Safety Check)
# =======================================================
print("Manually freezing the Vision Encoder...")
if hasattr(model, 'vision_tower'):
    for param in model.vision_tower.parameters():
        param.requires_grad = False
    print("‚úÖ Vision Encoder successfully frozen.")
else:
    print("‚ö†Ô∏è Could not find 'vision_tower' to freeze.")

# =======================================================
# 3. LoRA CONFIGURATION (Type Fix for 'r')
# =======================================================
# Define the LLM target layers
target_linear_layers = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]

# *** DEFENSIVE TYPE CAST FOR THE RANK (r) PARAMETER ***
LORA_RANK = 4

peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.05,
    r=int(LORA_RANK), # <--- **FIX: Explicitly cast 'r' to a Python integer**
    bias="none",
    target_modules=target_linear_layers,
    task_type="CAUSAL_LM",
    modules_to_save=["lm_head", "embed_tokens"],
)

# =======================================================
# 4. APPLY PEFT MODEL (Unsloth Optimized)
# =======================================================
# *** CRUCIAL: Use the Unsloth wrapper for correct multimodal handling ***
# *** CORRECTED CODE: Unpack the peft_config into keyword arguments ***
model = FastLanguageModel.get_peft_model(model, **peft_config.to_dict())

# 5. Check trainable parameters - MUST be small
model.print_trainable_parameters()


Please restructure your imports with 'import unsloth' at the top of your file.
  from unsloth import FastLanguageModel


ü¶• Unsloth: Will patch your computer to enable 2x faster free finetuning.
ü¶• Unsloth Zoo will now patch everything to make training faster!
==((====))==  Unsloth 2025.11.6: Fast Gemma3 patching. Transformers: 4.57.2.
   \\   /|    Tesla T4. Num GPUs = 1. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.9.0+cu126. CUDA: 7.5. CUDA Toolkit: 12.6. Triton: 3.5.1
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.33.post2. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: Using float16 precision for gemma3 won't work! Using float32.
Unsloth: Gemma3 does not support SDPA - switching to fast eager.


model.safetensors:   0%|          | 0.00/4.12G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/210 [00:00<?, ?B/s]

processor_config.json:   0%|          | 0.00/70.0 [00:00<?, ?B/s]

chat_template.json: 0.00B [00:00, ?B/s]

preprocessor_config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/4.69M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/35.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/670 [00:00<?, ?B/s]

Manually freezing the Vision Encoder...
‚úÖ Vision Encoder successfully frozen.


AssertionError: 

In [None]:
# =========================
# 5. LoRA CONFIG (OPTIMIZED FOR T4)
# =========================

peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.05,
    r=8,
    bias="none",
    # OPTIMIZATION: Do not use "all-linear" on free Colab.
    # Target only the Language Model layers to save huge VRAM.
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    task_type="CAUSAL_LM",
    modules_to_save=["lm_head", "embed_tokens"],
)

In [None]:
# =========================
# 6. COLLATE FN (WITH MAX LENGTH FIX)
# =========================

def collate_fn(examples: list[dict[str, Any]]):
    texts = []
    images = []

    for example in examples:
        images.append([example["image"].convert("RGB")])
        texts.append(
            processor.apply_chat_template(
                example["messages"],
                add_generation_prompt=False,
                tokenize=False,
            ).strip()
        )

    # CRITICAL VRAM FIX: Add truncation and max_length here
    batch = processor(
        text=texts,
        images=images,
        return_tensors="pt",
        padding=True,
        # Max sequence length limits VRAM allocation for text

    )

    # Labels are input_ids with some tokens masked out
    labels = batch["input_ids"].clone()

    # Mask image tokens
    boi_id = processor.tokenizer.convert_tokens_to_ids("<_boi>")
    eoi_id = processor.tokenizer.convert_tokens_to_ids("<_eoi>")

    # The labels of image tokens need to be set to -100
    # The boi_token is the first token of the input, and can be at input_ids[:, 0]
    labels[:, 0] = -100

    # The eoi_token is the second token of the input, and can be at input_ids[:, 1]
    labels[:, 1] = -100

    # For non-image/non-label tokens, set labels to -100
    labels[labels == processor.tokenizer.pad_token_id] = -100

    batch["labels"] = labels
    return batch

In [None]:
from trl import SFTConfig

args = SFTConfig(
    output_dir="medgemma-4b-it-lora-medmulti",
    num_train_epochs=1,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=16,
    gradient_checkpointing=True,

    # CRITICAL VRAM FIX: Use Paged Optimizer to offload to CPU RAM
    optim="paged_adamw_8bit",

    logging_steps=10,
    save_strategy="epoch",
    eval_strategy="no",

    learning_rate=1e-5,
    max_grad_norm=0.3,
    warmup_ratio=0.03,
    lr_scheduler_type="linear",

    bf16=False,
    fp16=True,

    push_to_hub=False,
    report_to="none",
    gradient_checkpointing_kwargs={"use_reentrant": False},
    dataset_kwargs={"skip_prepare_dataset": True},
    remove_unused_columns=False,
    label_names=["labels"],
)

In [None]:
# =========================
# 8. TRAINER & TRAIN
# =========================

# You can optionally subsample validation to speed up:
eval_dataset = data["validation"]  # .shuffle().select(range(200))

trainer = SFTTrainer(
    model=model,
    args=args,
    train_dataset=data["train"],
    eval_dataset=eval_dataset,
    peft_config=peft_config,
    processing_class=processor,
    data_collator=collate_fn,
)

print("[INFO] Starting training...")
trainer.train()



[INFO] Starting training...


==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 765 | Num Epochs = 1 | Total steps = 96
O^O/ \_/ \    Batch size per device = 1 | Gradient accumulation steps = 8
\        /    Data Parallel GPUs = 1 | Total batch size (1 x 8 x 1) = 8
 "-____-"     Trainable parameters = 1,358,576,640 of 5,658,494,832 (24.01% trained)


BackendCompilerFailed: backend='inductor' raised:
OutOfMemoryError: CUDA out of memory. Tried to allocate 1.25 GiB. GPU 0 has a total capacity of 14.74 GiB of which 10.12 MiB is free. Process 380905 has 14.73 GiB memory in use. Of the allocated memory 14.52 GiB is allocated by PyTorch, and 59.50 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"


In [None]:
# Save final adapter
trainer.save_model()  # saves LoRA weights into output_dir
print("[OK] Training complete. LoRA adapter saved in:", args.output_dir)