In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [15]:
from __future__ import print_function
from __future__ import division
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
from fastai.vision import *
from fastai.layers import *
import math

In [18]:
class ArcMarginProduct(nn.Module):
    
    """Implementation of the Arc Margin Product to be suitable for CelebA and VGGFace2.
    CelebA: multi-label classification -> One hot Encoded label
    VGGFace2: Label classification
    Args:
            in_features: size of each input sample (Output of the last layer from the pretrained model)
            out_features: size of each output sample (Number of classes)
            s: norm of input feature (Refer to the paper)
            m: margin (Refer to the paper)
            cos(theta + m)"""
    def __init__(self, in_features, out_features, label, dataset='celeba',s=30.0, m=0.50, easy_margin=False):
        super(ArcMarginProduct, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.label = label
        self.dataset = dataset
        self.s = s
        self.m = m
        self.weight = Parameter(torch.FloatTensor(out_features, in_features))
        nn.init.xavier_uniform_(self.weight)
        
        self.easy_margin = easy_margin
        self.cos_m = math.cos(m)
        self.sin_m = math.sin(m)
        self.th = math.cos(math.pi - m)
        self.mm = math.sin(math.pi - m) * m
    
    def forward(self, input):
        # --------------------------- cos(theta) & phi(theta) ---------------------------
        cosine = F.linear(F.normalize(input), F.normalize(self.weight))
        sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
        phi = cosine * self.cos_m - sine * self.sin_m
        if self.easy_margin:
            phi = torch.where(cosine > 0, phi, cosine)
        else:
            phi = torch.where(cosine > self.th, phi, cosine - self.mm)
        # --------------------------- convert label to one-hot ---------------------------
        if self.dataset == 'celeba':
            one_hot = self.label
        if self.dataset == 'vggface2':
            one_hot = torch.zeros(cosine.size(), device='cuda')
            one_hot.scatter_(1, self.label.view(-1, 1).long(), 1)
        else:
            raise ValueError('Select the dataset - CelebA or VGGFace2')
        
        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
        output *= self.s
        return output

In [17]:
class Customhead():
    """Implementation of the custom head for different model (Resnet34, Resnet50)
    Args:
        num_classes: Number of classes (CelebA: 40, VGGFace2: 9131)
        label: label of the dataset
        p_dropout: drop out ratio
        eps
        momentum
        affine
        track_running_stat
        """
    def __init__(self, num_classes, label, p_dropout=0.5, eps=1e-05, momentum=0.1,affine=True, track_running_stats=True):
        super(Customhead, self).__init__()
        self.num_classes = num_classes
        self.label = label
        self.p_dropout = p_dropout
        self.eps = eps
        self.momentum = momentum
        self.affine = affine
        self.track_running_stat = track_running_stat
        
    def head(self):
        custom_head = nn.Sequential(
        AdaptiveConcatPool2d(1),
        Flatten(),
        nn.BatchNorm1d(1024, eps=self.eps, momentum=self.momentum, affine=self.affine, track_running_stats=self.track_running_stat),
        nn.Dropout(p=self.p_dropout),
        nn.Linear(in_features=1024, out_features=512, bias=True),
        relu(inplace=True),
        nn.BatchNorm1d(512, eps=self.eps, momentum=self.momentum, affine=self.affine, track_running_stats=self.track_running_stat),
        nn.Dropout(p=self.p_dropout),
        nn.Linear(in_features=512, out_features=512, bias=True),
        ArcMarginProduct(in_features=512,out_features=self.num_classes,label=self.label))
        return custom_head

Defining model: Uncomment the following lines

In [19]:
# Please change the num_classes and label as suitable for the dataset
"""
    CelebA:
        num_classes: number of classes =  40
        label: dataset label (If using ImageDataBunch, label will be one hot encoded label. Please refer to fast.ai lesson 3 for more information - The satallite challenge)
    VGGFace2:
        num_classes: number of classes = 9131
        label: dataset label """


'\n    CelebA:\n        num_classes: number of classes =  40\n        label: dataset label (If using ImageDataBunch, label will be one hot encoded label. Please refer to fast.ai lesson 3 for more information - The satallite challenge)\n    VGGFace2:\n        num_classes: number of classes = 9131\n        label: dataset label '

In [None]:
# arc_face_head = Customhead(num_classes = 40, label = y)
# learn = cnn_learner(data,models.resnet34,custom_head=arc_face_head.head(),metrics=[fbeta])
# learn.crit = F.cross_entropy