# Text to Text Explanation: Machine Translation Example

This notebook demonstrates use of generating model explanations for a text to text scenario on a pretrained transformer model for machine translation. In this demo, we showcase explanations on 2 different models provided by Hugging Face which inclues translation from English to Spanish (https://huggingface.co/Helsinki-NLP/opus-mt-en-es) and English to French (https://huggingface.co/Helsinki-NLP/opus-mt-en-fr).

In [1]:
import numpy as np
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import shap
import torch

## English to Spanish model

In [13]:
from transformers import T5Config, T5ForConditionalGeneration                   
                                                                                
from .src.data import T5MolTokenizer, T5SelfiesTokenizer, T5SimpleTokenizer, TaskPrefixDataset, data_collator

ImportError: attempted relative import with no known parent package

In [2]:
# load the model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-es")
model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-es").cuda()

# define the input sentences we want to translate
data = [
    "Transformers have rapidly become the model of choice for NLP problems, replacing older recurrent neural network models"
]

Downloading:   0%|          | 0.00/802k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/826k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.59M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/44.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/312M [00:00<?, ?B/s]

### Explain the model's predictions

In [3]:
# we build an explainer by passing the model we want to explain and
# the tokenizer we want to use to break up the input strings
explainer = shap.Explainer(model, tokenizer)

# explainers are callable, just like models
shap_values = explainer(data, fixed_context=1)

### Visualize shap explanations

In [4]:
shap.plots.text(shap_values)

Unnamed: 0_level_0,Transform,ers,have,rapidly,become,the,model,of,choice,for,N,LP,problems,",",replacing,older,recurrent,neural,network,models,Unnamed: 21_level_0
Los,1.945,5.094,1.883,-0.525,0.166,0.081,-0.245,0.304,-0.123,0.072,-0.231,-0.114,-0.225,-0.02,0.05,0.022,0.047,0.17,0.102,-0.117,-0.052
transformador,7.291,4.428,-0.044,0.134,-0.164,0.053,0.16,0.146,0.054,0.295,0.097,-0.098,-0.113,0.111,0.125,0.167,0.211,0.058,0.225,0.273,0.098
es,-0.165,-0.11,-0.009,-0.035,0.016,0.001,-0.008,-0.011,-0.017,-0.017,-0.003,-0.004,-0.002,0.015,-0.009,-0.001,-0.004,-0.008,-0.002,-0.002,0.004
se,-0.108,1.584,0.779,-0.917,5.282,-0.668,-0.709,-0.68,-0.031,0.037,-0.059,-0.019,-0.076,0.051,0.004,-0.032,-0.01,0.045,0.02,0.025,-0.022
han,-0.394,-0.291,6.01,-1.295,1.99,-0.018,-0.323,-0.154,-0.018,-0.002,-0.028,-0.019,-0.036,-0.02,-0.031,-0.011,-0.022,-0.02,-0.118,-0.046,-0.024
convertido,0.159,-0.027,-1.342,0.881,4.348,0.837,0.132,-0.136,0.018,0.027,0.119,0.066,-0.045,-0.014,0.022,0.039,0.035,-0.028,-0.08,-0.025,0.087
rápidamente,-0.395,-0.48,-1.043,12.262,-0.039,-0.175,-0.205,-0.322,0.16,0.094,0.078,-0.036,-0.005,0.117,0.095,0.097,0.11,0.053,-0.04,0.09,0.103
en,-0.456,-0.449,-0.597,-0.269,2.29,0.754,-0.01,-0.453,0.357,0.15,-0.152,-0.023,0.042,0.104,0.102,0.065,0.05,0.07,0.213,0.108,0.007
el,-0.03,-0.012,-0.104,0.292,-0.905,1.922,2.482,0.821,0.396,0.151,0.022,-0.14,-0.186,0.05,-0.018,0.053,0.046,0.008,-0.038,-0.024,0.033
modelo,0.014,-0.108,-0.048,0.384,-1.631,0.417,8.906,-1.484,1.006,0.203,0.07,-0.322,-0.106,0.036,0.124,0.201,0.181,0.182,0.175,0.202,0.108
de,0.276,0.103,0.267,0.373,0.227,-0.281,-0.447,1.064,-0.375,-0.107,0.43,0.053,-0.058,-0.181,-0.042,0.024,0.14,-0.066,-0.404,-0.322,0.209
elección,-0.869,-0.83,-0.876,-0.622,-0.212,-0.207,3.542,2.623,13.116,-1.388,-1.674,-1.601,-1.389,-0.875,0.027,-0.038,-0.059,0.161,0.159,0.021,-0.152
para,-0.055,0.049,-0.076,0.159,-0.079,-0.056,0.414,0.516,-0.414,4.883,-0.258,-0.471,0.031,-0.398,-0.021,0.076,0.076,0.097,0.082,0.139,0.049
problemas,0.045,0.084,0.176,0.106,-0.015,-0.321,0.153,-0.234,-1.044,1.566,0.444,-1.215,10.147,-0.463,0.146,0.005,0.041,0.163,0.142,0.168,0.059
N,0.233,0.059,0.114,0.169,0.028,-0.073,0.314,-0.131,0.57,0.241,8.984,-1.445,1.438,-0.201,0.145,0.041,0.014,-0.042,0.095,0.153,0.175
LP,-0.093,-0.181,-0.15,-0.137,-0.184,-0.158,-0.13,-0.134,-0.103,-0.056,0.205,11.528,0.166,-0.114,-0.08,0.029,0.056,0.07,-0.019,0.02,-0.034
",",-0.018,-0.087,-0.004,0.259,0.113,-0.121,0.327,-0.122,-0.565,-0.551,-1.791,-0.863,2.462,5.537,0.416,-0.015,-0.042,-0.076,0.063,0.066,-0.026
reemplaza,-0.18,-0.02,-0.053,-0.055,0.068,-0.12,-0.149,-0.07,0.214,0.052,0.031,0.221,0.323,0.84,13.138,-0.854,-0.473,-0.431,-0.338,-0.472,-0.709
ndo,-0.05,-0.085,-0.069,-0.082,-0.11,-0.082,0.033,-0.106,-0.045,-0.112,-0.05,-0.025,-0.098,0.078,1.278,0.099,0.097,-0.019,0.252,0.091,-0.111
modelos,-0.132,0.097,-0.074,-0.095,0.004,0.079,-0.015,0.097,-0.037,0.124,-0.001,-0.057,-0.018,-0.036,-0.452,1.308,0.214,-0.439,-0.995,9.974,-0.373
de,-0.028,-0.008,-0.038,0.006,-0.048,-0.08,-0.063,-0.063,-0.023,-0.011,-0.092,-0.069,-0.021,0.024,-0.472,0.303,-0.589,-0.661,2.694,0.399,-0.179
red,-0.117,-0.13,-0.147,-0.018,-0.223,-0.288,0.174,-0.321,0.311,-0.02,-0.086,-0.092,0.118,0.131,0.136,0.061,-0.092,-1.013,9.76,0.122,-0.276
neuro,0.068,-0.126,-0.083,-0.103,-0.045,-0.035,-0.097,-0.027,-0.081,-0.082,0.254,0.189,-0.091,-0.18,-1.071,-0.983,-1.693,12.643,2.035,-0.375,-0.876
nal,-0.118,-0.029,-0.081,-0.159,-0.106,-0.025,-0.066,-0.001,-0.001,-0.003,0.198,-0.065,-0.105,-0.228,-0.796,-0.771,-1.173,5.364,1.639,0.309,-0.354
recurrente,0.129,-0.079,-0.027,-0.01,-0.001,-0.018,0.042,-0.096,-0.006,0.007,-0.039,0.025,-0.113,-0.136,-0.033,-1.69,11.393,1.248,1.129,1.025,-0.534
s,0.012,0.051,-0.031,-0.043,0.043,0.089,-0.125,0.031,-0.072,-0.007,-0.026,0.031,0.025,-0.155,-0.068,0.373,0.271,-0.045,0.026,1.048,-0.077
más,-0.041,-0.135,-0.11,0.152,-0.115,-0.145,0.04,-0.104,0.045,0.174,-0.104,-0.04,-0.023,0.13,-1.824,7.363,1.043,-0.747,0.437,0.368,-0.244
antiguos,-0.159,-0.048,-0.088,-0.138,-0.181,0.084,-0.222,0.025,-0.16,-0.029,0.011,-0.028,-0.167,-0.119,-0.618,6.305,0.001,-0.427,0.469,0.793,-0.16


## English to French

In [5]:
tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-fr")
model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-fr").cuda()

Downloading:   0%|          | 0.00/1.13k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/778k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/802k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.34M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/42.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/301M [00:00<?, ?B/s]

In [6]:
explainer = shap.Explainer(model,tokenizer)

In [7]:
shap_values = explainer(data)

In [8]:
shap.plots.text(shap_values)

Unnamed: 0_level_0,Trans former s have,rapidly become the model,of choice for N,LP problems,",",replacing older,recurrent ne,ural network,models
Les,6.444,-0.428,-0.511,0.289,0.054,-0.955,-0.312,0.113,0.051
transformateurs,12.165,0.195,0.831,-0.67,-0.213,0.87,0.122,-0.042,0.055
sont,1.463,2.382,0.033,-0.047,-0.0,-0.081,-0.068,-0.144,-0.165
rapidement,2.078,7.607,0.215,0.394,0.093,0.426,0.193,0.1,0.097
devenus,0.982,6.918,-0.299,0.038,0.008,0.352,-0.185,-0.023,0.007
le,0.392,5.862,0.022,-0.036,0.01,0.007,-0.042,0.028,0.02
modèle,0.125,7.002,0.389,-0.137,0.047,-0.035,-0.082,0.163,0.593
de,-0.022,0.387,3.329,-0.13,-0.063,-0.02,-0.114,0.129,0.007
choix,-0.42,3.933,6.109,-0.179,-0.117,0.093,-0.003,-0.348,-0.311
pour,0.247,0.334,3.965,-0.088,-0.065,0.029,0.056,0.141,-0.039
les,-0.105,-0.21,0.513,2.212,0.048,-0.022,0.043,0.259,0.072
problèmes,0.317,-0.29,1.43,6.483,0.871,0.059,0.37,0.229,-0.051
de,-0.316,-0.055,-0.128,0.805,0.164,-0.05,0.17,0.203,0.109
N,0.271,0.074,6.024,1.467,0.437,0.055,0.764,0.224,0.006
LP,-0.285,-0.276,1.518,9.614,0.988,-0.063,-0.587,-0.169,-0.262
",",0.311,0.705,0.686,1.201,1.939,0.044,0.025,0.016,0.003
remplaçant,-0.123,0.948,0.565,0.524,0.769,7.441,0.36,-0.032,0.187
les,-0.226,0.453,-0.04,0.224,0.126,1.174,-0.165,0.175,0.878
anciens,0.149,0.219,-0.184,0.192,0.056,6.71,0.126,-0.146,0.179
modèles,-0.05,0.133,0.14,0.093,0.19,0.551,0.968,0.614,5.679
de,-0.305,-0.245,0.152,0.128,0.043,-0.053,-0.092,1.213,0.792
réseaux,-0.038,-0.02,-0.21,0.22,-0.043,-0.335,0.324,4.518,1.894
neuro,-0.048,-0.058,0.344,0.019,0.058,0.257,4.493,5.286,0.984
naux,-0.076,-0.337,0.053,-0.028,-0.049,-0.373,1.163,3.022,0.39
récurrent,-0.222,0.118,-0.058,0.08,0.013,3.32,7.238,1.104,0.13
s,-0.027,-0.003,-0.03,0.038,0.013,-0.009,0.015,0.14,0.079
