In [None]:
import torch
import numpy as np
import cv2
from PIL import Image
from torchvision import transforms
from google.colab.patches import cv2_imshow
import matplotlib.pyplot as plt
import math

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
import kagglehub
path = kagglehub.dataset_download("ambityga/imagenet100")
print("Path to dataset files:", path)

In [None]:
!ls -l $path/val.X/n01773549

In [None]:
imagepath = path + '/val.X/' + 'n01773549' +'/' + 'ILSVRC2012_val_00008316.JPEG'

In [None]:
image = cv2.imread(imagepath)
cv2_imshow(image)

In [None]:
# load model
vitb8 = torch.hub.load('facebookresearch/dino:main', 'dino_vitb8')

In [None]:
vitb8.eval()

In [None]:
# preprocessing
print(image.shape)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image_size = (224, 224)
blob = cv2.dnn.blobFromImage(image, 1.0/255, image_size, swapRB=True, crop=False)
blob[0][0] = (blob[0][0] - 0.485)/0.229
blob[0][1] = (blob[0][1] - 0.456)/0.224
blob[0][2] = (blob[0][2] - 0.406)/0.225
x = torch.tensor(blob) # 1 x 3 x 224 x 224
print(x.min().item(),x.max().item())
print(blob.shape)

In [None]:
# embed
B, _, h, w = x.shape
x = vitb8.patch_embed(x)  # patch linear embedding by Conv2d(3, 768, kernel_size=(8, 8), stride=(8, 8))
print(x.shape) # 1 x 784 x 768   # 784 = 28 x 28, 224 = 28 x 8, 768 = 12 x 64

In [None]:
# add CLS
cls_tokens = vitb8.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1) # 1 x 785 x 768
print(x.shape)

In [None]:
# add positional encoding
x = x + vitb8.interpolate_pos_encoding(x, h, w) # 1 x 785 x 768

In [None]:
# encode
attn_maps = []
for block in vitb8.blocks:
    y, attn = block.attn(block.norm1(x))
    attn_maps.append(attn)
    x += y
    x += block.mlp(block.norm2(x))

x = vitb8.norm(x) # 1 x 785 x 768

In [None]:
# wipeout
features = x[:, 0] # 1 x 768
print(features.shape)
print(features)

In [None]:
def draw_mask(img, mask):
    H, W = img.shape[:2]
    mask_resized = cv2.resize(mask, (W, H), interpolation=cv2.INTER_LINEAR)
    if img.ndim == 3:
        mask_resized = np.repeat(mask_resized[:, :, None], 3, axis=2)
    result = (img.astype(np.float32) * mask_resized).clip(0, 255).astype(np.uint8)
    return result

In [None]:
# visualize attention maps
base = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
att_mats = torch.stack(attn_maps).squeeze(1) # (12, 12, 50, 50)
num_layers, num_heads, N, _ = att_mats.shape
grid_size = int(math.sqrt(N-1))
plt.figure(figsize=(2 * num_heads, 2 * num_layers))
for t in range(num_layers):
    for i in range(num_heads):
        head = att_mats[t, i]      # shape (50, 50)
        mask = head[0, 1:].reshape(grid_size, grid_size)
        mask /= mask.max()
        mask = mask.detach().cpu().numpy()
        disp = draw_mask(base, mask)
        plt.subplot(num_layers, num_heads, t * num_heads + i + 1)
        plt.imshow(disp, cmap='gray')
        plt.axis('off')
plt.tight_layout()
plt.show()