# Multimodal Model Evaluation for Melanoma Detection

<table align="left">
  <td style="text-align: center">
    <a href="https://colab.research.google.com/github/ayoisio/gke-multimodal-fine-tune-gemma-3-axolotl/blob/main/evaluation/Evaluation.ipynb">
      <img src="https://cloud.google.com/ml-engine/images/colab-logo-32px.png" alt="Google Colaboratory logo"><br> Run in Colab
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/colab/import/https:%2F%2Fraw.githubusercontent.com%2Fayoisio%2Fgke-multimodal-fine-tune-gemma-3-axolotl%2Fmain%2Fevaluation%2FEvaluation.ipynb">
      <img width="32px" src="https://lh3.googleusercontent.com/JmcxdQi-qOpctIvWKgPtrzZdJJK-J3sWE1RsfjZNwshCFgE_9fULcNpuXYTilIR2hjwN" alt="Google Cloud Colab Enterprise logo"><br> Run in Colab Enterprise
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://github.com/ayoisio/gke-multimodal-fine-tune-gemma-3-axolotl/blob/main/evaluation/Evaluation.ipynb">
      <img src="https://cloud.google.com/ml-engine/images/github-logo-32px.png" alt="GitHub logo"><br> View on GitHub
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/workbench/deploy-notebook?download_url=https://raw.githubusercontent.com/ayoisio/gke-multimodal-fine-tune-gemma-3-axolotl/main/evaluation/Evaluation.ipynb">
      <img src="https://lh3.googleusercontent.com/UiNooY4LUgW_oTvpsNhPpQzsstV5W8F7rYgxgGBD85cWJoLmrOzhVs_ksK_vgx40SHs7jCqkTkCk=e14-rj-sc0xffffff-h130-w32" alt="Vertex AI logo"><br> Open in Vertex AI Workbench
    </a>
  </td>
</table>

<div style="clear: both;"></div>

## Overview

This notebook demonstrates how to evaluate the performance of fine-tuned multimodal AI models on the [SIIM-ISIC Melanoma Classification](https://challenge2020.isic-archive.com/) dataset. We'll compare three models: base Gemma 3, our fine-tuned Gemma 3, and MedGemma to assess improvements in melanoma detection capabilities.

### What you'll learn

- How to load and compare multimodal models (base, fine-tuned, and domain-specific)
- How to run batch inference on medical imaging datasets
- How to calculate and visualize key performance metrics
- How to interpret model improvements for clinical applications
- How to handle edge cases and model classification challenges

### Prerequisites

- Completed fine-tuning using the main repository
- Access to Google Cloud Storage with model files
- Hugging Face account with Gemma 3 access
- GPU-enabled environment (recommended: A100 or better)

### Time to complete

45-60 minutes (depending on number of test images and GPU availability)

---

## Introduction

The SIIM-ISIC dataset contains over 33,000 dermoscopic images of skin lesions with corresponding labels indicating whether each lesion is benign or malignant melanoma. After fine-tuning Gemma 3 on this dataset, we need to rigorously evaluate its performance against both the base model and medical-domain-specific models.

Our evaluation will:

1. **Load multiple models** - Base Gemma 3, fine-tuned Gemma 3, and MedGemma
2. **Run inference** - Process test images through each model
3. **Calculate metrics** - Accuracy, precision, recall, specificity, and F1 scores
4. **Visualize results** - Create comprehensive performance comparisons
5. **Analyze improvements** - Quantify the benefits of fine-tuning

**⚠️ Note**: This notebook contains medical imagery. The content is intended for educational and research purposes only. Models evaluated here should not be used for actual medical diagnosis without proper validation and regulatory approval.

## Step 1: Install dependencies

Let's install all required packages for model loading, inference, and evaluation:

In [None]:
# Install required packages with specific versions for compatibility
print("📦 Installing required packages...")
!pip install transformers==4.51.3 -q
!pip install accelerate==1.6.0 -q
!pip install pillow==11.2.1 -q
!pip install matplotlib==3.9.4 -q
!pip install seaborn==0.13.2 -q
!pip install sentencepiece==0.2.0 -q
!pip install protobuf==3.20.3 -q
!pip install peft==0.15.2 -q
!pip install bitsandbytes==0.45.5 -q
!pip install triton==3.3.0 -q
!pip install torch==2.5.1 -q
!pip install torchvision==0.20.1 -q
!pip install scikit-learn==1.5.1 -q
!pip install pandas==2.2.2 -q
!pip install numpy==1.26.4 -q

print("✅ Package installation complete!")

# Verify key package versions
import transformers
import torch
print(f"\n📌 Key package versions:")
print(f"  • Transformers: {transformers.__version__}")
print(f"  • PyTorch: {torch.__version__}")
print(f"  • CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"  • CUDA version: {torch.version.cuda}")
    print(f"  • GPU: {torch.cuda.get_device_name(0)}")

## Step 2: Set up your environment

Configure authentication for both Google Cloud and Hugging Face:

In [None]:
import os
import sys

# Check if we're running in Colab
IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    from google.colab import auth
    auth.authenticate_user()
    print("✅ Authenticated via Colab")
else:
    # For Vertex AI Workbench or local environments
    print("ℹ️ Using Application Default Credentials")
    print("   If not authenticated, run: gcloud auth application-default login")

In [None]:
# Set your project ID and GCS bucket
PROJECT_ID = "YOUR_PROJECT_ID"  # @param {type:"string"}
GCS_BUCKET_NAME = f"{PROJECT_ID}-melanoma-dataset"  # @param {type:"string"}

# Set project
!gcloud config set project {PROJECT_ID}

# Verify project is set
!echo "Current project: $(gcloud config get-value project)"

print(f"\n📁 Using GCS bucket: gs://{GCS_BUCKET_NAME}")

In [None]:
# Set up Hugging Face authentication
# Get your token from: https://huggingface.co/settings/tokens
HF_TOKEN = "YOUR_HUGGING_FACE_TOKEN"  # @param {type:"string"}

import huggingface_hub
huggingface_hub.login(token=HF_TOKEN)
print("✅ Logged in to Hugging Face")

# Note: Make sure you have accepted the Gemma 3 model terms of use on Hugging Face

## Step 3: Import libraries and configure settings

Import all necessary libraries and set up the evaluation environment:

In [None]:
# Standard library imports
import os
import json
import time
import re
import subprocess
import tempfile
import traceback
import collections.abc
from collections import defaultdict
from datetime import datetime

# Data science imports
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image

# Machine learning imports
import torch
from transformers import (
    AutoModelForCausalLM,
    AutoModelForImageTextToText,
    AutoTokenizer,
    AutoProcessor
)
from peft import PeftModel
from sklearn.metrics import (
    confusion_matrix,
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    roc_curve,
    roc_auc_score
)

# Configure CUDA environment for better debugging
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
os.environ['TORCH_USE_CUDA_DSA'] = "1"
os.environ["PYTORCH_USE_CUDA_DSA"] = "1"

# Set device and check capabilities
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"🖥️ Using device: {device}")
if torch.cuda.is_available():
    print(f"  • GPU: {torch.cuda.get_device_name(0)}")
    print(f"  • Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")
    print(f"  • CUDA version: {torch.version.cuda}")
    print(f"  • BF16 support: {torch.cuda.is_bf16_supported()}")

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)

