# Explanations in AI: Methods, Stakeholders and Pitfalls
<h3 align="center">Image Data</h3>
<br>

As machine learning models become increasingly adept at solving image tasks (e.g., object classification, detection, segmentation, ...), there is an emerging need to better understand the reasoning behind the predictions. Computer vision tasks differ from other machine learning tasks in that the base features (i.e., pixels - light intensity values in different color channels) are not generally influential individually. Instead, the combination of pixels that form higher-level features like textures, or patterns is something that can be visually recognize and leveraged to create explanations.

Explainability methods for image tasks, also known as pixel attribution methods, can be used to identify the pixels that are most important for a model's prediction. These methods can be classified as either gradient-based methods, occlusion- or perturbation-based methods, or a combination of these approaches. 

Certain explainability methods for image data also require so-called baselines. These baselines help determine which pixels are important to the predicted label by helping simulate absence of information. A baseline should be composed of neutral or uninformative pixel values; for example, black images, white images, or random noise.

To visualize explanations for images, so-called saliency maps are used: A saliency map is a heatmap that shows which parts of an image are most important (salient) for a model's prediction. As described above, different methods exist to create saliency maps. Saliency maps are also called sensitivity maps, or pixel attribution maps. 

---
__Problem Statement:__ This notebook shows how to use a pre-trained image classifier to automatically classify animals and object into different categories. This kind of classifier could be used to automatically index images for a easier retrieval. In this notebook, four prevalent approaches for explaining image classification models: Saliency Maps (Vanilla Gradient), Integrated Gradient, SHAP (GradientShap and DeepLiftShap), and counterfactual (adverserial) examples.

