# Contrastive Representation Learning
Amirabbas Asadi  
Implemented Based on:
 - Simple Framework for Contrastive Learning of Visual Representations, Ting Chen and Simon Kornblith and Mohammad Norouzi and Geoffrey Hinton

In [4]:
import torch
from torch import nn

## Preparing the dataset and augmentation pipeline

In [5]:
from torch.utils.data import Dataset
from torchvision.datasets import MNIST
from torchvision.transforms import RandomAffine
from torchvision.transforms import RandomPerspective
from torchvision.transforms import GaussianBlur
from torchvision.transforms import ToTensor
from torchvision.transforms import Compose, RandomApply


class ContrastiveMNIST(Dataset):
  def __init__(self):
    self.mnist = MNIST('./mnist', download=True)

    self.augmentations = Compose([RandomApply([
                                  RandomPerspective(),
                                  RandomAffine(degrees=(-30, 30)),
                                  GaussianBlur(kernel_size=3)], p=0.7),
                                  ToTensor()])
  def __len__(self):
    return len(self.mnist)
  
  def __getitem__(self, index):    
    original, _ = self.mnist[index]
    t1 = self.augmentations(original)
    t2 = self.augmentations(original)
    return t1, t2

In [None]:
from torch.utils.data import DataLoader

batch_size = 64
dataset = ContrastiveMNIST()
dataloader = DataLoader(dataset, batch_size=batch_size)

In [7]:
class SimCLR(nn.Module):
  def __init__(self, h_dim=64, z_dim=8, tau=1.0):
    super().__init__()

    self.encoder = nn.Sequential(nn.Conv2d(1, 8, 3),
                                 nn.SiLU(),
                                 nn.MaxPool2d(2),
                                 nn.Conv2d(8, 16, 3),
                                 nn.SiLU(),
                                 nn.MaxPool2d(2),
                                 nn.Flatten(),
                                 nn.LazyLinear(h_dim),
                                 nn.SiLU(),
                                 nn.Linear(h_dim, h_dim),
                                 nn.SiLU())

    self.projection_head = nn.Sequential(nn.Linear(h_dim, z_dim),
                                         nn.SiLU(),
                                         nn.Linear(z_dim, z_dim))
    self.tau = tau

  def contrastive_loss(self, t1, t2):
    t = torch.vstack([t1, t2])
    h = self.encoder(t)
    z = self.projection_head(h)
    z_norm = z / z.norm(dim=1).unsqueeze(1)
    S = torch.mm(z_norm, z_norm.transpose(0,1)) / self.tau
    S.fill_diagonal_(0.0)
    L = -nn.functional.log_softmax(S, dim=1)
    L1 = torch.diag(L[:batch_size, batch_size:]).mean()
    L2 = torch.diag(L[batch_size:, :batch_size]).mean()
    loss = 0.5*(L1 + L2)
    return loss

## Training the model

In [None]:
model = SimCLR(h_dim=32, z_dim=16)

In [9]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.002)

In [None]:
epochs = 16
for epoch in range(epochs):
  for i, (t1, t2) in enumerate(dataloader):
    optimizer.zero_grad()
    loss = model.contrastive_loss(t1, t2)
    loss.backward()
    optimizer.step()
    print(loss.item())

## Visualizing the latent space

In [11]:
class MNISTShowcase(Dataset):
  def __init__(self):
    self.mnist = MNIST('./mnist', train=False)
    self.to_tensor = ToTensor()
  def __len__(self):
    return len(self.mnist)
  
  def __getitem__(self, index):    
    item, l = self.mnist[index]
    item = self.to_tensor(item)
    return item, l

mnist = MNISTShowcase()
mnist_loader = DataLoader(mnist, batch_size=1024, shuffle=True)

In [12]:
sample, labels = iter(mnist_loader).next()

In [None]:
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt

with torch.no_grad():
  latent = model.encoder(sample)

tsne = TSNE()
z = tsne.fit_transform(latent.detach().numpy())

In [15]:
def plot_representation(z, labels):
  plt.figure(figsize=(12, 8))
  plt.set_cmap("tab10")
  l = labels.numpy()
  for i in range(10):
    selected = z[l==i]
    plt.scatter(selected[:, 0], selected[:, 1], label=str(i))

  plt.legend()
  _ = plt.show()

In [None]:
plot_representation(z, labels)