# SimCLR
PyTorch implementation of SimCLR: A Simple Framework for Contrastive Learning of Visual Representations by T. Chen et al. With support for the LARS (Layer-wise Adaptive Rate Scaling) optimizer.

[Link to paper](https://arxiv.org/pdf/2002.05709.pdf)


## Setup the repository

In [1]:
!git clone https://github.com/spijkervet/SimCLR.git
%cd SimCLR
!wget https://github.com/Spijkervet/SimCLR/releases/download/1.2/checkpoint_100.tar
!sh setup.sh || python3 -m pip install -r requirements.txt || exit 1
!pip install  pyyaml --upgrade

fatal: destination path 'SimCLR' already exists and is not an empty directory.
/content/SimCLR
--2020-04-19 06:00:30--  https://github.com/Spijkervet/SimCLR/releases/download/1.2/checkpoint_100.tar
Resolving github.com (github.com)... 140.82.113.3
Connecting to github.com (github.com)|140.82.113.3|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://github-production-release-asset-2e65be.s3.amazonaws.com/246276098/8ae3c180-64bd-11ea-91fe-0f47017fe9be?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIWNJYAX4CSVEH53A%2F20200419%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20200419T060005Z&X-Amz-Expires=300&X-Amz-Signature=ce2efe4ccc17f59f38c1d180a548b69cf7011831d5372ea79504e91bbd791302&X-Amz-SignedHeaders=host&actor_id=0&repo_id=246276098&response-content-disposition=attachment%3B%20filename%3Dcheckpoint_100.tar&response-content-type=application%2Foctet-stream [following]
--2020-04-19 06:00:30--  https://github-production-release-asset-2e65be.s3.amazo

# Part 1:
## SimCLR pre-training

In [0]:
# whether to use a TPU or not (set in Runtime -> Change Runtime Type)
use_tpu = False

#### Install PyTorch/XLA

In [0]:
if use_tpu:
  VERSION = "20200220" #@param ["20200220","nightly", "xrt==1.15.0"]
  !curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
  !python pytorch-xla-env-setup.py --version $VERSION

In [0]:
#!pip install torch_xla

In [5]:
import os
import torch

if use_tpu:
  # imports the torch_xla package for TPU support
  import torch_xla
  import torch_xla.core.xla_model as xm
  dev = xm.xla_device()
  print(dev)
  
import torchvision
import argparse

from torch.utils.tensorboard import SummaryWriter

apex = False
try:
    from apex import amp
    apex = True
except ImportError:
    print(
        "Install the apex package from https://www.github.com/nvidia/apex to use fp16 for training"
    )

from model import load_model, save_model
from modules import NT_Xent
from modules.transformations import TransformsSimCLR
from utils import mask_correlated_samples, post_config_hook


Install the apex package from https://www.github.com/nvidia/apex to use fp16 for training


In [0]:
def train(args, train_loader, model, criterion, optimizer, writer):
    loss_epoch = 0
    for step, ((x_i, x_j), _) in enumerate(train_loader):

        optimizer.zero_grad()
        x_i = x_i.to(args.device)
        x_j = x_j.to(args.device)

        # positive pair, with encoding
        h_i, z_i = model(x_i)
        h_j, z_j = model(x_j)

        loss = criterion(z_i, z_j)

        if apex and args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()

        optimizer.step()

        if step % 50 == 0:
            print(f"Step [{step}/{len(train_loader)}]\t Loss: {loss.item()}")

        writer.add_scalar("Loss/train_epoch", loss.item(), args.global_step)
        loss_epoch += loss.item()
        args.global_step += 1

    return loss_epoch

### Load arguments from `config/config.yaml`

In [0]:
from pprint import pprint
from utils.yaml_config_hook import yaml_config_hook

config = yaml_config_hook("./config/config.yaml")
args = argparse.Namespace(**config)

if use_tpu:
  args.device = dev
else:
  args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  
args.out_dir = "logs"
if not os.path.exists("logs"):
  os.makedirs("logs")

In [8]:
### override any configuration parameters here, e.g. to adjust for use on GPUs on the Colab platform:
args.batch_size = 64
args.resnet = "resnet18"
pprint(vars(args))

{'batch_size': 64,
 'dataset': 'STL10',
 'device': device(type='cuda', index=0),
 'epoch_num': 100,
 'epochs': 100,
 'fp16': False,
 'fp16_opt_level': 'O2',
 'logistic_batch_size': 256,
 'logistic_epochs': 100,
 'model_path': 'logs/0',
 'normalize': True,
 'optimizer': 'Adam',
 'out_dir': 'logs',
 'projection_dim': 64,
 'resnet': 'resnet18',
 'seed': 42,
 'start_epoch': 0,
 'temperature': 0.5,
 'weight_decay': 1e-06,
 'workers': 16}


### Load dataset into train loader

In [9]:
root = "./datasets"

train_sampler = None

if args.dataset == "STL10":
    train_dataset = torchvision.datasets.STL10(
        root, split="unlabeled", download=True, transform=TransformsSimCLR()
    )
elif args.dataset == "CIFAR10":
    train_dataset = torchvision.datasets.CIFAR10(
        root, download=True, transform=TransformsSimCLR()
    )
else:
    raise NotImplementedError

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=args.batch_size,
    shuffle=(train_sampler is None),
    drop_last=True,
    num_workers=args.workers,
    sampler=train_sampler,
)

