# SimCLR
PyTorch implementation of SimCLR: A Simple Framework for Contrastive Learning of Visual Representations by T. Chen et al. With support for the LARS (Layer-wise Adaptive Rate Scaling) optimizer.

[Link to paper](https://arxiv.org/pdf/2002.05709.pdf)


## Setup the repository

In [1]:
# !git clone https://github.com/spijkervet/SimCLR.git
%cd SimCLR
# !wget https://github.com/Spijkervet/SimCLR/releases/download/1.2/checkpoint_100.tar
# !sh setup.sh || python3 -m pip install -r requirements.txt || exit 1
# !pip install  pyyaml --upgrade

/home/ubuntu/cassava_disease_classification/salomon_exp/SimCLR


In [2]:
import torch
import torchvision
import torchvision.transforms as transforms
import argparse

from experiment import ex
from model import load_model
from utils import post_config_hook

from modules import LogisticRegression

from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
import glob
import numpy as np

from PIL import Image

# Part 1:
## SimCLR pre-training

In [3]:
# whether to use a TPU or not (set in Runtime -> Change Runtime Type)
use_tpu = False

#### Install PyTorch/XLA

In [4]:
if use_tpu:
    VERSION = "20200220" #@param ["20200220","nightly", "xrt==1.15.0"]
    !curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
    !python pytorch-xla-env-setup.py --version $VERSION

In [5]:
import os
import torch

if use_tpu:
    # imports the torch_xla package for TPU support
    import torch_xla
    import torch_xla.core.xla_model as xm
    dev = xm.xla_device()
    print(dev)

import torchvision
import argparse

from torch.utils.tensorboard import SummaryWriter

apex = False
try:
    from apex import amp
    apex = True
except ImportError:
    print(
        "Install the apex package from https://www.github.com/nvidia/apex to use fp16 for training"
    )

from model import load_model, save_model
from modules import NT_Xent
from modules.transformations import TransformsSimCLR
from utils import post_config_hook

Install the apex package from https://www.github.com/nvidia/apex to use fp16 for training


In [6]:
# !pip install ipdb

In [7]:
import ipdb

### Load arguments from `config/config.yaml`

In [8]:
from pprint import pprint
from utils.yaml_config_hook import yaml_config_hook

config = yaml_config_hook("./config/config.yaml")
args = argparse.Namespace(**config)

if use_tpu:
    args.device = dev
else:
    args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

args.out_dir = "logs"
if not os.path.exists("logs"):
    os.makedirs("logs")

In [9]:
### override any configuration parameters here, e.g. to adjust for use on GPUs on the Colab platform:
args.batch_size = 64
args.resnet = "resnet18"
pprint(vars(args))

{'batch_size': 64,
 'dataset': 'CIFAR10',
 'device': device(type='cuda', index=0),
 'epoch_num': 100,
 'epochs': 100,
 'fp16': False,
 'fp16_opt_level': 'O2',
 'logistic_batch_size': 256,
 'logistic_epochs': 500,
 'model_path': 'logs/0',
 'normalize': True,
 'optimizer': 'Adam',
 'out_dir': 'logs',
 'pretrain': True,
 'projection_dim': 64,
 'resnet': 'resnet18',
 'seed': 42,
 'start_epoch': 0,
 'temperature': 0.5,
 'weight_decay': 1e-06,
 'workers': 16}


In [10]:
args.dataset = 'casava'

### Load dataset into train loader

In [11]:
data_path = "../data/train/train"
test_path = "../data/test/test"
extraimage_path = "../data/extraimages/extraimages"

In [12]:
print('Train set:')
class_distrbution = {}
for cls in os.listdir(data_path):
    print('{}:{}'.format(cls, len(os.listdir(os.path.join(data_path, cls)))))
    class_distrbution[cls] =  len(os.listdir(os.path.join(data_path, cls)))
im = Image.open(data_path+'/cgm/train-cgm-738.jpg')
print(im.size)
class_distrbution

Train set:
cmd:2658
healthy:316
cbsd:1443
cbb:466
cgm:773
(500, 500)


{'cmd': 2658, 'healthy': 316, 'cbsd': 1443, 'cbb': 466, 'cgm': 773}

In [13]:

class CassavaDataset(Dataset):
    def __init__(self, path, size, s=1, mutation = False):
        self.classes = os.listdir(path)
        self.path = [f"{path}/{className}" for className in self.classes]
        self.file_list = [glob.glob(f"{x}/*") for x in self.path]
        self.mutation = mutation
        color_jitter = torchvision.transforms.ColorJitter(
            0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s
        )
        
        self.train_transform = torchvision.transforms.Compose(
            [
                torchvision.transforms.RandomResizedCrop(size),
                torchvision.transforms.RandomHorizontalFlip(),  # with 0.5 probability
                torchvision.transforms.RandomApply([color_jitter], p=0.8),
                torchvision.transforms.RandomGrayscale(p=0.2),
                torchvision.transforms.ToTensor(),
            ]
        )
        
        self.test_transform = torchvision.transforms.Compose(
            [
                torchvision.transforms.Resize((size, size)),
                torchvision.transforms.ToTensor()
            ]
        )
        

        files = []
        class_names = {}
        for i, className in enumerate(self.classes):
            for fileName in self.file_list[i]:
                files.append([i, className, fileName])

                name = str(i)+'-'+className
                if name not in class_names:
                    class_names[name] = 1
                else:
                    class_names[name] += 1
        self.file_list = files
        files = None
    
    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        fileName = self.file_list[idx][2]
        classCategory = self.file_list[idx][0]
        image = Image.open(fileName)

        if self.mutation:
            image1 = self.train_transform(image)
            image2 = self.train_transform(image)
            
            sample = [[image1, image2], classCategory]
        else:
            
            image = self.test_transform(image)
            sample = [image, classCategory]

        return sample

In [15]:
size = 224

train_data = CassavaDataset(data_path, size, s=1, mutation = False)

test_data = CassavaDataset(test_path, size, s=1, mutation = False)

extraimage_data = CassavaDataset(extraimage_path, size, s=1, mutation = True)

#######################################################################
validation_split = 0.2
shuffle_dataset = True
random_seed= 42 #42

# Creating data indices for training and validation splits:
dataset_size = len(train_data)
indices = list(range(dataset_size))
split = int(np.floor(validation_split * dataset_size))

if shuffle_dataset :
    np.random.seed(random_seed)
    np.random.shuffle(indices)

train_indices, val_indices = indices[split:], indices[:split]
########################################################################

train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)

