# Sanity Check (Cascading Randomization) for Saliency Maps: ResNet-18 ImageNet Example

In [1]:
%%capture
from torchvision import transforms, models
import torch
import numpy as np
import matplotlib.pyplot as plt
from src import util
from captum.attr import IntegratedGradients, Saliency, InputXGradient, GuidedBackprop
from src import util
import PIL
from ntpath import basename
import os
import random

resnet = models.resnet18(pretrained=True)
resnet.eval()

In [2]:
def preprocess_image(file_path):
    img = PIL.Image.open(file_path).convert('RGB')
    preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    img = preprocess(img)
    return img.unsqueeze(0)

def get_image_and_label(file_path, image_net_cls, normalize=False):
    img = PIL.Image.open(file_path).convert('RGB')
    img_name = basename(file_path)
    cls_name = img_name[img_name.find('_') + 1:img_name.find('.')].replace('_', ' ')
    label = image_net_cls[cls_name]
    if not normalize:
        preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor()
        ])
        img = preprocess(img)
        return img, label
    else:
        preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        img = preprocess(img)
        return img, label


class MyDataset(torch.utils.data.Dataset):
    def __init__(self, dir_path):
        super(MyDataset).__init__()
        self.dir_path = dir_path
        self.files = [img for img in sorted(os.listdir(dir_path)) if img.endswith(".JPEG")]
        with open("imagenet_classes.txt", "r") as f:
            self.image_net_cls = {s.strip(): idx for idx, s in enumerate(f.readlines())}

    def __getitem__(self, idx):
        file_name = self.files[idx]
        return get_image_and_label(os.path.join(self.dir_path, file_name), self.image_net_cls, normalize=True)

    def __len__(self):
        return len(self.files)


class MyOriginalImages(MyDataset):
    def __init__(self, dir_path):
        super().__init__(dir_path)


    def __getitem__(self, idx):
        file_name = self.files[idx]
        return get_image_and_label(os.path.join(self.dir_path, file_name), self.image_net_cls, normalize=False)


with open("imagenet_classes.txt", "r") as f:
    cls_index_to_name = {idx: s.strip() for idx, s in enumerate(f.readlines())}

In [3]:
dataset = MyDataset('imagenet-sample-images')
originals = MyOriginalImages('imagenet-sample-images')
# full dataset loader
# important that batch is 1!
# important that shuffle is False!
full_loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False)

samples = [340, 4, 13, 430, 339] # zebra, hammerhead, junco, basketball, sorrel
dataset_subset = torch.utils.data.Subset(dataset, samples)
originals_subset = torch.utils.data.Subset(originals, samples)
random.seed(123)
big_dataset_subset = torch.utils.data.Subset(originals, random.sample(range(1000), 20))

# important that batch is 1!
# important that shuffle is False!
dataset_loader = torch.utils.data.DataLoader(dataset_subset, batch_size=1, shuffle=False)
originals_loader = torch.utils.data.DataLoader(originals_subset, batch_size=1, shuffle=False)
big_loader = torch.utils.data.DataLoader(big_dataset_subset, batch_size=1, shuffle=False)




In [4]:
# define module paths for cascading randomization
module_paths = [
    ['fc'],
    ['layer4', '1'], ['layer4', '0'],
    ['layer3', '1'], ['layer3', '0'],
    ['layer2', '1'], ['layer2', '0'],
    ['layer1', '1'], ['layer1', '0'],
    ['bn1'], ['conv1']
]

In [7]:
%matplotlib agg
# visualize integrated gradients
fig, _ = util.visualize_cascading_randomization(resnet, module_paths, (InputXGradient, False), dataset_loader, originals_loader, cls_index_to_name, viz_method="blended_heat_map")
fig.savefig("figures/resnet-imagenet/inputxgradient_cascrand.png", bbox_inches="tight")

In [8]:
%matplotlib agg
fig, _ = util.visualize_cascading_randomization(resnet, module_paths, (InputXGradient, True), dataset_loader, originals_loader, cls_index_to_name, viz_method="blended_heat_map")
fig.savefig("figures/resnet-imagenet/inputxgradient_smoothing_cascrand.png", bbox_inches="tight")

In [6]:
%matplotlib agg
# multiple saliency maps for each example
for (image, label), (original, _) in zip(dataset_loader, originals_loader):
    fig, _ = util.visualize_cascading_randomization2(
        resnet,
        module_paths,
        [(Saliency, False), (Saliency, True), (InputXGradient, False), (InputXGradient, True), (GuidedBackprop, False), (IntegratedGradients, False), (IntegratedGradients, True)],
        ['Gradient', 'SmoothGrad', 'Gradient ⊙ Input', 'Gradient ⊙ Input-SG' 'Guided Back-propagation', 'Integrated Gradients', 'Integrated Gradients-SG'],
        (image, label),
        original,
        viz_method="heat_map"
    )
    fig.savefig("figures/resnet-imagenet/cascrand.png", bbox_inches="tight")
    break



In [5]:
dic = util.ssim_saliency_comparison(
    resnet,
    module_paths,
        [(Saliency, False), (Saliency, True), (InputXGradient, False), (InputXGradient, True), (GuidedBackprop, False), (IntegratedGradients, False)],
        ['Gradient', 'SmoothGrad', 'Gradient ⊙ Input', 'Gradient ⊙ Input-SG', 'Guided Back-propagation', 'Integrated Gradients'], # integrated gradients takes a really long time
    big_loader
    )

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


layer: 1 of 11


  ssim_score = compare_ssim(original_explanations[(img_id, sal_id)], attribution, multichannel=True, gaussian_weights=True) # calculate ssim score with original attribution and add to sum


layer: 2 of 11
layer: 3 of 11
layer: 4 of 11
layer: 5 of 11
layer: 6 of 11
SSIM score is nan
Error with image: 2, sal method: Guided Back-propagation, path: layer2_1
SSIM score is nan
Error with image: 11, sal method: Guided Back-propagation, path: layer2_1
SSIM score is nan
Error with image: 15, sal method: Guided Back-propagation, path: layer2_1
SSIM score is nan
Error with image: 17, sal method: Guided Back-propagation, path: layer2_1
layer: 7 of 11
SSIM score is nan
Error with image: 9, sal method: Guided Back-propagation, path: layer2_0
layer: 8 of 11
SSIM score is nan
Error with image: 0, sal method: Guided Back-propagation, path: layer1_1
SSIM score is nan
Error with image: 2, sal method: Guided Back-propagation, path: layer1_1
SSIM score is nan
Error with image: 4, sal method: Guided Back-propagation, path: layer1_1
SSIM score is nan
Error with image: 5, sal method: Guided Back-propagation, path: layer1_1
SSIM score is nan
Error with image: 6, sal method: Guided Back-propagatio

In [6]:
%matplotlib agg
fig = plt.figure(figsize=(15, 6))
ax = fig.subplots()
#plot similarities
colors = ['#c44e52', '#cf171d', '#ccb974', '#d4ae26', '#55a868', '#8172b3']
for (key, value), color in zip(dic.items(), colors):
    ax.plot(
        ['original'] + list(value.keys()), [1] + list(value.values()),
        label=key,
        linestyle='dashed', linewidth=3.5, marker='o', markersize=10, color=color
    )

ax.grid(True)
ax.legend()
fig.savefig("figures/resnet-imagenet/ssim_cascrand_20.png", bbox_inches="tight")