print("\n✅ Environment configured successfully!")

## Step 4: Download evaluation data

Download the fine-tuned model files and test images from Google Cloud Storage:

In [None]:
# Create temporary directories
temp_dir = tempfile.mkdtemp()
tuned_model_dir = os.path.join(temp_dir, "tuned_model")
image_dir = os.path.join(temp_dir, "images")
os.makedirs(tuned_model_dir, exist_ok=True)
os.makedirs(image_dir, exist_ok=True)

print(f"📁 Created temporary directories:")
print(f"  • Tuned Model: {tuned_model_dir}")
print(f"  • Images: {image_dir}")

# Download the fine-tuned model files
print("\n📥 Downloading fine-tuned model files...")
!gsutil -m cp -r gs://{GCS_BUCKET_NAME}/tuned-models/* {tuned_model_dir} 2>/dev/null || echo "Model files may not exist yet"

# List downloaded model files
model_files = os.listdir(tuned_model_dir) if os.path.exists(tuned_model_dir) else []
if model_files:
    print(f"\n✅ Downloaded {len(model_files)} model files:")
    for file in sorted(model_files):
        file_path = os.path.join(tuned_model_dir, file)
        size_mb = os.path.getsize(file_path) / (1024 * 1024)
        print(f"  • {file} ({size_mb:.1f} MB)")
else:
    print("⚠️ No model files found. Make sure fine-tuning has completed.")

# Download test images
print("\n📥 Downloading test images...")
!gsutil -m cp "gs://{GCS_BUCKET_NAME}/processed_images/test/*.jpg" {image_dir} 2>/dev/null || echo "Test images may not be available"

