# Sanity Check (Cascading Randomization) for Saliency Maps: ResNet-18 ImageNet Example

In [1]:
%%capture
from torchvision import transforms, models
import torch
import numpy as np
import matplotlib.pyplot as plt
from src import util
from captum.attr import IntegratedGradients, Saliency, InputXGradient, GuidedBackprop
from src import util
import PIL
from ntpath import basename
import os

resnet = models.resnet18(pretrained=True)
resnet.eval()

In [2]:
def preprocess_image(file_path):
    img = PIL.Image.open(file_path).convert('RGB')
    preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    img = preprocess(img)
    return img.unsqueeze(0)

def get_image_and_label(file_path, image_net_cls, normalize=False):
    img = PIL.Image.open(file_path).convert('RGB')
    img_name = basename(file_path)
    cls_name = img_name[img_name.find('_') + 1:img_name.find('.')].replace('_', ' ')
    label = image_net_cls[cls_name]
    if not normalize:
        preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor()
        ])
        img = preprocess(img)
        return img, label
    else:
        preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ])
        img = preprocess(img)
        return img, label


class MyDataset(torch.utils.data.Dataset):
    def __init__(self, dir_path):
        super(MyDataset).__init__()
        self.dir_path = dir_path
        self.files = [img for img in sorted(os.listdir(dir_path)) if img.endswith(".JPEG")]
        with open("imagenet_classes.txt", "r") as f:
            self.image_net_cls = {s.strip(): idx for idx, s in enumerate(f.readlines())}

    def __getitem__(self, idx):
        file_name = self.files[idx]
        return get_image_and_label(os.path.join(self.dir_path, file_name), self.image_net_cls, normalize=True)

    def __len__(self):
        return len(self.files)


class MyOriginalImages(MyDataset):
    def __init__(self, dir_path):
        super().__init__(dir_path)


    def __getitem__(self, idx):
        file_name = self.files[idx]
        return get_image_and_label(os.path.join(self.dir_path, file_name), self.image_net_cls, normalize=False)


with open("imagenet_classes.txt", "r") as f:
    cls_index_to_name = {idx: s.strip() for idx, s in enumerate(f.readlines())}

In [3]:
dataset = MyDataset('imagenet-sample-images')
originals = MyOriginalImages('imagenet-sample-images')
# full dataset loader
# important that batch is 1!
# important that shuffle is False!
full_loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False)

#samples = [4, 13 ,23] hammerhead, junco, vulture
samples = [340, 4, 13, 430, 339] # zebra, hammerhead, junco, basketball, sorrel
# samples = np.random.choice(len(dataset), 2, replace=False)
dataset_subset = torch.utils.data.Subset(dataset, samples)
originals_subset = torch.utils.data.Subset(originals, samples)

# important that batch is 1!
# important that shuffle is False!
dataset_loader = torch.utils.data.DataLoader(dataset_subset, batch_size=1, shuffle=False)
originals_loader = torch.utils.data.DataLoader(originals_subset, batch_size=1, shuffle=False)

# img, label = get_image_and_label('imagenet-sample-images/n01440764_tench.JPEG', image_net_cls, normalize=True)
# with torch.no_grad():
#     output = resnet(img)
    
# probabilities = torch.nn.functional.softmax(output[0], dim=0)
# # Read the categories
# with open("imagenet_classes.txt", "r") as f:
#     categories = [s.strip() for s in f.readlines()]
# # Show top categories per image
# top5_prob, top5_catid = torch.topk(probabilities, 5)
# for i in range(top5_prob.size(0)):
#     print(categories[top5_catid[i]], top5_prob[i].item())
# for img, label in dataset_loader:
#     output = resnet(img)
#     print(output)
#     break

In [4]:
# define module paths for cascading randomization
module_paths = [
    ['fc'],
    ['layer4', '1'], ['layer4', '0'],
    ['layer3', '1'], ['layer3', '0'],
    ['layer2', '1'], ['layer2', '0'],
    ['layer1', '1'], ['layer1', '0'],
    ['bn1'], ['conv1']
]

In [5]:
%matplotlib agg
# visualize integrated gradients
fig, _ = util.visualize_cascading_randomization(resnet, module_paths, (InputXGradient, False), dataset_loader, originals_loader, cls_index_to_name, viz_method="blended_heat_map")
fig.savefig("figures/resnet-imagenet/inputxgradient_cascrand.png", bbox_inches="tight")

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


In [6]:
%matplotlib agg
fig, _ = util.visualize_cascading_randomization(resnet, module_paths, (InputXGradient, True), dataset_loader, originals_loader, cls_index_to_name, viz_method="blended_heat_map")
fig.savefig("figures/resnet-imagenet/inputxgradient_smoothing_cascrand.png", bbox_inches="tight")

