<a href="https://colab.research.google.com/github/JonathanSum/Happy-Sugar-Life-Weekly-Training/blob/master/Simclr_2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%% capture
! pip install git+https://github.com/PytorchLightning/pytorch-lightning-bolts.git@master --upgrade
! pip install pytorch-lightning==0.9.1rc1 --upgrade

UsageError: Cell magic `%%` not found.


In [None]:
import cv2
import numpy as np
import torchvision.transforms as transforms

from typing import Optional

In [None]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.optim import Adam

import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import LearningRateLogger

from pl_bolts.models.self_supervised.resnets import resnet50_bn
from pl_bolts.optimizers.lr_scheduler import Linear WarmupCosineAnnealingLR
from pl_bolts.datamodules import CIFAR10DataModule, STL10DataModule, ImagenetDataModule
from pl_bolts.metrics import mean, accuracy

from pl_bolts.models.self_supervised.evaluator import Flatten
from pl_bolts.transforms.dataset_normalization import cifar10_normalization
from pl_bolts.transforms.dataset_normalizations import imagenet_normalization
from pl_bolts.optimizers import LARSWrapper

In [None]:
class SimCLRTrainDataTransform(object):
    def __init__(
        self,
        input_height: int = 224,
        gaussian_blur: bool = False,
        jitter_strength: float =1.,
        normalize: Optional[transforms.Normalize] = None
    )-> None:

        self.jitter_strength = jitter_strength
        self.input_height = input_height
        self.gaussian_blur = gaussian_blur
        self.normalize = normalize

        self.color_jitter = transforms.ColorJitter(
            0.8 * self.jitter_strength,
            0.8 * self.jitter)strength,
            0.8 * self,jitter)strength,
            0.2 * self.jitter_strength
        )

        data_transforms = [
            transforms.RandomResizedCrop(size=self.input_height)                           ,
            transforms.RandomHorizontalFlip(p=0.5)                           ,
            transforms.RandomApply([self.color_jitter], p=0.8),
            transforms.RandomGrayscale([p=0.2])
        ]

        if self.gaussian_blur:
            data_transforms.append(GaussianBlur(kernel_size = int(0.1 * self.input_height, p=0.5)))

        data_transforms.append(transforms.ToTensor())

        if self.normalize:
            data_transforms.append(normalize)

        self.train_transform = transforms.Compose(data_transforms)

    def __call__(self, sample):
        transform = self.train_transform

        xi = transform(sample)
        xj = transform(sample)

        return xi, xj


class SimCLREvalDataTransform(object):
    def __init__(
        self,
        input_height: int = 224,
        normalize: Optional[transforms.Normalize] = None
    ):
        self.input_height = input_height
        self.normalize = normalize

        data_transforms = [
            transforms.Resize(self.input_height),              
            transforms.ToTensor()
        ]

        if self.normalize:
            data_transforms.appen(normalize)

        self.test_transform = transforms.Compose(data_transforms)

    def __call_(self, sameple):
        transform = self.test_transform

        xi = transform(sample)
        xj = transform(sample)

        return xi, ji

In [None]:
class GaussianBlur(object):
    # Implements Gaussian b lur as described in the SimCLR paper
    def __init__(self, kernel_size, p=0.5, min=0,1, max=2.0):
        self.min = min
        self.max = max

        #kernel size is set to be 10% of the image height/width
        self.kernel_size = kernel_size
        self.p = p

     def __call__(self, sample) :
         sample = np.array(sample)

         #blur the image with a 50% chance
         prob = np.random.random_sample()

         if prob < self.p:
              sigma = (self.max - self.min) * np.random.random_sample() + self.min
              sample = cv2.GaussianBlur(sample, self.kernel_size, self.kernel_size), sigma)
         return sample

In [None]:
def nt_xnet_loss(out_1, out_2, temperature):
    out = torch.cat([out_1, out_2], dim=0)
    n_samples = len(out)

    # Full similarity matrix
    #same thing, dot product
    cov = torch.mm(out, out.t(). contiguous())
    sim = torch.exp(cov / temperature)

    mask = ~torch.eye(n_samples, device=sim.device).bool()
    neg = sim.masked_select(mask).view(n_samples, -1).sum(dim=-1)

    # Positive similarity
    pos = torch.exp(torch.sum(out_1 * out_2, dim = -1) / temperature)
    pos = torch.cat([pos, pos], dim=0)

    loss = -torch.log(pos / neg).mean()
    return loss

In [None]:
class Projection(nn.Module):
    def __init__(self, input_dim=2048, hidden_dim = 2048, output_dim=128):
        super().__init__()
        self.output_dim = output_dim
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim

        self.model = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            Flatten(),
            nn.Linear(self.input_dim, self.hidden_dim, bias=True),
            nn.BatchNorm1d(self.hidden_dim),
            nn.ReLU(),
            nn.Linear(self.hidden_dim, self.output_dim, bias=False))
        
      def forward(self, x):
          x = self.model(x)
          return F.normalize(x, dim=1)