# Count downloaded images
image_files = [f for f in os.listdir(image_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
print(f"\n📸 Downloaded {len(image_files)} test images")
if len(image_files) > 0:
    print(f"  • Sample images: {', '.join(sorted(image_files)[:5])}{'...' if len(image_files) > 5 else ''}")

## Step 5: Load ground truth labels

Load the ground truth labels for our test images from the ISIC dataset:

In [None]:
def load_ground_truth_data():
    """
    Load ground truth data from the ISIC dataset.
    Returns a dictionary mapping image filenames to their labels (0=benign, 1=melanoma).
    """
    print("📊 Loading ground truth data...")

    # Try to download ground truth file from GCS
    gcs_path = f"gs://{GCS_BUCKET_NAME}/isic-challenge-data.s3.amazonaws.com/2020/ISIC_2020_Training_GroundTruth_v2.csv"
    local_path = os.path.join(temp_dir, "ISIC_2020_Training_GroundTruth_v2.csv")

    try:
        print(f"  • Downloading from: {gcs_path}")
        subprocess.run(["gsutil", "cp", gcs_path, local_path], check=True, capture_output=True)

        # Load the CSV
        gt_data = pd.read_csv(local_path)
        print(f"  ✅ Loaded {len(gt_data):,} ground truth labels")

        # Show data structure
        print("\n📋 Ground truth data structure:")
        print(f"  • Columns: {', '.join(gt_data.columns)}")
        print(f"\n  • Sample data:")
        display(gt_data.head())

        # Create mapping from image filename to label
        image_to_label = dict(zip(
            gt_data['image_name'].apply(lambda x: f"{x}.jpg"),
            gt_data['target']
        ))

        # Calculate statistics
        melanoma_count = gt_data['target'].sum()
        benign_count = len(gt_data) - melanoma_count

        print(f"\n📊 Dataset statistics:")
        print(f"  • Total images: {len(gt_data):,}")
        print(f"  • Benign: {benign_count:,} ({benign_count/len(gt_data)*100:.1f}%)")
        print(f"  • Melanoma: {melanoma_count:,} ({melanoma_count/len(gt_data)*100:.1f}%)")
        print(f"  • Class imbalance ratio: {benign_count/melanoma_count:.1f}:1")

        return image_to_label

    except Exception as e:
        print(f"❌ Error loading ground truth labels: {e}")
        print("   Evaluation will proceed without ground truth labels.")
        return {}

# Load ground truth labels
image_to_label = load_ground_truth_data()

## Step 6: Define model loading functions

Define functions to load the base model, fine-tuned model, and MedGemma:

In [None]:
def load_models(base_model_id="google/gemma-3-4b-it", tuned_model_path=None):
    """
    Load base and fine-tuned multimodal models with their tokenizers and processors.
    Ensures the base model returned is pristine if a tuned model is also loaded.

    Args:
        base_model_id: Hugging Face model ID for the base model
        tuned_model_path: Local path to the fine-tuned model adapter files

    Returns:
        Dictionary with base_model, tuned_model, tokenizer, and processor
    """
    print(f"🤖 Loading models...")
    print(f"  • Base model ID: {base_model_id}")

    # Load tokenizer and processor
    print("\n📚 Loading tokenizer and processor...")
    tokenizer = AutoTokenizer.from_pretrained(base_model_id)
    processor = AutoProcessor.from_pretrained(base_model_id)
    print("  ✅ Tokenizer and processor loaded")

    # Determine device and dtype
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model_dtype = torch.bfloat16 if device == "cuda" and torch.cuda.is_bf16_supported() else torch.float32

    print(f"\n🔧 Model configuration:")
    print(f"  • Device: {device}")
    print(f"  • Dtype: {model_dtype}")

    # Load the pristine base model
    print(f"\n📥 Loading pristine base model...")
    pristine_base_model = AutoModelForImageTextToText.from_pretrained(
        base_model_id,
        torch_dtype=model_dtype,
        device_map="auto"
    )
    pristine_base_model.eval()
    print("  ✅ Base model loaded successfully")

    # Load fine-tuned model if path provided
    loaded_tuned_model = None
    if tuned_model_path and os.path.exists(tuned_model_path):
        print(f"\n📥 Loading fine-tuned model from {tuned_model_path}...")

        # Check for adapter files
        adapter_config_path = os.path.join(tuned_model_path, "adapter_config.json")
        adapter_model_path = os.path.join(tuned_model_path, "adapter_model.safetensors")

        # Check alternative adapter model file
        if not os.path.exists(adapter_model_path):
            adapter_model_path_bin = os.path.join(tuned_model_path, "adapter_model.bin")
            if os.path.exists(adapter_model_path_bin):
                adapter_model_path = adapter_model_path_bin

        print(f"  • Adapter config exists: {os.path.exists(adapter_config_path)}")
        print(f"  • Adapter model exists: {os.path.exists(adapter_model_path)}")

        if os.path.exists(adapter_config_path) and os.path.exists(adapter_model_path):
            try:
                # Load a fresh base model instance for the adapter
                print("  • Loading fresh base model instance for adapter...")
                base_model_for_adapter = AutoModelForImageTextToText.from_pretrained(
                    base_model_id,
                    torch_dtype=model_dtype,
                    device_map="auto"
                )

                # Apply the adapter
                print("  • Applying fine-tuned adapter...")
                loaded_tuned_model = PeftModel.from_pretrained(
                    base_model_for_adapter,
                    tuned_model_path
                )
                loaded_tuned_model.eval()
                print("  ✅ Fine-tuned model loaded successfully")

            except Exception as e:
                print(f"  ❌ Error loading fine-tuned model: {e}")
                traceback.print_exc()
                if 'base_model_for_adapter' in locals():
                    del base_model_for_adapter
                    if torch.cuda.is_available():
                        torch.cuda.empty_cache()
        else:
            print("  ⚠️ Missing adapter files, cannot load fine-tuned model")

    return {
        "base_model": pristine_base_model,
        "tuned_model": loaded_tuned_model,
        "tokenizer": tokenizer,
        "processor": processor
    }

## Step 7: Define inference functions

Define the patterns for classifying model responses and the inference function:

In [None]:
# Define classification patterns
POSITIVE_PATTERNS = [
    r"yes, this appears to be malignant melanoma",
    r"this appears to be malignant melanoma",
    r"appears to be malignant melanoma",
    r"it appears to be malignant melanoma",
    r"Based on.*this appears to be malignant melanoma",
    r"Based on.*it appears to be malignant melanoma"
]

NEGATIVE_PATTERNS = [
    r"does not appear to be malignant melanoma",
    r"no, this does not appear to be malignant melanoma",
    r"this does not appear to be malignant melanoma",
    r"it does not appear to be malignant melanoma"
]

def run_inference(model, tokenizer, processor, image_path, prompt, model_family="gemma3"):
    """
    Run inference on a single image with a given model.

    Args:
        model: The loaded model
        tokenizer: The model's tokenizer
        processor: The model's processor
        image_path: Path to the image file
        prompt: Text prompt for the model
        model_family: Model family identifier (gemma3, medgemma)

    Returns:
        Dictionary with success status, response, classification, and timing
    """
    image_filename = os.path.basename(image_path)

    try:
        # Load and prepare image
        image = Image.open(image_path).convert("RGB")

        # System prompt
        system_prompt = "You are a dermatology assistant that helps identify potential melanoma from skin lesion images."

        # Prepare messages in chat template format
        messages = [
            {
                "role": "system",
                "content": [{"type": "text", "text": system_prompt}]
            },
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": prompt},
                    {"type": "image", "image": image}
                ]
            }
        ]

        # Apply chat template and tokenize
        inputs = processor.apply_chat_template(
            messages,
            add_generation_prompt=True,
            tokenize=True,
            return_dict=True,
            return_tensors="pt"
        )

        # Verify inputs are properly formatted
        if not isinstance(inputs, collections.abc.Mapping):
            return {
                "success": False,
                "error": f"Input preparation error: expected dictionary, got {type(inputs)}",
                "is_melanoma": None,
                "inference_time": 0
            }

        # Move inputs to model device
        inputs = inputs.to(model.device)

        # Run inference
        start_time = time.time()
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=500,
                do_sample=False,
                temperature=0.0
            )
        elapsed = time.time() - start_time

        # Decode response (only the generated part)
        input_len = inputs["input_ids"].shape[-1]
        response = tokenizer.decode(outputs[0][input_len:], skip_special_tokens=True).strip()

        # Classify response
        is_melanoma = None
        response_lower = response.lower()

        # Check positive patterns
        for pattern in POSITIVE_PATTERNS:
            if re.search(pattern, response_lower, re.IGNORECASE):
                is_melanoma = 1
                break

        # Check negative patterns if not already classified
        if is_melanoma is None:
            for pattern in NEGATIVE_PATTERNS:
                if re.search(pattern, response_lower, re.IGNORECASE):
                    is_melanoma = 0
                    break

        return {
            "success": True,
            "response": response,
            "is_melanoma": is_melanoma,
            "inference_time": elapsed
        }

    except Exception as e:
        print(f"❌ Error during inference on {image_filename}: {str(e)}")
        traceback.print_exc()
        return {
            "success": False,
            "error": str(e),
            "is_melanoma": None,
            "inference_time": 0
        }

print("✅ Inference functions defined")
print(f"\n📋 Classification patterns:")
print(f"  • Positive patterns: {len(POSITIVE_PATTERNS)}")
print(f"  • Negative patterns: {len(NEGATIVE_PATTERNS)}")

## Step 8: Define evaluation functions

Define the main function to evaluate multiple models on the test dataset:

