<h1>Sanity Check for saliency map</h1>

<h2>Introduction</h2>

<p>This notebook help to generate sanity check for PolyCAM methods</p>
<p>It does not required additional library in addition to requirements.txt and Jupyter</p>

<h2>Parameters</h2>

In [None]:
# Set gpu = True to use CUDA (recommended if available)
gpu = True

# Set batch size depending on the gpu memory size
batch_size = 1

image_list = 'images.txt'
labels_list = 'imagenet_validation_imagename_labels.txt'
image_folder = 'images'

<h2>Preparation code</h2>

In [None]:
import torch
from torch.nn import Conv2d, Linear

from torchvision import models
from torchvision import transforms
from torchvision.transforms.functional import normalize, resize

from polycam.polycam import  PCAMp, PCAMm, PCAMpm

from benchmarks.utils import overlay

import PIL

from matplotlib import pyplot as plt

<h2>Load model</h2>

In [None]:
model = models.vgg16(True)
model.eval()
if gpu:
    model = model.cuda()

<h2>Saliency maps selection</h2>

In [None]:
# Saliency methods to test
saliency_list = []
saliency_list.append(["PCAM+", PCAMp(model, batch_size=batch_size)])
saliency_list.append(["PCAM-", PCAMm(model, batch_size=batch_size)])
saliency_list.append(["PCAM+/-", PCAMpm(model, batch_size=batch_size)])

<h2>Image selection</h2>

In [None]:

# Load image
try:
    del image
except:
    pass

# Uncomment desired image or add your own

#img = PIL.Image.open("sanity_images/ILSVRC2012_val_00015410.JPEG")
#img = PIL.Image.open("sanity_images/ILSVRC2012_val_00010495.JPEG")
img = PIL.Image.open("sanity_images/ILSVRC2012_val_00021206.JPEG")
#img = PIL.Image.open("sanity_images/ILSVRC2012_val_00032239.JPEG")
#img = PIL.Image.open("sanity_images/ILSVRC2012_val_00041179.JPEG")


<h2>Image pre-processing</h2>

In [None]:
input_size = 224

totensor = transforms.ToTensor()

image = resize(totensor(img), (input_size, input_size)).unsqueeze(0)

if gpu:
    image = image.cuda()

image_norm = transforms.functional.normalize(image, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225))

# get result class
out = model(image_norm)
result_class = torch.softmax(out, dim=-1).max(1)[1].item()
result_class = torch.softmax(model(image_norm), dim=-1).max(1)[1]


<h2>Generate sanity check</h2>

In [None]:
#number of saliency maps for PolyCAM methods
nmap = 4

# keep all saliency maps generated
saliency_maps = {}

# First generate original saliency map
for saliency_name, saliency in saliency_list:
    print(saliency_name)
    saliency_maps[saliency_name] = []
    saliency_map = saliency(image_norm)[nmap].cpu().detach()
    saliency_maps[saliency_name].append(saliency(image_norm)[nmap].cpu().detach())
    
# Generate cascading sanity check
i = 1
s = 1
for name, module in reversed(list(model.named_modules())):
    if isinstance(module, Conv2d) or isinstance(module, Linear):
        module.reset_parameters()
        if ((i-1)%s == 0):
            print(name)
            for saliency_name, saliency in saliency_list:
                print(saliency_name)
                saliency_map = saliency(image_norm)[nmap].cpu().detach()
                saliency_maps[saliency_name].append(saliency(image_norm)[nmap].cpu().detach())
        i = i+1

<h2>Generate Visualization</h2>

In [None]:
fontsize = 32

col_names = ["Original saliency map", "classifier_fc3", "classifier_fc2", "classifier_fc1",
             "block5_conv3", "block5_conv2", "block5_conv1",
             "block4_conv3", "block4_conv2", "block4_conv1", 
             "block3_conv3", "block3_conv2", "block3_conv1",
             "block2_conv2", "block2_conv1",
             "block1_conv2", "block1_conv1"
            ]

n_rows = len(saliency_maps.keys())
n_columns = len(saliency_maps[saliency_list[0][0]])

fig, ax = plt.subplots(n_rows, n_columns, figsize=((n_columns+2)*4, (n_rows+3)*4))
for idx_row, (saliency_name, saliency) in enumerate(saliency_list):
    
    ax[idx_row][0].set_ylabel(saliency_name, labelpad=160, rotation=0, fontsize=52)
    
    for idx_col, saliency_map in enumerate(saliency_maps[saliency_name]):
        ax[0][idx_col].set_title(col_names[idx_col], rotation=90, fontsize=42, pad=40)
        
        ax[idx_row][idx_col].set(yticklabels=[])
        ax[idx_row][idx_col].set(xticklabels=[])
        ax[idx_row][idx_col].tick_params(left=False)
        ax[idx_row][idx_col].tick_params(bottom=False)
        
        sh, sw = saliency_map.shape[-2:]
        saliency_map = saliency_map.view(1,1,sh,sw).type(torch.float32)
        
        ax[idx_row][idx_col].imshow(overlay(image[0], saliency_map[0], alpha=0, colormap="turbo"))

plt.tight_layout()
plt.show()