In [None]:
# Importing required libraries
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms

In [178]:
# Defining device (using GPU)
device = torch.device("cuda")
print("Using device:", device)

Using device: cuda


#### Loading the data and preprocessing

In [None]:
# Defining transformations
# Implementing Jittering and random cropping with rescaling
transform = transforms.Compose([
    transforms.RandomResizedCrop(32, scale=(0.8, 1.0)),  
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),  
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2470, 0.2435, 0.2616))
])

# Load CIFAR-10 dataset
train = datasets.CIFAR10(root='./data', train=True, download=False, transform=transform)
test = datasets.CIFAR10(root='./data', train=False, download=False, transform=transform)

# Function to create DataLoaders
train_cifar = torch.utils.data.DataLoader(train, batch_size=64, shuffle=True)
test_cifar = torch.utils.data.DataLoader(test, batch_size=64)

In [202]:
# Confirming the test and train datasets' size x 64
print(len(train_cifar))
print(len(test_cifar))

782
157


#### Using a ResNet18 model

In [208]:
import torch.nn as nn
import torchvision.models as models

class ResNetBackbone(nn.Module):
    def __init__(self):
        super(ResNetBackbone, self).__init__()
        resnet = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)

        # Remove the fully connected layer
        self.backbone = nn.Sequential(*list(resnet.children())[:-1])

        # Feature output size: 512
        self.feature_dim = 256  # Reduced to 256 for better generalization

        self.flatten = nn.Flatten()

        # New bottleneck fully connected layer
        self.fc = nn.Linear(512, self.feature_dim)

        # Batch normalization
        self.batch_norm = nn.BatchNorm1d(self.feature_dim, momentum=0.5)

        # Stronger dropout
        self.dropout = nn.Dropout(p=0.6)

    def forward(self, x):
        x = self.backbone(x)
        x = self.flatten(x)
        x = self.fc(x)  # Bottleneck layer
        x = self.batch_norm(x)
        x = self.dropout(x)
        return x


#### Calculating the theta value

In [None]:
import torch.nn.functional as F

class Cosine_similarity(nn.Module):
    def __init__(self, feature_dim, num_classes):
        super(Cosine_similarity, self).__init__()
        
        # Learnable class weight vectors (initialized randomly)
        self.class_weights = nn.Parameter(torch.randn(num_classes, feature_dim))
        
    def forward(self, features):
        # Normalize the feature vectors (L2 normalization)
        features = F.normalize(features, p=2, dim=1)  
        
        # Normalize the class weight vectors
        class_weights = F.normalize(self.class_weights, p=2, dim=1)  
        
        # Compute cosine similarity (dot product since vectors are normalized)
        cosine_sim = torch.matmul(features, class_weights.T)  
        
        return cosine_sim

#### Calculating mean and standard deviation of features to calculate standardized feature norm

In [None]:
# Create an instance of ResNetBackbone and Cosine_similarity
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_classes = 10
model_backbone = ResNetBackbone().to(device)
model_head = Cosine_similarity(model_backbone.feature_dim, num_classes).to(device)

# Run a few batches and calculate feature norm statistics
model_backbone.eval() 
with torch.no_grad():
    norms_list = []
    for i, (images, labels) in enumerate(train_cifar):
        if i >= 5:
            break
        images = images.to(device)
        features = model_backbone(images)
        norms = torch.norm(features, p=2, dim=1)
        norms_list.extend(norms.tolist())

    # Calculating mean norm and standard deviation norm of features
    mean_norm = torch.tensor(norms_list).mean().item()
    std_norm = torch.tensor(norms_list).std().item()

    print(f"Observed mean norm: {mean_norm}")
    print(f"Observed std norm: {std_norm}")

Observed mean norm: 7.902091979980469
Observed std norm: 1.1947855949401855


