# Visualizing the Attention from a Vision Transformers trained using DINO (self-distillation with no labels)


In [1]:
import os
import sys

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as pth_transforms

import matplotlib.pyplot as plt
from PIL import Image
import skimage

from facebookresearch_dino_main.visualize_attention import  display_instances

## Loading the model

In [2]:
ckpt_dir = './ckpts'
vits_dino_path = os.path.join(ckpt_dir, 'dino_deitsmall16_pretrain.pth')
state_dict = torch.load(vits_dino_path)

model = torch.hub.load('facebookresearch/dino:main', 'dino_vits16')
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(f'Using {device}')

model = model.to(device)
model.eval()

Using cache found in C:\Users\Zace VR/.cache\torch\hub\facebookresearch_dino_main


Using cuda


VisionTransformer(
  (patch_embed): PatchEmbed(
    (proj): Conv2d(3, 384, kernel_size=(16, 16), stride=(16, 16))
  )
  (pos_drop): Dropout(p=0.0, inplace=False)
  (blocks): ModuleList(
    (0): Block(
      (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=384, out_features=1152, bias=True)
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=384, out_features=384, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (drop_path): Identity()
      (norm2): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (mlp): Mlp(
        (fc1): Linear(in_features=384, out_features=1536, bias=True)
        (act): GELU()
        (fc2): Linear(in_features=1536, out_features=384, bias=True)
        (drop): Dropout(p=0.0, inplace=False)
      )
    )
    (1): Block(
      (norm1): LayerNorm((384,), eps=1e-06, elementwise_affine=True)
      (attn): Attention(
        (

In [None]:
# open image
DATA_PATH = 'D:/self_supervised_pathology/datasets/NCT-CRC-HE-100K-NONORM/LYM'
OUT_PATH = 'D:/self_supervised_pathology/outputs/attention/dino_base/NCT-CRC-HE-100K-NONORM/LYM'

print('Making output dir')
os.makedirs(OUT_PATH, exist_ok=True)

# im_names = os.scandir(DATA_PATH)
image_size = 224
patch_size = 16
threshold = 0.8


# for im in im_names:
im = 'LYM-AAAEAEME.tif'
image_path = os.path.join(DATA_PATH, im)
print(f'handling {im}')
if os.path.isfile(image_path):

    with open(image_path, 'rb') as f:
        img = Image.open(f)
        img = img.convert('RGB')
else:
    print(f"Provided image path {image_path} is non valid.")

transform = pth_transforms.Compose([
    pth_transforms.Resize(image_size),
    pth_transforms.ToTensor(),
    pth_transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])
img = transform(img)

# make the image divisible by the patch size
w, h = img.shape[1] - img.shape[1] % patch_size, img.shape[2] - img.shape[2] % patch_size
img = img[:, :w, :h].unsqueeze(0)

w_featmap = img.shape[-2] // patch_size
h_featmap = img.shape[-1] // patch_size

print('computing attention')
attentions = model.get_last_selfattention(img.to(device))

nh = attentions.shape[1] # number of head

# we keep only the output patch attention
attentions = attentions[0, :, 0, 1:].reshape(nh, -1)

if threshold is not None:
    print('Thresholding')
    # we keep only a certain percentage of the mass
    val, idx = torch.sort(attentions)
    val /= torch.sum(val, dim=1, keepdim=True)
    cumval = torch.cumsum(val, dim=1)
    th_attn = cumval > (1 - threshold)
    idx2 = torch.argsort(idx)
    for head in range(nh):
        th_attn[head] = th_attn[head][idx2[head]]
    th_attn = th_attn.reshape(nh, w_featmap, h_featmap).float()
    # interpolate
    th_attn = nn.functional.interpolate(th_attn.unsqueeze(0), scale_factor=patch_size, mode="nearest")[0].cpu().numpy()

attentions = attentions.reshape(nh, w_featmap, h_featmap)
attentions = nn.functional.interpolate(attentions.unsqueeze(0), scale_factor=patch_size, mode="nearest")[0].cpu().numpy()

# save attentions heatmaps
print('Saving heatmaps')
os.makedirs(OUT_PATH, exist_ok=True)
torchvision.utils.save_image(torchvision.utils.make_grid(img, normalize=True, scale_each=True), os.path.join(OUT_PATH, im))
for j in range(nh):
    fname = os.path.join(OUT_PATH, "attn-head" + str(j) + ".png")
    plt.imsave(fname=fname, arr=attentions[j], format='png')
    print(f"{fname} saved.")

if threshold is not None:
    print('displaying')
    image = skimage.io.imread(os.path.join(OUT_PATH, im))
    for j in range(nh):
        display_instances(image, th_attn[j], fname=os.path.join(OUT_PATH, "mask_th" + str(threshold) + "_head" + str(j) +".png"), blur=False)


In [9]:
print(model.state_dict())

OrderedDict([('cls_token', tensor([[[ 3.0054e-02, -3.0335e-05, -1.2904e-04, -2.7991e-03,  4.8935e-04,
          -5.6446e-03, -6.1489e-03, -1.0758e-02,  2.6034e-03,  1.0472e-02,
          -5.9286e-04,  4.5921e-04, -1.2401e-02, -1.1338e-02,  3.0300e-02,
           1.5348e-03, -2.2310e-02, -9.5671e-03, -1.2471e-02, -1.0681e-02,
           2.7507e-02,  5.4123e-04, -1.8679e-03, -5.2474e-03, -6.1964e-03,
           3.1101e-04, -3.0664e-03,  7.0276e-04, -9.9438e-03,  2.8920e-04,
           7.8951e-03, -1.1650e-02, -4.1362e-03,  3.6069e-03, -3.1683e-04,
           9.1873e-04,  3.8778e-03, -1.1693e-02,  1.8876e-03, -9.7881e-03,
          -2.2812e-03,  8.7846e-04,  2.3671e-04, -1.1528e-02,  1.4693e-02,
          -2.7103e-03, -5.7107e-04,  9.1824e-04, -1.0475e-02,  7.5951e-04,
           7.0727e-04,  2.7722e-02,  5.7413e-03,  2.1729e-03, -3.0501e-03,
          -4.4870e-03, -5.5557e-04, -2.0807e-03, -2.9152e-03,  1.5234e-03,
          -1.0835e-02,  8.9625e-03, -2.5785e-03,  1.0576e-03,  4.7416e-04