<a href="https://colab.research.google.com/github/arkeodev/XAI/blob/main/Layerwise_Relevance_Propagation_(LRP).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Layerwise Relevance Propagation (LRP)

LRP is a technique used to explain the predictions of complex models by tracing the prediction back through the layers of the network to the input features, thereby providing a visual map or a set of influential features that led to the decision.

LRP has a theoretical foundation in Taylor decomposition, which helps explain the contributions of individual components (input features) to a function's (network's) output near a point (the input data).

<figure>
    <img src="https://raw.githubusercontent.com/arkeodev/XAI/main/images/lrp.png" width="800" height="300" alt="LRP">
    <figcaption>LRP</figcaption>
</figure>

### How LRP Works?

<figure>
    <img src="https://raw.githubusercontent.com/arkeodev/XAI/main/images/lrp-how-it-runs.png" width="800" height="300" alt="How LRP Works?">
    <figcaption>How LRP Works?</figcaption>
</figure>

Layer-wise Relevance Propagation (LRP) is a technique used in machine learning to understand and visualize the contribution of each input feature to the final prediction of a neural network. It's particularly useful for deep learning models, where the decision-making process can be quite opaque.

**Step 1: Forward Pass**

- Start with a trained neural network ready for making predictions.
- Input your data (like an image) into the network.
- Perform a forward pass through the network, where the input is processed layer by layer to arrive at a final prediction. This step is just like any other prediction process in neural networks.

**Step 2: Prediction Output**

- Capture the output of the network. This could be a class label in classification tasks or a value in regression tasks.
- Select the output neuron(s) corresponding to the predicted class or value for relevance backpropagation.

**Step 3: Select Relevance**

- The relevance is initially set to the output of the network for the predicted class. For instance, if the network predicts "dog", the relevance is assigned to the output neuron that corresponds to "dog".
- All other output neurons are set to have zero relevance since they did not contribute to the final prediction.

**Step 4: Backward Pass with LRP**

- Perform a backward pass starting from the selected output neuron.
- At each layer, distribute the relevance from the layer above to the neurons in the current layer. This is done using LRP rules which define how to allocate relevance among neurons based on their contributions.
- Continue this process layer by layer, moving from the output layer towards the input layer.

**Step 5: LRP Decomposition Rules**

- Apply specific LRP rules at each layer to decompose the relevance among the connected neurons. These rules are based on the contribution of each neuron to the activation of the neurons in the next layer.
- Commonly used rules include the LRP-ε, LRP-γ, and LRP-0 rules, each with different properties and applications.

**Step 6: Relevance Conservation**

- Ensure that the total relevance is conserved at each layer. The sum of the relevance scores assigned to all neurons in one layer should equal the total relevance from the layer above.

**Step 7: Generation of Heatmaps**

- Once the relevance has been backpropagated to the input layer, use the computed relevance scores to generate a heatmap.
- This heatmap visualizes the contribution of each input feature (like pixels in an image) to the network’s prediction.

**Step 8: Interpretation and Validation**

- Analyze the heatmap to interpret which features were most and least important for the model's prediction.
- Validate the explanation by techniques like pixel-perturbation analysis, where input features deemed important by LRP are altered to see if they affect the model's output.

**Step 9: Evaluation of Explanations**

- Use quantitative methods to evaluate the quality of the explanations provided by LRP, such as comparing them with ground truth annotations when available, or assessing the impact of perturbations on the model’s prediction accuracy.

## Implementatation

#### Introduction

Now, we will apply LRP to a brain cancer classification task using medical MRI scans.

MRI stands for Magnetic Resonance Imaging, and LRP helps us not only to identify if there's brain cancer in an image but also to understand why the model made its prediction. LRP is predominantly applied to neural networks but can be used with support vector machines as well.

Our goal with LRP is to visually explain the decision of a model by showing which parts of an input, such as image pixels, contributed to a particular prediction.

In [1]:
# %% Imports
import torch
import torch.nn as nn
import torchvision
import torchvision.models as models
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
import torch.optim as optim
import copy
import pandas as pd

Before proceeding, it's important to check if a GPU is available for computation, as this will significantly speed up the training process.

In [2]:
# Set GPU device
device = "cuda:0" if torch.cuda.is_available() else "cpu"
print(device)

cuda:0


#### Data Downloading

Now I will be using a dataset of brain MRI images. These images are organized into separate folders based on whether they show evidence of a tumor. The dataset contains both a training and a testing set, which I will load for our model.

In [4]:
! pip install kaggle -q

In [6]:
import os
from getpass import getpass

# Prompt the user for API username and key input
kaggle_username = getpass('Enter your Kaggle username')
kaggle_key = getpass('Enter your Kaggle API key')

