# HW2P2: Image Recognition and Verification

# Libraries

In [None]:
!nvidia-smi # Run this to see what GPU you have

In [None]:
!pip install wandb --quiet # Install WandB
!pip install pytorch_metric_learning --quiet #Install the Pytorch Metric Library
!pip install torchsummaryX==1.1.0 wandb --quiet
!pip install torchvision --quiet
!pip install --upgrade kaggle==1.6.17 --force-reinstall --no-deps

In [None]:
import torch
import torch.nn as nn
from torchsummaryX import summary
import torchvision
from torchvision.utils import make_grid
from torchvision import transforms
import torchvision.transforms.v2 as T
import torchvision.models as models
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import torch.nn.functional as F
import os
import gc
from tqdm import tqdm
from PIL import Image
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score
from sklearn import metrics as mt
from scipy.optimize import brentq
from scipy.interpolate import interp1d
import glob
import wandb
import matplotlib.pyplot as plt
from pytorch_metric_learning import samplers, losses

import csv
import random

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Device: ", DEVICE)


In [240]:
# from google.colab import drive # Link to your drive if you are not using Colab with GCP
# drive.mount('/content/drive') # Models in this HW take a long time to get trained and make sure to save it here

# Folder

In [241]:
# # change in kaggle
# data_dir = '/kaggle/input/11785-hw-2-p-2-face-verification-spring-2025/HW2p2_S25'

In [242]:
# # run on mac
# !mkdir 'content/data'
# !kaggle competitions download -c 11785-hw-2-p-2-face-verification-spring-2025
# !unzip -qo '11785-hw-2-p-2-face-verification-spring-2025' -d 'content/data'
data_dir = 'content/data/HW2p2_S25' # change in kaggle

# Config

In [None]:
!mkdir 'checkpoint'

config = {
    # machine resources
    'num_workers': 5, # change in kaggle
    'pin_memory': True, # change in kaggle

    # data paths
    'cls_data_dir': os.path.join(data_dir, "cls_data"), #TODO: Provide path of classification directory
    'ver_data_dir': os.path.join(data_dir, "ver_data"), #TODO: Provide path of verification directory
    'val_pairs_file': os.path.join(data_dir,"val_pairs.txt"), #TODO: Provide path of text file containing val pairs for verification
    'test_pairs_file': os.path.join(data_dir,"test_pairs.txt"), #TODO: Provide path of text file containing test pairs for verification
    
    # checkpoint paths
    'checkpoint_dir': "checkpoint", #TODO: Checkpoint directory
    
    # data size, change in kaggle
    'num_classes': 8631, #Dataset contains 8631 classes for classification, reduce this number if you want to train on a subset, but only for train dataset and not on val dataset
    
    # data augmentation
    'augment': True,
    'crop': {'if_on': True, 'scale': (0.8, 1.2)},
    'flip': {'if_on': True, 'probability': 0.5},
    'rotation': {'if_on': True, 'degrees': 15},
    'colorjitter': {'if_on': False, 'brightness': (0.9, 1.1), 'contrast': (0.9, 1.1), 'saturation': (0.9, 1.1), 'hue': (-0.02, 0.02)},
    'greyscale': {'if_on': False, 'probability': 0.05},
    'affine': {'if_on': False, 'degrees': 0, 'translate': (0.1, 0.1), 'scale': (0.9, 1.1)},
    'perspective': {'if_on': False, 'distortion_scale': 0.2, 'probability': 0.2},
    'normalize_mean': [0.5, 0.5, 0.5], 
    'normalize_std': [0.5, 0.5, 0.5], 

    # model
    'model': 'MobileFaceNet', # ['Network', 'DeeperNetwork', 'ConvNeXt', 'ResNet18', 'ResNet34', 'MobileFaceNetLike', 'EfficientNetB5Face']
    'model_params': {'dropout_rate': 0.2, "activation": 'prelu', "use_attention": True, "channel_reduction": 4, "spatial_kernel_size": 7}, # {'dropout_rate': 0.2, "width_mult": 0.8},
    'model_train': {"use_adversarial_training": False, "adversarial_training_alpha": 0.05, "adversarial_training_beta": 0.05},
    'epochs': 60, # 20 epochs is recommended ONLY for the early submission - you will have to train for much longer typically.
    'batch_size': 256, # Increase this if your GPU can handle it

    # loss
    'loss': 'arcface', # [None, 'triplet', 'npair', 'cosface', 'arcface', 'combined']
    # cross entropy
    'label_smoothing': 0.1, # Label smoothing for classification loss
    'cross_entropy_weight': 1,
    'other_loss_weight': 1.5,
    # triplet
    'margin': 0.2,
    # npair
    # cosface (scale, additive_margin), arcface (scale, angular_margin), combined (scale, angular_margin, additive_margin)
    'scale': 30,
    'additive_margin': 0.5,
    'angular_margin': 0.5,
    'combined_arcface_weight': 0.5,

    # optimizer
    'lr': 1e-3,
    'weight_decay': 1e-4,
    # 'scheduler': 'CosineAnnealingWarmRestarts',
    # 'scheduler_params': {'T_0': 50, 'T_mult': 1, 'eta_min': 1e-8},
    'scheduler': 'CosineAnnealingLR',
    'scheduler_params': {'T_max': 100, 'eta_min': 1e-9},

    # wandb
    'wandb_name': 'aug(norm0.5+rotation+flip0.5+crop1.2), model(MobileFaceNet,attention), loss(ce0.2+arcface30+0.5(1.5)), lr(3-9,l00), bs(256)',
    'wandb_init': True,
    'wandb_id': None,
}

# Dataset

### Transform

In [244]:
def create_transforms(image_size: int = 112, augment: bool = True) -> T.Compose:
    """Create transform pipeline for face recognition."""

    # Step 1: Basic transformations
    transform_list = [
        # Resize the image to the desired size (image_size x image_size)
        T.Resize((image_size, image_size)),

        # Convert PIL Image to tensor
        T.ToTensor(),

        # Convert image to float32 and scale the pixel values to [0, 1]
        T.ToDtype(torch.float32, scale=True),
    ]

    # Step 2: Data augmentation (optional, based on `augment` argument)
    if augment:  # This block will be executed if `augment=True`
        transform_list.extend([
            transforms.RandomResizedCrop(size=image_size, scale=config['crop']['scale']) if config['crop']['if_on'] else None, 
            transforms.RandomHorizontalFlip(p=config['flip']['probability']) if config['flip']['if_on'] else None,
            transforms.RandomRotation(degrees=config['rotation']['degrees']) if config['rotation']['if_on'] else None,
            transforms.ColorJitter(brightness=config['colorjitter']['brightness'], contrast=config['colorjitter']['contrast'], saturation=config['colorjitter']['saturation'], hue=config['colorjitter']['hue']) if config['colorjitter']['if_on'] else None,
            transforms.RandomGrayscale(p=config['greyscale']['probability']) if config['greyscale']['if_on'] else None,
            torchvision.transforms.RandomAffine(degrees=config['affine']['degrees'], translate=config['affine']['translate'], scale=config['affine']['scale']) if config['affine']['if_on'] else None,
            torchvision.transforms.RandomPerspective(distortion_scale=config['perspective']['distortion_scale'], p=config['perspective']['probability']) if config['perspective']['if_on'] else None,
        ])

    # Step 3: Standard normalization for image recognition tasks
    # The Normalize transformation requires mean and std values for each channel (R, G, B).
    # Here, we are normalizing the pixel values to have a mean of 0.5 and std of 0.5 for each channel.
    transform_list.extend([
        T.Normalize(config['normalize_mean'], config['normalize_std'])  # Standard mean and std for face recognition tasks
    ])

    # Remove None values from the list
    transform_list = [t for t in transform_list if t is not None]
        
    # Return the composed transformation pipeline
    return T.Compose(transform_list)


