In [None]:
# Stolen From https://github.com/jeonsworld/ViT-pytorch.git with some small changes
# "->" comments added to the original code

In [None]:
import os
import torch
import numpy as np
import cv2
import matplotlib.pyplot as plt
from urllib.request import urlretrieve
from PIL import Image
from torchvision import transforms

from vit_models.modeling import VisionTransformer, CONFIGS

In [None]:
model_name = "ViT-B_16-224"
os.makedirs("vit_models/attention_data", exist_ok=True)
if not os.path.isfile("vit_models/attention_data/ilsvrc2012_wordnet_lemmas.txt"):
    urlretrieve("https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt", "attention_data/ilsvrc2012_wordnet_lemmas.txt")
if not os.path.isfile(f"vit_models/attention_data/{model_name}.npz"):
    urlretrieve(f"https://storage.googleapis.com/vit_models/imagenet21k+imagenet2012/{model_name}.npz", f"attention_data/{model_name}.npz")

imagenet_labels = dict(enumerate(open('vit_models/attention_data/ilsvrc2012_wordnet_lemmas.txt')))

In [None]:
# Prepare Model
config = CONFIGS[model_name[:-4]]
model = VisionTransformer(config, num_classes=1000, zero_head=False, img_size=224, vis=True)
model.load_from(np.load(f"vit_models/attention_data/{model_name}.npz"))
model.eval()

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])

im = Image.open("pics/corgi_image.jpg")
x = transform(im)
print(f"input image shape: {x.shape}")

In [None]:
with torch.no_grad():
    logits, att_mat = model(x.unsqueeze(0))

# -> Stack attention maps across all layers
att_mat = torch.stack(att_mat).squeeze(1)

# Average the attention weights across all heads.
att_mat = torch.mean(att_mat, dim=1)

# To account for residual connections, we add an identity matrix to the
# attention matrix and re-normalize the weights.
# -> if q0 pays attention to all tokens and create weighted sum of values (sum of weights = 1)
# -> then the result will be added to x0 before self-attention (skip connection in transformer architecture)
# -> so it's like paying attention to x0(v0) as separate token with weight = 1
residual_att = torch.eye(att_mat.size(1))
aug_att_mat = att_mat + residual_att
# -> sum of weights for each q should be 1
# -> all weights are positive so softmax isn't necessary, and also the distribution is still sharp from last softmax
aug_att_mat = aug_att_mat / aug_att_mat.sum(dim=-1).unsqueeze(-1)

# Recursively multiply the weight matrices
joint_attentions = torch.zeros(aug_att_mat.size())
joint_attentions[0] = aug_att_mat[0]

# -> Final joint_attentions: attn_mat[-1] @ (attn_mat[-2] @ (attn_mat[-3] , (attn_mat[-4] , ...)))
for n in range(1, aug_att_mat.size(0)):
    joint_attentions[n] = torch.matmul(aug_att_mat[n], joint_attentions[n-1])

# Attention from the output token to the input space.
v = joint_attentions[-1]
grid_size = int(np.sqrt(aug_att_mat.size(-1)))
mask = v[0, 1:].reshape(grid_size, grid_size).detach().numpy()
mask = cv2.resize(mask / mask.max(), im.size)[..., np.newaxis]
result = (mask * im).astype("uint8")
cv2.imwrite('_results/ViT_B_mask.jpg', cv2.applyColorMap(np.uint8(mask[:, :, 0] * 255), cv2.COLORMAP_INFERNO))

In [None]:
fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(16, 16))

ax1.set_title('Original')
ax2.set_title('Attention Map')
_ = ax1.imshow(im)
_ = ax2.imshow(result)

probs = torch.nn.Softmax(dim=-1)(logits)
top5 = torch.argsort(probs, dim=-1, descending=True)
print("Prediction Label and Attention Map!\n")
for idx in top5[0, :5]:
    print(f'{probs[0, idx.item()]:.5f} : {imagenet_labels[idx.item()]}', end='')

In [None]:
for i, v in enumerate(joint_attentions):
    # Attention from the output token to the input space.
    mask = v[0, 1:].reshape(grid_size, grid_size).detach().numpy()
    mask = cv2.resize(mask / mask.max(), im.size)[..., np.newaxis]
    result = (mask * im).astype("uint8")

    fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(16, 16))
    ax1.set_title('Original')
    ax2.set_title('Attention Map_%d Layer' % (i+1))
    _ = ax1.imshow(im)
    _ = ax2.imshow(result)
    # plt.imsave(f'_results/ViT_B_mask_layer_{i:02d}.jpg', mask[:, :, 0])
    cv2.imwrite(f'_results/ViT_B_mask_layer{i:02d}.jpg', cv2.applyColorMap(np.uint8(mask[:, :, 0] * 255.), cv2.COLORMAP_INFERNO))