# Sanity Check (Cascading Randomization) for Saliency Maps: ResNet-18 ImageNet Example

In [1]:
%%capture
# do not display output on this cell
from torchvision import transforms, models
import torch
import numpy as np
import matplotlib.pyplot as plt
from src import util
from captum.attr import IntegratedGradients, Saliency, InputXGradient, GuidedBackprop
from src import util
import PIL
from ntpath import basename
import os

resnet = models.resnet18(pretrained=True)
resnet.eval()

In [2]:
# def show_image(file_path, resize=True, sztple=(299, 299)):
#     img = PIL.Image.open(file_path).convert('RGB')
#     if resize:
#         img = img.resize(sztple, PIL.Image.ANTIALIAS)
#     img = np.asarray(img)
#     return img

def preprocess_image(file_path):
    img = PIL.Image.open(file_path).convert('RGB')
    preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    img = preprocess(img)
    return img.unsqueeze(0)

def get_image_and_label(file_path, image_net_cls, normalize=False):
    img = PIL.Image.open(file_path).convert('RGB')
    img_name = basename(file_path)
    cls_name = img_name[img_name.find('_') + 1:img_name.find('.')].replace('_', ' ')
    label = image_net_cls[cls_name]
    if not normalize:
        preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor()
        ])
        img = preprocess(img)
        return img, label
    else:
        preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        img = preprocess(img)
        return img, label


class MyDataset(torch.utils.data.Dataset):
    def __init__(self, dir_path):
        super(MyDataset).__init__()
        self.dir_path = dir_path
        self.files = [img for img in sorted(os.listdir(dir_path)) if img.endswith(".JPEG")]
        with open("imagenet_classes.txt", "r") as f:
            self.image_net_cls = {s.strip(): idx for idx, s in enumerate(f.readlines())}
        
    # def __iter__(self):
    #     iter([get_image_and_label(x) for x in os.listdir(self.dir_path)])
    
    def __getitem__(self, idx):
        file_name = self.files[idx]
        return get_image_and_label(os.path.join(self.dir_path, file_name), self.image_net_cls, normalize=True)
        
    def __len__(self):
        return len(self.files)


class MyOriginalImages(MyDataset):
    def __init__(self, dir_path):
        super().__init__(dir_path)


    def __getitem__(self, idx):
        file_name = self.files[idx]
        return get_image_and_label(os.path.join(self.dir_path, file_name), self.image_net_cls, normalize=False)


# def show_image(im, title='', ax=None):
#     if ax is None:
#         plt.figure()
#     plt.axis('off')
#     im = ((im + 1) * 127.5).astype(np.uint8)
#     plt.imshow(im)
#     plt.title(title)

# img, label = get_image_and_label('imagenet-sample-images/n01440764_tench.JPEG')
# plt.imshow(img.squeeze(0).permute(1, 2, 0).numpy())
# with open('imagenet-sample-images/n01440764_tench.JPEG', '')
with open("imagenet_classes.txt", "r") as f:
    cls_index_to_name = {idx: s.strip() for idx, s in enumerate(f.readlines())}

In [3]:
dataset = MyDataset('imagenet-sample-images')
originals = MyOriginalImages('imagenet-sample-images')
# full dataset loader
# important that batch is 1!
# important that shuffle is False!
full_loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False)

#samples = [4, 13 ,23] hammerhead, junco, vulture
samples = [340, 4, 13, 430, 339] # zebra, hammerhead, junco, basketball, sorrel
# samples = np.random.choice(len(dataset), 2, replace=False)
dataset_subset = torch.utils.data.Subset(dataset, samples)
originals_subset = torch.utils.data.Subset(originals, samples)

# important that batch is 1!
# important that shuffle is False!
dataset_loader = torch.utils.data.DataLoader(dataset_subset, batch_size=1, shuffle=False)
originals_loader = torch.utils.data.DataLoader(originals_subset, batch_size=1, shuffle=False)

# img, label = get_image_and_label('imagenet-sample-images/n01440764_tench.JPEG', image_net_cls, normalize=True)
# with torch.no_grad():
#     output = resnet(img)
    
# probabilities = torch.nn.functional.softmax(output[0], dim=0)
# # Read the categories
# with open("imagenet_classes.txt", "r") as f:
#     categories = [s.strip() for s in f.readlines()]
# # Show top categories per image
# top5_prob, top5_catid = torch.topk(probabilities, 5)
# for i in range(top5_prob.size(0)):
#     print(categories[top5_catid[i]], top5_prob[i].item())
# for img, label in dataset_loader:
#     output = resnet(img)
#     print(output)
#     break

In [4]:
# define module paths for cascading randomization
module_paths = [
    ['fc'],
    ['layer4', '1'], ['layer4', '0'],
    ['layer3', '1'], ['layer3', '0'],
    ['layer2', '1'], ['layer2', '0'],
    ['layer1', '1'], ['layer1', '0'],
    ['bn1'], ['conv1']
]