### Classification Datasets and Dataloaders

In [None]:
class ImageDataset(torch.utils.data.Dataset):
    """Custom dataset for loading image-label pairs."""
    def __init__(self, root, transform, num_classes=config['num_classes']):
        """
        Args:
            root (str): Path to the directory containing the images folder.
            transform (callable): Transform to be applied to the images.
            num_classes (int, optional): Number of classes to keep. If None, keep all classes.
        """
        self.root = root
        self.labels_file = os.path.join(self.root, "labels.txt")
        self.transform = transform
        self.image_paths = []
        self.labels = []
        self.classes = set()

        # Read image-label pairs from the file
        with open(self.labels_file, 'r') as f:
            lines = f.readlines()

        lines = sorted(lines, key=lambda x: int(x.strip().split(' ')[-1]))

        # Get all unique labels first
        all_labels = sorted(set(int(line.strip().split(' ')[1]) for line in lines))

         # Select subset of classes if specified
        if num_classes is not None:
            selected_classes = set(all_labels[:num_classes])
        else:
            selected_classes = set(all_labels)

        # Store image paths and labels with a progress bar
        for line in tqdm(lines, desc="Loading dataset"):
            img_path, label = line.strip().split(' ')
            label = int(label)

            # Only add if label is in selected classes
            if label in selected_classes:
                self.image_paths.append(os.path.join(self.root, 'images', img_path))
                self.labels.append(label)
                self.classes.add(label)

        assert len(self.image_paths) == len(self.labels), "Images and labels mismatch!"

        # Convert classes to a sorted list
        self.classes = sorted(self.classes)

    def __len__(self):
        """Returns the total number of samples."""
        return len(self.image_paths)

    def __getitem__(self, idx):
        """
        Args:
            idx (int): Index of the sample to retrieve.

        Returns:
            tuple: (transformed image, label)
        """
        # Load and transform image on-the-fly
        image = Image.open(self.image_paths[idx]).convert('RGB')
        image = self.transform(image)
        label = self.labels[idx]
        return image, label

gc.collect()

In [None]:
# train transforms
train_transforms = create_transforms(augment=config['augment'])

# val transforms
val_transforms   = create_transforms(augment=False)

In [None]:
cls_train_dataset = ImageDataset(root = os.path.join(config['cls_data_dir'],'train'), transform=train_transforms, num_classes=config['num_classes'])
cls_val_dataset   = ImageDataset(root = os.path.join(config['cls_data_dir'],'dev'), transform=val_transforms, num_classes=config['num_classes'])
cls_test_dataset  = ImageDataset(root = os.path.join(config['cls_data_dir'],'test'), transform=val_transforms, num_classes=config['num_classes'])

# assert cls_train_dataset.classes == cls_val_dataset.classes == cls_test_dataset.classes, "Class mismatch!"

# Check dataset sizes
print(f"Training set size: {len(cls_train_dataset)}")
print(f"Validation set size: {len(cls_val_dataset)}")
print(f"Test set size: {len(cls_test_dataset)}")

# Dataloaders
cls_train_loader = DataLoader(cls_train_dataset, batch_size=config['batch_size'], shuffle=True,  num_workers=config['num_workers'], pin_memory=config['pin_memory'])
cls_val_loader   = DataLoader(cls_val_dataset,   batch_size=config['batch_size'], shuffle=False, num_workers=config['num_workers'], pin_memory=config['pin_memory'])
cls_test_loader  = DataLoader(cls_test_dataset,  batch_size=config['batch_size'], shuffle=False, num_workers=config['num_workers'], pin_memory=config['pin_memory'])

# Check if the dataloader is working
for images, labels in cls_train_loader:
    print(images.shape, labels.shape)
    break

### Verification Dataset and Datatloaders

In [248]:
class ImagePairDataset(torch.utils.data.Dataset):
    """Custom dataset for loading and transforming image pairs."""
    def __init__(self, root, pairs_file, transform):
        """
        Args:
            root (str): Path to the directory containing the images.
            pairs_file (str): Path to the file containing image pairs and match labels.
            transform (callable): Transform to be applied to the images.
        """
        self.root      = root
        self.transform = transform

        self.matches     = []
        self.image1_list = []
        self.image2_list = []

        # Read and load image pairs and match labels
        with open(pairs_file, 'r') as f:
            lines = f.readlines()

        for line in tqdm(lines, desc="Loading image pairs"):
            img_path1, img_path2, match = line.strip().split(' ')
            img1 = Image.open(os.path.join(self.root, img_path1)).convert('RGB')
            img2 = Image.open(os.path.join(self.root, img_path2)).convert('RGB')

            self.image1_list.append(img1)
            self.image2_list.append(img2)
            self.matches.append(int(match))  # Convert match to integer

        assert len(self.image1_list) == len(self.image2_list) == len(self.matches), "Image pair mismatch"

    def __len__(self):
        """Returns the total number of samples."""
        return len(self.image1_list)

    def __getitem__(self, idx):
        """
        Args:
            idx (int): Index of the sample to retrieve.

        Returns:
            tuple: (transformed image1, transformed image2, match label)
        """
        img1 = self.image1_list[idx]
        img2 = self.image2_list[idx]
        match = self.matches[idx]
        return self.transform(img1), self.transform(img2), match

In [249]:
class TestImagePairDataset(torch.utils.data.Dataset):
    """Custom dataset for loading and transforming image pairs."""
    def __init__(self, root, pairs_file, transform):
        """
        Args:
            root (str): Path to the directory containing the images.
            pairs_file (str): Path to the file containing image pairs and match labels.
            transform (callable): Transform to be applied to the images.
        """
        self.root      = root
        self.transform = transform

        self.image1_list = []
        self.image2_list = []

        # Read and load image pairs and match labels
        with open(pairs_file, 'r') as f:
            lines = f.readlines()

        for line in tqdm(lines, desc="Loading image pairs"):
            img_path1, img_path2 = line.strip().split(' ')
            img1 = Image.open(os.path.join(self.root, img_path1)).convert('RGB')
            img2 = Image.open(os.path.join(self.root, img_path2)).convert('RGB')

            self.image1_list.append(img1)
            self.image2_list.append(img2)

        assert len(self.image1_list) == len(self.image2_list), "Image pair mismatch"

    def __len__(self):
        """Returns the total number of samples."""
        return len(self.image1_list)

    def __getitem__(self, idx):
        """
        Args:
            idx (int): Index of the sample to retrieve.

        Returns:
            tuple: (transformed image1, transformed image2, match label)
        """
        img1 = self.image1_list[idx]
        img2 = self.image2_list[idx]
        return self.transform(img1), self.transform(img2)


In [None]:
# Datasets
ver_val_dataset  = ImagePairDataset(root=config['ver_data_dir'], pairs_file=config['val_pairs_file'], transform=val_transforms)
ver_test_dataset = TestImagePairDataset(root=config['ver_data_dir'], pairs_file=config['test_pairs_file'], transform=val_transforms)

# Check dataset sizes
print(f"Validation set size: {len(ver_val_dataset)}")
print(f"Test set size: {len(ver_test_dataset)}")

# Dataloader
ver_val_loader   = DataLoader(ver_val_dataset,  batch_size=config['batch_size'], shuffle=False, num_workers=config['num_workers'], pin_memory=config['pin_memory'])
ver_test_loader  = DataLoader(ver_test_dataset, batch_size=config['batch_size'], shuffle=False, num_workers=config['num_workers'], pin_memory=config['pin_memory'])

# Check if the dataloader is working
for images1, images2, labels in ver_val_loader:
    print(images1.shape, images2.shape, labels.shape)
    break

# EDA and Viz

### Classification Dataset Viz

