<a href="https://colab.research.google.com/github/SauravMaheshkar/Self-Supervised-Learning/blob/main/notebooks/DINO/PyTorch_DINO.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**Author:** [@SauravMaheshkar](https://twitter.com/MaheshkarSaurav)

* No Mixed Precision
* No Gradient Clipping
* Doesn't the keep the output layer fixed

# 📦  Packages and Basic Setup
---

In [None]:
%%capture
!pip install -U rich

import os
import sys
import datetime
import time
import math
import random
from typing import Sequence, List, Iterable
from PIL import ImageFilter, ImageOps
from collections import defaultdict, deque

import numpy as np
from PIL import Image
from rich import print
from rich.progress import track
import torch
import torch.nn as nn
import torch.distributed as dist
import torch.nn.functional as F
from torchvision import datasets, transforms
from torchvision import models as torchvision_models

In [None]:
#@title ⚙ Configuration
saveckp_freq = 2 #@param {type: "number"}
random_seed = 42 #@param {type: "number"}

# ============ Random Seed ... ==========
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    
seed_everything(seed=random_seed)

# 🆘 Utility Classes and Functions

---

## 🖖 Utilites for Data Augmentation

In [None]:
class GaussianBlur(object):
    """
    Apply Gaussian Blur to the PIL image.
    """
    def __init__(self, p: float=0.5, radius_min: float =0.1, radius_max: float =2.) -> None:
        self.prob = p
        self.radius_min = radius_min
        self.radius_max = radius_max

    def __call__(self, img):
        do_it = random.random() <= self.prob
        if not do_it:
            return img

        return img.filter(
            ImageFilter.GaussianBlur(
                radius=random.uniform(self.radius_min, self.radius_max)
            )
        )

class Solarization(object):
    """
    Apply Solarization to the PIL image.
    """
    def __init__(self, p: float) -> None:
        self.p = p

    def __call__(self, img):
        if random.random() < self.p:
            return ImageOps.solarize(img)
        else:
            return img

## ✂️ Multi-Crop Strategy

in the multi-crop strategy from a given image (say $x$) a set is generated consisting of $V$ different views of the image i.e. $\large \{ x_i \}_{i=1}^{V}$. This set contains two "**global**" (standard resolution crops) say $\large x_1^g$ and $\large x_2^g$, and several other "**local**" (low resolution crops) views. Using low resolution crops  ensures only a small increase in compute cost.All the local crops are passed through the student whereas only the global views are passed through the teacher i.e. encouraging "**local-to-global**" correspondences. The following loss is minimized.