Downloading http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz to ./datasets/stl10_binary.tar.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Extracting ./datasets/stl10_binary.tar.gz to ./datasets


### Load the SimCLR model, optimizer and learning rate scheduler

In [0]:
model, optimizer, scheduler = load_model(args, train_loader)

### Setup TensorBoard for logging experiments

In [0]:
tb_dir = os.path.join(args.out_dir, "colab")
if not os.path.exists(tb_dir):
  os.makedirs(tb_dir)
writer = SummaryWriter(log_dir=tb_dir)

### Create the mask that will remove correlated samples from the negative examples

In [0]:
mask = mask_correlated_samples(args)

### Initialize the criterion (NT-Xent loss)

In [0]:
criterion = NT_Xent(args.batch_size, args.temperature, mask, args.device)

### Start training

In [0]:
args.global_step = 0
args.current_epoch = 0
for epoch in range(args.start_epoch, args.epochs):
    lr = optimizer.param_groups[0]['lr']
    loss_epoch = train(args, train_loader, model, criterion, optimizer, writer)

    if scheduler:
        scheduler.step()

    if epoch % 5 == 0:
        save_model(args, model, optimizer)

    writer.add_scalar("Loss/train", loss_epoch / len(train_loader), epoch)
    writer.add_scalar("Misc/learning_rate", lr, epoch)
    print(
        f"Epoch [{epoch}/{args.epochs}]\t Loss: {loss_epoch / len(train_loader)}\t lr: {round(lr, 5)}"
    )
    args.current_epoch += 1

## end training
save_model(args, model, optimizer)

