In [2]:
import shap
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Input, Embedding, GlobalAveragePooling1D, Dense
from tensorflow.keras.models import Model
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences

texts = ["I love using SHAP with transformers", "This is terrible", "Absolutely amazing"]
labels = [1, 0, 1]

tokenizer = Tokenizer(num_words=1000, oov_token="<OOV>")
tokenizer.fit_on_texts(texts)
X = pad_sequences(tokenizer.texts_to_sequences(texts), maxlen=10)
y = np.array(labels)


input_layer = Input(shape=(10,))
x = Embedding(input_dim=1000, output_dim=32)(input_layer)
x = GlobalAveragePooling1D()(x)
x = Dense(16, activation="relu")(x)
output = Dense(1, activation="sigmoid")(x)
model = Model(input_layer, output)
model.compile(optimizer="adam", loss="binary_crossentropy", metrics=["accuracy"])
model.fit(X, y, epochs=5, batch_size=1)

def predict_fn(text_input):
    seq = tokenizer.texts_to_sequences(text_input)
    padded = pad_sequences(seq, maxlen=10)
    return model.predict(padded)

masker = shap.maskers.Text(" ")

explainer = shap.Explainer(predict_fn, masker)

shap_values = explainer(["I love using SHAP with transformers"])
shap.plots.text(shap_values[0])


Epoch 1/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 13ms/step - accuracy: 0.5417 - loss: 0.6960 
Epoch 2/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 10ms/step - accuracy: 0.5417 - loss: 0.6887
Epoch 3/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 12ms/step - accuracy: 1.0000 - loss: 0.6869
Epoch 4/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 13ms/step - accuracy: 1.0000 - loss: 0.6864
Epoch 5/5
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 13ms/step - accuracy: 0.4583 - loss: 0.6853    
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 67ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 32ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 62ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 44ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 35ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m