In [None]:
# Feel free to import anything you need

import numpy as np
from PIL import Image

import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
import torchvision
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
from torchvision import transforms
from torch.utils.data import DataLoader, SubsetRandomSampler
import torch.optim as optim
import matplotlib.pyplot as plt

from tqdm import tqdm

In [None]:
###########  Do not modify this cell ###########################
# ===============================================================

### seed everything for reproducibility
def seed_everything():
    seed = 24789
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

# 1. Contrastive learning loss functions

## 1(a) Implement the normalized temperature-scaled cross entropy loss (NT-Xent) based on SimCLR:
https://arxiv.org/pdf/2002.05709.pdf

\begin{equation}
\ell_{i, j}=-\log \frac{\exp \left(\operatorname{sim}\left(\boldsymbol{z}_i, \boldsymbol{z}_j\right) / \tau\right)}{\sum_{k=1}^{2 N}  \mathbb{1}_{[k \neq i]} \exp \left(\operatorname{sim}\left(\boldsymbol{z}_i, \boldsymbol{z}_k\right) / \tau\right)}
\end{equation}

In [None]:
def nt_xent(x, t=0.5):
    """
    1. Normalize across the second dimension
    2. Compute the normalized cosine similarity scores between x and its transpose
    3. Scale the normalized cosine similarity scores x_scores by dividing them by a temperature

    """
    # 1. Normalize across the second dimension
    x = F.normalize(x, dim=1)

    # 2. Compute the normalized cosine similarity scores between x and its transpose and scale the with temperature

    x_scores = (x @ x.t()).clamp(min=1e-7)
    x_scaled = x_scores / t 


    # Subtract a diagonal matrix with large negative values from x_scale to set the diagonals to be zeros after softmax. 
    x_scaled = x_scaled - torch.eye(x_scaled.size(0)).to(x_scaled.device) * 1e5

    # targets 2N elements.
    targets = torch.arange(x.size()[0])
    targets[::2] += 1  # target of 2k element is 2k+1
    targets[1::2] -= 1  # target of 2k+1 element is 2k

    # 3 . Compute the cross entropy loss between x_scale and target and return the computed loss
    loss = F.cross_entropy(x_scaled, targets.long().to(x_scaled.device))

    return loss

In [None]:
seed_everything()

z = torch.randn(4,3)
t = 0.7

loss = nt_xent(z, t)

print("z: \n", z)
print("Loss: \n", loss)

z: 
 tensor([[-0.2997, -0.6194, -0.4414],
        [-1.2945, -0.1358, -0.2589],
        [-0.7526,  0.4345, -0.9680],
        [ 0.8034, -2.5620,  0.9044]])
Loss: 
 tensor(1.2352)


## 1 (b) Implement Barlow twins loss function
https://arxiv.org/pdf/2103.03230.pdf


Cross correlation matrix
\begin{equation}
\mathcal{C}_{i j} \triangleq \frac{\sum_b z_{b, i}^A z_{b, j}^B}{\sqrt{\sum_b\left(z_{b, i}^A\right)^2} \sqrt{\sum_b\left(z_{b, j}^B\right)^2}}
\end{equation}\
Barlow Twins loss

\begin{equation}
\mathcal{L}_{\mathcal{B} \mathcal{T}} \triangleq \sum_i\left(1-\mathcal{C}_{i i}\right)^2+\lambda \quad \sum_i \sum_{j \neq i} \mathcal{C}_{i j}^2
\end{equation}

In [None]:
def off_diag_elems(x):
    """
    Compute the flattened view of the off-diagonal elements of a matrix

    """
    n, m = x.shape
    assert n == m
    return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()

def barlow_twins(z1, z2, lamda, batch_size):
    """
    z1 and z2 are the output image representations  
    1. Normalize the representations across the batch dimension (dim=0)
    2. Compute the cross correlation matrix (given in the equation above)
    3. Compute the barlow twins loss by following the equations given above

    """
    z1_norm = (z1 - torch.mean(z1, dim=0)) / torch.std(z1, dim=0)
    z2_norm = (z2 - torch.mean(z2, dim=0)) / torch.std(z2, dim=0)

    C = torch.matmul(z1_norm.T, z2_norm) / batch_size
    on_diag = torch.diagonal(C).add_(-1).pow_(2).sum()
    off_diag = off_diag_elems(C).pow_(2).sum()

    loss = on_diag + lamda * off_diag

    return loss