In [None]:
def show_cls_dataset_samples(train_loader, val_loader, test_loader, samples_per_set=8, figsize=(10, 6)):
    """
    Display samples from train, validation, and test datasets side by side

    Args:
        train_loader: Training data loader
        val_loader: Validation data loader
        test_loader: Test data loader
        samples_per_set: Number of samples to show from each dataset
        figsize: Figure size (width, height)
    """
    def denormalize(x):
        """Denormalize images from [-1, 1] to [0, 1]"""
        return x * 0.5 + 0.5

    def get_samples(loader, n):
        """Get n samples from a dataloader"""
        # torch.manual_seed(39)
        batch = next(iter(loader))
        return batch[0][:n], batch[1][:n]

    # Get samples from each dataset
    train_imgs, train_labels = get_samples(train_loader, samples_per_set)
    val_imgs, val_labels = get_samples(val_loader, samples_per_set)
    test_imgs, test_labels = get_samples(test_loader, samples_per_set)

    # Create figure
    fig, axes = plt.subplots(3, 1, figsize=figsize)

    # Plot each dataset
    for idx, (imgs, labels, title) in enumerate([
        (train_imgs, train_labels, 'Training Samples'),
        (val_imgs, val_labels, 'Validation Samples'),
        (test_imgs, test_labels, 'Test Samples')
    ]):

        # Create grid of images
        grid = make_grid(denormalize(imgs), nrow=8, padding=2)

        # Display grid
        axes[idx].imshow(grid.permute(1, 2, 0).cpu())
        axes[idx].axis('off')
        axes[idx].set_title(title, fontsize=10)

        # Add class labels below images (with smaller font)
        grid_width = grid.shape[2]
        imgs_per_row = min(8, samples_per_set)
        img_width = grid_width // imgs_per_row

        for i, label in enumerate(labels):
            col = i % imgs_per_row  # Calculate column position
            if label<len(train_loader.dataset.classes):
              class_name = train_loader.dataset.classes[label]
            else:
              class_name = f"Class {label} (Unknown)"
            axes[idx].text(col * img_width + img_width/2,
                         grid.shape[1] + 5,
                         class_name,
                         ha='center',
                         va='top',
                         fontsize=6,
                         rotation=45)

    plt.tight_layout()
    plt.show()

show_cls_dataset_samples(cls_train_loader, cls_val_loader, cls_test_loader)

### Ver Dataset Viz

In [None]:
import matplotlib.pyplot as plt
import torch
from torchvision.utils import make_grid

def show_ver_dataset_samples(val_loader, samples_per_set=4, figsize=(12, 8)):
    """
    Display verification pairs from the validation dataset

    Args:
        val_loader: Validation data loader
        samples_per_set: Number of pairs to show from the dataset
        figsize: Figure size (width, height)
    """
    def denormalize(x):
        """Denormalize images from [-1, 1] to [0, 1]"""
        return x * 0.5 + 0.5

    def get_samples(loader, n):
        """Get n samples from a dataloader"""
        batch = next(iter(loader))
        return batch[0][:n], batch[1][:n], batch[2][:n]

    # Get samples from the validation dataset
    val_imgs1, val_imgs2, val_labels = get_samples(val_loader, samples_per_set)

    # Create figure and axis
    fig, ax = plt.subplots(1, 1, figsize=figsize)

    # Create grids for both images in each pair
    grid1 = make_grid(denormalize(val_imgs1), nrow=samples_per_set, padding=2)
    grid2 = make_grid(denormalize(val_imgs2), nrow=samples_per_set, padding=2)

    # Combine the grids vertically
    combined_grid = torch.cat([grid1, grid2], dim=1)

    # Display the combined grid
    ax.imshow(combined_grid.permute(1, 2, 0).cpu())
    ax.axis('off')
    ax.set_title('Validation Pairs', fontsize=10)

    # Determine dimensions for placing the labels
    grid_width = grid1.shape[2]
    img_width = grid_width // samples_per_set

    # Add match/non-match labels for each pair
    for i, label in enumerate(val_labels):
        match_text = "✓ Match" if label == 1 else "✗ Non-match"
        color = 'green' if label == 1 else 'red'

        # Define a background box for the label
        bbox_props = dict(
            boxstyle="round,pad=0.3",
            fc="white",
            ec=color,
            alpha=0.8
        )

        ax.text(i * img_width + img_width / 2,
                combined_grid.shape[1] + 15,  # Position below the images
                match_text,
                ha='center',
                va='top',
                fontsize=8,
                color=color,
                bbox=bbox_props)

    plt.suptitle("Verification Pairs (Top: Image 1, Bottom: Image 2)", y=1.02)
    plt.tight_layout()
    plt.subplots_adjust(bottom=0.05)
    plt.show()

show_ver_dataset_samples(ver_val_loader)


# Model Architecture

### FAQ

**What's a very low early deadline architecture (mandatory early submission)**?

- The very low early deadline architecture is a 5-layer CNN. Keep in mind the parameter limit for this homework is 30M.
- The first convolutional layer has 64 channels, kernel size 7, and stride 4. The next three have 128, 256, 512 and 1024 channels. Each have kernel size 3 and stride 2. Documentation to make convolutional layers: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
- Think about strided convolutions from the lecture, as convolutions with stride = 1 and downsampling. For strided convolution, what padding do you need for preserving the spatial resolution? (Hint => padding = kernel_size // 2) - Think why?
- Each convolutional layer is accompanied by a Batchnorm and ReLU layer.
- Finally, you want to average pool over the spatial dimensions to reduce them to 1 x 1. Use AdaptiveAvgPool2d. Documentation for AdaptiveAvgPool2d: https://pytorch.org/docs/stable/generated/torch.nn.AdaptiveAvgPool2d.html
- Then, remove (Flatten?) these trivial 1x1 dimensions away.
Look through https://pytorch.org/docs/stable/nn.html


**Why does a very simple network have 4 convolutions**?

Input images are 112x112. Note that each of these convolutions downsample. Downsampling 2x effectively doubles the receptive field, increasing the spatial region each pixel extracts features from. Downsampling 32x is standard for most image models.

**Why does a very simple network have high channel sizes**?

Every time you downsample 2x, you do 4x less computation (at same channel size). To maintain the same level of computation, you 2x increase # of channels, which increases computation by 4x. So, balances out to same computation. Another intuition is - as you downsample, you lose spatial information. We want to preserve some of it in the channel dimension.

**What is return_feats?**

It essentially returns the second-to-last-layer features of a given image. It's a "feature encoding" of the input image, and you can use it for the verification task. You would use the outputs of the final classification layer for the classification task. You might also find that the classification outputs are sometimes better for verification too - try both.

### Baseline Model

In [253]:
# why need to add the padding
# stride = downsampling, downsampling = 32 is standard for face recognition
# stride / downsampling = 2, increase 2x of channels
# You might also find that the classification outputs are sometimes better for verification too - try both.???
class Network(torch.nn.Module):

    def __init__(self, num_classes, dropout_rate=0.2):
        super().__init__()

        self.backbone = torch.nn.Sequential(
            # TODO
            torch.nn.Conv2d(3, 64, kernel_size=7, stride=4, padding=3),
            torch.nn.BatchNorm2d(64),
            torch.nn.ReLU(),
            torch.nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            torch.nn.BatchNorm2d(128),
            torch.nn.ReLU(),
            torch.nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
            torch.nn.BatchNorm2d(256),
            torch.nn.ReLU(),
            torch.nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
            torch.nn.BatchNorm2d(512),
            torch.nn.ReLU(),
            torch.nn.Conv2d(512, 1024, kernel_size=3, stride=2, padding=1),
            torch.nn.BatchNorm2d(1024),
            torch.nn.ReLU(),
            torch.nn.AdaptiveAvgPool2d(output_size=(1, 1)),
        )

        self.dropout = torch.nn.Dropout(p=dropout_rate)

        self.cls_layer = torch.nn.Sequential(
            torch.nn.Flatten(), 
            torch.nn.Linear(1024, num_classes),
        )

    def forward(self, x, return_feats=False):

        feats = self.backbone(x)
        feats = self.dropout(feats)
        out = self.cls_layer(feats)

        if return_feats:
            return {"feats": feats, "out": out}
        else:
            return out

### Deeper Neural Network

In [254]:
class DeeperNetwork(torch.nn.Module):

    def __init__(self, num_classes, dropout_rate=0.2):
        super().__init__()

        self.backbone = torch.nn.Sequential(

            torch.nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=2, bias=False),
            torch.nn.BatchNorm2d(64),
            torch.nn.ReLU(),
            
            torch.nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=False),
            torch.nn.BatchNorm2d(128),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2, stride=2),
            
            torch.nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=False),
            torch.nn.BatchNorm2d(256),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2, stride=2),
            
            torch.nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=False),
            torch.nn.BatchNorm2d(512),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2, stride=2),
            
            torch.nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=1, bias=False),
            torch.nn.BatchNorm2d(1024),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2, stride=2),

            torch.nn.Conv2d(1024, 1536, kernel_size=3, stride=1, padding=1, bias=False),
            torch.nn.BatchNorm2d(1536),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(kernel_size=2, stride=2),
            
            torch.nn.AdaptiveAvgPool2d((1, 1)),
            )
        
        self.dropout = torch.nn.Dropout(p=dropout_rate)

        self.cls_layer = torch.nn.Sequential(
            torch.nn.Flatten(),
            
            torch.nn.Linear(1536, 768),
            torch.nn.LayerNorm(768),
            torch.nn.ReLU(),
            
            torch.nn.Linear(768, num_classes)
        )

        self.apply(self._init_weights)

    def forward(self, x, return_feats=False):
        feats = self.backbone(x)
        feats = self.dropout(feats)
        out = self.cls_layer(feats)

        if return_feats:
            return {"feats": feats, "out": out}
        else:
            return out
    
    def _init_weights(self, m):
        if isinstance(m, (torch.nn.Linear, torch.nn.Conv2d)):
            torch.nn.init.kaiming_uniform_(m.weight, mode='fan_out', nonlinearity='relu')    
            if hasattr(m, "bias") and m.bias is not None:
                torch.nn.init.normal_(m.bias)

