In [2]:
import typing
import io
import os

import torch
import numpy as np
import cv2
import matplotlib.pyplot as plt

from urllib.request import urlretrieve

from PIL import Image
from torchvision import transforms

from VIT import VisionTransformer, CONFIGS

In [None]:
# Prepare Model
config = CONFIGS["ViT-B_16"]
model = VisionTransformer(config, num_classes=200, zero_head=False, img_size=224, vis=True, extra_attention=None)
# model.load_from(np.load("pretrain_weights/imagenet21k_ViT-B_16.npz"))

weights = torch.load("Tiny ImageNet_origin_12_checkpoint.bin")

model.load_state_dict(weights)
model.eval()

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])
im = Image.open("visualization/test.JPEG")
x = transform(im)
x.size()

torch.Size([3, 224, 224])

In [10]:
logits, att_mat = model(x.unsqueeze(0))

att_mat = torch.stack(att_mat).squeeze(1)

# Average the attention weights across all heads.
att_mat = torch.mean(att_mat, dim=1)

# To account for residual connections, we add an identity matrix to the
# attention matrix and re-normalize the weights.
residual_att = torch.eye(att_mat.size(1))
aug_att_mat = att_mat + residual_att
aug_att_mat = aug_att_mat / aug_att_mat.sum(dim=-1).unsqueeze(-1)

# Recursively multiply the weight matrices
joint_attentions = torch.zeros(aug_att_mat.size())
joint_attentions[0] = aug_att_mat[0]
for n in range(1, aug_att_mat.size(0)):
    joint_attentions[n] = torch.matmul(aug_att_mat[n], joint_attentions[n-1])
    
# Attention from the output token to the input space.
v = joint_attentions[-1]
grid_size = int(np.sqrt(aug_att_mat.size(-1)))
mask = v[0, 1:].reshape(grid_size, grid_size).detach().numpy()

mask_threshold = cv2.resize(mask / mask.max(), im.size)[..., np.newaxis]

mask = np.power(mask, 2.5)   
mask = (mask - mask.min()) / (mask.max() - mask.min())

mask = cv2.resize(mask / mask.max(), im.size)[..., np.newaxis]
result = (mask * im).astype("uint8")

# print(result.shape)

threshold1 = 0.4
binary_mask = (mask_threshold > threshold1).astype(np.uint8) * 255

contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

im_marked1 = np.array(im.copy())  # 转换为NumPy数组
cv2.drawContours(im_marked1, contours, -1, (0, 255, 0), 2) 


threshold2 = 0.6
binary_mask = (mask_threshold > threshold2).astype(np.uint8) * 255

contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

im_marked2 = np.array(im.copy())  # 转换为NumPy数组
cv2.drawContours(im_marked2, contours, -1, (0, 255, 0), 2) 


threshold3 = 0.8
binary_mask = (mask_threshold > threshold3).astype(np.uint8) * 255

contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

im_marked3 = np.array(im.copy())  # 转换为NumPy数组
cv2.drawContours(im_marked3, contours, -1, (0, 255, 0), 2) 


plt.imshow(im)


RuntimeError: The size of tensor a (12) must match the size of tensor b (197) at non-singleton dimension 0

In [None]:
fig, (ax2, ax3, ax4, ax5) = plt.subplots(ncols=4, figsize=(20, 20))

# ax1.set_title('Input Image')
ax2.set_title(f'Attention Outline(threshold={threshold1})')
ax3.set_title(f'Attention Outline(threshold={threshold2})')
ax4.set_title(f'Attention Outline(threshold={threshold3})')
ax5.set_title('Attention Map * Input Image')
# _ = ax1.imshow(im)
_ = ax2.imshow(im_marked1)
_ = ax3.imshow(im_marked2)
_ = ax4.imshow(im_marked3)
_ = ax5.imshow(im)

ax5.imshow(im)
heatmap = ax5.imshow(mask.squeeze(), cmap='jet', alpha=0.5)  # 半透明热力图叠加

### Reference
* [attention_flow](https://github.com/samiraabnar/attention_flow)
* [vit-keras](https://github.com/faustomorales/vit-keras)