# RAFDB Metric Learning Training Notebook

This notebook demonstrates how to train a metric learning model on RAFDB with:
- Augmentation (crop, rotate, grayscale)
- Triplet margin mining
- MPerClassSampler for balanced batches
- wandb logging for experiment tracking
- Real-time triplet visualization

**Key differences from main.py:**
- Command-line arguments replaced with cell variables
- Interactive execution with visual feedback
- Better for debugging and experimentation
- Visualizations display inline

In [None]:
# ============================================================================
# SECTION 1: Import Required Modules
# ============================================================================

import sys
from pathlib import Path
import os

# Add current directory to path for relative imports
sys.path.insert(0, str(Path.cwd()))

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec

import wandb
from pytorch_metric_learning.miners import TripletMarginMiner
from pytorch_metric_learning.losses import TripletMarginLoss
from pytorch_metric_learning.samplers import MPerClassSampler
from tqdm import tqdm

# Import custom modules
from dataset import train_val_split, getdataset_from_imagefolder
from train import train_one_epoch
from evalute import evaluate_one_epoch, final_eval
from visualize import (
    compute_embeddings_and_predictions,
    compute_confusion_matrix,
    visualize_confusion_matrix,
    visualize_pca,
)
from model import EmbeddingModel, validate_model_and_augmentation

print("✓ All imports successful")
print(f"✓ PyTorch version: {torch.__version__}")
print(f"✓ CUDA available: {torch.cuda.is_available()}")

## Configuration: Replace command-line arguments with variables

Instead of parsing `sys.argv`, define parameters directly below. Modify these cells to match your setup.

In [None]:
# ============================================================================
# SECTION 2: Convert Command-Line Arguments to Variables
# ============================================================================

# Dataset arguments
RAFDB_ROOT = "/path/to/rafdb"  # ← CHANGE THIS to your RAFDB path
VAL_FRACTION = 0.2
SEED = 42

# Augmentation arguments
NUM_AUGMENTATIONS = 2
CROP_SCALE = (0.8, 1.0)
ROTATION_DEGREES = 15
GRAYSCALE_PROB = 0.3

# Training arguments
NUM_EPOCHS = 50
BATCH_SIZE = 28  # ← IMPORTANT: For RAFDB (7 classes), use batch_size = m * num_classes
LEARNING_RATE = 1e-3
LR_STEP = 10
NUM_WORKERS = 4

# Model arguments
EMBEDDING_DIM = 128

# Metric learning arguments
MARGIN = 0.1
METRIC = "euclidean"  # or "cosine"
M_PER_CLASS = 4  # ← samples per class in each batch

# Formula: batch_size should be <= m * num_classes
# For RAFDB: batch_size <= 4 * 7 = 28
# So with M_PER_CLASS=4 and 7 classes, max batch_size = 28
# If you want larger batch_size, either:
#   - Increase M_PER_CLASS (e.g., M_PER_CLASS=6 → batch_size <= 42)
#   - Use smaller batch_size (e.g., batch_size=20)

# Logging/output arguments
USE_WANDB = False  # Set to True if you have wandb configured
OUTPUT_DIR = "./outputs"

# Device setup
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Configuration:")
print(f"  RAFDB_ROOT: {RAFDB_ROOT}")
print(f"  Device: {DEVICE}")
print(f"  Output dir: {OUTPUT_DIR}")
print(f"  Use wandb: {USE_WANDB}")
print(f"\n  MPerClassSampler constraint:")
print(f"  - M_PER_CLASS: {M_PER_CLASS}")
print(f"  - Num classes (RAFDB): 7")
print(f"  - Max batch_size: {M_PER_CLASS * 7}")
print(f"  - Your batch_size: {BATCH_SIZE}")
if BATCH_SIZE > M_PER_CLASS * 7:
    print(f"  ⚠ WARNING: batch_size ({BATCH_SIZE}) > m*num_classes ({M_PER_CLASS * 7})")
    print(f"     This will cause an AssertionError!")
else:
    print(f"  ✓ Valid configuration")

