In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch import Tensor
from typing import *

In [15]:
class NegativeCosineSimilarity(nn.Module):
    def __init__(self,
                 mode: str = 'simplified'
                ) -> None:
        super(NegativeCosineSimilarity,self).__init__()
        
        self.mode = mode
        assert self.mode in ['simplified', 'original'], \
        'loss mode must be either (simplified) or (original)'
        
        
    def _forward1(self,
                  p: Tensor,
                  z: Tensor,
                 ) -> Tensor:
        z = z.detach()
        p = F.normalize(p, dim=1)
        z = F.normalize(z, dim=1)
        loss = -(p*z).sum(dim=1).mean()
        return loss
        
    def _forward2(self,
                  p: Tensor,
                  z: Tensor,
                 ) -> Tensor:
        z = z.detach
        loss = - F.cosine_similarity(p, z, dim=-1).mean()
        return loss
        
    def forward(self,
                  p1: Tensor,
                  p2: Tensor,
                  z1: Tensor,
                  z2: Tensor,
                 ) -> Tensor:
        
        if self.mode == 'original':
            loss1 = self._forward1(p1,z2)
            loss2 = self._forward1(p2,z1)
            loss = loss1/2 +loss2/2
            return loss
        
        elif self.mode == 'simplified':
            loss1 = self._forward1(p1,z2)
            loss2 = self._forward1(p2,z1)
            loss = loss1/2 +loss2/2
            return loss

In [16]:
NegativeCosineSimilarity()
z1 = torch.randn((1, 10))
z2 = torch.rand_like(z1)
p1 = torch.randn((1, 10))
p2 = torch.rand_like(p1)
criterion = NegativeCosineSimilarity()
criterion(p1,p2,z1,z2)

tensor(0.2974)

In [17]:
criterion = NegativeCosineSimilarity('original')
criterion.forward(p1,p2,z1,z2)

tensor(0.2974)

In [18]:
class ProjectionMLP(nn.Module):
    def __init__(self,
                 input_dim: int,
                 hidden_dim: int = 2048,
                 output_dim: int = 2048,
                ) -> None:
        super(ProjectionMLP,self).__init__()
        
        

        self.layer1 = nn.Sequential(nn.Linear(in_features=input_dim, out_features= hidden_dim, bias=False ),
                                    nn.BatchNorm1d(hidden_dim),
                                    nn.ReLU(inplace=True)
                                   )

        self.layer2 = nn.Sequential(nn.Linear(in_features=hidden_dim, out_features=hidden_dim, bias=False),
                                    nn.BatchNorm1d(hidden_dim),
                                    nn.ReLU(inplace=True)
                                   )
        
        self.layer3 = nn.Sequential(nn.Linear(in_features=hidden_dim, out_features=output_dim, bias=False),
                                    nn.BatchNorm1d(hidden_dim)
                                   )
    
    def forward(self, x: Tensor) -> Tensor:
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        
        return x

In [19]:
x = torch.randn(2,2048)
model = ProjectionMLP(2048)
model(x)

tensor([[-0.9948, -0.9908, -0.9999,  ..., -0.9918, -0.9999, -1.0000],
        [ 0.9948,  0.9908,  0.9999,  ...,  0.9918,  0.9999,  1.0000]],
       grad_fn=<NativeBatchNormBackward>)

In [20]:
class PredictionMLP(nn.Module):
    def __init__(self,
                 input_dim: int = 2048,
                 hidden_dim: int = 512,
                 output_dim: int = 2048,
                ) -> None:
        super(PredictionMLP,self).__init__()
        
        self.layer1 = nn.Sequential(nn.Linear(in_features=input_dim, out_features=hidden_dim, bias= False),
                                    nn.BatchNorm1d(hidden_dim),
                                    nn.ReLU(inplace=True)
                                   )
        
        self.layer2 = nn.Sequential(nn.Linear(in_features=hidden_dim, out_features=output_dim))
        
    def forward(self, x: Tensor) -> Tensor:
        x = self.layer1(x)
        x = self.layer2(x)
        
        return x
        

In [21]:
model  = PredictionMLP()
model(x)

tensor([[-0.3141, -0.2764,  0.7324,  ...,  0.1383, -0.5016, -0.1909],
        [ 0.1296, -0.7277, -0.4366,  ..., -0.1094,  0.6023, -0.2056]],
       grad_fn=<AddmmBackward>)

