# MILK10k Skin Lesion Classification - EfficientNetV2-L Training

This notebook trains an **EfficientNetV2-L** model with **Cross-Modal Attention Fusion** for improved Macro F1 score.

## Overview

| Component | Value |
|-----------|-------|
| **Model** | EfficientNetV2-L (dual backbone) |
| **Fusion** | Cross-Modal Attention + Gated Fusion |
| **Image Size** | 480x480 |
| **Loss** | Asymmetric Loss + Label Smoothing |
| **Features** | Deep Supervision + EMA |
| **GPU** | A100 80GB (Colab Pro) |
| **Expected Time** | 4-5 hours |

## Improvements over EfficientNet-B3

1. **Larger backbone**: 120M params vs 12M (10x more capacity)
2. **Cross-modal attention**: Clinical and dermoscopic images attend to each other
3. **Asymmetric loss**: Better handling of class imbalance
4. **Deep supervision**: Auxiliary heads for better gradient flow
5. **EMA**: Smoother weight updates for better generalization


## 1. Setup Environment

Check GPU and system specifications.


In [None]:
# Check GPU availability
!nvidia-smi

# Check CUDA version
!nvcc --version

# Check disk space and RAM
!df -h | grep -E 'Filesystem|/content'
!free -h

# Verify we have A100 80GB
import subprocess
result = subprocess.run(['nvidia-smi', '--query-gpu=name,memory.total', '--format=csv'], capture_output=True, text=True)
print(f"\nGPU Info:\n{result.stdout}")


## 2. Mount Google Drive


In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os
DRIVE_ROOT = '/content/drive/MyDrive/MILK10k_Project'
os.makedirs(DRIVE_ROOT, exist_ok=True)

print(f"Google Drive mounted!")
print(f"Project root: {DRIVE_ROOT}")


## 3. Install Dependencies


In [None]:
# Install required packages
%pip install -q timm albumentations tensorboard scikit-learn

# Verify installations
import torch
import torchvision
import timm
import albumentations as A

print(f"PyTorch: {torch.__version__}")
print(f"TorchVision: {torchvision.__version__}")
print(f"Timm: {timm.__version__}")
print(f"Albumentations: {A.__version__}")
print(f"CUDA Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")


## 4. Setup Project Files


In [None]:
import shutil
from pathlib import Path

# Create working directory
WORK_DIR = '/content/MILK10k'
os.makedirs(WORK_DIR, exist_ok=True)
%cd {WORK_DIR}

# Copy source code from Drive
SRC_DRIVE = f'{DRIVE_ROOT}/src'
if os.path.exists(SRC_DRIVE):
    !cp -r {SRC_DRIVE} {WORK_DIR}/
    print("Copied src/ from Google Drive")
else:
    print("WARNING: src/ not found in Google Drive. Please upload it first!")

# Copy preprocessed data
PREPROCESSED_DRIVE = f'{DRIVE_ROOT}/preprocessed_data'
if os.path.exists(PREPROCESSED_DRIVE):
    !cp -r {PREPROCESSED_DRIVE} {WORK_DIR}/
    print("Copied preprocessed_data/ from Google Drive")
else:
    print("WARNING: preprocessed_data/ not found. Please upload it first!")

# Link dataset from Google Drive (no copy needed)
DATASET_DRIVE = f'{DRIVE_ROOT}/dataset/MILK10k_Training_Input'
if os.path.exists(DATASET_DRIVE):
    os.makedirs(f'{WORK_DIR}/dataset', exist_ok=True)
    !ln -sf {DATASET_DRIVE} {WORK_DIR}/dataset/MILK10k_Training_Input
    print(f"Linked dataset from Google Drive")
else:
    print(f"WARNING: Dataset not found at: {DATASET_DRIVE}")

# Create output directories
os.makedirs('models', exist_ok=True)
os.makedirs('logs', exist_ok=True)
os.makedirs('results', exist_ok=True)

print(f"\nWorking directory: {WORK_DIR}")
!ls -la


## 5. Import Modules and Configuration


In [None]:
import sys
sys.path.append('/content/MILK10k/src')

