<a href="https://colab.research.google.com/github/ArthurCTLin/Workbook/blob/main/SimCLR/SimCLR.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Self-supervised Learning with SimCLR (Pytorch)

### Paper Concept
* **Showed the importance of composition of data augmentation :**
    * The composition of mulitple data augmentation operation can enhance up the performance constrative prediction task which try to obtain more useful representations.
    * Unsupervised contrastive learning benefits from stronger data augmentation than supervised learning.
* **Introduced a trainable nonlinear MLP structure ( $g(*)$ in the left below figure ) between representation and constative loss to improve the quality of representation.**
* **Normalized embeddings** and an appropriately adjusted **temperature parameter** can enhance up the performance of representation learning with contrastive cross entropy loss.
* **Contrastive learning benefits from larger batch sizes and more training steps/epochs.**
<img src="https://i.imgur.com/hhSyTPd.png" width=50%><img src="https://i.imgur.com/G95735O.gif" width=50%>

$\qquad \qquad \qquad \quad \quad \qquad$**(Figure Source: [Paper](https://arxiv.org/pdf/2002.05709.pdf))** $\quad \qquad \qquad \qquad \qquad \qquad \quad \quad \qquad$**(Figure Source: [SimCLR Github](https://github.com/google-research/simclr))**

### Import Libraries

In [None]:
import numpy as np
import pandas as pd
import shutil, time, os, requests, random, copy

# Torch or Torchvision
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import torchvision
from torchvision import datasets, transforms, models
from torchvision.datasets import STL10
from torchvision.datasets import DatasetFolder

# Plotting
import matplotlib.pyplot as plt
from PIL import Image
%matplotlib inline

from tqdm import tqdm

from sklearn.manifold import TSNE

# Reproduction
myseed = 42069
torch.backends.cudnn.deterministic = True 
torch.backends.cudnn.benchmark = False
np.random.seed(myseed)
torch.manual_seed(myseed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(myseed)
    
# Device
NUM_WORKERS = os.cpu_count()
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print("Device: ", device)
print("Number of workers: ", NUM_WORKERS)

### Transformation
As mentioned in contributions, the composition of mulitple data augmentation operation can enhance up the performance of SimCLR. Therefore, the following operations are applied:
* [**ColorJitter:**](https://pytorch.org/vision/main/generated/torchvision.transforms.ColorJitter.html) brightness, contrast, saturation, hue 
* [**RandomResizedCrop**](https://pytorch.org/vision/main/generated/torchvision.transforms.RandomResizedCrop.html#torchvision.transforms.RandomResizedCrop)
* [**RandomHorizontalFlip**](https://pytorch.org/vision/main/generated/torchvision.transforms.RandomHorizontalFlip.html#torchvision.transforms.RandomHorizontalFlip)
* [**RandomGrayscale**](https://pytorch.org/vision/main/generated/torchvision.transforms.RandomGrayscale.html#torchvision.transforms.RandomGrayscale)

$\qquad \qquad \qquad \qquad$![](https://i.imgur.com/w6WVWDj.png)

In [None]:
class TransformsSimCLR:
    """
    A stochastic data augmentation module that transforms any given data example randomly
    resulting in two correlated views of the same example,
    denoted x ̃i and x ̃j, which we consider as a positive pair.
    """

    def __init__(self, size, s=1, prob=0.8):
        color_jitter = torchvision.transforms.ColorJitter(
            0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s
        )
        self.train_transform = torchvision.transforms.Compose(
            [
                torchvision.transforms.RandomResizedCrop(size=size),
                torchvision.transforms.RandomHorizontalFlip(),  # with 0.5 probability
                torchvision.transforms.RandomApply([color_jitter], p=prob),
                torchvision.transforms.RandomGrayscale(p=0.25*prob),
                transforms.GaussianBlur(kernel_size=9),
                torchvision.transforms.ToTensor(),
                transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
            ]
        )
        
    def __call__(self, x):
        return self.train_transform(x), self.train_transform(x)


### Dataset

* **STL-10 dataset** is an image recognition dataset for developing unsupervised feature learning, deep learning, self-taught learning algorithms. 
* **10 classes:** airplane, bird, car, cat, deer, dog, horse, monkey, ship, truck.
* **Components:**
  * Training: 5000 images (500 per class and split into 10 pre-defined folds)
  * Testing: 8000 images (800 per class)
  * Unlabeled: 100000 images for unsupervised learning.

In [None]:

DATASET_PATH = './'
unlabeled_data = STL10(root=DATASET_PATH, split='unlabeled', download=True,
              transform=TransformsSimCLR(size=96, s=1, prob=0.8))
train_data_contrast = STL10(root=DATASET_PATH, split='train', download=True,
                transform=TransformsSimCLR(size=96, s=1, prob=0.8))


batch_size = 256
train_loader = DataLoader(unlabeled_data_total, batch_size=batch_size, shuffle=True,
               drop_last=True, pin_memory=True, num_workers=NUM_WORKERS)
valid_loader = DataLoader(train_data_contrast, batch_size=batch_size, shuffle=False,
               drop_last=True, pin_memory=True, num_workers=NUM_WORKERS)

In [None]:
NUM_IMAGES = 6
imgs = torch.stack([img for idx in range(NUM_IMAGES) for img in unlabeled_data[idx][0]], dim=0)
img_grid = torchvision.utils.make_grid(imgs, nrow=6, normalize=True, pad_value=0.9)
img_grid = img_grid.permute(1, 2, 0)

plt.figure(figsize=(10,5))
plt.title('Augmented image examples of the STL10 dataset')
plt.imshow(img_grid)
plt.axis('off')
plt.show()
plt.close()

### Model
* *Base encoder $f(*)$ :* ResNet-50
* *Projection head $g(*) :$* 2-layer MLP projection with 128-dimensional latent space.

In [None]:
class SimCLR(nn.Module):

    def __init__(self, encoder):
        super(SimCLR, self).__init__()

        self.encoder = encoder

        dim_mlp = self.encoder.fc.in_features
        self.encoder.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp),
                          nn.ReLU(),
                          self.encoder.fc
                          )
                    
    def forward(self, x):
        return self.encoder(x)

In [None]:
projection_dim = 128
encoder = models.resnet50(pretrained=False, num_classes=projection_dim)
model = SimCLR(encoder).to(device)

### Loss
Randomly sample a batch of N examples and obtain 2N data points (as mentioned above, one example is transformed into two correlated views of the given image).
* The pair from the same example is defined as **positive pair**.
* **Loss function** of the positive pair is defined:
$$l_{i, j} = -log \frac{exp(sim(z_{i}, z_{j})/\tau)}{\sum^{2N}_{k=1} 1_{k\neq i}exp(sim(z_{i}, z_{j})/\tau)}$$
  * $sim(x, y)$ is the consine similarity. $sim(x, y)=\frac{x^{T}y}{||x|| ||y||}$.
  * $1_{k\neq i}$ : is an indicator function evaluating to 1 iff $k\neq i$.
  * $\tau$ : temperature parameter.
  * The loss is based on **Infomation Noise Contrastive Estimation Loss (InfoNCE Loss)**.

In [None]:
def info_nce_loss(features):

        labels = torch.cat([torch.arange(batch_size) for i in range(2)], dim=0)
        labels = (labels.unsqueeze(0) == labels.unsqueeze(1)).float()
        labels = labels.to(device)

        features = F.normalize(features, dim=1)

        similarity_matrix = torch.matmul(features, features.T)
        
        mask = torch.eye(labels.shape[0], dtype=torch.bool).to(device)
        labels = labels[~mask].view(labels.shape[0], -1)
        similarity_matrix = similarity_matrix[~mask].view(similarity_matrix.shape[0], -1)
        
        positives = similarity_matrix[labels.bool()].view(labels.shape[0], -1)

        negatives = similarity_matrix[~labels.bool()].view(similarity_matrix.shape[0], -1)

        logits = torch.cat([positives, negatives], dim=1)
        labels = torch.zeros(logits.shape[0], dtype=torch.long).to(device)

        logits = logits / 0.5 #temperature parameter
        return logits, labels

### Optimizer (Optional)
* **Layer-wise Adaptive Rate Scaling (LARS)** 
* This optimizer is more suitable to the training with large batch size 
* The source paper: [Large batch training of convolutional networks](https://arxiv.org/abs/1708.03888)
* ***However, due to the limited memory in kaggle, 256 is the largest batch size I can set. Therefore, this technique is not adopted in the implementation.***

In [None]:
from torch.optim.optimizer import Optimizer, required
import re

EETA_DEFAULT = 0.001


class LARS(Optimizer):
    """
    Layer-wise Adaptive Rate Scaling for large batch training.
    Introduced by "Large Batch Training of Convolutional Networks" by Y. You,
    I. Gitman, and B. Ginsburg. (https://arxiv.org/abs/1708.03888)
    """

    def __init__(
        self,
        params,
        lr=required,
        momentum=0.9,
        use_nesterov=False,
        weight_decay=0.0,
        exclude_from_weight_decay=None,
        exclude_from_layer_adaptation=None,
        classic_momentum=True,
        eeta=EETA_DEFAULT,
    ):
        """Constructs a LARSOptimizer.
        Args:
        lr: A `float` for learning rate.
        momentum: A `float` for momentum.
        use_nesterov: A 'Boolean' for whether to use nesterov momentum.
        weight_decay: A `float` for weight decay.
        exclude_from_weight_decay: A list of `string` for variable screening, if
            any of the string appears in a variable's name, the variable will be
            excluded for computing weight decay. For example, one could specify
            the list like ['batch_normalization', 'bias'] to exclude BN and bias
            from weight decay.
        exclude_from_layer_adaptation: Similar to exclude_from_weight_decay, but
            for layer adaptation. If it is None, it will be defaulted the same as
            exclude_from_weight_decay.
        classic_momentum: A `boolean` for whether to use classic (or popular)
            momentum. The learning rate is applied during momeuntum update in
            classic momentum, but after momentum for popular momentum.
        eeta: A `float` for scaling of learning rate when computing trust ratio.
        name: The name for the scope.
        """

        self.epoch = 0
        defaults = dict(
            lr=lr,
            momentum=momentum,
            use_nesterov=use_nesterov,
            weight_decay=weight_decay,
            exclude_from_weight_decay=exclude_from_weight_decay,
            exclude_from_layer_adaptation=exclude_from_layer_adaptation,
            classic_momentum=classic_momentum,
            eeta=eeta,
        )

        super(LARS, self).__init__(params, defaults)
        self.lr = lr
        self.momentum = momentum
        self.weight_decay = weight_decay
        self.use_nesterov = use_nesterov
        self.classic_momentum = classic_momentum
        self.eeta = eeta
        self.exclude_from_weight_decay = exclude_from_weight_decay
        # exclude_from_layer_adaptation is set to exclude_from_weight_decay if the
        # arg is None.
        if exclude_from_layer_adaptation:
            self.exclude_from_layer_adaptation = exclude_from_layer_adaptation
        else:
            self.exclude_from_layer_adaptation = exclude_from_weight_decay

    def step(self, epoch=None, closure=None):
        loss = None
        if closure is not None:
            loss = closure()

        if epoch is None:
            epoch = self.epoch
            self.epoch += 1

        for group in self.param_groups:
            weight_decay = group["weight_decay"]
            momentum = group["momentum"]
            eeta = group["eeta"]
            lr = group["lr"]

            for p in group["params"]:
                if p.grad is None:
                    continue

                param = p.data
                grad = p.grad.data

                param_state = self.state[p]

                # TODO: get param names
                # if self._use_weight_decay(param_name):
                grad += self.weight_decay * param

                if self.classic_momentum:
                    trust_ratio = 1.0

                    # TODO: get param names
                    # if self._do_layer_adaptation(param_name):
                    w_norm = torch.norm(param)
                    g_norm = torch.norm(grad)

                    device = g_norm.get_device()
                    trust_ratio = torch.where(
                        w_norm.gt(0),
                        torch.where(
                            g_norm.gt(0),
                            (self.eeta * w_norm / g_norm),
                            torch.Tensor([1.0]).to(device),
                        ),
                        torch.Tensor([1.0]).to(device),
                    ).item()

                    scaled_lr = lr * trust_ratio
                    if "momentum_buffer" not in param_state:
                        next_v = param_state["momentum_buffer"] = torch.zeros_like(
                            p.data
                        )
                    else:
                        next_v = param_state["momentum_buffer"]

                    next_v.mul_(momentum).add_(scaled_lr, grad)
                    if self.use_nesterov:
                        update = (self.momentum * next_v) + (scaled_lr * grad)
                    else:
                        update = next_v

                    p.data.add_(-update)
                else:
                    raise NotImplementedError

        return loss

    def _use_weight_decay(self, param_name):
        """Whether to use L2 weight decay for `param_name`."""
        if not self.weight_decay:
            return False
        if self.exclude_from_weight_decay:
            for r in self.exclude_from_weight_decay:
                if re.search(r, param_name) is not None:
                    return False
        return True

    def _do_layer_adaptation(self, param_name):
        """Whether to do layer-wise learning rate adaptation for `param_name`."""
        if self.exclude_from_layer_adaptation:
            for r in self.exclude_from_layer_adaptation:
                if re.search(r, param_name) is not None:
                    return False
        return True

### Optimizer, Scheduler and Loss Function Declaration

In [None]:
#OPTMIZER

for param in model.parameters():
    param.requires_grad = True
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, weight_decay=1e-4)

# "decay the learning rate with the cosine decay schedule without restarts"
#SCHEDULER OR LINEAR EWARMUP
warmupscheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda epoch : (epoch+1)/10.0, verbose = True)

#SCHEDULER FOR COSINE DECAY
mainscheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 500, eta_min=0.05, last_epoch=-1, verbose = True)


scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 
                               T_max=len(train_loader), 
                               eta_min=0,
                               last_epoch=-1)

criterion = torch.nn.CrossEntropyLoss()

### Training
Since the limited running time (12 hrs) and memory of kaggle gpu resource. I just set the number of training epoch as 50. The result with more epochs is conducted with resuming learning and will be demonstrated afterwards.

In [None]:
model_path = "model.ckpt"

nr = 0
current_epoch = 0
epochs = 50
tr_loss = []
val_loss = []

best_loss = 9999.0
train_loss_report = 0.0
valid_loss_report = 0.0
train_loss_record = []
valid_loss_record = []

for epoch in range(epochs):
        
    print(f"Epoch [{epoch}/{epochs}]\t")
    stime = time.time()

    model.train()
    tr_loss_epoch = 0
    train_loss = []
    
    for batch in tqdm(train_loader):
        imgs, _ = batch
        imgs = torch.cat(imgs, dim=0).to(device)
        
        features = model(imgs)
        logits, labels = info_nce_loss(features)
        loss = criterion(logits, labels)
        optimizer.zero_grad()
        loss.backward()

        optimizer.step()
    
        tr_loss_epoch += loss.item()
        train_loss.append(loss.item())
        
    train_loss_report = sum(train_loss) / len(train_loss)
    print(f"[ Train | {epoch + 1:03d}/{epochs:03d} ] loss = {train_loss_report:.5f}")
    
    if epoch >= 10:
        scheduler.step()

    model.eval()
    with torch.no_grad():
        val_loss_epoch = 0
        valid_loss = []
        for batch in tqdm(valid_loader):

            imgs, _ = batch
            imgs = torch.cat(imgs, dim=0).to(device)
        
            features = model(imgs)
            logits, labels = info_nce_loss(features)
            loss = criterion(logits, labels)

            val_loss_epoch += loss.item()
            valid_loss.append(loss.item())
        
    valid_loss_report = sum(valid_loss) / len(valid_loss)
    print(f"[ Valid | {epoch + 1:03d}/{epochs:03d} ] loss = {valid_loss_report:.5f}")

    time_taken = (time.time()-stime)/60
    print(f"Epoch [{epoch}/{epochs}]\t Time Taken: {time_taken} minutes")
    
    if valid_loss_report < best_loss:
        best_loss = valid_loss_report
        torch.save({'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': loss}, 
                    model_path)
        print("The model is save!")
    
    train_loss_record.append(train_loss_report)
    valid_loss_record.append(valid_loss_report)

### Plotting

In [None]:
import matplotlib.pyplot as plt

x = np.arange(len(train_loss_record))
plt.plot(x, train_loss_record, color="blue", label="Train")
plt.plot(x, valid_loss_record, color="red", label="Valid")
plt.legend(loc="upper right")
plt.show()

### Result
$\qquad \qquad \qquad \qquad$<img src="https://i.imgur.com/46oniYI.png" width=50%>

$\qquad\quad \qquad \quad \quad\qquad \qquad \qquad \qquad \quad$ **The loss of training and validation**