### ConvNext

In [255]:
class ConvNeXtBlock(torch.nn.Module):
    """ConvNeXt Block with Depthwise Convolution and LayerNorm"""
    def __init__(self, in_channels, expansion=4):
        super().__init__()
        
        self.dwconv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=7, stride=1, padding=3, groups=in_channels)  # Depthwise Conv
        self.norm = torch.nn.GroupNorm(1, in_channels)
        self.pwconv1 = torch.nn.Linear(in_channels, in_channels * expansion)  # Pointwise Conv (1x1)
        self.gelu = torch.nn.GELU()
        self.pwconv2 = torch.nn.Linear(in_channels * expansion, in_channels)  # Project back
        self.residual = torch.nn.Identity()  # Residual Connection

    def forward(self, x):
        identity = x
        x = self.dwconv(x)  # Depthwise Convolution
        x = self.norm(x)
        x = x.permute(0, 2, 3, 1)  # Change (B, C, H, W) → (B, H, W, C) for LayerNorm
        x = self.pwconv1(x)
        x = self.gelu(x)
        x = self.pwconv2(x)
        x = x.permute(0, 3, 1, 2)  # Convert back to (B, C, H, W)
        return identity + x  # Residual Connection


class ConvNeXt(torch.nn.Module):
    def __init__(self, in_channels=3, num_classes=1000, depths=[3, 3, 8, 3], dims=[80, 160, 320, 640], dropout_rate=0.2): 
        super().__init__()

        self.stem = torch.nn.Sequential(
            torch.nn.Conv2d(in_channels, dims[0], kernel_size=4, stride=4),
            torch.nn.GroupNorm(1, dims[0])  # Changed BatchNorm to GroupNorm
        )

        self.stages = torch.nn.ModuleList()
        for i in range(len(depths)):
            stage = torch.nn.Sequential(
                *[ConvNeXtBlock(dims[i]) for _ in range(depths[i])],
                torch.nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2) if i < len(depths)-1 else torch.nn.Identity()
            )
            self.stages.append(stage)

        self.global_avg = torch.nn.AdaptiveAvgPool2d(1)
        
        # Feature Extraction Layer
        self.feature_layer = torch.nn.Sequential(
            torch.nn.Flatten(),
            torch.nn.Linear(dims[-1], 1024),  # Output 1024-d feature embedding
            torch.nn.ReLU()
        )
        self.dropout = torch.nn.Dropout(p=dropout_rate)

        # Classification Head
        self.cls_layer = torch.nn.Sequential(
            torch.nn.Flatten(), 
            torch.nn.Linear(1024, num_classes),
        )
        
    def forward(self, x, return_feats=False):
        x = self.stem(x)
        for stage in self.stages:
            x = stage(x)
        x = self.global_avg(x)

        feats = self.feature_layer(x)  # Extracted features
        feats = self.dropout(feats)  # Apply dropout

        out = self.cls_layer(feats)  # Classification Output

        if return_feats:
            return {"feats": feats, "out": out}
        else:
            return out

### ResNet

In [256]:
class BasicBlock(torch.nn.Module):
    expansion = 1  # Used in ResNet-50+ for increasing channel depth

    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(BasicBlock, self).__init__()

        self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = torch.nn.BatchNorm2d(out_channels)
        self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = torch.nn.BatchNorm2d(out_channels)

        self.downsample = downsample  # Downsampling layer (for stride=2 blocks)

    def forward(self, x):
        identity = x  # Save original input

        if self.downsample is not None:
            identity = self.downsample(x)  # Apply downsampling if necessary

        out = self.conv1(x)
        out = self.bn1(out)
        out = F.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out += identity  # Add residual connection
        out = F.relu(out)

        return out

class Bottleneck(torch.nn.Module):
    expansion = 4

    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(Bottleneck, self).__init__()
        self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
        self.bn1 = torch.nn.BatchNorm2d(out_channels)
        self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = torch.nn.BatchNorm2d(out_channels)
        self.conv3 = torch.nn.Conv2d(out_channels, out_channels * self.expansion, kernel_size=1, bias=False)
        self.bn3 = torch.nn.BatchNorm2d(out_channels * self.expansion)
        self.relu = torch.nn.ReLU(inplace=True)
        self.downsample = downsample

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)
        return out

class ResNet(torch.nn.Module):
    def __init__(self, block, layers, num_classes=1000, dropout_rate=0.2):
        """
        Args:
            block (nn.Module): Residual block (BasicBlock or Bottleneck).
            layers (list): Number of blocks at each layer (e.g., [2, 2, 2, 2] for ResNet-18).
            num_classes (int): Number of output classes.
        """
        super(ResNet, self).__init__()

        self.in_channels = 64

        # Initial convolution layer
        self.backbone = torch.nn.Sequential(
            torch.nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False),
            torch.nn.BatchNorm2d(64),
            torch.nn.ReLU(inplace=True),
            torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1),

            # Residual Blocks
            self._make_layer(block, 64, layers[0], stride=1),
            self._make_layer(block, 128, layers[1], stride=2),
            self._make_layer(block, 256, layers[2], stride=2),
            self._make_layer(block, 512, layers[3], stride=2),

            # Global Average Pooling & Fully Connected Layer
            torch.nn.AdaptiveAvgPool2d((1, 1))
        )

        self.cls_layer = torch.nn.Sequential(
            torch.nn.Flatten(), 
            torch.nn.Dropout(p=dropout_rate),
            torch.nn.Linear(512 * block.expansion, num_classes),
        )

        # Initialize weights
        self._initialize_weights()

    def _make_layer(self, block, out_channels, blocks, stride):
        """
        Constructs a ResNet layer with `blocks` residual blocks.
        """
        downsample = None
        if stride != 1 or self.in_channels != out_channels * block.expansion:
            downsample = torch.nn.Sequential(
                torch.nn.Conv2d(self.in_channels, out_channels * block.expansion, kernel_size=1, stride=stride, bias=False),
                torch.nn.BatchNorm2d(out_channels * block.expansion),
            )

        layers = []
        layers.append(block(self.in_channels, out_channels, stride, downsample))
        self.in_channels = out_channels * block.expansion  # Update input channel size

        for _ in range(1, blocks):
            layers.append(block(self.in_channels, out_channels))

        return torch.nn.Sequential(*layers)

    def forward(self, x, return_feats=True):
        feats = self.backbone(x)
        out = self.cls_layer(feats)

        if return_feats:
            return {"feats": feats, "out": out}
        else:
            return out

    def _initialize_weights(self):
        """
        Initializes model weights using Kaiming He initialization.
        """
        for m in self.modules():
            if isinstance(m, torch.nn.Conv2d):
                torch.nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, torch.nn.BatchNorm2d):
                torch.nn.init.constant_(m.weight, 1)
                torch.nn.init.constant_(m.bias, 0)