# Import modules
from src.config import *
from src.utils import *
from src.dataset import *
from src.models_v2 import *
from src.losses import *
from src.train_v2 import TrainerV2
from src.evaluate import *

# Display V2 configuration
print("="*60)
print("EFFICIENTNETV2-L CONFIGURATION")
print("="*60)

print(f"\nModel Config (V2):")
for key, value in MODEL_CONFIG_V2.items():
    print(f"  {key}: {value}")

print(f"\nTraining Config (V2):")
for key, value in TRAIN_CONFIG_V2.items():
    print(f"  {key}: {value}")

print(f"\nImage Config (V2):")
for key, value in IMAGE_CONFIG_V2.items():
    print(f"  {key}: {value}")

print(f"\nLoss Config (V2):")
for key, value in LOSS_CONFIG_V2.items():
    print(f"  {key}: {value}")


## 6. Load and Prepare Data


In [None]:
import pandas as pd
import json
import re

# Load preprocessed data
print("Loading preprocessed data...")
train_df = pd.read_csv('preprocessed_data/train_data.csv')
val_df = pd.read_csv('preprocessed_data/val_data.csv')

print(f"Training samples: {len(train_df):,}")
print(f"Validation samples: {len(val_df):,}")

# Load class weights
with open('preprocessed_data/class_weights.json', 'r') as f:
    class_weights = json.load(f)

print(f"\nClass Weights:")
for cat, weight in class_weights.items():
    print(f"  {cat}: {weight:.4f}")

# Estimate samples per class
total_samples = len(train_df)
samples_per_class = get_samples_per_class(class_weights, total_samples)
print(f"\nEstimated samples per class: {samples_per_class}")


### Fix Image Paths for Colab


In [None]:
def fix_image_paths(df, dataset_root):
    """Fix Windows paths to work with Colab."""
    df = df.copy()
    
    for col in ['clinical_image_path', 'dermoscopic_image_path']:
        if col in df.columns:
            df[col] = df[col].apply(lambda x: extract_relative_path(str(x), dataset_root))
    
    return df

def extract_relative_path(path_str, dataset_root):
    """Extract lesion_id/image.jpg from any path format."""
    match = re.search(r'(IL_\d+)[/\\](ISIC_\d+\.jpg)', path_str)
    
    if match:
        lesion_id = match.group(1)
        image_file = match.group(2)
        return os.path.join(dataset_root, 'MILK10k_Training_Input', lesion_id, image_file)
    else:
        parts = re.split(r'[/\\]', path_str)
        parts = [p for p in parts if p]
        if len(parts) >= 2:
            return os.path.join(dataset_root, 'MILK10k_Training_Input', parts[-2], parts[-1])
        raise ValueError(f"Cannot extract path from: {path_str}")

# Fix paths
DATASET_ROOT = f'{WORK_DIR}/dataset'
print("Fixing image paths...")

train_df = fix_image_paths(train_df, DATASET_ROOT)
val_df = fix_image_paths(val_df, DATASET_ROOT)

print(f"\nExample paths:")
print(f"  Clinical: {train_df['clinical_image_path'].iloc[0]}")
print(f"  Dermoscopic: {train_df['dermoscopic_image_path'].iloc[0]}")

# Verify paths
sample_path = train_df['clinical_image_path'].iloc[0]
if os.path.exists(sample_path):
    print(f"\nImage paths verified!")
else:
    print(f"\nWARNING: Image not found at {sample_path}")


## 7. Create DataLoaders (480x480 images)


In [None]:
# Optimized for A100 80GB with EfficientNetV2-L
BATCH_SIZE = TRAIN_CONFIG_V2['batch_size']  # 48 by default
NUM_WORKERS = TRAIN_CONFIG_V2['num_workers']  # 8
IMAGE_SIZE = IMAGE_CONFIG_V2['image_size']  # 480

print(f"Creating dataloaders for EfficientNetV2-L...")
print(f"  Image size: {IMAGE_SIZE}x{IMAGE_SIZE}")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Num workers: {NUM_WORKERS}")