# Initialize wandb if enabled
if USE_WANDB:
    wandb.init(
        project="rafdb-metric-learning",
        name=f"batch={BATCH_SIZE}_m={M_PER_CLASS}_embed={EMBEDDING_DIM}",
        config={
            "batch_size": BATCH_SIZE,
            "learning_rate": LEARNING_RATE,
            "num_epochs": NUM_EPOCHS,
            "embedding_dim": EMBEDDING_DIM,
            "margin": MARGIN,
            "m_per_class": M_PER_CLASS,
            "metric": METRIC,
            "num_augmentations": NUM_AUGMENTATIONS,
        },
    )
    print("✓ wandb initialized")
else:
    print("⚠ wandb disabled (set USE_WANDB=True to enable)")

## Classes and Helper Functions

Copy all class and function definitions from main.py. These are needed before executing the training loop.

In [None]:
# ============================================================================
# SECTION 3: Class and Function Definitions
# ============================================================================

class RAFDBWithAugmentation(Dataset):
    """RAFDB dataset wrapper with controlled augmentation variance."""

    def __init__(
        self,
        root: str,
        split: str = "train",
        num_augmentations: int = 2,
        crop_scale: tuple = (0.8, 1.0),
        rotation_degrees: int = 15,
        grayscale_prob: float = 0.3,
    ):
        self.root = Path(root)
        self.split = split
        self.num_augmentations = num_augmentations
        self.crop_scale = crop_scale
        self.rotation_degrees = rotation_degrees
        self.grayscale_prob = grayscale_prob

        split_dir = self.root / split
        if not split_dir.exists():
            raise FileNotFoundError(f"RAFDB split directory not found: {split_dir}")

        self.base_transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
            ),
        ])

        self.augmentation_transforms = [
            self._create_augmentation_transform()
            for _ in range(num_augmentations)
        ]

        self.dataset = datasets.ImageFolder(
            str(split_dir),
            transform=None,
        )

    def _create_augmentation_transform(self):
        transforms_list = [
            transforms.RandomResizedCrop(224, scale=self.crop_scale),
            transforms.RandomRotation(self.rotation_degrees),
        ]

        if self.grayscale_prob > 0:
            transforms_list.append(
                transforms.RandomGrayscale(p=self.grayscale_prob)
            )

        transforms_list.extend([
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225],
            ),
        ])

        return transforms.Compose(transforms_list)

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        img, label = self.dataset[idx]
        img_base = self.base_transform(img)
        img_augmented = [
            aug_transform(img) for aug_transform in self.augmentation_transforms
        ]
        all_views = torch.stack([img_base] + img_augmented, dim=0)
        return all_views, label


def collate_fn_with_augmentations(batch):
    """Custom collate to handle augmented views."""
    views_list = []
    labels_list = []

    for views, label in batch:
        num_views = views.shape[0]
        for v in range(num_views):
            views_list.append(views[v])
            labels_list.append(label)

    images = torch.stack(views_list, dim=0)
    labels = torch.tensor(labels_list, dtype=torch.long)

    return images, labels


