In [None]:
# Fine-tuning MedGemma-4B for Breast Cancer Histopathology Classification
# This notebook demonstrates how to fine-tune Google's MedGemma vision-language model
# on the BreakHis breast cancer dataset using LoRA (Low-Rank Adaptation)

# ============================================================================
# 0. SETUP AND INSTALLATIONS
# ============================================================================

# Install required packages
!pip install --upgrade --quiet transformers bitsandbytes datasets evaluate peft trl scikit-learn 
# Then, reinstall it, forcing a build from source
# This will take a few minutes as it compiles the code

import os
import re
import torch
import gc
from datasets import load_dataset, ClassLabel
from peft import LoraConfig, PeftModel
from transformers import AutoModelForImageTextToText, AutoProcessor
from trl import SFTTrainer, SFTConfig
import evaluate

# Hugging Face authentication
from huggingface_hub import login
HF_TOKEN = 'YOUR_HF_TOKEN'
login(HF_TOKEN)


In [13]:

# ============================================================================
# 1. LOAD AND PREPARE DATA
# ============================================================================
print("="*80)
print("STEP 1: Loading and Preparing Dataset")
print("="*80)

# Dataset configuration
DATASET_NAME = "sarath2003/BreakHis"
TRAIN_SIZE = 500  # Number of training samples
EVAL_SIZE = 1000   # Number of evaluation samples

# Load dataset
# The BreakHis dataset contains histopathological images of breast tumors
# at 200X magnification with 8 different classes (4 benign, 4 malignant)
dataset = load_dataset(DATASET_NAME, split="train").shuffle(seed=42)
train_data = dataset.select(range(TRAIN_SIZE))
eval_data = dataset.select(range(TRAIN_SIZE, TRAIN_SIZE + EVAL_SIZE))

print(f"Training samples: {len(train_data)}")
print(f"Evaluation samples: {len(eval_data)}")

# Extract class names from the dataset
CANCER_CLASSES = train_data.features["label"].names
print(f"\nClasses: {CANCER_CLASSES}")

STEP 1: Loading and Preparing Dataset
Training samples: 500
Evaluation samples: 1000

Classes: ['benign_adenosis', 'benign_fibroadenoma', 'benign_phyllodes_tumor', 'benign_tubular_adenoma', 'malignant_ductal_carcinoma', 'malignant_lobular_carcinoma', 'malignant_mucinous_carcinoma', 'malignant_papillary_carcinoma']


In [14]:

# ============================================================================
# 2. CREATE PROMPT AND FORMAT DATA
# ============================================================================
print("\n" + "="*80)
print("STEP 2: Formatting Data with Prompts")
print("="*80)

# Define the instruction prompt
# WHY THIS PROMPT:
# - Clear, concise task description
# - 0-7 numbering matches actual label indices (critical for training)
# - Asks for number only to simplify output parsing
PROMPT = """Analyze this breast tissue histopathology image and classify it.

Classes (0-7):
0: benign_adenosis
1: benign_fibroadenoma
2: benign_phyllodes_tumor
3: benign_tubular_adenoma
4: malignant_ductal_carcinoma
5: malignant_lobular_carcinoma
6: malignant_mucinous_carcinoma
7: malignant_papillary_carcinoma

Answer with only the number (0-7):"""



def format_data(example):
    """
    Format dataset examples into chat-style messages for training.
    
    WHY THIS FORMAT:
    - MedGemma expects chat-based input with user/assistant roles
    - Image is placed before text as per model's expected input order
    - Assistant response is just the label number for simplicity
    """
    example["messages"] = [
        {
            "role": "user",
            "content": [
                {"type": "image"},  # Image placeholder
                {"type": "text", "text": PROMPT},
            ],
        },
        {
            "role": "assistant",
            "content": [
                {"type": "text", "text": str(example["label"])},
            ],
        },
    ]
    return example


# Apply formatting
formatted_train = train_data.map(format_data)
formatted_eval = eval_data.map(format_data)

print("✓ Data formatted with instruction prompts")



STEP 2: Formatting Data with Prompts
✓ Data formatted with instruction prompts


In [15]:

# ============================================================================
# 3. LOAD MODEL AND PROCESSOR
# ============================================================================
print("\n" + "="*80)
print("STEP 3: Loading MedGemma Model")
print("="*80)

MODEL_ID = "google/medgemma-4b-it"

# Model configuration
# WHY BFLOAT16:
# - More numerically stable than float16 (avoids NaN issues)
# - Same memory footprint as float16
# - Better for vision-language models
# WHY device_map="auto": Automatically distributes model across available GPUs
model_kwargs = dict(
    torch_dtype=torch.bfloat16,
    device_map="auto",
    attn_implementation="sdpa",
)

model = AutoModelForImageTextToText.from_pretrained(MODEL_ID, **model_kwargs)
processor = AutoProcessor.from_pretrained(MODEL_ID)


# Configure tokenizer for training
# WHY right padding: Prevents issues with batched generation during training
processor.tokenizer.padding_side = "right"

print(f"✓ Model loaded: {MODEL_ID}")
print(f"✓ Using dtype: bfloat16")





STEP 3: Loading MedGemma Model


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

✓ Model loaded: google/medgemma-4b-it
✓ Using dtype: bfloat16


In [16]:

# ============================================================================
# 4. EVALUATE BASELINE MODEL (BEFORE FINE-TUNING)
# ============================================================================
print("\n" + "="*80)
print("STEP 4: Evaluating Baseline Model")
print("="*80)

# Setup evaluation metrics
accuracy_metric = evaluate.load("accuracy")
f1_metric = evaluate.load("f1")

def compute_metrics(predictions, references):
    """Compute accuracy and weighted F1 score"""
    return {
        **accuracy_metric.compute(predictions=predictions, references=references),
        **f1_metric.compute(predictions=predictions, references=references, average="weighted")
    }

def postprocess_prediction(text):
    """
    Extract predicted class number from model output.
    
    WHY THIS PARSING:
    - Model may output "Classification: 5" or just "5"
    - We use regex to find any digit 0-7 in the response
    - Returns -1 if no valid digit found (counts as wrong prediction)
    """
    digit_match = re.search(r'\b([0-7])\b', text.strip())
    return int(digit_match.group(1)) if digit_match else -1

def batch_predict(model, processor, prompts, images, batch_size=8, max_new_tokens=40):
    """
    Run batch inference on the model.
    
    WHY BATCH_SIZE=8:
    - Balance between speed and memory usage with bfloat16
    - Can be increased if more VRAM available
    
    WHY max_new_tokens=40:
    - We only need 1-2 tokens for the answer
    - 40 gives buffer for any extra text model might generate
    """
    predictions = []
    for i in range(0, len(prompts), batch_size):
        batch_texts = prompts[i:i + batch_size]
        batch_images = [[img] for img in images[i:i + batch_size]]
        
        # Process inputs
        inputs = processor(
            text=batch_texts,
            images=batch_images,
            padding=True,
            return_tensors="pt"
        ).to("cuda", torch.bfloat16)
        
        # Track prompt lengths to extract only generated text
        prompt_lengths = inputs["attention_mask"].sum(dim=1)
        
        # Generate
        with torch.inference_mode():
            outputs = model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                do_sample=False,  # Greedy decoding for deterministic results
                pad_token_id=processor.tokenizer.pad_token_id
            )
        
        # Decode only the generated part (not the prompt)
        for seq, length in zip(outputs, prompt_lengths):
            generated = processor.decode(seq[length:], skip_special_tokens=True)
            predictions.append(postprocess_prediction(generated))
    
    return predictions

# Prepare evaluation data
eval_prompts = [
    processor.apply_chat_template(
        [msg[0]],  # Only user message, not assistant response
        add_generation_prompt=True,
        tokenize=False
    )
    for msg in formatted_eval["messages"]
]
eval_images = formatted_eval["image"]
eval_labels = formatted_eval["label"]

# Run baseline evaluation
print("Running baseline evaluation...")
baseline_preds = batch_predict(model, processor, eval_prompts, eval_images)
baseline_metrics = compute_metrics(baseline_preds, eval_labels)

print(f"\n{'BASELINE RESULTS':-^80}")
print(f"Accuracy: {baseline_metrics['accuracy']:.1%}")
print(f"F1 Score: {baseline_metrics['f1']:.3f}")
print("-"*80)



