In [41]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ArcFaceLoss(nn.Module):
    def __init__(self, dimention, num_classes, s=64.0, m=0.50):
        super(ArcFaceLoss, self).__init__()
        self.s = s  # Feature scale (often set to 64.0)
        self.m = m  # Angular margin (often set to 0.50)
        
        # Initialize learnable weight parameters
        self.weight = nn.Parameter(torch.FloatTensor(num_classes, dimention))
        nn.init.xavier_uniform_(self.weight)  # Xavier initialization

    def forward(self, input, labels):
        # Normalize input features and weight vectors to unit vectors
        cosine = torch.matmul(F.normalize(input), F.normalize(self.weight).T)
        
        # Clip cosine values for numerical stability
        cosine = cosine.clamp(-1.0, 1.0)
        
        # Add angular margin directly to the cosine value (no need to compute angle)
        target_cosine = cosine.gather(1, labels.view(-1, 1))  # Cosine value for the correct class
        cosine_with_margin = target_cosine + self.m  # Add margin
        
        # Clip the margin-modified cosine for numerical stability
        cosine_with_margin = cosine_with_margin.clamp(-1.0, 1.0)
        
        # Create a one-hot encoding for the target class labels
        one_hot = torch.zeros_like(cosine)
        one_hot.scatter_(1, labels.view(-1, 1), 1.0)
        
        # Combine margin cosine for target class, and regular cosine for others
        logits = (one_hot * cosine_with_margin) + ((1.0 - one_hot) * cosine)
        
        # Scale logits by the factor 's'
        logits *= self.s
        
        # Compute cross-entropy loss
        loss = F.cross_entropy(logits, labels)
        return loss


tensor([1.9933])