### What is SimpleCLR?
* **SimCLR** is a framework for contrastive learning of visual representations. It learns representations by maximizing agreement between differently augmented views of the same data example via a contrastive loss in the latent space.
* **SimCLR** was presented in the Paper “A Simple Framework for Contrastive Learning of Visual Representations” by Chen et al. from Google Research in 2020. 
* The ideas in this paper are relatively simple and intuitive, but there is also a novel loss function that is key for achieve great performance for self-supervised pre-training.  
![](https://miro.medium.com/max/700/1*NjdVYtL4C2HmV1r22XIweg.png)

In [16]:
# !pip install lightly av

In [7]:
import torch
from torch import nn
import torchvision

from lightly.data import LightlyDataset
from lightly.data import SimCLRCollateFunction
from lightly.loss import NTXentLoss
from lightly.models.modules import SimCLRProjectionHead

In [8]:
class SimCLR(nn.Module):
    def __init__(self, backbone):
        super().__init__()
        self.backbone = backbone
        self.projection_head = SimCLRProjectionHead(512, 512, 128)

    def forward(self, x):
        x = self.backbone(x).flatten(start_dim=1)
        z = self.projection_head(x)
        return z

In [9]:
resnet = torchvision.models.resnet34()
backbone = nn.Sequential(*list(resnet.children())[:-1])
model = SimCLR(backbone)

In [17]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

SimCLR(
  (backbone): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)

In [23]:
!pip install torch-summary

Collecting torch-summary
  Downloading torch_summary-1.4.5-py3-none-any.whl (16 kB)
Installing collected packages: torch-summary
Successfully installed torch-summary-1.4.5


In [40]:
from torchsummary import summary
print(summary(model, (3, 32, 32), depth=3))

Layer (type:depth-idx)                   Output Shape              Param #
├─Sequential: 1-1                        [-1, 512, 1, 1]           --
|    └─Conv2d: 2-1                       [-1, 64, 16, 16]          9,408
|    └─BatchNorm2d: 2-2                  [-1, 64, 16, 16]          128
|    └─ReLU: 2-3                         [-1, 64, 16, 16]          --
|    └─MaxPool2d: 2-4                    [-1, 64, 8, 8]            --
|    └─Sequential: 2-5                   [-1, 64, 8, 8]            --
|    |    └─BasicBlock: 3-1              [-1, 64, 8, 8]            73,984
|    |    └─BasicBlock: 3-2              [-1, 64, 8, 8]            73,984
|    └─Sequential: 2-6                   [-1, 128, 4, 4]           --
|    |    └─BasicBlock: 3-3              [-1, 128, 4, 4]           230,144
|    |    └─BasicBlock: 3-4              [-1, 128, 4, 4]           295,424
|    └─Sequential: 2-7                   [-1, 256, 2, 2]           --
|    |    └─BasicBlock: 3-5              [-1, 256, 2, 2]       

In [41]:
cifar10 = torchvision.datasets.CIFAR10("datasets/cifar10", download=True)
dataset = LightlyDataset.from_torch_dataset(cifar10)

Files already downloaded and verified


In [42]:
collate_fn = SimCLRCollateFunction(
    input_size=32,
    gaussian_blur=0.,
)

In [43]:
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=256,
    collate_fn=collate_fn,
    shuffle=True,
    drop_last=True,
    num_workers=8,
)

  cpuset_checked))


In [44]:
criterion = NTXentLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.06)

In [45]:
print("Starting Training")
for epoch in range(3):
    total_loss = 0
    for (x0, x1), _, _ in dataloader:
        x0 = x0.to(device)
        x1 = x1.to(device)
        z0 = model(x0)
        z1 = model(x1)
        loss = criterion(z0, z1)
        total_loss += loss.detach()
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    avg_loss = total_loss / len(dataloader)
    print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}")

Starting Training


  cpuset_checked))


epoch: 00, loss: 5.22765
epoch: 01, loss: 5.21985
epoch: 02, loss: 5.20385
