In [17]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
import os
import torchmetrics
import datetime
from torch import einsum
from einops import rearrange, repeat
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from timm.models.registry import register_model
from timm.models.vision_transformer import _cfg
from ConvFFN import CONVFFN
from OwnPVTpt2 import LearnablePatchAttentionModel
import albumentations as A
import torch.distributed as dist
import torchvision
from tqdm.notebook import tqdm 

In [15]:
def default(val, def_val):
    return def_val if val is None else val

In [None]:
def MaybeSyncBatchnorm(is_distributed = None):
    is_distributed = default(is_distributed, dist.is_initialized() and dist.get_world_size() > 1)
    return nn.SyncBatchNorm if is_distributed else nn.BatchNorm1d

In [3]:
def Byol_loss_fn(z, detached_grad):
    z_norm = F.normalize(z, dim=-1, p=2)
    detached_grad_norm = F.normalize(detached_grad, dim=-1, p=2)

    return 2 - 2 * (z_norm*detached_grad_norm).sum(dim=-1)


In [4]:
model = LearnablePatchAttentionModel()


[0.0, 0.011111111380159855, 0.02222222276031971, 0.03333333507180214, 0.04444444552063942, 0.0555555559694767, 0.06666666269302368, 0.07777778059244156, 0.08888889104127884, 0.10000000149011612]


In [5]:
model(torch.randn(1, 3, 224, 224)).shape


torch.Size([1, 64, 56, 56]) Projection
torch.Size([1, 64, 56, 56]) torch.Size([1, 64, 56, 56])
torch.Size([1, 64, 56, 56]) torch.Size([1, 64, 56, 56])
torch.Size([1, 3136, 64]) Stagesout_0
torch.Size([1, 128, 28, 28]) Projection
torch.Size([1, 128, 28, 28]) torch.Size([1, 128, 28, 28])
torch.Size([1, 128, 28, 28]) torch.Size([1, 128, 28, 28])
torch.Size([1, 784, 128]) Stagesout_1
torch.Size([1, 256, 14, 14]) Projection
torch.Size([1, 256, 14, 14]) torch.Size([1, 256, 14, 14])
torch.Size([1, 256, 14, 14]) torch.Size([1, 256, 14, 14])
torch.Size([1, 256, 14, 14]) torch.Size([1, 256, 14, 14])
torch.Size([1, 196, 256]) Stagesout_2
torch.Size([1, 512, 7, 7]) Projection
torch.Size([1, 512, 7, 7]) torch.Size([1, 512, 7, 7])
torch.Size([1, 512, 7, 7]) torch.Size([1, 512, 7, 7])
torch.Size([1, 512, 7, 7]) torch.Size([1, 512, 7, 7])
torch.Size([1, 49, 512]) Stagesout_3


torch.Size([1, 512])

In [6]:
a = nn.Linear(512, 3)


In [7]:
a(torch.randn(1, 49, 512).mean(dim=1)).shape


torch.Size([1, 3])