STEP 4: Evaluating Baseline Model
Running baseline evaluation...

--------------------------------BASELINE RESULTS--------------------------------
Accuracy: 37.5%
F1 Score: 0.283
--------------------------------------------------------------------------------


In [17]:

# ============================================================================
# 5. CONFIGURE AND RUN FINE-TUNING
# ============================================================================
print("\n" + "="*80)
print("STEP 5: Fine-tuning with LoRA")
print("="*80)

# LoRA Configuration
# WHY LORA:
# - Trains only a small fraction of parameters (~1% of model)
# - Much faster and memory-efficient than full fine-tuning
# - Often achieves comparable performance
#
# PARAMETER EXPLANATIONS:
# - r=8: Rank of LoRA matrices (lower = fewer params, faster, less capacity)
#   - Too low (r=2): May underfit, can't learn complex patterns
#   - Too high (r=64): More params, slower, risk overfitting on small datasets
#   - r=8 is good balance for 500 training samples
#
# - lora_alpha=16: Scaling factor for LoRA weights
#   - Typically set to 2*r as a rule of thumb
#   - Controls how much LoRA adapters affect base model
#
# - lora_dropout=0.1: Regularization to prevent overfitting
#   - Higher values (0.2) = more regularization but may underfit
#   - Lower values (0.05) = less regularization but may overfit
#
# - target_modules="all-linear": Apply LoRA to all linear layers
#   - Alternative: Specify specific layers like ["q_proj", "v_proj"]
#   - "all-linear" is simpler and works well for most cases


peft_config = LoraConfig(
    r=8,
    lora_alpha=16,
    lora_dropout=0.1,
    bias="none",
    target_modules="all-linear",
    task_type="CAUSAL_LM",
)

# Custom data collator for vision-language training
def collate_fn(examples):
    """
    Prepare batches for training with images and text.
    
    WHY CUSTOM COLLATOR:
    - Need to handle both image and text inputs
    - Must mask padding tokens and image tokens in loss computation
    - MedGemma has special image token handling requirements
    """
    texts = []
    images = []
    
    for example in examples:
        images.append([example["image"]])
        texts.append(
            processor.apply_chat_template(
                example["messages"],
                add_generation_prompt=False,
                tokenize=False
            ).strip()
        )
    
    # Tokenize and process
    batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
    
    # Create labels (same as input_ids but with masking)
    labels = batch["input_ids"].clone()
    
    # Mask padding tokens (model shouldn't learn from padding)
    labels[labels == processor.tokenizer.pad_token_id] = -100
    
    # Mask image tokens (loss not computed on image embeddings)
    image_token_id = processor.tokenizer.convert_tokens_to_ids(
        processor.tokenizer.special_tokens_map["boi_token"]
    )
    labels[labels == image_token_id] = -100
    labels[labels == 262144] = -100  # Additional image-related token
    
    batch["labels"] = labels
    return batch

training_args = SFTConfig(
    output_dir="medgemma-breastcancer-finetuned",
    num_train_epochs=5,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=8,
    gradient_checkpointing=True,
    optim="paged_adamw_8bit",
    learning_rate=5e-4,
    lr_scheduler_type="cosine",
    warmup_ratio=0.03,  # Warm up LR for first 3% of training
    max_grad_norm=0.3,  # Clip gradients to prevent instability
    bf16=True,  # Use bfloat16 precision
    logging_steps=10,
    save_strategy="epoch",
    eval_strategy="epoch",
    push_to_hub=False,
    report_to="none",
    gradient_checkpointing_kwargs={"use_reentrant": False},
    dataset_kwargs={"skip_prepare_dataset": True},
    remove_unused_columns=False,
    label_names=["labels"], 
)



STEP 5: Fine-tuning with LoRA


In [None]:
import time

# Initialize trainer
trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=formatted_train,
    eval_dataset=formatted_eval,
    peft_config=peft_config,
    processing_class=processor,
    data_collator=collate_fn,
)

# Train the model
print("Starting training...")
print(f"Total training steps: ~{(TRAIN_SIZE * 5) // 8}")
start_time = time.perf_counter()