In [None]:
seed_everything()

N = 4
z1 = torch.randn(N, 4)
z2 = torch.randn(N, 4)
lamda = 5e-3

loss = barlow_twins(z1, z2, lamda, N)

print("z1: \n", z1)
print("z2: \n", z2)
print("Loss: \n", loss)

z1: 
 tensor([[-0.6637, -2.2035, -0.5497,  0.1591],
        [-0.8427, -0.3029, -1.2701, -0.8752],
        [ 0.1743, -0.6272, -0.5871, -1.7928],
        [-1.7089,  0.2134, -0.3203, -0.3736]])
z2: 
 tensor([[ 0.6968,  1.8388, -0.9387, -0.4800],
        [-1.4524,  1.3257,  0.9204, -0.5180],
        [-1.0516, -0.5895,  0.5390,  0.0944],
        [-1.7974,  0.0804,  0.7257,  0.9933]])
Loss: 
 tensor(5.3352)


# 2. Implement the SimCLR model 

In [None]:
class SimCLR(nn.Module):
    def __init__(self, base_encoder, projection_dim):
        super().__init__()
        self.enc = base_encoder(weights=False)  
        self.feature_dim = self.enc.fc.in_features

        # Modifying the base encoder as mentioned in B4 of SimCLR 
        self.enc.conv1 = nn.Conv2d(3, 64, 3, 1, 1, bias=False)
        self.enc.maxpool = nn.Identity()
        self.enc.fc = nn.Identity()  

        # Add MLP projection.
        self.projection_dim = projection_dim
        self.projector = nn.Sequential(nn.Linear(self.feature_dim, 2048),
                                       nn.ReLU(),
                                       nn.Linear(2048, projection_dim))

    def forward(self, x):

        ######## Define the forward function ############################
        feature = self.enc(x)
        projection = self.projector(feature)
        return feature, projection


In [None]:
###########  Do not modify this cell ###########################
# ===============================================================

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name):
        self.name = name
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


In [None]:

# color distortion composed by color jittering and color dropping.
# See Section A of SimCLR: https://arxiv.org/abs/2002.05709
def get_color_distortion(s=0.5):  # 0.5 for CIFAR10 by default
    # s is the strength of color distortion
    color_jitter = transforms.ColorJitter(0.8*s, 0.8*s, 0.8*s, 0.2*s)
    rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8)
    rnd_gray = transforms.RandomGrayscale(p=0.2)
    color_distort = transforms.Compose([rnd_color_jitter, rnd_gray])
    return color_distort


Creating pairs of image on the training set

In [None]:
###########  Do not modify this cell ###########################
# ===============================================================

class CIFAR10Pair(CIFAR10):
    """Generate mini-batche pairs on CIFAR10 training set."""
    def __getitem__(self, idx):
        img, target = self.data[idx], self.targets[idx]
        img = Image.fromarray(img)  # .convert('RGB')
        imgs = [self.transform(img), self.transform(img)]
        return torch.stack(imgs), target  # stack a positive pair

#3. Define the transformations, dataloaders, model, optimizer, scheduler and parameters required for training

In [None]:
############ Define the parameters for training ##############


batch_size = 512
projection_dim = 128
learning_rate =  0.6
momentum = 0.9
weight_decay = 1e-6
epochs = 100


In [None]:
assert torch.cuda.is_available()
cudnn.benchmark = True

######### Define the transformation on the training set

train_transform = transforms.Compose([transforms.RandomResizedCrop(32),
                                        transforms.RandomHorizontalFlip(p=0.5),
                                        transforms.ToTensor()])

##### get absolute path of data dir

data_dir = "/content"  


########## Define the train set and train dataloader

train_set = CIFAR10Pair(root=data_dir,
                        train=True,
                        transform=train_transform,
                        download=True)