---
__Dataset:__ 
[ImageNet](https://www.image-net.org/challenges/LSVRC/index.php) is a large-scale hierarchical image database organized according to the WordNet hierarchy. It spans 1,000 object classes and contains 1,281,167 training images, 50,000 validation images and 100,000 test images. For the purpose of this hands-on exercise, a subset of ImageNet is introduced that contains one sample image per class (for a total of 1,000 distinct images). 

Attribution: Olga Russakovsky*, Jia Deng*, Hao Su, Jonathan Krause, Sanjeev Satheesh, Sean Ma, Zhiheng Huang, Andrej Karpathy, Aditya Khosla, Michael Bernstein, Alexander C. Berg and Li Fei-Fei. (* = equal contribution) ImageNet Large Scale Visual Recognition Challenge. IJCV, 2015.

---
<a name="0">__Contents of Notebook:__</a>

1. <a href="#1">Loading Data and Model</a>
2. <a href="#2">Explanations</a> <br>
    2.1. <a href="#21">Saliency Maps</a> <br>
    2.2. <a href="#22">Integrated Gradients</a> <br>
    2.3. <a href="#23">SHAP Values</a> <br>
    2.4. <a href="#24">Counterfactual Explanations</a> <br>
3. <a href="#3">Summary</a> <br>

---

This notebook uses modified code snippets from [Captum](https://captum.ai/tutorials/Titanic_Basic_Interpret) and [PyTorch](https://pytorch.org/vision/stable/_modules/torchvision/models/resnet.html#resnet50.py).


In [None]:
# Operational libraries
import sys
from pathlib import Path
import copy

# Jupyter(lab) libraries
if not sys.warnoptions:
    import warnings

    warnings.filterwarnings("ignore")

# Reshaping/basic libraries
import numpy as np
import random

# Plotting libraries
import matplotlib.pyplot as plt
import seaborn as sns

# Explainability libraries
from captum.robust import FGSM
from captum.attr import IntegratedGradients, DeepLiftShap, GradientShap, Saliency
from omnixai.data.image import Image
from omnixai.explainers.vision import CounterfactualExplainer

# Visualization libraries
from utils.cv_utils import ImageNetCase
from utils.viz_utils import visualize_image_attr_multiple

# Store pretrained models and datasets
cache_dir = Path(".cache")
cache_dir.mkdir(parents=True, exist_ok=True)

# Neural Net libraries
import torch
from torch import nn
import torchvision.transforms as transforms
import tensorflow as tf

# Globals
import logging

tf.get_logger().setLevel(logging.ERROR)


def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)


set_seed(1)

## 1. <a name="1">Loading Data and Model</a>
(<a href="#0">Go to top</a>)


As first step, let's load the `ImageNetCase` class which contains helper functions and methods that can be used to quickly load sample images, perform transformations and create predictions using a pre-trained version of ResNet that was customized to contain unique model names (which is a requirement for certain classes in Captum).

In [None]:
# set index for image to retrieve
img_index = 1

# instantiate image case
ic = ImageNetCase()

# load image based on selected index
ic.load_image(img_index)

In the next section, we will present several different explanation methods. zto quickly extract the attribution/importance scores let's define a generic function, `attribute_scores_img`, that we can reuse. The attribution score of a feature indicates how much that feature contributed to the model's prediction. A positive score means that the feature contributed positively, while a negative score means that the feature contributed negatively. The magnitude of the score indicates the strength of the contribution. A zero score means that the feature had no contribution.

In [None]:
def attribute_scores_img(algorithm, input_instance, **kwargs):
    ic._model.zero_grad()
    tensor_attributions = algorithm.attribute(
        input_instance, target=img_index, **kwargs
    )
    return tensor_attributions

Before investigating any of the explanations, let's have a look at the prediction that ResNet creates for this particular image.

In [None]:
# get logits for prediction - to return probabilities, use return_probabilities=True
logits = ic.predict()

print(
    "Predicted",
    # lookup for class name
    ic.class_names[np.argmax(logits).item()],
    "with logit:",
    # extract max logit value
    round(logits.squeeze()[np.argmax(logits).item()].item(), 2),
)

Let's take our current input image and apply the transformation that is required to prepare the image (e.g., resizing, converting to a tensor, ...) for ResNet.

In [None]:
# convert image to tensor
input = ic._transform(ic._current_image).unsqueeze(0)

During the transformation, the image was also cropped and normalized. Eventually we will want to plot the model explanation next to the original image. However, the original image will be of different size. Therefore, to be able to plot a cropped version of the image that is true to the color of the original picture, we want to perform an inverse transformation and transpose the color channels and dimensions of the image to obtain the expected format.

In [None]:
# convert transformed image back to original (but maintain crop/size)
original_image = np.transpose(
    ic._inverse_transform(input).cpu().detach().numpy().squeeze(), (1, 2, 0)
)

## 2. <a name="2">Explanations</a>
(<a href="#0">Go to top</a>)


### 2.1. <a name="21">Saliency Maps</a>
(<a href="#2">Go to Explanations</a>)

Saliency maps are visualizations of images in which the most important (salient) pixels are highlighted. Saliency maps (as introduced by Simonyan et al. [1]) highlight the areas of a given image, discriminative with respect to certain class. Because the expression "saliency maps" is often used to refer to the collection of approaches that create explanations for image problems, the method developed by Simonyan et al. is also called "Vanilla Gradient".

To obtain the Vanilla Gradient attribution scores, a back-propagation with respect to the input image is performed (computing the gradient of the output with respect to the input image). The gradient quantifies the amount by which the classification score will change if the pixel changes by a small amount. 

It is also possible to calculate the saliency map of an image to a class other than its label; in that case the saliency map highlights which pixels detract or add to the classification.

---
[1] Simonyan, Karen, Andrea Vedaldi, and Andrew Zisserman. "Deep inside convolutional networks: Visualising image classification models and saliency maps." (2013).

In [None]:
# attach gradients
input.requires_grad = True

In [None]:
# instantiate Saliency method
sal = Saliency(ic._model)

# calculate pixel attributions
attr_sal = attribute_scores_img(
    sal,
    input,
)

# reshape result for plotting
attr_sal = np.transpose(attr_sal.squeeze().cpu().detach().numpy(), (1, 2, 0))

In [None]:
# visualize attributions
_ = visualize_image_attr_multiple(
    attr_sal,
    original_image,
    ["original_image", "heat_map"],
    ["all", "positive"],
    show_colorbar=True,
    titles=["Original (cropped) image", "Saliency"],
    cmap=plt.cm.hot,
)

<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/4.7.0/css/font-awesome.min.css">
<i class="fa fa-exclamation-circle" style="color:red"></i> Vanilla Gradient saliency maps can be problematic when certain activation functions are used, such as ReLU. For activation values below zero, ReLU applies a zero-cap. This is the so-called saturation problem; once the activation is saturated, the gradient will be zero and the pixels will be considered not important. Integrated gradients avoids this problem by integrating over a path, which means that it takes into account the entire activation function, not just the final output.

<i class="fa fa-exclamation-circle" style="color:red"></i> Ghorbani et al. [2] showed that introducing small (adversarial) perturbations to an image, which still lead to the same prediction, can lead to very different pixels being highlighted as explanations. To illustrate this, review the example below. The example illustrates that pixel attribution methods can be very fragile. The issue of fragility also applies to other widely-used interpretation methods such as relevance propagation, and DeepLIFT. 

<i class="fa fa-exclamation-circle" style="color:red"></i> Kindermans et al. [3] showed that gradient-based methods can be unreliable by adding a constant shift to the input data. They compared two networks, the original network and a shifted network that had its bias adjusted to compensate for the constant shift. Both networks produced the same predictions, and their gradients were the same. However, the explanations generated by the two networks were different. This shows that gradient-based methods can be sensitive to changes in the input data, which can make them unreliable.

<i class="fa fa-exclamation-circle" style="color:red"></i> The assumption is that pixels highlighted in saliency maps are the evidence for why a certain prediction was made. If that is indeed the case, when the prediction changes, the explanation should change.
Therefore, if the prediction is random, the explanation should really change. It is possible to test this behavior as outlined in [Sanity Check for Saliency Maps](https://arxiv.org/abs/1810.03292) [4]. A simple test would be to randomize the weights of a model starting from the top layer, successively, all the way to the bottom layer. This will destroys the learned weights from the top layers to the bottom ones.

---
[2] Ghorbani, Amirata, Abubakar Abid, and James Zou. “Interpretation of neural networks is fragile.” Proceedings of the AAAI Conference on Artificial Intelligence. (2019). <br>
[3] Kindermans, Pieter-Jan, Sara Hooker, Julius Adebayo, Maximilian Alber, Kristof T. Schütt, Sven Dähne, Dumitru Erhan, and Been Kim. “The (un) reliability of saliency methods.” (2019). <br>
[4] Julius Adebayo, Justin Gilmer, Michael Muelly, Ian Goodfellow, Moritz Hardt, and Been Kim. 2018. Sanity checks for saliency maps. In Proceedings of the 32nd International Conference on Neural Information Processing Systems (NIPS'18). Curran Associates Inc., Red Hook, NY, USA, 9525–9536.

In [None]:
# create version of current image with added random noise
input_noise = ic._transform_noise(ic._current_image).unsqueeze(0)

# convert back to original colors and clip RGB range
original_input_noise = np.transpose(
    torch.clip(ic._inverse_transform(input_noise[0]), min=0, max=1).detach().numpy(),
    (1, 2, 0),
)

In [None]:
# calculate pixel attributions
attr_sal_noise = attribute_scores_img(
    sal,
    input_noise,
)

# reshape result for plotting
attr_sal_noise = np.transpose(
    attr_sal_noise.squeeze().cpu().detach().numpy(), (1, 2, 0)
)

# visualize attributions
_ = visualize_image_attr_multiple(
    attr_sal_noise,
    original_input_noise,
    ["original_image", "heat_map"],
    ["all", "positive"],
    show_colorbar=True,
    titles=["Original (cropped) image", "Saliency map (pertubation)"],
    cmap=plt.cm.hot,
)

Interestingly, adding noise can create a more smooth-looking saliency map (especially when using a random noise baseline). This happens because in images, there is a rich correlation structure between neighboring pixels. Once the model has learned the value of a pixel, it will not use nearby pixels as those will have similar intensities.

However, by introducing random noise from an independent gaussian distribution this correlation structure will be broken up. This means that the importance of each pixel will be considered independently of the other pixel values and the resulting saliency map can look less noisy. 

Let's also have a look at whether the explanation is sensitive to the model itself by performing a mode randomization test. For this test we randomize the weights of a model starting from the top layer.

In [None]:
model_randomization = copy.deepcopy(ic.model)

# randomize model weights for top layer
with torch.no_grad():
    model_randomization.fc.weight = torch.nn.Parameter(
        torch.randn(model_randomization.fc.weight.size()) * 0.02
    )

In [None]:
# instantiate Saliency method
sal_rand = Saliency(model_randomization)

# calculate pixel attributions
attr_sal_model_rand = attribute_scores_img(
    sal_rand,
    input,
)

# reshape result for plotting
attr_sal_model_rand = np.transpose(
    attr_sal_model_rand.squeeze().cpu().detach().numpy(), (1, 2, 0)
)

# visualize attributions
_ = visualize_image_attr_multiple(
    attr_sal_model_rand,
    attr_sal,
    ["heat_map", "heat_map"],
    ["positive", "positive"],
    show_colorbar=True,
    titles=["Saliency map", "Saliency map (random predictions)"],
    cmap=plt.cm.hot,
)

In [None]:
def get_prediction(model, input):
    # forward pass to calculate predictions
    preds = model(input)

    # get max logit
    proba, indx = torch.max(preds, 1)

    return ic.class_names[indx], proba


get_prediction(model_randomization, input)[0]

We can clearly see that the model is now making random predictions; yet the saliency map remained mostly unchanged.

### 2.2. <a name="22">Integrated Gradients</a>
(<a href="#2">Go to Explanations</a>)

Integrated Gradients are a local attribution method that computes the integral of the gradients of the output of the model for the predicted class with respect to the input image pixels along the path from a baseline image to the original input image. As the model gets more information, the prediction score changes in a meaningful way. By accumulating gradients along the path, the model gradient can be used to determine which input features contribute most to the model prediction.

By integrating over a path, the Integrated Gradients method mitigates the saturation problem: The pixels' local gradients are accumulated when integrating along a straight line path from the baseline image to the input image.

In practice, this integration is approximated with $k$ linearly-spaced points between 0 and 1 for some value of $k$. We can use this property to better understand the difference between the total approximated and true integrated gradients (approximation error). To estimate the error, we calculate integrated gradients for different values of $k$ points, and measure the difference. If the difference is big (large approximation error), then a larger $k$ is required. In general, the lower the absolute value of the convergence delta the better is the approximation. 

The difference between total approximated and true integrated gradients is also known as 'convergence delta'. The convergence delta can be returned by passing `return_convergence_delta = True` to the method. 

In [None]:
# instantiate IntegratedGradients method
ig = IntegratedGradients(ic._model)

# calculate pixel attributions
attr_ig, delta = attribute_scores_img(
    ig, input, baselines=input * 0, return_convergence_delta=True
)

# reshape result for plotting
attr_ig = np.transpose(attr_ig.squeeze().cpu().detach().numpy(), (1, 2, 0))

In [None]:
# visualize attributions
_ = visualize_image_attr_multiple(
    attr_ig,
    original_image,
    ["original_image", "heat_map"],
    ["all", "all"],
    show_colorbar=True,
    titles=["Original (cropped) image", "Integrated Gradients"],
    cmap=sns.color_palette("bwr", as_cmap=True),
)

<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/4.7.0/css/font-awesome.min.css">
<i class="fa fa-exclamation-circle" style="color:red"></i> Integrated Gradient may produce spurious or noisy pixel attributions that aren’t related to the model’s predicted class. This is partly due to the accumulation of noise from regions of correlated, high-magnitude gradients for irrelevant pixels that occur along the straight line path that is used when computing Integrated Gradients.

<i class="fa fa-exclamation-circle" style="color:red"></i> It is important to choose a good baseline for Integrated Gradient to make sensible feature attributions. For example, if a black image is chosen as baseline, Integrated Gradient won’t attribute importance to a completely black pixel in an actual image. The baseline value should both have a near-zero prediction, and also faithfully represent a complete absence of signal. Try a different baseline below to check how the attribution scores change.

### 2.3. <a name="23">SHAP</a>
(<a href="#2">Go to Explanations</a>)

#### GradientShap
GradientShap is a gradient method to compute SHAP values. GradientShap combines ideas from Integrated Gradients, SHAP, and SmoothGrad - it can be viewed as an approximation of integrated gradients by computing the expectations of gradients for different baselines. Gaussian noise is added to each input sample multiple times, then a random point on the path between the baseline and the input is picked to determine the gradient of the outputs.
 
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/4.7.0/css/font-awesome.min.css">
<i class="fa fa-exclamation-circle" style="color:red"></i> GradientShap makes an assumption that the input features are independent and that the explanation model is linear, meaning that the explanations are modeled through the additive composition of feature effects.

In [None]:
# instantiate GradientShap method
gs = GradientShap(ic._model)

# calculate pixel attributions
attr_gs, delta = attribute_scores_img(
    gs, input, baselines=input * 0, return_convergence_delta=True
)

# reshape result for plotting
attr_gs = np.transpose(attr_gs.squeeze().cpu().detach().numpy(), (1, 2, 0))

In [None]:
# visualize attributions
_ = visualize_image_attr_multiple(
    attr_gs,
    original_image,
    ["original_image", "heat_map"],
    ["all", "all"],
    show_colorbar=True,
    titles=["Original (cropped) image", "GradientShap"],
    cmap=sns.color_palette("bwr", as_cmap=True),
)

#### DeepLiftShap
DeepLiftShap is a method extending DeepLIFT to approximate SHAP values, which are based on Shapley values proposed in cooperative game theory. DeepLIFT SHAP takes a distribution of baselines and computes the DeepLIFT attribution for each input-baseline pair and averages the resulting attributions per input example.

In [None]:
# instantiate method
dls = DeepLiftShap(ic._model)

# calculate pixel attributions
attr_dls, delta = attribute_scores_img(
    dls,
    input,
    baselines=torch.cat([input * 0, input * 1]),
    return_convergence_delta=True,
)

# reshape result for plotting
attr_dls = np.transpose(attr_dls.squeeze().cpu().detach().numpy(), (1, 2, 0))

In [None]:
# visualize attributions
_ = visualize_image_attr_multiple(
    attr_dls,
    original_image,
    ["original_image", "heat_map"],
    ["all", "all"],
    show_colorbar=True,
    titles=["Original (cropped) image", "DeepLiftShap"],
    cmap=sns.color_palette("bwr", as_cmap=True),
)

### 2.4. <a name="24">Counterfactual (and adverserial) examples</a>
(<a href="#2">Go to Explanations</a>)
Just like with tabular data, it is possible to create counterfactuals for images. Given a sample image, $x$, for which a model predicts predicts a certain class $c$, a counterfactual visual explanation identifies how $x$ could change such that the model would output a different specified class $c*$.

To find the counterfactual, we need to find a distractor image, $x*$, for which the model predicts $c*$ and identify regions in the distractor image such that replacing said region in the original image, $x$, flips the prediction. Replacing regions can be achieved with a permutation matrix or by obfuscating certain parts of the image.

Simplified we are trying to answer which parts of the image, if they are exchanged/not seen by the classifier, would most change the classifiers' decision.


Let's create the counterfactuals using the array of images provided above; we will create a counterfactual for the image at index zero, which is corresponds to the class `goldfish`.

In [None]:
# to stack images in array, we want to resize them first to avoid length mismatch
transform_stack = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.0, 0.0, 0.0], std=[1 / 229, 1 / 224, 1 / 225]),
    ]
)

