In [1]:
import numpy as np
import torch
import torch.backends.cudnn as cudnn
from timm.models import create_model
import models_mamba
import utils
import os
from xai_utils import *
from class_mapper import CLS2IDX
import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


<h1>Load Model</h1>
Make sure to speiciy the model checkpoint path

In [2]:
model_type = 'vim_small_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2'
model_path = './vim_s_midclstok_80p5acc.pth'
num_classes = 1000
model = create_model(
    model_type,
    pretrained=False,
    num_classes=num_classes,
    drop_rate=0,
    drop_path_rate=0,
    drop_block_rate=None,
    img_size=224
)
checkpoint = torch.load(model_path, map_location='cpu')
model.load_state_dict(checkpoint['model'])

<All keys matched successfully>

<h1>Auxiliary Functions</h1>

In [3]:
from PIL import Image
import torchvision.transforms as transforms

IMAGENET_DEFAULT_MEAN = [0.485, 0.456, 0.406]
IMAGENET_DEFAULT_STD = [0.229, 0.224, 0.225]

def transform_for_eval(image_path, input_size=224):
    transform_eval = transforms.Compose([
        transforms.Resize(int(input_size)),
        transforms.CenterCrop(input_size),
        transforms.ToTensor(),
        transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD),
    ])
    img = Image.open(image_path).convert('RGB')
    transformed_img = transform_eval(img)
    return transformed_img

import cv2

invTrans = transforms.Compose([ transforms.Normalize(mean = [ 0., 0., 0. ],
                                                     std = [ 1/0.229, 1/0.224, 1/0.225 ]),
                                transforms.Normalize(mean = [ -0.485, -0.456, -0.406 ],
                                                     std = [ 1., 1., 1. ]),
                               ])

def show_cam_on_image(img, mask):
    heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
    heatmap = np.float32(heatmap) / 255
    cam = heatmap + np.float32(img)
    cam = cam / np.max(cam)
    return cam


def generate_visualization(original_image, transformer_attribution):
    transformer_attribution = transformer_attribution.reshape(1, 1, 14, 14)
    transformer_attribution = torch.nn.functional.interpolate(transformer_attribution, scale_factor=16, mode='bilinear')
    transformer_attribution = transformer_attribution.reshape(224, 224).cuda().data.cpu().numpy()
    transformer_attribution = (transformer_attribution - transformer_attribution.min()) / (transformer_attribution.max() - transformer_attribution.min())
    image_transformer_attribution = original_image.permute(1, 2, 0).data.cpu().numpy()
    image_transformer_attribution = (image_transformer_attribution - image_transformer_attribution.min()) / (image_transformer_attribution.max() - image_transformer_attribution.min())
    vis = show_cam_on_image(image_transformer_attribution, transformer_attribution)
    vis =  np.uint8(255 * vis)
    vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
    return vis

def print_preds(logits):
    prob = torch.softmax(logits, dim=1)
    class_indices = logits.data.topk(5, dim=1)[1][0].tolist()
    max_str_len = 0
    class_names = []
    for cls_idx in class_indices:
        class_names.append(CLS2IDX[cls_idx])
        if len(CLS2IDX[cls_idx]) > max_str_len:
            max_str_len = len(CLS2IDX[cls_idx])

    print('Top 5 classes:')
    for cls_idx in class_indices:
        output_string = '\t{} : {}'.format(cls_idx, CLS2IDX[cls_idx])
        output_string += ' ' * (max_str_len - len(CLS2IDX[cls_idx])) + '\t\t'
        output_string += 'value = {:.3f}\t prob = {:.1f}%'.format(logits[0, cls_idx], 100 * prob[0, cls_idx])
        print(output_string)

In [4]:
model = model.cuda()

In [5]:
image  = transform_for_eval('./images/1.jpg').unsqueeze(0).cuda()
raw_image = Image.open('./images/1.jpg')
map_raw_atten, logits = generate_raw_attn(model, image)
map_mamba_attr, _ = generate_mamba_attr(model, image)
map_rollout, _ = generate_rollout(model, image)
image = image.squeeze()