$$
\huge \displaystyle\min_{\theta_s} \hspace{1em} \sum_{x \in \{ x_1^g, x_2^g \}} \hspace{0.5em} \sum_{\begin{array}{}x' \in V \\ x' \neq x\end{array}} \hspace{0.75em} H(\, P_t(x), P_s(x')\,)
$$

In [None]:
class MultiCropWrapper(nn.Module):
    """
    Perform forward pass separately on each resolution input.
    The inputs corresponding to a single resolution are clubbed and single
    forward is run on the same resolution inputs. Hence we do several
    forward passes = number of different resolutions used. We then
    concatenate all the output features and run the head forward on these
    concatenated features.
    """
    def __init__(self, backbone: nn.Module, head: nn.Module) -> None:
        super(MultiCropWrapper, self).__init__()
        # disable layers dedicated to ImageNet labels classification
        backbone.fc, backbone.head = nn.Identity(), nn.Identity()
        self.backbone = backbone
        self.head = head

    def forward(self, x: List) -> torch.Tensor:
        # convert to list
        if not isinstance(x, list):
            x = [x]
        idx_crops = torch.cumsum(torch.unique_consecutive(
            torch.tensor([inp.shape[-1] for inp in x]),
            return_counts=True,
        )[1], 0)
        start_idx, output = 0, torch.empty(0).to(x[0].device)
        for end_idx in idx_crops:
            _out = self.backbone(torch.cat(x[start_idx: end_idx]))
            # The output is a tuple with XCiT model. See:
            # https://github.com/facebookresearch/xcit/blob/master/xcit.py#L404-L405
            if isinstance(_out, tuple):
                _out = _out[0]
            # accumulate outputs
            output = torch.cat((output, _out))
            start_idx = end_idx
        # Run the head forward on the concatenated features.
        return self.head(output)

## 📉 The Loss Function

given an input image $x$ the teacher network produces a vector of scores $\large s^t(x) = [\,s^{t}_{1}(x),s^{t}_{2}(x), ..., s^{t}_{K}(x) \,]$ which are then converted into ("**soft**") probabilities $\large p^{t}_{k}(x) = \frac{e^{s^{t}_{k}(x)}}{\sum_j e^{s^{t}_{k}(x)}}$. These probabilities are usually softened using temperature scaling , and the loss that the student trains for is a linear combination of the cross-entropy loss $\mathbb{L}_{cls}$ and a Knowledge Distillation Loss $\mathbb{L}_{KD}$ viz.

$$
\huge \displaystyle \mathbb{L}_{KD} = - \tau^2 \sum_{k} \tilde{p}^{t}_{k} (x) + log \tilde{p}^{s}_{k}(x)
$$
$$
\huge \displaystyle \mathbb{L} = \alpha \mathbb{L}_{cls} + (1-\alpha) \mathbb{L}_{KD}
$$

In [None]:
class DINOLoss(nn.Module):
    def __init__(self, out_dim, ncrops, warmup_teacher_temp, teacher_temp,
                 warmup_teacher_temp_epochs, nepochs, student_temp=0.1,
                 center_momentum=0.9):
        super(DINOLoss, self).__init__()
        self.student_temp = student_temp
        self.center_momentum = center_momentum
        self.ncrops = ncrops
        self.register_buffer("center", torch.zeros(1, out_dim))
        # we apply a warm up for the teacher temperature because
        # a too high temperature makes the training instable at the beginning
        self.teacher_temp_schedule = np.concatenate((
            np.linspace(warmup_teacher_temp,
                        teacher_temp, warmup_teacher_temp_epochs),
            np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp
        ))

    def forward(self, student_output, teacher_output, epoch):
        """
        Cross-entropy between softmax outputs of the teacher and student networks.
        """
        student_out = student_output / self.student_temp
        student_out = student_out.chunk(self.ncrops)

        # teacher centering and sharpening
        temp = self.teacher_temp_schedule[epoch]
        teacher_out = F.softmax((teacher_output - self.center) / temp, dim=-1)
        teacher_out = teacher_out.detach().chunk(2)

        total_loss = 0
        n_loss_terms = 0
        for iq, q in enumerate(teacher_out):
            for v in range(len(student_out)):
                if v == iq:
                    # we skip cases where student and teacher operate on the same view
                    continue
                loss = torch.sum(-q * F.log_softmax(student_out[v], dim=-1), dim=-1)
                total_loss += loss.mean()
                n_loss_terms += 1
        total_loss /= n_loss_terms
        self.update_center(teacher_output)
        return total_loss

    @torch.no_grad()
    def update_center(self, teacher_output):
        """
        Update center used for teacher output.
        """
        batch_center = torch.sum(teacher_output, dim=0, keepdim=True)
        # dist.all_reduce(batch_center)
        batch_center = batch_center / len(teacher_output)

        # ema update
        self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum)

## 🐙 Head

In [None]:
class DINOHead(nn.Module):
    def __init__(self, in_dim: int, out_dim: int, use_bn: bool=False, norm_last_layer: bool=True, nlayers: int=3, hidden_dim: int=2048, bottleneck_dim: int=256):
        super(DINOHead, self).__init__()
        nlayers = max(nlayers, 1)
        if nlayers == 1:
            self.mlp = nn.Linear(in_dim, bottleneck_dim)
        else:
            layers = [nn.Linear(in_dim, hidden_dim)]
            if use_bn:
                layers.append(nn.BatchNorm1d(hidden_dim))
            layers.append(nn.GELU())
            for _ in range(nlayers - 2):
                layers.append(nn.Linear(hidden_dim, hidden_dim))
                if use_bn:
                    layers.append(nn.BatchNorm1d(hidden_dim))
                layers.append(nn.GELU())
            layers.append(nn.Linear(hidden_dim, bottleneck_dim))
            self.mlp = nn.Sequential(*layers)
        self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
        self.last_layer.weight_g.data.fill_(1)
        if norm_last_layer:
            self.last_layer.weight_g.requires_grad = False

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.mlp(x)
        x = nn.functional.normalize(x, dim=-1, p=2)
        x = self.last_layer(x)
        return x

## ⛑️ Utility Functions

In [None]:
def get_params_groups(model: nn.Module) -> Iterable:
    regularized = []
    not_regularized = []
    for name, param in model.named_parameters():
        if not param.requires_grad:
            continue
        # we do not regularize biases nor Norm parameters
        if name.endswith(".bias") or len(param.shape) == 1:
            not_regularized.append(param)
        else:
            regularized.append(param)
    return [{'params': regularized}, {'params': not_regularized, 'weight_decay': 0.}]

def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0):
    warmup_schedule = np.array([])
    warmup_iters = warmup_epochs * niter_per_ep
    if warmup_epochs > 0:
        warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)

    iters = np.arange(epochs * niter_per_ep - warmup_iters)
    schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters)))

    schedule = np.concatenate((warmup_schedule, schedule))
    assert len(schedule) == epochs * niter_per_ep
    return schedule