def visualize_triplet_mining(
    model: nn.Module,
    train_loader: DataLoader,
    miner,
    device: torch.device,
    num_samples: int = 3,
):
    """Visualize triplet mining results."""
    model.eval()
    print("\n" + "=" * 70)
    print("VISUALIZING TRIPLET MINING RESULTS")
    print("=" * 70)

    with torch.no_grad():
        batch_images, batch_labels = next(iter(train_loader))
        batch_images = batch_images.to(device)
        batch_labels = batch_labels.to(device)

        embeddings = model(batch_images)
        anchor_idx, positive_idx, negative_idx = miner(embeddings, batch_labels)

        if anchor_idx.numel() == 0:
            print("⚠ Miner returned 0 triplets. Skipping visualization.")
            return

        print(f"✓ Mined {anchor_idx.numel()} triplets from batch of {batch_images.shape[0]} images")

        num_to_show = min(num_samples, anchor_idx.numel())

        for triplet_num in range(num_to_show):
            a_idx = anchor_idx[triplet_num].item()
            p_idx = positive_idx[triplet_num].item()
            n_idx = negative_idx[triplet_num].item()

            a_label = batch_labels[a_idx].item()
            p_label = batch_labels[p_idx].item()
            n_label = batch_labels[n_idx].item()

            a_img = batch_images[a_idx]
            p_img = batch_images[p_idx]
            n_img = batch_images[n_idx]

            mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1).to(device)
            std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1).to(device)

            a_img_vis = (a_img * std + mean).clamp(0, 1).cpu().permute(1, 2, 0).numpy()
            p_img_vis = (p_img * std + mean).clamp(0, 1).cpu().permute(1, 2, 0).numpy()
            n_img_vis = (n_img * std + mean).clamp(0, 1).cpu().permute(1, 2, 0).numpy()

            fig = plt.figure(figsize=(14, 5))
            gs = GridSpec(1, 3, figure=fig, wspace=0.3)

            ax_a = fig.add_subplot(gs[0, 0])
            ax_a.imshow(a_img_vis)
            ax_a.set_title(f"Anchor\nLabel: {a_label}", fontsize=12, fontweight="bold", color="blue")
            ax_a.axis("off")

            ax_p = fig.add_subplot(gs[0, 1])
            ax_p.imshow(p_img_vis)
            status = "✓ Same" if p_label == a_label else "✗ Wrong"
            ax_p.set_title(
                f"Positive\nLabel: {p_label}\n{status}",
                fontsize=12,
                fontweight="bold",
                color="green" if p_label == a_label else "red",
            )
            ax_p.axis("off")

            ax_n = fig.add_subplot(gs[0, 2])
            ax_n.imshow(n_img_vis)
            status = "✓ Different" if n_label != a_label else "✗ Wrong"
            ax_n.set_title(
                f"Negative\nLabel: {n_label}\n{status}",
                fontsize=12,
                fontweight="bold",
                color="green" if n_label != a_label else "red",
            )
            ax_n.axis("off")

            fig.suptitle(
                f"Triplet {triplet_num + 1}/{num_to_show}",
                fontsize=14,
                fontweight="bold",
            )
            plt.tight_layout()
            plt.show()

            print(
                f"  Triplet {triplet_num + 1}: "
                f"Anchor(L={a_label}) -> Pos(L={p_label}) + Neg(L={n_label}) | "
                f"Valid: {p_label == a_label and n_label != a_label}"
            )

    print("=" * 70)


print("✓ All classes and functions defined")

## Step 1: Setup and Initialization

Create output directory and load dataset with validation.

In [None]:
# ============================================================================
# SECTION 4: Setup and Load Dataset
# ============================================================================

# Create output directory
output_dir = Path(OUTPUT_DIR)
output_dir.mkdir(parents=True, exist_ok=True)
print(f"✓ Output directory: {output_dir}")

# Load RAFDB dataset
print("\nLoading RAFDB dataset...")
print(f"  RAFDB root: {RAFDB_ROOT}")
print(f"  Augmentations: {NUM_AUGMENTATIONS} views per sample")
print(f"  Crop scale: {CROP_SCALE}")
print(f"  Rotation: ±{ROTATION_DEGREES}°")
print(f"  Grayscale prob: {GRAYSCALE_PROB}")

dataset = RAFDBWithAugmentation(
    root=RAFDB_ROOT,
    split="train",
    num_augmentations=NUM_AUGMENTATIONS,
    crop_scale=CROP_SCALE,
    rotation_degrees=ROTATION_DEGREES,
    grayscale_prob=GRAYSCALE_PROB,
)
print(f"✓ Dataset loaded: {len(dataset)} samples")

# Load test dataset
print("\nLoading test dataset...")
test_dataset = RAFDBWithAugmentation(
    root=RAFDB_ROOT,
    split="test",
    num_augmentations=NUM_AUGMENTATIONS,
    crop_scale=CROP_SCALE,
    rotation_degrees=ROTATION_DEGREES,
    grayscale_prob=GRAYSCALE_PROB,
)
print(f"✓ Test dataset loaded: {len(test_dataset)} samples")

# Validate augmentation and model
print("\n" + "=" * 70)
print("STEP 1: Validating augmentation and model...")
print("=" * 70)
temp_model = EmbeddingModel(
    model_name="resnet50",
    embedding_dim=EMBEDDING_DIM,
    pretrained=True,
)
if not validate_model_and_augmentation(temp_model, dataset, num_samples=4, device=DEVICE):
    print("\n❌ Validation failed. Exiting.")
    raise RuntimeError("Model validation failed")
del temp_model
model = temp_model
print("\n" + "=" * 70)
print("STEP 2: Preparing train/val split...")
print("=" * 70)

