## **TL;DR**

Use a pretraind a Vision Transformer provided by [`lightly`](https://github.com/lightly-ai/lightly) to create a vector index on an arbitrary dataset for Image Retrieval using [`faiss`](https://github.com/facebookresearch/faiss).

## 📦 Packages and Basic Setup
---

In [None]:
# @title
%%capture
!pip install faiss-gpu lightning lightly datasets supervision wandb ml-collections -q

import copy
import json
import warnings
from typing import Any, Callable, Dict, List, Tuple, TypeAlias, Union

import cv2
import faiss
import numpy as np
import supervision as sv
import torch
import torchvision.transforms as T
from datasets import load_dataset
from faiss import read_index, write_index
from PIL import Image, ImageOps
from pytorch_lightning import LightningModule
from torch import Tensor
from torch.nn import Identity
from torch.optim import SGD
from torch.optim.optimizer import Optimizer
from torchvision.models import resnet50
from tqdm.notebook import tqdm

from lightly.loss import DINOLoss
from lightly.models.modules import DINOProjectionHead
from lightly.models.utils import (
    activate_requires_grad,
    deactivate_requires_grad,
    get_weight_decay_parameters,
    update_momentum,
)
from lightly.transforms import DINOTransform
from lightly.utils.benchmarking import OnlineLinearClassifier
from lightly.utils.scheduler import CosineWarmupScheduler, cosine_schedule

BATCH_TYPE: TypeAlias = Dict[str, Any]

import os

from google.colab import userdata

key = userdata.get("W&B")
os.environ["WANDB_API_KEY"] = key

warnings.filterwarnings("ignore", category=UserWarning)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

!wget https://lightly-ssl-checkpoints.s3.amazonaws.com/imagenet_resnet50_dino_2023-06-06_13-59-48/pretrain/version_0/checkpoints/epoch%3D99-step%3D1000900.ckpt

In [None]:
import ml_collections

# @title ⚙️ Configuration
import wandb


def get_config() -> ml_collections.ConfigDict:
    config = ml_collections.ConfigDict()
    config.dataset: str = "ethz/food101"  # @param {type: "string"}
    config.method: str = "dino"  # @param {type: "string"}
    config.lr: float = 0.001  # @param {type: "number"}
    config.epochs: int = 10  # @param {type: "number"}
    config.batch_size: int = 32  # @param {type: "number"}
    config.top_k: int = 3  # @param {type: "number"}
    config.random_seed: int = 42  # @param {type: "number"}
    config.upload_to_wandb = True  # @param {type:"boolean"}
    config.fetch_from_wandb = True  # @param {type:"boolean"}
    config.upload_to_hf = True  # @param {type:"boolean"}
    config.hf_entity: str = "sauravmaheshkar"  # @param {type: "string"}
    config.wandb_entity: str = "sauravmaheshkar"  # @param {type: "string"}

    return config


config = get_config()

## 💿 The Dataset
---

In this particular example we use the `ethz/food101` dataset from the huggingface hub. This dataset contains 75750 images from 101 classes.

* This snippet highlights the ease of use of off-the-shelf `lightly` transforms ([`DINOTransform`](https://docs.lightly.ai/self-supervised-learning/lightly.transforms.html#lightly.transforms.dino_transform.DINOTransform)) when used with the huggingface ecosytem.
* No need to refactor images into a single folder, can simply iterate over all images in the dataset object.

In [None]:
%%capture
dataset = load_dataset(config.dataset, trust_remote_code=True)

## ✍️ Model Architecture & Training
---

![](https://github.com/SauravMaheshkar/SauravMaheshkar/blob/main/assets/DINO/DINO.png?raw=true)

In [None]:
# @title


class DINO(LightningModule):
    def __init__(self, batch_size_per_device: int, num_classes: int) -> None:
        super().__init__()
        self.save_hyperparameters()
        self.batch_size_per_device = batch_size_per_device

        resnet = resnet50()
        resnet.fc = Identity()  # Ignore classification head
        self.backbone = resnet
        self.projection_head = DINOProjectionHead(freeze_last_layer=1)
        self.student_backbone = copy.deepcopy(self.backbone)
        self.student_projection_head = DINOProjectionHead()
        self.criterion = DINOLoss()

        self.online_classifier = OnlineLinearClassifier(num_classes=num_classes)

    def forward(self, x: Tensor) -> Tensor:
        return self.backbone(x)

    def forward_student(self, x: Tensor) -> Tensor:
        features = self.student_backbone(x).flatten(start_dim=1)
        projections = self.student_projection_head(features)
        return projections

    def on_train_start(self) -> None:
        deactivate_requires_grad(self.backbone)
        deactivate_requires_grad(self.projection_head)

    def on_train_end(self) -> None:
        activate_requires_grad(self.backbone)
        activate_requires_grad(self.projection_head)

    def training_step(
        self, batch: Tuple[List[Tensor], Tensor, List[str]], batch_idx: int
    ) -> Tensor:
        # Momentum update teacher.
        momentum = cosine_schedule(
            step=self.trainer.global_step,
            max_steps=self.trainer.estimated_stepping_batches,
            start_value=0.996,
            end_value=1.0,
        )
        update_momentum(self.student_backbone, self.backbone, m=momentum)
        update_momentum(self.student_projection_head, self.projection_head, m=momentum)

        views, targets = batch[0], batch[1]
        global_views = torch.cat(views[:2])
        local_views = torch.cat(views[2:])

        teacher_features = self.forward(global_views).flatten(start_dim=1)
        teacher_projections = self.projection_head(teacher_features)
        student_projections = torch.cat(
            [self.forward_student(global_views), self.forward_student(local_views)]
        )

        loss = self.criterion(
            teacher_out=teacher_projections.chunk(2),
            student_out=student_projections.chunk(len(views)),
            epoch=self.current_epoch,
        )
        self.log_dict(
            {"train_loss": loss, "ema_momentum": momentum},
            prog_bar=True,
            sync_dist=True,
            batch_size=len(targets),
        )

        # Online classification.
        cls_loss, cls_log = self.online_classifier.training_step(
            (teacher_features.chunk(2)[0].detach(), targets), batch_idx
        )
        self.log_dict(cls_log, sync_dist=True, batch_size=len(targets))
        return loss + cls_loss

    def validation_step(
        self, batch: Tuple[Tensor, Tensor, List[str]], batch_idx: int
    ) -> Tensor:
        images, targets = batch[0], batch[1]
        features = self.forward(images).flatten(start_dim=1)
        cls_loss, cls_log = self.online_classifier.validation_step(
            (features.detach(), targets), batch_idx
        )
        self.log_dict(cls_log, prog_bar=True, sync_dist=True, batch_size=len(targets))
        return cls_loss

    def configure_optimizers(self):
        # Don't use weight decay for batch norm, bias parameters, and classification
        # head to improve performance.
        params, params_no_weight_decay = get_weight_decay_parameters(
            [self.student_backbone, self.student_projection_head]
        )
        # For ResNet50 we use SGD instead of AdamW/LARS as recommended by the authors:
        # https://github.com/facebookresearch/dino#resnet-50-and-other-convnets-trainings
        optimizer = SGD(
            [
                {"name": "dino", "params": params},
                {
                    "name": "dino_no_weight_decay",
                    "params": params_no_weight_decay,
                    "weight_decay": 0.0,
                },
                {
                    "name": "online_classifier",
                    "params": self.online_classifier.parameters(),
                    "weight_decay": 0.0,
                },
            ],
            lr=0.03 * self.batch_size_per_device * self.trainer.world_size / 256,
            momentum=0.9,
            weight_decay=1e-4,
        )
        scheduler = {
            "scheduler": CosineWarmupScheduler(
                optimizer=optimizer,
                warmup_epochs=int(
                    self.trainer.estimated_stepping_batches
                    / self.trainer.max_epochs
                    * 10
                ),
                max_epochs=int(self.trainer.estimated_stepping_batches),
            ),
            "interval": "step",
        }
        return [optimizer], [scheduler]

    def configure_gradient_clipping(
        self,
        optimizer: Optimizer,
        gradient_clip_val: Union[int, float, None] = None,
        gradient_clip_algorithm: Union[str, None] = None,
    ) -> None:
        self.clip_gradients(
            optimizer=optimizer,
            gradient_clip_val=3.0,
            gradient_clip_algorithm="norm",
        )
        self.student_projection_head.cancel_last_layer_gradients(self.current_epoch)


ssl_model = DINO(batch_size_per_device=config.batch_size, num_classes=1000).to(device)
ssl_model.load_state_dict(
    torch.load("/content/epoch=99-step=1000900.ckpt")["state_dict"]
)

## 🗂️ Creating the Image Index

In [None]:
transforms = T.Compose(
    [T.ToTensor(), T.Resize(244), T.CenterCrop(224), T.Normalize([0.5], [0.5])]
)


def augment(img: Image, transforms=transforms) -> torch.Tensor:
    if img.mode == "L":
        # Convert grayscale image to RGB by duplicating the single channel three times
        img = ImageOps.colorize(img, black="black", white="white")

    return transforms(img).unsqueeze(0)


def create_index(hf_dataset) -> faiss.IndexFlatL2:
    index = faiss.IndexFlatL2(2048)

    with torch.no_grad():
        for idx, element in enumerate(tqdm(hf_dataset)):
            embeddings = ssl_model(augment(element["image"], transforms).to(device))
            embedding = embeddings[0].cpu().numpy()
            index.add(np.array(embedding).reshape(1, -1))

    faiss.write_index(index, f"{config.method}.index")
    return index


data_index = create_index(dataset["train"])

In [None]:
if config.upload_to_wandb:
    run = wandb.init(
        project="SSL-Image-Retrieval",
        job_type="upload-embeddings",
        group=config.method,
        config=config.to_dict(),
        entity=config.wandb_entity,
    )

    embeddings_artifact = wandb.Artifact(
        name=f"{config.method}-embeddings", type="faiss-embeddings"
    )
    embeddings_artifact.add_file(local_path=f"{config.method}.index")
    run.log_artifact(embeddings_artifact)

    wandb.finish()

## 🔎 Using the Image Index

In [None]:
api = wandb.Api()
artifact = api.artifact("sauravmaheshkar/SSL-Image-Retrieval/dino-embeddings:latest")
artifact.download()

data_index = read_index("/content/artifacts/dino-embeddings:v0/dino.index")

In [None]:
def search_index(
    index: faiss.IndexFlatL2, embedding: list, k: int = config.top_k
) -> list:
    _, I = index.search(np.array(embedding[0].reshape(1, -1)), k)

    return I[0]

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import torchvision.transforms as T
from PIL import ImageOps

search_idx = 11234  # @param {type: "number"}
img = dataset["train"][search_idx]["image"]

# Define transforms
transforms = T.Compose(
    [T.ToTensor(), T.Resize(244), T.CenterCrop(224), T.Normalize([0.5], [0.5])]
)


def augment(img: Image, transforms=transforms) -> torch.Tensor:
    if img.mode == "L":
        # Convert grayscale image to RGB by duplicating the single channel three times
        img = ImageOps.colorize(img, black="black", white="white")
    return transforms(img).unsqueeze(0)


# Display images
fig, axs = plt.subplots(1, 4, figsize=(16, 4))

# Plot the source image
axs[0].imshow(img)
axs[0].set_title("Query")
axs[0].axis("off")

# Get embeddings and retrieve similar images
with torch.no_grad():
    embedding = ssl_model(augment(img, transforms).to(device))
    indices = search_index(data_index, np.array(embedding[0].cpu()).reshape(1, -1))

    for i, index in enumerate(indices[:3]):  # Get the first 3 retrieved images
        retrieved_img = dataset["train"][int(index)]["image"]
        axs[i + 1].imshow(retrieved_img)
        axs[i + 1].set_title(f"Retrieved image {i + 1}")
        axs[i + 1].axis("off")

plt.show()