## **TL;DR**

Pretrain a Vision Transformer using DINO ([`lightly`](https://github.com/lightly-ai/lightly) implementation) on an arbitrary dataset for Image Retrieval using [`faiss`](https://github.com/facebookresearch/faiss) as a vector database.

## 📦 Packages and Basic Setup
---

In [None]:
# @title
%%capture
!pip install faiss-gpu lightning lightly datasets supervision wandb ml-collections -q

import copy
import json
import warnings
from typing import Any, Callable, Dict, TypeAlias

import cv2
import faiss
import numpy as np
import pytorch_lightning as pl
import supervision as sv
import torch
import torchvision.transforms as T
from datasets import load_dataset
from faiss import read_index, write_index
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import WandbLogger
from PIL import Image, ImageOps
from tqdm.notebook import tqdm

from lightly.loss import DINOLoss
from lightly.models.modules import DINOProjectionHead
from lightly.models.utils import deactivate_requires_grad, update_momentum
from lightly.transforms.dino_transform import DINOTransform
from lightly.utils.scheduler import cosine_schedule

BATCH_TYPE: TypeAlias = Dict[str, Any]

import os

from google.colab import userdata

key = userdata.get("W&B")
os.environ["WANDB_API_KEY"] = key

os.makedirs("artifacts", exist_ok=True)
warnings.filterwarnings("ignore", category=UserWarning)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
import ml_collections

# @title ⚙️ Configuration
import wandb


def get_config() -> ml_collections.ConfigDict:
    config = ml_collections.ConfigDict()
    config.model: str = "dino_vits8"  # @param {type: "string"}
    config.dataset: str = "ethz/food101"  # @param {type: "string"}
    config.hidden_dim: int = 256  # @param {type: "number"}
    config.bottleneck_dim: int = 64  # @param {type: "number"}
    config.output_dim: int = 2048  # @param {type: "number"}
    config.warmup_teacher_temp_epochs: int = 5  # @param {type: "number"}
    config.lr: float = 0.001  # @param {type: "number"}
    config.epochs: int = 10  # @param {type: "number"}
    config.batch_size: int = 16  # @param {type: "number"}
    config.top_k: int = 3  # @param {type: "number"}
    config.random_seed: int = 42  # @param {type: "number"}
    config.upload_to_wandb = True  # @param {type:"boolean"}
    config.upload_to_hf = True  # @param {type:"boolean"}
    config.hf_entity: str = "sauravmaheshkar"  # @param {type: "string"}
    config.wandb_entity: str = "sauravmaheshkar"  # @param {type: "string"}

    return config


config = get_config()

wandb.init(
    project="SSL-Image-Retrieval",
    job_type="train",
    group=config.model,
    config=config.to_dict(),
    entity=config.wandb_entity,
)

wandb_logger = WandbLogger()
pl.seed_everything(config.random_seed)

## 💿 The Dataset
---

In this particular example we use the `ethz/food101` dataset from the huggingface hub. This dataset contains 75750 images from 101 classes.

* This snippet highlights the ease of use of off-the-shelf `lightly` transforms ([`DINOTransform`](https://docs.lightly.ai/self-supervised-learning/lightly.transforms.html#lightly.transforms.dino_transform.DINOTransform)) when used with the huggingface ecosytem.
* No need to refactor images into a single folder, can simply iterate over all images in the dataset object.

In [None]:
%%capture
dataset = load_dataset(config.dataset, trust_remote_code=True)

In [None]:
transform = DINOTransform()


def apply_transform(
    example_batch: BATCH_TYPE, transform: Callable = transform
) -> BATCH_TYPE:
    """
    Apply the given transform across a batch. To be used in a 'map' like manner

    Args:
      example_batch (Dict): a batch of data, should contain the key 'image'
      tranform (Callable): image transformations to be performed

    Returns:
      updated batch with transformations applied to the image
    """

    assert (
        "image" in example_batch.keys()
    ), "batch should be of type Dict[str, Any] with a key 'image'"

    example_batch["image"] = [
        transform(image.convert("RGB")) for image in example_batch["image"]
    ]
    return example_batch


dataset.set_transform(apply_transform)

In [None]:
dataloader = torch.utils.data.DataLoader(
    dataset["train"],
    batch_size=config.batch_size,
    shuffle=True,
    drop_last=True,
    num_workers=2,
)

## ✍️ Model Architecture & Training
---

![](https://github.com/SauravMaheshkar/SauravMaheshkar/blob/main/assets/DINO/DINO.png?raw=true)

In [None]:
# @title


class DINO(torch.nn.Module):
    def __init__(
        self,
        model: str = config.model,
        hidden_dim: int = config.hidden_dim,
        bottleneck_dim: int = config.bottleneck_dim,
        output_dim: int = config.output_dim,
        warmup_teacher_temp_epochs: int = config.warmup_teacher_temp_epochs,
    ) -> None:
        super().__init__()
        backbone = torch.hub.load("facebookresearch/dino:main", model, pretrained=False)
        input_dim = backbone.embed_dim

        self.student_backbone = backbone
        self.student_head = DINOProjectionHead(
            input_dim=input_dim,
            hidden_dim=hidden_dim,
            bottleneck_dim=bottleneck_dim,
            output_dim=output_dim,
            freeze_last_layer=1,
        )
        self.teacher_backbone = copy.deepcopy(backbone)
        self.teacher_head = DINOProjectionHead(
            input_dim=input_dim,
            hidden_dim=hidden_dim,
            bottleneck_dim=bottleneck_dim,
            output_dim=output_dim,
        )
        deactivate_requires_grad(self.teacher_backbone)
        deactivate_requires_grad(self.teacher_head)

    def forward(self, x):
        y = self.student_backbone(x).flatten(start_dim=1)
        z = self.student_head(y)
        return z

    def forward_teacher(self, x):
        y = self.teacher_backbone(x).flatten(start_dim=1)
        z = self.teacher_head(y)
        return z


ssl_model = DINO(
    model=config.model,
    hidden_dim=config.hidden_dim,
    bottleneck_dim=config.bottleneck_dim,
    output_dim=config.output_dim,
    warmup_teacher_temp_epochs=config.warmup_teacher_temp_epochs,
).to(device)

In [None]:
criterion = DINOLoss(
    output_dim=config.output_dim,
    warmup_teacher_temp_epochs=config.warmup_teacher_temp_epochs,
).to(device)

optimizer = torch.optim.Adam(ssl_model.parameters(), lr=config.lr)

for epoch in range(config.epochs):
    total_loss = 0
    momentum_val = cosine_schedule(epoch, config.epochs, 0.996, 1)
    for batch in dataloader:
        views = batch["image"]
        update_momentum(
            ssl_model.student_backbone, ssl_model.teacher_backbone, m=momentum_val
        )
        update_momentum(ssl_model.student_head, ssl_model.teacher_head, m=momentum_val)
        views = [view.to(device) for view in views]
        global_views = views[:2]
        teacher_out = [ssl_model.forward_teacher(view) for view in global_views]
        student_out = [ssl_model.forward(view) for view in views]
        loss = criterion(teacher_out, student_out, epoch=epoch)
        total_loss += loss.detach()
        loss.backward()
        wandb.log({"loss": loss.item()})
        ssl_model.student_head.cancel_last_layer_gradients(current_epoch=epoch)
        optimizer.step()
        optimizer.zero_grad()

    torch.save(
        ssl_model.state_dict(),
        f"artifacts/dino-food101-{config.model}-{config.hidden_dim}",
    )

    avg_loss = total_loss / len(dataloader)
    wandb.log({"avg_loss": avg_loss})
    print(f"epoch: {epoch:>02}, loss: {avg_loss:.5f}")

wandb.finish()