# ü©∫ MILK10k Skin Lesion Classification - Training on Google Colab

This notebook trains a deep learning model for skin lesion classification using the MILK10k dataset on Google Colab with GPU acceleration.

## üìã Overview

- **Model**: EfficientNet-B3 with metadata fusion
- **Dataset**: MILK10k (4,192 train + 1,048 validation)
- **Task**: Multi-label classification (11 diagnosis categories)
- **Loss**: Focal Loss with class weights
- **Training**: Mixed precision (AMP) + Early stopping

## üöÄ Before Running

1. **Set Runtime to GPU**: Runtime ‚Üí Change runtime type ‚Üí GPU (T4 or better)
2. **Upload to Google Drive**:
   - `preprocessed_data/` folder (train_data.csv, val_data.csv, class_weights.json)
   - Dataset images (or download from source)
   - Project code files

---

## 1Ô∏è‚É£ Setup Google Colab Environment

Check GPU availability and system specifications.

In [None]:
# Check GPU availability
!nvidia-smi

# Check CUDA version
!nvcc --version

# Check disk space
!df -h | grep -E 'Filesystem|/content'

# Check RAM
!free -h

## 2Ô∏è‚É£ Mount Google Drive

Mount Google Drive to access datasets and save results.

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Set up project paths (adjust these to match your Google Drive structure)
import os
DRIVE_ROOT = '/content/drive/MyDrive/MILK10k_Project'
os.makedirs(DRIVE_ROOT, exist_ok=True)

print(f"‚úÖ Google Drive mounted!")
print(f"üìÅ Project root: {DRIVE_ROOT}")

## 3Ô∏è‚É£ Install Dependencies

Install PyTorch with CUDA support and other required packages.

In [None]:
# Install required packages
!pip install -q timm albumentations tensorboard scikit-learn

# Verify installations
import torch
import torchvision
import timm
import albumentations as A