In [7]:
%matplotlib agg
# multiple saliency maps for each example
for (image, label), (original, _) in zip(dataset_loader, originals_loader):
    fig, _ = util.visualize_cascading_randomization2(
        resnet,
        module_paths,
        [(Saliency, False), (Saliency, True), (InputXGradient, False), (InputXGradient, True), (GuidedBackprop, False), (IntegratedGradients, False)],
        ['Gradient', 'SmoothGrad', 'Gradient ⊙ Input', 'Gradient ⊙ Input-SG', 'Guided Back-propagation', 'Integrated Gradients'],
        (image, label),
        original,
        viz_method="heat_map"
    )
    fig.savefig("figures/resnet-imagenet/cascrand.png", bbox_inches="tight")
    break

Working on Saliency
Working on Saliency
Working on InputXGradient
Working on InputXGradient
Working on GuidedBackprop
Working on IntegratedGradients




Working on Saliency
Working on Saliency
Working on InputXGradient
Working on InputXGradient
Working on GuidedBackprop
Working on IntegratedGradients
Working on Saliency
Working on Saliency
Working on InputXGradient
Working on InputXGradient
Working on GuidedBackprop
Working on IntegratedGradients
Working on Saliency
Working on Saliency
Working on InputXGradient
Working on InputXGradient
Working on GuidedBackprop
Working on IntegratedGradients
Working on Saliency
Working on Saliency
Working on InputXGradient
Working on InputXGradient
Working on GuidedBackprop
Working on IntegratedGradients
Working on Saliency
Working on Saliency
Working on InputXGradient
Working on InputXGradient
Working on GuidedBackprop
Working on IntegratedGradients
Working on Saliency
Working on Saliency
Working on InputXGradient
Working on InputXGradient
Working on GuidedBackprop
Working on IntegratedGradients




Working on Saliency
Working on Saliency
Working on InputXGradient
Working on InputXGradient
Working on GuidedBackprop
Working on IntegratedGradients
Working on Saliency
Working on Saliency
Working on InputXGradient
Working on InputXGradient
Working on GuidedBackprop
Working on IntegratedGradients
Working on Saliency
Working on Saliency
Working on InputXGradient
Working on InputXGradient
Working on GuidedBackprop
Working on IntegratedGradients
Working on Saliency
Working on Saliency
Working on InputXGradient
Working on InputXGradient
Working on GuidedBackprop
Working on IntegratedGradients
Working on Saliency
Working on Saliency
Working on InputXGradient
Working on InputXGradient
Working on GuidedBackprop
Working on IntegratedGradients


In [8]:
dic = util.ssim_saliency_comparison(
    resnet,
    module_paths,
        [(Saliency, False), (Saliency, True), (InputXGradient, False), (InputXGradient, True), (GuidedBackprop, False), (IntegratedGradients, False)],
        ['Gradient', 'SmoothGrad', 'Gradient ⊙ Input', 'Gradient ⊙ Input-SG', 'Guided Back-propagation', 'Integrated Gradients'], # integrated gradients takes a really long time
    dataset_loader
    )

  ssim_sum += compare_ssim(original_explanations[(img_id, sal_id)], attribution, multichannel=True, gaussian_weights=True) # calculate ssim score with original attribution and add to sum


In [9]:
dic

{'Gradient': {'fc': 0.6988636805595543,
  'layer4_1': 0.6243430175047996,
  'layer4_0': 0.5763960636408851,
  'layer3_1': 0.5236429602603492,
  'layer3_0': 0.5222669677398154,
  'layer2_1': 0.4995225375566238,
  'layer2_0': 0.48712593568497986,
  'layer1_1': 0.3663305949713293,
  'layer1_0': 0.31341853440773104,
  'bn1': 0.2888173423816648,
  'conv1': 0.38512609278961},
 'SmoothGrad': {'fc': 0.7888577028510921,
  'layer4_1': 0.6293241249616182,
  'layer4_0': 0.6220412260185424,
  'layer3_1': 0.6179409972797342,
  'layer3_0': 0.5328704413192613,
  'layer2_1': 0.5115607579962979,
  'layer2_0': 0.5808584332516571,
  'layer1_1': 0.5428970255790462,
  'layer1_0': 0.5478941158185873,
  'bn1': 0.5259110887664706,
  'conv1': 0.4466546961275455},
 'Gradient ⊙ Input': {'fc': 0.7876274863806909,
  'layer4_1': 0.8001470084497238,
  'layer4_0': 0.8047597977099563,
  'layer3_1': 0.800145250448771,
  'layer3_0': 0.7734568350632269,
  'layer2_1': 0.7612311797559872,
  'layer2_0': 0.7933035302104068,
 

In [10]:
%matplotlib agg
fig = plt.figure(figsize=(15, 6))
ax = fig.subplots()
#plot similarities
for key, value in dic.items():
    ax.plot(
        ['original'] + list(value.keys()), [1] + list(value.values()),
        label=key,
        linestyle='dashed', linewidth=3.5, marker='o', markersize=10
    )

ax.grid(True)
ax.legend()
fig.savefig("figures/resnet-imagenet/ssim_cascrand.png", bbox_inches="tight")

In [14]:
image = dic[(340, 0)]
plt.figure()
plt.axis('off')
plt.imshow(image, cmap="Reds")
print(image)
print(image.shape)
print(len(dic))


KeyError: (340, 0)