In [41]:
import torch
import torch.nn as nn
from torchvision import transforms
from vision_transformer_pytorch import VisionTransformer

transform = transforms.Compose([
    transforms.Resize((384, 384)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225],
    ),
])

In [59]:
model = VisionTransformer.from_name('ViT-B_16', num_classes=2)

In [81]:
import numpy as np
import cv2
import matplotlib.pyplot as plt
from typing import List, Literal

def get_attention_map(img, get_mask=False):
    x = transform(img)
    x.size()

    logits, att_mat = model(x.unsqueeze(0))

    att_mat = torch.stack(att_mat).squeeze(1)

    # Average the attention weights across all heads.
    att_mat = torch.mean(att_mat, dim=1)

    # To account for residual connections, we add an identity matrix to the
    # attention matrix and re-normalize the weights.
    residual_att = torch.eye(att_mat.size(1))
    aug_att_mat = att_mat + residual_att
    aug_att_mat = aug_att_mat / aug_att_mat.sum(dim=-1).unsqueeze(-1)

    # Recursively multiply the weight matrices
    joint_attentions = torch.zeros(aug_att_mat.size())
    joint_attentions[0] = aug_att_mat[0]

    for n in range(1, aug_att_mat.size(0)):
        joint_attentions[n] = torch.matmul(aug_att_mat[n], joint_attentions[n-1])

    v = joint_attentions[-1]
    grid_size = int(np.sqrt(aug_att_mat.size(-1)))
    mask = v[0, 1:].reshape(grid_size, grid_size).detach().numpy()
    if get_mask:
        result = cv2.resize(mask / mask.max(), img.size)
    else:        
        mask = cv2.resize(mask / mask.max(), img.size)[..., np.newaxis]
        result = (mask * img).astype("uint8")
    
    return result

def plot_attention_map(original_img, att_map):
    fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(16, 16))
    ax1.set_title('Original')
    ax2.set_title('Attention Map Last Layer')
    _ = ax1.imshow(original_img)
    _ = ax2.imshow(att_map)
    
def rollout(attentions : List[torch.Tensor], discard_ratio : float, head_fusion : Literal['mean', 'max', 'min']):
    # attentions : List which consist of [1, channels, height, width] size of attention data
    result = torch.eye(attentions[0].size(-1))
    with torch.no_grad():
        for attention in attentions:
            if head_fusion == 'mean':
                attention_heads_fused = attention.mean(axis = 1)
            elif head_fusion == 'max':
                attention_heads_fused = attention.max(axis = 1)[0]
            elif head_fusion == 'min':
                attention_heads_fused = attention.min(axis = 1)[0]
            
            flat = attention_heads_fused.view(attention_heads_fused.size()[0], -1)
            _, indices = flat.topk(int(flat.size(-1) * discard_ratio), dim = -1, largest = False)
            indices = indices[indices != 0]
            flat[0, indices] = 0
            
            I = torch.eye(attention_heads_fused.size(-1))
            a = (attention_heads_fused + 1.0 * I) / 2.0
            a = a / a.sum(dim = -1)
            
            result = torch.matmul(a, result)
    
    mask = result
    mask = mask.numpy()
    mask = mask / np.max(mask)
    return mask

In [92]:
from PIL import Image

img = Image.open("./dataset/dur21_dis0/test/disruption/21325_1369_1390/0000.jpg")

In [84]:
class VITAttentionRollout:
    def __init__(self, model : nn.Module, attention_layer_name = 'attn', head_fusion = 'mean', discard_ratio = 0.9):
        self.model = model
        self.head_fusion = head_fusion
        self.discard_ratio = discard_ratio
        
        for name, module in self.model.named_modules():
            for name_, module_ in module.named_children():
                if attention_layer_name in name_:
                    module.register_forward_hook(self.get_attention)
        
        self.attentions = []
    
    def get_attention(self, module, input, output):
        self.attentions.append(output.cpu())
    
    def __call__(self, input_tensor : torch.Tensor):
        self.attentions = []
        
        with torch.no_grad():
            output = self.model(input_tensor)
        
        return rollout(self.attentions, self.discard_ratio, self.head_fusion)

In [93]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"

x = transform(img).to(device).unsqueeze(0)
model.to(device)

rollout_model = VITAttentionRollout(model)
att_mask = rollout_model(x)

In [102]:
img = Image.open("./dataset/dur21_dis0/test/disruption/21325_1369_1390/0000.jpg")
img = transform(img).numpy()

def plot_attention_map(original_img, att_map):
    att_map = cv2.resize(384, 384)
    fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(16, 16))
    ax1.set_title('Original')
    ax2.set_title('Attention Map Last Layer')
    _ = ax1.imshow(original_img)
    _ = ax2.imshow(att_map)

plot_attention_map(img, att_mask)

error: OpenCV(4.5.2) :-1: error: (-5:Bad argument) in function 'resize'
> Overload resolution failed:
>  - Can't parse 'dsize'. Input argument doesn't provide sequence protocol
>  - Can't parse 'dsize'. Input argument doesn't provide sequence protocol


In [100]:
img.shape

(3, 384, 384)