# Explanations in AI: Methods, Stakeholders and Pitfalls
<h3 align="center">Text Data</h3>
<br>


---
This notebook shows how to use a pre-trained text classifier to predict the sentiment of movie reviews and how to generate explanations for these predictions using different methods.

---
__Dataset:__ 
The [IMDB movie review dataset](https://huggingface.co/datasets/imdb) is a collection of 50,000 movie reviews that were collected from the Internet Movie Database (IMDb). The reviews are labeled as either positive or negative, and they are all associated with a specific movie. The dataset was created by Maas et al. (2011) and is publicly available.

---
<a name="0">__Contents of Notebook:__</a>

1. <a href="#1">Downloading the Dataset</a>
2. <a href="#2">Loading and Inspecting the Model</a>
3. <a href="#3">Explanations</a> <br>
    3.1. <a href="#31">Kernel SHAP</a> <br>
    3.2. <a href="#32">Integrated Gradients</a> <br>
4. <a href="#4">Potential Issues</a>

---
Attribution: Andrew L. Maas, Raymond E. Daly, Peter T. Pham, Dan Huang, Andrew Y. Ng, and Christopher Potts. 2011. Learning Word Vectors for Sentiment Analysis. In Proceedings of the 49th Annual Meeting of the Association for Computational Linguistics.

This notebook uses modified code snippets from [Captum](https://captum.ai/) and [HuggingFace](huggingface.co/).

In [None]:
# Operational libraries
import sys
import os
from pathlib import Path

# Jupyter(lab) libraries
if not sys.warnoptions:
    import warnings

    warnings.filterwarnings("ignore")

# Reshaping/basic libraries
import numpy as np
import random
import pandas as pd

# Neural Net libraries
import torch
import torch.nn.functional as F
import transformers
import tensorflow as tf

# Globals
import logging

tf.get_logger().setLevel(logging.ERROR)

# Explainability libraries
import captum
from captum.attr import visualization

# Helper libraries
from IPython.display import display, HTML
import datasets

# Store pretrained models and datasets
cache_dir = Path(".cache")
cache_dir.mkdir(parents=True, exist_ok=True)


def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)


set_seed(1)

## 1. <a name="1">Downloading the Dataset</a>
(<a href="#0">Go to top</a>)

Let's download and unpack the dataset.

In [None]:
%%capture

# load the dataset
dataset = datasets.load_dataset("imdb", cache_dir=cache_dir/"datasets")

## 2. <a name="2">Loading and Inspecting the Model</a>
(<a href="#0">Go to top</a>)

Let's read in the dataset and inspect.

In [None]:
# add mapping from label class to true sentiment
label_to_sentiment = {0: "Negative", 1: "Positive"}


# create helper function to display example instances, labels (and predictions)
def display_instance_with_label(
    input_text: str, true_label: int, *, pred_label: int = None, pred_prob: float = None
):
    instance_html = (
        f"<b>Ground Truth Sentiment:</b> {label_to_sentiment[true_label]}<br><br>"
    )
    if pred_label is not None:
        pred_str = label_to_sentiment[pred_label]
        if pred_prob is not None:
            pred_str = f"{pred_str} (Probability: {pred_prob:0.2f})"
        instance_html += f"<b>Predicted Sentiment:</b> {pred_str}<br><br>"
    instance_html += f"<b>Review Text:</b> {input_text}"
    instance_html = f"<pre>{instance_html}</pre>"
    display(HTML(instance_html))

Let us inspect an input. Go ahead and select different inputs to show the text and the sentiment.

In [None]:
# specify index from dataset to inspect
idx_to_inspect = 99

# load the example based on index
instance = dataset["train"][idx_to_inspect]

# show the example review and sentiment
display_instance_with_label(instance["text"], instance["label"])

----

We will use a model trained as a part of [this sequence classification tutorial](https://huggingface.co/docs/transformers/main/tasks/sequence_classification) on HuggingFace. The tutorial takes a `distilbert-base-uncased` pre-trained model and fine-tunes it to IMDB sentiment classification task.


If you want to train / fine-tune your own model, follow the instructions in the tutorial.

In [None]:
%%capture

# specify model name to load from HuggingFace
model_name = "stevhliu/my_awesome_model"

# instantiate model
model = transformers.AutoModelForSequenceClassification.from_pretrained(
    model_name, cache_dir = cache_dir / "models"
)

Before we start interacting with the model, we also need to load the _tokenizer_ associated with this model. You might wonder what the role of the tokenizer is.

The tokenizer converts any input text to features, aka tokens, that the model can recognize. This operation is done by breaking the input words down into sub-words. Let us take an example:

In [None]:
%%capture

# convert text to tokens
tokenizer = transformers.AutoTokenizer.from_pretrained(
    model_name, cache_dir = cache_dir / "models",
)

Let us inspect a couple of examples.

In [None]:
# sample sentence
text1 = "What is amortization"
text2 = "This is quantization"

# apply tokenization
tokens1 = tokenizer.tokenize(text1)
tokens2 = tokenizer.tokenize(text2)

# show output
print(f"{text1:30} -> {tokens1}")
print(f"{text2:30} -> {tokens2}")

Notice how the tokenizer breaks large words into same tokens `ti` and `zation`.

For convenience, each token is mapped onto an numbered ID. You can read more about tokenization [here](https://huggingface.co/docs/transformers/main/preprocessing).

In [None]:
# each token is assigned an ID, show IDs for tokens
tokenizer.convert_tokens_to_ids(tokens1)

Now, let us pass a real input through the model. The process is as follows: We will tokenize the input text into tokens (and an attention mask, which you can read about [here](https://huggingface.co/docs/transformers/glossary#attention-mask)). These token IDs get converted into embeddings by the model which then get translated into a prediction.

In [None]:
def pred_logit_from_token_ids(input_ids):
    """Given the input IDs, get the model output logits"""
    with torch.no_grad():
        return model(input_ids=input_ids).logits


def get_pred_score_prob_from_input_ids(input_ids):
    """Given the input IDs, get the model prediction and the predicted probability"""
    # pass tokens (as IDs) to model to get prediction
    pred_score = pred_logit_from_token_ids(input_ids).flatten()
    # get probability for negative/positive sentiment
    pred_class_idx = pred_score.argmax(axis=-1).item()
    # compute the softmax probability
    pred_prob = F.softmax(pred_score[pred_class_idx])
    return pred_class_idx, pred_prob


def display_instance_pred_from_token_ids(input_ids):
    """Given input IDs, display the input text and the model output"""

    pred_class_idx, pred_prob = get_pred_score_prob_from_input_ids(input_ids)

    # The tokenizer might cut the text to fit within the model's limit and additionaly
    # add special tokens to it. So let us obtain the text that the model sees
    tokenized_text = tokenizer.decode(input_ids.flatten())

    display_instance_with_label(
        tokenized_text,
        instance["label"],
        pred_label=pred_class_idx,
        pred_prob=pred_prob,
    )

In [None]:
# pick another sample instance
instance = dataset["test"][0]

# tokenize instance, this will give IDs and an attention mask
input_tokenized = tokenizer(instance["text"], return_tensors="pt")

# proceed with the IDs
input_ids = input_tokenized["input_ids"]

# Display the model input and output
display_instance_pred_from_token_ids(input_ids)

## 3. <a name="3">Explanations</a>
(<a href="#0">Go to top</a>)

We will now generate explanations with various methods.

### 3.1. <a name="31">Kernel SHAP</a>
(<a href="#3">Go to Explanations</a>)

Similar to the tabular data, we have to set a `baseline` which mimics absence of information. Recall that SHAP works by removing different combinations of features (tokens in this case) at a time and monitoring the effect on the model output. We will simulate the removal of a token by replacing it with the baseline token. We set this baseline to the "unknown token ID" (`tokenizer.unk_token_id`), which is a default token used when the model encounters a token that is not contained in its vocabulary. You could select other baselines as well. For a discussion into baselines, see [section 5 in this paper](https://arxiv.org/pdf/2106.00786.pdf).

In [None]:
# set parallelism to 'false' to avoid warnings
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [None]:
pred_class_idx, pred_prob = get_pred_score_prob_from_input_ids(input_ids)
baseline_token = tokenizer.unk_token_id
n_samples = 200  # try higher values and monitor how much the results vary
perturbations_per_eval = 32
ks = captum.attr.KernelShap(pred_logit_from_token_ids)

set_seed(1)
attrs = ks.attribute(
    inputs=input_ids,
    baselines=torch.full_like(input_ids, fill_value=tokenizer.unk_token_id),
    target=pred_class_idx,
    n_samples=n_samples,
    perturbations_per_eval=perturbations_per_eval,
    show_progress=True,
)

Now, let us visualize the explanations.

In [None]:
attrs = attrs.flatten()
tokens = tokenizer.convert_ids_to_tokens(input_ids.flatten())
viz_data = visualization.VisualizationDataRecord(
    word_attributions=attrs,
    pred_prob=pred_prob,
    pred_class=label_to_sentiment[pred_class_idx],
    true_class=label_to_sentiment[instance["label"]],
    attr_class=label_to_sentiment[pred_class_idx],
    attr_score=attrs.sum(),
    raw_input_ids=tokens,
    convergence_score=0,  # Captum KernelSHAP does not provide this info.
)
visualization.visualize_text([viz_data])

### 3.2. <a name="32">Integrated Gradients</a>
(<a href="#3">Go to Explanations</a>)

We can also use the Integrated Gradients explainer. Recall from our Tabular data notebook that this method involves taking the gradient of the model output w.r.t. inputs. Since we cannot compute gradients w.r.t. the input tokens (due to the discreet step where tokens are mapped to input embeddings), we will take the gradients w.r.t. the input embeddings. For details on how the tokens are mapped into input embeddings, see Figure 2 in the [BERT paper](https://arxiv.org/pdf/1810.04805.pdf).

In [None]:
def pred_logit_from_embeddings(embeds):
    return model(inputs_embeds=embeds).logits

In [None]:
inputs_embeds = model.distilbert.embeddings(input_ids=input_ids)
baseline = model.distilbert.embeddings(
    input_ids=torch.full_like(input_ids, fill_value=tokenizer.unk_token_id)
)
n_steps = 10

for param in model.parameters():
    param.requires_grad = True

ig = captum.attr.IntegratedGradients(pred_logit_from_embeddings)

attrs, convergence_delta = ig.attribute(
    inputs=inputs_embeds,
    baselines=baseline,
    target=pred_class_idx,
    n_steps=n_steps,
    return_convergence_delta=True,
)

The embedding corresponding to each token is a 768 dimensional vector. Which means that each token has 768 importance scores. We sum these scores to get a single scalar importance score per token. You could also use other options like $L2$-norm. 

In [None]:
attrs = attrs.sum(dim=-1)

Let us visualize the importance scores.

In [None]:
attrs = attrs.flatten()
tokens = tokenizer.convert_ids_to_tokens(input_ids.flatten())
viz_data = visualization.VisualizationDataRecord(
    word_attributions=attrs,
    pred_prob=pred_prob,
    pred_class=label_to_sentiment[pred_class_idx],
    true_class=label_to_sentiment[instance["label"]],
    attr_class=label_to_sentiment[pred_class_idx],
    attr_score=attrs.sum(),
    raw_input_ids=tokens,
    convergence_score=0,  # Captum KernelSHAP does not provide this info.
)
visualization.visualize_text([viz_data])

## 4. <a name="4">Potential issues</a>
(<a href="#0">Go to top</a>)

Text explanations also suffer from issues similar to what we faced in Tabular and Image datasets. For instance, randomly initialized models might have explanations that are very similar to those of a trained model (see [On the Lack of Robust Interpretability of Neural Text Classifiers](https://arxiv.org/pdf/2106.04631.pdf)). Removing unimportant features may result in gibberish inputs that the models are very confident about (see [Pathologies of Neural Models Make Interpretations Difficult](https://aclanthology.org/D18-1407.pdf)). Similarly, [The Out-of-Distribution Problem in Explainability and Search Methods for Feature Importance Explanations](https://arxiv.org/abs/2106.00786) show that the feature replacement mechanism that explainers like SHAP use can result in out-of-distribution inputs where the model outputs are unreliable. This in turn could impact the utility of explanations. 

To get a quick insight into these issues, let us compare the model output on an input and its perturbed version.

First, we print the model output on the original input.

In [None]:
instance = dataset["test"][0]
input_tokenized = tokenizer(instance["text"], return_tensors="pt")
input_ids = input_tokenized["input_ids"]
display_instance_pred_from_token_ids(input_ids)

Now, let us select a random subset of 50% tokens with `unknown` token. We do this replacement to mimic the perturbations made by explainers like SHAP.

In [None]:
set_seed(1)
idx_replace = np.random.permutation(input_ids.shape[1])[: input_ids.shape[1] // 2]
perturbed_input_ids = input_ids.clone()
perturbed_input_ids[:, idx_replace] = tokenizer.unk_token_id
display_instance_pred_from_token_ids(perturbed_input_ids)

While the text looks unreadable to human eye, the model is still quite confident about its sentiment.

Papers like [The Out-of-Distribution Problem in Explainability and Search Methods for Feature Importance Explanations](https://arxiv.org/abs/2106.00786) and [A Benchmark for Interpretability Methods in Deep Neural Networks](https://arxiv.org/pdf/1806.10758.pdf) argue that this out-of-distribution behavior could impact the utility of explanations. Others like [True to the Model or True to the Data?](https://arxiv.org/pdf/2006.16234.pdf) argue that the utility depends on exactly what we are trying to explain.

## Thank you for participating!