train_loader = DataLoader(train_set,
                            batch_size=batch_size,
                            shuffle=True,
                            num_workers=2,
                            drop_last=True)


# Define the base encoder(resnet18) -- load from torchvision.models without pretraining

base_encoder = torchvision.models.resnet18
model = SimCLR(base_encoder, projection_dim=projection_dim)
model = model.cuda()

########### Define the optimizer and scheduler for training #######################

optimizer = torch.optim.SGD(
    model.parameters(),
    learning_rate,
    momentum=momentum,
    weight_decay=weight_decay,
    nesterov=True)


scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=(len(train_loader) * epochs))


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /content/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:03<00:00, 42680184.18it/s]


Extracting /content/cifar-10-python.tar.gz to /content




In [None]:
def train(args) -> None: 
    # SimCLR training
    model.train()
    best_loss = np.inf
    for epoch in range(1, args['epochs'] + 1):
        loss_meter = AverageMeter("SimCLR_loss")
        train_bar = tqdm(train_loader)
        for x, y in train_bar:
            sizes = x.size()
            x = x.view(sizes[0] * 2, sizes[2], sizes[3], sizes[4]).cuda(non_blocking=True)
            optimizer.zero_grad()

            ####### Get the output from the model ################

            feature, rep = model(x)

            ######  Compute the loss using nt-cross entropy loss #####################################

            loss = nt_xent(rep, args['temperature'])
            loss.backward()
            optimizer.step()
            scheduler.step()

            loss_meter.update(loss.item(), x.size(0))
            train_bar.set_description("Train epoch {}, SimCLR loss: {:.4f}".format(epoch, loss_meter.avg))
            if loss.item() < best_loss:
                best_loss = loss.item()
                torch.save(model.state_dict(), 'simclr_best_epoch.pt'.format(epoch))
     

In [None]:
args = {
    'epochs':  100,
    'batch_size': 512,
    'temperature': 0.5,
}

train(args)