# Create dataloaders with larger images
train_loader, val_loader = get_dataloaders(
    train_df,
    val_df,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    image_size=IMAGE_SIZE,
    fusion_strategy='late',  # Always late for V2
    use_metadata=MODEL_CONFIG_V2['use_metadata']
)

print(f"\nTrain DataLoader: {len(train_loader)} batches")
print(f"Val DataLoader: {len(val_loader)} batches")

# Test dataloader
print(f"\nTesting dataloader...")
for batch in train_loader:
    images, labels, metadata = batch
    clinical_img, dermoscopic_img = images
    print(f"  Clinical shape: {clinical_img.shape}")
    print(f"  Dermoscopic shape: {dermoscopic_img.shape}")
    print(f"  Labels shape: {labels.shape}")
    print(f"  Metadata shape: {metadata.shape}")
    break
print("DataLoader test successful!")


## 8. Create Model (EfficientNetV2-L with Cross-Modal Attention)


In [None]:
# Get device
device = get_device()

# Create EfficientNetV2-L model
print("Creating EfficientNetV2-L model with Cross-Modal Attention...")
model = create_model_v2(
    architecture=MODEL_CONFIG_V2['architecture'],
    num_classes=len(DIAGNOSIS_CATEGORIES),
    pretrained=MODEL_CONFIG_V2['pretrained'],
    use_metadata=MODEL_CONFIG_V2['use_metadata'],
    metadata_dim=MODEL_CONFIG_V2['metadata_dim'],
    dropout=MODEL_CONFIG_V2['dropout'],
    use_auxiliary_heads=MODEL_CONFIG_V2['use_auxiliary_heads']
)

model = model.to(device)

# Count parameters
total_params, trainable_params = count_parameters(model)

print(f"\nModel: {MODEL_CONFIG_V2['architecture']}")
print(f"Fusion: Cross-Modal Attention + Gated Fusion")
print(f"Auxiliary Heads: {MODEL_CONFIG_V2['use_auxiliary_heads']}")
print(f"Metadata: {MODEL_CONFIG_V2['use_metadata']}")
print(f"Device: {device}")
print(f"Parameters: {total_params:,} (Trainable: {trainable_params:,})")


## 9. Initialize Trainer


In [None]:
# Set random seed
set_seed(TRAIN_CONFIG_V2['random_seed'])

# Update paths to save in Google Drive
TRAIN_CONFIG_V2['checkpoint_dir'] = f'{DRIVE_ROOT}/models'
TRAIN_CONFIG_V2['log_dir'] = f'{DRIVE_ROOT}/logs'

os.makedirs(TRAIN_CONFIG_V2['checkpoint_dir'], exist_ok=True)
os.makedirs(TRAIN_CONFIG_V2['log_dir'], exist_ok=True)

print(f"Creating TrainerV2...")
print(f"  Checkpoint dir: {TRAIN_CONFIG_V2['checkpoint_dir']}")
print(f"  Log dir: {TRAIN_CONFIG_V2['log_dir']}")

# Create trainer
trainer = TrainerV2(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    class_weights=class_weights,
    samples_per_class=samples_per_class,
    device=device,
    config=TRAIN_CONFIG_V2
)

print(f"\nTrainer initialized!")
print(f"  Loss: {LOSS_CONFIG_V2['type']}")
print(f"  Use EMA: {TRAIN_CONFIG_V2.get('use_ema', True)}")
print(f"  Label Smoothing: {LOSS_CONFIG_V2.get('use_label_smoothing', True)}")


## 10. TensorBoard (Optional)


In [None]:
%load_ext tensorboard
%tensorboard --logdir {TRAIN_CONFIG_V2['log_dir']}

print("TensorBoard loaded! View metrics above during training.")


## 10.5 Resume Training from Checkpoint (Optional)

**Skip this cell if starting fresh.** Run this cell ONLY if you want to resume from a previous checkpoint.

Checkpoints are saved:
- Every 5 epochs: `checkpoint_v2_epoch_5.pth`, `checkpoint_v2_epoch_10.pth`, etc.
- Best model: `best_model_v2.pth`


