In [None]:
import torch
import torchvision
from torch.nn import ReLU, Conv2d, BatchNorm2d, Sequential, AdaptiveAvgPool2d, Linear, MaxPool2d, Flatten, CrossEntropyLoss
from typing import Tuple, List, Type, Dict, Any
from torch.autograd import Variable

In [None]:
class SphericalClassifier(torch.nn.Linear):
    """
    Last layer of network, solves metric learning task via classification
    forward-call takes data and targets and returns logits.
    """
    def __init__(self, in_features, out_features, scale = 64, margins=(1.0, 0, 0)):
        super(SphericalClassifier, self).__init__(in_features, out_features, bias=False)
        self._scale = scale
        self._margins = margins

    def __logits_calculator(self, cosines, target):
        m1, m2, m3 = self._margins
        one_hot = torch.zeros_like(cosines)
        one_hot.scatter_(1, target.view(-1, 1).long(), 1)

        req_cosines = torch.sum(one_hot * cosines, dim=1, keepdim=True)
        req_cosines = torch.cos(m2 + m1*torch.arccos(req_cosines))
        logits = self._scale * (cosines.scatter(dim=1,
                                               index=target.view(-1, 1).long(),
                                               src=req_cosines) - one_hot * m3)
        
        return logits
 
    def calculate_cosines(self, data):
        return torch.nn.functional.linear(torch.nn.functional.normalize(data),
                                  torch.nn.functional.normalize(self.weight))
        
    def forward(self, data, target):
        cosines = self.calculate_cosines(data)
        logits = self.__logits_calculator(cosines, target)
        return logits

In [None]:
class Embedder(torch.nn.Module):
    """
    Main part on network, maps images to embeddings, based on resnets
    Arg:
    embedding_dim -- size of embedding space

    cut_fc -- if False, after convolutions fully-conected layer will be applied,
    if False, there will be no fully-connected layer and very last
    convolutional block(resblock or bottleneck) will be replaced with usual conv
    with out_channels=embedding_dim

    reference_resnet -- describe the base model which will be modified

    after_relu -- if True, applies ReLU activation in the vety end.
    """
    __references = {'18' : (torchvision.models.resnet18, torchvision.models.ResNet18_Weights.DEFAULT, 1, 512),
                    '34' : (torchvision.models.resnet34, torchvision.models.ResNet34_Weights.DEFAULT, 2, 512),
                    '50' : (torchvision.models.resnet50, torchvision.models.ResNet50_Weights.DEFAULT, 2, 2048)}

    def __init__(self, embedding_dim = 128, cut_fc = False, reference_resnet = '18', after_relu=False):
        super(Embedder, self).__init__()
        self.cut_fc = cut_fc
        self.__embedding_dim = embedding_dim
        self.__reference_resnet = reference_resnet
        self.__after_relu = ReLU() if after_relu else None
        self.__build_resnet()
        
    def __build_resnet(self):
        ref_class, weights, layer, in_channels = Embedder.__references[self.__reference_resnet]

        self.__inner = ref_class(weights=weights)
        if not self.cut_fc:
            self.__inner.fc = Linear(in_channels, self.__embedding_dim)
        else:
            self.__inner.layer4[layer] = Conv2d(in_channels, self.__embedding_dim, kernel_size=3, padding=1, stride=1, bias=False)
            self.__flatten = torch.nn.Flatten()
            del(self.__inner.fc)

    @property
    def embedding_dim(self):
        return self.__embedding_dim
    
    def forward(self, x):
        if not self.cut_fc:
            if self.__after_relu:
                return self.__after_relu(self.__inner(x))
            else:
                return self.__inner(x)
        else:
            x = self.__inner.relu(self.__inner.bn1(self.__inner.conv1(x)))
            x = self.__inner.maxpool(x)
            x = self.__inner.layer1(x)
            x = self.__inner.layer2(x)
            x = self.__inner.layer3(x)
            x = self.__inner.layer4(x)
            x = self.__inner.avgpool(x)
            x = self.__flatten(x)
            if self.__after_relu:
                return self.__after_relu(x)
            else:
                return x

In [None]:
class UnitedModel(torch.nn.Module):
    def __init__(self,
                 out_features_classifier,
                 scale=64,
                 margins=(1.0, 0, 0),
                 embedding_dim=128,
                 cut_fc=False,
                 reference_resnet='18',
                 after_relu=False):
        super(UnitedModel, self).__init__()
        self.embedder = Embedder(embedding_dim, cut_fc, reference_resnet, after_relu)
        self.classifier = SphericalClassifier(embedding_dim, out_features_classifier, scale, margins)
        
    def forward(self, imgs, labels):
        embeddings = self.embedder(imgs)
        logits = self.classifier(embeddings, labels)
        return logits