In [6]:
# Importing required libraries

import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms

In [7]:
# Defining device (using GPU)
device = torch.device("cuda")
print("Using device:", device)

Using device: cuda


#### Data Loading and preprocessing

In [None]:

# Loading MNIST dataset
train = datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
test = datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())

# Create DataLoader to load all images
train_loader = torch.utils.data.DataLoader(train, batch_size=len(train), shuffle=False)

# Get all images as a single batch
# Extract images (ignoring labels)
data = next(iter(train_loader))[0]  

# Computing mean and standard deviation over all pixels in the batch
mean = data.mean() 
std = data.std()    

# Print computed mean and std
# Use .item() to get the scalar value
print("Computed Mean:", mean.item())  
print("Computed Std:", std.item())    


100%|██████████| 9.91M/9.91M [00:03<00:00, 3.29MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 749kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 1.75MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 289kB/s]


Computed Mean: 0.13066047430038452
Computed Std: 0.30810782313346863


In [None]:
# Define transformations for MNIST
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std)  
])

# Applying Normalization 
train.transform = transform
test.transform = transform

# Function to create DataLoaders
train_mnist = torch.utils.data.DataLoader(train, batch_size=64, shuffle=True)
test_mnist = torch.utils.data.DataLoader(test, batch_size=64)

In [None]:
# Checking train and test dataset size
print(len(train_mnist))
print(len(test_mnist))

938
157


#### Built a Simple Convolutional neural network model

In [12]:
import torch
import torch.nn as nn
import torchvision.models as models

