<a href="https://colab.research.google.com/github/SGSunil/simclr/blob/master/SimCLR_Notebook.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# SimCLR
PyTorch implementation of SimCLR: A Simple Framework for Contrastive Learning of Visual Representations by T. Chen et al. With support for the LARS (Layer-wise Adaptive Rate Scaling) optimizer and global batch norm.

[Link to paper](https://arxiv.org/pdf/2002.05709.pdf)


## Setup the repository

In [2]:
!git clone https://github.com/spijkervet/SimCLR.git
%cd SimCLR
!mkdir -p logs && cd logs && wget https://github.com/Spijkervet/SimCLR/releases/download/1.2/checkpoint_100.tar && cd ../
!wget https://raw.githubusercontent.com/Spijkervet/SimCLR/master/requirements.txt
!sh setup.sh || python3 -m pip install -r requirements.txt || exit 1
!pip install  pyyaml --upgrade

Cloning into 'SimCLR'...
remote: Enumerating objects: 531, done.[K
remote: Counting objects: 100% (4/4), done.[K
remote: Compressing objects: 100% (4/4), done.[K
remote: Total 531 (delta 0), reused 0 (delta 0), pack-reused 527[K
Receiving objects: 100% (531/531), 327.97 KiB | 10.93 MiB/s, done.
Resolving deltas: 100% (292/292), done.
/content/SimCLR/SimCLR
--2021-07-25 11:07:48--  https://github.com/Spijkervet/SimCLR/releases/download/1.2/checkpoint_100.tar
Resolving github.com (github.com)... 192.30.255.112
Connecting to github.com (github.com)|192.30.255.112|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://github-releases.githubusercontent.com/246276098/8ae3c180-64bd-11ea-91fe-0f47017fe9be?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIWNJYAX4CSVEH53A%2F20210725%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20210725T110711Z&X-Amz-Expires=300&X-Amz-Signature=e06d4a21060b20f87ec05ba9fb5e6513bc29a5b69f29ff5f5599bc9db96fc7c5&X-Amz-SignedHead

# Part 1:
## SimCLR pre-training

In [39]:
# whether to use a TPU or not (set in Runtime -> Change Runtime Type)
use_tpu = False
!mkdir save

#### Install PyTorch/XLA

In [4]:
if use_tpu:
  VERSION = "20200220" #@param ["20200220","nightly", "xrt==1.15.0"]
  !curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
  !python pytorch-xla-env-setup.py --version $VERSION

In [40]:
import os
import torch
import numpy as np

if use_tpu:
  # imports the torch_xla package for TPU support
  import torch_xla
  import torch_xla.core.xla_model as xm
  dev = xm.xla_device()
  print(dev)
  
import torchvision
import argparse

from torch.utils.tensorboard import SummaryWriter

apex = False
try:
    from apex import amp
    apex = True
except ImportError:
    print(
        "Install the apex package from https://www.github.com/nvidia/apex to use fp16 for training"
    )

from model import save_model, load_optimizer
from simclr import SimCLR
from simclr.modules import get_resnet, NT_Xent
from simclr.modules.transformations import TransformsSimCLR

Install the apex package from https://www.github.com/nvidia/apex to use fp16 for training


### Load arguments from `config/config.yaml`

In [41]:
from pprint import pprint
import argparse
from utils import yaml_config_hook

parser = argparse.ArgumentParser(description="SimCLR")
config = yaml_config_hook("./config/config.yaml")
for k, v in config.items():
    parser.add_argument(f"--{k}", default=v, type=type(v))

args = parser.parse_args([])
args.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [42]:
args.device

device(type='cuda')

In [43]:
### override any configuration parameters here, e.g. to adjust for use on GPUs on the Colab platform:
args.batch_size = 128
args.resnet = "resnet18"
pprint(vars(args))

{'batch_size': 128,
 'dataparallel': 0,
 'dataset': 'CIFAR10',
 'dataset_dir': './datasets',
 'device': device(type='cuda'),
 'epoch_num': 100,
 'epochs': 100,
 'gpus': 1,
 'image_size': 224,
 'logistic_batch_size': 256,
 'logistic_epochs': 500,
 'model_path': 'save',
 'nodes': 1,
 'nr': 0,
 'optimizer': 'Adam',
 'pretrain': True,
 'projection_dim': 64,
 'reload': False,
 'resnet': 'resnet18',
 'seed': 42,
 'start_epoch': 0,
 'temperature': 0.5,
 'weight_decay': 1e-06,
 'workers': 8}


### Load dataset into train loader

In [44]:
torch.manual_seed(args.seed)
np.random.seed(args.seed)

if args.dataset == "STL10":
    train_dataset = torchvision.datasets.STL10(
        args.dataset_dir,
        split="unlabeled",
        download=True,
        transform=TransformsSimCLR(size=args.image_size),
    )
elif args.dataset == "CIFAR10":
    train_dataset = torchvision.datasets.CIFAR10(
        args.dataset_dir,
        download=True,
        transform=TransformsSimCLR(size=args.image_size),
    )
else:
    raise NotImplementedError

if args.nodes > 1:
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset, num_replicas=args.world_size, rank=rank, shuffle=True
    )
else:
    train_sampler = None

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=args.batch_size,
    shuffle=(train_sampler is None),
    drop_last=True,
    num_workers=args.workers,
    sampler=train_sampler,
)

Files already downloaded and verified


  cpuset_checked))


### Load the SimCLR model, optimizer and learning rate scheduler

In [46]:
# initialize ResNet
encoder = get_resnet(args.resnet, pretrained=False)
n_features = encoder.fc.in_features  # get dimensions of fc layer

# initialize model
model = SimCLR(encoder, args.projection_dim, n_features)
if args.reload:
    model_fp = os.path.join(
        args.model_path, "checkpoint_{}.tar".format(args.epoch_num)
    )
    model.load_state_dict(torch.load(model_fp, map_location=args.device.type))
model = model.to(args.device)

# optimizer / loss
optimizer, scheduler = load_optimizer(args, model)

### Initialize the criterion (NT-Xent loss)

In [47]:
criterion = NT_Xent(args.batch_size, args.temperature, world_size=1)

### Setup TensorBoard for logging experiments

In [48]:
writer = SummaryWriter()

### Train function

In [49]:
def train(args, train_loader, model, criterion, optimizer, writer):
    loss_epoch = 0
    for step, ((x_i, x_j), _) in enumerate(train_loader):
        optimizer.zero_grad()
        x_i = x_i.cuda(non_blocking=True)
        x_j = x_j.cuda(non_blocking=True)

        # positive pair, with encoding
        h_i, h_j, z_i, z_j = model(x_i, x_j)

        loss = criterion(z_i, z_j)
        loss.backward()

        optimizer.step()

        print("step: ", step)
        if step % 50 == 0:
            print(f"Step [{step}/{len(train_loader)}]\t Loss: {loss.item()}")

        writer.add_scalar("Loss/train_epoch", loss.item(), args.global_step)
        loss_epoch += loss.item()
        args.global_step += 1
    return loss_epoch


### Start training

In [50]:
args.global_step = 0
args.current_epoch = 0
for epoch in range(args.start_epoch, args.epochs):
    lr = optimizer.param_groups[0]["lr"]
    loss_epoch = train(args, train_loader, model, criterion, optimizer, writer)

    if scheduler:
        scheduler.step()

    # save every 10 epochs
    if epoch % 10 == 0:
        save_model(args, model, optimizer)

    writer.add_scalar("Loss/train", loss_epoch / len(train_loader), epoch)
    writer.add_scalar("Misc/learning_rate", lr, epoch)
    print(
        f"Epoch [{epoch}/{args.epochs}]\t Loss: {loss_epoch / len(train_loader)}\t lr: {round(lr, 5)}"
    )
    args.current_epoch += 1

# end training
save_model(args, model, optimizer)

  cpuset_checked))