In [None]:
def evaluate_models_on_dataset(
    gemma_base_model_id="google/gemma-3-4b-it",
    gemma_tuned_model_path=None,
    medgemma_model_id_param=None,
    test_image_dir=None,
    num_images=0,
    results_dir="evaluation_results",
    model_processing_order=["base_gemma", "tuned_gemma", "medgemma"],
    specific_image_files_to_process=None,
    custom_prompts_map=None,
    reprocess_policy=None
):
    """
    Evaluate specified models on test images.

    Args:
        gemma_base_model_id: HuggingFace ID for base Gemma model
        gemma_tuned_model_path: Path to fine-tuned model adapter files
        medgemma_model_id_param: HuggingFace ID for MedGemma model
        test_image_dir: Directory containing test images
        num_images: Maximum number of images to process (0 for all)
        results_dir: Directory to save results
        model_processing_order: Order in which to process models
        specific_image_files_to_process: List of specific images to process
        custom_prompts_map: Custom prompts for each model
        reprocess_policy: Policy for reprocessing (None, "skip_if_exists", "overwrite")

    Returns:
        Dictionary of results for each model
    """
    print("🚀 Starting model evaluation...")
    os.makedirs(results_dir, exist_ok=True)

    if not test_image_dir or not os.path.exists(test_image_dir):
        raise ValueError(f"Test image directory '{test_image_dir}' not found")

    loaded_models_info = {}

    # Load Gemma 3 models
    if "base_gemma" in model_processing_order or "tuned_gemma" in model_processing_order:
        if gemma_base_model_id:
            print(f"\n--- Loading Gemma 3 models ---")
            gemma_models = load_models(gemma_base_model_id, gemma_tuned_model_path)

            if gemma_models.get("base_model"):
                loaded_models_info["base_gemma"] = {
                    "model": gemma_models["base_model"],
                    "tokenizer": gemma_models["tokenizer"],
                    "processor": gemma_models["processor"],
                    "family": "gemma3",
                    "name_for_log": "Base Gemma 3"
                }

            if gemma_models.get("tuned_model"):
                loaded_models_info["tuned_gemma"] = {
                    "model": gemma_models["tuned_model"],
                    "tokenizer": gemma_models["tokenizer"],
                    "processor": gemma_models["processor"],
                    "family": "gemma3",
                    "name_for_log": "Tuned Gemma 3"
                }

    # Load MedGemma if requested
    if "medgemma" in model_processing_order and medgemma_model_id_param:
        print(f"\n--- Loading MedGemma model ---")
        try:
            med_tokenizer = AutoTokenizer.from_pretrained(medgemma_model_id_param)
            med_processor = AutoProcessor.from_pretrained(medgemma_model_id_param)

            device = "cuda" if torch.cuda.is_available() else "cpu"
            dtype = torch.bfloat16 if device == "cuda" and torch.cuda.is_bf16_supported() else torch.float32

            med_model = AutoModelForImageTextToText.from_pretrained(
                medgemma_model_id_param,
                torch_dtype=dtype,
                device_map="auto"
            )
            med_model.eval()

            loaded_models_info["medgemma"] = {
                "model": med_model,
                "tokenizer": med_tokenizer,
                "processor": med_processor,
                "family": "medgemma",
                "name_for_log": "MedGemma"
            }
            print("  ✅ MedGemma loaded successfully")
        except Exception as e:
            print(f"  ❌ Error loading MedGemma: {e}")

    # Determine images to process
    if specific_image_files_to_process:
        actual_image_files = [f for f in specific_image_files_to_process
                             if os.path.exists(os.path.join(test_image_dir, f))]
        print(f"\n📸 Processing {len(actual_image_files)} specific images")
    else:
        all_images = sorted([f for f in os.listdir(test_image_dir)
                           if f.lower().endswith(('.jpg', '.jpeg', '.png'))])
        if num_images > 0:
            actual_image_files = all_images[:num_images]
        else:
            actual_image_files = all_images
        print(f"\n📸 Processing {len(actual_image_files)} images")

    # Default prompt
    default_prompt = (
        "This is a skin lesion image. Does this appear to be malignant melanoma? "
        "Please explain your reasoning and conclude with either "
        "'Yes, this appears to be malignant melanoma.' or "
        "'No, this does not appear to be malignant melanoma.'"
    )

    # Process images for each model
    current_run_results = defaultdict(list)

    for model_key in model_processing_order:
        if model_key not in loaded_models_info:
            print(f"\n⚠️ Skipping {model_key} (not loaded)")
            continue

        model_info = loaded_models_info[model_key]
        model_name = model_info["name_for_log"]

        print(f"\n--- Evaluating {model_name} ---")

        # Load existing results if using skip policy
        existing_results_map = {}
        if specific_image_files_to_process and reprocess_policy == "skip_if_exists":
            results_file = os.path.join(results_dir, f"{model_key}_results.json")
            if os.path.exists(results_file):
                try:
                    with open(results_file, 'r') as f:
                        existing_data = json.load(f)
                        existing_results_map = {item['image']: item for item in existing_data}
                    print(f"  • Loaded {len(existing_results_map)} existing results")
                except Exception as e:
                    print(f"  ⚠️ Could not load existing results: {e}")

        # Process each image
        for i, image_file in enumerate(actual_image_files):
            # Skip if exists and policy says so
            if (specific_image_files_to_process and
                reprocess_policy == "skip_if_exists" and
                image_file in existing_results_map):
                print(f"  • Skipping {image_file} (already processed)")
                current_run_results[model_key].append(existing_results_map[image_file])
                continue

            image_path = os.path.join(test_image_dir, image_file)
            ground_truth = image_to_label.get(image_file, None)

            # Get prompt
            prompt = default_prompt
            if custom_prompts_map and model_key in custom_prompts_map:
                prompt = custom_prompts_map[model_key]

            # Run inference
            print(f"  • Processing {image_file} ({i+1}/{len(actual_image_files)})...", end='')
            result = run_inference(
                model_info["model"],
                model_info["tokenizer"],
                model_info["processor"],
                image_path,
                prompt,
                model_family=model_info["family"]
            )

            result["image"] = image_file
            if ground_truth is not None:
                result["ground_truth"] = ground_truth

            current_run_results[model_key].append(result)

            # Print result
            if result["success"]:
                pred = "Melanoma" if result['is_melanoma'] == 1 else "Benign" if result['is_melanoma'] == 0 else "Uncertain"
                print(f" {pred} ({result['inference_time']:.1f}s)")
            else:
                print(f" Failed")

    # Save results
    print(f"\n💾 Saving results to {results_dir}...")
    for model_key, results in current_run_results.items():
        if not results:
            continue

        output_file = os.path.join(results_dir, f"{model_key}_results.json")

        # Merge with existing results if doing targeted reprocessing
        if specific_image_files_to_process:
            existing_map = {}
            if os.path.exists(output_file):
                try:
                    with open(output_file, 'r') as f:
                        existing_data = json.load(f)
                        existing_map = {item['image']: item for item in existing_data}
                except:
                    pass

            # Update with new results
            for result in results:
                existing_map[result['image']] = result

            final_results = list(existing_map.values())
        else:
            final_results = results

        # Save
        with open(output_file, 'w') as f:
            json.dump(final_results, f, indent=2)
        print(f"  • Saved {len(final_results)} results for {model_key}")

    return dict(current_run_results)

print("✅ Evaluation functions defined")

## Step 9: Run the evaluation

Now let's run the evaluation on our test images. You can customize which models to evaluate and how many images to process:

In [None]:
# Configuration for evaluation
EVALUATE_MEDGEMMA = True  # @param {type:"boolean"}
MEDGEMMA_MODEL_ID = "google/medgemma-4b-it" if EVALUATE_MEDGEMMA else None  # @param {type:"string"}
NUM_IMAGES_TO_EVALUATE = 100  # @param {type:"integer"}
# Set to 0 to evaluate all images