In [22]:
class EncodProject(nn.Module):
    def __init__(self,
                 model: nn.Module,
                 hidden_dim: int = 2048,
                 output_dim: int = 2048
                 ) -> None:
        super(EncodProject, self).__init__()
                
        self.encoder = nn.Sequential(*list(model.children())[:-1])
        
        self.projector = ProjectionMLP(input_dim=nn.Sequential(*list(model.children()))[-1].in_features,
                                       hidden_dim=hidden_dim,
                                       output_dim=output_dim
                                       )
    def forward(self, x: Tensor) -> Tensor:
        x = self.encoder(x)
        x = torch.flatten(x, 1)
        x = self.projector(x)
        return x

In [23]:
class SimSiam(nn.Module):
    def __init__(self,
                 model: nn.Module,
                 projector_hidden_dim: int = 2048,
                 projector_output_dim: int = 2048,
                 predictor_hidden_dim: int = 512,
                 predictor_output_dim: int = 2048
                ) -> None: 
        super(SimSiam, self).__init__()
        
        self.encode_project = EncodProject(model, 
                                           hidden_dim= projector_hidden_dim,
                                           output_dim= projector_hidden_dim
                                          )
        self.predictor = PredictionMLP(input_dim=projector_output_dim,
                                       hidden_dim=predictor_hidden_dim,
                                       output_dim=predictor_output_dim)
        
    def forward(self, 
                x1: Tensor,
                x2: Tensor
               ) -> Tuple[Tensor]:
        
        f, h = self.encode_project, self.predictor
        z1, z2 = f(x1), f(x2)
        p1, p2 = h(z1), h(z2)
        
        
        return {'p1': p1,
                'p2' : p2,
                'z1' : z1,
                'z2' : z2}

In [29]:
from torchvision.models import resnet18

model = resnet18()

learner = SimSiam(model)

opt = torch.optim.Adam(learner.parameters(), lr=0.001)

criterion = NegativeCosineSimilarity()

def sample_unlabelled_images():
    return torch.randn(20, 3, 256, 256)

for _ in range(100):
    images1 = sample_unlabelled_images()
    images2 = images1*0.9
    p1, p2, z1, z2 = learner(images1, images2).values()
    loss = criterion(p1, p2, z1, z2)
    opt.zero_grad()
    loss.backward()
    opt.step()
    print(_+1,loss)

1 tensor(-0.0045, grad_fn=<AddBackward0>)
2 tensor(-0.0239, grad_fn=<AddBackward0>)
3 tensor(-0.0532, grad_fn=<AddBackward0>)
4 tensor(-0.0805, grad_fn=<AddBackward0>)
5 tensor(-0.1194, grad_fn=<AddBackward0>)
6 tensor(-0.1787, grad_fn=<AddBackward0>)
7 tensor(-0.2419, grad_fn=<AddBackward0>)
8 tensor(-0.3658, grad_fn=<AddBackward0>)
9 tensor(-0.4321, grad_fn=<AddBackward0>)
10 tensor(-0.5377, grad_fn=<AddBackward0>)
11 tensor(-0.5980, grad_fn=<AddBackward0>)
12 tensor(-0.6544, grad_fn=<AddBackward0>)
13 tensor(-0.7177, grad_fn=<AddBackward0>)
14 tensor(-0.7648, grad_fn=<AddBackward0>)
15 tensor(-0.7951, grad_fn=<AddBackward0>)
16 tensor(-0.8146, grad_fn=<AddBackward0>)
17 tensor(-0.8446, grad_fn=<AddBackward0>)
18 tensor(-0.8590, grad_fn=<AddBackward0>)
19 tensor(-0.8557, grad_fn=<AddBackward0>)
20 tensor(-0.8766, grad_fn=<AddBackward0>)
21 tensor(-0.8814, grad_fn=<AddBackward0>)
22 tensor(-0.8823, grad_fn=<AddBackward0>)
23 tensor(-0.8889, grad_fn=<AddBackward0>)
24 tensor(-0.8808, g

In [118]:
def test(x,y,z):
    dic = {'x':x,
           'y':y,
           'z':z}
    return dic
x,y,z = test(1,2,3).values()
z

3