In [1]:
import types
import math
from typing import Callable

import torch
import torch.nn as nn
import torch.nn.functional as F

import timm

In [2]:
dino = timm.create_model('vit_small_patch16_224_dino', pretrained=True)

  model = create_fn(


In [3]:
patch_size = [16, 16]
hooks = [2, 5, 8, 11]
hook_patch = True
start_index = 1

In [81]:
pretrained = nn.Module()
pretrained.model = dino

In [82]:
activations = {}
def get_activation(name: str) -> Callable:
    def hook(model, inputs, outputs):
        activations[name] = outputs
    return hook

In [83]:
for i in range(len(hooks)):
    pretrained.model.blocks[hooks[i]].register_forward_hook(get_activation(f'{i}')) # Get shape of 2nd, 5th, 6th... Block of ViT
if hook_patch: pretrained.model.pos_drop.register_forward_hook(get_activation('4'))

In [85]:
dino.eval()
output = dino(torch.rand(5, 3, 224, 224))
output.shape

torch.Size([5, 384])

In [8]:
for i in range(len(hooks)):
    print(f"Block {hooks[i]} Shape:", list(activations[f'{i}'].shape))
if hook_patch:
    print(f"Dropout Block Shape:", list(activations['4'].shape))

Block 2 Shape: [5, 197, 384]
Block 5 Shape: [5, 197, 384]
Block 8 Shape: [5, 197, 384]
Block 11 Shape: [5, 197, 384]
Dropout Block Shape: [5, 197, 384]


In [9]:
list(activations['4'].shape)

[5, 197, 384]

In [10]:
# B = batch size
# N = number of tokens
# C = embedding dim

In [11]:
B = 5
N = 77
C = 768
start_index = 1

x = torch.rand(B, N, C) # [5, 77, 768]

# Every local patch token gains global context
if start_index == 2:
    readout = (x[:, 0] + x[:, 1]) / 2 # (CLS + DIST) / 2
else:
    readout = x[:, 0]                 # CLS

# readout: [B,             C]
# x      : [B, N - 2 or 1, C]
out = x[: , start_index:] + readout.unsqueeze(1)
out.shape

torch.Size([5, 76, 768])

In [12]:
class Readout(nn.Module):
    """
    Adds CLS and/or DIST Tokens to the patches
    """
    def __init__(self, start_index = 1):
        super().__init__()

        self.start_index = start_index
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # X [B, N, C]
        # Every local patch token gains global context
        if self.start_index == 2:
            readout = (x[:, 0] + x[:, 1]) / 2 # (CLS + DIST) / 2
        else:
            readout = x[:, 0]                 # CLS
        
        # readout: [B,             C]
        # x      : [B, N - 2 or 1, C]
        return x[: , self.start_index: ] + readout.unsqueeze(1)

readout = Readout(start_index=1)
readout(torch.rand(5, 77, 768)).shape

torch.Size([5, 76, 768])

In [13]:
class Transpose(nn.Module):
    """
    from [B, N, C] to [B, C, N]
    """
    def __init__(self, dim1: int, dim2: int):
        super().__init__()
        self.dim1 = dim1
        self.dim2 = dim2
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x.transpose(self.dim1, self.dim2).contiguous()

transpose = Transpose(1, 2)
transpose(torch.rand(5, 77, 768)).shape

torch.Size([5, 768, 77])

In [14]:
# _resize_pos_embed

start_index = 1
pos_embed = torch.rand(5, 197, 384)
gs_h = gs_w = 24

posemb_tok = pos_embed[:, :start_index]                     # posemb_tok:  [B,          1 or 2, C]
posemb_grid = pos_embed[0, start_index: ]                   # posemb_grid: [N - 1 or 2, C]

gs_old = int(math.sqrt(len(posemb_grid)))                   # gs_old, 14 or 16 .....

posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1)    # posemb_grid: [1, 14, 14, C]
posemb_grid = posemb_grid.permute(0, 3, 1, 2)               # posemb_grid: [1, C, 14, 14]
posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), 
                            mode="bilinear", 
                            align_corners=False)            # posemb_grid: [1, C, new_gs_w, new_gs_w]