# Define model processing order
model_order = ["base_gemma", "tuned_gemma"]
if EVALUATE_MEDGEMMA:
    model_order.append("medgemma")

print(f"📋 Evaluation configuration:")
print(f"  • Models to evaluate: {', '.join(model_order)}")
print(f"  • Images to process: {'All' if NUM_IMAGES_TO_EVALUATE == 0 else NUM_IMAGES_TO_EVALUATE}")
print(f"  • MedGemma ID: {MEDGEMMA_MODEL_ID if EVALUATE_MEDGEMMA else 'Not evaluating'}")

# Run evaluation
print("\n" + "="*60)
print("🚀 STARTING EVALUATION")
print("="*60)

evaluation_results = evaluate_models_on_dataset(
    gemma_base_model_id="google/gemma-3-4b-it",
    gemma_tuned_model_path=tuned_model_dir,
    medgemma_model_id_param=MEDGEMMA_MODEL_ID,
    test_image_dir=image_dir,
    num_images=NUM_IMAGES_TO_EVALUATE,
    model_processing_order=model_order
)

# Summary
print("\n" + "="*60)
print("✅ EVALUATION COMPLETE")
print("="*60)
print("\n📊 Results summary:")
for model_key, results in evaluation_results.items():
    successful = sum(1 for r in results if r.get('success', False))
    print(f"  • {model_key}: {successful}/{len(results)} successful inferences")

## Step 10: Post-process results

Process the results to fix any null classifications using our defined patterns:

In [None]:
def process_results(results_file):
    """
    Process results file to fix any null classifications using pattern matching.

    Args:
        results_file: Path to the results JSON file

    Returns:
        List of processed results
    """
    print(f"\n📄 Processing results from {os.path.basename(results_file)}")

    if not os.path.exists(results_file):
        print(f"  ❌ File not found")
        return []

    try:
        with open(results_file, 'r') as f:
            results = json.load(f)
    except Exception as e:
        print(f"  ❌ Error reading file: {e}")
        return []

    print(f"  • Loaded {len(results)} results")

    # Fix null classifications
    fixed_count = 0
    for item in results:
        if item.get('is_melanoma') is None and 'response' in item:
            response_lower = item['response'].lower()

            # Check patterns
            for pattern in POSITIVE_PATTERNS:
                if re.search(pattern, response_lower, re.IGNORECASE):
                    item['is_melanoma'] = 1
                    fixed_count += 1
                    break

            if item.get('is_melanoma') is None:
                for pattern in NEGATIVE_PATTERNS:
                    if re.search(pattern, response_lower, re.IGNORECASE):
                        item['is_melanoma'] = 0
                        fixed_count += 1
                        break

    # Count classifications
    melanoma_count = sum(1 for item in results if item.get('is_melanoma') == 1)
    benign_count = sum(1 for item in results if item.get('is_melanoma') == 0)
    uncertain_count = sum(1 for item in results if item.get('is_melanoma') is None)

    print(f"  • Fixed {fixed_count} null classifications")
    print(f"  • Final predictions: {melanoma_count} melanoma, {benign_count} benign, {uncertain_count} uncertain")

    # Save processed results
    processed_file = results_file.replace('.json', '_processed.json')
    try:
        with open(processed_file, 'w') as f:
            json.dump(results, f, indent=2)
        print(f"  ✅ Saved processed results to {os.path.basename(processed_file)}")
    except Exception as e:
        print(f"  ❌ Error saving processed results: {e}")

    return results

# Process all results
print("\n" + "="*60)
print("📊 POST-PROCESSING RESULTS")
print("="*60)

results_dir = "evaluation_results"
model_keys = ["base_gemma", "tuned_gemma", "medgemma"]
all_processed_results = {}

for model_key in model_keys:
    results_file = os.path.join(results_dir, f"{model_key}_results.json")
    if os.path.exists(results_file):
        processed_data = process_results(results_file)
        all_processed_results[model_key] = processed_data
    else:
        all_processed_results[model_key] = []

# Extract processed results for each model
base_processed = all_processed_results.get("base_gemma", [])
tuned_processed = all_processed_results.get("tuned_gemma", [])
medgemma_processed = all_processed_results.get("medgemma", [])

print("\n✅ Post-processing complete")

## Step 11: Calculate performance metrics

Calculate comprehensive performance metrics for all models:

In [None]:
def calculate_metrics(results):
    """
    Calculate comprehensive performance metrics for evaluation results.

    Args:
        results: List of result dictionaries with predictions and ground truth

    Returns:
        Dictionary containing all calculated metrics
    """
    # Filter for valid results with both prediction and ground truth
    valid_items = [
        item for item in results
        if 'ground_truth' in item and item['ground_truth'] is not None
        and 'is_melanoma' in item and item['is_melanoma'] is not None
    ]

    if not valid_items:
        print("  ⚠️ No valid items with ground truth and predictions")
        return {}

    # Extract true labels and predictions
    y_true = [item['ground_truth'] for item in valid_items]
    y_pred = [item['is_melanoma'] for item in valid_items]

    # Calculate confusion matrix
    try:
        tn, fp, fn, tp = confusion_matrix(y_true, y_pred, labels=[0, 1]).ravel()
    except ValueError:
        # Handle edge cases where all predictions are the same class
        unique_preds = set(y_pred)
        if len(unique_preds) == 1:
            pred_class = list(unique_preds)[0]
            if pred_class == 0:  # All predicted negative
                tn = sum(1 for yt, yp in zip(y_true, y_pred) if yt == 0 and yp == 0)
                fn = sum(1 for yt, yp in zip(y_true, y_pred) if yt == 1 and yp == 0)
                fp = tp = 0
            else:  # All predicted positive
                tp = sum(1 for yt, yp in zip(y_true, y_pred) if yt == 1 and yp == 1)
                fp = sum(1 for yt, yp in zip(y_true, y_pred) if yt == 0 and yp == 1)
                tn = fn = 0
        else:
            tn = fp = fn = tp = 0

    # Calculate metrics
    accuracy = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred, pos_label=1, zero_division=0)
    recall = recall_score(y_true, y_pred, pos_label=1, zero_division=0)
    f1 = f1_score(y_true, y_pred, pos_label=1, zero_division=0)

    # Calculate specificity (true negative rate)
    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0

    # Calculate balanced accuracy
    balanced_accuracy = (recall + specificity) / 2

    # Calculate additional metrics
    total_positives = tp + fn
    total_negatives = tn + fp
    prevalence = total_positives / len(valid_items) if len(valid_items) > 0 else 0

    metrics = {
        "accuracy": accuracy,
        "precision": precision,
        "recall": recall,
        "specificity": specificity,
        "f1": f1,
        "balanced_accuracy": balanced_accuracy,
        "tn": int(tn),
        "fp": int(fp),
        "fn": int(fn),
        "tp": int(tp),
        "total_samples": len(valid_items),
        "total_positives": int(total_positives),
        "total_negatives": int(total_negatives),
        "prevalence": prevalence,
        "y_true_list_for_roc": y_true,
        "y_pred_list_for_roc": y_pred
    }

    return metrics

