#The slippery slope of XAI evaluation

This notebook illustrates how the lack of ground truth explanations allows for manipulation of quantitative evaluation in explainable artificial intelligence (XAI). If you are running this notebook on Google Colab, remember to enable GPU support to speed up computation.

This example illustrates the basic concept on the MNIST dataset, where we optimize across a feasible set of perturbation functions to find a set of hyperparameters that give the best performance for a focus method.


In [1]:
#@title install packages

!pip install captum
!pip install quantus

from IPython.display import clear_output

clear_output()

In [2]:
#@title load packages

import torch
import quantus
import torchvision
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

from torchvision import datasets, transforms
from torch.utils.data import DataLoader


In [3]:
#@title Download data and create dataset and dataloader

device = 'cuda'
BATCH_SIZE = 500

transformer = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])
train_set = torchvision.datasets.MNIST(root='./sample_data', train=True, transform=transformer, download=True)
test_set = torchvision.datasets.MNIST(root='./sample_data', train=False, transform=transformer, download=True)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=BATCH_SIZE, pin_memory=True)

clear_output()

In [4]:
#@title Define the classification network and initialize network

class LeNet(torch.nn.Module):
    """Network architecture from: https://github.com/ChawDoe/LeNet5-MNIST-PyTorch."""
    def __init__(self):
        super().__init__()
        self.conv_1 = torch.nn.Conv2d(1, 6, 5)
        self.pool_1 = torch.nn.MaxPool2d(2, 2)
        self.relu_1 = torch.nn.ReLU()
        self.conv_2 = torch.nn.Conv2d(6, 16, 5)
        self.pool_2 = torch.nn.MaxPool2d(2, 2)
        self.relu_2 = torch.nn.ReLU()
        self.fc_1 = torch.nn.Linear(256, 120)
        self.relu_3 = torch.nn.ReLU()
        self.fc_2 = torch.nn.Linear(120, 84)
        self.relu_4 = torch.nn.ReLU()
        self.fc_3 = torch.nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool_1(self.relu_1(self.conv_1(x)))
        x = self.pool_2(self.relu_2(self.conv_2(x)))
        x = x.view(x.shape[0], -1)
        x = self.relu_3(self.fc_1(x))
        x = self.relu_4(self.fc_2(x))
        x = self.fc_3(x)
        return x

    def softmax_forward(self, x):
        x = x.unsqueeze(0)
        return torch.nn.functional.softmax(self.forward(x), dim=1)

model = LeNet()


In [5]:
#@title Define functions for training and evaluation

import torchvision

def train_model(model,
                train_data: torchvision.datasets,
                test_data: torchvision.datasets,
                device: torch.device,
                epochs: int = 20,
                criterion: torch.nn = torch.nn.CrossEntropyLoss(),
                optimizer: torch.optim = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9),
                evaluate: bool = False):
    """Train torch model."""

    model.train()

    for epoch in range(epochs):

        for images, labels in train_data:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()

            logits = model(images)
            loss = criterion(logits, labels)
            loss.backward()
            optimizer.step()

        # Evaluate model!
        if evaluate:
            predictions, labels = evaluate_model(model, test_data, device)
            test_acc = np.mean(np.argmax(predictions.cpu().numpy(), axis=1) == labels.cpu().numpy())

        print(f"Epoch {epoch+1}/{epochs} - test accuracy: {(100 * test_acc):.2f}% and CE loss {loss.item():.2f}")

    return model

def evaluate_model(model, data, device):
    """Evaluate torch model."""
    model.eval()
    logits = torch.Tensor().to(device)
    targets = torch.LongTensor().to(device)

    with torch.no_grad():
        for images, labels in data:
            images, labels = images.to(device), labels.to(device)
            logits = torch.cat([logits, model(images)])
            targets = torch.cat([targets, labels])

    return torch.nn.functional.softmax(logits, dim=1), targets

In [6]:
#@title Train and evaluate model

model = train_model(model=model.to(device),
                    train_data=train_loader,
                    test_data=test_loader,
                    device=device,
                    epochs=10,
                    criterion=torch.nn.CrossEntropyLoss().to(device),
                    optimizer=torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9),
                    evaluate=True)

# Model to GPU and eval mode.
model.to(device)
model.eval()

# Check test set performance.
predictions, labels = evaluate_model(model, test_loader, device)
test_acc = np.mean(np.argmax(predictions.cpu().numpy(), axis=1) == labels.cpu().numpy())
print(f"Model test accuracy: {(100 * test_acc):.2f}%")

Epoch 1/10 - test accuracy: 19.62% and CE loss 2.28
Epoch 2/10 - test accuracy: 85.77% and CE loss 0.48
Epoch 3/10 - test accuracy: 93.39% and CE loss 0.25
Epoch 4/10 - test accuracy: 95.26% and CE loss 0.21
Epoch 5/10 - test accuracy: 96.49% and CE loss 0.19
Epoch 6/10 - test accuracy: 96.80% and CE loss 0.09
Epoch 7/10 - test accuracy: 97.38% and CE loss 0.07
Epoch 8/10 - test accuracy: 97.69% and CE loss 0.06
Epoch 9/10 - test accuracy: 97.62% and CE loss 0.10
Epoch 10/10 - test accuracy: 97.98% and CE loss 0.06
Model test accuracy: 97.98%


In [65]:
#@title Function for creating faithfulness curve