# Split into train/val
train_ds, val_ds = train_val_split(
    dataset, val_fraction=VAL_FRACTION, seed=SEED
)
print(f"✓ Train samples: {len(train_ds)}")
print(f"✓ Val samples: {len(val_ds)}")

## Step 2: Create DataLoaders and Model

In [None]:
# Create samplers and DataLoaders
train_labels = torch.tensor([dataset.dataset.targets[i] for i in train_ds.indices])
num_classes = len(torch.unique(train_labels))
max_batch_size = M_PER_CLASS * num_classes

print(f"\nDataLoader setup:")
print(f"  Num unique classes in train set: {num_classes}")
print(f"  M_PER_CLASS: {M_PER_CLASS}")
print(f"  Max allowed batch_size: {max_batch_size}")
print(f"  Requested batch_size: {BATCH_SIZE}")

if BATCH_SIZE > max_batch_size:
    print(f"\n⚠ ERROR: batch_size ({BATCH_SIZE}) > m*num_classes ({max_batch_size})")
    print(f"MPerClassSampler requires: batch_size <= m * (number of unique labels)")
    print(f"\nFix options:")
    print(f"  1. Reduce batch_size to {max_batch_size} or less")
    print(f"  2. Increase M_PER_CLASS (e.g., from {M_PER_CLASS} to {int(BATCH_SIZE/num_classes)+1})")
    print(f"  3. Keep original batch_size but use a different sampler (not MPerClassSampler)")
    raise AssertionError(f"batch_size ({BATCH_SIZE}) must be <= m*num_classes ({max_batch_size})")

sampler = MPerClassSampler(
    train_labels,
    m=M_PER_CLASS,
    length_before_new_iter=len(train_ds),
    batch_size=BATCH_SIZE,
)

train_loader = DataLoader(
    train_ds,
    batch_size=BATCH_SIZE,
    sampler=sampler,
    num_workers=NUM_WORKERS,
    collate_fn=collate_fn_with_augmentations,
)

val_loader = DataLoader(
    val_ds,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    collate_fn=collate_fn_with_augmentations,
)

test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    collate_fn=collate_fn_with_augmentations,
)

print(f"✓ Train DataLoader: {len(train_loader)} batches")
print(f"✓ Val DataLoader: {len(val_loader)} batches")
print(f"✓ Test DataLoader: {len(test_loader)} batches")