batch_size_train = 64# 125
batch_size_eval = 64 #250
n_workers = 2

train_loader = torch.utils.data.DataLoader(train_data, batch_size = batch_size_train,
                                             sampler = train_sampler, num_workers = n_workers)

valid_loader = torch.utils.data.DataLoader(train_data, batch_size = batch_size_eval,
                                             sampler = valid_sampler, num_workers = n_workers)

unlabeled_loader = torch.utils.data.DataLoader(extraimage_data, batch_size = batch_size_eval, 
                                              shuffle =shuffle_dataset, num_workers = n_workers)


In [16]:
# next(unlabeled_loader.__iter__())

In [17]:
# len(next(unlabeled_loader.__iter__())[1])

In [18]:
# root = "./datasets"

# train_sampler = None

# if args.dataset == "STL10":
#     train_dataset = torchvision.datasets.STL10(
#         root, split="unlabeled", download=True, transform=TransformsSimCLR(size=96)
#     )
# elif args.dataset == "CIFAR10":
#     train_dataset = torchvision.datasets.CIFAR10(
#         root, download=True, transform=TransformsSimCLR(size=32)
#     )
# else:
#     raise NotImplementedError

# train_loader = torch.utils.data.DataLoader(
#     train_dataset,
#     batch_size=args.batch_size,
#     shuffle=(train_sampler is None),
#     drop_last=True,
#     num_workers=args.workers,
#     sampler=train_sampler,
# )

