<a href="https://colab.research.google.com/github/YoniSchirris/SimCLR-1/blob/master/SimCLR_Notebook.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# SimCLR
PyTorch implementation of SimCLR: A Simple Framework for Contrastive Learning of Visual Representations by T. Chen et al. With support for the LARS (Layer-wise Adaptive Rate Scaling) optimizer and global batch norm.

[Link to paper](https://arxiv.org/pdf/2002.05709.pdf)


## Setup the repository

In [0]:
!git clone https://github.com/spijkervet/SimCLR.git
%cd SimCLR
!mkdir -p logs && cd logs && wget https://github.com/Spijkervet/SimCLR/releases/download/1.2/checkpoint_100.tar && cd ../
!sh setup.sh || python3 -m pip install -r requirements.txt || exit 1
!pip install  pyyaml --upgrade

Cloning into 'SimCLR'...
remote: Enumerating objects: 63, done.[K
remote: Counting objects:   1% (1/63)[Kremote: Counting objects:   3% (2/63)[Kremote: Counting objects:   4% (3/63)[Kremote: Counting objects:   6% (4/63)[Kremote: Counting objects:   7% (5/63)[Kremote: Counting objects:   9% (6/63)[Kremote: Counting objects:  11% (7/63)[Kremote: Counting objects:  12% (8/63)[Kremote: Counting objects:  14% (9/63)[Kremote: Counting objects:  15% (10/63)[Kremote: Counting objects:  17% (11/63)[Kremote: Counting objects:  19% (12/63)[Kremote: Counting objects:  20% (13/63)[Kremote: Counting objects:  22% (14/63)[Kremote: Counting objects:  23% (15/63)[Kremote: Counting objects:  25% (16/63)[Kremote: Counting objects:  26% (17/63)[Kremote: Counting objects:  28% (18/63)[Kremote: Counting objects:  30% (19/63)[Kremote: Counting objects:  31% (20/63)[Kremote: Counting objects:  33% (21/63)[Kremote: Counting objects:  34% (22/63)[Kremote: Counting o

# Part 1:
## SimCLR pre-training

In [0]:
# whether to use a TPU or not (set in Runtime -> Change Runtime Type)
use_tpu = False

#### Install PyTorch/XLA

In [0]:
if use_tpu:
  VERSION = "20200220" #@param ["20200220","nightly", "xrt==1.15.0"]
  !curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
  !python pytorch-xla-env-setup.py --version $VERSION

In [0]:
import os
import torch

if use_tpu:
  # imports the torch_xla package for TPU support
  import torch_xla
  import torch_xla.core.xla_model as xm
  dev = xm.xla_device()
  print(dev)
  
import torchvision
import argparse

from torch.utils.tensorboard import SummaryWriter

apex = False
try:
    from apex import amp
    apex = True
except ImportError:
    print(
        "Install the apex package from https://www.github.com/nvidia/apex to use fp16 for training"
    )

from model import load_model, save_model
from modules import NT_Xent
from modules.transformations import TransformsSimCLR
from utils import post_config_hook


Install the apex package from https://www.github.com/nvidia/apex to use fp16 for training


In [0]:
def train(args, train_loader, model, criterion, optimizer, writer):
    loss_epoch = 0
    for step, ((x_i, x_j), _) in enumerate(train_loader):

        optimizer.zero_grad()
        x_i = x_i.to(args.device)
        x_j = x_j.to(args.device)

        # positive pair, with encoding
        h_i, z_i = model(x_i)
        h_j, z_j = model(x_j)

        loss = criterion(z_i, z_j)

        if apex and args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()

        optimizer.step()

        if step % 50 == 0:
            print(f"Step [{step}/{len(train_loader)}]\t Loss: {loss.item()}")

        writer.add_scalar("Loss/train_epoch", loss.item(), args.global_step)
        loss_epoch += loss.item()
        args.global_step += 1

    return loss_epoch

### Load arguments from `config/config.yaml`

In [0]:
from pprint import pprint
from utils.yaml_config_hook import yaml_config_hook

config = yaml_config_hook("./config/config.yaml")
args = argparse.Namespace(**config)

if use_tpu:
  args.device = dev
else:
  args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  
args.out_dir = "logs"
if not os.path.exists("logs"):
  os.makedirs("logs")

In [0]:
### override any configuration parameters here, e.g. to adjust for use on GPUs on the Colab platform:
args.batch_size = 64
args.resnet = "resnet18"
pprint(vars(args))

{'batch_size': 64,
 'dataset': 'CIFAR10',
 'device': device(type='cuda', index=0),
 'epoch_num': 100,
 'epochs': 100,
 'fp16': False,
 'fp16_opt_level': 'O2',
 'logistic_batch_size': 256,
 'logistic_epochs': 500,
 'model_path': 'logs/0',
 'normalize': True,
 'optimizer': 'Adam',
 'out_dir': 'logs',
 'pretrain': True,
 'projection_dim': 64,
 'resnet': 'resnet18',
 'seed': 42,
 'start_epoch': 0,
 'temperature': 0.5,
 'weight_decay': 1e-06,
 'workers': 16}


### Load dataset into train loader

In [0]:
root = "./datasets"

train_sampler = None

if args.dataset == "STL10":
    train_dataset = torchvision.datasets.STL10(
        root, split="unlabeled", download=True, transform=TransformsSimCLR(size=96) # 224 in the original paper
    )
elif args.dataset == "CIFAR10":
    train_dataset = torchvision.datasets.CIFAR10(
        root, download=True, transform=TransformsSimCLR(size=32) # 224 in the original paper
    )
else:
    raise NotImplementedError

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=args.batch_size,
    shuffle=(train_sampler is None),
    drop_last=True,
    num_workers=args.workers,
    sampler=train_sampler,
)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./datasets/cifar-10-python.tar.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./datasets/cifar-10-python.tar.gz to ./datasets


### Load the SimCLR model, optimizer and learning rate scheduler

In [0]:
model, optimizer, scheduler = load_model(args, train_loader)

### Setup TensorBoard for logging experiments

In [0]:
tb_dir = os.path.join(args.out_dir, "colab")
if not os.path.exists(tb_dir):
  os.makedirs(tb_dir)
writer = SummaryWriter(log_dir=tb_dir)

### Create the mask that will remove correlated samples from the negative examples

### Initialize the criterion (NT-Xent loss)

In [0]:
criterion = NT_Xent(args.batch_size, args.temperature, args.device)

### Start training

In [0]:
args.global_step = 0
args.current_epoch = 0
for epoch in range(args.start_epoch, args.epochs):
    lr = optimizer.param_groups[0]['lr']
    loss_epoch = train(args, train_loader, model, criterion, optimizer, writer)

    if scheduler:
        scheduler.step()

    if epoch % 5 == 0:
        save_model(args, model, optimizer)

    writer.add_scalar("Loss/train", loss_epoch / len(train_loader), epoch)
    writer.add_scalar("Misc/learning_rate", lr, epoch)
    print(
        f"Epoch [{epoch}/{args.epochs}]\t Loss: {loss_epoch / len(train_loader)}\t lr: {round(lr, 5)}"
    )
    args.current_epoch += 1

## end training
save_model(args, model, optimizer)

Step [0/781]	 Loss: 4.385220050811768
Step [50/781]	 Loss: 4.211308002471924
Step [100/781]	 Loss: 4.064598560333252
Step [150/781]	 Loss: 4.12453556060791
Step [200/781]	 Loss: 4.21726131439209
Step [250/781]	 Loss: 4.271238327026367
Step [300/781]	 Loss: 4.390628814697266
Step [350/781]	 Loss: 4.272647380828857
Step [400/781]	 Loss: 4.045458793640137
Step [450/781]	 Loss: 4.2279767990112305
Step [500/781]	 Loss: 4.2589592933654785
Step [550/781]	 Loss: 4.258835792541504
Step [600/781]	 Loss: 4.239417552947998
Step [650/781]	 Loss: 4.1034064292907715
Step [700/781]	 Loss: 4.461949348449707
Step [750/781]	 Loss: 4.1982645988464355
Epoch [0/100]	 Loss: 4.2084955882171355	 lr: 0.0003


## OPTIONAL: Download last checkpoint to local drive (replace `100` with `args.epochs`)

In [0]:
from google.colab import files
files.download('./logs/checkpoint_100.tar')

KeyboardInterrupt: ignored

# Part 2:
## Linear evaluation using logistic regression, using weights from frozen, pre-trained SimCLR model

In [0]:
import torch
import torchvision
import numpy as np
import argparse

from experiment import ex
from model import load_model
from utils import post_config_hook

from modules import LogisticRegression


In [0]:
def train(args, loader, simclr_model, model, criterion, optimizer):
    loss_epoch = 0
    accuracy_epoch = 0
    for step, (x, y) in enumerate(loader):
        optimizer.zero_grad()

        x = x.to(args.device)
        y = y.to(args.device)

        output = model(x)
        loss = criterion(output, y)

        predicted = output.argmax(1)
        acc = (predicted == y).sum().item() / y.size(0)
        accuracy_epoch += acc

        loss.backward()
        optimizer.step()

        loss_epoch += loss.item()
        # if step % 100 == 0:
        #     print(
        #         f"Step [{step}/{len(loader)}]\t Loss: {loss.item()}\t Accuracy: {acc}"
        #     )

    return loss_epoch, accuracy_epoch

In [0]:
def test(args, loader, simclr_model, model, criterion, optimizer):
    loss_epoch = 0
    accuracy_epoch = 0
    model.eval()
    for step, (x, y) in enumerate(loader):
        model.zero_grad()

        x = x.to(args.device)
        y = y.to(args.device)

        output = model(x)
        loss = criterion(output, y)

        predicted = output.argmax(1)
        acc = (predicted == y).sum().item() / y.size(0)
        accuracy_epoch += acc

        loss_epoch += loss.item()

    return loss_epoch, accuracy_epoch



In [0]:
from pprint import pprint
from utils.yaml_config_hook import yaml_config_hook

config = yaml_config_hook("./config/config.yaml")
pprint(config)
args = argparse.Namespace(**config)

if use_tpu:
  args.device = dev
else:
  args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

{'batch_size': 128,
 'dataset': 'CIFAR10',
 'epoch_num': 100,
 'epochs': 100,
 'fp16': False,
 'fp16_opt_level': 'O2',
 'logistic_batch_size': 256,
 'logistic_epochs': 500,
 'model_path': 'logs/0',
 'normalize': True,
 'optimizer': 'Adam',
 'pretrain': True,
 'projection_dim': 64,
 'resnet': 'resnet50',
 'seed': 42,
 'start_epoch': 0,
 'temperature': 0.5,
 'weight_decay': 1e-06,
 'workers': 16}


In [0]:
args.batch_size = 64
args.dataset = "STL10" # make sure to check this with the (pre-)trained checkpoint
args.resnet = "resnet50" # make sure to check this with the (pre-)trained checkpoint
args.model_path = "logs"
args.epoch_num = 100
args.logistic_epochs = 400

### Load dataset into train/test dataloaders

In [0]:
root = "./datasets"
if args.dataset == "STL10":
    train_dataset = torchvision.datasets.STL10(
        root,
        split="train",
        download=True,
        transform=TransformsSimCLR(size=96).test_transform, # 224 in original paper
    )
    test_dataset = torchvision.datasets.STL10(
        root,
        split="test",
        download=True,
        transform=TransformsSimCLR(size=96).test_transform, # 224 in original paper
    )
elif args.dataset == "CIFAR10":
    train_dataset = torchvision.datasets.CIFAR10(
        root,
        train=True,
        download=True,
        transform=TransformsSimCLR(size=32).test_transform, # 224 in original paper
    )
    test_dataset = torchvision.datasets.CIFAR10(
        root,
        train=False,
        download=True,
        transform=TransformsSimCLR(size=32).test_transform, # 224 in original paper
    )
else:
    raise NotImplementedError

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=args.logistic_batch_size,
    shuffle=True,
    drop_last=True,
    num_workers=args.workers,
)

test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=args.logistic_batch_size,
    shuffle=False,
    drop_last=True,
    num_workers=args.workers,
)

Downloading http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz to ./datasets/stl10_binary.tar.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./datasets/stl10_binary.tar.gz to ./datasets
Files already downloaded and verified


### Load SimCLR model and load model weights

In [0]:
simclr_model, _, _ = load_model(args, train_loader, reload_model=True)
simclr_model = simclr_model.to(args.device)
simclr_model.eval()

SimCLR(
  (encoder): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
          (

In [0]:
## Logistic Regression
n_classes = 10 # stl-10 / cifar-10
model = LogisticRegression(simclr_model.n_features, n_classes)
model = model.to(args.device)

In [0]:
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
criterion = torch.nn.CrossEntropyLoss()

### Helper functions to map all input data $X$ to their latent representations $h$ that are used in linear evaluation (they only have to be computed once)

In [0]:
def inference(loader, context_model, device):
    feature_vector = []
    labels_vector = []
    for step, (x, y) in enumerate(loader):
        x = x.to(device)

        # get encoding
        with torch.no_grad():
            h, z = context_model(x)

        h = h.detach()

        feature_vector.extend(h.cpu().detach().numpy())
        labels_vector.extend(y.numpy())

        if step % 20 == 0:
            print(f"Step [{step}/{len(loader)}]\t Computing features...")

    feature_vector = np.array(feature_vector)
    labels_vector = np.array(labels_vector)
    print("Features shape {}".format(feature_vector.shape))
    return feature_vector, labels_vector


def get_features(context_model, train_loader, test_loader, device):
    train_X, train_y = inference(train_loader, context_model, device)
    test_X, test_y = inference(test_loader, context_model, device)
    return train_X, train_y, test_X, test_y


def create_data_loaders_from_arrays(X_train, y_train, X_test, y_test, batch_size):
    train = torch.utils.data.TensorDataset(
        torch.from_numpy(X_train), torch.from_numpy(y_train)
    )
    train_loader = torch.utils.data.DataLoader(
        train, batch_size=batch_size, shuffle=False
    )

    test = torch.utils.data.TensorDataset(
        torch.from_numpy(X_test), torch.from_numpy(y_test)
    )
    test_loader = torch.utils.data.DataLoader(
        test, batch_size=batch_size, shuffle=False
    )
    return train_loader, test_loader

In [0]:
print("### Creating features from pre-trained context model ###")
(train_X, train_y, test_X, test_y) = get_features(
    simclr_model, train_loader, test_loader, args.device
)

arr_train_loader, arr_test_loader = create_data_loaders_from_arrays(
    train_X, train_y, test_X, test_y, args.logistic_batch_size
)

### Creating features from pre-trained context model ###
Step [0/19]	 Computing features...
Features shape (4864, 2048)
Step [0/31]	 Computing features...
Step [20/31]	 Computing features...
Features shape (7936, 2048)


In [0]:
for epoch in range(args.logistic_epochs):
    loss_epoch, accuracy_epoch = train(args, arr_train_loader, simclr_model, model, criterion, optimizer)
    
    if epoch % 10 == 0:
      print(f"Epoch [{epoch}/{args.logistic_epochs}]\t Loss: {loss_epoch / len(train_loader)}\t Accuracy: {accuracy_epoch / len(train_loader)}")


# final testing
loss_epoch, accuracy_epoch = test(
    args, arr_test_loader, simclr_model, model, criterion, optimizer
)
print(
    f"[FINAL]\t Loss: {loss_epoch / len(test_loader)}\t Accuracy: {accuracy_epoch / len(test_loader)}"
)

Epoch [0/400]	 Loss: 1.3062220372651752	 Accuracy: 0.5746299342105263
Epoch [10/400]	 Loss: 0.5964312082842776	 Accuracy: 0.7793996710526315
Epoch [20/400]	 Loss: 0.5432291501446774	 Accuracy: 0.7987253289473685
Epoch [30/400]	 Loss: 0.5082822049918928	 Accuracy: 0.8102384868421053
Epoch [40/400]	 Loss: 0.4818095000166642	 Accuracy: 0.8201069078947368
Epoch [50/400]	 Loss: 0.46029093704725565	 Accuracy: 0.8285361842105263
Epoch [60/400]	 Loss: 0.4420443051739743	 Accuracy: 0.8357319078947368
Epoch [70/400]	 Loss: 0.4261354976578763	 Accuracy: 0.8427220394736842
Epoch [80/400]	 Loss: 0.41198996179982234	 Accuracy: 0.8490953947368421
Epoch [90/400]	 Loss: 0.3992275517237814	 Accuracy: 0.85546875
Epoch [100/400]	 Loss: 0.38758305813136856	 Accuracy: 0.8591694078947368
Epoch [110/400]	 Loss: 0.3768635000053205	 Accuracy: 0.8622532894736842
Epoch [120/400]	 Loss: 0.36692376983793157	 Accuracy: 0.8682154605263158
Epoch [130/400]	 Loss: 0.3576517983486778	 Accuracy: 0.873766447368421
Epoch [1