<a href="https://colab.research.google.com/github/Hamza-Ali0237/PyTorch-Projects/blob/main/Intermediate/PyTorch-CelebA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Multi-Label Image Classification

## Setup & Configuration

In [None]:
!pip install torchmetrics

In [None]:
# Import Libraries
import torch
import torch.nn as nn
import torchvision as tv
import torchmetrics as tm
from torchvision import transforms
from torchvision.datasets import CelebA

import math
import random
import numpy as np
import seaborn as sns
from collections import Counter
import matplotlib.pyplot as plt

In [None]:
# Configuration Dictionary
cfg = {
    # Reproducibility
    "seed": 42,

    # Computing Device
    "device": "cuda" if torch.cuda.is_available() else "cpu",

    # Data Settings
    "root_dir": "./data",
    "image_size": 224,
    "batch_size": 32,
    "num_workers": 2,
    "pin_memory": True,
    "num_classes": 40,

    # Training Hyperparameters
    "epochs": 25,
    "lr": 1e-3,
    "weight_decay": 1e-4,
    "label_smoothing": 0.0,
}

In [None]:
device = cfg["device"]

In [None]:
# Set Seed For Reproducibilty
SEED = cfg["seed"]

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

if torch.cuda.is_available():
  torch.cuda.manual_seed(SEED)
  torch.cuda.manua nl_seed_all(SEED)

# Ensure Deterministic Behaviour
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

## Transform & Load Data

In [None]:
# Normalization Values
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
image_size = cfg["image_size"]

# Training Transforms (Augmentation+ Normalization)
train_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomCrop(image_size),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(
        brightness = 0.1, contrast = 0.1, saturation = 0.1, hue = 0.05
    ),
    transforms.ToTensor(),
    transforms.Normalize(
        mean = mean, std = std
    )
])

# Validation & Test Transforms
val_test_transforms = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize(
        mean = mean, std = std
    )
])

In [None]:
# Load Data
root = cfg["root_dir"]

train_dataset = CelebA(
    root = root, split = "train",
    transform = train_transforms, download = True
)

val_dataset = CelebA(
    root = root, split = "val",
    transform = val_test_transforms, download = True
)

test_dataset = CelebA(
    root = root, split = "test",
    transform = val_test_transforms, download = True
)

In [None]:
# Define DataLoaders
batch_size = cfg["batch_size"]
num_workers = cfg["num_workers"]
pin_memory = cfg["pin_memory"]

train_loader = DataLoader(
    train_dataset, batch_size = batch_size,
    shuffle = True, num_workers = num_workers,
    pin_memory = pin_memory
)

val_loader = DataLoader(
    val_dataset, batch_size = batch_size,
    shuffle = False, num_workers = num_workers,
    pin_memory = pin_memory
)

test_loader = DataLoader(
    test_dataset, batch_size = batch_size,
    shuffle = False, num_workers = num_workers,
    pin_memory = pin_memory
)

## CNN Class

In [None]:
class MultiLabelCNN(nn.Module):
  def __init__(self, num_classes, image_size):
        super(MultiLabelClass, self).__init__()

    self.features = nn.Sequential(
        nn.Conv2d(3, 32, kernel_size = 3, stride = 1, padding = 1),
        nn.BatchNorm2d(32),
        nn.ReLU(),
        nn.MaxPool2d(2)

        nn.Conv2d(32, 64, kernel_size = 3, stride = 1, padding = 1),
        nn.BatchNorm2d(32),
        nn.ReLU(),
        nn.MaxPool2d(2)

        nn.Conv2d(64, 128, kernel_size = 3, stride = 1, padding = 1),
        nn.BatchNorm2d(32),
        nn.ReLU(),
        nn.MaxPool2d(2)
    )

    # Dynamically Calculate Flattened Size
    with torch.no_grad():
      dummy = torch.zeros(1, 3, image_size, image_size)
      feat_dim = self.features(dummy).view(1, -1).shape[1]

    self.classifier = nn.Sequential(
        self.Flatten(),
        self.Linear(feat_dim, 512),
        self.ReLU(),
        self.Dropout(0.5),
        self.Linear(512, num_classes)
    )

  def forward(self, x):
    x = self.features(x)
    x = self.classifier(x)
    return x

## Training & Evaluation Functions

In [None]:
# Define Evaluation Metrics

# Accuracy
accuracy = tm.MultilabelAccuracy(
    num_labels = num_classes,
    threshold = 0.5
).to(device)

