# DeepLiftShap (Captum) applied to MNIST classifier

In [None]:
import torch
import torchvision
import torch.nn as nn
import matplotlib.pyplot as plt

from captum.attr import DeepLiftShap
from captum.attr import visualization as viz

import numpy as np

In [None]:
PATH = 'data'
BATCHSIZE = 64
LR = 1e-3
NUM_EPOCHS = 5

## load data

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

transform = torchvision.transforms.ToTensor()

train_data = torchvision.datasets.MNIST(root=PATH, train=True, transform=transform, download=True)
test_data = torchvision.datasets.MNIST(root=PATH, train=False, transform=transform, download=True)

train_loader = torch.utils.data.DataLoader(train_data, batch_size=BATCHSIZE, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=BATCHSIZE ,shuffle=True)

## Classifier

In [None]:
class Classifier(nn.Module): 
    def __init__(self):
        super().__init__()
        
        self.classify = nn.Sequential(
                                        nn.Conv2d(in_channels=1, out_channels=4, kernel_size=3), 
                                        nn.BatchNorm2d(4), 
                                        nn.ReLU(), 
                                        # input_size: (BATCHSIZE, 4, 26, 26)
                                        nn.Conv2d(in_channels=4, out_channels=16, kernel_size=4, stride=2), 
                                        nn.BatchNorm2d(16), 
                                        nn.ReLU(), 
                                        # input_size: (BATCHSIZE, 16, 12, 12)
                                        nn.Conv2d(in_channels=16, out_channels=8, kernel_size=3), 
                                        nn.ReLU(), 
                                        # input_size: (BATCHSIZE, 8, 10, 10)
                                        nn.Flatten(), 
                                        nn.Linear(8*10*10, 100), 
                                        nn.ReLU(), 
                                        nn.Linear(100, 10), 
        )
        
    def forward(self, input): 
        x = self.classify(input)
        return x

## initialize classifier

In [None]:
classifier = Classifier().to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(classifier.parameters(), lr=LR)

## training

In [None]:
losses = []

print('Start training classifier...')
for epoch in range(NUM_EPOCHS): 
    
    running_loss = 0.0
    for i, batch in enumerate(train_loader): 
        imgs, labels = batch
        imgs, labels = imgs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        
        preds = classifier(imgs)
        loss = criterion(preds, labels)
        
        loss.backward()
        optimizer.step()
        
        # collect stats
        losses.append(loss.item())
        running_loss += loss.item()
        
        # print stats
        if i % 200 == 199:
            print(f'[{epoch+1}/{NUM_EPOCHS}] [{i+1}/{len(train_loader)}] Loss classifier: {running_loss / 200}')
            running_loss = 0.0 

## plot losses

In [None]:
plt.plot(losses)
plt.title('training loss')
plt.xlabel('batches')
plt.ylabel('loss')
plt.show()

## test classifier

In [None]:
classifier.eval()

with torch.no_grad(): 
        
    correct = 0.0
    num_test_imgs = 0
    
    # accuracy
    for batch in test_loader: 
        imgs, labels = batch
        
        imgs, labels = imgs.to(device), labels.to(device)
        
        preds_raw = classifier(imgs)
        preds = torch.argmax(preds_raw, dim=1)
        
        correct += (preds == labels).sum().item() 
        num_test_imgs += len(labels)
        
    print(f'The accuracy of the classifier is: {correct / num_test_imgs:.3f}')

    # plot some images next to label and prediction
    imgs, labels = next(iter(train_loader))
    imgs, labels = imgs.to(device), labels.to(device)

    out_raw = classifier(imgs)
    out = torch.argmax(out_raw, dim=1)
    
    fig = plt.figure(figsize=(8, 8))
    fig.suptitle('(label, prediction)')

    for i in range(len(imgs)): 
        plt.subplot(8, 8, i+1)
        plt.axis('off')
        plt.imshow(imgs[i].squeeze().detach().cpu().numpy())
        plt.title(f'{labels[i].item(), out[i].item()}')
    
    plt.tight_layout()
    plt.show()

