<a href="https://colab.research.google.com/github/Hamza-Ali0237/PyTorch-Projects/blob/main/Intermediate/PyTorch-CelebA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Multi-Label Image Classification

## Setup & Configuration

In [None]:
!pip install torchmetrics datasets pillow

In [None]:
# Import Libraries
import torch
import torch.nn as nn
import torchvision as tv
import torchmetrics as tm
from datasets import load_dataset
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader

import math
import random
import numpy as np
import seaborn as sns
from PIL import Image
from collections import Counter
import matplotlib.pyplot as plt

In [None]:
# Configuration Dictionary
cfg = {
    # Reproducibility
    "seed": 42,

    # Computing Device
    "device": "cuda" if torch.cuda.is_available() else "cpu",

    # Data Settings
    "root_dir": "./data",
    "image_size": 224,
    "batch_size": 32,
    "num_workers": 2,
    "pin_memory": True,
    "num_classes": 40,

    # Training Hyperparameters
    "epochs": 25,
    "lr": 1e-3,
    "weight_decay": 1e-4,
    "label_smoothing": 0.0,
}

In [None]:
# Hyperparameters
epochs = cfg["epochs"]
lr = cfg["lr"]
weight_decay = cfg["weight_decay"]
num_classes = cfg["num_classes"]
image_size = cfg["image_size"]

In [None]:
# Setup Device
device = cfg["device"]

In [None]:
# Set Seed For Reproducibilty
SEED = cfg["seed"]

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

if torch.cuda.is_available():
  torch.cuda.manual_seed(SEED)
  torch.cuda.manual_seed_all(SEED)

# Ensure Deterministic Behaviour
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

## Transform & Load Data

In [None]:
# Normalization Values
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
image_size = cfg["image_size"]

# Training Transforms (Augmentation+ Normalization)
train_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomCrop(image_size),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(
        brightness = 0.1, contrast = 0.1, saturation = 0.1, hue = 0.05
    ),
    transforms.ToTensor(),
    transforms.Normalize(
        mean = mean, std = std
    )
])

# Validation & Test Transforms
val_test_transforms = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize(
        mean = mean, std = std
    )
])

In [None]:
# Load Data (Using HuggingFace Instead Of PyTorch)
celeba_data = load_dataset("flwrlabs/celeba")
print(celeba_data)

In [None]:
# Define Custom Dataset Class
class CelebADataset(Dataset):
  def __init__(self, hf_split, transform = None):
    self.dataset = hf_split
    self.transform = transform
    self.attr_names = [col for col in self.dataset.column_names if col not in ["image", "celeb_id"]]

  def __len__(self):
    return len(self.dataset)

  def __getitem__(self, idx):
    # Image
    image = self.dataset[idx]["image"]
    if self.transform:
      image = self.transform(image)

    # Multi-label Target Vector (40 Attributes)
    label = torch.tensor([int(self.dataset[idx][attr] == 1) for attr in self.attr_names], dtype = torch.float32)

    return image, label

In [None]:
# Create Splits
train_dataset = CelebADataset(celeba_data["train"], transform = train_transforms)
val_dataset = CelebADataset(celeba_data["valid"], transform = val_test_transforms)
test_dataset = CelebADataset(celeba_data["test"], transform = val_test_transforms)

In [None]:
# Define DataLoaders
batch_size = cfg["batch_size"]
num_workers = cfg["num_workers"]
pin_memory = cfg["pin_memory"]

train_loader = DataLoader(
    train_dataset, batch_size = batch_size,
    shuffle = True, num_workers = num_workers,
    pin_memory = pin_memory
)

val_loader = DataLoader(
    val_dataset, batch_size = batch_size,
    shuffle = False, num_workers = num_workers,
    pin_memory = pin_memory
)

test_loader = DataLoader(
    test_dataset, batch_size = batch_size,
    shuffle = False, num_workers = num_workers,
    pin_memory = pin_memory
)

## Training & Evaluation Functions

In [None]:
# Define Evaluation Metrics

# Accuracy
accuracy = tm.classification.MultilabelAccuracy(
    num_labels = num_classes,
    threshold = 0.5
).to(device)

# F1 Score
f1_macro = tm.classification.F1Score(
    task = "multilabel",
    num_labels = num_classes,
    average = "macro",
    threshold = 0.5
).to(device)

f1_micro = tm.classification.F1Score(
    task = "multilabel",
    num_labels = num_classes,
    average = "micro",
    threshold = 0.5
).to(device)

# Area Under ROC Curve (AUROC)
auroc_macro = tm.classification.MultilabelAUROC(
    num_labels = num_classes,
    average = "macro"
).to(device)

# Mean Average Precision (mAP)
map_macro = tm.classification.MultilabelAveragePrecision(
    num_labels=num_classes,
    average="macro"
).to(device)

map_micro = tm.classification.MultilabelAveragePrecision(
    num_labels=num_classes,
    average="micro"
).to(device)

metrics_dict = {
    "accuracy": accuracy,
    "f1_macro": f1_macro,
    "f1_micro": f1_micro,
    "auroc_macro": auroc_macro,
    "map_macro": map_macro,
    "map_micro": map_micro
}

