# Colab pretrain notebook

## How to use 

It should run as it is provided You are part of the shared folder named SAL. 

Before running make sure : 
* All of the paths match - I have marked them with ```#$01 ``` so just search for them in this file. 
* make sure that the path in the config file (stored in the drive) match the path of the checkpoint (file that ends  ```.pt``` ). 
* If Resuming from checkpoint set flag to ```true```
* If starting a new Run from zero chenge the name of the ```.pt``` file in the config.json file in order not to override past checkpoints

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
#$01
path_to_checkpoint = "/content/drive/MyDrive/SAL/"

In [3]:
import argparse
import json
import os

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms
from torch.cuda.amp.grad_scaler import GradScaler
from torch.utils.data import DataLoader

In [4]:
class _args(): 
    ...
    #$01
    config = "/content/drive/MyDrive/SAL/pretrain.config.json"
    device = "gpu"
    model = "resnet18"
    num_epochs = 800
    encoder_dim = 2048
    pred_dim = 512
    batch_size = 512
    num_workers=1


args = _args()
device = "cuda"

In [5]:
class BackBoneEncoder(nn.Module) : 
  '''
  SimSiam - Simple Siamese https://arxiv.org/pdf/2011.10566.pdf  
  '''

  def __init__(self, base_encoder, dim, pred_dim, in_pretrain = True):

    super(BackBoneEncoder, self).__init__()
    self.in_pretrain = in_pretrain
    self.encoder = base_encoder(num_classes = dim, zero_init_residual=True) 

    self.last_dim = self.encoder.fc.in_features

    self.encoder.fc = nn.Sequential(
        nn.Linear(self.last_dim, self.last_dim, bias = False), 
        nn.BatchNorm1d(self.last_dim),
        nn.Mish(inplace = True),
        nn.Linear(self.last_dim, self.last_dim, bias = False), 
        nn.BatchNorm1d(self.last_dim),
        nn.Mish(inplace=True),
        self.encoder.fc,
        nn.BatchNorm1d(dim, affine=False)
    )
    self.encoder.fc[6].bias.requires_grad = False

    self.predictor = nn.Sequential(
        nn.Linear(dim, pred_dim, bias = False), 
        nn.BatchNorm1d(pred_dim),
        nn.Mish(inplace = True), 
        nn.Linear(pred_dim, dim) 
    )

  def forward(self, x1, x2 = None) : 

    if x2 is None and self.in_pretrain: 
      raise ValueError("Expected 2 images but got 1 -> Are you in train or pretrain")

    if self.in_pretrain: 
      z1 = self.encoder(x1)
      z2 = self.encoder(x2)

      p1 = self.predictor(z1)
      p2 = self.predictor(z2)

      # Detach as the stop gradient operation
      return p1, p2, z1.detach(), z2.detach()
    
    else :
      return self.encoder(x1)

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.history = dict()
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0
        self.history = dict()

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} Avg Over Epoch ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)

In [6]:
# Apply augmentations twice for each image
class TwoCropTransform:
    def __init__(self, base_transform):
        self.base_transform = base_transform

    def __call__(self, image):
        q = self.base_transform(image)
        k = self.base_transform(image)

        return [q, k]


# Training Loop
def train(model, criterion, optimizer, loader, epoch, scaler, scheduler):
    losses = AverageMeter("Loss", ":.4f")

    for _, (images, _) in enumerate(loader):

        optimizer.zero_grad()

        # Conversion to fp16
        with torch.autocast("cuda", dtype=torch.float16):
            images[0] = images[0].to(device, non_blocking=True)
            images[1] = images[1].to(device, non_blocking=True)
            p1, p2, z1, z2 = model(images[0], images[1])
            
            # Loss as described in the original paper, note z's do not contribute to grad
            loss = -(criterion(p1, z2).mean() + criterion(p2, z1).mean()) * 0.5

        losses.update(loss.item(), images[0].shape[0])

        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

    print(f"Epoch : {epoch} , {losses}")
    return losses.avg


def main():

    with open(args.config, "r") as f:
        config = json.load(f)

    model = BackBoneEncoder(
        models.__dict__[args.model],
        args.encoder_dim,
        args.pred_dim,
        in_pretrain=True,
    ).to(device)

    augmentation = [
        transforms.RandomResizedCrop(32, scale=(0.2, 1.0)),
        transforms.RandomApply([transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.6),
        transforms.RandomHorizontalFlip(),
        transforms.RandomGrayscale(p=0.2),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ]

    batch_size = args.batch_size
    current_epoch = 0
    num_epochs = args.num_epochs

    train_data = datasets.CIFAR10(
        root="data",
        train=True,
        download=True,
        transform=TwoCropTransform(transforms.Compose(augmentation)),
    )

    loader = DataLoader(
        train_data, batch_size=batch_size, shuffle=True, num_workers=args.num_workers
    )

    criterion = nn.CosineSimilarity(dim=-1).to("cuda")

    optim_config = config["optimizer"]
    optimizer = optim.SGD(
        model.parameters(),
        lr=optim_config["lr"],
        momentum=optim_config["momentum"],
        weight_decay=optim_config["weight_decay"],
    )

    scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer, T_0=100, verbose=True
    )

    if config["checkpoint"]:

        checkpoint = torch.load(config["path_to_checkpoint"])
        model.load_state_dict(checkpoint["model"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        scheduler.load_state_dict(checkpoint["scheduler"])
        current_epoch = checkpoint["epoch"]
        loss_history = checkpoint["loss"]

        print(
            f"Resuming Training from Epoch {current_epoch}, Last Loss {loss_history[-1]}"
        )

    model.train()

    scaler = GradScaler()

    loss_history = list()

    for epoch in range(current_epoch, num_epochs):
        print(f"Epoch {epoch}")

        avg_epoch_loss = train(
            model, criterion, optimizer, loader, epoch, scaler, scheduler
        )

        scheduler.step(epoch)

        loss_history.append(avg_epoch_loss)

        torch.save(
            {
                "epoch": epoch,
                "model": model.state_dict(),
                "optimizer": optimizer.state_dict(),
                "loss": loss_history,
                "scheduler": scheduler.state_dict(),
            },
            config["path_to_checkpoint"],
        )
        config["checkpoint"] = True

        with open(args.config, "w") as out:
            json.dump(config, out, indent=4)
        print("Checkpoint Saved")


In [7]:
main()

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to data/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting data/cifar-10-python.tar.gz to data
Epoch 00000: adjusting learning rate of group 0 to 6.0000e-02.
Epoch 0
Epoch : 0 , Loss -0.5249 Avg Over Epoch (-0.3119)
Epoch 00000: adjusting learning rate of group 0 to 6.0000e-02.
Checkpoint Saved
Epoch 1


KeyboardInterrupt: ignored