def ResNet18(num_classes=1000, dropout_rate=0.2):
    return ResNet(BasicBlock, [2, 2, 2, 2], num_classes, dropout_rate)

def ResNet34(num_classes=1000, dropout_rate=0.2):
    return ResNet(BasicBlock, [3, 4, 6, 3], num_classes, dropout_rate)


### MobileFaceNet

In [272]:
# channel attention, spatial attention

class Conv_BN(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, groups=1, activation='prelu'):
        super(Conv_BN, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        if activation=='prelu':
            self.activ = nn.PReLU(out_channels)
        elif activation=='SiLU':
            self.activ = nn.SiLU(out_channels=True)
        else:
            self.activ = None

    def forward(self, x):
        x = self.conv(x)
        if self.bn:
            x = self.bn(x)
        if self.activ:
            x = self.activ(x)
        return x


class DepthwiseConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, activation='prelu'):
        super(DepthwiseConv, self).__init__()
        self.conv = nn.Sequential(
            # Depthwise convolution
            nn.Conv2d(in_channels, in_channels, kernel_size=kernel_size, stride=stride, padding=padding, groups=in_channels, bias=False),
            nn.BatchNorm2d(in_channels),
            nn.PReLU(in_channels) if activation=='prelu' else nn.SiLU(in_channels),
            
            # Pointwise convolution
            nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(out_channels),
        )

    def forward(self, x):
        return self.conv(x)


