In [None]:
#!/usr/bin/env python
# encoding: utf-8
'''
@author: wujiyang
@contact: wujiyang@hust.edu.cn
@file: ArcMarginProduct.py
@time: 2018/12/25 9:13
@desc: additive angular margin for arcface/insightface
'''

import math
import torch
from torch import nn
from torch.nn import Parameter
import torch.nn.functional as F

class ArcMarginProduct(nn.Module):
    def __init__(self, in_feature=16, out_feature=41, s=32.0, m=1.35, easy_margin=False):
        super(ArcMarginProduct, self).__init__()
        self.in_feature = in_feature
        self.out_feature = out_feature
        #s is the feature scale
        self.s = s
        #m is the margin value
        self.m = m
        #weight matrix dxn d=number of samples n= number of classes
        self.weight = Parameter(torch.Tensor(out_feature, in_feature))
        
        nn.init.xavier_uniform_(self.weight)

        self.easy_margin = easy_margin
        self.cos_m = math.cos(m)
        self.sin_m = math.sin(m)

        # make the function cos(theta+m) monotonic decreasing while theta in [0°,180°]
        self.th = math.cos(math.pi - m)
        self.mm = math.sin(math.pi - m) * m

    def forward(self, x, label):
        # cos(theta)yi
        #full connected layer takes input the normalized input and the normalized weight
        cosine = F.linear(F.normalize(x), F.normalize(self.weight))
                
        # cos(theta + m)
        sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
        phi = cosine * self.cos_m - sine * self.sin_m
        
        if self.easy_margin:
            phi = torch.where(cosine > 0, phi, cosine)
        else:
            phi = torch.where((cosine - self.th) > 0, phi, cosine - self.mm)
       
        #Groundtruth one hot vector
        one_hot = torch.zeros_like(cosine)
        one_hot.scatter_(1, label.view(-1, 1), 1)
        cosine = self.s * cosine
        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
        output = output #* self.s

        return output



In [None]:
#criterion_classi = torch.nn.CrossEntropyLoss().to(device)
#margin = ArcMarginProduct(args.feature_dim, trainset.class_nums, s=args.scale_size)
#feature = net(img)
#output = margin(feature)
#loss_classi = criterion_classi(output, label)