step:  0
Step [0/390]	 Loss: 5.5189290046691895
step:  1
step:  2
step:  3
step:  4
step:  5
step:  6
step:  7
step:  8
step:  9
step:  10
step:  11
step:  12
step:  13
step:  14
step:  15
step:  16
step:  17
step:  18
step:  19
step:  20
step:  21
step:  22
step:  23
step:  24
step:  25
step:  26
step:  27
step:  28
step:  29
step:  30
step:  31
step:  32
step:  33
step:  34
step:  35
step:  36
step:  37
step:  38
step:  39
step:  40
step:  41
step:  42
step:  43
step:  44
step:  45
step:  46
step:  47
step:  48
step:  49
step:  50
Step [50/390]	 Loss: 5.371453762054443
step:  51
step:  52
step:  53
step:  54
step:  55
step:  56
step:  57
step:  58
step:  59
step:  60
step:  61
step:  62
step:  63
step:  64
step:  65
step:  66
step:  67
step:  68
step:  69
step:  70
step:  71
step:  72
step:  73
step:  74
step:  75
step:  76
step:  77
step:  78
step:  79
step:  80
step:  81
step:  82
step:  83
step:  84
step:  85
step:  86
step:  87
step:  88
step:  89
step:  90
step:  91
step:  92
st