### Load the SimCLR model, optimizer and learning rate scheduler

In [19]:
pprint(vars(args))

{'batch_size': 64,
 'dataset': 'casava',
 'device': device(type='cuda', index=0),
 'epoch_num': 100,
 'epochs': 100,
 'fp16': False,
 'fp16_opt_level': 'O2',
 'logistic_batch_size': 256,
 'logistic_epochs': 500,
 'model_path': 'logs/0',
 'normalize': True,
 'optimizer': 'Adam',
 'out_dir': 'logs',
 'pretrain': True,
 'projection_dim': 64,
 'resnet': 'resnet18',
 'seed': 42,
 'start_epoch': 0,
 'temperature': 0.5,
 'weight_decay': 1e-06,
 'workers': 16}


In [20]:
model, optimizer, scheduler = load_model(args, unlabeled_loader)

In [21]:
model

SimCLR(
  (encoder): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_runn

In [22]:
optimizer

Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    eps: 1e-08
    lr: 0.0003
    weight_decay: 0
)

In [23]:
scheduler

### Setup TensorBoard for logging experiments

In [24]:
tb_dir = os.path.join(args.out_dir, "colab")
if not os.path.exists(tb_dir):
    os.makedirs(tb_dir)
    
writer = SummaryWriter(log_dir=tb_dir)

### Create the mask that will remove correlated samples from the negative examples

### Initialize the criterion (NT-Xent loss)

In [25]:
args.batch_size

64

In [26]:
criterion = NT_Xent(args.batch_size, args.temperature, args.device)

In [27]:
criterion

NT_Xent(
  (criterion): CrossEntropyLoss()
  (similarity_f): CosineSimilarity()
)

### Start training

In [30]:
def train(args, train_loader, model, criterion, optimizer, writer):
    loss_epoch = 0
    for step, ((x_i, x_j), _) in enumerate(train_loader):

        if x_i.shape[0] != args.batch_size:
            continue
        optimizer.zero_grad()
        x_i = x_i.to(args.device)
        x_j = x_j.to(args.device)

        # positive pair, with encoding
        h_i, z_i = model(x_i)
        h_j, z_j = model(x_j)

#         ipdb.set_trace()
        loss = criterion(z_i, z_j)

        if apex and args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()

        optimizer.step()

        if step % 50 == 0:
            print(f"Step [{step}/{len(train_loader)}]\t Loss: {loss.item()}")

        writer.add_scalar("Loss/train_epoch", loss.item(), args.global_step)
        loss_epoch += loss.item()
        args.global_step += 1

    return loss_epoch

In [None]:
import pdb, traceback, sys

if __name__ == '__main__':
    try:
        
        args.global_step = 0
        args.current_epoch = 0
        for epoch in range(args.start_epoch, args.epochs):
            lr = optimizer.param_groups[0]['lr']
        #     ipdb.set_trace()
            loss_epoch = train(args, unlabeled_loader, model, criterion, optimizer, writer)

            if scheduler:
                scheduler.step()

            if epoch % 5 == 0:
                save_model(args, model, optimizer)

            writer.add_scalar("Loss/train", loss_epoch / len(unlabeled_loader), epoch)
            writer.add_scalar("Misc/learning_rate", lr, epoch)
            print(
                f"Epoch [{epoch}/{args.epochs}]\t Loss: {loss_epoch / len(unlabeled_loader)}\t lr: {round(lr, 5)}"
            )
            args.current_epoch += 1

        ## end training
        save_model(args, model, optimizer)
    except:
        extype, value, tb = sys.exc_info()
        traceback.print_exc()
        pdb.post_mortem(tb)

