In [None]:
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [None]:
# Set random seeds for reproducibility
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)


In [None]:
# Define the Simple CNN model
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, 1)  # 3 input channels for RGB images
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        
        # Initialize fc1 with the dynamically calculated size
        self._initialize_fc1()

    def _initialize_fc1(self):
        # Use a dummy tensor to calculate the output size after convolutions
        with torch.no_grad():
            dummy_input = torch.zeros(1, 3, 32, 32)  # CIFAR-10 image size
            x = F.relu(self.conv1(dummy_input))
            x = F.max_pool2d(x, 2)
            x = F.relu(self.conv2(x))
            x = F.max_pool2d(x, 2)
            self.flattened_size = x.view(1, -1).size(1)
        
        self.fc1 = nn.Linear(self.flattened_size, 128)
        self.output = nn.Linear(128, 10)  # 10 classes for CIFAR-10

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = x.view(x.size(0), -1)  # Dynamically flatten
        x = F.relu(self.fc1(x))
        x = self.output(x)
        return x

In [None]:
# Define ArcFace Loss
class ArcFaceLoss(nn.Module):
    def __init__(self, margin=0.5):
        super(ArcFaceLoss, self).__init__()
        self.margin = margin

    def forward(self, embeddings, labels):
        # Calculate the cosine similarity
        cosine = F.normalize(embeddings)
        theta = torch.acos(cosine.clamp(-1, 1))  # Clamp to avoid NaN
        target = torch.cos(theta + self.margin)

        # Convert labels to one-hot encoding
        one_hot = torch.zeros(cosine.size()).to(cosine.device)
        one_hot.scatter_(1, labels.view(-1, 1), 1)

        output = one_hot * target + (1.0 - one_hot) * cosine
        return F.cross_entropy(output, labels)

In [None]:
# Load CIFAR-10 data
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [None]:
# Instantiate model, loss function, and optimizer
model = SimpleCNN()
criterion = ArcFaceLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
# Train the model
def train(model, train_loader, criterion, optimizer, epochs=5):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for images, labels in train_loader:
            optimizer.zero_grad()
            outputs = model(images)  # Get embeddings
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        
        # Calculate average training loss
        avg_train_loss = total_loss / len(train_loader)
        print(f'Epoch [{epoch + 1}/{epochs}], Train Loss: {avg_train_loss:.4f}')

In [None]:
# Test the model
def test(model, test_loader):
    model.eval()
    correct = 0
    total = 0
    total_loss = 0
    with torch.no_grad():
        for images, labels in test_loader:
            outputs = model(images)
            loss = criterion(outputs, labels)
            total_loss += loss.item()

            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    avg_test_loss = total_loss / len(test_loader)
    accuracy = 100 * correct / total
    print(f'Test Loss: {avg_test_loss:.4f}, Test Accuracy: {accuracy:.2f}%')

In [None]:
# Run training and testing
print("Training on CIFAR-10 with ArcFace Loss:")
train(model, train_loader, criterion, optimizer, epochs=5)
print("Testing on CIFAR-10 test set:")
test(model, test_loader)


Files already downloaded and verified
Files already downloaded and verified
Training on CIFAR-10 with ArcFace Loss:
Epoch [1/5], Train Loss: 2.2591
Epoch [2/5], Train Loss: 2.0707
Epoch [3/5], Train Loss: 2.0015
Epoch [4/5], Train Loss: 1.9513
Epoch [5/5], Train Loss: 1.9205
Testing on CIFAR-10 test set:
Test Loss: 2.0099, Test Accuracy: 67.64%