# 💿 The Dataset
---

For the purposes of this colab, we use the CIFAR10 dataset to train the model using DINO for Multi-Class Image Classification.

## 🖖 Data Augmentation Pipeline

The various pipelines for the patches (both global and local)

In [None]:
class DataAugmentationDINO(object):
    def __init__(self, global_crops_scale: Sequence[float], local_crops_scale: Sequence[float], local_crops_number: int) -> None:
        flip_and_color_jitter = transforms.Compose([
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomApply(
                [transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)],
                p=0.8
            ),
            transforms.RandomGrayscale(p=0.2),
        ])
        normalize = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
        ])

        # first global crop
        self.global_transfo1 = transforms.Compose([
            transforms.RandomResizedCrop(224, scale=global_crops_scale, interpolation=transforms.InterpolationMode.BICUBIC),
            flip_and_color_jitter,
            GaussianBlur(1.0),
            normalize,
        ])
        # second global crop
        self.global_transfo2 = transforms.Compose([
            transforms.RandomResizedCrop(224, scale=global_crops_scale, interpolation=transforms.InterpolationMode.BICUBIC),
            flip_and_color_jitter,
            GaussianBlur(0.1),
            Solarization(0.2),
            normalize,
        ])
        # transformation for the local small crops
        self.local_crops_number = local_crops_number
        self.local_transfo = transforms.Compose([
            transforms.RandomResizedCrop(96, scale=local_crops_scale, interpolation=transforms.InterpolationMode.BICUBIC),
            flip_and_color_jitter,
            GaussianBlur(p=0.5),
            normalize,
        ])

    def __call__(self, image: torch.Tensor) -> Sequence[torch.Tensor]:
        crops = []
        crops.append(self.global_transfo1(image))
        crops.append(self.global_transfo2(image))
        for _ in range(self.local_crops_number):
            crops.append(self.local_transfo(image))
        return crops

## ⚙️ Dataloader

In [None]:
%%capture
transform = DataAugmentationDINO(
  global_crops_scale = (0.4, 1.),
  local_crops_scale = (0.05, 0.4),
  local_crops_number = 8,
)
dataset = datasets.CIFAR10("./data", transform=transform, download=True)
data_loader = torch.utils.data.DataLoader(
  dataset,
  batch_size=32,
  num_workers=2,
  pin_memory=True,
  drop_last=True,
)
print(f"Data loaded: there are {len(dataset)} images.")

# ✍️ Model Architecture & Training
---

## Building Student and Teacher Networks


