In [None]:
import torch 
import torch.nn as nn
from torchvision import transforms

import os
import numpy as np 

from typing import Optional, List, Union

import matplotlib.pyplot as plt 
import matplotlib

from tqdm.notebook import tqdm

from PIL import Image

Here is a simple example to understand the shape of following attention visualization function

In [20]:
import torch
import torch.nn as nn

dwc = nn.Conv2d(3, 3, kernel_size = 1, stride=1, padding=0)
weights = dwc.weight
print(weights.shape)

second_dwc = nn.Conv2d(in_channels=3, out_channels=128, kernel_size=7, stride=1, padding=1)
second_weights = second_dwc.weight
print(second_weights.shape)

torch.Size([3, 3, 1, 1])
torch.Size([128, 3, 7, 7])


The weights of the convolution kernel would be: out_channels, in_channels, kernel_size, kernel_size

In [35]:
Attention_1 = torch.nn.MultiheadAttention(embed_dim=512, num_heads=8, batch_first=True)
dummy_input = torch.rand(32, 128, 512)
q = dummy_input
k = dummy_input
v = dummy_input
with torch.no_grad():
    attn_output, attn_output_weights = Attention_1(q, k, v)

print(attn_output.shape)

torch.Size([32, 128, 512])


The weights of the attention would be (batch_size, input_length, model_dim)

In [None]:
class Attention_Visualization(nn.Module):
    def __init__(self, 
                 qk: torch.Tensor,
                 attn: torch.Tensor,
                 kernel: Optional[None],
                 name: Optional[str]=None):
        """
        The kernel here would be the linear attention dwc kernel
        Not the ReLU kernel used in the code
        """
        super().__init__()
        assert (qk is None and attn is not None) or qk is not None and attn is None
        if qk is not None:
            self.attn = self.qk[0] * self.qk[1].transpose(-2, -1)
            self.attn = self.attn / self.attn.sum(dim=-1, keepdim=True)

        if len(self.attn.shape) == 4:
            self.attn = self.attn[:, 0, :, :]
        self.kernel = kernel
        self.name = name

        os.makedirs('./visualize', exist_ok=True)

    @staticmethod
    def set_flag(path: str, flag):
        with open(os.path.join(path, 'flag.txt'), mode='w') as f:
            f.write(str(flag))

    @staticmethod
    def get_flag(path):
        if not os.path.exists(os.path.join(path, 'flag.txt')):
            flag=0
        else:
            with open(os.path.join(path, 'flag.txt'), mode='r') as f:
                flag = int(f.readlines()[-1])
        return flag
    
    @staticmethod
    def mask_image(image:Union[torch.Tensor, np.array], attn:Union[torch.Tensor, np.array], color=None, alpha: int=0.3):
        """
        If you want to call this function, please call by using self.mask_image()...
        """
        pass

    def get_attn_matrix(self) -> torch.Tensor:
        """
        this function would only get the first head of the attn matrix, since we would 
        use a loop for all the attn heads we have in the next function,
        check the github issue about the visualization, this attention matrix has a shape (1 ,N, N)
        """
        attn_matrix = self.attn[0, :, :].clone()
        if self.kernel is not None:
            kernel = self.kernel[0, 0, :, :].clone()
            # the attn_matrix.shape[0] would be height*width = N, 
            # the following function will get the height and width of the input image
            a = int(attn_matrix.shape[0] ** 0.5)
            n = int(kernel.shape[0] - 1 /2)
            conv_masks = torch.zeros(size=(attn_matrix[0], kernel.shape[1] + n * (a+1) *2))
            for i in range(attn_matrix.shape[0]):
                for j in range(kernel.shape[0]):
                    conv_masks[i, i + j*a: i + j*a + kernel.shape[1] + n * (a+1) *2]
            conv_mask = conv_mask[:, n * (a+1):n * (a+1) + attn_matrix.shape[1]]
            attn_matrix = attn_matrix + conv_mask
            # return the absolute value of the attention matrix in case the value of the attention matrix would be minus
            attn_matrix = torch.abs(attn_matrix)

        # normalization of the attention matrix value for visualization
        attn_matrix = attn_matrix/attn_matrix.sum(dim=-1, keepdim=True)
        # exponential enlarge the attention matrix value so the visualization result would be better
        attn_matrix = attn_matrix * (attn_matrix.shape[0]/196) * 10
        attn_matrix[attn_matrix>1] = 1
        return attn_matrix
    
    def get_all_attn(self, max_num:Optional[int]=None) -> torch.Tensor:
        """ 
        This function tries to get all the attention matrix,  the first step would be using the get_attn_matrix to get a single attention matrix
        the attn_remain function works for ensuring attention matrix has a square shape
        """
        attn = self.get_attn_matrix()
        shape_remain = attn.shape[1] - int(int(attn.shape[1] ** 0.5)**2)
        n = attn.shape[0]
        m = attn.shape[1] - shape_remain
        shape = [int(m ** 0.5), int(m ** 0.5)]
        if max_num is not None:
            import math
            # math.ceil() function returns the smallest integral value greater than the number.
            # for example math.ceil(1.85) = 2
            sep = math.ceil(n/max_num)
            n = n // sep
        all_attn = []
        for i in range(n):
            if max_num is None:
                temp = attn[i, shape_remain:]
            else:
                temp = attn[i * sep, shape_remain:]
            temp_np = temp.reshape(shape[0], shape[1]).cpu().numpy()
            all_attn.append(temp_np)
        return all_attn

    def visualize_all_attentions(self, max_num:Optional[None], image:Optional[torch.Tensor]=None, **kwargs) -> torch.Tensor:
        path = './visualize/' + self.name + '_all'
        if not os.path.exists(path):
            os.mkdir(path)
        all_attn = self.get_all_attn(max_num=None, **kwargs)
        flag = self.get_flag(path=path)
        count = flag
        if not os.path.exists(path + '/' + self.name + '_' + str(count)):
            os.mkdir(path + '/' + self.name + '_' + str(count))
        if image is None:
            norm = matplotlib.colors.Normalize(vmin=0, vmax=1)
            for i in tqdm(range(len(all_attn))):
                # plt.matshow: display an array as a matrix
                plt.matshow(all_attn[i], cmap='Blues', norm=norm)
                plt.colorbar()
                plt.title('Attention Mask')
                plt.savefig(path + '/' + self.name + '_' + str(count) + '/' + str(i) + '.png', dpi=600)
                plt.close()
        else:
            image = np.array(Image.open(image))
            for i in tqdm(range(len(all_attn))):
                result = self.mask_image(image, all_attn[i])
                result.save(path + '/' + self.name + '_' + str(count) + '/' + str(i) + '.png')
            if count == 0:
                n = all_attn[0].shape[0]*all_attn[0].shape[1]
                # sep = 1 this is not needed
                if max_num is not None:
                    import math
                    sep = math.ceil(n/max_num)
                    n = n//sep
                    if not os.path.exists(path + './query'):
                        os.mkdir(path + './query')
                    for i in range(n):
                        attn = np.zeros(shape=(all_attn.shape[0], all_attn.shape[1], dtype=float))
                        attn[(i * sep) // all_attn[0].shape[1], (i * sep) % all_attn[0].shape[1]] = 1.0
                        result = self.mask_image(image, attn, alpha=-1, color=[255., 33., 33.])
                        result.save(path + '/query/' + str(i) + '.png')
        self.set_flag(path=path,flag=flag+1)

debug