# Text to Text  Explanation: Abstractive Summarization Example

This notebook demonstrates use of generating model explanations for a text to text scenario on a pretrained transformer model. Below we demonstrate the process of generating explanations for a pretrained model distilbart on the Extreme Summarization (XSum) Dataset provided by hugging face (https://huggingface.co/sshleifer/distilbart-xsum-12-6). 

The first example only needs the model and tokenizer and we use the model decoder to generate log odds of the output tokens to be explained. In the second example, we demonstrate the use of how to generate expplanations for model in the form of an api/fucntion (input->text and output->text). In this case we need to approximate the log odds by using a text similarity model. The underlying explainer used to compute the shap values is the partition explainer.

In [17]:
import numpy as np
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, T5Tokenizer
import torch
# from datasets import load_dataset
import shap

### Load model and tokenizer

In [18]:
tokenizer = T5Tokenizer.from_pretrained("t5-base")
model =  AutoModelForSeq2SeqLM.from_pretrained("t5-base").cuda()

### Load data

In [19]:
tokenizer

PreTrainedTokenizer(name_or_path='t5-base', vocab_size=32100, model_max_len=512, is_fast=False, padding_side='right', special_tokens={'eos_token': '</s>', 'unk_token': '<unk>', 'pad_token': '<pad>', 'additional_special_tokens': ['<extra_id_0>', '<extra_id_1>', '<extra_id_2>', '<extra_id_3>', '<extra_id_4>', '<extra_id_5>', '<extra_id_6>', '<extra_id_7>', '<extra_id_8>', '<extra_id_9>', '<extra_id_10>', '<extra_id_11>', '<extra_id_12>', '<extra_id_13>', '<extra_id_14>', '<extra_id_15>', '<extra_id_16>', '<extra_id_17>', '<extra_id_18>', '<extra_id_19>', '<extra_id_20>', '<extra_id_21>', '<extra_id_22>', '<extra_id_23>', '<extra_id_24>', '<extra_id_25>', '<extra_id_26>', '<extra_id_27>', '<extra_id_28>', '<extra_id_29>', '<extra_id_30>', '<extra_id_31>', '<extra_id_32>', '<extra_id_33>', '<extra_id_34>', '<extra_id_35>', '<extra_id_36>', '<extra_id_37>', '<extra_id_38>', '<extra_id_39>', '<extra_id_40>', '<extra_id_41>', '<extra_id_42>', '<extra_id_43>', '<extra_id_44>', '<extra_id_45>',

In [14]:
from IPython.core.debugger import set_trace 

In [27]:
tokenizer('I hate you!!!')

{'input_ids': [27, 5591, 25, 3158, 1], 'attention_mask': [1, 1, 1, 1, 1]}

In [3]:
dataset = load_dataset('xsum',split='train')

Using custom data configuration default
Reusing dataset xsum (C:\Users\ryserrao\.cache\huggingface\datasets\xsum\default\1.1.0\128741c17b7a4c939dbf844a75a5e83deadd07deaf4b2eda2056ed8eebdb03ae)


In [4]:
# slice inputs from dataset to run model inference on
s = dataset['document'][0:1]

### Create an explainer object

In [20]:
# set_trace()
explainer = shap.Explainer(model,tokenizer)

### Compute shap values

In [24]:
s = ['Summarization: I enjoy walking with my cute dog']

In [25]:
shap_values = explainer(s)

> [0;32m/ext3/miniconda3/lib/python3.8/site-packages/shap/explainers/_explainer.py[0m(255)[0;36m__call__[0;34m()[0m
[0;32m    253 [0;31m            [0mfeature_names[0m [0;34m=[0m [0;34m[[0m[0;34m[[0m[0;34m][0m [0;32mfor[0m [0m_[0m [0;32min[0m [0mrange[0m[0;34m([0m[0mlen[0m[0;34m([0m[0margs[0m[0;34m)[0m[0;34m)[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    254 [0;31m        [0;32mfor[0m [0mrow_args[0m [0;32min[0m [0mshow_progress[0m[0;34m([0m[0mzip[0m[0;34m([0m[0;34m*[0m[0margs[0m[0;34m)[0m[0;34m,[0m [0mnum_rows[0m[0;34m,[0m [0mself[0m[0;34m.[0m[0m__class__[0m[0;34m.[0m[0m__name__[0m[0;34m+[0m[0;34m" explainer"[0m[0;34m,[0m [0msilent[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[1;31m2[0;32m-> 255 [0;31m            row_result = self.explain_row(
[0m[0;32m    256 [0;31m                [0;34m*[0m[0mrow_args[0m[0;34m,[0m [0mmax_evals[0m[0;34m=[0m[0mmax_evals[0m[0;34m,[

### Visualize shap explanations

In [26]:
shap.plots.text(shap_values)

Unnamed: 0_level_0,Sum mar,ization,:,I enjoy,walking,with,my,cute,dog,</s>
<extra_id_0>,-2.688,-2.008,-2.219,3.884,1.636,-1.392,0.448,0.163,-1.769,-0.318
:,2.896,1.603,4.071,0.623,0.076,0.385,0.824,-1.067,0.69,0.31
I,-0.306,0.326,-0.012,5.087,-0.397,-0.668,0.784,-0.566,-0.938,0.086
enjoy,0.452,0.015,0.206,7.111,0.033,0.134,-0.217,-0.169,-0.565,-0.018
walking,-0.195,0.291,0.466,0.413,5.972,1.779,-0.298,-0.095,1.303,0.032
with,0.327,0.255,0.209,0.648,1.785,3.855,-0.017,0.196,0.324,0.083
my,0.08,0.35,0.44,-0.16,0.635,0.019,4.096,0.004,0.395,0.025
cute,-0.12,0.29,0.367,0.106,0.208,0.803,2.99,6.651,-0.533,0.027
dog,0.519,0.056,0.181,-0.032,0.378,0.225,0.134,0.722,3.495,0.208
.,-0.054,0.445,0.556,0.814,0.08,0.262,-0.045,-0.167,0.665,0.134
<extra_id_1>,0.039,0.81,0.648,0.716,0.131,0.242,-0.207,-0.247,0.805,0.147
:,-0.225,2.064,4.062,0.409,0.83,0.247,0.354,0.512,3.572,0.037
I,-1.327,0.394,0.388,4.25,-0.092,-0.435,-0.369,0.003,0.446,0.04
enjoy,-0.178,-0.164,0.005,4.582,0.065,0.12,-0.8,-0.426,-0.514,0.045
walking,-0.464,0.07,0.25,-1.138,3.015,0.815,-0.217,0.533,0.333,0.148
with,-0.341,0.049,0.198,-0.352,0.703,2.017,-0.204,0.784,0.496,0.184
my,0.006,0.341,0.438,-0.72,0.494,-0.224,2.496,-0.375,0.682,-0.12
cute,-0.459,0.182,0.361,-1.145,-0.066,0.819,1.003,3.879,-1.133,-0.081
dog,0.017,0.014,0.139,-0.216,0.305,0.214,-0.021,0.614,1.43,0.07


### API

Below we demonstrate generating explanations for a model which is an api/function. Since this is a model agnostic case, we use a text similarity model to approximate log odds of generating output text which is used for computing shap explanations.

In [8]:
# Define function
def f(x):
    inputs = tokenizer(x.tolist(), return_tensors="pt", padding=True).to('cuda')
    with torch.no_grad():
        out = model.generate(**inputs)
    sentence = [tokenizer.decode(g, skip_special_tokens=True) for g in out]
    return np.array(sentence)

For a model agnostic case, we wrap the model to be explained with the shal.models.TeacherForcing class and define the text similarity model and tokenizer. The TeacherForcing class uses the similarity model to approximate the log odds of generating the output text from the model(function->f)

We also have to define a Text masker and define mask_token="..." and pass collapse_mask_token=True, which then cues the algorithm to use text infilling while masking

In [9]:
# wrap model with TeacherForcingLogits class
teacher_forcing_model = shap.models.TeacherForcing(f, similarity_model=model, similarity_tokenizer=tokenizer)
# create a Text masker
masker = shap.maskers.Text(tokenizer, mask_token = "...", collapse_mask_token=True)

### Create an explainer object using wrapped model and Text masker

In [10]:
explainer_model_agnostic = shap.Explainer(teacher_forcing_model, masker)

explainers.Partition is still in an alpha state, so use with caution...


### Compute shap values

In [11]:
shap_values_model_agnostic = explainer_model_agnostic(s)

HBox(children=(FloatProgress(value=0.0, max=48.0), HTML(value='')))

Partition explainer: 2it [00:31, 15.99s/it]                                                                            


### Visualize shap explanations

In [12]:
shap.plots.text(shap_values_model_agnostic)

Unnamed: 0_level_0,"The problem is affecting people using the older versions of the PlayStation 3 ,","called the "" Fat "" model .",The problem isn 't affecting the newer PS,3 Slim systems that have been on sale since September last year .,Sony have also said they are aiming to,have the problem fixed shortly but is advising,"some users to avoid using their console for the time being ."" We hope to resolve this problem within the next 24 hours ,"" a statement reads .",""" In the meantime , if you have a model other than the new slim PS 3 ,","we advise that you do not use your PS 3 system ,","as doing so may result in errors in some functionality , such as recording obtained trophies ,","and not being able to restore certain data ."" We believe we have identified that this","problem is being caused by a bug in the clock functionality incorporated in the system ."" The PlayStation Network is used by millions of people around the world .",It allows users to play their friends at games like Fifa over the internet and also do things like download software or visit online stores . Ċ
Sony,1.192,0.374,0.53,0.373,1.011,0.138,-0.775,0.965,0.665,0.006,0.315,1.309,0.553
has,0.246,0.125,0.207,-0.01,0.305,0.168,0.266,0.297,0.334,0.036,0.227,0.465,0.153
said,0.259,0.117,0.139,0.147,0.161,0.05,0.294,0.377,0.005,-0.062,0.236,0.304,-0.057
that,0.433,0.181,0.364,0.355,0.304,0.494,0.06,0.156,0.134,0.098,0.023,0.622,-0.041
a,0.17,0.058,0.198,0.05,0.027,0.096,-0.15,0.423,-0.002,-0.097,0.341,0.36,0.245
bug,0.802,0.005,0.443,-0.28,0.279,0.792,0.48,0.579,0.163,0.625,1.695,2.734,0.04
in,0.6,0.209,0.036,0.077,0.643,0.309,0.22,0.37,-0.088,0.018,0.526,0.613,-0.065
its,-0.498,-0.162,-0.281,-0.178,-0.664,-0.839,0.039,0.311,-0.072,0.095,0.167,0.497,-0.325
PlayStation,1.288,0.147,0.959,0.916,0.104,-0.329,0.726,1.03,0.969,0.544,0.638,2.045,-0.192
3,0.536,0.382,0.439,0.785,-0.557,-0.527,0.623,0.568,0.924,0.568,0.758,0.955,-1.279
console,-0.106,-0.064,0.128,-0.064,0.089,0.195,1.363,0.728,0.385,0.201,-0.133,0.163,-0.264
is,-0.029,0.058,0.019,0.096,0.227,0.287,0.466,-0.264,-0.157,-0.255,-0.084,0.258,-0.112
causing,0.432,0.055,0.282,0.292,0.255,0.376,0.277,0.116,-0.114,-0.04,0.151,0.623,-0.121
some,0.387,0.12,0.161,-0.025,0.392,0.417,0.309,-0.03,-0.068,-0.068,0.76,0.75,-0.077
users,0.101,0.265,-0.114,0.082,0.053,0.193,1.05,0.162,-0.388,-0.127,0.796,0.964,0.69
to,-0.091,-0.009,0.015,0.071,0.16,0.064,0.193,-0.021,-0.043,-0.069,-0.064,-0.037,0.216
lose,0.51,0.071,0.095,-0.038,0.913,1.022,-0.087,0.006,0.012,-0.101,0.713,1.038,-0.134
access,0.58,-0.036,0.489,0.364,0.092,-0.13,-0.148,0.582,0.268,0.809,1.171,2.335,0.18
to,-0.015,-0.002,0.01,0.007,-0.015,-0.057,0.021,0.008,-0.004,-0.0,0.011,-0.011,0.011
the,0.203,0.141,-0.125,0.118,0.271,0.236,0.463,0.405,0.065,0.017,-0.563,-0.468,0.301
PlayStation,-0.017,0.091,-0.024,0.182,-0.063,-0.155,-0.166,0.442,0.454,0.428,-0.167,-0.053,0.397
Network,0.461,0.306,0.315,0.556,0.547,0.507,0.714,0.695,0.283,0.591,0.402,1.605,0.257
.,-0.001,0.129,0.205,0.215,0.19,0.084,0.662,0.401,0.131,-0.08,-0.317,-0.233,0.332