print(f"‚úÖ PyTorch: {torch.__version__}")
print(f"‚úÖ TorchVision: {torchvision.__version__}")
print(f"‚úÖ Timm: {timm.__version__}")
print(f"‚úÖ Albumentations: {A.__version__}")
print(f"‚úÖ CUDA Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"‚úÖ GPU: {torch.cuda.get_device_name(0)}")
    print(f"‚úÖ CUDA Version: {torch.version.cuda}")

## 4Ô∏è‚É£ Setup Project Files

Clone repository or upload project files to Colab workspace.

In [None]:
# Option 1: Clone from GitHub (if repository is public)
# !git clone https://github.com/YOUR_USERNAME/YOUR_REPO.git
# %cd YOUR_REPO

# Option 2: Copy from Google Drive (recommended)
import shutil
from pathlib import Path

# Create working directory
WORK_DIR = '/content/MILK10k'
os.makedirs(WORK_DIR, exist_ok=True)
%cd {WORK_DIR}

# Copy source code from Drive
SRC_DRIVE = f'{DRIVE_ROOT}/src'
if os.path.exists(SRC_DRIVE):
    !cp -r {SRC_DRIVE} {WORK_DIR}/
    print("‚úÖ Copied src/ from Google Drive")
else:
    print("‚ö†Ô∏è src/ not found in Google Drive. Please upload it first!")

# Copy preprocessed data
PREPROCESSED_DRIVE = f'{DRIVE_ROOT}/preprocessed_data'
if os.path.exists(PREPROCESSED_DRIVE):
    !cp -r {PREPROCESSED_DRIVE} {WORK_DIR}/
    print("‚úÖ Copied preprocessed_data/ from Google Drive")
else:
    print("‚ö†Ô∏è preprocessed_data/ not found. Please upload it first!")

# Option A: Use dataset directly from Google Drive (slower but saves space)
DATASET_DRIVE = f'{DRIVE_ROOT}/dataset/MILK10k_Training_Input'
if os.path.exists(DATASET_DRIVE):
    # Create symlink to access images from Drive (faster than copying)
    !ln -s {DATASET_DRIVE} {WORK_DIR}/dataset/MILK10k_Training_Input
    print(f"‚úÖ Linked dataset from Google Drive (no copy needed)")
else:
    print(f"‚ö†Ô∏è Dataset not found at: {DATASET_DRIVE}")
    print("   Please upload MILK10k_Training_Input/ to your Google Drive!")

# Option B: Copy dataset to local Colab storage (faster but uses ~5-10GB)
# Uncomment if you want faster I/O during training:
# DATASET_DRIVE = f'{DRIVE_ROOT}/dataset'
# if os.path.exists(DATASET_DRIVE):
#     print("Copying dataset to Colab storage (this may take 10-15 minutes)...")
#     !cp -r {DATASET_DRIVE} {WORK_DIR}/
#     print("‚úÖ Copied dataset to Colab local storage")

# Create necessary directories
os.makedirs('models', exist_ok=True)
os.makedirs('logs', exist_ok=True)
os.makedirs('results', exist_ok=True)

print(f"\nüìÅ Working directory: {WORK_DIR}")
print(f"üìÇ Contents:")
!ls -la
print(f"\nüì∏ Dataset path will be: {WORK_DIR}/dataset/MILK10k_Training_Input")

### Alternative: Extract Dataset from ZIP (Faster Setup)

If you uploaded a ZIP file to Google Drive, use this to extract it.

In [None]:
# Uncomment and run this if you uploaded a ZIP file to Google Drive

# ZIP_PATH = f'{DRIVE_ROOT}/MILK10k_Training_Input.zip'
# 
# if os.path.exists(ZIP_PATH):
#     print(f"Extracting dataset from ZIP (this may take 5-10 minutes)...")
#     !unzip -q {ZIP_PATH} -d {WORK_DIR}/dataset/
#     print(f"‚úÖ Dataset extracted to: {WORK_DIR}/dataset/")
# else:
#     print(f"‚ö†Ô∏è ZIP file not found at: {ZIP_PATH}")

print("‚ÑπÔ∏è Skipped (using direct Drive access or already extracted)")

### Alternative: Download from Kaggle (Best for Large Datasets)

If dataset is hosted on Kaggle, use this method.

In [None]:
# Uncomment and configure if downloading from Kaggle

# # Install Kaggle API
# !pip install -q kaggle
# 
# # Upload your kaggle.json to Colab or use the file uploader
# from google.colab import files
# files.upload()  # Upload kaggle.json
# 
# !mkdir -p ~/.kaggle
# !cp kaggle.json ~/.kaggle/
# !chmod 600 ~/.kaggle/kaggle.json
# 
# # Download dataset (replace with actual dataset name)
# !kaggle datasets download -d YOUR_USERNAME/milk10k-dataset
# !unzip -q milk10k-dataset.zip -d {WORK_DIR}/dataset/
# 
# print("‚úÖ Dataset downloaded from Kaggle")

print("‚ÑπÔ∏è Skipped (using Google Drive or already downloaded)")

## 5Ô∏è‚É£ Load Configuration

Import configuration and verify settings.

In [None]:
import sys
sys.path.append('/content/MILK10k/src')

# Import all modules
from src.config import *
from src.utils import *
from src.dataset import *
from src.models import *
from src.evaluate import *

# Display configuration
print("="*60)
print("TRAINING CONFIGURATION")
print("="*60)
print(f"\nüìä Model Config:")
for key, value in MODEL_CONFIG.items():
    print(f"  {key}: {value}")

print(f"\nüéØ Training Config:")
for key, value in TRAIN_CONFIG.items():
    print(f"  {key}: {value}")

print(f"\nüñºÔ∏è Image Config:")
for key, value in IMAGE_CONFIG.items():
    print(f"  {key}: {value}")

print(f"\n‚öñÔ∏è Loss Config:")
for key, value in LOSS_CONFIG.items():
    print(f"  {key}: {value}")

print(f"\nüìÇ Diagnosis Categories ({len(DIAGNOSIS_CATEGORIES)}):")
for cat in DIAGNOSIS_CATEGORIES:
    print(f"  - {cat}")

## 6Ô∏è‚É£ Load Dataset

Load preprocessed training and validation data.

In [None]:
import pandas as pd
import json

# Load preprocessed data
print("Loading preprocessed data...")
train_df = pd.read_csv('preprocessed_data/train_data.csv')
val_df = pd.read_csv('preprocessed_data/val_data.csv')

print(f"‚úÖ Training samples: {len(train_df):,}")
print(f"‚úÖ Validation samples: {len(val_df):,}")

# Load class weights
with open('preprocessed_data/class_weights.json', 'r') as f:
    class_weights = json.load(f)

print(f"\n‚öñÔ∏è Class Weights:")
for cat, weight in class_weights.items():
    print(f"  {cat}: {weight:.4f}")

# Display sample data
print(f"\nüìä Training Data Sample:")
print(train_df.head(3))

# Check class distribution
print(f"\nüìà Label Distribution (Training):")
label_counts = train_df[DIAGNOSIS_CATEGORIES].sum()
for cat in DIAGNOSIS_CATEGORIES:
    count = label_counts[cat]
    pct = (count / len(train_df)) * 100
    print(f"  {cat}: {count:,} ({pct:.2f}%)")

In [None]:
import os
from pathlib import Path
import re

def fix_image_paths(df, dataset_root):
    """
    Fix Windows absolute paths to work with Colab's dataset location.
    
    Extracts only the relative path (lesion_id/image.jpg) and reconstructs
    with the correct dataset root path.
    
    Example:
        D:\\PYTHON\\DEEP_LEARNING\\dataset\\MILK10k_Training_Input\\IL_8583674\\ISIC_8570261.jpg
        -> /content/MILK10k/dataset/MILK10k_Training_Input/IL_8583674/ISIC_8570261.jpg
    """
    df = df.copy()
    
    for col in ['clinical_image_path', 'dermoscopic_image_path']:
        if col in df.columns:
            # Extract just the lesion_id and filename from Windows paths
            # Use regex to find pattern: IL_XXXXXXX/ISIC_XXXXXXX.jpg
            df[col] = df[col].apply(lambda x: 
                extract_relative_path(str(x), dataset_root)
            )
    
    return df

def extract_relative_path(path_str, dataset_root):
    """
    Extract lesion_id/image.jpg from any path format (Windows or Linux).
    
    Examples:
        D:\\PYTHON\\...\\IL_8583674\\ISIC_8570261.jpg -> IL_8583674/ISIC_8570261.jpg
        /some/path/IL_8583674/ISIC_8570261.jpg -> IL_8583674/ISIC_8570261.jpg
    """
    # Use regex to extract the lesion_id and image filename
    # Pattern: IL_XXXXXXX (lesion ID) followed by / or \ and then ISIC_XXXXXXX.jpg (image file)
    match = re.search(r'(IL_\d+)[/\\](ISIC_\d+\.jpg)', path_str)
    
    if match:
        lesion_id = match.group(1)
        image_file = match.group(2)
        # Build the correct path
        return os.path.join(dataset_root, 'MILK10k_Training_Input', lesion_id, image_file)
    else:
        # If pattern doesn't match, try to extract last 2 parts using both separators
        parts = re.split(r'[/\\]', path_str)
        parts = [p for p in parts if p]  # Remove empty parts
        if len(parts) >= 2:
            return os.path.join(dataset_root, 'MILK10k_Training_Input', parts[-2], parts[-1])
        else:
            raise ValueError(f"Cannot extract lesion_id and image from path: {path_str}")

# Dataset is in Google Drive, accessed via symlink
# The symlink points to: DRIVE_ROOT/dataset/MILK10k_Training_Input
# But we reference it as: /content/MILK10k/dataset/MILK10k_Training_Input
DATASET_ROOT = f'{WORK_DIR}/dataset'

print("Fixing image paths for Colab environment...")
print(f"Dataset root: {DATASET_ROOT}")

train_df = fix_image_paths(train_df, DATASET_ROOT)
val_df = fix_image_paths(val_df, DATASET_ROOT)

print(f"\n‚úÖ Image paths updated!")
print(f"\nüì∏ Example corrected paths:")
print(f"  Clinical: {train_df['clinical_image_path'].iloc[0]}")
print(f"  Dermoscopic: {train_df['dermoscopic_image_path'].iloc[0]}")

# Save corrected CSV files for future use
print(f"\nüíæ Saving corrected CSV files...")
train_df.to_csv('preprocessed_data/train_data.csv', index=False)
val_df.to_csv('preprocessed_data/val_data.csv', index=False)
print(f"‚úÖ Saved corrected CSVs to preprocessed_data/")

# Also save to Google Drive for persistence
os.makedirs(f'{DRIVE_ROOT}/preprocessed_data', exist_ok=True)
train_df.to_csv(f'{DRIVE_ROOT}/preprocessed_data/train_data.csv', index=False)
val_df.to_csv(f'{DRIVE_ROOT}/preprocessed_data/val_data.csv', index=False)
print(f"‚úÖ Saved corrected CSVs to Google Drive for future use")

# Verify paths exist
sample_clinical = train_df['clinical_image_path'].iloc[0]
sample_dermoscopic = train_df['dermoscopic_image_path'].iloc[0]

print(f"\nüîç Verifying image files...")
if os.path.exists(sample_clinical):
    print(f"‚úÖ Sample clinical image exists!")
    print(f"   Path: {sample_clinical}")
else:
    print(f"‚ö†Ô∏è WARNING: Clinical image not found at: {sample_clinical}")
    # Try to diagnose the issue
    print(f"\nüîß Debugging:")
    print(f"  - WORK_DIR: {WORK_DIR}")
    symlink_path = f'{WORK_DIR}/dataset/MILK10k_Training_Input'
    if os.path.islink(symlink_path):
        print(f"  - Dataset symlink target: {os.readlink(symlink_path)}")
    else:
        print(f"  - Not a symlink, checking directory: {os.path.exists(symlink_path)}")
    
    # Check lesion directory
    lesion_id = sample_clinical.split('/')[-2]
    lesion_dir = f'{WORK_DIR}/dataset/MILK10k_Training_Input/{lesion_id}'
    print(f"  - Lesion directory exists: {os.path.exists(lesion_dir)}")
    if os.path.exists(lesion_dir):
        print(f"  - Files in lesion dir: {os.listdir(lesion_dir)}")
    
if os.path.exists(sample_dermoscopic):
    print(f"‚úÖ Sample dermoscopic image exists!")
    print(f"   Path: {sample_dermoscopic}")
else:
    print(f"‚ö†Ô∏è WARNING: Dermoscopic image not found at: {sample_dermoscopic}")

## 7Ô∏è‚É£ Create DataLoaders

Create training and validation dataloaders with augmentation.

In [None]:
# Adjust batch size and workers for Colab Pro with A100
BATCH_SIZE = 32  # Increase for A100 (reduce to 16 or 8 if OOM)
NUM_WORKERS = 4  # A100 instance has more CPU cores

print(f"Creating dataloaders...")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Num workers: {NUM_WORKERS}")

# Create dataloaders
train_loader, val_loader = get_dataloaders(
    train_df,
    val_df,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    image_size=IMAGE_CONFIG['image_size'],
    fusion_strategy=IMAGE_CONFIG['fusion_strategy'],
    use_metadata=MODEL_CONFIG['use_metadata']
)

print(f"\n‚úÖ Train DataLoader: {len(train_loader)} batches")
print(f"‚úÖ Val DataLoader: {len(val_loader)} batches")

# Test dataloader
print(f"\nTesting dataloader...")
for batch in train_loader:
    if len(batch) == 3:
        images, labels, metadata = batch
        print(f"  Images shape: {images.shape}")
        print(f"  Labels shape: {labels.shape}")
        print(f"  Metadata shape: {metadata.shape}")
    else:
        images, labels = batch
        print(f"  Images shape: {images.shape}")
        print(f"  Labels shape: {labels.shape}")
    break
print("‚úÖ DataLoader test successful!")

## 8Ô∏è‚É£ Create Model

Initialize EfficientNet-B3 model with metadata fusion.

In [None]:
# Get device
device = get_device()

# Create model
print("Creating model...")
model = create_model(
    architecture=MODEL_CONFIG['architecture'],
    num_classes=len(DIAGNOSIS_CATEGORIES),
    pretrained=MODEL_CONFIG['pretrained'],
    fusion_strategy=IMAGE_CONFIG['fusion_strategy'],
    use_metadata=MODEL_CONFIG['use_metadata'],
    metadata_dim=MODEL_CONFIG['metadata_dim'],
    dropout=MODEL_CONFIG['dropout']
)

model = model.to(device)

# Count parameters
total_params, trainable_params = count_parameters(model)

print(f"\n‚úÖ Model: {MODEL_CONFIG['architecture']}")
print(f"‚úÖ Fusion: {IMAGE_CONFIG['fusion_strategy']}")
print(f"‚úÖ Metadata: {MODEL_CONFIG['use_metadata']}")
print(f"‚úÖ Device: {device}")
print(f"‚úÖ Parameters: {total_params:,} (Trainable: {trainable_params:,})")

## 9Ô∏è‚É£ Initialize Training Components

Setup loss function, optimizer, scheduler, and trainer.

In [None]:
from src.train import Trainer

# Set random seed for reproducibility
set_seed(TRAIN_CONFIG['random_seed'])

# Update checkpoint and log directories to save in Google Drive
TRAIN_CONFIG['checkpoint_dir'] = f'{DRIVE_ROOT}/models'
TRAIN_CONFIG['log_dir'] = f'{DRIVE_ROOT}/logs'

# Create directories
os.makedirs(TRAIN_CONFIG['checkpoint_dir'], exist_ok=True)
os.makedirs(TRAIN_CONFIG['log_dir'], exist_ok=True)

print(f"Creating trainer...")
print(f"  Checkpoint dir: {TRAIN_CONFIG['checkpoint_dir']}")
print(f"  Log dir: {TRAIN_CONFIG['log_dir']}")

# Create trainer
trainer = Trainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    class_weights=class_weights,
    device=device
)

