In [1]:
import torch
import torchvision
from torch.nn import ReLU, Conv2d, BatchNorm2d, Sequential, AdaptiveAvgPool2d, Linear, MaxPool2d, Flatten, CrossEntropyLoss
from typing import Tuple, List, Type, Dict, Any
from torch.autograd import Variable

In [4]:
class SphereEmbedding(torch.nn.Linear):
    def __init__(self, in_features, out_features):
        super(SphereEmbedding, self).__init__(in_features, out_features, bias = False)

    def forward(self, input):
        return torch.nn.functional.linear(torch.nn.functional.normalize(input), torch.nn.functional.normalize(self.weight))

In [18]:
class Combined_loss(torch.nn.Module):
    def __init__(self, MSE_weight, scale=64, margin=0.05):
        super(Combined_loss, self).__init__()
        self.MSE_weight = MSE_weight
        self.scale = scale
        self.margin = margin

        self.CE = CrossEntropyLoss()

    def _MSE(self, anchor_embedding, positive_embedding):
        return self.MSE_weight*torch.mean((anchor_embedding - positive_embedding).pow(2))

    def _CosFaceLoss(self, cosines, target):
        one_hot = torch.zeros_like(cosines)
        one_hot.scatter_(1, target.view(1, -1).long(), 1)
        logits = self.scale * (cosines - self.margin * one_hot) 
        loss = self.CE(logits, target)
        return loss

    def forward(self, anchor_pred, anchor_embedding, positive_pred, positive_embedding,
                target):
        loss = self._MSE(anchor_embedding, positive_embedding)
        loss += self._CosFaceLoss(anchor_pred, target) + self._CosFaceLoss(positive_pred, target)

        return loss

In [None]:
class CosFace(torch.nn.Module):
    def __init__(self, scale=64, margin=0.15):
        super(CosFace, self).__init__()
        self.scale = scale
        self.margin = margin
        self.CE = CrossEntropyLoss()

    def forward(self, cosines, target):
        one_hot = torch.zeros_like(cosines)
        one_hot.scatter_(1, target.view(1, -1).long(), 1)
        logits = self.scale * (cosines - self.margin * one_hot) 
        loss = self.CE(logits, target)
        return loss

In [17]:
class Embedder(torch.nn.Module):
    def __init__(self, embedding_size = 128):
        super(Embedder, self).__init__()
        self.embedding_size = embedding_size
        self.make_resnet()

    def make_resnet(self):
        resnet = torchvision.models.resnet152(weights = torchvision.models.ResNet152_Weights.DEFAULT)
        for parameter in resnet.parameters():
            parameter.requires_grad_(False)

        self.conv1 = resnet.conv1
        self.bn1 = resnet.bn1
        self.relu = resnet.relu
        self.maxpool = resnet.maxpool

        self.layer1 = resnet.layer1
        self.layer2 = resnet.layer2
        self.layer3 = resnet.layer3

        self.avgpool = resnet.avgpool
        self.fc = Linear(1024, self.embedding_size)

        self.layer3[35].conv3.requires_grad_(True)
        self.layer3[35].bn3.requires_grad_(True)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x

In [19]:
class Classificator(torch.nn.Module):
    def __init__(self, num_classes,
                 embedding_size = 128,
                 return_embedding = True):
        super(Classificator, self).__init__()
        self.embedding_size = embedding_size
        self.num_classes = num_classes
        self.return_embedding = return_embedding

        self.embedder = Embedder(embedding_size = embedding_size)
        self.classifier = SphereEmbedding(embedding_size, num_classes)

    def forward(self, batch):
        embedding = self.embedder(batch)
        prediction = self.classifier(embedding)

        if self.return_embedding:
            return prediction, embedding
        else:
            return prediction