Step [0/1562]	 Loss: 4.837167263031006
Step [50/1562]	 Loss: 4.757541656494141
Step [100/1562]	 Loss: 4.5447211265563965
Step [150/1562]	 Loss: 4.601693153381348
Step [200/1562]	 Loss: 4.481614112854004
Step [250/1562]	 Loss: 4.502140998840332
Step [300/1562]	 Loss: 4.440523147583008
Step [350/1562]	 Loss: 4.386401653289795
Step [400/1562]	 Loss: 4.013962745666504
Step [450/1562]	 Loss: 4.128726005554199
Step [500/1562]	 Loss: 4.110790729522705
Step [550/1562]	 Loss: 3.995676279067993
Step [600/1562]	 Loss: 4.009866714477539
Step [650/1562]	 Loss: 3.92177152633667
Step [700/1562]	 Loss: 4.08205509185791
Step [750/1562]	 Loss: 4.016186714172363
Step [800/1562]	 Loss: 3.942352056503296
Step [850/1562]	 Loss: 3.956094741821289
Step [900/1562]	 Loss: 4.050191879272461
Step [950/1562]	 Loss: 4.037079334259033
Step [1000/1562]	 Loss: 3.8202831745147705
Step [1050/1562]	 Loss: 3.9044482707977295
Step [1100/1562]	 Loss: 3.985013961791992
Step [1150/1562]	 Loss: 3.794029712677002
Step [1200/156

## Download last checkpoint to local drive (replace `100` with `args.epochs`)

In [0]:
from google.colab import files
files.download('./logs/checkpoint_100.tar')

# Part 2:
## Linear evaluation using logistic regression, using weights from frozen, pre-trained SimCLR model

In [0]:
import torch
import torchvision
import torchvision.transforms as transforms
import argparse

from experiment import ex
from model import load_model
from utils import post_config_hook

from modules import LogisticRegression


In [0]:
def train(args, loader, simclr_model, model, criterion, optimizer):
    loss_epoch = 0
    accuracy_epoch = 0
    for step, (x, y) in enumerate(loader):
        optimizer.zero_grad()

        x = x.to(args.device)
        y = y.to(args.device)

        # get encoding
        with torch.no_grad():
            h, z = simclr_model(x)
            # h = 512
            # z = 64

        output = model(h)
        loss = criterion(output, y)

        predicted = output.argmax(1)
        acc = (predicted == y).sum().item() / y.size(0)
        accuracy_epoch += acc

        loss.backward()
        optimizer.step()

        loss_epoch += loss.item()
        if step % 1 == 0:
            print(f"Step [{step}/{len(loader)}]\t Loss: {loss.item()}\t Accuracy: {acc}")

    return loss_epoch, accuracy_epoch

In [0]:
def test(args, loader, simclr_model, model, criterion, optimizer):
    loss_epoch = 0
    accuracy_epoch = 0
    model.eval()
    for step, (x, y) in enumerate(loader):
        model.zero_grad()

        x = x.to(args.device)
        y = y.to(args.device)

        # get encoding
        with torch.no_grad():
            h, z = simclr_model(x)
            # h = 512
            # z = 64

        output = model(h)
        loss = criterion(output, y)

        predicted = output.argmax(1)
        acc = (predicted == y).sum().item() / y.size(0)
        accuracy_epoch += acc

        loss_epoch += loss.item()


    return loss_epoch, accuracy_epoch

In [0]:
from pprint import pprint
from utils.yaml_config_hook import yaml_config_hook

config = yaml_config_hook("./config/config.yaml")
pprint(config)
args = argparse.Namespace(**config)

if use_tpu:
  args.device = dev
else:
  args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [0]:
args.batch_size = 64
args.resnet = "resnet18"
args.model_path = "logs"
args.epoch_num = 100

### Load dataset into train/test dataloaders

In [0]:
root = "./datasets"
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

if args.dataset == "STL10":
    train_dataset = torchvision.datasets.STL10(
        root, split="train", download=True, transform=torchvision.transforms.ToTensor()
    )
    test_dataset = torchvision.datasets.STL10(
        root, split="test", download=True, transform=torchvision.transforms.ToTensor()
    )
elif args.dataset == "CIFAR10":
    train_dataset = torchvision.datasets.CIFAR10(
        root, train=True, download=True, transform=transform
    )
    test_dataset = torchvision.datasets.CIFAR10(
        root, train=False, download=True, transform=transform
    )
else:
    raise NotImplementedError

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=args.logistic_batch_size,
    shuffle=True,
    drop_last=True,
    num_workers=args.workers,
)

test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=args.logistic_batch_size,
    shuffle=False,
    drop_last=True,
    num_workers=args.workers,
)

### Load SimCLR model and load model weights

In [0]:
simclr_model, _, _ = load_model(args, train_loader, reload_model=True)
simclr_model = simclr_model.to(args.device)
simclr_model.eval()

In [0]:
## Logistic Regression
n_classes = 10 # stl-10
model = LogisticRegression(simclr_model.n_features, n_classes)
model = model.to(args.device)

In [0]:
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
criterion = torch.nn.CrossEntropyLoss()

In [0]:
for epoch in range(args.logistic_epochs):
    loss_epoch, accuracy_epoch = train(args, train_loader, simclr_model, model, criterion, optimizer)
    print(f"Epoch [{epoch}/{args.logistic_epochs}]\t Loss: {loss_epoch / len(train_loader)}\t Accuracy: {accuracy_epoch / len(train_loader)}")

# final testing
loss_epoch, accuracy_epoch = test(args, test_loader, simclr_model, model, criterion, optimizer)
print(f"[FINAL]\t Loss: {loss_epoch / len(test_loader)}\t Accuracy: {accuracy_epoch / len(test_loader)}")