Step [0/197]	 Loss: 4.515799522399902
Step [50/197]	 Loss: 4.516963481903076
Step [100/197]	 Loss: 4.4952311515808105
Step [150/197]	 Loss: 4.5351481437683105
Epoch [0/100]	 Loss: 4.49011571274191	 lr: 0.0003
Step [0/197]	 Loss: 4.470291614532471
Step [50/197]	 Loss: 4.598147392272949
Step [100/197]	 Loss: 4.407454967498779
Step [150/197]	 Loss: 4.158841609954834
Epoch [1/100]	 Loss: 4.3127174353236475	 lr: 0.0003
Step [0/197]	 Loss: 4.162941932678223
Step [50/197]	 Loss: 4.267168998718262
Step [100/197]	 Loss: 4.181676864624023
Step [150/197]	 Loss: 4.039258003234863
Epoch [2/100]	 Loss: 4.127113335023677	 lr: 0.0003
Step [0/197]	 Loss: 4.064942359924316
Step [50/197]	 Loss: 3.9076011180877686
Step [100/197]	 Loss: 3.8971030712127686
Step [150/197]	 Loss: 4.091803073883057
Epoch [3/100]	 Loss: 4.026782896312965	 lr: 0.0003
Step [0/197]	 Loss: 4.051982879638672
Step [50/197]	 Loss: 3.994680881500244
Step [100/197]	 Loss: 3.9155969619750977
Step [150/197]	 Loss: 3.856182336807251
Epoch 

In [31]:
# import pdb, traceback, sys

# def bombs():
#     a = []
#     print (a[0])

# if __name__ == '__main__':
#     try:
#         bombs()
#     except:
#         extype, value, tb = sys.exc_info()
#         traceback.print_exc()
#         pdb.post_mortem(tb)

## Download last checkpoint to local drive (replace `100` with `args.epochs`)

In [None]:
from google.colab import files
files.download('./logs/checkpoint_args.epochs.tar') # checkpoint_100.tar

# Part 2:
## Linear evaluation using logistic regression, using weights from frozen, pre-trained SimCLR model

In [None]:
def train(args, loader, simclr_model, model, criterion, optimizer):
    loss_epoch = 0
    accuracy_epoch = 0
    for step, (x, y) in enumerate(loader):
        optimizer.zero_grad()

        x = x.to(args.device)
        y = y.to(args.device)

        # get encoding
        with torch.no_grad():
            h, z = simclr_model(x)
            # h = 512
            # z = 64

        output = model(h)
        loss = criterion(output, y)

        predicted = output.argmax(1)
        acc = (predicted == y).sum().item() / y.size(0)
        accuracy_epoch += acc

        loss.backward()
        optimizer.step()

        loss_epoch += loss.item()
        if step % 1 == 0:
            print(f"Step [{step}/{len(loader)}]\t Loss: {loss.item()}\t Accuracy: {acc}")

    return loss_epoch, accuracy_epoch

In [None]:
def test(args, loader, simclr_model, model, criterion, optimizer):
    loss_epoch = 0
    accuracy_epoch = 0
    model.eval()
    for step, (x, y) in enumerate(loader):
        model.zero_grad()

        x = x.to(args.device)
        y = y.to(args.device)

        # get encoding
        with torch.no_grad():
            h, z = simclr_model(x)
            # h = 512
            # z = 64

        output = model(h)
        loss = criterion(output, y)

        predicted = output.argmax(1)
        acc = (predicted == y).sum().item() / y.size(0)
        accuracy_epoch += acc

        loss_epoch += loss.item()


    return loss_epoch, accuracy_epoch

In [None]:
from pprint import pprint
from utils.yaml_config_hook import yaml_config_hook

config = yaml_config_hook("./config/config.yaml")
pprint(config)
args = argparse.Namespace(**config)

if use_tpu:
  args.device = dev
else:
  args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

{'batch_size': 256,
 'dataset': 'STL10',
 'epoch_num': 100,
 'epochs': 100,
 'fp16': False,
 'fp16_opt_level': 'O2',
 'logistic_batch_size': 256,
 'logistic_epochs': 100,
 'model_path': 'logs/0',
 'normalize': True,
 'optimizer': 'Adam',
 'projection_dim': 64,
 'resnet': 'resnet50',
 'seed': 42,
 'start_epoch': 0,
 'temperature': 0.5,
 'weight_decay': 1e-06,
 'workers': 16}