os.environ['KAGGLE_USERNAME'] = kaggle_username  # Sets the username as an environment variable
os.environ['KAGGLE_KEY'] = kaggle_key            # Sets the key as an environment variable

Enter your Kaggle username··········
Enter your Kaggle API key··········


In [7]:
! kaggle datasets download -d sartajbhuvaji/brain-tumor-classification-mri

Downloading brain-tumor-classification-mri.zip to /content
 98% 85.0M/86.8M [00:00<00:00, 181MB/s]
100% 86.8M/86.8M [00:00<00:00, 149MB/s]


In [14]:
TRAIN_ROOT = "./data/brain_mri/Training"
TEST_ROOT = "./data/brain_mri/Testing"

In [None]:
!mkdir -p "./data/brain_mri"  # Creates the directory and any necessary parent directories
!unzip -q brain-tumor-classification-mri.zip -d "./data/brain_mri"

In [15]:
# %% Load data
train_dataset = torchvision.datasets.ImageFolder(root=TRAIN_ROOT)
test_dataset = torchvision.datasets.ImageFolder(root=TRAIN_ROOT)

#### Building the Model

We construct our neural network model based on the pre-trained VGG16 architecture, modifying the output layer to suit our four-class problem - identifying different types of brain tumors. This model adaptation allows us to leverage the powerful feature extraction capabilities of VGG16 while fine-tuning it for our specific task.

