In [None]:
from datasets import load_dataset
import os
import torch

In [None]:
print(torch.cuda.is_available())
print(torch.cuda.get_device_name(0))

In [None]:
# Path to dataset
data_dir = "Brain-Tumor" 

# Load dataset 
dataset = load_dataset(
    "imagefolder",
    data_dir=data_dir,
)

print(dataset)

In [None]:
# Check class labels
BRAIN_TUMOR_CLASSES = dataset["train"].features["label"].names
print("Detected classes:", BRAIN_TUMOR_CLASSES)

# Labels
BRAIN_TUMOR_CLASSES = [
    "glioma",
    "meningioma",
    "no tumour",
    "pituitary"
]

options = "\n".join(BRAIN_TUMOR_CLASSES)

In [None]:
PROMPT = f"What is the most likely type of brain tumor shown in the MRI image?\n{options}"

# Formatting function
def format_data(example: dict) -> dict:
    example["messages"] = [
        {
            "role": "user",
            "content": [
                {"type": "image"},
                {"type": "text", "text": PROMPT},
            ],
        },
        {
            "role": "assistant",
            "content": [
                {"type": "text", "text": BRAIN_TUMOR_CLASSES[example["label"]]},
            ],
        },
    ]
    return example

# Apply formatting
formatted_dataset = dataset.map(format_data)

# Check one example
print(formatted_dataset["train"][0]["messages"])

In [None]:
# Load model + processor
from transformers import AutoProcessor, AutoModelForImageTextToText

model_id = "google/medgemma-4b-it"

if torch.cuda.get_device_capability()[0] < 8:
    raise ValueError("GPU does not support bfloat16, please use a GPU that supports bfloat16.")

model_kwargs = dict(
    attn_implementation="eager",
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

model = AutoModelForImageTextToText.from_pretrained(model_id, **model_kwargs)
processor = AutoProcessor.from_pretrained(model_id)
processor.tokenizer.padding_side = "right"

In [None]:
# PEFT config (LoRA)
from peft import LoraConfig

peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.05,
    r=16,
    bias="none",
    target_modules="all-linear",
    task_type="CAUSAL_LM",
    modules_to_save=["lm_head", "embed_tokens"],
)

# Data collator
def collate_fn(examples):
    texts, images = [], []
    for example in examples:
        images.append([example["image"]])
        texts.append(
            processor.apply_chat_template(
                example["messages"], add_generation_prompt=False, tokenize=False
            ).strip()
        )
    batch = processor(text=texts, images=images, return_tensors="pt", padding=True)

    labels = batch["input_ids"].clone()

    # Mask special tokens for loss
    image_token_id = [
        processor.tokenizer.convert_tokens_to_ids(
            processor.tokenizer.special_tokens_map["boi_token"]
        )
    ]
    labels[labels == processor.tokenizer.pad_token_id] = -100
    labels[labels == image_token_id] = -100
    labels[labels == 262144] = -100

    batch["labels"] = labels
    return batch

In [None]:
os.environ["WANDB_DISABLED"] = "true"

# Training args
from trl import SFTConfig

args = SFTConfig(
    output_dir="medgemma-brain-tumor",
    num_train_epochs=1,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=32,
    gradient_checkpointing=True,
    optim="adamw_torch_fused",
    logging_steps=1,
    logging_first_step=True,
    save_strategy="epoch",
    eval_strategy="steps",
    eval_steps=10,          
    learning_rate=2e-4,
    bf16=True,
    max_grad_norm=0.3,
    warmup_ratio=0.03,
    lr_scheduler_type="linear",
    push_to_hub=False,      
    report_to="none",
    disable_tqdm=False,
    gradient_checkpointing_kwargs={"use_reentrant": False},
    dataset_kwargs={"skip_prepare_dataset": True},
    remove_unused_columns=False,
    label_names=["labels"],
)

# Trainer
from trl import SFTTrainer

trainer = SFTTrainer(
    model=model,
    args=args,
    train_dataset=formatted_dataset["train"],
    eval_dataset=formatted_dataset["validation"],
    peft_config=peft_config,
    processing_class=processor,
    data_collator=collate_fn,
)

model.gradient_checkpointing_enable()

In [None]:
trainer.train()