# Initialize model
print("\n" + "=" * 70)
print("STEP 3: Initializing embedding model...")
print("=" * 70)
model = EmbeddingModel(
    model_name="resnet18",
    embedding_dim=EMBEDDING_DIM,
    pretrained=True,
    normalize=True,
)
model.to(DEVICE)
print("✓ Model initialized:")
print("  - Backbone: ResNet18 (pretrained)")
print(f"  - Embedding dim: {EMBEDDING_DIM}")
print("  - Normalization: Enabled (L2)")
print(f"  - Device: {DEVICE}")
print(f"  - Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

# Setup loss and optimizer
loss_fn = TripletMarginLoss(margin=MARGIN, swap=True)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = optim.lr_scheduler.StepLR(
    optimizer, step_size=LR_STEP, gamma=0.1
)

# Miner for hard triplet selection
miner = TripletMarginMiner(
    margin=MARGIN, type_of_triplets="semihard", distance=None
)

print("✓ Loss, optimizer, and miner initialized")

## Step 3: Visualize Triplet Mining

Before training, verify that the miner is selecting correct triplets (same class as positive, different class as negative).

In [None]:
# Visualize triplet mining
print("\n" + "=" * 70)
print("STEP 4A: Visualizing triplet mining...")
print("=" * 70)
visualize_triplet_mining(model, train_loader, miner, DEVICE, num_samples=3)

## Step 4: Training Loop

Run the training loop. You can interrupt this cell to stop training at any time.

In [None]:
# ============================================================================
# SECTION 5: Training Loop
# ============================================================================

print("\n" + "=" * 70)
print("STEP 4B: Starting training...")
print("=" * 70)

best_accuracy = 0.0
history = {"train_loss": [], "val_recall_1": [], "val_accuracy": []}

for epoch in range(NUM_EPOCHS):
    # Train
    train_loss = train_one_epoch(
        model=model,
        train_loader=train_loader,
        optimizer=optimizer,
        loss_fn=loss_fn,
        device=DEVICE,
        miner=miner,
        epoch=epoch,
    )

    # Validate
    val_metrics = evaluate_one_epoch(
        model=model,
        val_loader=val_loader,
        device=DEVICE,
        metric=METRIC,
    )

    scheduler.step()

    # Track history
    history["train_loss"].append(train_loss)
    history["val_recall_1"].append(val_metrics.get("recall_at_1", 0))
    history["val_accuracy"].append(val_metrics.get("accuracy", 0))

    # Print summary
    print(
        f"Epoch {epoch + 1}/{NUM_EPOCHS} | "
        f"Loss: {train_loss:.4f} | "
        f"Recall@1: {val_metrics.get('recall_at_1', 0):.4f} | "
        f"Accuracy: {val_metrics.get('accuracy', 0):.4f}"
    )

    # Log to wandb
    if USE_WANDB:
        wandb.log({
            "epoch": epoch + 1,
            "train_loss": train_loss,
            "val_recall_at_1": val_metrics.get("recall_at_1", 0),
            "val_recall_at_5": val_metrics.get("recall_at_5", 0),
            "val_recall_at_10": val_metrics.get("recall_at_10", 0),
            "val_accuracy": val_metrics.get("accuracy", 0),
        })

    # Save best model based on validation accuracy
    if val_metrics.get("accuracy", 0) > best_accuracy:
        best_accuracy = val_metrics["accuracy"]
        model_path = output_dir / "best_model.pth"
        torch.save(model.state_dict(), model_path)
        print(f"  ✓ Saved best model")
        if USE_WANDB:
            wandb.save(str(model_path))

print("\n" + "=" * 70)
print("Training complete!")
print("=" * 70)

# Close wandb
if USE_WANDB:
    wandb.finish()
    print("✓ wandb finished")

## Step 5: Visualization and Analysis

Plot training history and visualize embeddings using PCA.

In [None]:
# Plot training history
fig, axes = plt.subplots(1, 2, figsize=(14, 4))

axes[0].plot(history["train_loss"], label="Train Loss", marker="o")
axes[0].set_xlabel("Epoch")
axes[0].set_ylabel("Loss")
axes[0].set_title("Training Loss")
axes[0].legend()
axes[0].grid(True, alpha=0.3)

axes[1].plot(history["val_recall_1"], label="Recall@1", marker="o")
axes[1].plot(history["val_accuracy"], label="Accuracy", marker="s")
axes[1].set_xlabel("Epoch")
axes[1].set_ylabel("Metric")
axes[1].set_title("Validation Metrics")
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(output_dir / "training_history.png", dpi=150, bbox_inches="tight")
plt.show()

print(f"✓ Saved training history to {output_dir / 'training_history.png'}")

# Final evaluation and visualization
print("\nFinal evaluation and visualization...")
embeddings, labels = compute_embeddings_and_predictions(
    model, val_loader, device=DEVICE
)

cm = compute_confusion_matrix(embeddings, labels, metric=METRIC)
visualize_confusion_matrix(
    cm, output_path=str(output_dir / "confusion_matrix.png")
)
visualize_pca(embeddings, labels, output_path=str(output_dir / "pca.png"))

# Final test evaluation
print("\n" + "=" * 70)
print("STEP 6: Final evaluation on test set...")
print("=" * 70)
test_metrics = final_eval(
    model, test_loader, device=DEVICE, metric=METRIC
)
print(f"Test Recall@1: {test_metrics.get('recall_at_1', 0):.4f}")
print(f"Test Recall@5: {test_metrics.get('recall_at_5', 0):.4f}")
print(f"Test Recall@10: {test_metrics.get('recall_at_10', 0):.4f}")
print(f"Test Accuracy: {test_metrics.get('accuracy', 0):.4f}")

if USE_WANDB:
    wandb.log({
        "test_recall_at_1": test_metrics.get("recall_at_1", 0),
        "test_recall_at_5": test_metrics.get("recall_at_5", 0),
        "test_recall_at_10": test_metrics.get("recall_at_10", 0),
        "test_accuracy": test_metrics.get("accuracy", 0),
    })

print(f"\n✓ All results saved to {output_dir}")