#### Defining AdaLoss function

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class AdaLossModel(nn.Module):
    def __init__(self, num_classes, alpha, backbone_type, mean_norm, std_norm):
        super(AdaLossModel, self).__init__()

        self.head = Cosine_similarity(self.backbone.feature_dim, num_classes)
        self.num_classes = num_classes
        self.alpha = alpha
        self.register_buffer('batch_mean', torch.ones(1) * mean_norm)
        self.register_buffer('batch_std', torch.ones(1) * std_norm)

    def forward(self, x, labels=None):
        features = self.backbone(x)
        norms = torch.norm(features, p=2, dim=1)

        # Updating batch statistics for each batch
        with torch.no_grad():
            batch_mean = norms.mean()
            batch_std = norms.std()
            self.batch_mean.data = self.alpha * batch_mean + (1 - self.alpha) * self.batch_mean
            self.batch_std.data = self.alpha * batch_std + (1 - self.alpha) * self.batch_std
        # Clamping for numerical stability
        safe_norms = torch.clamp(norms, min=0.001, max=100)
        # Calculating standardized norms using z-score normalization
        standardized_norms = (safe_norms - self.batch_mean) / (self.batch_std + 1e-5)
        # Calculating the cosine similarity between feature vector and weight vector
        cosine_sim = self.head(features)
        # Only return logits in test mode
        if labels is None:
            return cosine_sim  # Only return logits in test mode

        # Compute AdaLoss
        m = 0.4
        h = 0.333
        margin_scaler = standardized_norms * h
        margin_scaler = torch.clamp(margin_scaler, -1, 1)

        g_angular = -m * margin_scaler
        g_additive = m + (m * margin_scaler)

        one_hot_labels = torch.zeros_like(cosine_sim).scatter_(1, labels.unsqueeze(1), 1)
        angular_term = one_hot_labels * g_angular.unsqueeze(1)
        additive_term = one_hot_labels * g_additive.unsqueeze(1)

        theta = torch.acos(torch.clamp(cosine_sim, -1 + 1e-7, 1 - 1e-7))
        theta_m = torch.clamp(theta + angular_term, min=1e-5, max=math.pi - 1e-5)
        cosine_m = torch.cos(theta_m)

        final_cosine = cosine_m - additive_term

        # Scale logits before computing loss
        loss = F.cross_entropy(final_cosine * 30, labels)
        return loss, cosine_sim  # Return both loss and logits


#### Evaluating loss Function

In [213]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import math

# Example Usage
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_classes = 10
alpha = 0.1


# Instantiate the model
model = AdaLossModel(num_classes=num_classes, alpha=alpha, backbone_type='resnet', mean_norm=mean_norm, std_norm=std_norm).to(device)

# Define optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
num_epochs = 5

# Training loop
for epoch in range(num_epochs):
    model.train()  # Set model to training mode
    train_loss = 0
    correct = 0
    total = 0

    for batch_idx, (inputs, labels) in enumerate(train_cifar):
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        
        loss, logits = model(inputs, labels)  # Extract both loss and logits
        loss.backward()
        optimizer.step()

        train_loss += loss.item()

        # Calculate training accuracy
        _, predicted = logits.max(1)  # Use logits directly instead of extracting features
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    # Calculate epoch training loss and accuracy
    train_loss /= len(train_cifar)
    train_accuracy = 100. * correct / total
    print(f"Epoch {epoch + 1}/{num_epochs}, Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%")

    # Testing loop
    model.eval()
    test_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in test_cifar:
            inputs, labels = inputs.to(device), labels.to(device)

            loss, logits = model(inputs, labels)  # Ensure model returns both loss and logits
            test_loss += loss.item()

            pred = logits.argmax(dim=1)
            total += labels.size(0)
            correct += pred.eq(labels).sum().item()

    test_loss /= len(test_cifar)
    test_accuracy = 100. * correct / total
    print(f"Epoch {epoch + 1}/{num_epochs}, Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%")



Epoch 1/5, Train Loss: 9.4100, Train Accuracy: 60.78%
Epoch 1/5, Test Loss: 7.1904, Test Accuracy: 74.22%
Epoch 2/5, Train Loss: 6.4302, Train Accuracy: 75.06%
Epoch 2/5, Test Loss: 6.5076, Test Accuracy: 77.13%
Epoch 3/5, Train Loss: 5.5423, Train Accuracy: 78.69%
Epoch 3/5, Test Loss: 6.1183, Test Accuracy: 78.40%
Epoch 4/5, Train Loss: 5.0185, Train Accuracy: 80.75%
Epoch 4/5, Test Loss: 5.9002, Test Accuracy: 79.96%
Epoch 5/5, Train Loss: 4.5895, Train Accuracy: 82.53%
Epoch 5/5, Test Loss: 5.7319, Test Accuracy: 80.37%
