# 02 â€” Model Training
Train EfficientNetV2-S and ResNet50 on cleaned Fitzpatrick17k data.
Two-phase training: frozen backbone then fine-tuned.

In [None]:
# Uncomment for Colab
# !pip install -q wandb timm

In [None]:
import sys
import torch
import pandas as pd
import numpy as np
import wandb
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from pathlib import Path

from src.data.dataset import FitzpatrickDataset
from src.data.transforms import get_train_transforms, get_eval_transforms
from src.models.classifier import SkinToneClassifier
from src.training.config import TrainingConfig
from src.training.trainer import Trainer, compute_class_weights
from src.utils.logging import init_wandb

In [None]:
# Configuration
config = TrainingConfig(
    backbone="efficientnet_v2_s",
    num_classes=3,
    pretrained=True,
    freeze_backbone=True,
    unfreeze_after_epochs=5,
    epochs=20,
    batch_size=32,
    learning_rate=1e-4,
    early_stopping_patience=5,
    use_class_weights=True,
    wandb_project="skin-tone-classifier",
)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")

IMAGE_DIR = "data/images"
DATA_DIR = "data/cleaned"

In [None]:
# Load data splits
train_df = pd.read_csv(f"{DATA_DIR}/train.csv")
val_df = pd.read_csv(f"{DATA_DIR}/val.csv")

print(f"Train: {len(train_df)}, Val: {len(val_df)}")
print(f"Train distribution:\n{train_df['skin_tone_group'].value_counts().sort_index()}")

In [None]:
# Create datasets and loaders
train_transform = get_train_transforms(config.image_size)
eval_transform = get_eval_transforms(config.image_size)

train_dataset = FitzpatrickDataset(train_df, IMAGE_DIR, transform=train_transform)
val_dataset = FitzpatrickDataset(val_df, IMAGE_DIR, transform=eval_transform)

train_loader = DataLoader(
    train_dataset, batch_size=config.batch_size,
    shuffle=True, num_workers=config.num_workers, pin_memory=True,
)
val_loader = DataLoader(
    val_dataset, batch_size=config.batch_size,
    shuffle=False, num_workers=config.num_workers, pin_memory=True,
)

In [None]:
# Compute class weights
labels = train_df["skin_tone_label"].tolist()
weights = compute_class_weights(labels, num_classes=3)
class_weights = torch.tensor(weights, dtype=torch.float32)
print(f"Class weights: {weights}")

In [None]:
# Initialize model
model = SkinToneClassifier(
    backbone_name=config.backbone,
    num_classes=config.num_classes,
    pretrained=config.pretrained,
)
if config.freeze_backbone:
    model.freeze_backbone()

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total params: {total_params:,}")
print(f"Trainable params: {trainable_params:,} ({trainable_params/total_params:.1%})")

In [None]:
# Initialize W&B
run = init_wandb(
    project=config.wandb_project,
    config=vars(config),
    run_name=f"{config.backbone}_lr{config.learning_rate}_bs{config.batch_size}",
    tags=["milestone1", config.backbone],
)

In [None]:
# Train
trainer = Trainer(
    model=model,
    config=config,
    train_loader=train_loader,
    val_loader=val_loader,
    class_weights=class_weights if config.use_class_weights else None,
    device=DEVICE,
    wandb_run=run,
)

history = trainer.train()

In [None]:
# Training curves
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

epochs = range(1, len(history["train"]) + 1)

axes[0].plot(epochs, [m["loss"] for m in history["train"]], label="Train")
axes[0].plot(epochs, [m["loss"] for m in history["val"]], label="Val")
axes[0].set_title("Loss")
axes[0].set_xlabel("Epoch")
axes[0].legend()

axes[1].plot(epochs, [m["accuracy"] for m in history["train"]], label="Train")
axes[1].plot(epochs, [m["accuracy"] for m in history["val"]], label="Val")
axes[1].set_title("Accuracy")
axes[1].set_xlabel("Epoch")
axes[1].legend()

axes[2].plot(epochs, [m["f1"] for m in history["train"]], label="Train")
axes[2].plot(epochs, [m["f1"] for m in history["val"]], label="Val")
axes[2].set_title("Macro F1")
axes[2].set_xlabel("Epoch")
axes[2].legend()

plt.tight_layout()
plt.show()

In [None]:
# Save model artifact
Path("checkpoints").mkdir(exist_ok=True)
torch.save(model.state_dict(), f"checkpoints/{config.backbone}_final.pt")
wandb.save(f"checkpoints/{config.backbone}_final.pt")
print(f"Model saved to checkpoints/{config.backbone}_final.pt")

In [None]:
# Finish W&B run
wandb.finish()
print("Training complete!")

## Train ResNet50
Change backbone to `resnet50` and re-run cells above, or copy and modify the config below.

In [None]:
# To train ResNet50, create a new config and repeat the training:
# config_resnet = TrainingConfig(backbone="resnet50", ...)
# Then re-run the model init, trainer, and training cells above.