posemb_grid = posemb_grid.permute(0, 2, 3, 1)               # posemb_grid: [1, new_gs_w, new_gs_w, C]
posemb_grid = posemb_grid.reshape(1, gs_h * gs_w, -1)       # posemb_grid: [1, new_gs_w x new_gs_w, C]
posemb_grid = posemb_grid.expand(pos_embed.shape[0], -1, -1)# posemb_grid: [B, new_gs_w x new_gs_w, C]           added Batch dim back
posemb = torch.cat([posemb_tok, posemb_grid], dim = 1)      # posemb_grid: [B, new_gs_w x new_gs_w + 1 or 2, C]  + 1 or 2 CLS token in dim = 1
posemb.shape


torch.Size([5, 577, 384])

In [15]:
def _resize_pos_embed(self, posemb: torch.Tensor, gs_h: int, gs_w: int) -> torch.Tensor:
    posemb_tok = pos_embed[:, : self.start_index]                    # posemb_tok:  [B,          1 or 2, C]
    posemb_grid = pos_embed[0, self.start_index :]                   # posemb_grid: [N - 1 or 2, C]

    gs_old = int(math.sqrt(len(posemb_grid)))                   # gs_old, 14 or 16 .....

    posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1)    # posemb_grid: [1, 14, 14, C]
    posemb_grid = posemb_grid.permute(0, 3, 1, 2)               # posemb_grid: [1, C, 14, 14]
    posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), 
                                mode="bilinear", 
                                align_corners=False)            # posemb_grid: [1, C, new_gs_w, new_gs_w]
    
    posemb_grid = posemb_grid.permute(0, 2, 3, 1)               # posemb_grid: [1, new_gs_w, new_gs_w, C]
    posemb_grid = posemb_grid.reshape(1, gs_h * gs_w, -1)       # posemb_grid: [1, new_gs_w x new_gs_w, C]
    posemb_grid = posemb_grid.expand(pos_embed.shape[0], -1, -1)# posemb_grid: [B, new_gs_w x new_gs_w, C]    # FIX
    posemb = torch.cat([posemb_tok, posemb_grid], dim = 1)      # posemb_grid: [B, new_gs_w x new_gs_w + 1 or 2, C] added Batch dim back + 1 or 2 CLS token
    return posemb

dino.start_index = 1
out = _resize_pos_embed(self = dino, posemb = torch.rand(5, 196, 384), gs_h = 24, gs_w = 24)
out.shape

torch.Size([5, 577, 384])

In [16]:
dino.patch_size = patch_size

In [17]:
x = torch.rand(5, 3, 256, 256)
x = pretrained.model.patch_embed.proj(x) # x: [B, embedding dimension, H // 16   W // 16]
x = x.flatten(2)                         # x: [B, embedding dimension, H // 16 x W // 16]
x = x.transpose(1,2)                     # x: [B, H // 16 x W // 16, embedding dimension]
# x: [B, patch_dim, embedding dimension]
# x: [B, N,         C]

pos_embed = _resize_pos_embed(self = dino, posemb = dino.pos_embed, gs_h = dino.patch_size[0], gs_w = dino.patch_size[1])

# # Adding CLS Tokens
cls_tokens = pretrained.model.cls_token             # cls_tokens: [1, 1, embedding dimension]
cls_tokens = cls_tokens.expand(x.shape[0], -1, -1)  # cls_tokens: [B, 1, embedding dimension]

x = torch.cat([cls_tokens, x], dim=1)               # x:   [B, N + 1, C]

assert x.shape == pos_embed.shape, f"x shape: {x.shape}, pos_embed: {pos_embed.shape}"
x = x + pos_embed
x = pretrained.model.pos_drop(x)                # x:   [B, N + 1, C]
for blk in pretrained.model.blocks:
    x = blk(x)                                  # x:   [B, N + 1, C]
x = pretrained.model.norm(x)                    # x:   [B, N + 1, C]
x.shape

torch.Size([5, 257, 384])