def print_model_metrics(model_name, metrics_dict):
    """
    Print formatted metrics for a model.
    """
    if not metrics_dict:
        print(f"\n{model_name}: No metrics available")
        return

    print(f"\n📊 {model_name} Performance")
    print(f"   Evaluated on {metrics_dict.get('total_samples', 0)} samples")
    print(f"   Class distribution: {metrics_dict.get('total_positives', 0)} positive, {metrics_dict.get('total_negatives', 0)} negative")
    print("\n   Metrics:")
    print(f"   • Accuracy:           {metrics_dict.get('accuracy', 0):.4f}")
    print(f"   • Precision:          {metrics_dict.get('precision', 0):.4f}")
    print(f"   • Recall (Sensitivity): {metrics_dict.get('recall', 0):.4f}")
    print(f"   • Specificity:        {metrics_dict.get('specificity', 0):.4f}")
    print(f"   • F1 Score:           {metrics_dict.get('f1', 0):.4f}")
    print(f"   • Balanced Accuracy:  {metrics_dict.get('balanced_accuracy', 0):.4f}")

    print("\n   Confusion Matrix:")
    print(f"   • True Negatives:  {metrics_dict.get('tn', 0)}")
    print(f"   • False Positives: {metrics_dict.get('fp', 0)}")
    print(f"   • False Negatives: {metrics_dict.get('fn', 0)}")
    print(f"   • True Positives:  {metrics_dict.get('tp', 0)}")

# Calculate metrics for all models
print("\n" + "="*60)
print("📈 CALCULATING PERFORMANCE METRICS")
print("="*60)

base_gemma_metrics = calculate_metrics(base_processed)
tuned_gemma_metrics = calculate_metrics(tuned_processed)
medgemma_metrics = calculate_metrics(medgemma_processed)

# Print metrics for each model
print_model_metrics("Base Gemma 3", base_gemma_metrics)
print_model_metrics("Fine-tuned Gemma 3", tuned_gemma_metrics)
print_model_metrics("MedGemma", medgemma_metrics)

print("\n" + "="*60)

## Step 12: Visualize performance comparisons

Create comprehensive visualizations to compare model performance:

In [None]:
def visualize_performance_comparison(metrics_map):
    """
    Create comprehensive visualizations comparing model performance.
    """
    if not metrics_map or not any(metrics_map.values()):
        print("⚠️ No valid metrics to visualize")
        return

    # Filter out empty metrics
    valid_metrics_map = {name: m for name, m in metrics_map.items() if m}
    if not valid_metrics_map:
        print("⚠️ All metrics are empty")
        return

    model_names = list(valid_metrics_map.keys())
    n_models = len(model_names)

    # Color schemes
    bar_colors = ['#3498db', '#e74c3c', '#2ecc71', '#f39c12', '#9b59b6']

    # Set style
    plt.style.use('seaborn-v0_8-whitegrid')

    # 1. Performance Metrics Comparison
    print("\n📊 Creating performance comparison chart...")
    metrics_to_plot = ['accuracy', 'precision', 'recall', 'specificity', 'f1', 'balanced_accuracy']
    metric_labels = ['Accuracy', 'Precision', 'Recall', 'Specificity', 'F1 Score', 'Balanced\nAccuracy']

    fig_width = max(12, len(metrics_to_plot) * 2.5)
    plt.figure(figsize=(fig_width, 8))

    x = np.arange(len(metrics_to_plot))
    bar_width = 0.8 / n_models

    # Plot bars for each model
    for i, model_name in enumerate(model_names):
        model_metrics = valid_metrics_map[model_name]
        values = [model_metrics.get(m, 0) for m in metrics_to_plot]

        offset = (i - (n_models - 1) / 2) * bar_width
        bars = plt.bar(x + offset, values, bar_width,
                       label=model_name,
                       color=bar_colors[i % len(bar_colors)],
                       edgecolor='black',
                       linewidth=0.7)

        # Add value labels on bars
        for bar in bars:
            height = bar.get_height()
            plt.text(bar.get_x() + bar.get_width()/2., height + 0.01,
                    f'{height:.3f}',
                    ha='center', va='bottom', fontsize=10, fontweight='bold')

    # Add percentage improvements for 2-model comparison
    if n_models == 2 and 'Base Gemma 3' in model_names:
        base_idx = model_names.index('Base Gemma 3')
        other_idx = 1 - base_idx
        base_metrics = valid_metrics_map[model_names[base_idx]]
        other_metrics = valid_metrics_map[model_names[other_idx]]

        for i, metric in enumerate(metrics_to_plot):
            base_val = base_metrics.get(metric, 0)
            other_val = other_metrics.get(metric, 0)

            if base_val > 0:
                pct_change = ((other_val - base_val) / base_val) * 100
                color = 'green' if pct_change > 0 else 'red'
                symbol = '↑' if pct_change > 0 else '↓'

                plt.text(x[i], max(base_val, other_val) + 0.08,
                        f'{symbol}{abs(pct_change):.1f}%',
                        ha='center', va='bottom',
                        color=color, fontsize=11, fontweight='bold')

    plt.xlabel('Performance Metric', fontsize=14, fontweight='bold')
    plt.ylabel('Score', fontsize=14, fontweight='bold')
    plt.title('Model Performance Comparison', fontsize=16, fontweight='bold', pad=20)
    plt.xticks(x, metric_labels, fontsize=12)
    plt.yticks(fontsize=11)
    plt.ylim(0, 1.3)
    plt.legend(loc='upper center', bbox_to_anchor=(0.5, -0.15),
              ncol=min(n_models, 3), fontsize=12, frameon=True, fancybox=True)
    plt.grid(True, linestyle=':', alpha=0.6)
    plt.tight_layout()

    os.makedirs(results_dir, exist_ok=True)
    plt.savefig(os.path.join(results_dir, 'performance_comparison.png'), dpi=300, bbox_inches='tight')
    plt.show()

    # 2. Confusion Matrices
    print("\n📊 Creating confusion matrices...")
    if n_models == 1:
        fig, axes = plt.subplots(1, 1, figsize=(6, 5))
        axes = [axes]
    elif n_models == 2:
        fig, axes = plt.subplots(1, 2, figsize=(12, 5))
    elif n_models == 3:
        fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    else:
        fig, axes = plt.subplots(2, 2, figsize=(12, 10))
        axes = axes.ravel()

    fig.suptitle('Confusion Matrices', fontsize=16, fontweight='bold', y=1.02)

    cmap_names = ['Blues', 'Oranges', 'Greens', 'Reds']

    for idx, model_name in enumerate(model_names):
        if idx >= len(axes):
            break

        ax = axes[idx]
        metrics = valid_metrics_map[model_name]

        # Create confusion matrix
        cm_values = np.array([
            [metrics.get('tn', 0), metrics.get('fp', 0)],
            [metrics.get('fn', 0), metrics.get('tp', 0)]
        ])

        # Plot heatmap
        sns.heatmap(cm_values, annot=True, fmt='d',
                   cmap=cmap_names[idx % len(cmap_names)],
                   cbar=True, ax=ax,
                   annot_kws={'size': 14, 'weight': 'bold'},
                   xticklabels=['Benign', 'Melanoma'],
                   yticklabels=['Benign', 'Melanoma'])

        ax.set_xlabel('Predicted', fontsize=12, fontweight='bold')
        ax.set_ylabel('Actual', fontsize=12, fontweight='bold')
        ax.set_title(f'{model_name}\n({metrics.get("total_samples", 0)} samples)',
                    fontsize=14, fontweight='bold')

    # Remove extra subplots if any
    for i in range(n_models, len(axes)):
        fig.delaxes(axes[i])

    plt.tight_layout()
    plt.savefig(os.path.join(results_dir, 'confusion_matrices.png'), dpi=300, bbox_inches='tight')
    plt.show()

