In [1]:
import pandas as pd
import numpy as np
import scipy as sp
import matplotlib.pyplot as plt
import torch
import transformers
import datasets
import shap


# load the emotion dataset
dataset  = datasets.load_dataset("emotion", split = "train")
data = pd.DataFrame({'text':dataset['text'],'emotion':dataset['label']})


# load the model and tokenizer
tokenizer = transformers.AutoTokenizer.from_pretrained("nateraw/bert-base-uncased-emotion", use_fast=True)
model = transformers.AutoModelForSequenceClassification.from_pretrained("nateraw/bert-base-uncased-emotion").cuda()
labels = sorted(model.config.label2id, key=model.config.label2id.get)

# this defines an explicit python function that takes a list of strings and outputs scores for each class
def f(x):
    tv = torch.tensor([tokenizer.encode(v, padding='max_length', max_length=128, truncation=True) for v in x]).cuda()
    attention_mask = (tv!=0).type(torch.int64).cuda()
    outputs = model(tv,attention_mask=attention_mask)[0].detach().cpu().numpy()
    scores = (np.exp(outputs).T / np.exp(outputs).sum(-1)).T
    val = sp.special.logit(scores)
    return val

explainer = shap.Explainer(f, tokenizer, output_names=labels)

shap_values = explainer(data['text'][:1])

shap.plots.text(shap_values)

  def _pt_shuffle_rec(i, indexes, index_mask, partition_tree, M, pos):
  def delta_minimization_order(all_masks, max_swap_size=100, num_passes=2):
  def _reverse_window(order, start, length):
  def _reverse_window_score_gain(masks, order, start, length):
  def _mask_delta_score(m1, m2):
  def identity(x):
  def _identity_inverse(x):
  def logit(x):
  def _logit_inverse(x):
  def _build_fixed_single_output(averaged_outs, last_outs, outputs, batch_positions, varying_rows, num_varying_rows, link, linearizing_weights):
  def _build_fixed_multi_output(averaged_outs, last_outs, outputs, batch_positions, varying_rows, num_varying_rows, link, linearizing_weights):
  def _init_masks(cluster_matrix, M, indices_row_pos, indptr):
  def _rec_fill_masks(cluster_matrix, indices_row_pos, indptr, indices, M, ind):
  def _single_delta_mask(dind, masked_inputs, last_mask, data, x, noop_code):
  def _delta_masking(masks, x, curr_delta_inds, varying_rows_out,
  def _jit_build_partition_tree(xmin, xmax, ymi