In [None]:
# Model Training & Validation Function
def train_eval_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs, device, metrics_dict):

  # Store History
  history = {
      "train_loss": [], "val_loss": [],
      "train_acc": [], "val_acc": [],
      "val_f1_macro": [], "val_f1_micro": [],
      "val_auroc_macro": [], "val_map_macro": [], "val_map_micro": []
  }

  for epoch in range(num_epochs):

    # Training Phase
    model.train()
    epoch_train_loss = 0.0

    for images, labels in train_loader:
      images, labels = images.to(device), labels.to(device)

      optimizer.zero_grad()
      outputs = model(images)
      loss = criterion(outputs, labels)
      loss.backward()
      optimizer.step()
      if scheduler is not None:
        scheduler.step()

      epoch_train_loss += loss.item()

      preds = torch.sigmoid(outputs)
      metrics_dict["accuracy"](preds, labels)

    epoch_train_loss /= len(train_loader)
    epoch_train_acc = metrics_dict["accuracy"].compute().item()
    metrics_dict["accuracy"].reset()

    history["train_loss"].append(epoch_train_loss)
    history["train_acc"].append(epoch_train_acc)

    # Validation Phase
    model.eval()
    epoch_val_loss = 0.0

    with torch.no_grad():
      for images, labels in val_loader:
        images, labels = images.to(device), labels.to(device)

        outputs = model(images)
        loss = criterion(outputs, labels)
        epoch_val_loss += loss.item()

        preds = torch.sigmoid(outputs)

        # Update All Validation Metrics
        metrics_dict["accuracy"](preds, labels)
        metrics_dict["f1_macro"](preds, labels)
        metrics_dict["f1_micro"](preds, labels)
        metrics_dict["auroc_macro"](preds, labels)
        metrics_dict["map_macro"](preds, labels)
        metrics_dict["map_micro"](preds, labels)

    epoch_val_loss /= len(val_loader)
    epoch_val_acc = metrics_dict["accuracy"].compute().item()
    epoch_val_f1_macro = metrics_dict["f1_macro"].compute().item()
    epoch_val_f1_micro = metrics_dict["f1_micro"].compute().item()
    epoch_val_auroc_macro = metrics_dict["auroc_macro"].compute().item()
    epoch_val_map_macro = metrics_dict["map_macro"].compute().item()
    epoch_val_map_micro = metrics_dict["map_micro"].compute().item()

    # Reset Metrics
    for metric in metrics_dict.value():
      metric.reset()

    # Save History
    history["val_loss"].append(epoch_val_loss)
    history["val_acc"].append(epoch_val_acc)
    history["val_f1_macro"].append(epoch_val_f1_macro)
    history["val_f1_micro"].append(epoch_val_f1_micro)
    history["val_auroc_macro"].append(epoch_val_auroc_macro)
    history["val_map_macro"].append(epoch_val_map_macro)
    history["val_map_micro"].append(epoch_val_map_micro)

    # Print Progress
    print(
        f"Epoch [{epoch+1}/{num_epochs}] "
        f"Train Loss: {epoch_train_loss:.4f}, Acc: {epoch_train_acc:.4f} | "
        f"Val Loss: {epoch_val_loss:.4f}, Acc: {epoch_val_acc:.4f}, "
        f"F1_macro: {epoch_val_f1_macro:.4f}, F1_micro: {epoch_val_f1_micro:.4f}, "
        f"AUROC_macro: {epoch_val_auroc_macro:.4f}, "
        f"mAP_macro: {epoch_val_map_macro:.4f}, mAP_micro: {epoch_val_map_micro:.4f}"
    )

    return history

In [None]:
# Plot Training History Function
def plot_train_history(history):

    # Loss
    plt.figure(figsize=(8, 6))
    plt.plot(history["train_loss"], label="Train Loss")
    plt.plot(history["val_loss"], label="Validation Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Training vs Validation Loss")
    plt.legend()
    plt.grid(True)
    plt.show()

    # Accuracy
    plt.figure(figsize=(8, 6))
    plt.plot(history["train_acc"], label="Train Accuracy")
    plt.plot(history["val_acc"], label="Validation Accuracy")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.title("Training vs Validation Accuracy")
    plt.legend()
    plt.grid(True)
    plt.show()

    # F1 Scores
    if "val_f1_macro" in history:
        plt.figure(figsize=(8, 6))
        plt.plot(history["val_f1_macro"], label="Val F1 Macro")
        plt.plot(history["val_f1_micro"], label="Val F1 Micro")
        plt.xlabel("Epoch")
        plt.ylabel("F1 Score")
        plt.title("Validation F1 Scores")
        plt.legend()
        plt.grid(True)
        plt.show()

    # AUROC
    if "val_auroc_macro" in history:
        plt.figure(figsize=(8, 6))
        plt.plot(history["val_auroc_macro"], label="Val AUROC Macro")
        plt.xlabel("Epoch")
        plt.ylabel("AUROC")
        plt.title("Validation AUROC")
        plt.legend()
        plt.grid(True)
        plt.show()

    # mAP
    if "val_map_macro" in history:
        plt.figure(figsize=(8, 6))
        plt.plot(history["val_map_macro"], label="Val mAP Macro")
        plt.plot(history["val_map_micro"], label="Val mAP Micro")
        plt.xlabel("Epoch")
        plt.ylabel("mAP")
        plt.title("Validation mAP")
        plt.legend()
        plt.grid(True)
        plt.show()