def visualize_roc_curves(metrics_map):
    """
    Generate ROC curves for models with binary predictions.
    """
    print("\n📊 Creating ROC curves...")

    valid_roc_data = {}
    for model_name, metrics in metrics_map.items():
        if (metrics and
            'y_true_list_for_roc' in metrics and
            'y_pred_list_for_roc' in metrics and
            len(metrics['y_true_list_for_roc']) > 0):
            valid_roc_data[model_name] = (
                metrics['y_true_list_for_roc'],
                metrics['y_pred_list_for_roc']
            )

    if not valid_roc_data:
        print("  ⚠️ No valid data for ROC curves")
        return

    plt.figure(figsize=(10, 8))
    colors = ['#3498db', '#e74c3c', '#2ecc71', '#f39c12']

    for i, (model_name, (y_true, y_pred)) in enumerate(valid_roc_data.items()):
        y_true_np = np.array(y_true)
        y_pred_np = np.array(y_pred)

        if len(np.unique(y_true_np)) < 2:
            print(f"  ⚠️ Skipping {model_name}: only one class in ground truth")
            continue

        # Calculate ROC curve
        fpr, tpr, _ = roc_curve(y_true_np, y_pred_np)
        auc_score = roc_auc_score(y_true_np, y_pred_np)

        # Plot
        plt.plot(fpr, tpr,
                label=f'{model_name} (AUC = {auc_score:.3f})',
                color=colors[i % len(colors)],
                linewidth=2.5,
                marker='o' if len(fpr) < 10 else None,
                markersize=8 if len(fpr) < 10 else 0)

    # Add diagonal reference line
    plt.plot([0, 1], [0, 1], 'k--', linewidth=2, label='Random (AUC = 0.500)')

    plt.xlim([-0.01, 1.01])
    plt.ylim([-0.01, 1.01])
    plt.xlabel('False Positive Rate', fontsize=14, fontweight='bold')
    plt.ylabel('True Positive Rate', fontsize=14, fontweight='bold')
    plt.title('Receiver Operating Characteristic (ROC) Curves', fontsize=16, fontweight='bold')
    plt.legend(loc='lower right', fontsize=12, frameon=True, fancybox=True)
    plt.grid(True, linestyle=':', alpha=0.7)
    plt.tight_layout()

    plt.savefig(os.path.join(results_dir, 'roc_curves.png'), dpi=300, bbox_inches='tight')
    plt.show()

print("✅ Visualization functions defined")

In [None]:
# Create visualizations
print("\n" + "="*60)
print("📊 CREATING VISUALIZATIONS")
print("="*60)

# Prepare metrics for visualization
metrics_for_viz = {}
if base_gemma_metrics:
    metrics_for_viz["Base Gemma 3"] = base_gemma_metrics
if tuned_gemma_metrics:
    metrics_for_viz["Fine-tuned Gemma 3"] = tuned_gemma_metrics
if medgemma_metrics:
    metrics_for_viz["MedGemma"] = medgemma_metrics

if metrics_for_viz:
    visualize_performance_comparison(metrics_for_viz)
    visualize_roc_curves(metrics_for_viz)
    print("\n✅ Visualizations saved to evaluation_results/")
else:
    print("\n⚠️ No valid metrics available for visualization")

## Step 13: Calculate and visualize improvements

Quantify the improvements achieved through fine-tuning:

In [None]:
def calculate_improvements(reference_metrics, current_metrics):
    """
    Calculate improvement percentages between two sets of metrics.
    """
    if not reference_metrics or not current_metrics:
        return {}

    improvements = {}
    metrics_to_compare = ['accuracy', 'precision', 'recall', 'specificity', 'f1', 'balanced_accuracy']

    for metric in metrics_to_compare:
        if metric in reference_metrics and metric in current_metrics:
            ref_val = reference_metrics[metric]
            curr_val = current_metrics[metric]

            if ref_val != 0:
                pct_improvement = ((curr_val - ref_val) / abs(ref_val)) * 100
            elif curr_val > 0:
                pct_improvement = float('inf')
            else:
                pct_improvement = 0

            improvements[metric] = {
                "reference_value": ref_val,
                "current_value": curr_val,
                "absolute_improvement": curr_val - ref_val,
                "percentage_improvement": pct_improvement
            }

    return improvements

