# Text to Text Explanation: Open Ended Text Generation Using GPT2

This notebook demonstrates use of generating model explanations for open ended text generation using gpt2. In this demo, we use the pretrained gpt2 model provided by hugging face (https://huggingface.co/gpt2) to explain the model used to generate text based on passing custom model generation configurations on an intial provided text.

In [1]:
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM
import shap
import torch

### Load model and tokenizer

In [2]:
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model =  AutoModelForCausalLM.from_pretrained("gpt2").cuda()

Below, we set certain model configurations. We need to define if the model is a decoder or encoder-decoder.
This can be set through the 'is_decoder' or 'is_encoder_decoder' param in model's config file.
We can also set custom model generation parameters which will be used during the output text generation decoding process.

In [3]:
# set model decoder to true and generation params
model.config.is_decoder=True
model.config.text_generation_params = {
    "do_sample": True,
    "max_length": 50,
    "temperature": 0.7,
    "top_k": 0
  }

### Define initial text

In [4]:
s=["I enjoy walking with my cute dog"]

### Create an explainer object

In [5]:
explainer = shap.Explainer(model,tokenizer)

explainers.Partition is still in an alpha state, so use with caution...


### Compute shap values

In [6]:
shap_values = explainer(s)

Setting `pad_token_id` to 50256 (first `eos_token_id`) to generate sequence


### Visualize shap explanations

In [7]:
shap.plots.text(shap_values[0])

Unnamed: 0_level_0,I,enjoy,walking,with,my,cute,dog
",",-0.335,-0.292,-0.256,-0.272,-0.351,-0.19,4.471
and,-0.079,0.709,0.163,0.31,-0.333,-0.013,0.339
I,1.392,0.146,-0.006,-0.259,0.904,0.084,0.018
love,0.553,1.471,-0.334,-0.225,0.196,0.919,0.132
being,0.044,0.834,0.692,-0.015,0.019,0.245,-0.31
a,0.067,-0.2,-0.603,-0.053,-0.279,-0.123,0.398
family,0.27,0.222,-0.138,0.216,0.478,-0.019,1.064
dog,-0.158,0.198,0.345,-0.008,-0.364,0.282,1.84
.,-0.1,0.102,0.155,-0.062,-0.175,-0.019,0.095
I,0.63,-0.368,0.052,-0.044,0.083,-0.081,0.075
love,0.084,-0.09,-0.058,-0.091,0.25,-0.054,-0.087
being,0.128,-0.206,-0.016,-0.026,0.095,-0.085,-0.225
able,-0.142,0.05,-0.058,0.298,0.036,0.155,-0.119
to,-0.222,-0.144,0.268,-0.067,0.198,0.347,0.088
be,0.095,-0.194,-0.109,0.014,0.041,-0.087,0.071
alone,-0.177,0.119,0.284,-0.018,0.041,0.215,-0.115
with,-0.158,0.114,-0.065,0.139,0.031,0.192,0.043
my,0.077,0.046,-0.014,0.002,0.173,-0.124,0.017
dog,0.086,0.015,0.085,0.151,-0.001,-0.247,0.597
",",0.031,0.1,-0.015,0.049,0.11,0.132,0.022
and,0.057,0.122,0.089,0.18,0.18,0.092,0.174
I,0.251,0.266,0.15,0.226,0.119,0.074,0.091
love,-0.114,0.051,-0.04,0.027,0.206,-0.038,0.065
being,-0.03,0.002,-0.044,0.115,0.105,-0.039,0.134
able,-0.098,-0.091,-0.085,0.029,0.034,-0.044,-0.145
to,-0.091,-0.049,0.074,0.133,0.177,0.176,0.018
be,0.039,-0.032,-0.041,0.169,0.144,-0.06,0.2
with,0.038,-0.018,0.03,-0.021,0.053,-0.014,0.105
a,0.002,0.089,-0.003,0.07,-0.081,-0.038,0.001
man,-0.113,-0.205,0.141,0.003,0.177,0.172,-0.013
.,-0.168,0.129,0.04,-0.071,0.025,0.143,0.079
I,0.183,0.043,0.009,-0.046,0.047,0.038,0.016
love,-0.238,0.005,-0.11,-0.065,0.163,0.081,-0.02
having,-0.007,0.097,-0.026,0.109,-0.049,0.042,0.016
my,-0.062,-0.016,-0.111,0.021,0.161,0.056,-0.037
dog,-0.031,0.054,-0.068,-0.103,-0.12,-0.168,0.091
with,0.06,-0.031,-0.08,0.06,-0.01,-0.069,0.036
me,0.075,0.031,0.21,-0.173,0.161,-0.04,-0.121
and,-0.004,0.052,-0.025,0.017,-0.094,-0.025,0.057
being,0.019,0.003,-0.001,0.118,-0.002,0.053,0.017
with,0.054,-0.083,-0.05,-0.106,0.075,0.005,-0.052
my,0.006,0.001,0.009,-0.118,-0.027,-0.01,-0.012
kids,0.074,-0.114,0.116,-0.049,-0.217,-0.026,0.011


### Another example...

In [8]:
s=['Scientists confirmed the worst possible outcome: the massive asteroid will collide with Earth']

In [9]:
explainer = shap.Explainer(model,tokenizer)

explainers.Partition is still in an alpha state, so use with caution...


In [10]:
shap_values = explainer(s)

Setting `pad_token_id` to 50256 (first `eos_token_id`) to generate sequence


In [11]:
shap.plots.text(shap_values[0])

Unnamed: 0_level_0,Scientists,confirmed,the,worst,possible,outcome,:,the,massive,asteroid,will,collide,with,Earth
.,-0.464,-0.464,-0.55,-0.464,-0.184,-0.184,-0.318,-0.519,-0.347,-0.34,-0.166,0.982,0.799,4.018
Ċ,0.182,0.182,0.075,0.137,0.206,0.206,0.089,0.059,0.184,0.226,-0.118,0.189,0.025,0.357
Ċ,0.341,0.341,0.213,0.152,0.197,0.197,0.003,-0.001,0.509,0.581,-0.122,0.533,0.499,0.669
If,-0.336,-0.336,-0.167,-0.138,0.169,0.169,-0.015,0.016,-0.132,-0.089,0.194,0.126,-0.047,-0.035
there,0.039,0.039,0.01,-0.01,0.066,0.066,0.022,0.016,0.079,0.117,-0.069,-0.052,-0.021,0.068
is,-0.058,-0.058,-0.073,-0.097,0.052,0.052,0.008,-0.042,-0.06,-0.028,0.276,0.121,0.031,0.025
a,0.051,0.051,-0.066,-0.083,0.056,0.056,0.036,0.015,-0.053,-0.024,0.092,0.222,0.032,0.085
collision,0.35,0.35,-0.118,-0.198,0.259,0.259,0.223,0.178,0.159,0.637,0.721,3.452,0.367,0.758
",",0.013,0.013,0.019,0.033,0.111,0.111,0.051,0.074,-0.048,-0.086,0.31,0.65,0.122,-0.19
it,0.285,0.285,0.076,0.084,-0.078,-0.078,-0.05,-0.027,0.18,0.196,0.003,0.246,0.049,0.184
would,0.403,0.403,0.02,0.029,0.237,0.237,0.028,0.094,0.178,0.243,-0.377,-0.012,0.21,0.468
not,-0.07,-0.07,-0.031,-0.041,0.073,0.073,0.015,0.019,-0.196,-0.196,-0.104,-0.149,-0.147,-0.177
be,0.165,0.165,0.063,0.082,-0.069,-0.069,-0.104,-0.077,0.086,0.068,-0.102,-0.133,-0.088,0.05
hurting,-0.146,-0.146,0.016,0.101,-0.006,-0.006,-0.089,-0.022,0.023,-0.078,0.055,0.041,0.022,-0.252
anyone,-0.144,-0.144,0.017,0.028,0.042,0.042,-0.042,-0.002,0.037,0.085,0.025,0.152,-0.003,-0.361
.,0.025,0.025,-0.002,0.011,-0.028,-0.028,-0.025,-0.022,-0.03,-0.052,-0.05,-0.076,-0.036,-0.122
Ċ,0.132,0.132,-0.055,-0.052,-0.076,-0.076,0.033,-0.006,-0.012,-0.037,0.057,-0.005,0.079,-0.11
Ċ,-0.01,-0.01,0.049,0.052,0.035,0.035,0.045,0.036,-0.015,-0.016,-0.059,-0.019,-0.037,0.005
"""",0.292,0.292,0.22,0.22,0.152,0.152,-0.037,0.124,0.261,0.197,-0.08,-0.039,0.137,0.21
There,0.303,0.303,0.028,0.008,0.064,0.064,-0.012,0.004,-0.051,-0.01,-0.012,0.004,0.031,-0.012
's,-0.029,-0.029,0.019,0.022,0.055,0.055,0.04,0.091,0.078,0.076,-0.042,-0.01,0.07,0.101
no,0.063,0.063,-0.003,-0.004,0.023,0.023,0.009,0.012,-0.055,-0.061,0.032,0.024,-0.015,0.005
real,0.118,0.118,-0.024,-0.045,0.092,0.092,0.043,0.042,0.018,0.051,0.015,0.038,0.004,0.05
danger,-0.024,-0.024,-0.021,-0.02,-0.048,-0.048,-0.067,-0.045,0.117,0.151,0.053,-0.119,0.063,0.0
",",-0.078,-0.078,-0.026,-0.017,-0.017,-0.017,-0.011,-0.013,-0.003,-0.035,0.009,0.007,0.009,-0.017
because,0.344,0.344,0.015,-0.008,0.106,0.106,0.074,0.039,0.05,0.119,0.055,0.204,0.04,0.123
we,0.036,0.036,0.004,0.013,-0.002,-0.002,-0.029,-0.03,0.04,0.065,-0.035,-0.021,0.038,0.188
don,0.228,0.228,-0.023,-0.035,0.03,0.03,0.045,0.031,-0.011,0.03,-0.003,0.077,0.003,0.013
't,0.022,0.022,-0.026,-0.002,-0.011,-0.011,-0.055,-0.075,0.179,0.179,-0.012,0.148,0.006,0.129
know,0.344,0.344,0.038,0.035,0.049,0.049,-0.01,0.011,0.116,0.142,0.11,0.027,0.032,-0.034
where,-0.144,-0.144,0.02,0.028,-0.05,-0.05,-0.017,-0.018,0.035,0.044,-0.088,-0.079,-0.019,0.046
it,0.122,0.122,0.009,0.01,-0.02,-0.02,-0.009,-0.01,0.137,0.207,0.08,0.094,0.005,0.019
's,0.037,0.037,0.02,0.023,0.007,0.007,0.017,0.032,0.012,0.015,-0.073,-0.023,0.024,0.057
coming,-0.07,-0.07,-0.012,-0.015,-0.046,-0.046,-0.048,-0.057,-0.006,-0.008,-0.166,-0.185,-0.062,-0.055
from,-0.033,-0.033,0.016,0.014,-0.024,-0.024,0.022,0.014,-0.164,-0.234,-0.12,0.008,-0.004,-0.066
",""",0.382,0.382,0.035,0.015,0.066,0.066,0.012,0.066,0.001,0.088,-0.049,0.08,0.086,0.186
