In [32]:
#in this notebook we will implement MOCO paper:
#https://arxiv.org/pdf/1911.05722.pdf
#CNN architecture we change, but MOCO loss function is the same, and training is the same
#we use 5 different data gmentations (rotations, blur, color distortion, cropping and resizing) for defining the positive samples

#training is done on 1 GPU, training setting are:
#1. Use the entire training data to learn the representations.
#2. Once the representations are learned, use a linear and logistic layers and retrain with 10-50% of supervised training data.
#3. Experiment with two different sizes for the encoder dictionary.

In [49]:
#imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchsummary import summary

from functools import partial
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm import tqdm

In [34]:
from PIL import Image

In [115]:
import importlib

import models
importlib.reload(models)
from models import  pentaClassifier, binaryClassifier, baseClassifier


import config
from config import cfg
importlib.reload(config)



<module 'config' from '/home/lisa/bhartendu/adrl/A3/config.py'>

In [36]:
#device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#### CIFAR10

In [37]:
from torchvision.datasets import CIFAR10

In [38]:
import argparse
import json
import math
import os
import pandas as pd

In [39]:
#seed
seed = 42
os.environ['PYTHONHASHSEED'] = str(seed)
# Torch RNG
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)



#### HYPER - PARAMETERS

In [40]:
results_dir = 'results'

In [41]:
lr = 0.06
epochs = 200
batch_size = 512
schedule = [120, 160]
wd = 5e-4
cos = True      # cosine lr schedule

In [42]:
moco_dim = 512
moco_k = 4096
moco_m = 0.99
moco_t = 0.07

In [43]:
bn_splits = 8

## data-loader

#### Adapted from pytorch code of Contrastive learning libs: http://github.com/zhirongw/lemniscate.pytorch

In [44]:
class CIFAR10Pair(CIFAR10):
    """CIFAR10 Dataset.
    """
    def __getitem__(self, index):
        img = self.data[index]
        img = Image.fromarray(img)

        if self.transform is not None:
            im_1 = self.transform(img)
            im_2 = self.transform(img)

        return im_1, im_2

# train_transform = transforms.Compose([
#     transforms.RandomResizedCrop(32),
#     transforms.RandomHorizontalFlip(p=0.5),
#     transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
#     transforms.RandomGrayscale(p=0.2),
#     transforms.ToTensor(),
#     transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])])
#we will aplly 5 different data gmentations (rotations, blur, color distortion, cropping and resizing) for defining the positive samples
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(32),
    #random rotation with random angle
    transforms.RandomRotation([-1* torch.rand(1)*360 , torch.rand(1)*360]),
    #random blur
    transforms.RandomApply([transforms.GaussianBlur(3, sigma=[0.1, 2.0])], p=0.5),
    #random color distortion
    transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
    transforms.ToTensor(),
    transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])])
    

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])])

# data prepare
train_data = CIFAR10Pair(root='data', train=True, transform=train_transform, download=True)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=16, pin_memory=True, drop_last=True)

memory_data = CIFAR10(root='data', train=True, transform=test_transform, download=True)
memory_loader = DataLoader(memory_data, batch_size=batch_size, shuffle=False, num_workers=16, pin_memory=True)

test_data = CIFAR10(root='data', train=False, transform=test_transform, download=True)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=16, pin_memory=True)


Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


## Split Batch implementation adapted from: 
#### https://github.com/davidcpage/cifar10-fast/blob/master/torch_backend.py

In [116]:
class SplitBatchNorm(nn.BatchNorm2d):
    def __init__(self, num_features, num_splits= bn_splits, **kw):
        super().__init__(num_features, **kw)
        self.num_splits = num_splits
        
    def forward(self, input):
        N, C, H, W = input.shape
        if self.training or not self.track_running_stats:
            running_mean_split = self.running_mean.repeat(self.num_splits)
            running_var_split = self.running_var.repeat(self.num_splits)
            outcome = nn.functional.batch_norm(
                input.view(-1, C * self.num_splits, H, W), running_mean_split, running_var_split, 
                self.weight.repeat(self.num_splits), self.bias.repeat(self.num_splits),
                True, self.momentum, self.eps).view(N, C, H, W)
            self.running_mean.data.copy_(running_mean_split.view(self.num_splits, C).mean(dim=0))
            self.running_var.data.copy_(running_var_split.view(self.num_splits, C).mean(dim=0))
            return outcome
        else:
            return nn.functional.batch_norm(
                input, self.running_mean, self.running_var, 
                self.weight, self.bias, False, self.momentum, self.eps)

#### the following class was adapted from facebookresearch/moco which replaces 

In [46]:
# #write a function to replace all the batchnorm layers with splitbatchnorm layers if bn_splits > 1
# def replace_bn(model, num_splits):
#     #if num_splits is 1, then return the model
#     if num_splits == 1:
#         return model
#     norm_layer = partial(SplitBatchNorm, num_splits=bn_splits) if bn_splits > 1 else nn.BatchNorm2d
#     for name, child in model.named_children():
#         if isinstance(child, nn.BatchNorm2d):
#             setattr(model, name, norm_layer(child.num_features))
#         else:
#             replace_bn(child, num_splits)
#     return model
    

            