print(f"‚úÖ Trainer initialized successfully!")

## üîü Load TensorBoard (Optional)

Load TensorBoard extension to monitor training in real-time.

In [None]:
# Load TensorBoard extension
%load_ext tensorboard

# Start TensorBoard (will update during training)
%tensorboard --logdir {TRAIN_CONFIG['log_dir']}

print("‚úÖ TensorBoard loaded! View metrics above during training.")

## 1Ô∏è‚É£1Ô∏è‚É£ Start Training üöÄ

**‚ö†Ô∏è IMPORTANT**: This will take several hours. Colab free tier may disconnect after ~12 hours.

Expected training time on Colab:
- **T4 GPU**: ~15-20 hours for 100 epochs
- **A100 GPU** (Colab Pro): ~5-8 hours

The training will:
- Save best model to Google Drive automatically
- Save checkpoints every 5 epochs
- Stop early if no improvement for 15 epochs

In [None]:
# Start training
print("üöÄ Starting training...")
print("‚ö†Ô∏è This will take several hours. Don't close the browser tab!")
print("="*60)

# Train the model
history = trainer.train()

print("\n" + "="*60)
print("üéâ TRAINING COMPLETED!")
print("="*60)

## 1Ô∏è‚É£2Ô∏è‚É£ View Training Results

