In [None]:
import os
from PIL import Image
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import torch
import torch.nn as nn 
import cv2
import numpy as np

import vision_transformer as vits
from explanation_generator import Baselines

device = torch.cuda.set_device(0) # set device
print('GPU:', torch.cuda.get_device_name(device))

In [None]:
# paremeter setting
n_last_blocks = 4
pretrained_weights = "../SNNA/ckp/backbone_200.pth" # pretrained weights for backbone
checkpoint_key = "teacher"
arch = "vit_small"
patch_size = 8
num_labels = 4
classifier_weights_dir = "../SNNA/ckp" # pretrained weights for linear classifier
image_size = 360 # The image short side is resized to 360

# construct backbone model
backbone = vits.__dict__["vit_small"](patch_size=patch_size, num_classes=0)
embed_dim = backbone.embed_dim * n_last_blocks

# load backbone weights to evaluate
state_dict = torch.load(pretrained_weights, map_location="cpu")
if checkpoint_key is not None and checkpoint_key in state_dict:
    print(f"Take key {checkpoint_key} in provided checkpoint dict")
    state_dict = state_dict[checkpoint_key]
# remove `module.` prefix
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
# remove `backbone.` prefix induced by multicrop wrapper
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
# load state dict for backbone
msg = backbone.load_state_dict(state_dict, strict=False)
print('Pretrained weights found at {} and loaded with msg: {}'.format(pretrained_weights, msg))
print(f"Model {arch} built.")

# construct classifier
class LinearClassifier(nn.Module):
    """Linear layer to train on top of frozen features"""
    def __init__(self, dim, num_labels=1000):
        super(LinearClassifier, self).__init__()
        self.num_labels = num_labels
        self.linear = nn.Linear(dim, num_labels)
        self.linear.weight.data.normal_(mean=0.0, std=0.01)
        self.linear.bias.data.zero_()

    def forward(self, x):
        # flatten
        x = x.view(x.size(0), -1)
        # linear layer
        return self.linear(x)

# construct a linear classifier head
linear_classifier = LinearClassifier(embed_dim, num_labels)
# load pretrained weights for linear classifier
state_dict = torch.load(os.path.join(classifier_weights_dir, "classifier.pth.tar"), map_location="cpu")["state_dict"]
# remove `module.` prefix
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
# load state dict for linear classifier
linear_classifier.load_state_dict(state_dict, strict=True)

# construct the model with the backbone and the linear classifier
class Model(nn.Module):
    def __init__(self, backbone, head):
        super(Model, self).__init__()
        self.backbone = backbone
        self.head = head

    def forward(self, x, register_hook=None):
        x = x.unsqueeze(0).cuda() # (1, 3, w, h) add a batch dimension
        intermediate_output = self.backbone.get_intermediate_layers(x, n_last_blocks, register_hook=register_hook)
        output = torch.cat([x[:, 0] for x in intermediate_output], dim=-1)
        logits = self.head(output)
        return logits
    
model = Model(backbone, linear_classifier)
model.cuda()
model.eval()

In [None]:
# load the image 
input_image  = Image.open('../SNNA/data/1.jpg')
plt.figure(figsize=(10, 10))
plt.axis('off')
plt.imshow(input_image)

In [None]:
# ============ preparing data ... ============
transform = transforms.Compose([
    transforms.Resize(image_size, interpolation=3),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

image = transform(input_image)
# make the image divisible by the patch size
w, h = image.shape[1] - image.shape[1] % patch_size, image.shape[2] - image.shape[2] % patch_size
img = image[:, :w, :h] 
print(f"Image shape: {img.shape}") 


# ============ forward ... ============
output = model(img)
print(torch.nn.functional.softmax(output, dim=-1))

w_featmap = img.shape[-2] // patch_size
h_featmap = img.shape[-1] // patch_size

def show_cam_on_image(img, mask):
    heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET) # chage a grayscale image to a color map
    heatmap = np.float32(heatmap) / 255
    cam = heatmap + np.float32(img)
    # cam = cam / np.max(cam)  # scale the value to [0, 1]
    cam = (cam-np.min(cam))/(np.max(cam)-np.min(cam)) # min-max normalization the cam value to [0, 1]
    return cam

# ============ generate attribution maps ... ============
# compare all methods in a row
attribution_generator = Baselines(model)
batch_size = 1

def generate_attribution(image, class_index=None):
    rawAttn = attribution_generator.rawAttn(image).detach()
    att_gradient = attribution_generator.att_gradient(image, index=class_index).detach()
    generic_att = attribution_generator.generic_att(image, index=class_index).detach()
    norm_att = attribution_generator.norm_att(image, index=class_index).detach()
    IGradient = attribution_generator.IGradient(image, index=class_index, steps=20).detach()
    SNNA = attribution_generator.SGradient_normAtt(image, index=class_index, magnitude=False).detach()

    return rawAttn, att_gradient, generic_att, norm_att, IGradient, SNNA


def generate_visualization(attribution, image):
    attribution = attribution.reshape(batch_size, 1, w_featmap, h_featmap)
    attribution = torch.nn.functional.interpolate(attribution, scale_factor=patch_size, mode='bilinear')
    attribution = attribution.reshape(w, h).data.cpu().numpy()
    attribution = (attribution - attribution.min()) / (attribution.max() - attribution.min()) # min-max normalization the attribution value to [0, 1]

    vis = show_cam_on_image(image, attribution)
    vis =  np.uint8(255 * vis)
    vis = cv2.cvtColor(vis, cv2.COLOR_RGB2BGR)
    return vis

rawAttn, att_gradient, generic_att, norm_att, IGradient, SNNA = generate_attribution(img) # cam: class activation map overlay on image

img = img.permute(1, 2, 0).data.cpu().numpy() # (3, w, h)->(w, h, 3)
img = (img - img.min()) / (img.max() - img.min())

rawAttn_vis = generate_visualization(rawAttn, img)
att_gradient_vis = generate_visualization(att_gradient, img)
generic_att_vis = generate_visualization(generic_att, img)
norm_att_vis = generate_visualization(norm_att, img)
IGradient_vis = generate_visualization(IGradient, img)
SNNA = generate_visualization(SNNA, img)

vis_list = [rawAttn_vis, att_gradient_vis, generic_att_vis, norm_att_vis, IGradient_vis, SNNA]

# plot baseline figure
fig, axs = plt.subplots(1, 6, figsize=(102, 10))
# remove white space between subplots
plt.subplots_adjust(wspace=0.01, hspace=0)
for ax, vis, in zip(axs, vis_list):
    ax.axis('off')
    ax.imshow(vis)
# plt.show()
# save the figure
# plt.savefig('../SNNA/output/campare.jpg', bbox_inches='tight', pad_inches=0.0)