In [96]:
#function to use the class SplitBatchNorm to replace the batchnorm layers with splitbatchnorm layers if bn_splits > 1
#the init has num_features and num_splits as arguments, thus while changing the batchnorm layers to splitbatchnorm layers, we need to pass the num_splits as an argument
def replace_bn(model, num_splits):
    #if num_splits is 1, then return the model
    if num_splits == 1:
        return model
    #loop through the model and replace the batchnorm layers with splitbatchnorm layers using the class SplitBatchNorm
    norm_layer = partial(SplitBatchNorm, num_splits=bn_splits)
    for name, child in model.named_children():
        if isinstance(child, nn.BatchNorm2d):
            #change the batchnorm layers to splitbatchnorm layers by the return of the function : SplitBatchNorm(child.num_features, num_splits)
            #replace batchnorm with norm_layer
            setattr(model, name, norm_layer(child.num_features))
        else:
            replace_bn(child, num_splits)

## Model

In [122]:
class convNet(nn.Module):
    def __init__(self, cfg, device):
        super(convNet,self).__init__()

        self.config = cfg
        self.device = device
        layers =[]

        channel_set=self.config['channels']
        for i in range(self.config['num_conv_blocks']):
            if i==0:
                layers.append(nn.Conv2d(self.config['in_channels'],channel_set[i][1],3, bias=False))
            else:
                layers.append(nn.Conv2d(channel_set[i][0],channel_set[i][1],3, bias=False))
            
            layers.append(SplitBatchNorm(channel_set[i][1]))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(p=0.3))
            layers.append(nn.AdaptiveMaxPool2d(2))

        self.net = nn.ModuleList(layers)
        print('-'*59)
        print('ConvNet')
        print(self.net)
        print('-'*59)

    def forward(self,x):

        for _, l in enumerate(self.net):
            x = l(x)            

        return x

In [123]:
#create a model

test_model = convNet(cfg['model'], device)

-----------------------------------------------------------
ConvNet
ModuleList(
  (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), bias=False)
  (1): SplitBatchNorm(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU()
  (3): Dropout(p=0.3, inplace=False)
  (4): AdaptiveMaxPool2d(output_size=2)
  (5): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), bias=False)
  (6): SplitBatchNorm(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (7): ReLU()
  (8): Dropout(p=0.3, inplace=False)
  (9): AdaptiveMaxPool2d(output_size=2)
  (10): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), bias=False)
  (11): SplitBatchNorm(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (12): ReLU()
  (13): Dropout(p=0.3, inplace=False)
  (14): AdaptiveMaxPool2d(output_size=2)
)
-----------------------------------------------------------


In [126]:
#summary of the model
summary(test_model.to(device), (3, 32, 32))

RuntimeError: shape '[-1, 256, 30, 30]' is invalid for input of size 57600

In [113]:
#create a class which will define the model: it  will take a base model as input and will replace all the batchnorm layers with splitbatchnorm layers if bn_splits > 1
class ModelBase(nn.Module):
    def __init__(self, ):
        super().__init__()
        #we get the model from the script models.py named convNet
        self.model = convNet( cfg=cfg['model'], device=device)
        #we replace the batchnorm layers with splitbatchnorm layers if bn_splits > 1
        norm_layer = partial(SplitBatchNorm, num_splits=bn_splits) if bn_splits > 1 else nn.BatchNorm2d
        temp_model = getattr(self.model, None)
        self.temp_model = temp_model(norm_layer = norm_layer)

      

    def forward(self, x):
        return self.temp_model(x)
        


In [114]:
#create an instance of the class ModelBase
model = ModelBase()


