In [16]:
# Fine-tuning MedGemma-4B for Breast Cancer Histopathology Classification
# This notebook demonstrates how to fine-tune Google's MedGemma vision-language model
# on the BreakHis breast cancer dataset using LoRA (Low-Rank Adaptation)

# ============================================================================
# 0. SETUP AND INSTALLATIONS
# ============================================================================

# Install required packages
!pip install --upgrade --quiet transformers datasets evaluate peft trl scikit-learn 
# Then, reinstall it, forcing a build from source
# This will take a few minutes as it compiles the code

import os
import re
import torch
import gc
from datasets import load_dataset, ClassLabel
from peft import LoraConfig, PeftModel
from transformers import AutoModelForImageTextToText, AutoProcessor
from trl import SFTTrainer, SFTConfig
import evaluate
import pandas as pd

# Hugging face authentication

# Important Security Note: You should never hardcode secrets like API keys or tokens directly into your code or notebooks, especially in a production environment. This practice is insecure and creates a significant security risk.

# The most secure and "enterprise-grade" way to handle secrets (like  Hugging Face token) in Vertex AI Workbench is to use Google Cloud Secret Manager. 

#To do that you will need to:
# 1. In the Google cloud console, go to Google Cloud Secret Manager and create your secret. 
# 2. Grant Permission to your Workbench Instance "Service Account" to read the secret - Find your Workbench Service Account email (click on the instance name), Go to the IAM & Admin page, find that service account and add the role: Secret Manager Secret Accessor.
# 3. Access the Secret in Your Notebook (see https://docs.cloud.google.com/secret-manager/docs/reference/libraries#client-libraries-usage-python)

# If you are just experimenting and don't want to set up Secret Manager yet, you can use the interactive login widget. This saves the token temporarily in the instance's file system.


# Hugging Face authentication using interactive login widget

from huggingface_hub import notebook_login
notebook_login()


# Hugging Face authentication using Google cloud secret manager

# from google.cloud import secretmanager
# from huggingface_hub import login

# def get_secret(secret_id, version_id="latest"):
#     # Create the Secret Manager client.
#     client = secretmanager.SecretManagerServiceClient()

#     # Build the resource name of the secret version.
#     # Replace 'YOUR_PROJECT_ID' with your actual project ID
#     project_id = "YOUR_PROJECT_ID" 
#     name = f"projects/{project_id}/secrets/{secret_id}/versions/{version_id}"

#     # Access the secret version.
#     response = client.access_secret_version(request={"name": name})
#     return response.payload.data.decode("UTF-8")

# # Retrieve the token
# hf_token = get_secret("hugging-face-token")

# # Login to Hugging Face
# login(token=hf_token)

# print("Successfully logged in to Hugging Face!")


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)



