In [None]:
import typing
import io
import os

import torch
import numpy as np
import cv2
import matplotlib.pyplot as plt

from urllib.request import urlretrieve

from utils.dataset import *
from PIL import Image
from torchvision import transforms

from models.modeling import VisionTransformer, CONFIGS
from torch.utils.data import DataLoader, RandomSampler, DistributedSampler, SequentialSampler
import torchvision

In [None]:
os.makedirs("attention_data", exist_ok=True)
if not os.path.isfile("attention_data/ilsvrc2012_wordnet_lemmas.txt"):
    urlretrieve("https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt", "attention_data/ilsvrc2012_wordnet_lemmas.txt")
if not os.path.isfile("attention_data/ViT-B_16-224.npz"):
    urlretrieve("https://storage.googleapis.com/vit_models/imagenet21k+imagenet2012/ViT-B_16-224.npz", "attention_data/ViT-B_16-224.npz")

imagenet_labels = dict(enumerate(open('attention_data/ilsvrc2012_wordnet_lemmas.txt')))

In [None]:
# Test Image

# Prepare Model
config = CONFIGS["ViT-B_16"]
model = VisionTransformer(config, num_classes=1938, zero_head=False, img_size=448, vis=True)
checkpoint=torch.load("/home/deep/junwang/MTGCV_V/soygbl_base_linear_ns290400_2e-2_bs4_cos.log_checkpoint.bin")
model.load_state_dict(checkpoint)
model.eval()