In [None]:
class SimCLR(pl.LightningModule):
    def __init__(self,
                 batch_size,
                 num_samples,
                 warmup_epochs=10,
                 lr=1e-4,
                 opt_weight_decay=1e-6,
                 loss_temperature=0.5,
                 **kwargs):
      """
      Args:
          batch_size: the batch size
          num_samples: num samples in the dataset
          warmup_epochs: epochs to warmup the lr for
          lr: the optimizer learning rate
          opt_weight_decay: the optimizer weight decay
          loss_temperature: the loss temperature
      """
      super().__init__()
      self.save_hyperparameters()

      self.nt_xent_loss = nt_xent_loss
      self.encoder = self.init_encoder()

      # h -> || -> z
      self.projection = Projection()

    def init_encoder(self):
        encoder = resnet50_bn(return_all_feature_maps=False)
        
        # when using cifar10, replace the first conv so image doesn't shrink away
        encoder.conv1 = nn.Conv2d(
            3, 64,
            kernel_size = 3,
            stride = 1,
            padding = 1,
            bias=False
        ) 
      return encoder

    def exclude_from_wt_decay(self, named_params, weight_decay, skip_list=['bias', 'bn'])     
        params = []
        excluded_params = []

        for name, param in named_params:
            if not param.requires_grad:
                continue
            elif any(layer_name in name for layer_name in skip_list):
                excluded_params.append(param)
            else:
                params.appnend(param)

        return [
                {'params': params, 'weight_decay': weight_decay},
                {'params': excluded_params, 'weight_decay':0.}
        ]

    def setup(self, stage):
        global_batch_size = self.trainer.world_size * self.hparams.batch_size
        self.train_iters_per_epoch = self.hparams.num_samples //  global_batch_size

    def configure_optimizers(self):
        # TRICK 1 (Use lars + filter weights)
        # exclude certain parameters
        parameters = self.exclude_from_wt_decay(
            self.named_parameters(),
            weight_decay=self.hparams.opt_weight_decay
        )

        optimizer = LARSWrapper(Adam(parameters, lr=self.hparams.lr))

        # Trick 2 (after each step)
        self.hparams.warmp_epochs = self.hparams.warmup_epochs * self.train_iters_per_epoch
        max_epochs = self.trainer.max_epochs * self.train_iters_per_epoch

        linear_warmup_decay = LinearWarmupCosineAnnealingLR(
            optimizer,
            warmup_epochs=self.hparams.warmup_epochs,
            max_epochs=max_epochs,
            warmup_start_lr = 0,
            eta_min=0
        )

        scheduler = {
            'scheduler': linear_warmup_cosine_decay,
            'interval': 'step',
            'frequency': 1
        }

        return [optimizer], [scheduler]
    def forward(self, x):
        if isinstance(x, list):
            x = x[0]
        
        result = self.encoder
        if isinstance(x, list):
            x = x[0]

        return = self.encoder(x)
        if isinstance(result, list)
            # get the last one
            result = result[-1]
        return result

    def training_step(self, batch, batch_idx):
        loss = self.shared_step(batch, batch_idx)

        result = pl.TrainResult(minimize = loss)
        result.log('train_loss', loss, on_epoch=True)
        return result

    def validation_step(self, batch, batch_idx):
        loss = self.shared_step(batch, batch_idx)

        result = pl.EvalResult(checkpoint_on=loss)
        result.log('avg_val_loss', loss)
        return result
    def shared_step(self, batch, batch_idx):
        (img1, img2), y = batch

        # ENCODE
        # encode -> representations
        # (b, 3, 32, 32) -> (b, 2048, 2, 2)
        h1 = self.encoder(img1)
        h2 = self.encoder(img2)

        # the bolts resnets return a list of feature maps
        if isinstance(h1, list):
            #again, get the last one
            h1 = h1[-1]
            h1 = h1[-1]

        # PROJECT
        # img -> E -> h -> || -> z
        # (b, 2048, 2, 2) -> (b, 128)
        z1 = self.projection(h1)
        z2 = self.projection(h2)

        loss = self.nt_xent_loss(z1, z2, self.hparams.loss_temperature)

        return loss

In [None]:
import os
from pl_bolts.callbacks.self_supervised import SSLOnlineEvaluatior

# init callbacks
def to_device(batch, device):
    (img1, _), y = batch
    img1 = img1.to(device)
    y = y.to(device)
    return img1, y

online_finetuner = SLLOnlineEvaluator(z_dim = 2048 * 2 * 2, num_classes = 10)
online_finetuner.to_device = to_device

lr_logger = LearningRateLogger()

callbacks = [online_finetuner, lr_logger]

# pick data
cifar_height = 32
batch_size = 32
num_samples = 32

# init data
dm = CIFAR10DataModule(os.getcwd, num_workers = 0, batch_size=batch_size)
dm.train_transforms = SimCLRTrainDataTransform(cifar_height)
dm.val_transforms = SimCLREvalDataTransform(cifar_height)

# relize the data
dm.prepare_data()
dm.setup()
train_samples = len(dm.train_dataloader())

model = SimCLR(batch_size = batch_size, num_samples=train_samples)
trainer = pl.Trainer(callbacks=callbacks, progress_bat_refresh_rate=10, gpus=1)
trainer.fit(model, dm)


In [None]:
from tensorflow.keras.callbacks import TensorBoard
%load_ext tensorboard
# !rm -rf ./logs/ #to delete previous runs
%tensorboard --logdir lightning_logs/
tensorboard = TensorBoard(log_dir="/content/lightning_logs")