[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.2[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [17]:
# ============================================================================
# 1. LOAD AND PREPARE DATA FROM KAGGLE
# ============================================================================

! pip install -q kagglehub

import kagglehub
import os
import pandas as pd  # <-- This import was missing
from PIL import Image

# Download the dataset metadata
path = kagglehub.dataset_download("ambarish/breakhis")
print("Path to dataset files:", path)
folds = pd.read_csv('{}/Folds.csv'.format(path))

# Filter for 100X magnification from the first fold
folds_100x = folds[folds['mag']==100]
folds_100x = folds_100x[folds_100x['fold']==1]


# Get the train/test splits
folds_100x_test = folds_100x[folds_100x.grp=='test']
folds_100x_train = folds_100x[folds_100x.grp=='train']

# Get the lists of relative filenames
test_filenames = folds_100x_test.filename.values
train_filenames = folds_100x_train.filename.values

# Define the base path for images
BASE_PATH = "/home/jupyter/.cache/kagglehub/datasets/ambarish/breakhis/versions/4/BreaKHis_v1"

rate_benign_in_test = round(100* sum(folds_100x_test.filename.str.contains('benign')==1)/len(test_filenames))
rate_malignant_in_test = round(100*sum(folds_100x_test.filename.str.contains('malignant')==1)/len(test_filenames))
rate_benign_in_train = round(100*sum(folds_100x_train.filename.str.contains('benign')==1)/len(train_filenames))
rate_malignant_in_train = round(100*sum(folds_100x_train.filename.str.contains('malignant')==1)/len(train_filenames))

print(f"\nFor 100X magnificantion found: {len(train_filenames)} train files:  {rate_benign_in_train}% benign, {rate_malignant_in_train}% malignant, {len(test_filenames)} test files  with {rate_benign_in_test}% benign and {rate_malignant_in_test}% malignant.")


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)



[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.2[0m[39;49m -> [0m[32;49m25.3[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Path to dataset files: /home/jupyter/.cache/kagglehub/datasets/ambarish/breakhis/versions/4

For 100X magnificantion found: 1321 train files:  29% benign, 71% malignant, 760 test files  with 34% benign and 66% malignant.


In [18]:

# ============================================================================
# 1.1 UNDERSAMPLE DATA TO GET BALANCED TRAIN AND TEST SETS
# ============================================================================

import pandas as pd

# --- 1. Create Balanced TRAIN Set ---

# Separate benign and malignant training files
train_benign_df = folds_100x_train[folds_100x_train['filename'].str.contains('benign')]
train_malignant_df = folds_100x_train[folds_100x_train['filename'].str.contains('malignant')]

# Find the size of the smaller class
min_train_count = min(len(train_benign_df), len(train_malignant_df))

# Sample the smaller amount from both classes
balanced_train_benign = train_benign_df.sample(n=min_train_count, random_state=42)
balanced_train_malignant = train_malignant_df.sample(n=min_train_count, random_state=42)

# Combine them into a new balanced DataFrame
balanced_train_df = pd.concat([balanced_train_benign, balanced_train_malignant])

# --- 2. Create Balanced TEST Set ---

# Separate benign and malignant test files
test_benign_df = folds_100x_test[folds_100x_test['filename'].str.contains('benign')]
test_malignant_df = folds_100x_test[folds_100x_test['filename'].str.contains('malignant')]

# Find the size of the smaller class
min_test_count = min(len(test_benign_df), len(test_malignant_df))

# Sample the smaller amount from both classes
balanced_test_benign = test_benign_df.sample(n=min_test_count, random_state=42)
balanced_test_malignant = test_malignant_df.sample(n=min_test_count, random_state=42)

# Combine them into a new balanced DataFrame
balanced_test_df = pd.concat([balanced_test_benign, balanced_test_malignant])

# --- 3. Get the Final Filename Lists ---

# These are the variables you requested
balanced_train_filenames = balanced_train_df['filename'].values
balanced_test_filenames = balanced_test_df['filename'].values


# --- 4. Print Summary ---
print("Original train files:", len(folds_100x_train))
print("Original test files:", len(folds_100x_test))
print("\n--- Balanced Sets Created (50/50) ---")
print(f"Balanced Train: {len(balanced_train_filenames)} files ({min_train_count} benign + {min_train_count} malignant)")
print(f"Balanced Test:  {len(balanced_test_filenames)} files ({min_test_count} benign + {min_test_count} malignant)")

test_filenames = balanced_test_filenames
train_filenames = balanced_train_filenames

Original train files: 1321
Original test files: 760

--- Balanced Sets Created (50/50) ---
Balanced Train: 766 files (383 benign + 383 malignant)
Balanced Test:  522 files (261 benign + 261 malignant)


In [19]:

# ============================================================================
# 2. CREATE HUGGING FACE DATASETS
# ============================================================================
from datasets import Dataset, Image as HFImage, Features, ClassLabel

# These class names match your original notebook (cell [40])
CLASS_NAMES = [
    'benign_adenosis', 'benign_fibroadenoma', 'benign_phyllodes_tumor', 
    'benign_tubular_adenoma', 'malignant_ductal_carcinoma', 
    'malignant_lobular_carcinoma', 'malignant_mucinous_carcinoma', 
    'malignant_papillary_carcinoma'
]


# This function maps the filename path to the correct class label (0-7)
def get_label_from_filename(filename):
    """Extract label from BreakHis filename path."""
    filename = filename.replace('\\', '/').lower()
    
    # Map folder names to labels
    if '/adenosis/' in filename: return 0
    if '/fibroadenoma/' in filename: return 1
    if '/phyllodes_tumor/' in filename: return 2
    if '/tubular_adenoma/' in filename: return 3
    if '/ductal_carcinoma/' in filename: return 4
    if '/lobular_carcinoma/' in filename: return 5
    if '/mucinous_carcinoma/' in filename: return 6
    if '/papillary_carcinoma/' in filename: return 7
    
    return -1

# Create dictionaries with full paths and labels
train_data_dict = {
    'image': [os.path.join(BASE_PATH, f) for f in train_filenames],
    'label': [get_label_from_filename(f) for f in train_filenames]
}

test_data_dict = {
    'image': [os.path.join(BASE_PATH, f) for f in test_filenames],
    'label': [get_label_from_filename(f) for f in test_filenames]
}

# Define the dataset features, just like your original dataset
features = Features({
    'image': HFImage(),
    'label': ClassLabel(names=CLASS_NAMES)
})

# Create the Dataset objects
# The .cast_column tells the dataset to load the image from the path
train_dataset = Dataset.from_dict(train_data_dict, features=features).cast_column("image", HFImage())
eval_dataset = Dataset.from_dict(test_data_dict, features=features).cast_column("image", HFImage())

print("\n--- Created new Hugging Face Datasets ---")
print(train_dataset)
print(eval_dataset)

# ============================================================================
# 2.1 APPLY  FORMATTING FUNCTION
# ============================================================================

# Define the instruction prompt from cell [17]
PROMPT = """Analyze this breast tissue histopathology image and classify it.

Classes (0-7):
0: benign_adenosis
1: benign_fibroadenoma
2: benign_phyllodes_tumor
3: benign_tubular_adenoma
4: malignant_ductal_carcinoma
5: malignant_lobular_carcinoma
6: malignant_mucinous_carcinoma
7: malignant_papillary_carcinoma

Answer with only the number (0-7):"""

def format_data(example):
    """Format dataset examples into chat-style messages for training."""
    
    
    example["messages"] = [
        {
            "role": "user",
            "content": [
                {"type": "image"},  # Image placeholder
                {"type": "text", "text": PROMPT},
            ],
        },
        {
            "role": "assistant",
            "content": [
                {"type": "text", "text": str(example["label"])},
            ],
        },
    ]
    return example

# Apply formatting
# These are your new, correctly-split, formatted datasets
formatted_train = train_dataset.map(format_data, batched=False)
formatted_eval = eval_dataset.map(format_data, batched=False)

print("\n--- Applied formatting ---")
print(formatted_train)
print(formatted_eval)

# Check the first example
print("\n--- Example from new formatted_train ---")
print(formatted_train[0])


--- Created new Hugging Face Datasets ---
Dataset({
    features: ['image', 'label'],
    num_rows: 766
})
Dataset({
    features: ['image', 'label'],
    num_rows: 522
})


Map:   0%|          | 0/766 [00:00<?, ? examples/s]

Map:   0%|          | 0/522 [00:00<?, ? examples/s]


--- Applied formatting ---
Dataset({
    features: ['image', 'label', 'messages'],
    num_rows: 766
})
Dataset({
    features: ['image', 'label', 'messages'],
    num_rows: 522
})

--- Example from new formatted_train ---
{'image': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=700x460 at 0x7F220BCB6B60>, 'label': 2, 'messages': [{'content': [{'text': None, 'type': 'image'}, {'text': 'Analyze this breast tissue histopathology image and classify it.\n\nClasses (0-7):\n0: benign_adenosis\n1: benign_fibroadenoma\n2: benign_phyllodes_tumor\n3: benign_tubular_adenoma\n4: malignant_ductal_carcinoma\n5: malignant_lobular_carcinoma\n6: malignant_mucinous_carcinoma\n7: malignant_papillary_carcinoma\n\nAnswer with only the number (0-7):', 'type': 'text'}], 'role': 'user'}, {'content': [{'text': '2', 'type': 'text'}], 'role': 'assistant'}]}


In [20]:

# ============================================================================
# 3. LOAD MODEL AND PROCESSOR
# ============================================================================
print("\n" + "="*80)
print("STEP 3: Loading MedGemma Model")
print("="*80)

MODEL_ID = "google/medgemma-4b-it"

# Model configuration
# WHY BFLOAT16:
# - More numerically stable than float16 (avoids NaN issues)
# - Same memory footprint as float16
# - Better for vision-language models
# WHY device_map="auto": Automatically distributes model across available GPUs
model_kwargs = dict(
    torch_dtype=torch.bfloat16,
    device_map="auto",
    attn_implementation="sdpa",
)

model = AutoModelForImageTextToText.from_pretrained(MODEL_ID, **model_kwargs)
processor = AutoProcessor.from_pretrained(MODEL_ID)

# # Override image size to something smaller
# processor.image_processor.size = {"height": 448, "width": 448}  # Half size

# Configure tokenizer for training
# WHY right padding: Prevents issues with batched generation during training
processor.tokenizer.padding_side = "right"

print(f"✓ Model loaded: {MODEL_ID}")
print(f"✓ Using dtype: bfloat16")





STEP 3: Loading MedGemma Model


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

✓ Model loaded: google/medgemma-4b-it
✓ Using dtype: bfloat16


In [21]:

# ============================================================================
# 4. EVALUATE BASELINE MODEL (BEFORE FINE-TUNING)
# ============================================================================
print("\n" + "="*80)
print("STEP 4: Evaluating Baseline Model")
print("="*80)

# Setup evaluation metrics
accuracy_metric = evaluate.load("accuracy")
f1_metric = evaluate.load("f1")

# def compute_metrics(predictions, references):
#     """Compute accuracy and weighted F1 score"""
#     return {
#         **accuracy_metric.compute(predictions=predictions, references=references),
#         **f1_metric.compute(predictions=predictions, references=references, average="weighted")
#     }


def compute_metrics(predictions, references):
    """Compute 8-class and binary (benign/malignant) metrics."""
    
    # --- 1. 8-class (multi-class) metrics (Original) ---
    multi_class_acc = accuracy_metric.compute(predictions=predictions, references=references)
    multi_class_f1 = f1_metric.compute(predictions=predictions, references=references, average="weighted")
    
    # --- 2. Binary (benign/malignant) metrics (New) ---
    
    # Convert lists to binary: 
    # Labels 0-3 are 'benign' (0)
    # Labels 4-7 are 'malignant' (1)
    # Parsed -1 (error) remains -1 so it's counted as wrong
    binary_preds = [1 if p > 3 else 0 if p >= 0 else -1 for p in predictions]
    binary_refs = [1 if r > 3 else 0 for r in references] # True labels are always 0-7

    binary_acc = accuracy_metric.compute(predictions=binary_preds, references=binary_refs)
    
    # Use average='binary' for F1. This calculates F1 for the positive class (1, malignant)
    binary_f1 = f1_metric.compute(predictions=binary_preds, references=binary_refs, average="binary") 

    # --- 3. Return all metrics ---
    return {
        "accuracy_8class": multi_class_acc['accuracy'],
        "f1_8class_weighted": multi_class_f1['f1'],
        "accuracy_binary": binary_acc['accuracy'],
        "f1_binary_malignant": binary_f1['f1']
    }
def postprocess_prediction(text):
    """
    Extract predicted class number from model output.
    
    WHY THIS PARSING:
    - Model may output "Classification: 5" or just "5"
    - We use regex to find any digit 0-7 in the response
    - Returns -1 if no valid digit found (counts as wrong prediction)
    """
    digit_match = re.search(r'\b([0-7])\b', text.strip())
    return int(digit_match.group(1)) if digit_match else -1

def batch_predict(model, processor, prompts, images, batch_size=8, max_new_tokens=40):
    """
    Run batch inference on the model.
    
    WHY BATCH_SIZE=8:
    - Balance between speed and memory usage with bfloat16
    - Can be increased if more VRAM available
    
    WHY max_new_tokens=40:
    - We only need 1-2 tokens for the answer
    - 40 gives buffer for any extra text model might generate
    """
    predictions = []
    for i in range(0, len(prompts), batch_size):
        batch_texts = prompts[i:i + batch_size]
        batch_images = [[img] for img in images[i:i + batch_size]]
        
        # Process inputs
        inputs = processor(
            text=batch_texts,
            images=batch_images,
            padding=True,
            return_tensors="pt"
        ).to("cuda", torch.bfloat16)
        
        # Track prompt lengths to extract only generated text
        prompt_lengths = inputs["attention_mask"].sum(dim=1)
        
        # Generate
        with torch.inference_mode():
            outputs = model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                do_sample=False,  # Greedy decoding for deterministic results
                pad_token_id=processor.tokenizer.pad_token_id
            )
        
        # Decode only the generated part (not the prompt)
        for seq, length in zip(outputs, prompt_lengths):
            generated = processor.decode(seq[length:], skip_special_tokens=True)
            predictions.append(postprocess_prediction(generated))
    
    return predictions

# Prepare evaluation data
eval_prompts = [
    processor.apply_chat_template(
        [msg[0]],  # Only user message, not assistant response
        add_generation_prompt=True,
        tokenize=False
    )
    for msg in formatted_eval["messages"]
]
eval_images = formatted_eval["image"]
eval_labels = formatted_eval["label"]

# Run baseline evaluation
print("Running baseline evaluation...")
baseline_preds = batch_predict(model, processor, eval_prompts, eval_images)
baseline_metrics = compute_metrics(baseline_preds, eval_labels)

# print(f"\n{'BASELINE RESULTS':-^80}")
# print(f"Accuracy: {baseline_metrics['accuracy']:.1%}")
# print(f"F1 Score: {baseline_metrics['f1']:.3f}")
# print("-"*80)



print(f"\n{'BASELINE RESULTS':-^80}")
print(f"Accuracy (8-class):   {baseline_metrics['accuracy_8class']:.1%}")
print(f"F1 Score (8-class):   {baseline_metrics['f1_8class_weighted']:.3f}")
print(f"Accuracy (Binary):    {baseline_metrics['accuracy_binary']:.1%}")
print(f"F1 Score (Binary):    {baseline_metrics['f1_binary_malignant']:.3f}")
print("-" * 80)



STEP 4: Evaluating Baseline Model
Running baseline evaluation...

--------------------------------BASELINE RESULTS--------------------------------
Accuracy (8-class):   32.6%
F1 Score (8-class):   0.241
Accuracy (Binary):    59.6%
F1 Score (Binary):    0.639
--------------------------------------------------------------------------------


In [22]:

# ============================================================================
# 5. CONFIGURE AND RUN FINE-TUNING
# ============================================================================
print("\n" + "="*80)
print("STEP 5: Fine-tuning with LoRA")
print("="*80)

# LoRA Configuration
# WHY LORA:
# - Trains only a small fraction of parameters (~1% of model)
# - Much faster and memory-efficient than full fine-tuning
# - Often achieves comparable performance
#
# PARAMETER EXPLANATIONS:
# - r=8: Rank of LoRA matrices (lower = fewer params, faster, less capacity)
#   - Too low (r=2): May underfit, can't learn complex patterns
#   - Too high (r=64): More params, slower, risk overfitting on small datasets
#   - r=8 is good balance for 500 training samples
#
# - lora_alpha=16: Scaling factor for LoRA weights
#   - Typically set to 2*r as a rule of thumb
#   - Controls how much LoRA adapters affect base model
#
# - lora_dropout=0.1: Regularization to prevent overfitting
#   - Higher values (0.2) = more regularization but may underfit
#   - Lower values (0.05) = less regularization but may overfit
#
# - target_modules="all-linear": Apply LoRA to all linear layers
#   - Alternative: Specify specific layers like ["q_proj", "v_proj"]
#   - "all-linear" is simpler and works well for most cases


peft_config = LoraConfig(
    r=8,
    lora_alpha=16,
    lora_dropout=0.1,
    bias="none",
    target_modules="all-linear",
    task_type="CAUSAL_LM",
)

# Custom data collator for vision-language training
def collate_fn(examples):
    """
    Prepare batches for training with images and text.
    
    WHY CUSTOM COLLATOR:
    - Need to handle both image and text inputs
    - Must mask padding tokens and image tokens in loss computation
    - MedGemma has special image token handling requirements
    """
    texts = []
    images = []
    
    for example in examples:
        images.append([example["image"]])
        texts.append(
            processor.apply_chat_template(
                example["messages"],
                add_generation_prompt=False,
                tokenize=False
            ).strip()
        )
    
    # Tokenize and process
    batch = processor(text=texts, images=images, return_tensors="pt", padding=True)
    
    # Create labels (same as input_ids but with masking)
    labels = batch["input_ids"].clone()
    
    # Mask padding tokens (model shouldn't learn from padding)
    labels[labels == processor.tokenizer.pad_token_id] = -100
    
    # Mask image tokens (loss not computed on image embeddings)
    image_token_id = processor.tokenizer.convert_tokens_to_ids(
        processor.tokenizer.special_tokens_map["boi_token"]
    )
    labels[labels == image_token_id] = -100
    labels[labels == 262144] = -100  # Additional image-related token
    
    batch["labels"] = labels
    return batch

training_args = SFTConfig(
    output_dir="medgemma-breastcancer-finetuned",
    num_train_epochs=5,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=8,
    gradient_checkpointing=True,
    optim="paged_adamw_8bit",
    learning_rate=5e-4,
    lr_scheduler_type="cosine",
    warmup_ratio=0.03,  # Warm up LR for first 3% of training
    max_grad_norm=0.3,  # Clip gradients to prevent instability
    bf16=True,  # Use bfloat16 precision
    logging_steps=10,
    save_strategy="steps",
    save_steps=100,
    eval_strategy="epoch",
    push_to_hub=False,
    report_to="none",
    gradient_checkpointing_kwargs={"use_reentrant": False},
    dataset_kwargs={"skip_prepare_dataset": True},
    remove_unused_columns=False,
    label_names=["labels"], 
)



STEP 5: Fine-tuning with LoRA


In [23]:
import time

# Initialize trainer
trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=formatted_train,
    eval_dataset=formatted_eval,
    peft_config=peft_config,
    processing_class=processor,
    data_collator=collate_fn,
)

# Train the model
print("Starting training...")
print(f"Total training steps: ~{(len(formatted_train) * 5) // 8}")
start_time = time.perf_counter()

trainer.train()
# If you would like to continue training that stopped from some reason - 
# trainer.train(resume_from_checkpoint=True)
end_time = time.perf_counter()

# Save the fine-tuned model
trainer.save_model()
print(f"✓ Model saved to {training_args.output_dir}")

print("Model training duration: ", (end_time - start_time)/60, " minutes")

The model is already on multiple devices. Skipping the move to device specified in `args`.
The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'eos_token_id': 1, 'bos_token_id': 2, 'pad_token_id': 0}.


Starting training...
Total training steps: ~478


	save_steps: 100 (from args) != 500 (from trainer_state.json)


Epoch,Training Loss,Validation Loss


✓ Model saved to medgemma-breastcancer-finetuned
Model training duration:  0.026576538550095088  minutes


In [24]:

# ============================================================================
# 6. EVALUATE FINE-TUNED MODEL
# ============================================================================
print("\n" + "="*80)
print("STEP 6: Evaluating Fine-tuned Model")
print("="*80)

# Clear memory and load fine-tuned model
del model
torch.cuda.empty_cache()
gc.collect()

# Load base model
base_model = AutoModelForImageTextToText.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    attn_implementation="sdpa"
)

# Load LoRA adapters and merge them
finetuned_model = PeftModel.from_pretrained(base_model, training_args.output_dir)
finetuned_model = finetuned_model.merge_and_unload()

# Load processor from fine-tuned checkpoint
processor_finetuned = AutoProcessor.from_pretrained(training_args.output_dir)

# Configure for generation
finetuned_model.generation_config.max_new_tokens = 50
finetuned_model.generation_config.pad_token_id = processor_finetuned.tokenizer.pad_token_id
finetuned_model.config.pad_token_id = processor_finetuned.tokenizer.pad_token_id

print("✓ Fine-tuned model loaded")

# Run evaluation
print("Running fine-tuned evaluation...")
finetuned_preds = batch_predict(
    finetuned_model,
    processor_finetuned,
    eval_prompts,
    eval_images,
    batch_size=4  # Smaller batch size for safety
)
finetuned_metrics = compute_metrics(finetuned_preds, eval_labels)



STEP 6: Evaluating Fine-tuned Model


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

✓ Fine-tuned model loaded
Running fine-tuned evaluation...


In [25]:
# ============================================================================
# 7. COMPARE RESULTS (with binary)
# ============================================================================
print("\n" + "=" * 80)
print("FINAL RESULTS COMPARISON")
print("=" * 80)

print(f"\n--- 8-Class Classification (0-7) ---")
print(f"{'Model':<20} {'Accuracy':<12} {'F1 (Weighted)':<15}")
print("-" * 47)
print(f"{'Baseline':<20} {baseline_metrics['accuracy_8class']:>10.1%} {baseline_metrics['f1_8class_weighted']:>13.3f}")
print(f"{'Fine-tuned':<20} {finetuned_metrics['accuracy_8class']:>10.1%} {finetuned_metrics['f1_8class_weighted']:>13.3f}")
print("-" * 47)

print(f"\n--- Binary (Benign/Malignant) Classification ---")
print(f"{'Model':<20} {'Accuracy':<12} {'F1 (Malignant)':<15}")
print("-" * 47)
print(f"{'Baseline':<20} {baseline_metrics['accuracy_binary']:>10.1%} {baseline_metrics['f1_binary_malignant']:>13.3f}")
print(f"{'Fine-tuned':<20} {finetuned_metrics['accuracy_binary']:>10.1%} {finetuned_metrics['f1_binary_malignant']:>13.3f}")
print("-" * 47)

print(f"\n{'=' * 80}")

# Success indicators (checking 8-class accuracy)
if finetuned_metrics['accuracy_8class'] > baseline_metrics['accuracy_8class']:
    print("\n✓ Fine-tuning successful! 8-class accuracy improved.")
else:
    print("\n⚠ Fine-tuning did not improve 8-class accuracy. Consider:")
    print("  - Training for more epochs")
    print("  - Using more training data")
    print("  - Adjusting learning rate or LoRA rank")

print("\nTraining complete!")


FINAL RESULTS COMPARISON

--- 8-Class Classification (0-7) ---
Model                Accuracy     F1 (Weighted)  
-----------------------------------------------
Baseline                  32.6%         0.241
Fine-tuned                87.2%         0.865
-----------------------------------------------

--- Binary (Benign/Malignant) Classification ---
Model                Accuracy     F1 (Malignant) 
-----------------------------------------------
Baseline                  59.6%         0.639
Fine-tuned                99.0%         0.991
-----------------------------------------------


✓ Fine-tuning successful! 8-class accuracy improved.

Training complete!