In [None]:
args.batch_size = 64
args.resnet = "resnet18"
args.model_path = "logs"
args.epoch_num = 100

### Load dataset into train/test dataloaders

In [None]:
root = "./datasets"
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

if args.dataset == "STL10":
    train_dataset = torchvision.datasets.STL10(
        root, split="train", download=True, transform=torchvision.transforms.ToTensor()
    )
    test_dataset = torchvision.datasets.STL10(
        root, split="test", download=True, transform=torchvision.transforms.ToTensor()
    )
elif args.dataset == "CIFAR10":
    train_dataset = torchvision.datasets.CIFAR10(
        root, train=True, download=True, transform=transform
    )
    test_dataset = torchvision.datasets.CIFAR10(
        root, train=False, download=True, transform=transform
    )
else:
    raise NotImplementedError

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=args.logistic_batch_size,
    shuffle=True,
    drop_last=True,
    num_workers=args.workers,
)

test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=args.logistic_batch_size,
    shuffle=False,
    drop_last=True,
    num_workers=args.workers,
)

Files already downloaded and verified
Files already downloaded and verified


### Load SimCLR model and load model weights

In [None]:
simclr_model, _, _ = load_model(args, train_loader, reload_model=True)
simclr_model = simclr_model.to(args.device)
simclr_model.eval()

SimCLR(
  (encoder): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_runn

In [None]:
## Logistic Regression
n_classes = 10 # stl-10
model = LogisticRegression(simclr_model.n_features, n_classes)
model = model.to(args.device)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
criterion = torch.nn.CrossEntropyLoss()

In [None]:
for epoch in range(args.logistic_epochs):
    loss_epoch, accuracy_epoch = train(args, train_loader, simclr_model, model, criterion, optimizer)
    print(f"Epoch [{epoch}/{args.logistic_epochs}]\t Loss: {loss_epoch / len(train_loader)}\t Accuracy: {accuracy_epoch / len(train_loader)}")

# final testing
loss_epoch, accuracy_epoch = test(args, test_loader, simclr_model, model, criterion, optimizer)
print(f"[FINAL]\t Loss: {loss_epoch / len(test_loader)}\t Accuracy: {accuracy_epoch / len(test_loader)}")

Step [0/19]	 Loss: 2.544018268585205	 Accuracy: 0.078125
Step [1/19]	 Loss: 2.3547680377960205	 Accuracy: 0.15234375
Step [2/19]	 Loss: 2.3515422344207764	 Accuracy: 0.140625
Step [3/19]	 Loss: 2.2883174419403076	 Accuracy: 0.15234375
Step [4/19]	 Loss: 2.2646541595458984	 Accuracy: 0.15625
Step [5/19]	 Loss: 2.1859045028686523	 Accuracy: 0.16796875
Step [6/19]	 Loss: 2.098809242248535	 Accuracy: 0.25
Step [7/19]	 Loss: 2.1023709774017334	 Accuracy: 0.19140625
Step [8/19]	 Loss: 2.0565273761749268	 Accuracy: 0.23046875
Step [9/19]	 Loss: 2.0184667110443115	 Accuracy: 0.2421875
Step [10/19]	 Loss: 1.9349607229232788	 Accuracy: 0.265625
Step [11/19]	 Loss: 2.036658525466919	 Accuracy: 0.203125
Step [12/19]	 Loss: 1.8934686183929443	 Accuracy: 0.2890625
Step [13/19]	 Loss: 1.9491320848464966	 Accuracy: 0.29296875
Step [14/19]	 Loss: 1.9400029182434082	 Accuracy: 0.24609375
Step [15/19]	 Loss: 1.9001213312149048	 Accuracy: 0.26953125
Step [16/19]	 Loss: 1.903915524482727	 Accuracy: 0.28125

KeyboardInterrupt: ignored