In [None]:
# ============================================================
# RESUME FROM CHECKPOINT - Only run if continuing training!
# ============================================================
# Comment out/skip this cell if starting fresh training

RESUME_FROM_CHECKPOINT = False  # Set to True to resume

if RESUME_FROM_CHECKPOINT:
    # List available checkpoints
    checkpoint_dir = TRAIN_CONFIG_V2['checkpoint_dir']
    print("Available checkpoints:")
    import glob
    checkpoints = glob.glob(f"{checkpoint_dir}/*.pth")
    for cp in sorted(checkpoints):
        print(f"  - {os.path.basename(cp)}")
    
    # Choose checkpoint to resume from (latest epoch checkpoint)
    CHECKPOINT_FILE = "checkpoint_v2_epoch_10.pth"  # Change this to your checkpoint
    checkpoint_path = f"{checkpoint_dir}/{CHECKPOINT_FILE}"
    
    if os.path.exists(checkpoint_path):
        print(f"\nLoading checkpoint: {checkpoint_path}")
        
        # Load checkpoint
        checkpoint = torch.load(checkpoint_path, map_location=device)
        
        # Load model weights
        model.load_state_dict(checkpoint['model_state_dict'])
        print(f"Model weights loaded from epoch {checkpoint['epoch'] + 1}")
        
        # Load optimizer state
        trainer.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        print("Optimizer state loaded")
        
        # Load scheduler state if available
        if 'scheduler_state_dict' in checkpoint and checkpoint['scheduler_state_dict']:
            trainer.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
            print("Scheduler state loaded")
        
        # Update trainer's starting epoch
        START_EPOCH = checkpoint['epoch'] + 1
        print(f"\nResuming training from epoch {START_EPOCH + 1}")
        print(f"Previous best metric: {checkpoint.get('metrics', 'N/A')}")
    else:
        print(f"Checkpoint not found: {checkpoint_path}")
        print("Starting fresh training instead.")
        START_EPOCH = 0
else:
    START_EPOCH = 0
    print("Starting fresh training (RESUME_FROM_CHECKPOINT = False)")


## 11. Start Training

**Expected Training Time**: 4-5 hours on A100 80GB

Features:
- Mixed precision (FP16) for memory efficiency
- EMA for smoother weights
- Deep supervision with auxiliary heads
- Checkpoints saved every 5 epochs


In [None]:
# Start training (supports resume from checkpoint)
print("Starting EfficientNetV2-L training...")
print("This will take approximately 4-5 hours on A100 80GB.")
print("="*60)

# Check if resuming from checkpoint
if RESUME_FROM_CHECKPOINT and 'START_EPOCH' in dir() and START_EPOCH > 0:
    # Get best F1 from checkpoint if available
    best_f1_resume = checkpoint.get('metrics', 0.0)
    if isinstance(best_f1_resume, dict):
        best_f1_resume = best_f1_resume.get('macro_f1', 0.0)
    history = trainer.train(start_epoch=START_EPOCH, best_f1=best_f1_resume)
else:
    history = trainer.train()

print("\n" + "="*60)
print("TRAINING COMPLETED!")
print("="*60)


## 12. View Training Results


In [None]:
import matplotlib.pyplot as plt

# Load training history
history_df = pd.read_csv(f'{TRAIN_CONFIG_V2["checkpoint_dir"]}/training_history_v2.csv')

print("="*60)
print("TRAINING SUMMARY - EfficientNetV2-L")
print("="*60)
print(f"\nTotal epochs: {len(history_df)}")
print(f"Best Macro F1: {history_df['val_f1_macro'].max():.4f}")
print(f"Best Micro F1: {history_df['val_f1_micro'].max():.4f}")
print(f"Final Train Loss: {history_df['train_loss'].iloc[-1]:.4f}")
print(f"Final Val Loss: {history_df['val_loss'].iloc[-1]:.4f}")

# Compare with B3 baseline
print(f"\n--- Comparison with EfficientNet-B3 ---")
print(f"B3 Best Macro F1: ~0.50")
print(f"V2-L Best Macro F1: {history_df['val_f1_macro'].max():.4f}")
improvement = (history_df['val_f1_macro'].max() - 0.50) / 0.50 * 100
print(f"Improvement: {improvement:+.1f}%")

