# Text to Text Explanation: Open Ended Text Generation Using GPT2

This notebook demonstrates use of generating model explanations for open ended text generation using gpt2. In this demo, we use the pretrained gpt2 model provided by hugging face (https://huggingface.co/gpt2) to explain the model used to generate text based on passing custom model generation configurations on an intial provided text.

In [1]:
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM
import shap
import torch

### Load model and tokenizer

In [2]:
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model =  AutoModelForCausalLM.from_pretrained("gpt2").cuda()

Below, we set certain model configurations. We need to define if the model is a decoder or encoder-decoder.
This can be set through the 'is_decoder' or 'is_encoder_decoder' param in model's config file.
We can also set custom model generation parameters which will be used during the output text generation decoding process.

In [3]:
# set model decoder to true and generation params
model.config.is_decoder=True
model.config.text_generation_params = {
    "do_sample": True,
    "max_length": 50,
    "temperature": 0.7,
    "top_k": 0
  }

### Define initial text

In [4]:
s=["I enjoy walking with my cute dog"]

### Create an explainer object

In [5]:
explainer = shap.Explainer(model,tokenizer)

explainers.Partition is still in an alpha state, so use with caution...


### Compute shap values

In [6]:
shap_values = explainer(s)

Setting `pad_token_id` to 50256 (first `eos_token_id`) to generate sequence


### Visualize shap explanations

In [7]:
shap.plots.text(shap_values)

Unnamed: 0_level_0,I,enjoy,walking,with,my,cute,dog
.,-0.059,-0.537,-0.846,0.841,-0.352,-0.469,3.642
I,0.987,0.113,0.032,0.067,0.244,0.264,0.975
walk,0.532,-0.3,2.858,-0.163,0.388,-0.621,0.412
with,0.474,-0.133,1.357,1.399,0.129,-0.248,0.395
a,0.148,0.403,0.133,-0.169,-0.583,-0.152,0.107
dog,0.238,-0.097,0.37,0.644,-0.069,-0.388,2.897
from,-0.281,0.029,-0.217,0.209,-0.299,0.304,0.404
time,0.71,1.401,0.369,0.187,0.285,-0.483,1.009
to,0.05,-0.017,-0.041,-0.178,0.001,0.052,0.313
time,0.021,0.067,0.109,-0.048,0.005,0.013,0.107
.,-0.078,-0.036,0.189,-0.206,0.011,-0.261,0.195
He,0.049,-0.143,-0.091,-0.027,-0.184,0.297,0.612
will,-0.145,0.15,0.143,-0.069,-0.046,-0.146,0.263
talk,0.007,-0.231,-0.065,0.016,-0.107,-0.143,0.104
to,0.198,-0.113,0.078,-0.015,0.035,0.024,-0.03
me,0.299,0.18,-0.104,0.147,0.314,0.202,0.028
and,0.13,0.154,0.077,0.043,0.013,0.063,0.197
my,-0.066,0.108,-0.056,-0.086,0.141,0.241,-0.048
dogs,0.37,0.257,0.16,0.109,-0.232,-0.318,0.411
and,0.215,0.019,0.115,0.056,0.014,-0.081,-0.003
I,0.181,0.044,-0.026,-0.007,0.068,0.012,-0.084
can,0.014,0.263,-0.053,0.071,0.015,0.095,-0.04
chat,-0.306,0.376,-0.148,0.064,-0.111,0.43,0.045
with,0.055,0.044,-0.026,0.041,0.035,0.052,0.136
them,-0.087,-0.067,-0.056,-0.034,-0.046,0.05,-0.112
.,-0.209,-0.011,0.026,-0.023,-0.005,0.052,0.18
I,0.116,0.07,-0.04,-0.041,0.07,0.028,-0.023
don,0.169,-0.344,0.086,0.072,0.024,-0.282,-0.044
't,0.119,-0.179,0.048,0.004,-0.221,0.034,0.242
really,0.094,-0.021,0.068,0.043,-0.004,0.123,-0.061
know,-0.028,-0.308,0.042,-0.02,-0.003,-0.089,-0.023
where,0.099,-0.005,0.097,-0.043,-0.012,-0.043,-0.001
I,0.068,-0.006,0.037,-0.015,0.065,-0.086,-0.144
stand,0.143,-0.066,0.158,-0.006,-0.008,-0.237,0.169
with,-0.075,0.15,-0.179,0.12,-0.028,0.126,0.119
the,0.0,-0.006,-0.041,-0.018,0.007,-0.131,-0.07
dog,0.233,-0.018,0.118,0.061,-0.073,-0.162,0.128
.,0.035,-0.105,0.033,-0.048,0.011,-0.19,0.169
I,0.084,-0.024,-0.024,-0.003,0.057,-0.04,0.031
walked,-0.144,-0.416,0.624,0.031,-0.158,-0.083,0.02
my,0.113,0.18,-0.387,-0.018,0.273,0.059,0.129
dog,0.004,-0.04,-0.092,0.096,0.038,-0.17,-0.077
at,0.042,0.053,0.101,-0.059,-0.006,-0.025,-0.051


### Another example...

In [8]:
s=['Scientists confirmed the worst possible outcome: the massive asteroid will collide with Earth']

In [9]:
explainer = shap.Explainer(model,tokenizer)

explainers.Partition is still in an alpha state, so use with caution...


In [10]:
shap_values = explainer(s)

Setting `pad_token_id` to 50256 (first `eos_token_id`) to generate sequence


In [11]:
shap.plots.text(shap_values)

Unnamed: 0_level_0,Scientists,confirmed,the,worst,possible,outcome,:,the,massive,asteroid,will,collide,with,Earth
via,0.207,0.631,-0.414,-0.414,0.018,0.018,-0.125,-0.125,-0.195,-0.195,0.454,0.71,0.839,0.128
a,0.167,0.037,-0.009,-0.009,0.2,0.2,0.314,0.314,0.149,0.149,0.666,0.685,0.693,1.365
massive,0.406,0.074,0.194,0.194,-0.034,-0.034,-0.144,-0.144,0.883,0.883,0.55,0.857,-0.08,1.303
collision,0.84,0.581,-0.091,-0.091,0.138,0.138,0.126,0.126,0.447,0.447,0.912,2.054,-0.088,1.084
with,-0.103,0.02,0.042,0.042,0.018,0.018,-0.044,-0.044,0.264,0.264,-0.146,-0.243,-0.246,0.056
the,0.009,-0.081,0.019,0.019,0.014,0.014,0.073,0.073,-0.035,-0.035,0.109,0.054,0.034,0.564
Sun,0.813,0.309,-0.061,-0.061,0.053,0.053,-0.075,-0.075,0.468,0.468,-0.013,-0.026,-0.303,2.13
.,-0.222,0.105,0.094,0.094,0.099,0.099,0.087,0.087,-0.031,-0.031,0.044,0.015,0.233,0.079
Ċ,0.263,0.452,0.107,0.107,0.112,0.112,0.161,0.161,0.037,0.037,-0.084,-0.025,0.221,-0.388
Ċ,2.436,1.267,-0.066,-0.066,-0.069,-0.069,-0.002,-0.002,0.054,0.054,-0.271,0.048,-0.101,-0.44
The,0.362,0.211,0.09,0.09,0.045,0.045,0.017,0.017,-0.002,-0.002,-0.027,0.008,0.035,-0.021
collision,1.064,0.507,0.061,0.061,0.037,0.037,0.06,0.06,0.155,0.155,0.275,0.558,-0.146,-0.526
has,0.169,-0.047,-0.047,-0.047,-0.008,-0.008,-0.002,-0.002,-0.069,-0.069,-0.009,-0.027,-0.101,0.055
already,-0.109,-0.123,0.087,0.087,0.195,0.195,0.092,0.092,0.171,0.171,0.546,0.429,0.022,-0.033
occurred,-0.54,-0.33,-0.192,-0.192,0.191,0.191,0.111,0.111,-0.167,-0.167,0.623,0.582,0.058,0.187
",",-0.38,-0.187,-0.047,-0.047,0.069,0.069,0.074,0.074,-0.026,-0.026,0.032,0.012,0.008,0.055
sending,0.55,0.392,0.052,0.052,-0.135,-0.135,-0.123,-0.123,0.277,0.277,-0.215,-0.098,-0.103,0.203
a,0.149,0.071,-0.034,-0.034,0.025,0.025,-0.006,-0.006,0.123,0.123,0.059,0.089,-0.036,-0.134
huge,-0.071,-0.044,0.051,0.051,-0.025,-0.025,-0.036,-0.036,0.285,0.285,0.035,0.037,0.017,0.14
fireball,0.239,0.197,-0.009,-0.009,-0.062,-0.062,-0.068,-0.068,0.135,0.135,0.008,0.077,-0.045,-0.113
into,0.114,0.046,0.004,0.004,0.007,0.007,0.006,0.006,0.061,0.061,-0.015,0.016,-0.007,0.027
space,0.615,0.176,-0.051,-0.051,0.081,0.081,0.011,0.011,0.322,0.322,0.018,0.102,-0.13,0.521
to,0.093,0.039,-0.03,-0.03,0.011,0.011,-0.015,-0.015,0.167,0.167,0.015,-0.002,0.065,0.205
slam,0.314,0.083,0.038,0.038,0.031,0.031,-0.008,-0.008,0.208,0.208,0.047,0.103,-0.056,0.153
into,0.26,0.146,0.021,0.021,0.035,0.035,-0.01,-0.01,0.065,0.065,0.021,0.09,-0.018,-0.028
Earth,0.31,0.107,-0.069,-0.069,0.04,0.04,-0.052,-0.052,0.16,0.16,-0.016,-0.023,-0.025,0.417
and,-0.046,-0.035,0.023,0.023,0.009,0.009,-0.009,-0.009,0.029,0.029,-0.015,-0.02,0.054,0.025
destroying,-0.307,-0.161,0.018,0.018,-0.025,-0.025,-0.02,-0.02,-0.005,-0.005,0.016,0.016,0.04,0.158
the,-0.142,-0.043,-0.007,-0.007,-0.0,-0.0,0.012,0.012,-0.059,-0.059,-0.01,-0.007,0.039,0.046
Earth,-0.116,-0.142,-0.049,-0.049,-0.015,-0.015,-0.044,-0.044,0.124,0.124,0.064,0.054,-0.048,0.252
's,0.525,0.262,0.013,0.013,0.045,0.045,0.037,0.037,-0.008,-0.008,-0.089,-0.059,-0.063,-0.209
atmosphere,0.15,0.05,-0.037,-0.037,0.045,0.045,0.026,0.026,-0.004,-0.004,0.054,0.07,0.05,0.029
.,-0.03,-0.012,-0.041,-0.041,-0.033,-0.033,-0.014,-0.014,-0.026,-0.026,0.007,0.038,-0.027,0.041
Ċ,0.433,0.304,-0.029,-0.029,-0.075,-0.075,0.024,0.024,-0.02,-0.02,0.032,0.067,0.012,-0.099
Ċ,-0.011,-0.044,0.012,0.012,0.001,0.001,0.004,0.004,0.023,0.023,-0.018,0.023,-0.002,-0.063
At,-0.117,-0.097,0.001,0.001,0.027,0.027,0.022,0.022,-0.029,-0.029,-0.011,-0.03,-0.009,0.059