transform = transforms.Compose([
    transforms.Resize((600, 600)),
    transforms.CenterCrop((448,448)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])

In [None]:
config = CONFIGS["ViT-B_16"]
model = make_model(config, None, zero_head=True, num_classes=200, vis=True)
checkpoint=torch.load('/home/ubuntu/junwang/paper/aaai2021/inter_vit_jun/output/soybean200_InterViT_checkpoint.bin')
model.load_state_dict(checkpoint)
model.eval()

transform = transforms.Compose([
    transforms.Resize((600, 600)),
    transforms.CenterCrop((448,448)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [None]:
data_name="soybean2000"
dataset=eval(data_name)(root='./data/'+data_name, is_train=False, transform=transform)
train_sampler = RandomSampler(dataset)
test_loader = DataLoader(dataset,sampler=train_sampler,batch_size=1,num_workers=4,pin_memory=True)

In [None]:
logits, att_mat = model(x.unsqueeze(0))

att_mat = torch.stack(att_mat).squeeze(1)

# Average the attention weights across all heads.
att_mat = torch.mean(att_mat, dim=1)

# To account for residual connections, we add an identity matrix to the
# attention matrix and re-normalize the weights.
residual_att = torch.eye(att_mat.size(1))
aug_att_mat = att_mat + residual_att
aug_att_mat = aug_att_mat / aug_att_mat.sum(dim=-1).unsqueeze(-1)

# Recursively multiply the weight matrices
joint_attentions = torch.zeros(aug_att_mat.size())
joint_attentions[0] = aug_att_mat[0]

for n in range(1, aug_att_mat.size(0)):
    joint_attentions[n] = torch.matmul(aug_att_mat[n], joint_attentions[n-1])
    
# Attention from the output token to the input space.
v = joint_attentions[-1]
grid_size = int(np.sqrt(aug_att_mat.size(-1)))
mask = v[0, 1:].reshape(grid_size, grid_size).detach().numpy()
mask = cv2.resize(mask / mask.max(), im.size)[..., np.newaxis]
result = (mask * im).astype("uint8")

In [None]:
def generate_colors(num_colors):
    """
    Generate distinct value by sampling on hls domain.

    Parameters
    ----------
    num_colors: int
        Number of colors to generate.

    Returns
    ----------
    colors_np: np.array, [num_colors, 3]
        Numpy array with rows representing the colors.

    """
    colors=[]
    for i in np.arange(0., 360., 360. / num_colors):
        hue = i/360.
        lightness = 0.5
        saturation = 0.9
        colors.append(colorsys.hls_to_rgb(hue, lightness, saturation))
    colors_np = np.array(colors)*255.

    return colors_np

def show_att_on_image(img, mask, output):
    """
    Convert the grayscale attention into heatmap on the image, and save the visualization.

    Parameters
    ----------
    img: np.array, [H, W, 3]
        Original colored image.
    mask: np.array, [H, W]
        Attention map normalized by subtracting min and dividing by max.
    output: str
        Destination image (path) to save.

    Returns
    ----------
    Save the result to output.
c
    """
    img_h, img_w = img.size[0], img.size[1]
    plt.subplots(nrows=1, ncols=1, figsize=(0.02*img_h, 0.02*img_w))
    plt.axis('off')
    plt.imshow(img, alpha=1)
    normed_mask = mask / mask.max()
    normed_mask = (normed_mask * 255).astype('uint8')
    plt.imshow(normed_mask, alpha=0.5, interpolation='nearest', cmap='jet')
    
    plt.savefig(output)
    plt.close()
    #heatmap = cv2.applyColorMap(np.uint8(255*mask), cv2.COLORMAP_JET)
    #heatmap = np.float32(heatmap) / 255

    # add heatmap onto the image
    #merged = heatmap + np.float32(img)

    # re-scale the image
    #merged = merged / np.max(merged)
    #cv2.imwrite(output, np.uint8(255 * merged))

In [None]:
fig_rows = 1
fig_cols = 1
f_assign, axarr_assign = plt.subplots(fig_rows, fig_cols, figsize=(fig_cols*2,fig_rows*2))

root=os.path.join('./visualization',data_name)
with torch.no_grad():
    for i,data in enumerate(test_loader):
        if i>=500:break
        dir_path_attn=os.path.join(root,str(i),'attention_map')
        os.makedirs(dir_path_attn, exist_ok=True)
        x , label=data[0].squeeze(0),data[1]
        logits,  att_mat= model(x.unsqueeze(0))
        att_mat = torch.stack(att_mat).squeeze(1)
        ### visualize attention map
        att_mat = torch.mean(att_mat, dim=1)
        residual_att = torch.eye(att_mat.size(1))
        aug_att_mat = att_mat + residual_att
        aug_att_mat = aug_att_mat / aug_att_mat.sum(dim=-1).unsqueeze(-1)

        # Recursively multiply the weight matrices
        joint_attentions = torch.zeros(aug_att_mat.size())
        joint_attentions[0] = aug_att_mat[0]

        for n in range(1, aug_att_mat.size(0)):
            joint_attentions[n] = torch.matmul(aug_att_mat[n], joint_attentions[n-1])
    
        # Attention from the output token to the input space.
        v = joint_attentions[-1]
        
        grid_size = int(np.sqrt(aug_att_mat.size(-1)))
        mask = v[0, 1:].reshape(grid_size, grid_size).detach().numpy()
        mask = cv2.resize(mask / mask.max(), (448,448))[..., np.newaxis]
        mask = mask[:,:,0]
        
        save_input = transforms.Normalize(mean=(0, 0, 0),std=(1/0.229, 1/0.224, 1/0.225))(x.data.cpu())
        save_input = transforms.Normalize(mean=(-0.485, -0.456, -0.406),std=(1, 1, 1))(save_input)
        save_input = torch.nn.functional.interpolate(save_input.unsqueeze(0), size=(448, 448), mode='bilinear', align_corners=False).squeeze(0)
        img = torchvision.transforms.ToPILImage()(save_input)
        
        show_att_on_image(img, mask, os.path.join(dir_path_attn,'heatmap.png'))
        

        img.save(dir_path_attn+"/input.png")

In [None]:
fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(16, 16))

ax1.set_title('Original')
ax2.set_title('Attention Map')
_ = ax1.imshow(im)
_ = ax2.imshow(result)

probs = torch.nn.Softmax(dim=-1)(logits)
top5 = torch.argsort(probs, dim=-1, descending=True)
print("Prediction Label and Attention Map!\n")
for idx in top5[0, :5]:
    print(f'{probs[0, idx.item()]:.5f} : {imagenet_labels[idx.item()]}', end='')

### Reference
* [attention_flow](https://github.com/samiraabnar/attention_flow)
* [vit-keras](https://github.com/faustomorales/vit-keras)

In [None]:
for i, v in enumerate(joint_attentions):
    # Attention from the output token to the input space.
    mask = v[0, 1:].reshape(grid_size, grid_size).detach().numpy()
    mask = cv2.resize(mask / mask.max(), im.size)[..., np.newaxis]
    result = (mask * im).astype("uint8")

    fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(16, 16))
    ax1.set_title('Original')
    ax2.set_title('Attention Map_%d Layer' % (i+1))
    _ = ax1.imshow(im)
    _ = ax2.imshow(result)