# Plot training curves
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Loss curves
axes[0, 0].plot(history_df['train_loss'], label='Train Loss', linewidth=2)
axes[0, 0].plot(history_df['val_loss'], label='Val Loss', linewidth=2)
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].set_title('Training and Validation Loss')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# F1 scores
axes[0, 1].plot(history_df['val_f1_macro'], label='Macro F1', linewidth=2, color='green')
axes[0, 1].plot(history_df['val_f1_micro'], label='Micro F1', linewidth=2, color='blue')
axes[0, 1].axhline(y=0.50, color='red', linestyle='--', label='B3 Baseline')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('F1 Score')
axes[0, 1].set_title('Validation F1 Scores')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# Learning rate
axes[1, 0].plot(history_df['learning_rate'], linewidth=2, color='orange')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Learning Rate')
axes[1, 0].set_title('Learning Rate Schedule')
axes[1, 0].set_yscale('log')
axes[1, 0].grid(True, alpha=0.3)

# F1 improvement
best_so_far = history_df['val_f1_macro'].cummax()
axes[1, 1].fill_between(range(len(best_so_far)), 0.50, best_so_far, alpha=0.3, color='green')
axes[1, 1].plot(best_so_far, linewidth=2, color='green', label='Best Macro F1')
axes[1, 1].axhline(y=0.50, color='red', linestyle='--', label='B3 Baseline')
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('Macro F1')
axes[1, 1].set_title('Macro F1 Improvement')
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(f'{DRIVE_ROOT}/training_curves_v2.png', dpi=300, bbox_inches='tight')
plt.show()

print(f"\nTraining curves saved to: {DRIVE_ROOT}/training_curves_v2.png")


## 13. Save Model Info


In [None]:
from datetime import datetime

best_epoch = history_df['val_f1_macro'].idxmax()
best_macro_f1 = history_df['val_f1_macro'].max()
best_micro_f1 = history_df.loc[best_epoch, 'val_f1_micro']

model_info = f"""# EfficientNetV2-L Training Results

**Date**: {datetime.now().strftime('%Y-%m-%d %H:%M')}
**GPU**: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}
**Total Epochs**: {len(history_df)}
**Best Epoch**: {best_epoch + 1}
**Best Macro F1**: {best_macro_f1:.4f}
**Best Micro F1**: {best_micro_f1:.4f}

## Model Configuration
- Architecture: {MODEL_CONFIG_V2['architecture']}
- Fusion: Cross-Modal Attention + Gated Fusion
- Image Size: {IMAGE_CONFIG_V2['image_size']}x{IMAGE_CONFIG_V2['image_size']}
- Batch Size: {TRAIN_CONFIG_V2['batch_size']}
- Loss: {LOSS_CONFIG_V2['type']}

## Files
- Best Model: {TRAIN_CONFIG_V2['checkpoint_dir']}/best_model_v2.pth
- Training History: {TRAIN_CONFIG_V2['checkpoint_dir']}/training_history_v2.csv
"""

info_path = f'{DRIVE_ROOT}/MODEL_INFO_V2.md'
with open(info_path, 'w') as f:
    f.write(model_info)

print(model_info)
print(f"\nModel info saved to: {info_path}")

# Print file locations
print("\n" + "="*60)
print("Files saved in Google Drive:")
print(f"  - {TRAIN_CONFIG_V2['checkpoint_dir']}/best_model_v2.pth")
print(f"  - {TRAIN_CONFIG_V2['checkpoint_dir']}/training_history_v2.csv")
print(f"  - {DRIVE_ROOT}/training_curves_v2.png")
print("="*60)


---

## Training Complete!

### Key Results
- EfficientNetV2-L with Cross-Modal Attention
- Improved Macro F1 over baseline
- Model saved to Google Drive

### Next Steps
1. Run inference with TTA (Test-Time Augmentation)
2. Optimize thresholds per class
3. Ensemble with EfficientNet-B3 for final submission

---