Train epoch 1, SimCLR loss: 5.9324: 100%|██████████| 97/97 [01:38<00:00,  1.02s/it]
Train epoch 2, SimCLR loss: 5.7052: 100%|██████████| 97/97 [01:14<00:00,  1.30it/s]
Train epoch 3, SimCLR loss: 5.6161: 100%|██████████| 97/97 [01:15<00:00,  1.29it/s]
Train epoch 4, SimCLR loss: 5.5487: 100%|██████████| 97/97 [01:13<00:00,  1.31it/s]
Train epoch 5, SimCLR loss: 5.4926: 100%|██████████| 97/97 [01:13<00:00,  1.32it/s]
Train epoch 6, SimCLR loss: 5.4549: 100%|██████████| 97/97 [01:13<00:00,  1.32it/s]
Train epoch 7, SimCLR loss: 5.4258: 100%|██████████| 97/97 [01:13<00:00,  1.32it/s]
Train epoch 8, SimCLR loss: 5.4081: 100%|██████████| 97/97 [01:13<00:00,  1.31it/s]
Train epoch 9, SimCLR loss: 5.3890: 100%|██████████| 97/97 [01:13<00:00,  1.32it/s]
Train epoch 10, SimCLR loss: 5.3787: 100%|██████████| 97/97 [01:13<00:00,  1.31it/s]
Train epoch 11, SimCLR loss: 5.3678: 100%|██████████| 97/97 [01:13<00:00,  1.32it/s]
Train epoch 12, SimCLR loss: 5.3557: 100%|██████████| 97/97 [01:13<00:00, 

# Finetuning and testing the saved model

In [None]:
###########  Do not modify this cell ###########################
# ===============================================================

class LinModel(nn.Module):
    """Linear wrapper of encoder."""
    def __init__(self, encoder: nn.Module, feature_dim: int, n_classes: int):
        super().__init__()
        self.enc = encoder
        self.feature_dim = feature_dim
        self.n_classes = n_classes
        self.lin = nn.Linear(self.feature_dim, self.n_classes)

    def forward(self, x):
        return self.lin(self.enc(x))

In [None]:
###########  Do not modify this cell ###########################
# ===============================================================

def run_epoch(model, dataloader, epoch, optimizer=None, scheduler=None):
    if optimizer:
        model.train()
        print("training...........")
    else:
        model.eval()
        print("eval...............")

    loss_meter = AverageMeter('loss')
    acc_meter = AverageMeter('acc')
    loader_bar = tqdm(dataloader)
    for x, y in loader_bar:
        x, y = x.cuda(), y.cuda()
        logits = model(x)
        loss = F.cross_entropy(logits, y)

        if optimizer:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if scheduler:
                scheduler.step()

        acc = (logits.argmax(dim=1) == y).float().mean()
        loss_meter.update(loss.item(), x.size(0))
        acc_meter.update(acc.item(), x.size(0))
        if optimizer:
            loader_bar.set_description("Train epoch {}, loss: {:.4f}, acc: {:.4f}"
                                       .format(epoch, loss_meter.avg, acc_meter.avg))
        else:
            loader_bar.set_description("Test epoch {}, loss: {:.4f}, acc: {:.4f}"
                                       .format(epoch, loss_meter.avg, acc_meter.avg))

    return loss_meter.avg, acc_meter.avg

In [None]:
############################################ Define train and test transforms #########################################################

train_transform = transforms.Compose([transforms.RandomResizedCrop(32),
                                          transforms.RandomHorizontalFlip(p=0.5),
                                          transforms.ToTensor()])
test_transform = transforms.ToTensor()

data_dir = '/content'

###################################### Define train and test dataloader #########################################

train_set = CIFAR10(root=data_dir, train=True, transform=train_transform, download=False)
test_set = CIFAR10(root=data_dir, train=False, transform=test_transform, download=False)

n_classes = 10

train_loader = DataLoader(train_set, batch_size=512, drop_last=True)
test_loader = DataLoader(test_set, batch_size=1, shuffle=False)

################################ Define the model and load the trained weights of the best model ##############################

base_encoder = torchvision.models.resnet18
pre_model = SimCLR(base_encoder, projection_dim=projection_dim).cuda()
pre_model.load_state_dict(torch.load('simclr_best_epoch.pt'))
model = LinModel(pre_model.enc, feature_dim=pre_model.feature_dim, n_classes=len(train_set.targets))
model = model.cuda()

# Fix encoder
model.enc.requires_grad = False
parameters = [param for param in model.parameters() if param.requires_grad is True]  # trainable parameters.

############################# Define the optimizer and scheduler ###################################################

optimizer = torch.optim.SGD(
    parameters,
    0.2,   # lr = 0.1 * batch_size / 256, see section B.6 and B.7 of SimCLR paper.
    momentum=momentum,
    weight_decay=0.,
    nesterov=True)

# cosine annealing lr
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=(len(train_loader) * epochs))



In [None]:
def finetune() -> None:
   
    optimal_loss, optimal_acc = 1e5, 0.
    train_losses, test_losses = [], []
    train_accuracies, test_accuracies = [], []
    for epoch in range(1, epochs + 1):
        train_loss, train_acc = run_epoch(model, train_loader, epoch, optimizer, scheduler)
        train_losses.append(train_loss)
        train_accuracies.append(train_acc)
        test_loss, test_acc = run_epoch(model, test_loader, epoch)
        test_losses.append(test_loss)
        test_accuracies.append(test_acc)

        if train_loss < optimal_loss:
            optimal_loss = train_loss
            optimal_acc = test_acc
            torch.save(model.state_dict(), 'simclr_lin_best.pth')

    return train_losses, train_accuracies, test_losses, test_accuracies

In [None]:
train_losses, train_accuracies, test_losses, test_accuracies = finetune()
plt.plot(train_losses, label='train loss')  # simply visualize the training loss
plt.plot(test_losses, label='test loss')
plt.legend()
plt.show()

training...........


Train epoch 1, loss: 3.5564, acc: 0.2334: 100%|██████████| 97/97 [01:05<00:00,  1.47it/s]


eval...............


Test epoch 1, loss: 1.8549, acc: 0.3163: 100%|██████████| 10000/10000 [01:20<00:00, 124.42it/s]


training...........


