In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class AdditiveMarginSoftmaxLoss(nn.Module):
    # TODO: implement 'reduction'
    def __init__(self, in_features, out_features, s=30.0, m=0.4):
        """ Additive Margin softmax loss """
        # Notes: last layer has no bias (not shifting from origin point)
        super(AdditiveMarginSoftmaxLoss, self).__init__()
        self.s = s
        self.m = m
        self.in_features = in_features
        self.out_features = out_features
        self.fc = nn.Linear(in_features, out_features, bias=False)
        
    def forward(self, x, labels):
        """ input shape (N, in_features) """
        
        assert len(x) == len(labels)
        assert torch.min(labels) >= 0
        assert torch.max(labels) < self.out_features       
       
        # for-loop unpacks generator only.
        #
        # FIXME: W is not normalized actually.
        #
        # See discussion: 
        # https://discuss.pytorch.org/t/how-to-do-weight-normalization-in-last-classification-layer/35193/4
        W = self.fc.weight / torch.norm(self.fc.weight, dim=1, keepdim=True)
        
        print(f'{self.fc.weight=}')
        print(f'{W=}')

        x = F.normalize(x, dim=1)
        print(f'{x=}')

        wf = torch.mm(W, x.transpose(0, 1)).transpose(0, 1)
        print(f'{wf=}')
        print(f'{wf.transpose(0, 1)[labels]=}')
        print(f'{torch.diagonal(wf.transpose(0, 1)[labels])=}')

        numerator = self.s * (torch.diagonal(wf.transpose(0, 1)[labels]) - self.m)
        print(f'{numerator=}')
        
        excl = torch.cat([torch.cat((wf[i, :y], wf[i, y+1:])).unsqueeze(0) for i, y in enumerate(labels)], dim=0)
        print(f'{excl=}')
        
        denominator = torch.exp(numerator) + torch.sum(torch.exp(self.s * excl), dim=1)
        print(f'{torch.sum(torch.exp(self.s * excl), dim=1)=}')
        print(f'{denominator=}')
        
        L = numerator - torch.log(denominator)
        return -torch.mean(L)

In [3]:
torch.manual_seed(0)

# 3-dim feature maps to one-hot vector with 3-dim
loss = AdditiveMarginSoftmaxLoss(3, 3, m=0.4)

x = loss.fc.weight[0].detach().unsqueeze(0)
# x = torch.randn(1, 3)
y = torch.LongTensor([0])

print(f'{loss(x, y).item()=}')

self.fc.weight=Parameter containing:
tensor([[-0.0043,  0.3097, -0.4752],
        [-0.4249, -0.2224,  0.1548],
        [-0.0114,  0.4578, -0.0512]], requires_grad=True)
W=tensor([[-0.0076,  0.5460, -0.8377],
        [-0.8431, -0.4413,  0.3072],
        [-0.0248,  0.9935, -0.1112]], grad_fn=<DivBackward0>)
x=tensor([[-0.0076,  0.5460, -0.8377]])
wf=tensor([[ 1.0000, -0.4919,  0.6358]], grad_fn=<TransposeBackward0>)
wf.transpose(0, 1)[labels]=tensor([[1.0000]], grad_fn=<IndexBackward0>)
torch.diagonal(wf.transpose(0, 1)[labels])=tensor([1.0000], grad_fn=<DiagonalBackward0>)
numerator=tensor([18.0000], grad_fn=<MulBackward0>)
excl=tensor([[-0.4919,  0.6358]], grad_fn=<CatBackward0>)
torch.sum(torch.exp(self.s * excl), dim=1)=tensor([1.9223e+08], grad_fn=<SumBackward1>)
denominator=tensor([2.5789e+08], grad_fn=<AddBackward0>)
loss(x, y).item()=1.3680419921875
