# Evaluate Saliency Maps for 3-Way Macrophage/Monocyte Classifier

This tutorial demonstrates how to apply model interpretability algorithms from Captum library on a simple model and test samples from CIFAR dataset.

In this tutorial we build a simple model as described in:
https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#sphx-glr-beginner-blitz-cifar10-tutorial-py

Then we use attribution algorithms such as `IntegratedGradients`, `Saliency`, `DeepLift` and `NoiseTunnel` to attribute the label of the image to the input pixels and visualize it.
  
  **Note:** Before running this tutorial, please install the torchvision, and matplotlib packages.

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import pickle
from utils import *

import torch
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
from captum.attr import DeepLift

import matplotlib.pyplot as plt
import numpy as np


In [None]:
DATA_PATH = './runs/mac_polar_run/test_data_fold_0'
MODEL_PATH = './runs/mac_polar_run/model_fold_0'
BATCH_SIZE = 4
NUM_WORKERS = 2

In the cell below we load test and train datasets, define image transformers

In [None]:
transforms = transforms.Compose([
    standardize_input()
    ])

In [None]:
raw_images = []
raw_labels = []

data = pickle.load(open(DATA_PATH, 'rb'))
test_sampler = equal_classes_sampler(data.labels)
testloader = DataLoader(data, batch_size=BATCH_SIZE, sampler=test_sampler,
                        shuffle=False, num_workers=0) 


In [None]:
for batch, data in enumerate(testloader): 
    print(data[0].shape, data[1].shape)
    img, label = data
    if batch < 2:
        break

In [None]:
print("Using existing trained model")
net = torch.load(MODEL_PATH)
net.to("cpu")
net.eval()

In the cell below we load some images from the test dataset and perform predictions.

In [None]:
def imshow(img, transpose = True):
    #img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()


A generic function that will be used for calling `attribute` on attribution algorithm defined in input.

In [None]:
def attribute_image_features(algorithm, input, target, **kwargs):
    net.zero_grad()
    tensor_attributions = algorithm.attribute(input,
                                              target=target,
                                              **kwargs
                                             )
    
    return tensor_attributions
        

Applies DeepLift on test image. Deeplift assigns attributions to each input pixel by looking at the differences of output and its reference in terms of the differences of the input from the reference.

In [None]:
dataiter = iter(testloader)
images, labels = dataiter.next()
images = images[:,0:2,:,:]
# print images
imshow(torchvision.utils.make_grid(images[:,[0],:,:]))
print('GroundTruth: ', ' '.join('%5s' % labels[j].item() for j in range(4)))

outputs = net(images.float())
predicted = torch.argmax(outputs,1).to(torch.double)

print('Predicted: ', ' '.join('%5s' % predicted[j].item()
                              for j in range(4)))
                              
num_imgs = len(images)

dl = DeepLift(net)

attr_dl = []
org_imgs = []

for i in range(num_imgs):
    input = images[i].unsqueeze(0).float()
    input.requires_grad = True
    
    attr_dl_tmp = attribute_image_features(dl, input, int(labels[i].item()), baselines=input * 0)
    attr_dl_tmp = attr_dl_tmp.squeeze(0).cpu().detach().numpy()
    attr_dl_tmp /= attr_dl_tmp.max()
   # print(attr_dl_tmp.shape)
    attr_dl.append(attr_dl_tmp) 

    org_img_tmp = np.transpose((images[i].cpu().detach().numpy() / 2) + 0.5, (1, 2, 0))
    org_img_tmp /= org_img_tmp.max()
    org_imgs.append(org_img_tmp)

In the cell below we will visualize the attributions for `Saliency Maps`, `DeepLift`, `Integrated Gradients` and `Integrated Gradients with SmoothGrad`.

In [None]:
phenos = ["M0", "M1", "M2"]
rows = 2
cols = num_imgs
fig, axes = plt.subplots(nrows=rows, ncols=cols, figsize=(cols*2.5 + .5, 5))

for i in range(cols):
    axes[0, i].set_axis_off()
    axes[1, i].set_axis_off()
    im1 = axes[0, i].imshow(org_imgs[i][:,:,0], cmap='viridis', vmin=0)
    im2 = axes[1, i].imshow(attr_dl[i][0], cmap='viridis', vmin=0)
    lbl = phenos[int(labels[i].item())]
    prd = phenos[int(predicted[i].item())]
    axes[0,i].set_title(f"label: {lbl}\npred: {prd}")

fig.subplots_adjust(bottom=0, top=0.8, left=0.1, right=0.8,
                    wspace=0.05, hspace=0.02)

# add an axes, lower left corner in [0.83, 0.1] measured in figure coordinate with axes width 0.02 and height 0.8
cb_ax = fig.add_axes([0.83, 0.1, 0.02, 0.7])
cbar = fig.colorbar(im1, cax=cb_ax)

plt.show()