-----------------------------------------------------------
ConvNet
ModuleList(
  (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), bias=False)
  (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU()
  (3): Dropout(p=0.3, inplace=False)
  (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (5): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), bias=False)
  (6): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (7): ReLU()
  (8): Dropout(p=0.3, inplace=False)
  (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (10): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), bias=False)
  (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (12): ReLU()
  (13): Dropout(p=0.3, inplace=False)
  (14): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
---------------------------------------------

TypeError: getattr(): attribute name must be string

In [108]:
#print summary of the model
temp = torch.randn(1, 3, 32, 32).to(device)
summary(model.to(device), input_size=(3, 32, 32), device='cuda')

TypeError: 'NoneType' object is not callable

In [100]:
#create a calss of convNet to test if it works
temp_cnn = convNet( cfg=cfg['model'], device=device)

-----------------------------------------------------------
ConvNet
ModuleList(
  (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), bias=False)
  (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU()
  (3): Dropout(p=0.3, inplace=False)
  (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (5): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), bias=False)
  (6): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (7): ReLU()
  (8): Dropout(p=0.3, inplace=False)
  (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (10): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), bias=False)
  (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (12): ReLU()
  (13): Dropout(p=0.3, inplace=False)
  (14): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
---------------------------------------------

In [101]:
#print summary of the model
temp = torch.randn(1, 3, 32, 32).to(device)
summary(temp_cnn.to(device), input_size=(3, 32, 32), device='cuda')

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 32, 30, 30]             864
       BatchNorm2d-2           [-1, 32, 30, 30]              64
              ReLU-3           [-1, 32, 30, 30]               0
           Dropout-4           [-1, 32, 30, 30]               0
         MaxPool2d-5           [-1, 32, 15, 15]               0
            Conv2d-6           [-1, 64, 13, 13]          18,432
       BatchNorm2d-7           [-1, 64, 13, 13]             128
              ReLU-8           [-1, 64, 13, 13]               0
           Dropout-9           [-1, 64, 13, 13]               0
        MaxPool2d-10             [-1, 64, 6, 6]               0
           Conv2d-11            [-1, 128, 4, 4]          73,728
      BatchNorm2d-12            [-1, 128, 4, 4]             256
             ReLU-13            [-1, 128, 4, 4]               0
          Dropout-14            [-1, 12

In [102]:
#apply replace_bn function to the model
temp_cnn = replace_bn(temp_cnn, bn_splits)


In [103]:
#summary of the model
temp = torch.randn(1, 3, 32, 32).to(device)
summary(temp_cnn.to(device), input_size=(3, 32, 32), device='cuda')

AttributeError: 'NoneType' object has no attribute 'to'

## MOCO

In [None]:
class ModelMoCo(nn.Module):
    def __init__(self, dim=128, K=4096, m=0.99, T=0.1, arch='resnet18', bn_splits=8, symmetric=True):
        super(ModelMoCo, self).__init__()

        self.K = K
        self.m = m
        self.T = T
        self.symmetric = symmetric

        # create the encoders
        self.encoder_q = ModelBase()
        self.encoder_k = ModelBase()

        for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
            param_k.data.copy_(param_q.data)  # initialize
            param_k.requires_grad = False  # not update by gradient

        # create the queue
        self.register_buffer("queue", torch.randn(dim, K))
        self.queue = nn.functional.normalize(self.queue, dim=0)

        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))

    @torch.no_grad()
    def _momentum_update_key_encoder(self):
        """
        Momentum update of the key encoder
        """
        for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
            param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)

    @torch.no_grad()
    def _dequeue_and_enqueue(self, keys):
        batch_size = keys.shape[0]

        ptr = int(self.queue_ptr)
        assert self.K % batch_size == 0  # for simplicity

        # replace the keys at ptr (dequeue and enqueue)
        self.queue[:, ptr:ptr + batch_size] = keys.t()  # transpose
        ptr = (ptr + batch_size) % self.K  # move pointer

        self.queue_ptr[0] = ptr

    @torch.no_grad()
    def _batch_shuffle_single_gpu(self, x):
        """
        Batch shuffle, for making use of BatchNorm.
        """
        # random shuffle index
        idx_shuffle = torch.randperm(x.shape[0]).cuda()

        # index for restoring
        idx_unshuffle = torch.argsort(idx_shuffle)

        return x[idx_shuffle], idx_unshuffle

    @torch.no_grad()
    def _batch_unshuffle_single_gpu(self, x, idx_unshuffle):
        """
        Undo batch shuffle.
        """
        return x[idx_unshuffle]

    def contrastive_loss(self, im_q, im_k):
        # compute query features
        q = self.encoder_q(im_q)  # queries: NxC
        q = nn.functional.normalize(q, dim=1)  # already normalized

        # compute key features
        with torch.no_grad():  # no gradient to keys
            # shuffle for making use of BN
            im_k_, idx_unshuffle = self._batch_shuffle_single_gpu(im_k)

            k = self.encoder_k(im_k_)  # keys: NxC
            k = nn.functional.normalize(k, dim=1)  # already normalized

            # undo shuffle
            k = self._batch_unshuffle_single_gpu(k, idx_unshuffle)

        # compute logits
        # Einstein sum is more intuitive
        # positive logits: Nx1
        l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
        # negative logits: NxK
        l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])

        # logits: Nx(1+K)
        logits = torch.cat([l_pos, l_neg], dim=1)

        # apply temperature
        logits /= self.T

        # labels: positive key indicators
        labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()
        
        loss = nn.CrossEntropyLoss().cuda()(logits, labels)

        return loss, q, k

    def forward(self, im1, im2):
        """
        Input:
            im_q: a batch of query images
            im_k: a batch of key images
        Output:
            loss
        """

        # update the key encoder
        with torch.no_grad():  # no gradient to keys
            self._momentum_update_key_encoder()

        # compute loss
        if self.symmetric:  # asymmetric loss
            loss_12, q1, k2 = self.contrastive_loss(im1, im2)
            loss_21, q2, k1 = self.contrastive_loss(im2, im1)
            loss = loss_12 + loss_21
            k = torch.cat([k1, k2], dim=0)
        else:  # asymmetric loss
            loss, q, k = self.contrastive_loss(im1, im2)

        self._dequeue_and_enqueue(k)

        return loss

# create model
model = ModelMoCo(
        dim=args.moco_dim,
        K=args.moco_k,
        m=args.moco_m,
        T=args.moco_t,
        arch=args.arch,
        bn_splits=args.bn_splits,
        symmetric=args.symmetric,
    ).cuda()
print(model.encoder_q)