# Visualizing ViT Attention

In [None]:
try:
    from vit_rollout import VITAttentionRollout,rollout
    from vit_explain import show_mask_on_image
except ModuleNotFoundError as s:
    print('Installing required files...')
    url =f"https://raw.githubusercontent.com/lilloukas/vit-explain/main/vit_rollout.py"
    url2 = f"https://raw.githubusercontent.com/jacobgil/vit-explain/main/vit_explain.py"
    url3 = f"https://raw.githubusercontent.com/jacobgil/vit-explain/main/vit_grad_rollout.py"
    urls = url,url2,url3
    for url in urls:
        !wget --no-cache --backups=1 {url}
    from vit_rollout import VITAttentionRollout,rollout
    from vit_explain import show_mask_on_image
try:
    import timm
except:
    !pip install timm
    import timm
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
from sklearn.model_selection import StratifiedKFold
from tqdm.notebook import tqdm
import glob
from PIL import Image
import cv2
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from mpl_toolkits.axes_grid1 import ImageGrid
import math
import os
import random
colab = False

# IMGS_PATH is the location of the images to be evaluated
if colab:
    MODEL_PATH ='/content/gdrive/Shareddrives/520 Project/Saved Models/ViT/best_ViT_one_layer.pth'
else:
    MODEL_PATH ='/projectnb2/dl523/projects/Sarcasm/520 Project/Saved_Models/best_ViT_one_layer.pth'
    IMGS_PATH = '/projectnb/dl523/students/colejh/520/wikipaintings_small/wikipaintings_test' 

# basic pre-processing tasks for proper ViT data ingestion
our_ViT = timm.create_model('vit_huge_patch14_224_in21k', pretrained = True, num_classes = 25)
our_ViT.load_state_dict(torch.load(MODEL_PATH))
config = resolve_data_config({}, model=our_ViT)
transform = create_transform(**config)

# Putting model on GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
our_ViT.to(device)
our_ViT.eval()




In [None]:
# img_path = '/projectnb2/dl523/projects/Sarcasm/520 Project/0x0.jpg'
def visualize_attention(model,transform,img_path,img_size = 224,show_image = False,return_both = False,head_fusion = 'max'):
    '''
    Visualize the attention for a given image. Overlays the heatmap produced by every attention head in Vision Transformer
    Returns the predicted art style, the heatmap image, and if desired, the original input image
    Model: vision transformer model
    transform: transform's applied to images which are input to the model
    img_size: dependent on the vision transformer architechture 
    show_image: plot the heatmap image 
    return_both: if true, returns the heatmap image and the original image 
    head_fusion: describes how the attention head values are calculated: options are 'mean','max', and 'min'
    
    '''
    img = Image.open(img_path)
    test = transform(img).unsqueeze(0).to(device)
    attention_rollout = VITAttentionRollout(model, head_fusion = head_fusion)
    output,mask = attention_rollout(test)

    img = img.resize((224, 224))
    np_img = np.array(img)[:,:,::-1]
    np_img = cv2.cvtColor(np_img,cv2.COLOR_BGR2RGB)
    mask = cv2.resize(mask,(np_img.shape[1],np_img.shape[0]))
    result = show_mask_on_image(np_img,mask)
    rgb_img = cv2.cvtColor(result, cv2.COLOR_BGR2RGB)
    if show_image:
        plt.imshow(rgb_img)
        plt.show()
    if return_both:
        return output,rgb_img,np_img
    return output,rgb_img

In [None]:
# For dynamically adjusting the layout of the image grid
def largest_factor_pair(dim):
    '''
    Returns the largest factor pair for the input value
    '''
    factor_pairs = []
    for i in range(1, int(math.sqrt(dim))+1):
        if dim % i == 0:
            factor_pairs.append((i, dim / i))
    return factor_pairs[-1]


In [None]:
RANDOM_CHOICES = True # For picking a random image from every art style to visualize attention
styles = [ style for style in os.listdir(IMGS_PATH) if os.path.isdir(os.path.join(IMGS_PATH, style)) ]
styles = sorted(styles)

# Storing all the information for each image
attentions = []
originals = []
original_style = []
attention_style = []

# If RANDOM_CHOICES is false, will show image grid of attentions for only the selected styles
selected_styles = [random.choice(styles),random.choice(styles)]
if RANDOM_CHOICES:
    for style in styles:
        style_path = os.path.join(IMGS_PATH, style)
        os.chdir(style_path)
        images = glob.glob("*.jpg")
        img = random.choice(images)
        output,att,orig = visualize_attention(our_ViT,transform,style_path+'/'+img,return_both = True,head_fusion = 'max')
        attentions.append(att)
        originals.append(orig)
        original_style.append(style)
        predicted = torch.argmax(output,1).cpu().item()
        attention_style.append(styles[predicted])

else:
    for style in styles:
        if style in selected_styles:
            style_path = os.path.join(IMGS_PATH, style)
            os.chdir(style_path)

            # List of all images in given style directory
            images = glob.glob("*.jpg") 
            for img in images:
                output,att,orig = visualize_attention(our_ViT,transform,style_path+'/'+img,return_both = True,head_fusion = 'max')
                attentions.append(att)
                originals.append(orig)
                original_style.append(style)
                predicted = torch.argmax(output,1).cpu().item()
                attention_style.append(styles[predicted])


In [None]:
# Plots a grid of attention images and their corresponding original image
if originals:
    if RANDOM_CHOICES:
        imgs = []
        labels =[]
        for first,pred,second,true in zip(attentions,attention_style,originals,original_style):
            imgs.append(first)
            labels.append(pred)
            imgs.append(second)
            labels.append(true)
        fig = plt.figure(figsize=(35., 35.))
        grid = ImageGrid(fig, 111, 
                         nrows_ncols=(5,10),  
                         axes_pad=0.3, 
                         )
        for ax, im,lab in zip(grid, imgs,labels):

            ax.imshow(im)
            ax.set_title(lab)
        plt.show()
    else:
        imgs = []
        labels =[]
        for first,pred,second,true in zip(attentions,attention_style,originals,original_style):
            imgs.append(first)
            labels.append(pred)
            imgs.append(second)
            labels.append(true)
        fig = plt.figure(figsize=(25., 25.))
        dim = largest_factor_pair(len(imgs))
        grid = ImageGrid(fig, 111,  
                         nrows_ncols=(int(dim[0]), int(dim[1])), 
                         axes_pad=0.3, 
                         )
        for ax, im,lab in zip(grid, imgs,labels):

            ax.imshow(im)
            ax.set_title(lab)
        plt.show()