from quantus import gaussian_noise, uniform_noise, baseline_replacement_by_blur

def perturb_input(x, perturbation_index, perturbation_function):
    if perturbation_function == 'baseline_replacement_by_blur':
        x = baseline_replacement_by_blur(x, perturbation_index, [0], blur_kernel_size=7)
    elif perturbation_function == 'gaussian_noise':
        x = gaussian_noise(x, perturbation_index, [0], perturb_std=1.0)
    elif perturbation_function == 'uniform_noise':
        x = uniform_noise(x, perturbation_index, [0], lower_bound=-1.0, upper_bound=1.0)
    else:
        raise ValueError('Unknown perturbation function')
    return x

def create_faithfulness_curve(model, x_sample, a_sample, subset_size, perturbation_function):

  model.eval()
  _, H, W = x_sample.shape
  number_of_pixels = H*W
  number_of_subsets = number_of_pixels // subset_size

  a_sample = a_sample.flatten()
  sorted_idx = np.argsort(-a_sample)

  x_perturbed = x_sample.copy()

  initial_prediction_scores = torch.nn.functional.softmax(model(torch.tensor(x_perturbed, device='cuda').unsqueeze(0)), dim=1)
  initial_prediction_index = initial_prediction_scores.argmax()

  faithfulness_curve = []

  for subset_i in range(number_of_subsets):

      x_perturbed = x_perturbed.flatten()
      pert_idx = sorted_idx[subset_i*subset_size:(subset_i+1)*subset_size]
      x_perturbed = perturb_input(x_perturbed.flatten(), pert_idx, perturbation_function)
      x_perturbed = x_perturbed.reshape(1, 28, 28)

      prediction_scores = torch.nn.functional.softmax(model(torch.tensor(x_perturbed, device='cuda').unsqueeze(0)), dim=1).squeeze()

      faithfulness_curve.append(prediction_scores[initial_prediction_index].item())

  return faithfulness_curve


In [81]:
#@title Calculate faithfulness score across several XAI methods for different partition size


xai_methods = ['Saliency', 'LRP', 'KernelShap']
partition_size = 28
image_size = 28
number_of_channels = 1
number_of_analysis_samples = 100
feasible_set_of_perturbation_functions = ['gaussian_noise', 'uniform_noise', 'baseline_replacement_by_blur']
results = pd.DataFrame(index=xai_methods, columns=feasible_set_of_perturbation_functions)

analysis_set, _ = torch.utils.data.random_split(test_set, [number_of_analysis_samples, len(test_set)-number_of_analysis_samples])
analysis_loader = torch.utils.data.DataLoader(analysis_set, batch_size=number_of_analysis_samples, shuffle=False)

for perturbation_function in feasible_set_of_perturbation_functions:
    for xai_method in xai_methods:

        faithfulness_scores = []

        for x_batch, y_batch in analysis_loader:
            x_batch , y_batch = x_batch.to('cuda'), y_batch.to('cuda')
            a_batch = quantus.explain(model, x_batch, y_batch, method=xai_method,
                                      img_size=image_size, nr_channels=number_of_channels, normalise=False)

            with torch.no_grad():
                for idx, (x_in, a_in) in enumerate(zip(x_batch, a_batch)):

                    faithfulness_curve = create_faithfulness_curve(model, x_in.numpy(force=True), a_in, partition_size, perturbation_function)
                    faithfulness_scores.append(np.trapz(faithfulness_curve))

        results.at[xai_method, perturbation_function] = np.mean(faithfulness_scores)
        print(perturbation_function, xai_method)




gaussian_noise Saliency
gaussian_noise LRP




gaussian_noise KernelShap




uniform_noise Saliency
uniform_noise LRP




uniform_noise KernelShap




baseline_replacement_by_blur Saliency
baseline_replacement_by_blur LRP




baseline_replacement_by_blur KernelShap


#Manipulation output

The cell below calculates the scores for the different partition sizes. We assume that a partion size of 14 is the base option, as this is commonly used in the literature. Our manipulation is towards LRP, so we seek to optimize towards LRP. Note that since we are calculating the AUC of the faithfulness curve, we are looking for the minimal value.


In [83]:
#@title calculate maipulation scores and print base versus manipulated hyperparameters

focus_method = 'LRP'
non_focus_methods = ['Saliency', 'KernelShap']

print(f"Base option scores \n {results['uniform_noise']}")
print(f"Top method base option {results['uniform_noise'].idxmin()}")

adversarial_objective = pd.DataFrame(index=[focus_method], columns=feasible_set_of_perturbation_functions)

for perturbation_function in feasible_set_of_perturbation_functions:

    adversarial_objective.at[focus_method, perturbation_function] = results[perturbation_function][focus_method]-np.mean(results[perturbation_function][non_focus_methods])


top_perturbation_function_for_focus_method = adversarial_objective.idxmin(axis=1).values[0]

print(f"Manipulated option scores \n {results[top_perturbation_function_for_focus_method]}")
print(f"Top method base option {results[top_perturbation_function_for_focus_method].idxmin()}")

Base option scores 
 Saliency      23.277109
LRP           24.828824
KernelShap    24.425446
Name: uniform_noise, dtype: object
Top method base option Saliency
Manipulated option scores 
 Saliency       22.31418
LRP           19.626559
KernelShap    24.195569
Name: baseline_replacement_by_blur, dtype: object
Top method base option LRP