# F1 Score
f1_macro = tm.MultilabelF1(
    num_labels = num_classes,
    average = "macro",
    threshold = 0.5
).to(device)

f1_micro = tm.MultilabelF1(
    num_labels = num_classes,
    average = "micro",
    threshold = 0.5
).to(device)

# Area Under ROC Curve (AUROC)
auroc_macro = tm.MultilabelAUROC(
    num_labels = num_classes,
    average = "macro"
).to(device)

# Mean Average Precision (mAP)
map_macro = tm.MultilabelAveragePrecision(
    num_labels=num_classes,
    average="macro"
).to(device)

map_micro = tm.MultilabelAveragePrecision(
    num_labels=num_classes,
    average="micro"
).to(device)

metrics_dict = {
    "accuracy": accuracy,
    "f1_macro": f1_macro,
    "f1_micro": f1_micro,
    "auroc_macro": auroc_macro,
    "map_macro": map_macro,
    "map_micro": map_micro
}

In [None]:
# Model Training & Validation Function
def train_vaL_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs, device, metrics_dict):

  # Store History
  history = {
      "train_loss": [], "val_loss": [],
      "train_acc": [], "val_acc": [],
      "val_f1_macro": [], "val_f1_micro": [],
      "val_auroc_macro": [], "val_map_macro": [], "val_map_micro": []
  }

  for epoch in range(num_epochs):

    # Training Phase
    model.train()
    epoch_train_loss = 0.0

    for images, labels in train_loader:
      images, labels = images.to(device), labels.to(device)

      optimizer.zero_grad()
      outputs = model(images)
      loss = criterion(output, labels)
      loss.backward()
      optimzer.step()
      if scheduler is not None:
        scheduler.step()

      epoch_train_loss += loss.item()

      preds = torch.sigmoid(outputs)
      metrics_dict["accuracy"](preds, labels)

    epoch_train_loss /= len(train_loader)
    epoch_train_acc = metrics_dict["accuracy"].compute().item()
    metrics_dict["accuracy"].reset()

    history["train_loss"].append(epoch_train_loss)
    history["train_acc"].append(epoch_train_acc)

    # Validation Phase
    model.eval()
    epoch_val_loss = 0.0

    with torch.no_grad():
      for images, labels in val_loader:
        images, labels = images.to(device), labels.to(device)

        outputs = model(images)
        loss = criterion(outputs, labels)
        epoch_val_loss += loss.item()

        preds = torch.sigmoid(outputs)

        # Update All Validation Metrics
        metrics_dict["accuracy"](preds, labels)
        metrics_dict["f1_macro"](preds, labels)
        metrics_dict["f1_micro"](preds, labels)
        metrics_dict["auroc_macro"](preds, labels)
        metrics_dict["map_macro"](preds, labels)
        metrics_dict["map_micro"](preds, labels)

    epoch_val_loss /= len(val_loader)
    epoch_val_acc = metrics_dict["accuracy"].compute().item()
    epoch_val_f1_macro = metrics_dict["f1_macro"].compute().item()
    epoch_val_f1_micro = metrics_dict["f1_micro"].compute().item()
    epoch_val_auroc_macro = metrics_dict["auroc_macro"].compute().item()
    epoch_val_map_macro = metrics_dict["map_macro"].compute().item()
    epoch_val_map_micro = metrics_dict["map_micro"].compute().item()

    # Reset Metrics
    for metric in metrics_dict.value():
      metric.reset()

    # Save History
    history["val_loss"].append(epoch_val_loss)
    history["val_acc"].append(epoch_val_acc)
    history["val_f1_macro"].append(epoch_val_f1_macro)
    history["val_f1_micro"].append(epoch_val_f1_micro)
    history["val_auroc_macro"].append(epoch_val_auroc_micro)
    history["val_map_macro"].append(epoch_val_map_macro)
    history["val_map_micro"].append(epoch_val_map_micro)

    # Print progress
    print(
        f"Epoch [{epoch+1}/{num_epochs}] "
        f"Train Loss: {epoch_train_loss:.4f}, Acc: {epoch_train_acc:.4f} | "
        f"Val Loss: {epoch_val_loss:.4f}, Acc: {epoch_val_acc:.4f}, "
        f"F1_macro: {epoch_val_f1_macro:.4f}, F1_micro: {epoch_val_f1_micro:.4f}, "
        f"AUROC_macro: {epoch_val_auroc_macro:.4f}, "
        f"mAP_macro: {epoch_val_map_macro:.4f}, mAP_micro: {epoch_val_map_micro:.4f}"
    )

    return history