# SHAP to explain MNIST classifier

In [None]:
import torch
import torchvision
import torch.nn as nn
import matplotlib.pyplot as plt
import shap

In [None]:
PATH = 'data'
BATCHSIZE = 64
LR = 1e-3
NUM_EPOCHS = 2

## load data

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

transform = torchvision.transforms.ToTensor()

train_data = torchvision.datasets.MNIST(root=PATH, train=True, transform=transform, download=True)
test_data = torchvision.datasets.MNIST(root=PATH, train=False, transform=transform, download=True)

train_loader = torch.utils.data.DataLoader(train_data, batch_size=BATCHSIZE, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=BATCHSIZE ,shuffle=True)

## Classifier

In [None]:
class Classifier(nn.Module): 
    def __init__(self):
        super().__init__()
        

        self.conv = nn.Sequential(
                                            nn.Conv2d(in_channels=1, out_channels=4, kernel_size=3), 
                                            nn.BatchNorm2d(4), 
                                            nn.ReLU(), 
                                            # input_size: (BATCHSIZE, 4, 26, 26)
                                            nn.Conv2d(in_channels=4, out_channels=16, kernel_size=4, stride=2), 
                                            nn.BatchNorm2d(16), 
                                            nn.ReLU(), 
                                            # input_size: (BATCHSIZE, 16, 12, 12)
                                            nn.Conv2d(in_channels=16, out_channels=8, kernel_size=3), 
                                            nn.ReLU(), 
                                            # output_size: (BATCHSIZE, 8, 10, 10)
            )

        self.fc = nn.Sequential(
                                            nn.Linear(8*10*10, 100), 
                                            nn.ReLU(), 
                                            nn.Linear(100, 10)#, 
                                            # nn.Softmax(dim=1) (*)
        )
            
    def forward(self, x): 
        x = self.conv(x)
        x = x.view(-1, 8*10*10)
        x = self.fc(x)
        return x

## initialize classifier

In [None]:
classifier = Classifier().to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(classifier.parameters(), lr=LR)

## training

In [None]:
classifier.train()

losses = []

print('Start training classifier...')
for epoch in range(NUM_EPOCHS): 
    
    running_loss = 0.0
    for i, batch in enumerate(train_loader): 
        imgs, labels = batch
        imgs, labels = imgs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        
        preds = classifier(imgs)
        loss = criterion(preds, labels) # preds.log() (*)
        
        loss.backward()
        optimizer.step()
        
        # collect stats
        losses.append(loss.item())
        running_loss += loss.item()
        
        # print stats
        if i % 200 == 199:
            print(f'[{epoch+1}/{NUM_EPOCHS}] [{i+1}/{len(train_loader)}] Loss classifier: {running_loss / 200}')
            running_loss = 0.0 

## plot losses

In [None]:
plt.plot(losses)
plt.title('training loss')
plt.xlabel('batches')
plt.ylabel('loss')
plt.show()

## test classifier

In [None]:
classifier.eval()

with torch.no_grad():    
    correct = 0.0
    num_test_imgs = 0

    for batch in test_loader: 
        imgs, labels = batch
        
        imgs, labels = imgs.to(device), labels.to(device)
        
        preds_raw = classifier(imgs)
        preds = torch.argmax(preds_raw, dim=1)
        
        correct += (preds == labels).sum().item() 
        num_test_imgs += len(labels)
        
    print(f'The accuracy of the classifier is: {correct / num_test_imgs:.3f}')


    imgs, labels = next(iter(train_loader))
    imgs, labels = imgs.to(device), labels.to(device)

    out_raw = classifier(imgs)
    out = torch.argmax(out_raw, dim=1)

    fig = plt.figure(figsize=(8, 8))
    fig.suptitle('(label, prediction)')

    for i in range(len(imgs)): 
        plt.subplot(8, 8, i+1)
        plt.axis('off')
        plt.imshow(imgs[i].squeeze().detach().cpu().numpy())
        plt.title(f'{labels[i].item(), out[i].item()}')

    plt.tight_layout()
    plt.show()

## SHAP $ \text{\scriptsize (see example in SHAP documentation)} $

In [None]:
# since shuffle=True, this is a random sample of test data
batch = next(iter(test_loader))
images, _ = batch

images = images.to(device)

background = images[:50]
test_images = images[50:53]

e = shap.DeepExplainer(classifier, background)
shap_values = e.shap_values(test_images)

In [None]:
import numpy as np 

shap_numpy = [np.swapaxes(np.swapaxes(s, 1, -1), 1, 2) for s in shap_values]
test_numpy = np.swapaxes(np.swapaxes(test_images.cpu().numpy(), 1, -1), 1, 2)

In [None]:
# plot the feature attributions
shap.image_plot(shap_numpy, -test_numpy)