Analyze training history and visualize performance.

In [None]:
import matplotlib.pyplot as plt

# Load training history
history_df = pd.read_csv(f'{TRAIN_CONFIG["checkpoint_dir"]}/training_history.csv')

print("="*60)
print("TRAINING SUMMARY")
print("="*60)
print(f"\nTotal epochs: {len(history_df)}")
print(f"Best Macro F1: {history_df['val_f1_macro'].max():.4f}")
print(f"Best Micro F1: {history_df['val_f1_micro'].max():.4f}")
print(f"Final Train Loss: {history_df['train_loss'].iloc[-1]:.4f}")
print(f"Final Val Loss: {history_df['val_loss'].iloc[-1]:.4f}")

# Plot training curves
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Loss curves
axes[0, 0].plot(history_df['train_loss'], label='Train Loss', linewidth=2)
axes[0, 0].plot(history_df['val_loss'], label='Val Loss', linewidth=2)
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].set_title('Training and Validation Loss')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# F1 scores
axes[0, 1].plot(history_df['val_f1_macro'], label='Macro F1', linewidth=2)
axes[0, 1].plot(history_df['val_f1_micro'], label='Micro F1', linewidth=2)
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('F1 Score')
axes[0, 1].set_title('Validation F1 Scores')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# Learning rate
axes[1, 0].plot(history_df['learning_rate'], linewidth=2, color='green')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Learning Rate')
axes[1, 0].set_title('Learning Rate Schedule')
axes[1, 0].grid(True, alpha=0.3)