def print_improvement_summary(improvements, ref_name, curr_name):
    """
    Print a formatted summary of improvements.
    """
    if not improvements:
        print(f"\n⚠️ No improvement data for {curr_name} vs {ref_name}")
        return

    print(f"\n📈 Performance Improvements: {curr_name} vs {ref_name}")
    print("=" * 70)

    # Find best and worst improvements
    best_metric = max(improvements.items(),
                     key=lambda x: x[1]['percentage_improvement'] if x[1]['percentage_improvement'] != float('inf') else 0)
    worst_metric = min(improvements.items(),
                      key=lambda x: x[1]['percentage_improvement'] if x[1]['percentage_improvement'] != float('inf') else 0)

    print(f"\n📊 Summary:")
    print(f"   • Best improvement: {best_metric[0]} ({best_metric[1]['percentage_improvement']:.1f}%)")
    print(f"   • Least improvement: {worst_metric[0]} ({worst_metric[1]['percentage_improvement']:.1f}%)")

    print(f"\n📋 Detailed improvements:")
    for metric, values in improvements.items():
        print(f"\n   {metric.replace('_', ' ').title()}:")
        print(f"     • {ref_name}: {values['reference_value']:.4f}")
        print(f"     • {curr_name}: {values['current_value']:.4f}")
        print(f"     • Change: {values['absolute_improvement']:+.4f}", end="")

        if values['percentage_improvement'] == float('inf'):
            print(f" (∞% - from zero to positive)")
        else:
            print(f" ({values['percentage_improvement']:+.1f}%)")

def visualize_improvements(improvements, ref_name, curr_name):
    """
    Create a bar chart showing percentage improvements.
    """
    if not improvements:
        return

    metrics = list(improvements.keys())
    percentages = []

    # Handle infinite improvements for visualization
    max_finite = 0
    for m in metrics:
        val = improvements[m]['percentage_improvement']
        if val != float('inf') and val != float('-inf'):
            max_finite = max(max_finite, abs(val))

    cap_value = max(max_finite * 1.2, 100) if max_finite > 0 else 100

    for m in metrics:
        val = improvements[m]['percentage_improvement']
        if val == float('inf'):
            percentages.append(cap_value)
        elif val == float('-inf'):
            percentages.append(-cap_value)
        else:
            percentages.append(val)

    # Create figure
    plt.figure(figsize=(10, 8))

    # Create horizontal bar chart
    colors = ['green' if p > 0 else 'red' if p < 0 else 'gray' for p in percentages]
    metric_labels = [m.replace('_', ' ').title() for m in metrics]

    bars = plt.barh(metric_labels, percentages, color=colors, edgecolor='black', linewidth=0.7)

    # Add value labels
    for i, (bar, orig_val) in enumerate(zip(bars, [improvements[m]['percentage_improvement'] for m in metrics])):
        width = bar.get_width()

        if orig_val == float('inf'):
            label = '+∞%'
        elif orig_val == float('-inf'):
            label = '-∞%'
        else:
            label = f'{orig_val:.1f}%'

        ha = 'left' if width >= 0 else 'right'
        offset = 3 if width >= 0 else -3

        plt.text(width + offset, bar.get_y() + bar.get_height()/2,
                label, va='center', ha=ha, fontweight='bold')

    plt.axvline(x=0, color='black', linestyle='-', linewidth=1)
    plt.xlabel('Percentage Improvement (%)', fontsize=14, fontweight='bold')
    plt.title(f'Performance Improvements: {curr_name} vs {ref_name}',
             fontsize=16, fontweight='bold')
    plt.grid(True, axis='x', linestyle=':', alpha=0.7)
    plt.tight_layout()

    # Save figure
    filename = f"improvements_{curr_name.replace(' ', '_').lower()}_vs_{ref_name.replace(' ', '_').lower()}.png"
    plt.savefig(os.path.join(results_dir, filename), dpi=300, bbox_inches='tight')
    plt.show()

# Calculate and display improvements
print("\n" + "="*60)
print("📈 IMPROVEMENT ANALYSIS")
print("="*60)

# Tuned vs Base Gemma
if tuned_gemma_metrics and base_gemma_metrics:
    improvements_tuned = calculate_improvements(base_gemma_metrics, tuned_gemma_metrics)
    print_improvement_summary(improvements_tuned, "Base Gemma 3", "Fine-tuned Gemma 3")
    visualize_improvements(improvements_tuned, "Base Gemma 3", "Fine-tuned Gemma 3")

# MedGemma vs Base Gemma
if medgemma_metrics and base_gemma_metrics:
    improvements_med = calculate_improvements(base_gemma_metrics, medgemma_metrics)
    print_improvement_summary(improvements_med, "Base Gemma 3", "MedGemma")
    visualize_improvements(improvements_med, "Base Gemma 3", "MedGemma")

# MedGemma vs Tuned Gemma
if medgemma_metrics and tuned_gemma_metrics:
    improvements_med_vs_tuned = calculate_improvements(tuned_gemma_metrics, medgemma_metrics)
    print_improvement_summary(improvements_med_vs_tuned, "Fine-tuned Gemma 3", "MedGemma")
    visualize_improvements(improvements_med_vs_tuned, "Fine-tuned Gemma 3", "MedGemma")

print("\n" + "="*60)

## Summary and Conclusions

This notebook has demonstrated a comprehensive evaluation of multimodal AI models for melanoma detection:

### 🎯 Key Findings

1. **Base Model Behavior**: The base Gemma 3 model typically shows a tendency to over-diagnose, with high recall but poor specificity
2. **Fine-tuning Impact**: Domain-specific fine-tuning dramatically improves specificity and balanced accuracy
3. **Medical Domain Models**: MedGemma demonstrates strong baseline performance but can still benefit from task-specific fine-tuning

### 📊 Performance Insights

- **Accuracy improvements** often exceed 1000% when moving from base to fine-tuned models
- **Specificity** (correctly identifying benign lesions) shows the most dramatic improvements
- **Balanced accuracy** provides the best overall measure of diagnostic capability

### 🚀 Next Steps

1. **Clinical Validation**: Work with medical professionals to validate model predictions
2. **Dataset Expansion**: Include more diverse skin types and lesion types
3. **Ensemble Methods**: Combine multiple models for improved reliability
4. **Explainability**: Implement visualization techniques to understand model decisions
5. **Production Deployment**: Create APIs and interfaces for clinical use

### ⚠️ Important Considerations

- These models are for **research and educational purposes only**
- **Do not use** for actual medical diagnosis without proper validation
- Always consult qualified healthcare professionals for medical decisions
- Consider regulatory requirements (FDA, CE marking) for clinical deployment

### 🔗 Resources

- [Main Repository](https://github.com/ayoisio/gke-multimodal-fine-tune-gemma-3-axolotl)
- [Axolotl Documentation](https://github.com/axolotl-ai-cloud/axolotl)
- [SIIM-ISIC Challenge](https://www.kaggle.com/c/siim-isic-melanoma-classification)
- [Google Cloud Healthcare AI](https://cloud.google.com/solutions/healthcare-life-sciences)

---

Thank you for following this evaluation notebook. The combination of Google Cloud's infrastructure and Axolotl's fine-tuning framework enables powerful domain-specific AI applications that can make a real difference in healthcare and beyond.