class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.dropout = nn.Dropout(0.25)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(64 * 7 * 7, 128)  

        self.feature_dim = 128 

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.pool(x)
        x = self.dropout(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        x = self.pool(x)
        x = self.dropout(x)
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.relu(x)
        return x


#### Calculating the theta value

In [None]:
import torch.nn.functional as F

class Cosine_similarity(nn.Module):
    def __init__(self, feature_dim, num_classes):
        super(Cosine_similarity, self).__init__()

        # Initializing random class weights
        self.class_weights = nn.Parameter(torch.randn(num_classes, feature_dim))
        
    def forward(self, features):
        # Normalize the feature vectors (L2 norm)
        features = F.normalize(features, p=2, dim=1)  
        
        # Normalize the class weight vectors (L2 norm)
        class_weights = F.normalize(self.class_weights, p=2, dim=1)  
        
        # Compute cosine similarity (dot product since vectors are normalized)
        cosine_sim = torch.matmul(features, class_weights.T)  
        
        return cosine_sim

#### Calculating mean and standard deviation of features to calculate standardized feature norm

In [None]:
# Create an instance of SimpleCNN and Cosine_similarity
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_classes = 10
model_backbone = SimpleCNN().to(device)
model_head = Cosine_similarity(model_backbone.feature_dim, num_classes).to(device)

# Run a few batches and calculate feature norm statistics
# Set backbone to evaluation mode
model_backbone.eval()  
with torch.no_grad():
    norms_list = []
    for i, (images, labels) in enumerate(train_mnist):
        if i >= 5:
            break
        images = images.to(device)
        features = model_backbone(images)
        norms = torch.norm(features, p=2, dim=1)
        norms_list.extend(norms.tolist())

    # Calculating mean norm and standard deviation norm of features
    mean_norm = torch.tensor(norms_list).mean().item()
    std_norm = torch.tensor(norms_list).std().item()

    print(f"Observed mean norm: {mean_norm}")
    print(f"Observed std norm: {std_norm}")

Observed mean norm: 1.5565310716629028
Observed std norm: 0.18902243673801422


#### Defining AdaLoss Function

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class AdaLossModel(nn.Module):
    def __init__(self, num_classes, alpha, backbone_type, mean_norm, std_norm):
        super(AdaLossModel, self).__init__()
        self.head = Cosine_similarity(self.backbone.feature_dim, num_classes)
        self.num_classes = num_classes
        self.alpha = alpha
        self.register_buffer('batch_mean', torch.ones(1) * mean_norm)
        self.register_buffer('batch_std', torch.ones(1) * std_norm)

    def forward(self, x, labels=None):
        features = self.backbone(x)
        norms = torch.norm(features, p=2, dim=1)

        # Updating batch statistics for each batch
        with torch.no_grad():
            batch_mean = norms.mean()
            batch_std = norms.std()
            self.batch_mean.data = self.alpha * batch_mean + (1 - self.alpha) * self.batch_mean
            self.batch_std.data = self.alpha * batch_std + (1 - self.alpha) * self.batch_std
        # Clamping for numerical stability
        safe_norms = torch.clamp(norms, min=0.001, max=100)
        # Calculating standardized norms using z-score normalization
        standardized_norms = (safe_norms - self.batch_mean) / (self.batch_std + 1e-5)
        # Calculating the cosine similarity between feature vector and weight vector
        cosine_sim = self.head(features)

        # Only return logits in test mode
        if labels is None:
            return cosine_sim  

        m = 0.4
        h = 0.333
        margin_scaler = standardized_norms * h
        # margin_scalar is clamped so that cosine values are in range [-1,1]
        margin_scaler = torch.clamp(margin_scaler, -1, 1)

        # Calculating g_add, g_angular values
        g_angular = -m * margin_scaler
        g_additive = m + (m * margin_scaler)

        # One hot encoding of labels giving 1 at index of correct class for samples
        one_hot_labels = torch.zeros_like(cosine_sim).scatter_(1, labels.unsqueeze(1), 1)

        # Calculating angular and additive terms for samples
        angular_term = one_hot_labels * g_angular.unsqueeze(1)
        additive_term = one_hot_labels * g_additive.unsqueeze(1)

        # Calculating theta value
        theta = torch.acos(torch.clamp(cosine_sim, -1 + 1e-7, 1 - 1e-7))

        # Calculating (theta+angular_term)
        theta_m = torch.clamp(theta + angular_term, min=1e-5, max=math.pi - 1e-5)
        cosine_m = torch.cos(theta_m)

        final_cosine = cosine_m - additive_term

        # Scale logits before computing loss
        loss = F.cross_entropy(final_cosine * 30, labels)
        # Return both loss and logits
        return loss, cosine_sim  


#### Evaluating the Loss Function

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import math

# Example Usage
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_classes = 10
alpha = 0.1

# Instantiating the loss model
model = AdaLossModel(num_classes=num_classes, alpha=alpha, backbone_type='cnn', mean_norm=mean_norm, std_norm=std_norm).to(device)

# Defining optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
num_epochs = 20

# Training data on the SimpleCNN using AdaLoss function
for epoch in range(num_epochs):
    # Set model to training mode
    model.train()  
    train_loss = 0
    correct = 0
    total = 0

    for batch_idx, (inputs, labels) in enumerate(train_mnist):
        inputs, labels = inputs.to(device), labels.to(device)
        # Setting gradients to zero after every iteration
        optimizer.zero_grad()
        
        # Model takes inputs and labels and returns loss and logits
        loss, logits = model(inputs, labels)  
        # Calculates gradients
        loss.backward()
        optimizer.step()

        # Calculating training loss
        train_loss += loss.item()

        # Calculate training accuracy
        _, predicted = logits.max(1)  
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    # Calculate epoch training loss and accuracy
    train_loss /= len(train_mnist)
    train_accuracy = 100. * correct / total
    # Prints training accuracy and training loss
    print(f"Epoch {epoch + 1}/{num_epochs}, Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%",end=" ")

    # Testing the Model 
    model.eval()
    test_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in test_mnist:
            inputs, labels = inputs.to(device), labels.to(device)

             # Model takes inputs and labels and returns loss and logits
            loss, logits = model(inputs, labels) 

            # Calculating testing loss
            test_loss += loss.item()

            pred = logits.argmax(dim=1)
            total += labels.size(0)
            correct += pred.eq(labels).sum().item()

    test_loss /= len(test_mnist)
    test_accuracy = 100. * correct / total
    print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%")



Epoch 1/20, Train Loss: 4.1680, Train Accuracy: 89.79% Test Loss: 1.2448, Test Accuracy: 96.83%
Epoch 2/20, Train Loss: 1.5467, Train Accuracy: 96.51% Test Loss: 0.8185, Test Accuracy: 97.75%
Epoch 3/20, Train Loss: 1.0779, Train Accuracy: 97.38% Test Loss: 0.6034, Test Accuracy: 98.29%
Epoch 4/20, Train Loss: 0.8634, Train Accuracy: 97.81% Test Loss: 0.5553, Test Accuracy: 98.49%
Epoch 5/20, Train Loss: 0.7372, Train Accuracy: 98.04% Test Loss: 0.4637, Test Accuracy: 98.73%
Epoch 6/20, Train Loss: 0.6395, Train Accuracy: 98.32% Test Loss: 0.4686, Test Accuracy: 98.58%
Epoch 7/20, Train Loss: 0.5835, Train Accuracy: 98.44% Test Loss: 0.3972, Test Accuracy: 98.85%
Epoch 8/20, Train Loss: 0.5326, Train Accuracy: 98.53% Test Loss: 0.3697, Test Accuracy: 98.83%
Epoch 9/20, Train Loss: 0.4901, Train Accuracy: 98.64% Test Loss: 0.3714, Test Accuracy: 98.90%
Epoch 10/20, Train Loss: 0.4589, Train Accuracy: 98.72% Test Loss: 0.3237, Test Accuracy: 99.01%
Epoch 11/20, Train Loss: 0.4274, Train 