KeyboardInterrupt: ignored

## OPTIONAL: Download last checkpoint to local drive (replace `100` with `args.epochs`)

In [51]:
from google.colab import files
files.download('checkpoint_100.tar')

FileNotFoundError: ignored

# Part 2:
## Linear evaluation using logistic regression, using weights from frozen, pre-trained SimCLR model

In [52]:
import torch
import torchvision
import numpy as np
import argparse
from simclr.modules import LogisticRegression


In [53]:
def train(args, loader, simclr_model, model, criterion, optimizer):
    loss_epoch = 0
    accuracy_epoch = 0
    for step, (x, y) in enumerate(loader):
        optimizer.zero_grad()

        x = x.to(args.device)
        y = y.to(args.device)

        output = model(x)
        loss = criterion(output, y)

        predicted = output.argmax(1)
        acc = (predicted == y).sum().item() / y.size(0)
        accuracy_epoch += acc

        loss.backward()
        optimizer.step()

        loss_epoch += loss.item()
        # if step % 100 == 0:
        #     print(
        #         f"Step [{step}/{len(loader)}]\t Loss: {loss.item()}\t Accuracy: {acc}"
        #     )

    return loss_epoch, accuracy_epoch

In [54]:
def test(args, loader, simclr_model, model, criterion, optimizer):
    loss_epoch = 0
    accuracy_epoch = 0
    model.eval()
    for step, (x, y) in enumerate(loader):
        model.zero_grad()

        x = x.to(args.device)
        y = y.to(args.device)

        output = model(x)
        loss = criterion(output, y)

        predicted = output.argmax(1)
        acc = (predicted == y).sum().item() / y.size(0)
        accuracy_epoch += acc

        loss_epoch += loss.item()

    return loss_epoch, accuracy_epoch



In [None]:
from pprint import pprint
from utils import yaml_config_hook

parser = argparse.ArgumentParser(description="SimCLR")
config = yaml_config_hook("./config/config.yaml")
for k, v in config.items():
    parser.add_argument(f"--{k}", default=v, type=type(v))

args = parser.parse_args([])

if use_tpu:
  args.device = dev
else:
  args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [69]:
args.batch_size = 64
args.dataset = "STL10" # make sure to check this with the (pre-)trained checkpoint
args.resnet = "resnet50" # make sure to check this with the (pre-)trained checkpoint
args.model_path = "logs"
args.epoch_num = 100
args.logistic_epochs = 500

### Download a pre-trained model for demonstration purposes

In [None]:
!wget https://github.com/Spijkervet/SimCLR/releases/download/1.2/checkpoint_100.tar

