In [None]:
import functools

import numpy as np
import torch
import shap
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from captum.attr import visualization as viz

In [None]:
sst2_dataset = load_dataset("glue", "sst2")
model_name = "distilbert-base-uncased-finetuned-sst-2-english"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)

In [None]:
sst2_dataset['validation']

In [None]:
def predict_fn(input_ids, attention_mask=None, batch_size=32, label=None,
               output_logits=False, repeat_input_ids=False):
    """
    Wrapper function for a Huggingface Transformers model into the format that KernelSHAP expects,
    i.e. where inputs and outputs are numpy arrays.
    """

    input_ids = torch.tensor(input_ids)
    attention_mask = torch.ones_like(input_ids) if attention_mask is None else torch.tensor(attention_mask)

    if repeat_input_ids:
        assert input_ids.shape[0] == 1
        input_ids = input_ids.repeat(attention_mask.shape[0], 1)
 
    ds = torch.utils.data.TensorDataset(input_ids.long(), attention_mask.long())
    dl = torch.utils.data.DataLoader(ds, batch_size=batch_size)
    probas = []
    logits = []
    with torch.no_grad():
        for batch in dl:
            out = model(batch[0], attention_mask=batch[1])
            logits.append(out[0].detach())
            probas.append(torch.nn.functional.softmax(out[0],
                                                      dim=1).detach())
    logits = torch.cat(logits, dim=0).numpy()
    probas = torch.cat(probas, dim=0).numpy()

    if label is not None:
        probas = probas[:, label]
        logits = logits[:, label]

    return (probas, logits) if output_logits else probas

def tokens2words(tokens, seq, token_prefix="##"):
    """
    Utility function to aggregate 'seq' on word-level based on 'tokens'
    """

    tmp = []
    for token, x in zip(tokens, seq):
        if token.startswith(token_prefix):
            if type(x) == str:
                x = x.replace(token_prefix,"")
            tmp[-1] += x
        else:
            if type(x) == str:
                tmp.append(x)
            else:
                tmp.append(x.item())

    return tmp if type(tmp[-1]) == str else torch.tensor(tmp)

In [None]:
nsamples = 500
idx = 101
ref_token = tokenizer.mask_token_id # Could also consider <UNK> or <PAD> tokens


In [None]:
input_text = sst2_dataset["validation"][idx]["sentence"]
label = sst2_dataset["validation"][idx]["label"]
input_ids = tokenizer.encode(input_text, return_tensors="np")
input_tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
input_words = tokens2words(input_tokens, input_tokens)
pred = predict_fn(input_ids)
pred_label = pred.argmax()
pred_p = pred[0, pred_label]


In [None]:
input_ids.shape

In [None]:
pred_p

In [None]:
baseline = input_ids.copy()
baseline_attn = np.zeros_like(input_ids)

# Keep CLS and SEP tokens fixed in baseline
baseline[:, 1:-1] = ref_token
baseline_attn[:, 0] = 1
baseline_attn[:, -1] = 1

In [None]:
baseline_attn

In [None]:
predict_fn_label = functools.partial(predict_fn, label=pred_label)
predict_fn_label_attn = functools.partial(predict_fn_label, input_ids, repeat_input_ids=True)

explainer = shap.KernelExplainer(predict_fn_label, baseline)
explainer_attn = shap.KernelExplainer(predict_fn_label_attn, baseline_attn)

In [None]:
phi = explainer.shap_values(input_ids, nsamples=nsamples)
phi_words = tokens2words(input_tokens, phi.squeeze())

In [None]:
phi.squeeze()

In [None]:
phi_words

In [None]:
viz_rec = [viz.VisualizationDataRecord(
    phi_words/phi_words.norm(), pred_p, pred_label, label,
    pred_label, phi_words.sum(), input_words, None)]

phi_attn = explainer_attn.shap_values(np.ones_like(input_ids), nsamples=nsamples)
phi_attn_words = tokens2words(input_tokens, phi_attn.squeeze())
viz_rec_attn = [viz.VisualizationDataRecord(
    phi_attn_words/phi_attn_words.norm(), pred_p, pred_label, label,
    pred_label, phi_attn_words.sum(), input_words, None)]

In [None]:
viz.visualize_text(viz_rec)

In [None]:
viz.visualize_text(viz_rec_attn)