Train epoch 2, loss: 1.8171, acc: 0.3257: 100%|██████████| 97/97 [00:57<00:00,  1.67it/s]


eval...............


Test epoch 2, loss: 1.6685, acc: 0.3973: 100%|██████████| 10000/10000 [01:19<00:00, 125.39it/s]


training...........


Train epoch 3, loss: 1.6869, acc: 0.3794: 100%|██████████| 97/97 [00:57<00:00,  1.68it/s]


eval...............


Test epoch 3, loss: 1.5265, acc: 0.4405: 100%|██████████| 10000/10000 [01:20<00:00, 124.07it/s]


training...........


Train epoch 4, loss: 1.6074, acc: 0.4145: 100%|██████████| 97/97 [00:57<00:00,  1.68it/s]


eval...............


Test epoch 4, loss: 1.5196, acc: 0.4548: 100%|██████████| 10000/10000 [01:20<00:00, 124.83it/s]


training...........


Train epoch 5, loss: 1.5420, acc: 0.4416: 100%|██████████| 97/97 [00:57<00:00,  1.68it/s]


eval...............


Test epoch 5, loss: 1.4734, acc: 0.4657: 100%|██████████| 10000/10000 [01:20<00:00, 124.65it/s]


training...........


Train epoch 6, loss: 1.4648, acc: 0.4694: 100%|██████████| 97/97 [00:58<00:00,  1.67it/s]


eval...............


Test epoch 6, loss: 1.3837, acc: 0.5036: 100%|██████████| 10000/10000 [01:20<00:00, 124.20it/s]


training...........


Train epoch 7, loss: 1.3820, acc: 0.5010: 100%|██████████| 97/97 [00:57<00:00,  1.67it/s]


eval...............


Test epoch 7, loss: 1.2399, acc: 0.5540: 100%|██████████| 10000/10000 [01:20<00:00, 124.09it/s]


training...........


Train epoch 8, loss: 1.3025, acc: 0.5344: 100%|██████████| 97/97 [00:57<00:00,  1.68it/s]


eval...............


Test epoch 8, loss: 1.1131, acc: 0.6025: 100%|██████████| 10000/10000 [01:19<00:00, 125.56it/s]


training...........


Train epoch 9, loss: 1.2306, acc: 0.5607: 100%|██████████| 97/97 [00:57<00:00,  1.68it/s]


eval...............


Test epoch 9, loss: 1.1738, acc: 0.5926: 100%|██████████| 10000/10000 [01:21<00:00, 123.38it/s]


training...........


Train epoch 10, loss: 1.1659, acc: 0.5853: 100%|██████████| 97/97 [00:58<00:00,  1.66it/s]


eval...............


Test epoch 10, loss: 1.0552, acc: 0.6269: 100%|██████████| 10000/10000 [01:20<00:00, 124.41it/s]


training...........


Train epoch 11, loss: 1.1084, acc: 0.6059: 100%|██████████| 97/97 [00:57<00:00,  1.68it/s]


eval...............


Test epoch 11, loss: 0.9714, acc: 0.6624: 100%|██████████| 10000/10000 [01:15<00:00, 131.88it/s]


training...........


Train epoch 12, loss: 1.0573, acc: 0.6258: 100%|██████████| 97/97 [00:57<00:00,  1.69it/s]


eval...............


Test epoch 12, loss: 0.9120, acc: 0.6816: 100%|██████████| 10000/10000 [01:12<00:00, 137.80it/s]


training...........


Train epoch 13, loss: 1.0145, acc: 0.6418: 100%|██████████| 97/97 [00:57<00:00,  1.69it/s]


eval...............


Test epoch 13, loss: 0.9766, acc: 0.6711: 100%|██████████| 10000/10000 [01:12<00:00, 137.41it/s]


training...........


Train epoch 14, loss: 0.9675, acc: 0.6627:  19%|█▊        | 18/97 [00:10<00:45,  1.72it/s]

In [None]:
plt.plot(train_accuracies, label='train accuracy')  # simply visualize the training loss
plt.plot(test_accuracies, label='test accuracy')
plt.legend()
plt.show()