# Loss comparison
best_epoch = history_df['val_f1_macro'].idxmax()
axes[1, 1].bar(['Train Loss', 'Val Loss'], 
               [history_df.loc[best_epoch, 'train_loss'], 
                history_df.loc[best_epoch, 'val_loss']])
axes[1, 1].set_title(f'Loss at Best Epoch ({best_epoch+1})')
axes[1, 1].set_ylabel('Loss')
axes[1, 1].grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.savefig(f'{DRIVE_ROOT}/training_curves.png', dpi=300, bbox_inches='tight')
plt.show()

print(f"\n‚úÖ Training curves saved to: {DRIVE_ROOT}/training_curves.png")

## 1Ô∏è‚É£3Ô∏è‚É£ Save Model Info

Document training results for team reference.

In [None]:
from datetime import datetime

# Create model info file
best_epoch = history_df['val_f1_macro'].idxmax()
best_macro_f1 = history_df['val_f1_macro'].max()
best_micro_f1 = history_df.loc[best_epoch, 'val_f1_micro']

model_info = f"""# Training Results

**Date**: {datetime.now().strftime('%Y-%m-%d %H:%M')}
**GPU**: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}
**Training Time**: Check notebook execution time
**Total Epochs**: {len(history_df)}
**Best Epoch**: {best_epoch + 1}
**Best Macro F1**: {best_macro_f1:.4f}
**Best Micro F1**: {best_micro_f1:.4f}

## Model Configuration
- Architecture: {MODEL_CONFIG['architecture']}
- Fusion Strategy: {IMAGE_CONFIG['fusion_strategy']}
- Use Metadata: {MODEL_CONFIG['use_metadata']}
- Image Size: {IMAGE_CONFIG['image_size']}
- Batch Size: {BATCH_SIZE}

## Files
- Best Model: `{TRAIN_CONFIG['checkpoint_dir']}/best_model.pth`
- Training History: `{TRAIN_CONFIG['checkpoint_dir']}/training_history.csv`
- Training Curves: `{DRIVE_ROOT}/training_curves.png`

## Next Steps
1. Download best_model.pth from Google Drive
2. Run inference on test set
3. Generate submission file
4. Share results with team
"""

