# Explainability

We will explore [Grad-CAM](https://arxiv.org/abs/1610.02391) for explaining predictions of image classification models. This can help us to gain (visual) insights about what makes a model predict a specific class label. Explainability techniques are important for better understanding model behaviour and analyse failure cases.

We will start by implementing Grad-CAM and use it on our LeNet-like MNIST classification model. We then analyse a [VGG-16 model](https://arxiv.org/abs/1409.1556) pre-trained on the ImageNet dataset.

**Objective:** Use Grad-CAM to analyse deep convolutional neural networks for image classification.

In [None]:
# On Google Colab uncomment the following line to install PyTorch Lightning
# ! pip install lightning

In [None]:
import numpy as np
import pandas as pd
import seaborn as sns
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import matplotlib
import matplotlib.pyplot as plt

from torch.utils.data import random_split, DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
from pytorch_lightning import LightningModule, LightningDataModule, Trainer, seed_everything
from torchmetrics.functional import accuracy

## Data

We use a [LightningDataModule](https://lightning.ai/docs/pytorch/stable/data/datamodule.html) for handling the MNIST dataset.

In [None]:
class MNISTDataModule(LightningDataModule):
    def __init__(self, data_dir: str = './data', batch_size: int = 32, num_workers: int = 4, transform = transforms.ToTensor()):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.transform = transform        

        self.test_set = MNIST(self.data_dir, train=False, transform=self.transform, download=True)
        dev_set = MNIST(self.data_dir, train=True, transform=self.transform, download=True)
        self.train_set, self.val_set = random_split(dev_set, [55000, 5000])

    def train_dataloader(self):
        return DataLoader(self.train_set, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, persistent_workers=True)

    def val_dataloader(self):
        return DataLoader(self.val_set, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, persistent_workers=True)

    def test_dataloader(self):
        return DataLoader(self.test_set, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, persistent_workers=True)

## Model

We use a [LightningModule](https://lightning.ai/docs/pytorch/stable/common/lightning_module.html) for implementing the model and its training and testing steps.

In [None]:
class ImageClassifier(LightningModule):
    def __init__(self, input_dim: tuple[int, int] = (28,28), output_dim: int = 10, learning_rate: float = 0.001):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.learning_rate = learning_rate
        
        # LeNet
        self.conv = nn.Sequential(
            nn.Conv2d(1, 6, kernel_size=5),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(6, 16, kernel_size=5),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
            )        
        self.fc = nn.Sequential(
            nn.Linear(16 * 4 * 4, 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, self.output_dim)
            )

    def forward(self, x):
        # first pass x through the conv layers
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        # then pass linearised x through the fully connected layers
        return self.fc(x)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer
    
    def process_batch(self, batch):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        probs = torch.softmax(logits, dim=1)
        preds = torch.argmax(probs, dim=1)        
        acc = accuracy(preds, y, task='multiclass', num_classes=self.output_dim)
        return loss, acc

    def training_step(self, batch, batch_idx):
        loss, acc = self.process_batch(batch)
        self.log('train_loss', loss, prog_bar=True)
        self.log('train_acc', acc, prog_bar=True)
        if batch_idx == 0:
            grid = torchvision.utils.make_grid(batch[0][0:16, ...], nrow=4, normalize=True)
            self.logger.experiment.add_image('train_images', grid, self.global_step)
        return loss

    def validation_step(self, batch, batch_idx):
        loss, acc = self.process_batch(batch)
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)

    def test_step(self, batch, batch_idx):
        loss, acc = self.process_batch(batch)
        self.log('test_loss', loss)
        self.log('test_acc', acc)

## Load pre-trained model

Let's start by loading the LeNet-like model previously trained for MNIST image classification.

**Task:** Add the path to your pre-trained MNIST classification model.

In [None]:
seed_everything(42, workers=True)

data = MNISTDataModule(data_dir='./data', batch_size=32)

model_dir = '<path_to_model_checkpoint>' # for example: './lightning_logs/classification/mnist-lenet/version_0/checkpoints/epoch=5-step=10314.ckpt'
model_dir = '../lightning_logs/classification/mnist-lenet/version_0/checkpoints/epoch=5-step=10314.ckpt'
model = ImageClassifier.load_from_checkpoint(model_dir, input_dim=(28,28), output_dim=10)

trainer = Trainer() # dummy trainer for running test() on the loaded model

## Testing

Evaluate the pre-trained model on the test data and confirm the classification accuracy.

In [None]:
trainer.test(model=model, datamodule=data)

## Class activation maps

For generating class activation maps, we first need to decide which convolutional layer we want to use for this.

Let's look at a print-out of the model architecture.

```
ImageClassifier(
  (conv): Sequential(
    (0): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (fc): Sequential(
    (0): Linear(in_features=256, out_features=120, bias=True)
    (1): ReLU()
    (2): Linear(in_features=120, out_features=84, bias=True)
    (3): ReLU()
    (4): Linear(in_features=84, out_features=10, bias=True)
  )
)
```

Our classification model consists of two sub-models, the convolutional part with two convolutional and two max-pooling layers, and the fully connected layers. For Grad-CAM, we usually want to compute the class activations at the output of last convolutional layer, before performing max-pooling.

We now implement a modified model where separate out the application of the first five layers of the convolutional part, register the output for Grad-CAM calculations, and then apply the final max-pooling before passing the output to the fully connected layers. For this, we can modify the forward pass accordingly.

In [None]:
class ImageClassifierCAM(ImageClassifier):
    def __init__(self, input_dim: tuple[int, int] = (28,28), output_dim: int = 10, learning_rate: float = 0.001):
        super().__init__(input_dim, output_dim, learning_rate)       
        
        # placeholder for the gradients
        self.gradients = None

    # hook for the gradients of the activations
    def activations_hook(self, grad):
        self.gradients = grad
        
    def forward(self, x):
        # disect the network to access its last convolutional layer
        features_conv = self.conv[:5]

        # first pass x through the conv layers
        x = features_conv(x)
        
        # register the hook
        x.register_hook(self.activations_hook)
        
        # apply the remaining pooling
        x = F.max_pool2d(x, kernel_size=2)
        
        x = x.view((1, -1))        
        # then pass linearised x through the fully connected layers
        x = self.fc(x)
        return x
    
    # method for the gradient extraction
    def get_activations_gradient(self):
        return self.gradients
    
    # method for the activation exctraction
    def get_activations(self, x):
        features_conv = self.conv[:5]
        return features_conv(x)

With this modified model we can now run inference on a selected test image.

In [None]:
model_modified = ImageClassifierCAM.load_from_checkpoint(model_dir, input_dim=(28,28), output_dim=10)

In [None]:
test_batch, _ = next(iter(data.test_dataloader()))
test_batch = test_batch[0:16,...]

grid = torchvision.utils.make_grid(test_batch, nrow=4, normalize=True).numpy()[0,...].squeeze()

plt.imshow(grid, cmap=matplotlib.cm.gray)
plt.axis('off')
plt.title('example test images')

In [None]:
test_index = 1
test_image = test_batch[test_index,...].unsqueeze(0)

plt.imshow(test_image.squeeze(), cmap=matplotlib.cm.gray)
plt.axis('off')
plt.title('test image')

We run the modified model on the selected test image and store the output logits and the predicted class label.

In [None]:
logits = model_modified(test_image)
predicted_label = F.softmax(logits, dim=1).argmax(dim=1)
print(predicted_label)

The following function implements the Grad-CAM algorithm. For more details, please have a look at this [blog post](https://medium.com/@stepanulyanin/implementing-grad-cam-in-pytorch-ea0937c31e82).

In [None]:
def run_gradcam(model, image, logits, class_label, num_feature_channels):
    # get the gradient of the output with respect to the parameters of the model
    logits[:, class_label].backward()

    # pull the gradients out of the model
    gradients = model.get_activations_gradient()

    # pool the gradients across the channels
    pooled_gradients = torch.mean(gradients, dim=[0, 2, 3])

    # get the activations of the last convolutional layer
    activations = model.get_activations(image).detach()

    # weight the channels by corresponding gradients
    for i in range(num_feature_channels):
        activations[:, i, :, :] *= pooled_gradients[i]
        
    # average the channels of the activations
    heatmap = torch.mean(activations, dim=1).squeeze()

    # relu on top of the heatmap
    # expression (2) in https://arxiv.org/abs/1610.02391
    heatmap = np.maximum(heatmap, 0)

    # normalize the heatmap
    heatmap /= torch.max(heatmap)

    return heatmap

Now let's run Grad-CAM and visualise the generated heatmap.

In [None]:
heatmap = run_gradcam(model=model_modified, image=test_image, logits=logits, class_label=predicted_label, num_feature_channels=16)

# draw the heatmap
plt.matshow(heatmap, cmap=matplotlib.cm.bwr)
plt.colorbar()
plt.title('Grad-CAM heatmap')

The heatmap will have the same resolution as the output of the selected convolutional layer. In order to overlay the heatmap on the original test image, we need to upsample the heatmap.

In [None]:
import matplotlib

img = test_image.squeeze()
hmp = transforms.functional.resize(heatmap[None, None, ...], test_image.shape[2::], antialias=True).squeeze()

f, ax = plt.subplots(1,3, figsize=(15, 15))

ax[0].imshow(img, cmap=matplotlib.cm.gray)
ax[0].axis('off')
ax[0].set_title('test image')

ax[1].imshow(hmp, cmap=matplotlib.cm.bwr)
ax[1].axis('off')
ax[1].set_title('Grad-CAM heatmap')

ax[2].imshow(hmp*img, cmap=matplotlib.cm.bwr)
ax[2].axis('off')
ax[2].set_title('multiplication')

## Grad-CAM with VGG-16

Now that we have seen how to adapt our own model for Grad-CAM, in the following we will do this with a more powerful model trained on a much more difficult task than MNIST classification.

Here, we will consider a VGG-16 model trained on ImageNet classification, inspired by this [blog post](https://medium.com/@stepanulyanin/implementing-grad-cam-in-pytorch-ea0937c31e82).

In [None]:
! wget https://www.doc.ic.ac.uk/~bglocker/teaching/mli/gradcam.zip
! unzip gradcam.zip

In [None]:
import torch
import torch.nn as nn
from torch.utils import data
from torchvision.models import vgg16
from torchvision import transforms
from torchvision import datasets
import matplotlib.pyplot as plt
import numpy as np

# use the ImageNet transformation
transform = transforms.Compose([transforms.Resize((224, 224)), 
                                transforms.ToTensor(),
                                transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

# define the image dataset
dataset = datasets.ImageFolder(root='./data/gradcam/', transform=transform)

# define the dataloader to load that single image
dataloader = data.DataLoader(dataset=dataset, shuffle=False, batch_size=1)

# ImageNet class names
with open('./data/gradcam/imagenet1000_clsidx_to_labels.txt') as f:
    idx2label = eval(f.read())

Here is print out of the VGG-16 architecture.
```
VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (18): ReLU(inplace=True)
    (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU(inplace=True)
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): ReLU(inplace=True)
    (23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (25): ReLU(inplace=True)
    (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (27): ReLU(inplace=True)
    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (29): ReLU(inplace=True)
    (30): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
  (classifier): Sequential(
    (0): Linear(in_features=25088, out_features=4096, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.5, inplace=False)
    (3): Linear(in_features=4096, out_features=4096, bias=True)
    (4): ReLU(inplace=True)
    (5): Dropout(p=0.5, inplace=False)
    (6): Linear(in_features=4096, out_features=1000, bias=True)
  )
)
```

### Modifying VGG

We can see that the output of the last convolutional layer is at index 29 in the `features` submodel (which is the convolutional part). Similar to before, we implement a modified VGG model to separate out the different parts of the model, so we can store the gradients of the selected convolutional layer during inference.

This time, we recommend to separate out the model components alread in the `__init__` function, and then adjust the `forward` function accordingly.

**Task:** Implement the modified VGG model for Grad-CAM.

In [None]:
class VGG(nn.Module):
    def __init__(self):
        super(VGG, self).__init__()
        
        # get the pretrained VGG16 network
        self.vgg = vgg16(weights=torchvision.models.VGG16_Weights.IMAGENET1K_V1)
        
        # separate out the convolutional part up to the last max-pooling layer
        
        # disect the network to access its last convolutional layer
        self.features_conv = self.vgg.features[:30]
        
        # get the max pool of the features stem
        self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        
        # get the average pooling part
        self.avg_pool = self.vgg.avgpool
        
        # get the classifier of the vgg19
        self.classifier = self.vgg.classifier
        
        # placeholder for the gradients
        self.gradients = None
    
    # hook for the gradients of the activations
    def activations_hook(self, grad):
        self.gradients = grad
        
    def forward(self, x):
        # apply the convolutional part
        x = self.features_conv(x)
        
        # register the hook
        x.register_hook(self.activations_hook)
        
        # apply the remaining max pooling
        x = self.max_pool(x)

        # apply the average pooling
        x = self.avg_pool(x)
        
        # apply the fully connected layers
        x = x.view((1, -1))
        x = self.classifier(x)
        return x
    
    # method for the gradient extraction
    def get_activations_gradient(self):
        return self.gradients
    
    # method for the activation exctraction
    def get_activations(self, x):
        return self.features_conv(x)

Let's visualise our test image.

In [None]:
image_index = 0
image = cv2.cvtColor(cv2.imread('./data/gradcam/test_images/image_' + str(image_index) + '.jpg'), cv2.COLOR_BGR2RGB)

plt.imshow(image)
plt.axis('off')
plt.title('test image')

We run inference on the test image with out modified VGG model.

In [None]:
# initialize the VGG model
vgg = VGG()

# set the evaluation mode
vgg.eval()

# get the image from the dataloader
# img, _ = next(iter(dataloader))
img, _ = list(dataloader)[image_index]

# get the most likely prediction of the model
logits = vgg(img)

predicted_label = logits.softmax(dim=1).argmax(dim=1)

print(idx2label[predicted_label.item()])

Now we run Grad-CAM to generate the corresponding heatmap.

In [None]:
heatmap = run_gradcam(model=vgg, image=img, logits=logits, class_label=predicted_label, num_feature_channels=512)

# draw the heatmap
plt.matshow(heatmap, cmap=matplotlib.cm.bwr)
plt.colorbar()
plt.title('Grad-CAM heatmap')

The heatmap will again have the same size as the output of the selected convolutional layer, which is 14x14 for the VGG-16.

To visualise the heatmap and overlay it onto the test image, we use OpenCV for resizing, smoothing and generating a colour overlay.

In [None]:
gradcam = cv2.resize(heatmap.numpy(), (image.shape[1], image.shape[0]))
gradcam = cv2.blur(gradcam,(50,50))
gradcam = np.uint8(255 * gradcam)
gradcam = cv2.cvtColor(cv2.applyColorMap(gradcam, cv2.COLORMAP_JET), cv2.COLOR_BGR2RGB)

superimposed_img = cv2.addWeighted(gradcam, 0.5, image, 1.0, 0)

f, ax = plt.subplots(1,3, figsize=(15, 15))

ax[0].imshow(image)
ax[0].axis('off')
ax[0].set_title('test image')

ax[1].imshow(gradcam)
ax[1].axis('off')
ax[1].set_title('Grad-CAM heatmap')

ax[2].imshow(superimposed_img)
ax[2].axis('off')
ax[2].set_title('overlay')

## Other classification models

**Optional task:** Choose one of the [torchvision classification models](https://pytorch.org/vision/stable/models.html#classification) pre-trained on ImageNet and implement your own modified model in order to run Grad-CAM. Feel free to use other test images which can be found under `./data/gradcam/test_images/` (or add your own). Note, you will need to re-instantiate the dataset when changing the test images.