In [5]:
%matplotlib agg
# visualize integrated gradients
fig, _ = util.visualize_cascading_randomization(resnet, module_paths, (InputXGradient, False), dataset_loader, originals_loader, cls_index_to_name, viz_method="blended_heat_map")
fig.savefig("figures/resnet_inputxgradient_cascading_randomization.png", bbox_inches="tight")

In [None]:
%matplotlib agg
fig, _ = util.visualize_cascading_randomization(resnet, module_paths, (InputXGradient, True), dataset_loader, originals_loader, cls_index_to_name, viz_method="blended_heat_map")
fig.savefig("figures/resnet_inputxgradient_smoothing_cascading_randomization.png", bbox_inches="tight")

In [7]:
%matplotlib agg
# multiple saliency maps for each example
for (image, label), (original, _) in zip(dataset_loader, originals_loader):
    fig, _ = util.visualize_cascading_randomization2(
        resnet,
        module_paths,
        [(Saliency, False), (Saliency, True), (InputXGradient, False), (InputXGradient, True), (GuidedBackprop, False), (IntegratedGradients, False)],
        ['Gradient', 'SmoothGrad', 'Gradient ⊙ Input', 'Gradient ⊙ Input-SG', 'Guided Back-propagation', 'Integrated Gradients'],
        (image, label),
        original,
        viz_method="heat_map"
    )
    fig.savefig("figures/resnet_cascading_randomization.png", bbox_inches="tight")
    break

Working on Saliency
Working on Saliency
Working on InputXGradient
Working on InputXGradient
Working on GuidedBackprop
Working on IntegratedGradients




Working on Saliency
Working on Saliency
Working on InputXGradient
Working on InputXGradient
Working on GuidedBackprop
Working on IntegratedGradients
Working on Saliency
Working on Saliency
Working on InputXGradient
Working on InputXGradient
Working on GuidedBackprop
Working on IntegratedGradients
Working on Saliency
Working on Saliency
Working on InputXGradient
Working on InputXGradient
Working on GuidedBackprop
Working on IntegratedGradients
Working on Saliency
Working on Saliency
Working on InputXGradient
Working on InputXGradient
Working on GuidedBackprop
Working on IntegratedGradients
Working on Saliency
Working on Saliency
Working on InputXGradient
Working on InputXGradient
Working on GuidedBackprop
Working on IntegratedGradients
Working on Saliency
Working on Saliency
Working on InputXGradient
Working on InputXGradient
Working on GuidedBackprop
Working on IntegratedGradients
Working on Saliency
Working on Saliency
Working on InputXGradient
Working on InputXGradient
Working on Gui

In [11]:
dic = util.ssim_saliency_comparison(
    resnet,
    module_paths,
        [(Saliency, False), (Saliency, True), (InputXGradient, False), (InputXGradient, True), (GuidedBackprop, False), (IntegratedGradients, False)],
        ['Gradient', 'SmoothGrad', 'Gradient ⊙ Input', 'Gradient ⊙ Input-SG', 'Guided Back-propagation', 'Integrated Gradients'], # integrated gradients takes a fuckton of time
    dataset_loader
    )

In [12]:
dic

{'Gradient': {'fc': 0.7007162118803472,
  'layer4_1': 0.5843336223542628,
  'layer4_0': 0.586383585026679,
  'layer3_1': 0.5460930649541866,
  'layer3_0': 0.5469156755935984,
  'layer2_1': 0.5016109011083099,
  'layer2_0': 0.4848122505316317,
  'layer1_1': 0.36857221626667397,
  'layer1_0': 0.2829490931925805,
  'bn1': 0.3849463556271995,
  'conv1': 0.378120144227316},
 'SmoothGrad': {'fc': 0.7751769180294782,
  'layer4_1': 0.6437133338100526,
  'layer4_0': 0.6017864010647115,
  'layer3_1': 0.6173261118545617,
  'layer3_0': 0.45944497687895786,
  'layer2_1': 0.4961996721413362,
  'layer2_0': 0.5191355010094192,
  'layer1_1': 0.5666334941430149,
  'layer1_0': 0.5414745766784212,
  'bn1': 0.5908051159945817,
  'conv1': 0.52257754807884},
 'Gradient ⊙ Input': {'fc': 0.8401561945275986,
  'layer4_1': 0.8152612303498616,
  'layer4_0': 0.7950904996460721,
  'layer3_1': 0.7925661351675649,
  'layer3_0': 0.7717537748259116,
  'layer2_1': 0.7707802096298083,
  'layer2_0': 0.7674233779167575,
  

In [13]:
%matplotlib agg
fig = plt.figure(figsize=(15, 6))
ax = fig.subplots()
#plot similarities
for key, value in dic.items():
    ax.plot(
        ['original'] + list(value.keys()), [1] + list(value.values()),
        label=key,
        linestyle='dashed', linewidth=3.5, marker='o', markersize=10
    )

ax.grid(True)
ax.legend()
fig.savefig("figures/resnet_ssim_cascading_randomization_2.png")

In [None]:
image = dic[(340, 0)]
plt.figure()
plt.axis('off')
plt.imshow(image, cmap="Reds")
print(image)
print(image.shape)
print(len(dic))