In [None]:
def forward_flex(self, x: torch.Tensor) -> torch.Tensor:
    x = self.patch_embed.proj(x)             # x: [B, embedding dimension, H // 16   W // 16]
    x = x.flatten(2)                         # x: [B, embedding dimension, H // 16 x W // 16]
    x = x.transpose(1,2)                     # x: [B, H // 16 x W // 16, embedding dimension]
    # x: [B, patch_dim, embedding dimension]
    # x: [B, N,         C]

    pos_embed = _resize_pos_embed(self = self, posemb = self.pos_embed, gs_h = self.patch_size[0], gs_w = self.patch_size[1])

    # # Adding CLS Tokens
    cls_tokens = self.cls_token             # cls_tokens: [1, 1, embedding dimension]
    cls_tokens = cls_tokens.expand(x.shape[0], -1, -1)  # cls_tokens: [B, 1, embedding dimension]

    x = torch.cat([cls_tokens, x], dim=1)               # x:   [B, N + 1, C]

    assert x.shape == pos_embed.shape, f"x shape: {x.shape}, pos_embed: {pos_embed.shape}"
    x = x + pos_embed
    x = self.pos_drop(x)                # x:   [B, N + 1, C]
    for blk in self.blocks:
        x = blk(x)                                  # x:   [B, N + 1, C]
    x = self.norm(x)                    # x:   [B, N + 1, C]
    return x

x = torch.rand(5, 3, 256, 256)
out = forward_flex(self = pretrained.model, x = x)
out.shape

torch.Size([5, 257, 384])

In [32]:
def forward_vit(pretrained: nn.Module, x: torch.Tensor) -> torch.Tensor:
    _ = pretrained.model.forward_flex(x) # No need to store output because the dict `activations` gets updated during ReadOut
    return {k: pretrained.rearrange(v) for k, v in activations.items()}

pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
pretrained.model._resize_pos_embed = types.MethodType(_resize_pos_embed, pretrained.model)

In [87]:
activations = {}

In [None]:

def make_vit_backbone(model: nn.Module, 
                      patch_size = [16, 16],
                      hooks = [2, 5, 8, 11],
                      hook_patch = True,
                      start_index = 1):
    assert len(hooks) == 4
    pretained = nn.Module
    pretained.model = model

    for i in range(len(hooks)):

        pretrained.model.blocks[hooks[i]].register_forward_hook(get_activation(f'{i}')) # Get shape of 2nd, 5th, 6th... Block of ViT
    if hook_patch: pretrained.model.pos_drop.register_forward_hook(get_activation('4'))

    pretrained.rearrange = nn.Sequential(
                                        Readout(start_index=start_index), 
                                        Transpose(1, 2)                  # [B, C, N]
                                        )
    
    pretrained.model.start_index = start_index
    pretrained.model.patch_size = patch_size

    pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
    pretrained.model._resize_pos_embed = types.MethodType(_resize_pos_embed, pretrained.model)

    return pretrained
    

In [76]:
pretained_dino = make_vit_backbone(model = dino)
pretained_dino.model.__class__.__name__

'VisionTransformer'

In [94]:
img = torch.rand(5, 3, 224, 224)

pretained_dino.model.eval()
out1 = pretained_dino.model(img)
out1.shape

torch.Size([5, 384])

In [95]:
for i in range(len(hooks)):
    print(f"Block {hooks[i]} Shape:", list(activations[f'{i}'].shape))
if hook_patch:
    print(f"Dropout Block Shape:", list(activations['4'].shape))

Block 2 Shape: [5, 197, 384]
Block 5 Shape: [5, 197, 384]
Block 8 Shape: [5, 197, 384]
Block 11 Shape: [5, 197, 384]
Dropout Block Shape: [5, 197, 384]


In [96]:
x = torch.rand(5, 77, 768)

out2 = pretained_dino.rearrange(x)
out2.shape

torch.Size([5, 768, 76])

In [None]:
x = torch.rand(5, 3, 256, 256)

out2 = pretained_dino.model.forward_flex(x)
out2.shape

torch.Size([5, 257, 384])

In [None]:
dino.start_index = 1
out = _resize_pos_embed(self = dino, posemb = torch.rand(5, 196, 384), gs_h = 24, gs_w = 24)

In [110]:
x = torch.rand(5, 196, 384)

out2 = pretained_dino.model._resize_pos_embed(posemb = x, gs_h = 24, gs_w = 24)
out2.shape

torch.Size([5, 577, 384])