# use current image as placeholder first image
img_array = transform_stack(ic._current_image).unsqueeze(0).detach().numpy().astype(int)

# loop through additional images
for i in range(5):
    i = transform_stack(ic.load_image(i)).unsqueeze(0).detach().numpy().astype(int)
    img_array = np.concatenate((img_array, i), 0)

# store data in PIL array
pil_array = Image(data=img_array, batched=True, channel_last=False)

# look at a cropped example
pil_array[0].to_pil()

The next cell can take about 15 minutes to run; to reduce the wait time, specify one index, `idx`, to explain in `explainer.explain(pil_array[idx])` rather than getting counterfactuals for all examples.

In [None]:
transform_process = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)
pre = lambda ims: torch.stack([transform_process(im.to_pil()) for im in ims])


# initialize explainer
explainer = CounterfactualExplainer(
    model=ic.model, preprocess_function=pre, num_iterations=50
)

# explain the test image
explanations = explainer.explain(pil_array)

In [None]:
# index for image to explain
expl_idx = 5

# show explanation visual for first image in sample array
explanations.ipython_plot(index=int(expl_idx), class_names=ic._class_names)

In [None]:
# show cropped counterfactual for comparison
Image(
    data=transform_stack(
        ic.load_image(int(explanations.explanations[expl_idx]["cf_label"]))
    )
    .detach()
    .numpy(),
    channel_last=False,
).to_pil()

