# SimCLR implementation #

Implementation following: https://theaisummer.com/simclr/

In [None]:
!pip install torch torchvision pytorch-lightning lightning-bolts

In [4]:
import os

import torch
import torchvision.models as models
from torchvision.datasets import STL10, EuroSAT
from torch.utils.data import DataLoader
from torch.multiprocessing import cpu_count
from pytorch_lightning.callbacks import GradientAccumulationScheduler, ModelCheckpoint
from pytorch_lightning import Trainer
import torchvision.transforms as T
import pytorch_lightning as pl
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from pl_bolts.optimizers import LinearWarmupCosineAnnealingLR

import numpy as np

import matplotlib.pyplot as plt

In [None]:
means = [87.81586935763889, 96.97416420717593, 103.98142336697049]
stds = [51.67849701591506, 34.908630837585186, 29.465280593587384]


def imshow(img, norm_means, norm_stds):
    """
    shows an imagenet-normalized image on the screen
    """
    mean = torch.tensor(norm_means, dtype=torch.float32)
    std = torch.tensor(norm_stds, dtype=torch.float32)
    unnormalize = T.Normalize((-mean / std).tolist(), (1.0 / std).tolist())
    npimg = unnormalize(img).numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()


def reproducibility(config):
    SEED = int(config.seed)
    torch.manual_seed(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(SEED)
    if config.cuda:
        torch.cuda.manual_seed(SEED)


def device_as(t1, t2):
    """
    Moves t1 to the device of t2
    """
    return t1.to(t2.device)


def define_parameter_groups(model, weight_decay, optimizer_name):
    def exclude_from_weight_decay_and_adaptation(name):
        if "bn" in name:
            return True
        if optimizer_name == "lars" and "bias" in name:
            return True

    param_groups = [
        {
            "params": [
                p
                for name, p in model.named_parameters()
                if not exclude_from_weight_decay_and_adaptation(name)
            ],
            "weight_decay": weight_decay,
            "layer_adaptation": True,
        },
        {
            "params": [
                p
                for name, p in model.named_parameters()
                if exclude_from_weight_decay_and_adaptation(name)
            ],
            "weight_decay": 0.0,
            "layer_adaptation": False,
        },
    ]
    return param_groups


def default(val, def_val):
    return def_val if val is None else val

## Augmentation

In [None]:
class Augment:
    """
    a probabilistic data augmentation module
    Transforms any given data example randomly
    resulting in two correlated views of the same example,
    denoted x_i and  x_j which we consider a positive pair.
    """

    def __init__(self, img_size, norm_means, norm_stds, s=1):
        color_jitter = T.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
        # 10% of the image
        blur = T.GaussianBlur(
            kernel_size=(
                3,
                3,
            ),
            sigma=(0.1, 0.2),
        )

        self.train_transform = T.Compose(
            [
                # Crop image on a random scale from 7% tpo 100%
                T.RandomResizedCrop(size=img_size),
                # Flip image horizontally with 50% probability
                T.RandomHorizontalFlip(p=0.5),
                # Apply heavy color jitter with 80% probability
                T.RandomApply([color_jitter], p=0.8),
                # Apply gaussian blur with 50% probability
                T.RandomApply([blur], p=0.5),
                # Convert RGB images to grayscale with 20% probability
                T.RandomGrayscale(p=0.2),
                T.ToTensor(),
                T.Normalize(
                    mean=norm_means,
                    std=norm_stds,
                ),
            ]
        )

    def __call__(self, x):
        return self.train_transform(x), self.train_transform(x)

## Model ##

In [None]:
class AddProjection(nn.Module):
    def __init__(self, config, model=None, mlp_dim=512):
        super(AddProjection, self).__init__()
        embedding_size = config.embedding_size
        self.backbone = default(
            model, models.resnet18(weights=None, num_classes=config.embedding_size)
        )
        mlp_dim = default(mlp_dim, self.backbone.fc.in_features)
        print("DIM MLP input:", mlp_dim)
        self.backbone.fc = nn.Identity()

        # add mlp projection head
        self.projection = nn.Sequential(
            nn.Linear(in_features=mlp_dim, out_features=mlp_dim),
            nn.BatchNorm1d(mlp_dim),
            nn.ReLU(),
            nn.Linear(in_features=mlp_dim, out_features=embedding_size),
            nn.BatchNorm1d(embedding_size),
        )

    def forward(self, x, return_embedding=False):
        embedding = self.backbone(x)
        if return_embedding:
            return embedding
        return self.projection(embedding)

## Training ##

In [None]:
class SimCLRTraining(pl.LightningModule):
    def __init__(self, config, norm_means, norm_stds, model=None, feat_dim=512):
        super().__init__()
        self.config = config
        self.augment = Augment(
            config.img_size, norm_means=norm_means, norm_stds=norm_stds
        )
        self.model = AddProjection(config, model=model, mlp_dim=feat_dim)

        self.loss = InfoNceLoss(temperature=self.config.temperature)

    def forward(self, batch, *args, **kwargs) -> torch.Tensor:
        return self.model(batch)

    def training_step(self, batch, batch_idx, *args, **kwargs) -> torch.Tensor:
        (x1, x2), labels = batch
        z1 = self.model(x1)
        z2 = self.model(x2)
        loss = self.loss(z1, z2)
        self.log(
            "InfoNCE loss",
            loss,
            on_step=True,
            on_epoch=True,
            prog_bar=True,
            logger=True,
        )
        return loss

    def configure_optimizers(self):
        wd = self.config.weight_decay
        lr = self.config.lr
        max_epochs = int(self.config.epochs)
        param_groups = define_parameter_groups(
            self.model, weight_decay=wd, optimizer_name="adam"
        )
        optimizer = Adam(param_groups, lr=lr, weight_decay=wd)

        print(
            f"Optimizer Adam, "
            f"Learning Rate {lr}, "
            f"Effective batch size {self.config.batch_size * self.config.gradient_accumulation_steps}"
        )

        scheduler_warmup = LinearWarmupCosineAnnealingLR(
            optimizer, warmup_epochs=10, max_epochs=max_epochs, warmup_start_lr=0.0
        )
        return [optimizer], [scheduler_warmup]

## Loss ##

In [None]:
class InfoNceLoss(nn.Module):
    """
    InfoNCE loss as in SimCLR paper
    """

    def __init__(self, temperature=0.5):
        super().__init__()
        self.temperature = temperature

    @staticmethod
    def calc_similarity_batch(a, b):
        representations = torch.cat([a, b], dim=0)
        return F.cosine_similarity(
            representations.unsqueeze(1), representations.unsqueeze(0), dim=2
        )

    def forward(self, proj_1, proj_2):
        """
        proj_1 and proj_2 are batched embeddings [batch, embedding_dim]
        where corresponding indices are pairs
        z_i, z_j as in the SimCLR paper
        """
        assert proj_1.shape == proj_2.shape, "Projections' shapes need to match"
        batch_size = proj_1.shape[0]
        mask = (~torch.eye(batch_size * 2, batch_size * 2, dtype=torch.bool)).float()

        z_i = F.normalize(proj_1, p=2, dim=1)
        z_j = F.normalize(proj_2, p=2, dim=1)

        similarity_matrix = self.calc_similarity_batch(z_i, z_j)

        sim_ij = torch.diag(similarity_matrix, batch_size)
        sim_ji = torch.diag(similarity_matrix, -batch_size)

        positives = torch.cat([sim_ij, sim_ji], dim=0)

        nominator = torch.exp(positives / self.temperature)

        print("mask", device_as(mask, similarity_matrix).shape)
        print("exp", torch.exp(similarity_matrix / self.temperature).shape)
        denominator = device_as(mask, similarity_matrix) * torch.exp(
            similarity_matrix / self.temperature
        )

        all_losses = torch.log(nominator / torch.sum(denominator, dim=1))
        loss = torch.sum(all_losses) / (2 * batch_size)
        return loss

In [None]:
# Machine setup
available_gpus = torch.cuda.device_count()
save_model_path = os.path.join(os.getcwd(), "saved_models/")
print("available_gpus:", available_gpus)

# Run setup
filename = "SimCLR_ResNet18_adam"
save_name = filename + ".ckpt"
resume_from_checkpoint = False


# Model Setup
class Hparams:
    def __init__(self):
        self.epochs = 1  # number of training epochs
        self.seed = 1234  # randomness seed
        self.cuda = False  # use nvidia gpu
        self.img_size = 64  # image shape
        self.save = "./saved_models/"  # save checkpoint
        self.gradient_accumulation_steps = 1  # gradient accumulation steps
        self.batch_size = 64
        self.lr = 1e-3
        self.embedding_size = 128  # papers value is 128
        self.temperature = 0.5  # 0.1 or 0.5
        self.weight_decay = 1e-6


train_config = Hparams()
reproducibility(train_config)

model = SimCLRTraining(
    config=train_config,
    model=models.resnet18(weights=None),
    feat_dim=512,
    norm_means=means,
    norm_stds=stds,
)

transform = Augment(train_config.img_size, norm_means=means, norm_stds=stds)

dataset = EuroSAT("./", transform=transform, download=True)
data_loader = DataLoader(
    dataset=dataset, batch_size=train_config.batch_size, num_workers=cpu_count()
)

# Needed to get simulate a large batch size
accumulator = GradientAccumulationScheduler(scheduling={0: 1})

checkpoint_callback = ModelCheckpoint(
    filename=filename,
    dirpath=save_model_path,
    every_n_epochs=2,
    save_last=True,
    save_top_k=2,
    monitor="InfoNCE loss_epoch",
    mode="min",
)

if resume_from_checkpoint:
    trainer = Trainer(
        callbacks=[accumulator, checkpoint_callback],
        gpus=available_gpus,
        max_epochs=train_config.epochs,
        resume_from_checkpoint=train_config.checkpoint_path,
    )
else:
    trainer = Trainer(
        callbacks=[accumulator, checkpoint_callback],
        gpus=available_gpus,
        max_epochs=train_config.epochs,
    )

trainer.fit(model, data_loader)
trainer.save_checkpoint(save_name)