# Multimodal Data Preparation for Melanoma Detection with Axolotl

<table align="left">
  <td style="text-align: center">
    <a href="https://colab.research.google.com/github/ayoisio/gke-multimodal-fine-tune-gemma-3-axolotl/blob/main/data/DataPreparation.ipynb">
      <img src="https://cloud.google.com/ml-engine/images/colab-logo-32px.png" alt="Google Colaboratory logo"><br> Run in Colab
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/colab/import/https:%2F%2Fraw.githubusercontent.com%2Fayoisio%2Fgke-multimodal-fine-tune-gemma-3-axolotl%2Fmain%2Fdata%2FDataPreparation.ipynb">
      <img width="32px" src="https://lh3.googleusercontent.com/JmcxdQi-qOpctIvWKgPtrzZdJJK-J3sWE1RsfjZNwshCFgE_9fULcNpuXYTilIR2hjwN" alt="Google Cloud Colab Enterprise logo"><br> Run in Colab Enterprise
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://github.com/ayoisio/gke-multimodal-fine-tune-gemma-3-axolotl/blob/main/data/DataPreparation.ipynb">
      <img src="https://cloud.google.com/ml-engine/images/github-logo-32px.png" alt="GitHub logo"><br> View on GitHub
    </a>
  </td>
  <td style="text-align: center">
    <a href="https://console.cloud.google.com/vertex-ai/workbench/deploy-notebook?download_url=https://raw.githubusercontent.com/ayoisio/gke-multimodal-fine-tune-gemma-3-axolotl/main/data/DataPreparation.ipynb">
      <img src="https://lh3.googleusercontent.com/UiNooY4LUgW_oTvpsNhPpQzsstV5W8F7rYgxgGBD85cWJoLmrOzhVs_ksK_vgx40SHs7jCqkTkCk=e14-rj-sc0xffffff-h130-w32" alt="Vertex AI logo"><br> Open in Vertex AI Workbench
    </a>
  </td>
</table>

## Overview