In [16]:
# %% Building the model
class CNNModel(nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()
        self.vgg16 = models.vgg16(pretrained=True)

        # Replace output layer according to our problem
        in_feats = self.vgg16.classifier[6].in_features
        self.vgg16.classifier[6] = nn.Linear(in_feats, 4)

    def forward(self, x):
        x = self.vgg16(x)
        return x

model = CNNModel()
model.to(device)
model

Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to /root/.cache/torch/hub/checkpoints/vgg16-397923af.pth
100%|██████████| 528M/528M [00:05<00:00, 102MB/s] 


CNNModel(
  (vgg16): VGG(
    (features): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU(inplace=True)
      (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (3): ReLU(inplace=True)
      (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (6): ReLU(inplace=True)
      (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (8): ReLU(inplace=True)
      (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (11): ReLU(inplace=True)
      (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (13): ReLU(inplace=True)
      (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (15): ReLU(inplace=True)
      (16):

#### Data Preprocessing

Here we apply the necessary transformations to our dataset images to fit the input requirements of the pre-trained VGG16 model. The images are resized and converted to tensors, ready for model consumption.

In [17]:
# %% Prepare data for pretrained model
train_dataset = torchvision.datasets.ImageFolder(
        root=TRAIN_ROOT,
        transform=transforms.Compose([
                      transforms.Resize((255,255)),
                      transforms.ToTensor()
        ])
)

test_dataset = torchvision.datasets.ImageFolder(
        root=TEST_ROOT,
        transform=transforms.Compose([
                      transforms.Resize((255,255)),
                      transforms.ToTensor()
        ])
)

#### Creating Data Loaders

Data loaders are an integral part of the training process, enabling efficient data manipulation and batching. Here, we define our data loaders for both the training and test sets with a batch size of 32.

In [18]:
# %% Create data loaders
batch_size = 32
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True
)
test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=True
)


#### Training the Model

We will train our neural network model using the cross-entropy loss function and the Adam optimizer. The learning rate is set to a small value to fine-tune the pre-trained network. We will train the model for 10 epochs and print the loss for each batch to monitor the training progress.

In [None]:
# %% Train
cross_entropy_loss = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.00001)
epochs = 10

# Iterate x epochs over the train data
for epoch in range(epochs):
    for i, batch in enumerate(train_loader, 0):
        inputs, labels = batch
        inputs = inputs.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        # Labels are automatically one-hot-encoded
        loss = cross_entropy_loss(outputs, labels)
        loss.backward()
        optimizer.step()
        print(loss)

tensor(1.4567, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(1.5099, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(1.3743, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(1.3706, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(1.4097, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(1.4625, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(1.2630, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(1.3791, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(1.3515, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(1.2479, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(1.3732, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(1.3679, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(1.3097, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(1.2085, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(1.1745, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(1.2401, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(1.2134, device='cuda:0', grad_fn=

#### Inspecting Predictions

After training, it's crucial to evaluate the model's performance. Here, we load a batch of test data and make predictions with the trained model. We calculate the batch accuracy by comparing the predicted labels with the ground truth and display the results in a dataframe for a clear comparison.

In [None]:
# %% Inspect predictions for first batch
inputs, labels = next(iter(test_loader))
inputs = inputs.to(device)
labels = labels.numpy()
outputs = model(inputs).max(1).indices.detach().cpu().numpy()
comparison = pd.DataFrame()
print("Batch accuracy: ", (labels==outputs).sum()/len(labels))
comparison["labels"] = labels
comparison["outputs"] = outputs
comparison

#### Implementing Layerwise Relevance Propagation (LRP)

Layerwise Relevance Propagation (LRP) aims to explain the decision-making process of neural networks. We start by defining functions to adapt our VGG16 model for LRP by transforming dense layers to convolutional ones and cloning layers while applying specific relevance rules. We then compute the relevances for each layer in reverse, starting from the output and working back to the input, essentially performing a backward pass of relevance scores.

LRP assigns a relevance score to each neuron based on its contribution to the final decision. This score is backpropagated through the network's layers to the input layer. We treat different layers with specific rules - for lower layers, we focus on positive contributions, while for upper layers, we apply a more general rule without bias.

This process helps us understand which parts of an input image the model finds most relevant for its prediction, giving us insight into its decision-making process.

In [None]:
# %% Layerwise relevance propagation for VGG16
# For other CNN architectures this code might become more complex
# Source: https://git.tu-berlin.de/gmontavon/lrp-tutorial
# http://iphome.hhi.de/samek/pdf/MonXAI19.pdf

def new_layer(layer, g):
    """Clone a layer and pass its parameters through the function g."""
    layer = copy.deepcopy(layer)
    try: layer.weight = torch.nn.Parameter(g(layer.weight))
    except AttributeError: pass
    try: layer.bias = torch.nn.Parameter(g(layer.bias))
    except AttributeError: pass
    return layer

def dense_to_conv(layers):
    """ Converts a dense layer to a conv layer """
    newlayers = []
    for i,layer in enumerate(layers):
        if isinstance(layer, nn.Linear):
            newlayer = None
            if i == 0:
                m, n = 512, layer.weight.shape[0]
                newlayer = nn.Conv2d(m,n,7)
                newlayer.weight = nn.Parameter(layer.weight.reshape(n,m,7,7))
            else:
                m,n = layer.weight.shape[1],layer.weight.shape[0]
                newlayer = nn.Conv2d(m,n,1)
                newlayer.weight = nn.Parameter(layer.weight.reshape(n,m,1,1))
            newlayer.bias = nn.Parameter(layer.bias)
            newlayers += [newlayer]
        else:
            newlayers += [layer]
    return newlayers

def get_linear_layer_indices(model):
    offset = len(model.vgg16._modules['features']) + 1
    indices = []
    for i, layer in enumerate(model.vgg16._modules['classifier']):
        if isinstance(layer, nn.Linear):
            indices.append(i)
    indices = [offset + val for val in indices]
    return indices

def apply_lrp_on_vgg16(model, image):
    image = torch.unsqueeze(image, 0)
    # >>> Step 1: Extract layers
    layers = list(model.vgg16._modules['features']) \
                + [model.vgg16._modules['avgpool']] \
                + dense_to_conv(list(model.vgg16._modules['classifier']))
    linear_layer_indices = get_linear_layer_indices(model)
    # >>> Step 2: Propagate image through layers and store activations
    n_layers = len(layers)
    activations = [image] + [None] * n_layers # list of activations

    for layer in range(n_layers):
        if layer in linear_layer_indices:
            if layer == 32:
                activations[layer] = activations[layer].reshape((1, 512, 7, 7))
        activation = layers[layer].forward(activations[layer])
        if isinstance(layers[layer], torch.nn.modules.pooling.AdaptiveAvgPool2d):
            activation = torch.flatten(activation, start_dim=1)
        activations[layer+1] = activation

    # >>> Step 3: Replace last layer with one-hot-encoding
    output_activation = activations[-1].detach().cpu().numpy()
    max_activation = output_activation.max()
    one_hot_output = [val if val == max_activation else 0
                        for val in output_activation[0]]

   # activations[-1] = torch.FloatTensor([one_hot_output]).to(device)

    # >>> Step 4: Backpropagate relevance scores
    relevances = [None] * n_layers + [activations[-1]]
    # Iterate over the layers in reverse order
    for layer in range(0, n_layers)[::-1]:
        current = layers[layer]
        # Treat max pooling layers as avg pooling
        if isinstance(current, torch.nn.MaxPool2d):
            layers[layer] = torch.nn.AvgPool2d(2)
            current = layers[layer]
        if isinstance(current, torch.nn.Conv2d) or \
           isinstance(current, torch.nn.AvgPool2d) or\
           isinstance(current, torch.nn.Linear):
            activations[layer] = activations[layer].data.requires_grad_(True)

            # Apply variants of LRP depending on the depth
            # see: https://link.springer.com/chapter/10.1007%2F978-3-030-28954-6_10
            # Lower layers, LRP-gamma >> Favor positive contributions (activations)
            if layer <= 16:       rho = lambda p: p + 0.25*p.clamp(min=0); incr = lambda z: z+1e-9
            # Middle layers, LRP-epsilon >> Remove some noise / Only most salient factors survive
            if 17 <= layer <= 30: rho = lambda p: p;                       incr = lambda z: z+1e-9+0.25*((z**2).mean()**.5).data
            # Upper Layers, LRP-0 >> Basic rule
            if layer >= 31:       rho = lambda p: p;                       incr = lambda z: z+1e-9

            # Transform weights of layer and execute forward pass
            z = incr(new_layer(layers[layer],rho).forward(activations[layer]))
            # Element-wise division between relevance of the next layer and z
            s = (relevances[layer+1]/z).data
            # Calculate the gradient and multiply it by the activation
            (z * s).sum().backward();
            c = activations[layer].grad
            # Assign new relevance values
            relevances[layer] = (activations[layer]*c).data
        else:
            relevances[layer] = relevances[layer+1]

    # >>> Potential Step 5: Apply different propagation rule for pixels
    return relevances[0]

#### Visualizing Relevance Scores

Using the LRP implementation, we can now visualize the relevance scores for individual images. This cell demonstrates how to apply LRP to a specific test image and generate a heatmap of relevances, which helps us see which pixels were most influential in the model's prediction.

We normalize the relevance scores to make them easier to visualize and use a color map to distinguish areas of high and low relevance. The resulting visualization can be compared with the original image to better understand the model's decision-making process. This kind of visualization is particularly useful in domains like medical imaging, where understanding the model's focus can be critical.

In [None]:
# %%
# Calculate relevances for first image in this test batch
image_id = 2
image_relevances = apply_lrp_on_vgg16(model, inputs[image_id])
image_relevances = image_relevances.permute(0,2,3,1).detach().cpu().numpy()[0]
image_relevances = np.interp(image_relevances, (image_relevances.min(),
                                                image_relevances.max()),
                                                (0, 1))
# Show relevances
pred_label = list(test_dataset.class_to_idx.keys())[
             list(test_dataset.class_to_idx.values())
            .index(labels[image_id])]
if outputs[image_id] == labels[image_id]:
    print("Groundtruth for this image: ", pred_label)

    # Plot images next to each other
    plt.axis('off')
    plt.subplot(1,2,1)
    plt.imshow(image_relevances[:,:,0], cmap="seismic")
    plt.subplot(1,2,2)
    plt.imshow(inputs[image_id].permute(1,2,0).detach().cpu().numpy())
    plt.show()
else:
    print("This image is not classified correctly.")

## Explainable AI Demos

[Explainable AI Demos](https://lrpserver.hhi.fraunhofer.de/) is an educational or demonstration tool designed to make AI more accessible and understandable to users by visually and interactively showcasing how AI models arrive at their conclusions. There are four different types of demos:

- Handwriting Classification: This demo seems to use LRP to explain how a neural network trained on the MNIST dataset predicts handwritten digits. It suggests that users can input their handwriting for the AI to classify and explain.
- Image Classification: A more advanced LRP demo for image classification that uses a neural network implemented with Caffe, a deep learning framework. This demo likely illustrates how the AI model determines the content of images.
- Text Classification: This is for classifying natural language documents. The neural network provides predictions on the document's semantic category and uses LRP to explain the classification process.
- Visual Question Answering: This demo allows users to ask AI questions about an image and receive not only answers but also visual explanations that highlight relevant parts of the image involved in the AI's reasoning.
<table>
<tr>
    <td style="padding: 10px;"><img src="https://raw.githubusercontent.com/arkeodev/XAI/main/images/lrp_text_recognition_and_classification.png" alt="Text Recognition and Classification" width="600" /></td>
    <td style="padding: 10px;"><img src="https://raw.githubusercontent.com/arkeodev/XAI/main/images/lrp_image_classification.png" alt="Image Classification" width="600" /></td>
</tr>
<tr>
    <td style="padding: 10px;"><img src="https://raw.githubusercontent.com/arkeodev/XAI/main/images/lrp_mnist_image_recognition.png" alt="MNIST Image Recognition" width="600" /></td>
    <td style="padding: 10px;"><img src="https://raw.githubusercontent.com/arkeodev/XAI/main/images/lrp_question_and_answering.png" alt="Question and Answering" width="600" /></td>
</tr>
</table>

## References

- Layer-Wise Relevance Propagation: An Overview (Paper): https://iphome.hhi.de/samek/pdf/MonXAI19.pdf
- Layer-Wise Relevance Propogation: https://www.hhi.fraunhofer.de/en/departments/ai/technologies-and-solutions/layer-wise-relevance-propagation.html
- Explainable AI Demos: https://lrpserver.hhi.fraunhofer.de/