raw_attn = generate_visualization(invTrans(image).detach().cpu(), map_raw_atten)
mamba_attr = generate_visualization(invTrans(image).detach().cpu(), map_mamba_attr)
rollout = generate_visualization(invTrans(image).detach().cpu(), map_rollout)
print_preds(logits)
fig, axs = plt.subplots(1, 4, figsize=(10,10))
axs[0].imshow(raw_image)
axs[0].axis('off')
axs[1].imshow(raw_attn)
axs[1].axis('off')
axs[2].imshow(rollout)
axs[2].axis('off')
axs[3].imshow(mamba_attr)
axs[3].axis('off')


AttributeError: 'JpegImageFile' object has no attribute 'permute'

In [None]:
image  = transform_for_eval('./images/2.jpg').unsqueeze(0).cuda()
raw_image = Image.open('./images/2.jpg')
map_raw_atten, logits = generate_raw_attn(model, image)
map_mamba_attr, _ = generate_mamba_attr(model, image)
map_rollout, _ = generate_rollout(model, image)
image = image.squeeze()

raw_attn = generate_visualization(invTrans(image).detach().cpu(), map_raw_atten)
mamba_attr = generate_visualization(invTrans(image).detach().cpu(), map_mamba_attr)
rollout = generate_visualization(invTrans(image).detach().cpu(), map_rollout)
print_preds(logits)
fig, axs = plt.subplots(1, 4, figsize=(10,10))
axs[0].imshow(raw_image)
axs[0].axis('off')
axs[1].imshow(raw_attn)
axs[1].axis('off')
axs[2].imshow(rollout)
axs[2].axis('off')
axs[3].imshow(mamba_attr)
axs[3].axis('off')


In [None]:
image  = transform_for_eval('./images/3.jpg').unsqueeze(0).cuda()
raw_image = Image.open('./images/3.jpg')
map_raw_atten, logits = generate_raw_attn(model, image)
map_mamba_attr, _ = generate_mamba_attr(model, image)
map_rollout, _ = generate_rollout(model, image)
image = image.squeeze()

raw_attn = generate_visualization(invTrans(image).detach().cpu(), map_raw_atten)
mamba_attr = generate_visualization(invTrans(image).detach().cpu(), map_mamba_attr)
rollout = generate_visualization(invTrans(image).detach().cpu(), map_rollout)
print_preds(logits)
fig, axs = plt.subplots(1, 4, figsize=(10,10))
axs[0].imshow(raw_image)
axs[0].axis('off')
axs[1].imshow(raw_attn)
axs[1].axis('off')
axs[2].imshow(rollout)
axs[2].axis('off')
axs[3].imshow(mamba_attr)
axs[3].axis('off')


In [None]:
image  = transform_for_eval('./images/4.jpg').unsqueeze(0).cuda()
raw_image = Image.open('./images/4.jpg')
map_raw_atten, logits = generate_raw_attn(model, image)
map_mamba_attr, _ = generate_mamba_attr(model, image)
map_rollout, _ = generate_rollout(model, image)
image = image.squeeze()

raw_attn = generate_visualization(invTrans(image).detach().cpu(), map_raw_atten)
mamba_attr = generate_visualization(invTrans(image).detach().cpu(), map_mamba_attr)
rollout = generate_visualization(invTrans(image).detach().cpu(), map_rollout)
print_preds(logits)
fig, axs = plt.subplots(1, 4, figsize=(10,10))
axs[0].imshow(raw_image)
axs[0].axis('off')
axs[1].imshow(raw_attn)
axs[1].axis('off')
axs[2].imshow(rollout)
axs[2].axis('off')
axs[3].imshow(mamba_attr)
axs[3].axis('off')


In [None]:
image  = transform_for_eval('./images/5.jpg').unsqueeze(0).cuda()
raw_image = Image.open('./images/5.jpg')
map_raw_atten, logits = generate_raw_attn(model, image)
map_mamba_attr, _ = generate_mamba_attr(model, image)
map_rollout, _ = generate_rollout(model, image)
image = image.squeeze()

raw_attn = generate_visualization(invTrans(image).detach().cpu(), map_raw_atten)
mamba_attr = generate_visualization(invTrans(image).detach().cpu(), map_mamba_attr)
rollout = generate_visualization(invTrans(image).detach().cpu(), map_rollout)
print_preds(logits)
fig, axs = plt.subplots(1, 4, figsize=(10,10))
axs[0].imshow(raw_image)
axs[0].axis('off')
axs[1].imshow(raw_attn)
axs[1].axis('off')
axs[2].imshow(rollout)
axs[2].axis('off')
axs[3].imshow(mamba_attr)
axs[3].axis('off')