Another method that can be used to create counterfactuals for image data is connected to the concept of adverserials: The same method that creates adversarial examples (AEs) to fool image-classifiers can be used to generate counterfactual explanations (CEs) that explain algorithmic decisions.

We first utilize the Fast Gradient Sign Method (FGSM) to construct an adversarial example. FGSM utilizes the sign of the gradient to perturb the input: For a given input image, FGSM uses the gradients of the loss with respect to the input image to create a new image that maximizes the loss. This can be used to find counterfactuals. Because the gradients are used, this method is computationally very efficient.

In [None]:
# Construct FGSM attacker
fgsm = FGSM(ic.model, lower_bound=-1, upper_bound=1)

perturbed_image_fgsm = fgsm.perturb(input, epsilon=0.2, target=img_index)

new_pred_fgsm, score_fgsm = get_prediction(ic.model, perturbed_image_fgsm)

print(new_pred_fgsm + " " + str(score_fgsm.item()))

plt.imshow(
    ic._inverse_transform(perturbed_image_fgsm)
    .squeeze()
    .permute(1, 2, 0)
    .detach()
    .numpy()
)

### 3. <a name="3">Summary</a>
(<a href="#0">Go to top</a>)

In this notebook we looked at several different methods to create variations of pixel attribution maps. We found that pixel attribution maps are useful because they create explanations that are visual which makes it easy to immediately recognize the important regions of an image. In particular, we reviewed examples of gradient based methods (Vanilla Gradient, Integrated Gradients) and perturbation methods (GradientShap, DeepLiftShap).

We found that various explanation methods also face certain challenges:
- Pixel attribution methods can be very fragile by adding random noise to an image which did not impact the model prediction but changed the highlighted areas.
- Saliency methods can be insensitive to model and data; by changing weights in the layers of our model we found that the prediction changed, yet the explanation did not.
- Baselines need to be considered carefully as different choices for baselines in gradient based methods, will yield different results.

When choosing a method to use, first and foremost it will be important to understand whether not you have access to the model itself. The methods presented in this notebook required direct access to the model. In addition, you should evaluate whether the method you want to use is sensitive to (various kinds of) perturbations, and whether or not changing the most salient pixels (insertion/deletion game) has an impact. For more details about metrics you can have a look at "[Explaining Classifiers using Adversarial Perturbations on the Perceptual Ball](https://arxiv.org/pdf/1912.09405.pdf)", and "[Understanding Deep Networks via Extremal Perturbations and Smooth Masks](https://arxiv.org/pdf/1910.08485.pdf)".

## Thank you for participating!