In [None]:
class MLP(nn.Module):
    def __init__(self,dim,hidden_dim,out_dim):
        super().__init__()
        self.fc1 = nn.Linear(dim,hidden_dim)
        self.batchnorm1 = MaybeSyncBatchnorm()(hidden_dim)
        self.dropout = nn.Dropout(.2)
        self.fc2 = nn.Linear(hidden_dim,hidden_dim)
        self.dropout2 = nn.Dropout(.2)
        self.batchnorm2 = MaybeSyncBatchnorm()(hidden_dim)
        self.fc3 = nn.Linear(hidden_dim,out_dim)
        self.act = nn.GELU()

        self.apply(self._init_weights)
    
    def _init_weights(self,m):
        if isinstance(m,nn.Linear):
            trunc_normal_(m.weight,std=.02)
            if isinstance(m,nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias,0)
    
    def forward(self,x):
        x = self.fc1(x)
        x = self.batchnorm1(x)
        x = self.act(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.batchnorm2(x)
        x = self.act(x)
        x = self.dropout2(x)
        x = self.fc3(x)
        return x

In [8]:
class Model_Boyol(nn.Module):
    def __init__(self, modela, modelb, projection_layer_size=1024,projection_hidden_layer_size=4096):
        super().__init__()

        assert modela!=None and modelb!=None, "Both models must be defined"
        assert modela.embeded_dimesion[-1] == modelb.embeded_dimesion[-1], "Embedding dimensions of both models must be equal"
        
        self.modela = modela
        self.modelb = modelb
        for i in self.modelb.parameters():
            i.requires_grad = False



        self.projection_layera = MLP(self.modela.embeded_dimesion[-1],projection_hidden_layer_size,projection_layer_size)
        self.projection_layerb = MLP(self.modelb.embeded_dimesion[-1],projection_hidden_layer_size,projection_layer_size)

        

        

    def forward(self, x, y):
        z1 = self.modela(x)


        z1 = self.projection_layera(z1)
        
        with torch.no_grad():
            z2 = self.modelb(y)
            z2 = self.projection_layerb(z2)

        return z1, z2 

    def update_average(self, tau, online_parameter, offline_parameter):
        return tau*offline_parameter+(1-tau)*online_parameter

    def update_model_b(self):

        for online_parameters, offline_parameters in zip(self.modela.parameters(), self.modelb.parameters()):
            a=online_parameters.data
            offline_parameters.data = self.update_average(
                0.99, online_parameters.data, offline_parameters.data)
            b=offline_parameters.data
            print("Update from ",a," to ",b)
        print("Model B Updated")


In [9]:
x = torch.randn(1, 3, 224, 224)
y = torch.randn(1, 3, 224, 224)

model_boyol = Model_Boyol(LearnablePatchAttentionModel(),
                          LearnablePatchAttentionModel())


[0.0, 0.011111111380159855, 0.02222222276031971, 0.03333333507180214, 0.04444444552063942, 0.0555555559694767, 0.06666666269302368, 0.07777778059244156, 0.08888889104127884, 0.10000000149011612]
[0.0, 0.011111111380159855, 0.02222222276031971, 0.03333333507180214, 0.04444444552063942, 0.0555555559694767, 0.06666666269302368, 0.07777778059244156, 0.08888889104127884, 0.10000000149011612]


In [10]:
z1,z2=model_boyol(x, y)
z3,z4=model_boyol(y, x)

torch.Size([1, 64, 56, 56]) Projection
torch.Size([1, 64, 56, 56]) torch.Size([1, 64, 56, 56])
torch.Size([1, 64, 56, 56]) torch.Size([1, 64, 56, 56])
torch.Size([1, 3136, 64]) Stagesout_0
torch.Size([1, 128, 28, 28]) Projection
torch.Size([1, 128, 28, 28]) torch.Size([1, 128, 28, 28])
torch.Size([1, 128, 28, 28]) torch.Size([1, 128, 28, 28])
torch.Size([1, 784, 128]) Stagesout_1
torch.Size([1, 256, 14, 14]) Projection
torch.Size([1, 256, 14, 14]) torch.Size([1, 256, 14, 14])
torch.Size([1, 256, 14, 14]) torch.Size([1, 256, 14, 14])
torch.Size([1, 256, 14, 14]) torch.Size([1, 256, 14, 14])
torch.Size([1, 196, 256]) Stagesout_2
torch.Size([1, 512, 7, 7]) Projection
torch.Size([1, 512, 7, 7]) torch.Size([1, 512, 7, 7])
torch.Size([1, 512, 7, 7]) torch.Size([1, 512, 7, 7])
torch.Size([1, 512, 7, 7]) torch.Size([1, 512, 7, 7])
torch.Size([1, 49, 512]) Stagesout_3
torch.Size([1, 64, 56, 56]) Projection
torch.Size([1, 64, 56, 56]) torch.Size([1, 64, 56, 56])
torch.Size([1, 64, 56, 56]) torch

In [16]:
def train_one_iter(train_loader,optimizer,model,criterion,device,epoch,log_interval=100,transform=None):
    assert transform!=None, "Transform must be defined"
    running_loss=0
    model.train()

    for batch_idx, image in tqdm(enumerate(train_loader),total=len(enumerate(train_loader))):

        image=image.to(device)
        transformed_image = transform(image=image.detach().numpy())
        transformed_image = transformed_image['image']
        transformed_image = transformed_image.to(device)

        optimizer.zero_grad()

        z1,z2=model(image,transformed_image)
        loss1 = criterion(z1,z2.detach())

        z1,z2=model(transformed_image,image)
        loss2 = criterion(z1,z2.detach())

        loss = loss1+loss2

        loss.backward()

        optimizer.step()

        model.update_model_b()

        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.8f}'.format(
                epoch, batch_idx * len(image), len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.item()))
        running_loss+=loss.item()
    return running_loss/len(train_loader)