In [None]:
# Model Testing Function
def test_model(model, test_loader, criterion, device, metrics_dict):
  model.eval()
  test_loss = 0.0

  with torch.no_grad():
    for images, labels in test_loader:
      images, labels = images.to(device), labels.to(device)

      outputs = model(images)
      loss = criterion(outputs, labels)
      test_loss += loss.item()

      preds = torch.sigmoid(outputs)

      # Update Metrics
      metrics_dict["accuracy"](preds, labels)
      metrics_dict["f1_macro"](preds, labels)
      metrics_dict["f1_micro"](preds, labels)
      metrics_dict["auroc_macro"](preds, labels)
      metrics_dict["map_macro"](preds, labels)
      metrics_dict["map_micro"](preds, labels)

  test_loss /= len(test_loader)

  # Compute Metrics
  test_results = {
      "loss": test_loss,
      "accuracy": metrics_dict["accuracy"].compute().item(),
      "f1_macro": metrics_dict["f1_macro"].compute().item(),
      "f1_micro": metrics_dict["f1_micro"].compute().item(),
      "auroc_macro": metrics_dict["auroc_macro"].compute().item(),
      "map_macro": metrics_dict["map_macro"].compute().item(),
      "map_micro": metrics_dict["map_micro"].compute().item(),
  }

  # Reset Metics
  for metric in metrics_dict.values():
    metric.reset()

  # Print Results
  print(
      f"Test Results:\n"
      f"Loss: {test_results['loss']:.4f}\n"
      f"Accuracy: {test_results['accuracy']:.4f}\n"
      f"F1 Macro: {test_results['f1_macro']:.4f}, "
      f"F1 Micro: {test_results['f1_micro']:.4f}\n"
      f"AUROC Macro: {test_results['auroc_macro']:.4f}\n"
      f"mAP Macro: {test_results['map_macro']:.4f}, "
      f"mAP Micro: {test_results['map_micro']:.4f}"
  )

  return test_results

In [None]:
# Define Optimizer, Criterion, Scheduler
def train_utils(model, lr, weight_decay, train_loader, num_epochs):

  # Define Optimizer
  optimizer = torch.optim.AdamW(model.parameters(), lr = lr, weight_decay = weight_decay)

  # Define Criterion
  criterion = nn.BCEWithLogitsLoss()

  # Define OneCycleLR Scheduler
  steps_per_epoch = math.ceil(len(train_loader))

  scheduler = torch.optim.lr_scheduler.OneCycleLR(
      optimizer,
      max_lr = lr,
      steps_per_epoch = steps_per_epoch,
      epochs = num_epochs
  )

  return optimizer, criterion, scheduler

## CNN Class

In [None]:
class MultiLabelCNN(nn.Module):
  def __init__(self, num_classes, image_size):
    super(MultiLabelCNN, self).__init__()

    self.features = nn.Sequential(
        nn.Conv2d(3, 32, kernel_size = 3, stride = 1, padding = 1),
        nn.BatchNorm2d(32),
        nn.ReLU(),
        nn.MaxPool2d(2),

        nn.Conv2d(32, 64, kernel_size = 3, stride = 1, padding = 1),
        nn.BatchNorm2d(64),
        nn.ReLU(),
        nn.MaxPool2d(2),

        nn.Conv2d(64, 128, kernel_size = 3, stride = 1, padding = 1),
        nn.BatchNorm2d(128),
        nn.ReLU(),
        nn.MaxPool2d(2),
    )

    # Dynamically Calculate Flattened Size
    with torch.no_grad():
      dummy = torch.zeros(1, 3, image_size, image_size)
      feat_dim = self.features(dummy).view(1, -1).shape[1]

    self.classifier = nn.Sequential(
        nn.Flatten(),
        nn.Linear(feat_dim, 512),
        nn.ReLU(),
        nn.Dropout(0.5),
        nn.Linear(512, num_classes)
    )

  def forward(self, x):
    x = self.features(x)
    x = self.classifier(x)
    return x

In [None]:
# Instantiate Model Class
model = MultiLabelCNN(num_classes, image_size)
model = model.to(device)

### Train MultiLabelCNN Model

In [None]:
# Instantiate Training Utils
optimizer, criterion, scheduler = train_utils(
    model, lr, weight_decay, train_loader, epochs
)

In [None]:
# Train & Evaluate (On Validation Split) Model
history = train_eval_model(
    model, train_loader, val_loader,
    criterion, optimizer, scheduler,
    epochs, device, metrics_dict
)

In [None]:
# Plot Training History
plot_train_history(history)

### Evaluate MultiLabelCNN Model

In [None]:
# Evaluate Model On Test Data
test_results = test_model(model, test_loader, criterion, device, metrics_dict)

In [None]:
print(test_results)