--2020-07-13 20:47:57--  https://github.com/Spijkervet/SimCLR/releases/download/1.2/checkpoint_100.tar
Resolving github.com (github.com)... 140.82.114.3
Connecting to github.com (github.com)|140.82.114.3|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://github-production-release-asset-2e65be.s3.amazonaws.com/246276098/8ae3c180-64bd-11ea-91fe-0f47017fe9be?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIWNJYAX4CSVEH53A%2F20200713%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20200713T204757Z&X-Amz-Expires=300&X-Amz-Signature=66ef1af62e159b36feeb4a5199ed257f16a98294da54b7ce2b06d84389026c56&X-Amz-SignedHeaders=host&actor_id=0&repo_id=246276098&response-content-disposition=attachment%3B%20filename%3Dcheckpoint_100.tar&response-content-type=application%2Foctet-stream [following]
--2020-07-13 20:47:58--  https://github-production-release-asset-2e65be.s3.amazonaws.com/246276098/8ae3c180-64bd-11ea-91fe-0f47017fe9be?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-

### Load dataset into train/test dataloaders

In [70]:
if args.dataset == "STL10":
    train_dataset = torchvision.datasets.STL10(
        args.dataset_dir,
        split="train",
        download=True,
        transform=TransformsSimCLR(size=args.image_size).test_transform,
    )
    test_dataset = torchvision.datasets.STL10(
        args.dataset_dir,
        split="test",
        download=True,
        transform=TransformsSimCLR(size=args.image_size).test_transform,
    )
elif args.dataset == "CIFAR10":
    train_dataset = torchvision.datasets.CIFAR10(
        args.dataset_dir,
        train=True,
        download=True,
        transform=TransformsSimCLR(size=args.image_size).test_transform,
    )
    test_dataset = torchvision.datasets.CIFAR10(
        args.dataset_dir,
        train=False,
        download=True,
        transform=TransformsSimCLR(size=args.image_size).test_transform,
    )
else:
    raise NotImplementedError

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=args.logistic_batch_size,
    shuffle=True,
    drop_last=True,
    num_workers=args.workers,
)


Downloading http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz to ./datasets/stl10_binary.tar.gz


HBox(children=(FloatProgress(value=0.0, max=2640397119.0), HTML(value='')))


Extracting ./datasets/stl10_binary.tar.gz to ./datasets
Files already downloaded and verified


  cpuset_checked))


In [71]:

test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=args.logistic_batch_size,
    shuffle=False,
    drop_last=True,
    num_workers=args.workers,
)

  cpuset_checked))


### Load ResNet encoder / SimCLR and load model weights

In [72]:
encoder = get_resnet(args.resnet, pretrained=False) # don't load a pre-trained model from PyTorch repo
n_features = encoder.fc.in_features  # get dimensions of fc layer

# load pre-trained model from checkpoint
simclr_model = SimCLR( encoder, args.projection_dim, n_features)
model_fp = os.path.join(
    "logs", "checkpoint_{}.tar".format(args.epoch_num)
)
simclr_model.load_state_dict(torch.load(model_fp, map_location=args.device.type))
simclr_model = simclr_model.to(args.device)
    

In [67]:
args.device.type

'cuda'

In [38]:
args.device.type

'cuda'

In [21]:
args.device

device(type='cuda')

In [73]:
## Logistic Regression
n_classes = 10 # stl-10 / cifar-10
model = LogisticRegression(simclr_model.n_features, n_classes)
model = model.to(args.device)

In [74]:
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
criterion = torch.nn.CrossEntropyLoss()

### Helper functions to map all input data $X$ to their latent representations $h$ that are used in linear evaluation (they only have to be computed once)