class ChannelAttention(nn.Module):
    def __init__(self, channels, reduction=16, activation='prelu'):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channels, channels // reduction, bias=False),
            nn.PReLU(channels // reduction) if activation=='prelu' else nn.SiLU(channels // reduction),
            nn.Linear(channels // reduction, channels, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        # x: (N, C, H, W)
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y


class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        self.conv1 = nn.Conv2d(2, 1, kernel_size=kernel_size, padding=(kernel_size-1)//2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        y = torch.cat([avg_out, max_out], dim=1)
        y = self.conv1(y)
        y = self.sigmoid(x)
        return x * y


class CBAM(nn.Module):
    def __init__(self, channels, channel_reduction=16, activation='prelu', spatial_kernel_size=7):
        super(CBAM, self).__init__()
        self.channel_att = ChannelAttention(channels, channel_reduction, activation)
        self.spatial_att = SpatialAttention(spatial_kernel_size)

    def forward(self, x):
        # Apply channel attention first
        x = self.channel_att(x)
        # Then apply spatial attention
        x = self.spatial_att(x)
        return x


class Bottleneck(nn.Module):
    def __init__(self, in_channels, out_channels, stride, expand_ratio, activation='prelu', use_attention=None, channel_reduction=4, spatial_kernel_size=7):
        super(Bottleneck, self).__init__()
        self.stride = stride
        self.use_res_connect = self.stride == 1 and in_channels == out_channels

        hidden_dim = round(in_channels * expand_ratio)
        
        if expand_ratio == 1:
            self.conv = nn.Sequential(
                # Depthwise convolution
                nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=stride, padding=1, groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.PReLU(hidden_dim) if activation=='prelu' else nn.SiLU(hidden_dim),
                # Pointwise convolution
                nn.Conv2d(hidden_dim, out_channels, kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(out_channels),
            )
        else:
            self.conv = nn.Sequential(
                # Pointwise convolution to expand channels
                nn.Conv2d(in_channels, hidden_dim, kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.PReLU(hidden_dim) if activation=='prelu' else nn.SiLU(hidden_dim),
                # Depthwise convolution
                nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=stride, padding=1, groups=hidden_dim, bias=False),
                nn.BatchNorm2d(hidden_dim),
                nn.PReLU(hidden_dim) if activation=='prelu' else nn.SiLU(hidden_dim),
                # Pointwise convolution to reduce channels
                nn.Conv2d(hidden_dim, out_channels, kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(out_channels),
            )
        
        if use_attention=='channel':
            self.attention = ChannelAttention(out_channels, channel_reduction, activation)
        elif use_attention=='spatial':
            self.attention = SpatialAttention(spatial_kernel_size)
        elif use_attention=='cbam':
            self.attention = CBAM(out_channels, channel_reduction, activation, spatial_kernel_size)
        else:
            self.attention = None

    def forward(self, x):
        out = self.conv(x)
        
        if self.attention:
            out = self.attention(out)

        if self.use_res_connect:
            return x + out
        else:
            return out


class MobileFaceNet(nn.Module):
    def __init__(self, num_classes=128, dropout_rate=0.2, activation='prelu', use_attention=False, channel_reduction=4, spatial_kernel_size=7):
        super(MobileFaceNet, self).__init__()
        
        self.backbone = torch.nn.Sequential()

        # Initial convolution layer, 3 x 112 -> 64 x 56
        self.backbone.add_module("conv3x3", Conv_BN(3, 64, kernel_size=3, stride=2, padding=1, activation=activation))
        
        # step down
        # Depthwise convolution, 64 x 56 -> 64 x 56
        self.backbone.add_module("depthwise conv3x3", DepthwiseConv(64, 64, kernel_size=1, stride=1, padding=0, activation=activation))  # why padding=1
        
        # learn learn learn
        # Bottleneck blocks, x5, 64 x 56 -> 64 x 28
        for i in range(5):
            if i == 0:
                self.backbone.add_module(f"bottleneck1_{i}", Bottleneck(64, 64, stride=2, expand_ratio=2, activation=activation, use_attention=None))
            else:
                self.backbone.add_module(f"bottleneck1_{i}", Bottleneck(64, 64, stride=1, expand_ratio=2, activation=activation, use_attention=None))
        
        # step down
        # Bottleneck blocks, x1, 64 x 28 -> 128 x 14
        self.backbone.add_module("bottleneck2", Bottleneck(64, 128, stride=2, expand_ratio=4, activation=activation, use_attention=None))
        
        # learn learn learn
        # Bottleneck blocks, x6, 128 x 14 -> 128 x 14
        for i in range(6):
            self.backbone.add_module(f"bottleneck3_{i}", Bottleneck(128, 128, stride=1, expand_ratio=2, activation=activation, use_attention='channel' if use_attention else None, channel_reduction=channel_reduction))
        
        # step down
        # Bottleneck blocks, x1, 128 x 14 -> 128 x 7
        self.backbone.add_module("bottleneck4", Bottleneck(128, 128, stride=2, expand_ratio=4, activation=activation, use_attention=None))
        
        # learn learn learn
        # Bottleneck blocks, x2, 128 x 7 -> 128 x 7
        for i in range(2):
            self.backbone.add_module(f"bottleneck5_{i}", Bottleneck(128, 128, stride=1, expand_ratio=2, activation=activation, use_attention='cbam' if use_attention else None, channel_reduction=channel_reduction, spatial_kernel_size=spatial_kernel_size))
       
        # summarize from low-level to high-level
        # Conv 1x1, 128 x 7 -> 512 x 7
        self.backbone.add_module("conv1x1", Conv_BN(128, 512, kernel_size=1, stride=1, padding=0, activation=activation))
        
        # avg pooling
        # GDConv 7x7, 512 x 7 -> 512 x 1
        self.backbone.add_module("linear GDConv7x7", Conv_BN(512, 512, kernel_size=7, stride=1, padding=0, groups=512, activation=activation))
        
        # come up with different hypothesis
        # Linear Conv 1x1, 512 x 1 -> 512 x 1
        self.backbone.add_module("linear conv1x1", nn.Conv2d(512, 512, kernel_size=(1,1), stride=1)),
        
        # Classification Layer
        self.cls_layer = torch.nn.Sequential(
            torch.nn.Flatten(), 
            torch.nn.Dropout(dropout_rate),
            torch.nn.Linear(512, num_classes),
            torch.nn.BatchNorm1d(num_classes)
        )

        # Initialize weights
        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, torch.nn.Conv2d):
                torch.nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    torch.nn.init.constant_(m.bias, 0)
            elif isinstance(m, torch.nn.Linear):
                torch.nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    torch.nn.init.constant_(m.bias, 0)
            elif isinstance(m, (torch.nn.BatchNorm2d, torch.nn.BatchNorm1d, torch.nn.LayerNorm)):
                torch.nn.init.constant_(m.weight, 1)
                torch.nn.init.constant_(m.bias, 0)

    def forward(self, x, return_feats=False):
        # x: (N, 3, 112, 112)
        feats = self.backbone(x)  # (N, 512, 1, 1) → flattened inside cls_layer
        out = self.cls_layer(feats)  # (N, num_classes)
        if return_feats:
            return {"feats": feats, "out": out}
        else:
            return out

### Model Configuration

In [None]:
# Initialize your model
MODEL = {
    'Network': Network,
    'DeeperNetwork': DeeperNetwork,
    'ConvNeXt': ConvNeXt,
    'ResNet18': ResNet18,
    'ResNet34': ResNet34,
    'MobileFaceNet': MobileFaceNet, 
}

model = MODEL[config['model']](num_classes=config["num_classes"], **config['model_params']).to(DEVICE)
summary(model, torch.randn(64, 3, 112, 112).to(DEVICE))

# Loss

### CombinedMarginLoss

In [24]:
class CombinedMarginLoss(torch.nn.Module):
    def __init__(self, num_classes, embedding_size, scale=30.0, angular_margin=0.50, additive_margin=0.35, arcface_weight=0.5):
        super(CombinedMarginLoss, self).__init__()
        self.arcface = losses.ArcFaceLoss(config['num_classes'], embedding_size, scale=config['scale'], margin=config['angular_margin'])
        self.cosface = losses.CosFaceLoss(config['num_classes'], embedding_size, scale=config['scale'], margin=config['additive_margin'])
        self.arcface_weight = arcface_weight

    def forward(self, embeddings, labels):
        arcface_loss = self.arcface(embeddings, labels)
        cosface_loss = self.cosface(embeddings, labels)
        combined_loss = self.arcface_weight * arcface_loss + (1 - self.arcface_weight) * cosface_loss
        return combined_loss

### Loss Configuration

In [25]:
embedding_size = model(torch.randn(64, 3, 112, 112).to(DEVICE), return_feats=True)["feats"].shape[1]
embedding_size

LOSS_FUNCTIONS = {
    "crossentropy": lambda config: torch.nn.CrossEntropyLoss(
        label_smoothing=config["label_smoothing"]
    ),
    "triplet": lambda config: losses.TripletMarginLoss(margin=config["margin"]),
    "npair": lambda config: losses.NPairsLoss(),
    "arcface": lambda config: losses.ArcFaceLoss(
        config["num_classes"],
        embedding_size,
        scale=config["scale"],
        margin=config["angular_margin"],
    ),
    "cosface": lambda config: losses.CosFaceLoss(
        config["num_classes"],
        embedding_size,
        scale=config["scale"],
        margin=config["additive_margin"],
    ),
    "combined": lambda config: CombinedMarginLoss(
        num_classes=config["num_classes"],
        embedding_size=embedding_size,
        scale=config["scale"],
        angular_margin=config["angular_margin"],
        additive_margin=config["additive_margin"],
        arcface_weight=config["combined_arcface_weight"],
    ),
}

ce_criterion = LOSS_FUNCTIONS["crossentropy"](config)
if config["loss"] is not None:
    criterion = (LOSS_FUNCTIONS[config["loss"]](config)).to(DEVICE)

# Optimizer and Scheduler

In [26]:
# --------------------------------------------------- #

# Defining Optimizer
# TODO: Feel free to pick a optimizer
if config["loss"] in ["arcface", "cosface", "combined"]:
    optimizer = torch.optim.AdamW(
        [{"params": model.parameters()}, {"params": criterion.parameters()}],
        lr=config["lr"],
        weight_decay=config["weight_decay"],
    )
else:
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=config["lr"],
        weight_decay=config["weight_decay"],
    )

# --------------------------------------------------- #

# Defining Scheduler
# TODO: Use a good scheduler such as ReduceLRonPlateau, StepLR, MultistepLR, CosineAnnealing, etc.
SCHEDULER = {
    "CosineAnnealingLR": torch.optim.lr_scheduler.CosineAnnealingLR,
    "CosineAnnealingWarmRestarts": torch.optim.lr_scheduler.CosineAnnealingWarmRestarts,
}

scheduler = SCHEDULER[config["scheduler"]](optimizer, **config["scheduler_params"])

# --------------------------------------------------- #

# Initialising mixed-precision training. # Good news. We've already implemented FP16 (Mixed precision training) for you
# It is useful only in the case of compatible GPUs such as T4/V100
scaler = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available())

# Metrics

In [27]:
class AverageMeter:
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [28]:
def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    maxk = min(max(topk), output.size()[1])
    batch_size = target.size(0)
    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.reshape(1, -1).expand_as(pred))
    return [correct[:min(k, maxk)].reshape(-1).float().sum(0) * 100. / batch_size for k in topk]

In [29]:
def get_ver_metrics(labels, scores, FPRs):
    # eer and auc
    fpr, tpr, _ = mt.roc_curve(labels, scores, pos_label=1)
    roc_curve = interp1d(fpr, tpr)
    EER = 100. * brentq(lambda x : 1. - x - roc_curve(x), 0., 1.) # FPR = FNR
    AUC = 100. * mt.auc(fpr, tpr)

    # get acc
    tnr = 1. - fpr
    pos_num = labels.count(1)
    neg_num = labels.count(0)
    ACC = 100. * max(tpr * pos_num + tnr * neg_num) / len(labels)

    # TPR @ FPR
    if isinstance(FPRs, list):
        TPRs = [
            ('TPR@FPR={}'.format(FPR), 100. * roc_curve(float(FPR)))
            for FPR in FPRs
        ]
    else:
        TPRs = []

    return {
        'ACC': ACC,
        'EER': EER,
        'AUC': AUC,
        'TPRs': TPRs,
    }

# Train and Validation Function

### Train

In [30]:
def train_epoch(
    model, dataloader, optimizer, lr_scheduler, scaler, device, epoch, loss
):

    model.train()

    # metric meters
    loss_m = AverageMeter()
    acc_m = AverageMeter()

    # Progress Bar
    batch_bar = tqdm(
        total=len(dataloader),
        dynamic_ncols=True,
        leave=False,
        position=0,
        desc="Train",
        ncols=5,
    )

    for i, (images, labels) in enumerate(dataloader):

        optimizer.zero_grad()  # Zero gradients

        # send to cuda
        images = images.to(device, non_blocking=True)
        if isinstance(labels, (tuple, list)):
            targets1, targets2, lam = labels
            labels = (targets1.to(device), targets2.to(device), lam)
        else:
            labels = labels.to(device, non_blocking=True)

        # forward
        with torch.cuda.amp.autocast():  # This implements mixed precision. Thats it!
            outputs = model(images, return_feats=True)

            loss = ce_criterion(outputs["out"], labels)
            if config['loss'] is not None:
                output_feats = F.normalize(outputs["feats"].view(outputs["feats"].size(0), -1))
                other_loss = criterion(output_feats, labels)            
                loss = loss * config['cross_entropy_weight'] + other_loss * config['other_loss_weight']

            # Adversarial training part
            if config.get('use_adversarial_training', False):
                embeddings = output_feats.detach().clone().requires_grad_()
                logits = model.cls_layer(embeddings, return_feats=False)

                embed_adv_loss = ce_criterion(logits, labels)
                if config['loss'] is not None:
                    embed_adv_other_loss = criterion(embeddings, labels)
                    embed_adv_loss = embed_adv_loss * config['cross_entropy_weight'] + embed_adv_other_loss * config['other_loss_weight']

                grad = torch.autograd.grad(embed_adv_loss, embeddings)[0]
                adv_alpha = config['adversarial_training_alpha']
                perturbed_embeddings = embeddings + (adv_alpha * grad.sign())
                perturbed_embeddings = F.normalize(perturbed_embeddings, p=2, dim=1)
                perturbed_logits = model.cls_layer(perturbed_embeddings, return_feats=False)

                perturbed_adv_loss = ce_criterion(perturbed_logits, labels)
                if config['loss'] is not None:
                    perturbed_adv_other_loss = criterion(perturbed_embeddings, labels)
                    perturbed_adv_loss = perturbed_adv_loss * config['cross_entropy_weight'] + perturbed_adv_other_loss * config['other_loss_weight']

                adv_beta = config['adversarial_training_beta']
                loss = (1 - adv_beta) * loss + adv_beta * perturbed_adv_loss

        # Backprop
        scaler.scale(loss).backward()  # This is a replacement for loss.backward()
        scaler.step(optimizer)  # This is a replacement for optimizer.step()
        scaler.update()

        # metrics: loss
        loss_m.update(loss.item())

        # metrics: accuracy
        if "feats" in outputs:
            acc = accuracy(outputs["out"], labels)[0].item()
        else:
            acc = 0.0
        acc_m.update(acc)

        # tqdm lets you add some details so you can monitor training as you train.
        batch_bar.set_postfix(
            # acc         = "{:.04f}%".format(100*accuracy),
            acc="{:.04f}% ({:.04f})".format(acc, acc_m.avg),
            loss="{:.04f} ({:.04f})".format(loss.item(), loss_m.avg),
            lr="{:.04f}".format(float(optimizer.param_groups[0]["lr"])),
        )

        batch_bar.update()  # Update tqdm bar

        # batch = epoch * len(dataloader) + (i + 1)
        # wandb.log(
        #     {
        #         "batch": batch,
        #         "batch_train_loss": loss,
        #         "batch_train_acc": acc * 100,
        #     }
        # )

    # You may want to call some schedulers inside the train function. What are these?
    if lr_scheduler is not None:
        lr_scheduler.step()

    batch_bar.close()

    return acc_m.avg, loss_m.avg

### Valid

In [31]:
@torch.no_grad()
def valid_epoch_cls(model, dataloader, device, config):

    model.eval()
    batch_bar = tqdm(
        total=len(dataloader),
        dynamic_ncols=True,
        position=0,
        leave=False,
        desc="Val Cls.",
        ncols=5,
    )

    # metric meters
    loss_m = AverageMeter()
    acc_m = AverageMeter()
    
    for i, (images, labels) in enumerate(dataloader):

        # Move images to device
        images, labels = images.to(device), labels.to(device)

        # Get model outputs
        with torch.inference_mode():
            outputs = model(images, return_feats=True)

            loss = ce_criterion(outputs["out"], labels)
            
            if config['loss'] is not None:
                output_feats = F.normalize(outputs["feats"].view(outputs["feats"].size(0), -1))
                other_loss = criterion(output_feats, labels)
                loss = loss * config['cross_entropy_weight'] + other_loss * config['other_loss_weight']

        # metrics
        loss_m.update(loss.item())
        acc = accuracy(outputs["out"], labels)[0].item()
        acc_m.update(acc)

        batch_bar.set_postfix(
            acc="{:.04f}% ({:.04f})".format(acc, acc_m.avg),
            loss="{:.04f} ({:.04f})".format(loss.item(), loss_m.avg),
        )

        batch_bar.update()

    batch_bar.close()
    return acc_m.avg, loss_m.avg

In [32]:
gc.collect() # These commands help you when you face CUDA OOM error
torch.cuda.empty_cache()

# Verification Task

In [33]:
def valid_epoch_ver(model, pair_data_loader, device, config):

    model.eval()
    scores = []
    match_labels = []
    batch_bar = tqdm(total=len(pair_data_loader), dynamic_ncols=True, position=0, leave=False, desc='Val Veri.')
    for i, (images1, images2, labels) in enumerate(pair_data_loader):

        # match_labels = match_labels.to(device)
        images = torch.cat([images1, images2], dim=0).to(device)
        # Get model outputs
        with torch.inference_mode():
            outputs = model(images, return_feats=True)

        feats = F.normalize(outputs['feats'], dim=1)
        feats1, feats2 = feats.chunk(2)
        similarity = F.cosine_similarity(feats1, feats2)
        scores.append(similarity.cpu().numpy().squeeze())
        match_labels.append(labels.cpu().numpy().squeeze())
        batch_bar.update()

    scores = np.concatenate(scores)
    match_labels = np.concatenate(match_labels)

    FPRs=['1e-4', '5e-4', '1e-3', '5e-3', '5e-2']
    metric_dict = get_ver_metrics(match_labels.tolist(), scores.tolist(), FPRs)
    # print(metric_dict)

    return metric_dict['ACC'], metric_dict['EER']

# WandB

In [None]:
wandb.login()

In [None]:
# Create your wandb run
if config['wandb_init']:
    run = wandb.init(
        name = config['wandb_name'], ## Wandb creates random run names if you skip this field
        reinit = True, ### Allows reinitalizing runs when you re-run this cell
        project = "hw2p2", ### Project should be created in your wandb account
        config = config ### Wandb Config for your run
    )
else:
    run = wandb.init(
        name = config['wandb_name'], ## Wandb creates random run names if you skip this field
        id = config['wandb_id'], ### Insert specific run id here if you want to resume a previous run
        resume = "must", ### You need this to resume previous runs, but comment out reinit = True when using this
        project = "hw2p2", ### Project should be created in your wandb account
        config = config ### Wandb Config for your run
    )

# Checkpointing and Loading Model

In [36]:
import os
checkpoint_dir = config['checkpoint_dir']

# Create the directory if it doesn't exist
os.makedirs(checkpoint_dir, exist_ok=True)

In [37]:
def save_model(model, optimizer, scheduler, metrics, epoch, best_valid_cls_acc, best_valid_ret_acc, best_valid_ret_eer, path):
    checkpoint = {
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "scheduler_state_dict": scheduler.state_dict(),
        "metric": metrics,
        "epoch": epoch,
        "best_valid_cls_acc": best_valid_cls_acc,
        "best_valid_ret_acc": best_valid_ret_acc,
        "best_valid_ret_eer": best_valid_ret_eer
    }

    torch.save(checkpoint, path)
    wandb.save(path)
    # print(f"Checkpoint saved locally at {path}")


def load_model(model, optimizer=None, scheduler=None, path="./checkpoint.pth", wandb_run=True):
    if wandb_run:
        restored_path = wandb.restore(path).name
        checkpoint = torch.load(restored_path, map_location=DEVICE)
    else:
        checkpoint = torch.load(path, map_location=DEVICE)

    model.load_state_dict(checkpoint["model_state_dict"])
    if optimizer is not None:
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    else:
        optimizer = None
    if scheduler is not None:
        scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
    else:
        scheduler = None
    epoch = checkpoint["epoch"]
    metrics = checkpoint["metric"]
    best_valid_cls_acc = checkpoint["best_valid_cls_acc"]
    best_valid_ret_acc = checkpoint["best_valid_ret_acc"]
    best_valid_ret_eer = checkpoint["best_valid_ret_eer"]

    print(f"Checkpoint loaded successfully (Epoch {epoch})")

    return model, optimizer, scheduler, epoch, metrics, best_valid_cls_acc, best_valid_ret_acc, best_valid_ret_eer

# Experiments

In [None]:
eval_cls = True
if config['wandb_init']:
    e = 0
    best_valid_cls_acc = 0.0
    best_valid_ret_acc = 0.0
    best_valid_ret_eer = float('inf')
else:
    model, optimizer, scheduler, last_epoch, metrics, best_valid_cls_acc, best_valid_ret_acc, best_valid_ret_eer = load_model(model, optimizer, scheduler, "checkpoint/last.pth")
    e = last_epoch + 1

print(f"loss: {config['loss']}")

for epoch in range(e, config["epochs"]):
    # epoch
    print(f"\nEpoch {epoch + 1}/{config['epochs']}")
    metrics = {
        "epoch": epoch + 1,
    }

    # train
    train_cls_acc, train_loss = train_epoch(
        model, cls_train_loader, optimizer, scheduler, scaler, DEVICE, epoch, config
    )
    curr_lr = float(optimizer.param_groups[0]["lr"])
    print(
        f"\nTraining Metrics - Epoch {epoch + 1}/{config['epochs']}\n"
        f"Train Cls. Acc: {train_cls_acc:.4f}%\n"
        f"Train Cls. Loss: {train_loss:.4f}\n"
        f"Learning Rate: {curr_lr:.6f}"
    )
    metrics.update(
        {
            "train_cls_acc": train_cls_acc,
            "train_loss": train_loss,
            "lr": curr_lr,
        }
    )

    # classification validation
    if eval_cls:
        valid_cls_acc, valid_loss = valid_epoch_cls(
            model, cls_val_loader, DEVICE, config
        )
        print(
            f"\nClassification Validation - Epoch {epoch + 1}/{config['epochs']}\n"
            f"Val Cls. Acc: {valid_cls_acc:.4f}%\n"
            f"Val Cls. Loss: {valid_loss:.4f}"
        )
        metrics.update(
            {
                "valid_cls_acc": valid_cls_acc,
                "valid_loss": valid_loss,
            }
        )

    # retrieval validation
    valid_ret_acc, valid_ret_eer = valid_epoch_ver(
        model, ver_val_loader, DEVICE, config
    )
    print(
        f"\nRetrieval Validation - Epoch {epoch + 1}/{config['epochs']}\n"
        f"Val Ret. Acc: {valid_ret_acc:.4f}%\n"
        f"Val Ret. EER: {valid_ret_eer:.4f}"
    )
    metrics.update(
        {
            "valid_ret_acc": valid_ret_acc,
            "valid_ret_eer": valid_ret_eer,
        }
    )

    # save best model
    if eval_cls:
        if valid_cls_acc >= best_valid_cls_acc:
            best_valid_cls_acc = valid_cls_acc
            save_model(
                model,
                optimizer,
                scheduler,
                metrics,
                epoch,
                best_valid_cls_acc,
                best_valid_ret_acc,
                best_valid_ret_eer,
                os.path.join(config["checkpoint_dir"], "best_cls.pth"),
            )
            wandb.save(os.path.join(config["checkpoint_dir"], "best_cls.pth"))
            print("Saved best classification model")

    if valid_ret_acc >= best_valid_ret_acc:
        best_valid_ret_acc = valid_ret_acc
        save_model(
            model,
            optimizer,
            scheduler,
            metrics,
            epoch,
            best_valid_cls_acc,
            best_valid_ret_acc,
            best_valid_ret_eer,
            os.path.join(config["checkpoint_dir"], "best_ret_acc.pth"),
        )
        wandb.save(os.path.join(config["checkpoint_dir"], "best_ret_acc.pth"))
        print("Saved best retrieval acc model")

    if valid_ret_eer <= best_valid_ret_eer:
        best_valid_ret_eer = valid_ret_eer
        save_model(
            model,
            optimizer,
            scheduler,
            metrics,
            epoch,
            best_valid_cls_acc,
            best_valid_ret_acc,
            best_valid_ret_eer,
            os.path.join(config["checkpoint_dir"], "best_ret_eer.pth"),
        )
        wandb.save(os.path.join(config["checkpoint_dir"], "best_ret_eer.pth"))
        print("Saved best retrieval eer model")

    # save model
    save_model(
        model,
        optimizer,
        scheduler,
        metrics,
        epoch,
        best_valid_cls_acc,
        best_valid_ret_acc,
        best_valid_ret_eer,
        os.path.join(config["checkpoint_dir"], "last.pth"),
    )
    wandb.save(os.path.join(config["checkpoint_dir"], "last.pth"))

    # log to tracker
    if run is not None:
        run.log(metrics)

# Testing and Kaggle Submission

In [36]:
def test_epoch_ver(model, pair_data_loader, config):

    model.eval()
    scores = []
    batch_bar = tqdm(total=len(pair_data_loader), dynamic_ncols=True, position=0, leave=False, desc='Val Veri.')
    for i, (images1, images2) in enumerate(pair_data_loader):

        images = torch.cat([images1, images2], dim=0).to(DEVICE)
        # Get model outputs
        with torch.inference_mode():
            outputs = model(images, return_feats=True)

        feats = F.normalize(outputs['feats'], dim=1)
        feats1, feats2 = feats.chunk(2)
        similarity = F.cosine_similarity(feats1, feats2)
        scores.extend(similarity.cpu().numpy().tolist())
        batch_bar.update()

    return scores

In [None]:
submission_dir = "submissions"
os.makedirs(submission_dir, exist_ok=True)

model_paths = ["best_ret_eer.pth", "best_ret_acc.pth", "last.pth"]

for model_path in model_paths:
    print(f"Loading {model_path}")

    model, optimizer, scheduler, last_epoch, metrics, best_valid_cls_acc, best_valid_ret_acc, best_valid_ret_eer = load_model(model, optimizer, scheduler, os.path.join(config['checkpoint_dir'],model_path))
    print(f"best cls acc: {best_valid_cls_acc}, best ret acc: {best_valid_ret_acc}, best ret eer: {best_valid_ret_eer}")

    scores = test_epoch_ver(model, ver_test_loader, config)
    scores = torch.flatten(torch.tensor(scores))

    file_name = os.path.join(submission_dir, model_path.split(".")[0] + '_submission.csv')

    with open(file_name, "w+") as f:
        f.write("ID,Label\n")
        for i in range(len(scores)):
            f.write("{},{}\n".format(i, scores[i]))

    artifact = wandb.Artifact("submission", type="dataset")
    artifact.add_file(file_name)
    wandb.log_artifact(artifact)
    print(model_path.split(".")[0] + " submitted to WandB")

In [39]:
# ### Submit to kaggle competition using kaggle API (Uncomment below to use)
# !kaggle competitions submit -c 11785-hw-2-p-2-face-verification-spring-2025 -f /content/verification_early_submission.csv -m "Test Submission"

# ### However, its always safer to download the csv file and then upload to kaggle

In [None]:
wandb.finish()