This notebook demonstrates how to prepare the [SIIM-ISIC Melanoma Classification](https://challenge2020.isic-archive.com/) dataset for fine-tuning multimodal AI models using Axolotl on Google Kubernetes Engine (GKE). The process transforms over 33,000 dermoscopic images into the specific format required for multimodal fine-tuning.

### What you'll learn

- How to efficiently download and process large medical imaging datasets using Google Cloud Storage
- How to create stratified train/validation/test splits while maintaining class distribution
- How to format multimodal data for Axolotl's chat template format
- How to handle class imbalance in medical imaging datasets

### Prerequisites

- Google Cloud Project with billing enabled
- Access to Google Cloud Storage
- Python 3.8+ environment
- Approximately 40GB of temporary storage for processing

### Time to complete

30-40 minutes (primarily due to downloading and processing 32GB of images)

---

## Introduction

The SIIM-ISIC dataset contains over 33,000 dermoscopic images of skin lesions with corresponding labels indicating whether each lesion is benign or malignant melanoma. The dataset was released as part of the ISIC 2020 Challenge to help improve melanoma detection algorithms. Each image comes with additional metadata including the patient's age, sex, and the anatomical site of the lesion.

Our goal is to transform this raw medical imaging dataset into the specific format required by Axolotl for multimodal fine-tuning. This involves:

1. **Downloading and exploring the dataset** - Understanding the data structure and class distribution
2. **Splitting the data** - Creating training (80%), validation (10%), and test (10%) sets while maintaining class distribution
3. **Processing images and labels** - Converting to Axolotl's chat template format
4. **Creating JSONL files** - Generating properly formatted files that maintain image-diagnosis relationships

The output will be a training-ready dataset that follows the [chat_template](https://docs.axolotl.ai/docs/multimodal.html#dataset-format) format for fine-tuning Gemma 3 on this melanoma classification task.

**‚ö†Ô∏è Note**: This notebook contains medical imagery. The content is intended for educational and research purposes only.

## Step 1: Install dependencies

First, let's install the required packages for data processing and Google Cloud Storage integration:

In [None]:
# Install required packages
!pip install google-cloud-storage numpy==1.26.4 pandas matplotlib seaborn scikit-learn tqdm -q

# Import to verify installation
import google.cloud.storage
print(f"‚úÖ google-cloud-storage version: {google.cloud.storage.__version__}")

## Step 2: Set up your environment

Configure your Google Cloud authentication and project settings:

In [None]:
import os
import sys

# Check if we're running in Colab
IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    from google.colab import auth
    auth.authenticate_user()
    print("‚úÖ Authenticated via Colab")
else:
    # For Vertex AI Workbench or local environments
    print("‚ÑπÔ∏è Using Application Default Credentials")
    print("   If not authenticated, run: gcloud auth application-default login")

In [None]:
# Set your project ID
PROJECT_ID = "YOUR_PROJECT_ID"  # @param {type:"string"}

# Set project
!gcloud config set project {PROJECT_ID}

# Verify project is set
!echo "Current project: $(gcloud config get-value project)"

## Step 3: Import libraries and configure settings

Import all necessary libraries and set up configuration parameters:

In [None]:
import json
import os
import shutil
import tempfile
import warnings
import zipfile
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
import logging

# Third-party imports
from google.cloud import storage
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm

# Suppress warnings for cleaner output
warnings.filterwarnings("ignore", category=FutureWarning)

# Configuration
# ‚ö†Ô∏è IMPORTANT: Update these values with your bucket name
GCS_BUCKET_NAME = f"{PROJECT_ID}-melanoma-dataset"  # @param {type:"string"}
SOURCE_BUCKET = GCS_BUCKET_NAME
DEST_BUCKET = GCS_BUCKET_NAME
SOURCE_FOLDER = "isic-challenge-data.s3.amazonaws.com/2020"
DEST_FOLDER = "axolotl-data"

print(f"üìÅ Using GCS bucket: gs://{GCS_BUCKET_NAME}")
print(f"üìÇ Source folder: {SOURCE_FOLDER}")
print(f"üìÇ Destination folder: {DEST_FOLDER}")

# Initialize Cloud Storage client
client = storage.Client(project=PROJECT_ID)
source_bucket = client.bucket(SOURCE_BUCKET)
dest_bucket = client.bucket(DEST_BUCKET)

# Create temp directories
os.makedirs("/tmp/train_images", exist_ok=True)
print("‚úÖ Temporary directories created")

# Set thread pool parameters for parallel processing
MAX_WORKERS = 16  # Adjust based on your system's capabilities
print(f"üîß Parallel processing with {MAX_WORKERS} workers")

### Verify bucket access

In [None]:
# Check if bucket exists and is accessible
try:
    source_bucket.reload()
    print(f"‚úÖ Successfully connected to bucket: gs://{GCS_BUCKET_NAME}")
    print(f"   Location: {source_bucket.location}")
    print(f"   Storage class: {source_bucket.storage_class}")
except Exception as e:
    print(f"‚ùå Error accessing bucket: {e}")
    print(f"   Please ensure the bucket exists and you have access permissions")

## Step 4: Explore the dataset

Let's list and examine the files available in the SIIM-ISIC dataset:

In [None]:
print("üìã Listing files in the SIIM-ISIC dataset:")
print("=" * 60)

files = list(source_bucket.list_blobs(prefix=SOURCE_FOLDER))
for i, file in enumerate(files, 1):
    file_size_mb = file.size / (1024 * 1024)
    print(f"{i}. {file.name.split('/')[-1]} ({file_size_mb:.1f} MB)")

print(f"\nüìä Total files: {len(files)}")
total_size_gb = sum(f.size for f in files) / (1024 * 1024 * 1024)
print(f"üíæ Total size: {total_size_gb:.1f} GB")

## Step 5: Download and visualize dataset metadata

Let's download the metadata CSV file and create comprehensive visualizations to understand the dataset characteristics:

In [None]:
# Download metadata CSV and examine structure
print("üì• Downloading metadata...")
metadata_blob = source_bucket.blob(f"{SOURCE_FOLDER}/ISIC_2020_Training_GroundTruth_v2.csv")
with tempfile.NamedTemporaryFile() as temp_file:
    metadata_blob.download_to_filename(temp_file.name)
    train_metadata = pd.read_csv(temp_file.name)

    # Print dataset statistics
    print("\nüìä Dataset Overview:")
    print(f"Total samples: {len(train_metadata):,}")
    print(f"\nColumns: {', '.join(train_metadata.columns)}")
    print(f"\nFirst few rows:")
    display(train_metadata.head())

    # Set plot style
    sns.set(style="whitegrid")
    plt.rcParams.update({'font.size': 12})

    # Figure 1: Key Distributions
    plt.figure(figsize=(18, 14))

    # 1. Anatomical Site Distribution
    plt.subplot(2, 2, 1)
    site_counts = train_metadata['anatom_site_general_challenge'].value_counts()
    ax = plt.barh(site_counts.index, site_counts.values, color=plt.cm.Reds(np.linspace(0.3, 0.7, len(site_counts))))
    plt.title('Anatomical Site Distribution', fontsize=16, pad=20)
    plt.xlabel('Count', fontsize=14)

    # Add count labels to bars
    for i, v in enumerate(site_counts.values):
        plt.text(v + 100, i, f'{v:,}', va='center')

    # 2. Age Distribution by Malignancy
    plt.subplot(2, 2, 2)
    ax = sns.boxplot(x='benign_malignant', y='age_approx', data=train_metadata,
                     palette={'benign': 'lightgreen', 'malignant': 'salmon'})
    plt.title('Age Distribution by Malignancy', fontsize=16, pad=20)
    plt.xlabel('Diagnosis Type', fontsize=14)
    plt.ylabel('Age', fontsize=14)

    # 3. Sex Distribution
    plt.subplot(2, 2, 3)
    sex_counts = train_metadata['sex'].value_counts()
    ax = plt.bar(sex_counts.index, sex_counts.values, color=plt.cm.Blues(np.linspace(0.4, 0.7, len(sex_counts))))
    plt.title('Sex Distribution', fontsize=16, pad=20)
    plt.ylabel('Count', fontsize=14)

    # Add count labels to bars
    for i, v in enumerate(sex_counts.values):
        plt.text(i, v + 100, f'{v:,}', ha='center')

    # 4. Benign vs Malignant Distribution
    plt.subplot(2, 2, 4)
    diagnosis_counts = train_metadata['benign_malignant'].value_counts()
    plt.pie(diagnosis_counts, labels=diagnosis_counts.index, autopct='%1.1f%%',
            colors=['lightgreen', 'salmon'], startangle=90, explode=(0, 0.1))
    plt.title('Benign vs Malignant Distribution', fontsize=16, pad=20)
    plt.axis('equal')

    # Add text annotations with counts
    plt.annotate(f"Benign: {diagnosis_counts['benign']:,}", xy=(-1.2, -0.8), fontsize=12)
    plt.annotate(f"Malignant: {diagnosis_counts['malignant']:,}", xy=(0.8, -0.8), fontsize=12)

    plt.tight_layout(pad=3.0)
    plt.show()

    # Figure 2: Relationships between variables
    plt.figure(figsize=(18, 14))

    # 1. Malignancy by Sex
    plt.subplot(2, 2, 1)
    sex_malig = pd.crosstab(train_metadata['sex'], train_metadata['benign_malignant'])
    ax = sex_malig.plot(kind='bar', color=['lightgreen', 'salmon'], ax=plt.gca())
    plt.title('Malignancy by Sex', fontsize=16, pad=20)
    plt.ylabel('Count', fontsize=14)
    plt.xlabel('Sex', fontsize=14)
    plt.xticks(rotation=0)
    plt.legend(title='Diagnosis')

    # Add text labels
    for p in ax.patches:
        ax.annotate(f'{int(p.get_height()):,}',
                   (p.get_x() + p.get_width() / 2., p.get_height()),
                   ha = 'center', va = 'bottom')

    # 2. Malignancy Rate by Sex
    plt.subplot(2, 2, 2)
    sex_malig_rate = pd.crosstab(train_metadata['sex'],
                               train_metadata['benign_malignant'],
                               normalize='index') * 100
    ax = sex_malig_rate.plot(kind='bar', color=['lightgreen', 'salmon'], ax=plt.gca())
    plt.title('Malignancy Rate by Sex (%)', fontsize=16, pad=20)
    plt.ylabel('Percentage', fontsize=14)
    plt.xlabel('Sex', fontsize=14)
    plt.xticks(rotation=0)
    plt.legend(title='Diagnosis')

    # Add text labels
    for p in ax.patches:
        ax.annotate(f"{p.get_height():.1f}%",
                   (p.get_x() + p.get_width() / 2., p.get_height() + 1),
                   ha = 'center')

    # 3. Malignancy Rate by Anatomical Site
    plt.subplot(2, 2, 3)
    # Filter out any NaN values and ensure proper calculation
    site_data = train_metadata.dropna(subset=['anatom_site_general_challenge'])
    site_malig = pd.crosstab(site_data['anatom_site_general_challenge'],
                           site_data['benign_malignant'],
                           normalize='index') * 100
    ax = site_malig.plot(kind='barh', color=['lightgreen', 'salmon'], ax=plt.gca())
    plt.title('Malignancy Rate by Anatomical Site (%)', fontsize=16, pad=20)
    plt.xlabel('Percentage', fontsize=14)
    plt.ylabel('Anatomical Site', fontsize=14)
    plt.legend(title='Diagnosis')

    # Add percentage labels for malignant cases
    for i, v in enumerate(site_malig['malignant']):
        plt.text(v + 0.5, i, f"{v:.1f}%", va='center')

    # 4. Missing Values Visualization
    plt.subplot(2, 2, 4)
    missing_values = train_metadata.isnull().sum().sort_values(ascending=False)
    # Only show columns with missing values
    missing_values = missing_values[missing_values > 0]
    if len(missing_values) > 0:
        ax = plt.barh(missing_values.index, missing_values.values,
                     color=plt.cm.Greys(np.linspace(0.3, 0.7, len(missing_values))))
        plt.title('Missing Values per Column', fontsize=16, pad=20)
        plt.xlabel('Count', fontsize=14)

        # Add count labels to bars
        for i, v in enumerate(missing_values.values):
            plt.text(v + 5, i, f'{v:,}', va='center')
    else:
        plt.text(0.5, 0.5, "No missing values!", ha='center', va='center', fontsize=14)
        plt.title('Missing Values Check', fontsize=16, pad=20)
        plt.axis('off')

    plt.tight_layout(pad=3.0)
    plt.show()

    # Figure 3: Additional visualizations
    plt.figure(figsize=(18, 8))

    # 1. Age distribution
    plt.subplot(1, 2, 1)
    ax = sns.histplot(train_metadata['age_approx'].dropna(), bins=15, kde=True)
    plt.title('Age Distribution', fontsize=16, pad=20)
    plt.xlabel('Age', fontsize=14)
    plt.ylabel('Count', fontsize=14)

    # Add statistics annotation
    age_mean = train_metadata['age_approx'].mean()
    age_median = train_metadata['age_approx'].median()
    plt.annotate(f"Mean: {age_mean:.1f}\nMedian: {age_median:.1f}",
                xy=(0.05, 0.95), xycoords='axes fraction',
                bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="gray", alpha=0.8))

    # 2. Class imbalance (target variable)
    plt.subplot(1, 2, 2)
    target_counts = train_metadata['target'].value_counts()
    ax = plt.bar(['Non-Melanoma (0)', 'Melanoma (1)'], target_counts.values,
               color=['lightgreen', 'salmon'])
    plt.text(0, target_counts[0]/2, f"{target_counts[0]:,} samples",
           ha='center', va='center', color='black', fontweight='bold')
    plt.text(1, target_counts[1]/2, f"{target_counts[1]:,} samples",
           ha='center', va='center', color='black', fontweight='bold')
    plt.title('Class Distribution (Target Variable)', fontsize=16, pad=20)
    plt.ylabel('Count', fontsize=14)

    # Add imbalance ratio annotation
    imbalance_ratio = target_counts[0] / target_counts[1]
    plt.figtext(0.5, 0.01, f"Class imbalance ratio: {imbalance_ratio:.1f}:1 (Non-melanoma:Melanoma)",
              ha='center', fontsize=14, bbox={"facecolor":"orange", "alpha":0.2, "pad":5})

    plt.tight_layout(pad=3.0)
    plt.show()

    # Print additional textual information
    print("\nüìä Diagnosis distribution:")
    print(train_metadata['diagnosis'].value_counts())

    print("\nüìä Benign/Malignant distribution:")
    print(train_metadata['benign_malignant'].value_counts())

    # Check for missing values
    print("\n‚ö†Ô∏è Missing values per column:")
    missing_summary = train_metadata.isnull().sum()
    missing_summary = missing_summary[missing_summary > 0]
    if len(missing_summary) > 0:
        print(missing_summary)
    else:
        print("No missing values found!")

    # Class imbalance ratio
    imbalance_ratio = len(train_metadata[train_metadata['target'] == 0]) / len(train_metadata[train_metadata['target'] == 1])
    print(f"\n‚öñÔ∏è Class imbalance ratio (Benign:Malignant): {imbalance_ratio:.2f}:1")

## Step 6: Download and extract image data

Now let's download the actual dermoscopic images. This is the most time-consuming step due to the large file size:

In [None]:
print("üì• Downloading training images ZIP (this may take 10-15 minutes)...")
train_zip_blob = source_bucket.blob(f"{SOURCE_FOLDER}/ISIC_2020_Training_JPEG.zip")

# Show download progress
with tempfile.NamedTemporaryFile(suffix=".zip") as temp_zip:
    # Download with progress tracking
    start_time = time.time()
    train_zip_blob.download_to_filename(temp_zip.name)
    download_time = time.time() - start_time

    print(f"‚úÖ Download completed in {download_time/60:.1f} minutes")

    print("üì¶ Extracting training images...")
    extract_start = time.time()

    with zipfile.ZipFile(temp_zip.name, 'r') as zip_ref:
        # Get total number of files for progress tracking
        file_list = zip_ref.namelist()
        total_files = len(file_list)

        # Extract with progress bar
        for file in tqdm(file_list, desc="Extracting files"):
            zip_ref.extract(file, "/tmp/train_images")

    extract_time = time.time() - extract_start
    print(f"‚úÖ Extraction completed in {extract_time/60:.1f} minutes")

# Count the number of image files extracted
image_files = [f for f in os.listdir("/tmp/train_images/train") if f.endswith('.jpg')]
image_count = len(image_files)
print(f"\nüì∏ Extracted {image_count:,} images to temporary directory")
print(f"üíæ Total size: {sum(os.path.getsize(f'/tmp/train_images/train/{f}') for f in image_files) / (1024**3):.1f} GB")

## Step 7: Create stratified train/validation/test splits

We'll split the data into training (80%), validation (10%), and test (10%) sets while maintaining the class distribution:

In [None]:
print("üîÄ Creating stratified train/validation/test splits...\n")

# First split: 80% train, 20% temp
train_df, temp_df = train_test_split(
    train_metadata,
    test_size=0.2,
    random_state=42,
    stratify=train_metadata['target']
)

# Second split: Split the temp set into validation and test (50/50)
val_df, test_df = train_test_split(
    temp_df,
    test_size=0.5,
    random_state=42,
    stratify=temp_df['target']
)

print("üìä Dataset splits:")
print(f"  ‚Ä¢ Training:   {len(train_df):,} examples ({len(train_df)/len(train_metadata)*100:.1f}%)")
print(f"  ‚Ä¢ Validation: {len(val_df):,} examples ({len(val_df)/len(train_metadata)*100:.1f}%)")
print(f"  ‚Ä¢ Test:       {len(test_df):,} examples ({len(test_df)/len(train_metadata)*100:.1f}%)")

# Verify class distribution in each split
print("\n‚öñÔ∏è Class distribution (target=1 is melanoma):")
for split_name, split_df in [("Training", train_df), ("Validation", val_df), ("Test", test_df)]:
    counts = split_df['target'].value_counts().sort_index()
    melanoma_pct = counts[1] / len(split_df) * 100
    print(f"\n{split_name} set:")
    print(f"  ‚Ä¢ Benign (0):    {counts[0]:,} ({counts[0]/len(split_df)*100:.1f}%)")
    print(f"  ‚Ä¢ Melanoma (1):  {counts[1]:,} ({melanoma_pct:.1f}%)")

# Visualize the splits
plt.figure(figsize=(12, 6))

splits_data = {
    'Training': train_df['target'].value_counts(),
    'Validation': val_df['target'].value_counts(),
    'Test': test_df['target'].value_counts()
}

x = np.arange(2)
width = 0.25

for i, (split_name, counts) in enumerate(splits_data.items()):
    plt.bar(x + i*width, counts.values, width, label=split_name)

plt.xlabel('Class')
plt.ylabel('Count')
plt.title('Class Distribution Across Splits')
plt.xticks(x + width, ['Benign (0)', 'Melanoma (1)'])
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## Step 8: Define data processing functions

Now we'll create functions to process the data into Axolotl's chat template format:

In [None]:
def process_single_example(row, split_name, dataset_type=None, include_metadata=False):
    """
    Process a single example from the dataset into Axolotl's chat template format.

    Args:
        row: Pandas DataFrame row with the data for this example
        split_name: String indicating the split type (train, val, test)
        dataset_type: String indicating the dataset type folder prefix (optional)
        include_metadata: Boolean indicating whether to include patient metadata

    Returns:
        Processed example in Axolotl's multimodal format
    """
    image_name = f"{row['image_name']}.jpg"

    # For GCS FUSE mounting, paths should not include the bucket name (since the bucket becomes the mount point)
    if dataset_type:
        # Path used in the Axolotl input (for GCS FUSE mounting)
        image_path = f"/mnt/gcs/processed_images/{dataset_type}/{split_name}/{image_name}"
        # Path used for uploading to GCS (includes the bucket)
        upload_path = f"processed_images/{dataset_type}/{split_name}/{image_name}"
    else:
        # Path used in the Axolotl input (for GCS FUSE mounting)
        image_path = f"/mnt/gcs/processed_images/{split_name}/{image_name}"
        # Path used for uploading to GCS (includes the bucket)
        upload_path = f"processed_images/{split_name}/{image_name}"

    # Upload image to processed_images folder
    image_local_path = f"/tmp/train_images/train/{image_name}"
    if os.path.exists(image_local_path):
        dest_blob = dest_bucket.blob(upload_path)
        try:
            dest_blob.upload_from_filename(image_local_path)
        except Exception as e:
            print(f"Failed to upload {image_name}: {str(e)}")

    is_malignant = row['target'] == 1

    # Prepare the user prompt based on whether we include metadata
    if include_metadata:
        # Get metadata values (with error handling for missing or NaN values)
        sex = row['sex'] if pd.notna(row['sex']) else "unknown"
        age = row['age_approx'] if pd.notna(row['age_approx']) else "unknown"
        site = row['anatom_site_general_challenge'] if pd.notna(row['anatom_site_general_challenge']) else "unknown"

        user_text = f"This is a skin lesion from a {sex} patient, age {age}, located on the {site}. Does this appear to be malignant melanoma?"
    else:
        user_text = "Does this skin lesion appear to be malignant melanoma?"

    # Create the example in Axolotl's multimodal format
    example = {
        "messages": [
            {
                "role": "system",
                "content": [
                    {"type": "text", "text": "You are a dermatology assistant that helps identify potential melanoma from skin lesion images."}
                ]
            },
            {
                "role": "user",
                "content": [
                    {"type": "image", "path": image_path},
                    {"type": "text", "text": user_text}
                ]
            },
            {
                "role": "assistant",
                "content": [
                    {"type": "text", "text": "Yes, this appears to be malignant melanoma." if is_malignant else
                                            "No, this does not appear to be malignant melanoma."}
                ]
            }
        ]
    }

    return example

def process_dataset_split(df, split_name, dataset_type=None, include_metadata=False):
    """
    Process a dataset split (train, validation, or test) and create examples.
    Uses parallel processing for efficiency.

    Args:
        df: Pandas DataFrame with the data for this split
        split_name: String indicating the split type (train, val, test)
        dataset_type: String indicating the dataset type folder prefix (optional)
        include_metadata: Boolean indicating whether to include metadata

    Returns:
        List of examples in Axolotl's multimodal format
    """
    print(f"\nüîÑ Processing {len(df):,} examples for {split_name} split")
    examples = []

    # Process examples in parallel
    with ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
        # Create a dictionary of futures to their original positions to maintain order
        future_to_idx = {
            executor.submit(
                process_single_example,
                row,
                split_name,
                dataset_type,
                include_metadata
            ): i for i, (_, row) in enumerate(df.iterrows())
        }

        # Track progress with tqdm
        with tqdm(total=len(df), desc=f"Processing {split_name} split") as pbar:
            for future in as_completed(future_to_idx):
                try:
                    example = future.result()
                    examples.append(example)
                except Exception as e:
                    print(f"Error processing example: {str(e)}")
                pbar.update(1)

    return examples

def create_and_save_dataset(dataset_type=None, include_metadata=False):
    """
    Create and save a complete dataset (train, validation, test) in Axolotl format.

    Args:
        dataset_type: String name for the dataset type folder prefix (optional)
        include_metadata: Boolean indicating whether to include metadata

    Returns:
        Tuple of (train_count, val_count, test_count)
    """
    start_time = time.time()

    # Log dataset creation start
    if dataset_type:
        print(f"\nüöÄ Starting to process {dataset_type} dataset...")
    else:
        print("\nüöÄ Starting to process dataset...")

    # Create the datasets
    train_examples = process_dataset_split(train_df, "train", dataset_type, include_metadata)
    val_examples = process_dataset_split(val_df, "val", dataset_type, include_metadata)
    test_examples = process_dataset_split(test_df, "test", dataset_type, include_metadata)

    # Determine the folder path based on whether dataset_type is provided
    if dataset_type:
        print(f"\nüíæ Writing {dataset_type} JSONL files...")
        folder_prefix = f"{DEST_FOLDER}/{dataset_type}"
    else:
        print("\nüíæ Writing JSONL files...")
        folder_prefix = DEST_FOLDER

    # Write to JSONL files
    # Training set
    print(f"  ‚Ä¢ Writing training set ({len(train_examples):,} examples)...")
    train_output_file = f"{folder_prefix}/siim_isic_train.jsonl"
    train_blob = dest_bucket.blob(train_output_file)
    with train_blob.open("w") as f:
        for example in train_examples:
            f.write(json.dumps(example) + "\n")

    # Validation set
    print(f"  ‚Ä¢ Writing validation set ({len(val_examples):,} examples)...")
    val_output_file = f"{folder_prefix}/siim_isic_val.jsonl"
    val_blob = dest_bucket.blob(val_output_file)
    with val_blob.open("w") as f:
        for example in val_examples:
            f.write(json.dumps(example) + "\n")

    # Test set
    print(f"  ‚Ä¢ Writing test set ({len(test_examples):,} examples)...")
    test_output_file = f"{folder_prefix}/siim_isic_test.jsonl"
    test_blob = dest_bucket.blob(test_output_file)
    with test_blob.open("w") as f:
        for example in test_examples:
            f.write(json.dumps(example) + "\n")

    # Calculate and log total processing time
    total_time = time.time() - start_time
    hours, remainder = divmod(total_time, 3600)
    minutes, seconds = divmod(remainder, 60)
    print(f"\n‚è±Ô∏è Dataset creation completed in {int(hours)}h {int(minutes)}m {int(seconds)}s")

    return len(train_examples), len(val_examples), len(test_examples)

print("‚úÖ Data processing functions defined successfully")

## Step 9: Create the multimodal dataset

Now let's process all the images and create the dataset in Axolotl's format. This step will:
1. Upload images to organized folders in GCS
2. Create JSONL files with proper chat template formatting

In [None]:
# Create the dataset (without patient metadata)
print("üèÅ Starting dataset creation process...")
print("This will upload images to GCS and create JSONL files.")
print("Expected time: 15-20 minutes depending on your connection speed.\n")

train_count, val_count, test_count = create_and_save_dataset(include_metadata=False)

print("\n‚úÖ Dataset created successfully!")
print("=" * 50)
print(f"üìä Final dataset statistics:")
print(f"  ‚Ä¢ Training examples:   {train_count:,}")
print(f"  ‚Ä¢ Validation examples: {val_count:,}")
print(f"  ‚Ä¢ Test examples:       {test_count:,}")
print(f"  ‚Ä¢ Total examples:      {train_count + val_count + test_count:,}")

# Print GCS paths for reference
print(f"\nüìÇ Dataset files created in GCS:")
print(f"  ‚Ä¢ gs://{GCS_BUCKET_NAME}/{DEST_FOLDER}/siim_isic_train.jsonl")
print(f"  ‚Ä¢ gs://{GCS_BUCKET_NAME}/{DEST_FOLDER}/siim_isic_val.jsonl")
print(f"  ‚Ä¢ gs://{GCS_BUCKET_NAME}/{DEST_FOLDER}/siim_isic_test.jsonl")
print(f"\nüì∏ Images uploaded to:")
print(f"  ‚Ä¢ gs://{GCS_BUCKET_NAME}/processed_images/train/")
print(f"  ‚Ä¢ gs://{GCS_BUCKET_NAME}/processed_images/val/")
print(f"  ‚Ä¢ gs://{GCS_BUCKET_NAME}/processed_images/test/")

# Note about metadata-enhanced dataset
print("\nüí° Note: You can also create a metadata-enhanced dataset by uncommenting the line below:")
print("# metadata_train_count, metadata_val_count, metadata_test_count = create_and_save_dataset('metadata', include_metadata=True)")

## Step 10: Verify dataset creation

Let's verify that our dataset was created correctly:

In [None]:
print("üîç Verifying dataset creation...\n")

# Check JSONL files
print("üìÑ JSONL files:")
for split in ['train', 'val', 'test']:
    blob_path = f"{DEST_FOLDER}/siim_isic_{split}.jsonl"
    blob = dest_bucket.blob(blob_path)
    if blob.exists():
        blob.reload()
        size_mb = blob.size / (1024 * 1024)
        print(f"  ‚úÖ {blob_path} ({size_mb:.1f} MB)")
    else:
        print(f"  ‚ùå {blob_path} NOT FOUND")

# Sample a few examples from the training set
print("\nüìù Sample training examples:")
train_blob = dest_bucket.blob(f"{DEST_FOLDER}/siim_isic_train.jsonl")
with train_blob.open("r") as f:
    for i, line in enumerate(f):
        if i >= 2:  # Show first 2 examples
            break
        example = json.loads(line)
        print(f"\nExample {i+1}:")
        print(f"  System: {example['messages'][0]['content'][0]['text'][:50]}...")
        print(f"  User text: {example['messages'][1]['content'][1]['text']}")
        print(f"  Image path: {example['messages'][1]['content'][0]['path']}")
        print(f"  Assistant: {example['messages'][2]['content'][0]['text']}")

# Check image organization
print("\nüì∏ Checking image organization:")
for split in ['train', 'val', 'test']:
    prefix = f"processed_images/{split}/"
    blobs = list(dest_bucket.list_blobs(prefix=prefix, max_results=5))
    count = sum(1 for _ in dest_bucket.list_blobs(prefix=prefix))
    print(f"  ‚Ä¢ {split}: {count:,} images")

print("\n‚úÖ Dataset verification complete!")

## Step 11: Clean up temporary files

Finally, let's clean up the temporary files to free up disk space:

In [None]:
# Clean up temp directories
print("üßπ Cleaning up temporary files...")
shutil.rmtree("/tmp/train_images", ignore_errors=True)
print("‚úÖ Temporary files removed")

print("\n" + "=" * 60)
print("üéâ PROCESSING COMPLETE!")
print("=" * 60)
print("\nüìä Final Dataset Summary:")
print(f"  ‚Ä¢ Total examples: {train_count + val_count + test_count:,}")
print(f"    - Training:   {train_count:,} ({train_count/(train_count+val_count+test_count)*100:.1f}%)")
print(f"    - Validation: {val_count:,} ({val_count/(train_count+val_count+test_count)*100:.1f}%)")
print(f"    - Test:       {test_count:,} ({test_count/(train_count+val_count+test_count)*100:.1f}%)")
print(f"\nüìÅ Data location: gs://{GCS_BUCKET_NAME}/{DEST_FOLDER}/")
print(f"\nüöÄ Next steps:")
print(f"  1. Update your gemma3-melanoma.yaml configuration file")
print(f"  2. Deploy the training job using: ./scripts/deploy-training.sh")
print(f"  3. Monitor training progress with TensorBoard")

## Summary

In this notebook, we successfully:

1. ‚úÖ Downloaded and explored the SIIM-ISIC Melanoma dataset
2. ‚úÖ Created stratified train/validation/test splits (80/10/10)
3. ‚úÖ Processed 33,000+ dermoscopic images into Axolotl's chat template format
4. ‚úÖ Uploaded organized images to Google Cloud Storage
5. ‚úÖ Generated JSONL files ready for multimodal fine-tuning

The dataset is now ready to be used for fine-tuning Gemma 3 with Axolotl on GKE. The multimodal format allows the model to learn from both visual features and structured prompts, enabling it to make clinically relevant predictions.

### Key takeaways

- **Class imbalance**: The dataset has approximately 58:1 ratio of benign to malignant cases, reflecting real-world distribution
- **Scalable processing**: Parallel processing reduced preparation time from hours to minutes
- **Cloud-native approach**: Direct integration with GCS enables seamless training on GKE
- **Flexible format**: The chat template format supports both with and without patient metadata

### What's next?

Check out the [main repository](https://github.com/ayoisio/gke-multimodal-fine-tune-gemma-3-axolotl) for:
- Setting up GKE cluster with GPU support
- Deploying the Axolotl training job
- Evaluating model performance
- Comparing with baseline models

---

**Remember**: This notebook is for educational and research purposes only. Any models trained on this data should not be used for actual medical diagnosis without proper validation and regulatory approval.