In [75]:
def inference(loader, simclr_model, device):
    feature_vector = []
    labels_vector = []
    for step, (x, y) in enumerate(loader):
        x = x.to(device)

        # get encoding
        with torch.no_grad():
            h, _, z, _ = simclr_model(x, x)

        h = h.detach()

        feature_vector.extend(h.cpu().detach().numpy())
        labels_vector.extend(y.numpy())

        if step % 20 == 0:
            print(f"Step [{step}/{len(loader)}]\t Computing features...")

    feature_vector = np.array(feature_vector)
    labels_vector = np.array(labels_vector)
    print("Features shape {}".format(feature_vector.shape))
    return feature_vector, labels_vector


def get_features(context_model, train_loader, test_loader, device):
    train_X, train_y = inference(train_loader, context_model, device)
    test_X, test_y = inference(test_loader, context_model, device)
    return train_X, train_y, test_X, test_y


def create_data_loaders_from_arrays(X_train, y_train, X_test, y_test, batch_size):
    train = torch.utils.data.TensorDataset(
        torch.from_numpy(X_train), torch.from_numpy(y_train)
    )
    train_loader = torch.utils.data.DataLoader(
        train, batch_size=batch_size, shuffle=False
    )

    test = torch.utils.data.TensorDataset(
        torch.from_numpy(X_test), torch.from_numpy(y_test)
    )
    test_loader = torch.utils.data.DataLoader(
        test, batch_size=batch_size, shuffle=False
    )
    return train_loader, test_loader

In [76]:
print("### Creating features from pre-trained context model ###")
(train_X, train_y, test_X, test_y) = get_features(
    simclr_model, train_loader, test_loader, args.device
)

arr_train_loader, arr_test_loader = create_data_loaders_from_arrays(
    train_X, train_y, test_X, test_y, args.logistic_batch_size
)

### Creating features from pre-trained context model ###


  cpuset_checked))


Step [0/19]	 Computing features...
Features shape (4864, 2048)
Step [0/31]	 Computing features...
Step [20/31]	 Computing features...
Features shape (7936, 2048)


In [77]:
for epoch in range(args.logistic_epochs):
    loss_epoch, accuracy_epoch = train(args, arr_train_loader, simclr_model, model, criterion, optimizer)
    
    if epoch % 10 == 0:
      print(f"Epoch [{epoch}/{args.logistic_epochs}]\t Loss: {loss_epoch / len(train_loader)}\t Accuracy: {accuracy_epoch / len(train_loader)}")


# final testing
loss_epoch, accuracy_epoch = test(
    args, arr_test_loader, simclr_model, model, criterion, optimizer
)
print(
    f"[FINAL]\t Loss: {loss_epoch / len(test_loader)}\t Accuracy: {accuracy_epoch / len(test_loader)}"
)

Epoch [0/500]	 Loss: 1.3823571612960415	 Accuracy: 0.5575657894736842
Epoch [10/500]	 Loss: 0.5902091625489687	 Accuracy: 0.7888569078947368
Epoch [20/500]	 Loss: 0.545104128749747	 Accuracy: 0.8040707236842105
Epoch [30/500]	 Loss: 0.5162652919166967	 Accuracy: 0.8141447368421053
Epoch [40/500]	 Loss: 0.4943905378642835	 Accuracy: 0.8196957236842105
Epoch [50/500]	 Loss: 0.4765512786413494	 Accuracy: 0.826891447368421
Epoch [60/500]	 Loss: 0.46139819998490184	 Accuracy: 0.8330592105263158
Epoch [70/500]	 Loss: 0.4481848243035768	 Accuracy: 0.8384046052631579
Epoch [80/500]	 Loss: 0.43644678592681885	 Accuracy: 0.8429276315789473
Epoch [90/500]	 Loss: 0.4258712718361302	 Accuracy: 0.84765625
Epoch [100/500]	 Loss: 0.4162356100584331	 Accuracy: 0.8507401315789473
Epoch [110/500]	 Loss: 0.4073750972747803	 Accuracy: 0.8544407894736842
Epoch [120/500]	 Loss: 0.39916436295760305	 Accuracy: 0.858141447368421
Epoch [130/500]	 Loss: 0.391505944101434	 Accuracy: 0.86328125
Epoch [140/500]	 Los