This notebook is for running inference with a DINO (v1) model and visualizing the attention values.

In [None]:
from io import BytesIO
import GPUtil
import torch
from torch import nn
import torchvision
from torchvision import transforms
import numpy as np
import requests
from PIL import Image
import matplotlib.pyplot as plt

torch.cuda.empty_cache()
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f"device: {device}")

gpus = GPUtil.getGPUs()
for gpu in gpus:
    print(f"GPU ID: {gpu.id}, GPU Name: {gpu.name}")
    print(f"Total GPU memory: {gpu.memoryTotal} MB")
    print(f"Free GPU memory: {gpu.memoryFree} MB")
    print(f"Used GPU memory: {gpu.memoryUsed} MB")

# Load DINO model

In [None]:
# model = torch.hub.load('facebookresearch/dino:main', 'dino_vits16')
# model = torch.hub.load('facebookresearch/dino:main', 'dino_vits8')
model = torch.hub.load('facebookresearch/dino:main', 'dino_vitb16')
# model = torch.hub.load('facebookresearch/dino:main', 'dino_vitb8')
# model = torch.hub.load('facebookresearch/dino:main', 'dino_xcit_small_12_p16')
# model = torch.hub.load('facebookresearch/dino:main', 'dino_xcit_small_12_p8')
# model = torch.hub.load('facebookresearch/dino:main', 'dino_xcit_medium_24_p16')
# model = torch.hub.load('facebookresearch/dino:main', 'dino_xcit_medium_24_p8')
# model = torch.hub.load('facebookresearch/dino:main', 'dino_resnet50')

print(f"Number of model parameters: {sum(p.numel() for p in model.parameters()):,}")

In [None]:
patch_size = model.patch_embed.proj.kernel_size[0]
print(f"patch_size: {patch_size}")

In [None]:
# prepare model for inference
for p in model.parameters():
    p.requires_grad = False
model.eval()
model.to(device)

# Load image to run inference on

In [None]:
def preprocess_image(img, size, preserve_aspect_ratio=True):
    # Define the transformation to resize the image
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x[None,...]),
        transforms.Resize(size, antialias=True),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])
    return transform(img)

In [None]:
# load an image from URL
url = 'https://m.media-amazon.com/images/M/MV5BMTM1MjQzMDA5NV5BMl5BanBnXkFtZTcwMDk5MDg3Mw@@._V1_.jpg'  # into the wild
# url = 'https://static.euronews.com/articles/stories/06/35/53/24/1440x810_cmsv2_548e11b5-0a57-53f4-88d9-5ea703413300-6355324.jpg'  # latest fra ID
# url = 'https://upload.wikimedia.org/wikipedia/commons/1/14/New_Estonian_ID_card_%282021%29%28back%29.jpg'  # EST ID back

response = requests.get(url)
img = Image.open(BytesIO(response.content))

# configure the new image height for input into the DINO model
new_height = 720
new_height = new_height // patch_size * patch_size  # find closest matching to patch size height
new_width = round((img.size[0] / img.size[1]) * new_height / patch_size) * patch_size
print(new_height, new_width)

In [None]:
# prepare image for inference
img = preprocess_image(img, (new_height, new_width))

In [None]:
# model inference; get the attention values of the [CLS] token of the last multi-head attention layer
attentions = model.get_last_selfattention(img.to(device))

In [None]:
# extract only the attention values of interest

# number of heads
nh = attentions.shape[1]

# we keep only the output patch attention; attention from [CLS] token to each image patch
attentions = attentions[0, :, 0, 1:].reshape(nh, -1)

# reshape back to original image shape with a factor of 'patch_size' difference
w_featmap = img.shape[-2] // patch_size
h_featmap = img.shape[-1] // patch_size
attentions = attentions.reshape(nh, w_featmap, h_featmap)

# repeat the attention values such that they are as big as the original image size
attentions = nn.functional.interpolate(attentions.unsqueeze(0), scale_factor=patch_size, mode="nearest")[0].cpu().numpy()

# Plot image with attentions

In [None]:
# visualize the input image with the attention maps

# configure number of columns for the attention maps
n_cols = 3
n_rows = nh // n_cols + 1  # +1 for the first row where we'll plot the input image
assert nh % n_cols == 0  # make sure all attention maps are plotted

h, w = img.shape[2:]
fig, axes = plt.subplots(n_rows, n_cols, figsize=(5*n_cols, 5*(h/w)*n_rows))
fig.tight_layout()

# plot the main image in the top row
gs = axes[0, 0].get_gridspec()
# remove the underlying axes
for ax in axes[0, :]:
    ax.remove()
ax_top = fig.add_subplot(gs[0, :])
main_img = torchvision.utils.make_grid(img, normalize=True, scale_each=True).permute((1,2,0))
ax_top.imshow(main_img)

# plot the attention maps underneath the main image
axes = np.ravel(axes)
for i in range(nh):
    ax = axes[i + n_cols]
    ax.set_title(f"attention-head{i}", fontsize=13)
    ax.imshow(main_img)
    ax.imshow(attentions[i], alpha=0.85)