In [81]:
import torch
import torch.nn.functional as F

from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import transforms as T, datasets

import math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import plotly.express as px

from tqdm.notebook import tqdm
from sklearn.manifold import TSNE

device = torch.device('mps' if torch.cuda.is_available() else 'cpu')
print(device)

cpu


In [82]:
train_dataset = datasets.MNIST(root='./sample_data', train=True, transform=T.ToTensor(), download=True)
test_dataset = datasets.MNIST(root='./sample_data', train=False, transform=T.ToTensor(), download=True)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=True)

In [87]:
import math

class ArcFaceLoss(nn.Module):
    def __init__(self, num_classes, embedding_dim, scale=30.0, margin=0.50):
        super().__init__()
        self.num_classes = num_classes
        self.embedding_dim = embedding_dim
        self.scale = scale  # Scaling factor
        self.margin = margin  # Angular margin (default 0.5)
        
        # Learnable weight matrix (class centers)
        self.weights = nn.Parameter(torch.randn(num_classes, embedding_dim))
        nn.init.xavier_uniform_(self.weights)

    def forward(self, embeddings, labels):
        # Normalize embeddings and weight vectors (L2 norm)
        embeddings = F.normalize(embeddings, p=2, dim=1)
        weights = F.normalize(self.weights, p=2, dim=1)
        
        # Compute cosine similarity
        logits = embeddings @ weights.T  # Shape: (batch_size, num_classes)
        theta = torch.acos(torch.clamp(logits, -1.0, 1.0))  # Convert to angle

        # Add angular margin to the correct class
        logits_with_margin = torch.cos(theta + self.margin)

        # Scale logits
        logits_scaled = self.scale * logits_with_margin

        # Cross-entropy loss
        loss = F.cross_entropy(logits_scaled, labels)
        return loss