In [None]:
import torch as torch
import torch.optim as optim
import torch.nn as nn
import os
import torchvision
import torch.nn as nn
from torch.autograd import Variable as var
import logging as log
import gc
import numpy as np
import torchvision.transforms as transforms
import torchvision.models as models
from torch.utils.data import Dataset
from PIL import Image
from glob import glob

In [None]:
# CUDA for PyTorch
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
torch.backends.cudnn.benchmark = True

In [None]:
## Constants
N = 128 # Batch Size
T = 0.07 # Temperature
C = 3 # Number of Channels
m = 0.9 # momemntum contrast
K = 4096 # dictionary size

In [None]:
DATA = '/kaggle/input/imagenetmini-1000/imagenet-mini/'

In [None]:
class ImageNet(Dataset):
    def __init__(self, root_dir, train=False, transform=None):

            self.root_dir = root_dir
            
            self.transform = transform

            self.sub_directory = 'train' if train else 'val'
            
            path = os.path.join(
            root_dir, self.sub_directory, "*","*")
            
            self.imgs = glob(path)
            
    def __len__(self):
        return len(self.imgs)
    
    def __getitem__(self,idx):
        img = Image.open(self.imgs[idx],).convert('RGB')
        if self.transform is not None:
            img = self.transform(img);

        return img;

In [None]:
## Augmentations

def get_random_augmentation():
    return transforms.Compose([
        transforms.RandomResizedCrop(size=224,scale=(0.2, 1.)),
        transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
        transforms.RandomHorizontalFlip(),
        transforms.Grayscale(num_output_channels=C),
        transforms.ToTensor(),
        transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
    ])

In [None]:
## Load Data
transform = get_random_augmentation()

train_data = ImageNet(
    root_dir=DATA, train=True,  transform=transform)
test_data = ImageNet(
    root_dir=DATA, train=False,  transform=transform)

train_set = torch.utils.data.DataLoader(
    train_data, batch_size=N,shuffle=True,num_workers = 4, pin_memory=True, drop_last=True)
test_set = torch.utils.data.DataLoader(
    test_data, batch_size=N,shuffle=False,num_workers = 4,pin_memory=True, drop_last=True)

In [None]:
class Resnet50Model(nn.Module):
    def __init__(self):
        super(Resnet50Model, self).__init__()

        model = models.resnet50(pretrained=False)
        modules = list(model.children())[:-1]
        self.resnet = nn.Sequential(*modules)
        self.fc = nn.Sequential(nn.Linear(2048, N), nn.ReLU())
    
    def forward(self,x):
        x = self.resnet(x)
        x = x.view(x.size(0),-1)
        x = self.fc(x)
        
        return x
    

In [None]:
encoder_q = Resnet50Model().cuda()
encoder_k = Resnet50Model().cuda()

In [None]:
optimizer = torch.optim.SGD(encoder_q.parameters(), lr=0.03, momentum=0.9, weight_decay=0.0001)

In [None]:
cec = nn.CrossEntropyLoss().cuda()

In [None]:
for param_q, param_k in zip(encoder_q.parameters(), encoder_k.parameters()):
            param_k.data.copy_(param_q.data)
            param_k.requires_grad = False

In [None]:
@torch.no_grad()
def dequeue(queue):
    return queue[:,:(K-N)]

@torch.no_grad()
def enqueue(queue,k):
    return torch.cat([k, queue],dim=1)


In [None]:
@torch.no_grad()
def concat_all_gather(tensor):
    
    tensors_gather = [torch.ones_like(tensor)
        for _ in range(torch.distributed.get_world_size())]

    output = torch.cat(tensors_gather, dim=1)
    return output

In [None]:
with torch.no_grad():
    queue = torch.randn(N,K).cuda()
    queue = nn.functional.normalize(queue, dim=0)

In [None]:
def saveModel(epoch, model,optimizer,loss,path):
      torch.save({
              'epoch': epoch,
              'model_state_dict': model.state_dict(),
              'optimizer_state_dict': optimizer.state_dict(),
              'loss': loss
              }, path)

In [None]:
def loadModel(model,optimizer,path):
    checkpoint = torch.load(path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    loss = checkpoint['loss']

    print('Epoch: ',epoch,'Loss: ',loss)
    return model,optimizer, epoch, loss;

In [None]:

for e in range(20):
    epoch_loss = 0.0
    running_loss = 0.0
    for i,(images) in enumerate(train_set):
    
        images = var(images.cuda())
        optimizer.zero_grad()

        images_q = images
        images_k = images

        q = encoder_q.forward(images_q)
        q = nn.functional.normalize(q,dim=1)
        
        with torch.no_grad():
            for p_k,p_q in zip(encoder_k.parameters(),encoder_q.parameters()):
                val = (1-m)*p_q.data + m*p_k.data
                p_k.data = p_k.data.copy_(val)
                
        
            k = encoder_k.forward(images_k)
            k = nn.functional.normalize(k, dim=1)



         # positive logits: Nx1
        l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
        # negative logits: NxK
        l_neg = torch.einsum('nc,ck->nk', [q, queue.clone().detach()])


        logits = torch.cat([l_pos, l_neg], dim=1).cuda()

        labels = torch.zeros(N).type(torch.cuda.LongTensor).cuda()

        logits = logits/T;

        loss = cec(logits, labels)

        #updating query encoder
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

        running_loss += loss.item()

        queue = dequeue(queue)
        queue = enqueue(queue, k)
        
        if((i+1) % 50 == 0):
            print('Epoch :',e+1,'Batch :',(i+1),'Loss :',float(running_loss/50))
            running_loss = 0.0
            saveModel(e,encoder_q,optimizer,epoch_loss,"encoder_query.pth")
            saveModel(e,encoder_k,optimizer,epoch_loss,"encoder_keys.pth")
        
    saveModel(e,encoder_q,optimizer,epoch_loss,"encoder_query.pth")
    saveModel(e,encoder_k,optimizer,epoch_loss,"encoder_keys.pth")
    print('Epoch :',e+1, 'Loss :',epoch_loss/len(train_set))

In [None]:

def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res