trainer.train()
end_time = time.perf_counter()

# Save the fine-tuned model
trainer.save_model()
print(f"✓ Model saved to {training_args.output_dir}")

print("Model training duration: ", (end_time - start_time)/60, " minutes")

The model is already on multiple devices. Skipping the move to device specified in `args`.
The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'eos_token_id': 1, 'bos_token_id': 2, 'pad_token_id': 0}.


Starting training...
Total training steps: ~312


Epoch,Training Loss,Validation Loss,Entropy,Num Tokens,Mean Token Accuracy
1,0.0269,0.028996,0.074295,187500.0,0.986325
2,0.0253,0.026479,0.048433,375000.0,0.987034
3,0.023,0.024792,0.070041,562500.0,0.987188
4,0.0206,0.022377,0.181104,750000.0,0.987573
5,0.0182,0.022093,0.164053,937500.0,0.987838


✓ Model saved to medgemma-breastcancer-finetuned


In [None]:

# ============================================================================
# 6. EVALUATE FINE-TUNED MODEL
# ============================================================================
print("\n" + "="*80)
print("STEP 6: Evaluating Fine-tuned Model")
print("="*80)

# Clear memory and load fine-tuned model
del model
torch.cuda.empty_cache()
gc.collect()

# Load base model
base_model = AutoModelForImageTextToText.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    attn_implementation="sdpa"
)

# Load LoRA adapters and merge them
finetuned_model = PeftModel.from_pretrained(base_model, training_args.output_dir)
finetuned_model = finetuned_model.merge_and_unload()

# Load processor from fine-tuned checkpoint
processor_finetuned = AutoProcessor.from_pretrained(training_args.output_dir)

# Configure for generation
finetuned_model.generation_config.max_new_tokens = 50
finetuned_model.generation_config.pad_token_id = processor_finetuned.tokenizer.pad_token_id
finetuned_model.config.pad_token_id = processor_finetuned.tokenizer.pad_token_id

print("✓ Fine-tuned model loaded")

# Run evaluation
print("Running fine-tuned evaluation...")
finetuned_preds = batch_predict(
    finetuned_model,
    processor_finetuned,
    eval_prompts,
    eval_images,
    batch_size=4  # Smaller batch size for safety
)
finetuned_metrics = compute_metrics(finetuned_preds, eval_labels)



STEP 6: Evaluating Fine-tuned Model


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

✓ Fine-tuned model loaded
Running fine-tuned evaluation...


In [None]:

# ============================================================================
# 7. COMPARE RESULTS
# ============================================================================
print("\n" + "="*80)
print("FINAL RESULTS COMPARISON")
print("="*80)
print(f"\n{'Model':<20} {'Accuracy':<12} {'F1 Score':<12}")
print("-" * 44)
print(f"{'Baseline':<20} {baseline_metrics['accuracy']:>10.1%}  {baseline_metrics['f1']:>10.3f}")
print(f"{'Fine-tuned':<20} {finetuned_metrics['accuracy']:>10.1%}  {finetuned_metrics['f1']:>10.3f}")
print("-" * 44)

# Calculate improvement
accuracy_improvement = (finetuned_metrics['accuracy'] - baseline_metrics['accuracy']) * 100
f1_improvement = finetuned_metrics['f1'] - baseline_metrics['f1']

print(f"\n{'Improvement':<20} {accuracy_improvement:>+9.1f}%  {f1_improvement:>+10.3f}")
print("="*80)

# Success indicators
if finetuned_metrics['accuracy'] > baseline_metrics['accuracy']:
    print("\n✓ Fine-tuning successful! Accuracy improved.")
else:
    print("\n⚠ Fine-tuning did not improve accuracy. Consider:")
    print("  - Training for more epochs")
    print("  - Using more training data")
    print("  - Adjusting learning rate or LoRA rank")

print("\nTraining complete!")


FINAL RESULTS COMPARISON

Model                Accuracy     F1 Score    
--------------------------------------------
Baseline                  37.5%       0.283
Fine-tuned                57.6%       0.511
--------------------------------------------

Improvement              +20.1%      +0.228

✓ Fine-tuning successful! Accuracy improved.

Training complete!