## Captum (DeepLiftShap)

In [None]:
# instantiate DeepLiftShap on model (*see Captum documentation) 
dl = DeepLiftShap(classifier)

# random integer for picking an image
num = torch.randint(BATCHSIZE, (1,)).item()

# load batch for image and background
imgs, labels = next(iter(test_loader))
img, label = imgs[num].unsqueeze(dim=1).to(device), labels[0].to(device)

background = imgs.to(device)

# check shapes
print(img.shape)
print(background.shape)

# predict image class (-> torch.Size([1, 10]))
outputs = classifier(img)

print(f'Original Image label: {labels[num].item()}')
print(f'Predicted: {torch.argmax(outputs).item()}, Probability: {torch.max(torch.nn.functional.softmax(outputs, 1)).item():.2f}')

# show image
orig_image = imgs[num].squeeze().cpu().detach().numpy()
plt.imshow(orig_image)
plt.axis('off')
plt.show()

# calculate attributions of pixels of input image for label (*)
attribution = dl.attribute(img, target=label, baselines=background) 
attribution = np.transpose(attribution.squeeze(0).cpu().detach().numpy(), (1, 2, 0))

im = img.squeeze().unsqueeze(dim=0).cpu().detach().numpy() / 2 + 0.5
original_image = np.transpose(im, (1, 2, 0))

# visualization (*)
_ = viz.visualize_image_attr(attribution, original_image, method="blended_heat_map",sign="all", alpha_overlay=0.8, show_colorbar=True, 
                          title="Overlayed DeepLift")

## display image next to attribution values of prediction

In [None]:
dl = DeepLiftShap(classifier)

num = torch.randint(BATCHSIZE, (1,)).item()

imgs, labels = next(iter(test_loader))
background = imgs.to(device)

img, label = imgs[num].unsqueeze(dim=1).to(device), labels[num].to(device)
outputs = classifier(img)
  
attribution = dl.attribute(img, target=torch.argmax(outputs).item(), baselines=background) 
attribution = np.transpose(attribution.squeeze(0).cpu().detach().numpy(), (1, 2, 0))

im = img.squeeze().unsqueeze(dim=0).cpu().detach().numpy()
original_image = np.transpose(im, (1, 2, 0))
_ = viz.visualize_image_attr_multiple(attribution, original_image, methods=["original_image", "blended_heat_map"], signs=["all", "all"], titles=[f"label: {label}", f"predicted:{torch.argmax(outputs).item()}"], alpha_overlay=0.8, show_colorbar=True, 
                            fig_size=(6, 8), use_pyplot=True)


## display attributions for all classes

In [None]:
dl = DeepLiftShap(classifier)

num = torch.randint(BATCHSIZE, (1,)).item()

imgs, labels = next(iter(test_loader))
background = imgs.to(device)

img, label = imgs[num].unsqueeze(dim=1).to(device), labels[0].to(device)
im = img.squeeze().cpu().detach().numpy()

fig = plt.figure(figsize=(15,3), tight_layout=True)
st = fig.suptitle("DeepLiftShap", fontsize="x-large")
ax = plt.subplot(1, 11, 1)
ax.axis('off')
ax.set_title(f'predicted: {torch.argmax(classifier(img)).item()}')
ax.imshow(im)

for i in range(10): 
    attribution = dl.attribute(img, target=i, baselines=background) 
    attribution = np.transpose(attribution.squeeze(0).cpu().detach().numpy(), (1, 2, 0))
    ax = plt.subplot(1, 11, i+2)
    im = img.squeeze().unsqueeze(dim=0).cpu().detach().numpy()
    original_image = np.transpose(im, (1, 2, 0))
    viz.visualize_image_attr(attribution, original_image, method="blended_heat_map",sign="all", plt_fig_axis=(fig, ax), alpha_overlay=0.8, show_colorbar=True, use_pyplot=False)

# put suptitle under plots
st.set_y(0)