In [None]:
# Instantiate Models
student = torchvision_models.resnet18()
teacher = torchvision_models.resnet18()
embed_dim = student.fc.weight.shape[1]

# multi-crop wrapper handles "the forward method" with inputs of different resolutions
student = MultiCropWrapper(student, DINOHead(
  in_dim = embed_dim,
  out_dim = 65536,
  use_bn = False,
  norm_last_layer = True
))
teacher = MultiCropWrapper(
  teacher,
  DINOHead(in_dim = embed_dim, out_dim = 65536, use_bn = False),
)

# move networks to GPU
student, teacher = student.cuda(), teacher.cuda()
teacher_without_ddp = teacher

# teacher and student start with the same weights
teacher_without_ddp.load_state_dict(student.state_dict())

# Stop Gradient: there is no backpropagation through the teacher, so no need for gradients
for p in teacher.parameters():
    p.requires_grad = False
print(f"Student and Teacher are built: they are both resnet18 network.")

## The Loss, Optimizers and Schedulers

In [None]:
# The Loss Function
dino_loss = DINOLoss(
    65536,
    8 + 2,  # total number of crops = 2 global crops + local_crops_number
    0.04,
    0.04,
    0,
    20,
).cuda()

# Optimizer
params_groups = get_params_groups(student)
optimizer = torch.optim.AdamW(params_groups)  # to use with ViTs

# ============ init schedulers ... ============
lr_schedule = cosine_scheduler(
    0.0005 * 32 / 256.,  # linear scaling rule
    1e-6,
    20, len(data_loader),
    warmup_epochs=10,
)
wd_schedule = cosine_scheduler(
    0.04,
    0.4,
    20, len(data_loader),
)
# momentum parameter is increased to 1. during training with a cosine schedule
momentum_schedule = cosine_scheduler(0.996, 1, 20, len(data_loader))
print(f"Loss, optimizer and schedulers ready.")

## Training

![](https://github.com/SauravMaheshkar/infographics/blob/main/DINO/DINO.png?raw=true)



In [None]:
print("Starting DINO training !")
for epoch in range(0, 1):
  for it, (images, _) in track(enumerate(data_loader), total = len(data_loader)):
          # update weight decay and learning rate according to their schedule
          it = len(data_loader) * epoch + it  # global training iteration
          for i, param_group in enumerate(optimizer.param_groups):
              param_group["lr"] = lr_schedule[it]
              if i == 0:  # only the first group is regularized
                  param_group["weight_decay"] = wd_schedule[it]

          # move images to gpu
          images = [im.cuda(non_blocking=True) for im in images]
          # teacher and student forward passes + compute dino loss
          teacher_output = teacher(images[:2])  # only the 2 global views pass through the teacher
          student_output = student(images)
          loss = dino_loss(student_output, teacher_output, epoch)

          if not math.isfinite(loss.item()):
              print("Loss is {}, stopping training".format(loss.item()), force=True)
              sys.exit(1)

          # student update
          optimizer.zero_grad()
          loss.backward()
          optimizer.step()

          # EMA update for the teacher
          with torch.no_grad():
              m = momentum_schedule[it]  # momentum parameter
              for param_q, param_k in zip(student.parameters(), teacher_without_ddp.parameters()):
                  param_k.data.mul_(m).add_((1 - m) * param_q.detach().data)

          # logging
          torch.cuda.synchronize()
          # print(f"loss:{loss.item()}")
          # print(f"lr:{optimizer.param_groups[0]['lr']}")
          # print(f"wd:{optimizer.param_groups[0]['weight_decay']}")
  
  # ============ writing logs ... ============
  save_dict = {
    'student': student.state_dict(),
    'teacher': teacher.state_dict(),
    'optimizer': optimizer.state_dict(),
    'epoch': epoch + 1,
    'dino_loss': dino_loss.state_dict(),
  }
  torch.save(save_dict, os.path.join("./", 'checkpoint.pth'))
  if saveckp_freq and epoch % saveckp_freq == 0:
      torch.save(save_dict, os.path.join("./", f'checkpoint{epoch:04}.pth'))