# Text to Text  Explanation: Abstractive Summarization Example

This notebook demonstrates use of generating model explanations for a text to text scenario on a pretrained transformer model. Below we demonstrate the process of generating explanations for a pretrained model distilbart on the Extreme Summarization (XSum) Dataset provided by hugging face (https://huggingface.co/sshleifer/distilbart-xsum-12-6). 

The first example only needs the model and tokenizer and we use the model decoder to generate log odds of the output tokens to be explained. In the second example, we demonstrate the use of how to generate expplanations for model in the form of an api/fucntion (input->text and output->text). In this case we need to approximate the log odds by using a text similarity model. The underlying explainer used to compute the shap values is the partition explainer.

In [1]:
import numpy as np
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
import nlp
import shap

### Load model and tokenizer

In [2]:
tokenizer = AutoTokenizer.from_pretrained("sshleifer/distilbart-xsum-12-6")
model =  AutoModelForSeq2SeqLM.from_pretrained("sshleifer/distilbart-xsum-12-6").cuda()

### Load data

In [3]:
dataset = nlp.load_dataset('xsum',split='train')



In [4]:
# slice inputs from dataset to run model inference on
s = dataset['document'][0:1]

### Create an explainer object

In [5]:
explainer = shap.Explainer(model,tokenizer)

explainers.Partition is still in an alpha state, so use with caution...


### Compute shap values

In [6]:
shap_values = explainer(s)

### Visualize shap explanations

In [7]:
shap.plots.text(shap_values)

Unnamed: 0_level_0,"The problem is affecting people using the older versions of the PlayStation 3 ,","called the "" Fat "" model .",The problem isn 't affecting the newer PS,3 Slim systems that have been on sale since September last year .,Sony have also said they are aiming to have the problem fixed shortly but is advising,"some users to avoid using their console for the time being ."" We hope to resolve this problem within the next 24 hours ,"" a statement reads .",""" In the meantime ,","if you have a model other than the new slim PS 3 ,","we advise that you do not use your PS 3 system ,","as doing so may result in errors in some functionality , such as recording obtained trophies ,","and not being able to restore certain data ."" We believe we have identified that this","problem is being caused by a bug in the clock functionality incorporated in the system ."" The PlayStation Network is used by millions of people around the world .",It allows users to play their friends at games like Fifa over the internet and also do things like download software or visit online stores . Ċ
Sony,1.012,0.134,0.561,0.48,1.737,1.328,0.058,0.632,0.42,0.115,0.404,1.488,0.475
has,0.337,0.183,0.195,0.381,0.174,0.414,0.163,0.471,0.31,0.306,0.186,0.223,0.458
said,0.358,-0.112,0.202,-0.062,0.531,0.502,0.231,0.269,-0.117,-0.148,0.562,0.572,-0.447
that,0.292,0.14,0.204,0.346,-0.288,-0.169,-0.008,-0.024,0.16,0.234,0.235,0.292,0.033
a,0.354,0.148,0.193,-0.119,0.294,0.675,0.003,0.089,0.317,-0.161,0.307,1.146,-0.111
bug,1.176,-0.113,0.642,0.005,0.127,0.352,0.047,-0.355,0.211,0.712,1.785,3.215,-0.306
in,0.188,-0.051,0.214,0.195,0.056,0.051,0.133,0.312,0.196,0.59,0.257,1.152,0.092
its,-0.449,-0.282,-0.192,-0.148,-0.216,-0.055,0.009,-0.143,-0.188,-0.038,0.14,0.818,0.074
PlayStation,1.406,-0.024,0.782,0.318,1.238,1.292,0.113,0.958,0.943,0.115,0.254,3.228,0.472
3,1.005,0.478,0.66,0.789,0.679,0.17,0.29,1.001,1.764,0.435,-0.191,-1.34,-0.406
console,0.166,0.089,0.175,0.027,0.17,1.815,0.031,0.194,0.324,-0.021,-0.201,0.068,-0.345
is,0.246,0.095,0.087,-0.007,0.168,0.101,0.027,-0.079,-0.065,-0.177,0.208,0.788,-0.087
causing,0.311,0.025,0.312,-0.055,0.387,0.589,-0.028,-0.109,0.325,0.284,0.676,1.205,-0.182
some,0.296,0.089,0.216,0.163,0.405,0.304,-0.004,0.217,0.129,0.392,0.518,0.417,-0.043
users,0.438,0.091,-0.1,-0.073,0.148,0.659,-0.107,-0.271,-0.273,-0.081,0.68,0.599,0.877
to,0.105,0.097,-0.002,0.028,0.0,0.177,0.017,0.058,-0.088,0.015,-0.097,0.044,0.06
lose,0.205,0.086,0.267,0.233,0.725,-0.182,-0.063,-0.084,0.113,-0.367,1.017,0.57,0.089
access,0.66,-0.193,0.378,0.215,0.211,0.88,0.112,0.116,0.308,0.164,1.083,1.015,0.271
to,0.015,-0.009,0.029,0.012,0.06,0.01,-0.002,0.01,-0.016,-0.011,-0.056,0.031,-0.022
the,-0.511,-0.244,-0.184,-0.083,0.017,0.432,0.072,0.178,-0.015,0.038,-1.119,0.72,0.436
PlayStation,-0.054,-0.129,-0.123,-0.295,0.085,0.266,0.184,0.507,0.285,0.236,-0.167,4.178,0.41
Network,0.375,0.05,0.036,0.154,-0.493,-0.066,0.022,0.109,0.101,0.29,0.818,5.419,1.274
.,0.052,0.046,-0.044,0.106,0.163,0.209,-0.059,-0.135,-0.041,-0.188,0.329,0.825,0.054


### API

Below we demonstrate generating explanations for a model which is an api/function. Since this is a model agnostic case, we use a text similarity model to approximate log odds of generating output text which is used for computing shap explanations.

In [8]:
# Define function
def f(x):
    input_ids = torch.tensor([tokenizer.encode(x)]).cuda()
    with torch.no_grad():
        out=model.generate(input_ids)
    sentence = [tokenizer.decode(g, skip_special_tokens=True) for g in out][0]
    return sentence

For a model agnostic case, we wrap the model to be explained with the shal.models.TeacherForcingLogits class and define the text similarity model and tokenizer. The TeacherForcingLogits class uses the similarity model to approximate the log odds of generating the output text from the model(function->f)

We also have to create a Text masker and define mask_token="..." and pass collapse_mask_token=True, which then cues the algorithm to use text infilling while masking

In [9]:
# wrap model with TeacherForcingLogits class
wrapped_model = shap.models.TeacherForcingLogits(f, similarity_model=model, similarity_tokenizer=tokenizer)
# create a Text masker
masker = shap.maskers.Text(tokenizer, mask_token = "...", collapse_mask_token=True)

### Create an explainer object using wrapped model and Text masker

In [10]:
explainer_model_agnostic = shap.Explainer(wrapped_model,masker)

explainers.Partition is still in an alpha state, so use with caution...


### Compute shap values

In [11]:
shap_values_model_agnostic = explainer_model_agnostic(s)

HBox(children=(FloatProgress(value=0.0, max=48.0), HTML(value='')))

Partition explainer: 2it [00:49, 24.71s/it]                                                                            


### Visualize shap explanations

In [12]:
shap.plots.text(shap_values_model_agnostic)

Unnamed: 0_level_0,"The problem is affecting people using the older versions of the PlayStation 3 ,","called the "" Fat "" model .",The problem isn 't affecting the newer PS,3 Slim systems that have been on sale since September last year .,Sony have also said they are aiming to,have the problem fixed shortly but is advising,"some users to avoid using their console for the time being ."" We hope to resolve this problem within the next 24 hours ,"" a statement reads .",""" In the meantime , if you have a model other than the new slim PS 3 ,","we advise that you do not use your PS 3 system ,","as doing so may result in errors in some functionality , such as recording obtained trophies ,","and not being able to restore certain data ."" We believe we have identified that this","problem is being caused by a bug in the clock functionality incorporated in the system ."" The PlayStation Network is used by millions of people around the world .",It allows users to play their friends at games like Fifa over the internet and also do things like download software or visit online stores . Ċ
Sony,1.192,0.374,0.53,0.373,1.011,0.138,-0.775,0.965,0.665,0.006,0.315,1.309,0.553
has,0.246,0.125,0.207,-0.01,0.305,0.168,0.266,0.297,0.334,0.036,0.227,0.465,0.153
said,0.259,0.117,0.139,0.147,0.161,0.05,0.294,0.377,0.005,-0.062,0.236,0.304,-0.057
that,0.433,0.181,0.364,0.355,0.304,0.494,0.06,0.156,0.134,0.098,0.023,0.622,-0.041
a,0.17,0.058,0.198,0.05,0.027,0.096,-0.15,0.423,-0.002,-0.097,0.341,0.36,0.245
bug,0.802,0.005,0.443,-0.28,0.279,0.792,0.48,0.579,0.163,0.625,1.695,2.734,0.04
in,0.6,0.209,0.036,0.077,0.643,0.309,0.22,0.37,-0.088,0.018,0.526,0.613,-0.065
its,-0.498,-0.162,-0.281,-0.178,-0.664,-0.839,0.039,0.311,-0.072,0.095,0.167,0.497,-0.325
PlayStation,1.288,0.147,0.959,0.916,0.104,-0.329,0.726,1.03,0.969,0.544,0.638,2.045,-0.192
3,0.536,0.382,0.439,0.785,-0.557,-0.527,0.623,0.568,0.924,0.568,0.758,0.955,-1.279
console,-0.106,-0.064,0.128,-0.064,0.089,0.195,1.363,0.728,0.385,0.201,-0.133,0.163,-0.264
is,-0.029,0.058,0.019,0.096,0.227,0.287,0.466,-0.264,-0.157,-0.255,-0.084,0.258,-0.112
causing,0.432,0.055,0.282,0.292,0.255,0.376,0.277,0.116,-0.114,-0.04,0.151,0.623,-0.121
some,0.387,0.12,0.161,-0.025,0.392,0.417,0.309,-0.03,-0.068,-0.068,0.76,0.75,-0.077
users,0.101,0.265,-0.114,0.082,0.053,0.193,1.05,0.162,-0.388,-0.127,0.796,0.964,0.69
to,-0.091,-0.009,0.015,0.071,0.16,0.064,0.193,-0.021,-0.043,-0.069,-0.064,-0.037,0.216
lose,0.51,0.071,0.095,-0.038,0.913,1.022,-0.087,0.006,0.012,-0.101,0.713,1.038,-0.134
access,0.58,-0.036,0.489,0.364,0.092,-0.13,-0.148,0.582,0.268,0.809,1.171,2.335,0.18
to,-0.015,-0.002,0.01,0.007,-0.015,-0.057,0.021,0.008,-0.004,-0.0,0.011,-0.011,0.011
the,0.203,0.141,-0.125,0.118,0.271,0.236,0.463,0.405,0.065,0.017,-0.563,-0.468,0.301
PlayStation,-0.017,0.091,-0.024,0.182,-0.063,-0.155,-0.166,0.442,0.454,0.428,-0.167,-0.053,0.397
Network,0.461,0.306,0.315,0.556,0.547,0.507,0.714,0.695,0.283,0.591,0.402,1.605,0.257
.,-0.001,0.129,0.205,0.215,0.19,0.084,0.662,0.401,0.131,-0.08,-0.317,-0.233,0.332