# Save model info
info_path = f'{DRIVE_ROOT}/MODEL_INFO.md'
with open(info_path, 'w') as f:
    f.write(model_info)

print(model_info)
print(f"\n‚úÖ Model info saved to: {info_path}")

## 1Ô∏è‚É£4Ô∏è‚É£ Download Trained Model

Download the trained model and results to your local machine.

In [None]:
from google.colab import files

# Option 1: Download directly (may be slow for large files)
print("Downloading files...")
print("‚ö†Ô∏è This may take a while for large model files")

# Download best model
try:
    files.download(f'{TRAIN_CONFIG["checkpoint_dir"]}/best_model.pth')
    print("‚úÖ Downloaded: best_model.pth")
except Exception as e:
    print(f"‚ö†Ô∏è Could not download model: {e}")
    print(f"üìÅ Access it in Google Drive: {TRAIN_CONFIG['checkpoint_dir']}/best_model.pth")

# Download training history
try:
    files.download(f'{TRAIN_CONFIG["checkpoint_dir"]}/training_history.csv')
    print("‚úÖ Downloaded: training_history.csv")
except Exception as e:
    print(f"‚ö†Ô∏è Could not download history: {e}")

# Download training curves
try:
    files.download(f'{DRIVE_ROOT}/training_curves.png')
    print("‚úÖ Downloaded: training_curves.png")
except Exception as e:
    print(f"‚ö†Ô∏è Could not download curves: {e}")

print("\n" + "="*60)
print("üì¶ All files are also saved in Google Drive:")
print(f"  üìÅ {DRIVE_ROOT}/")
print("="*60)

---

## üéâ Training Complete!

### What to do next:

1. **Download the trained model** from Google Drive (`best_model.pth`)
2. **Share with your team** via Drive link or upload to repository
3. **Run inference** on test set using `src/inference.py` or `src/generate_submission.py`
4. **Document results** in README and team chat

### Tips for Colab Training:

- **Runtime disconnections**: Colab free tier may disconnect after ~12 hours. Use Colab Pro for longer sessions.
- **Checkpoints**: Models are saved every 5 epochs to Google Drive, so you can resume if disconnected.
- **GPU memory**: If you get OOM errors, reduce `BATCH_SIZE` to 8 or `IMAGE_CONFIG['image_size']` to 224.
- **Cost**: Colab Pro (~$10/month) gives faster GPUs (A100) and longer runtime.

### Resume Training (if interrupted):

```python
# Load checkpoint and continue training
checkpoint_path = f'{TRAIN_CONFIG["checkpoint_dir"]}/checkpoint_epoch_XX.pth'
checkpoint = load_checkpoint(model, optimizer, checkpoint_path, device=device)
# Then run trainer.train() again
```

---

**Happy Training! üöÄ**