In [None]:
# Authors: Alexander Dolk and Hjalmar Davidsen

In [None]:
import numpy as np
import lime
import torch
import torch.nn.functional as F
from lime.lime_text import LimeTextExplainer
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import pandas as pd
import scipy as sp
import matplotlib.pyplot as plt
import transformers
import shap

In [None]:
tokenizer = AutoTokenizer.from_pretrained('<PATH TO TOKENIZER>', 
                                                   local_files_only=True,
                                                   model_max_length=512,
                                                   max_len=512,
                                                   truncation=True,
                                                   padding='Longest')
model = AutoModelForSequenceClassification.from_pretrained('<PATH TO CLASSIFICATION MODEL>', local_files_only=True, problem_type="multi_label_classification", num_labels=18)
pred = transformers.pipeline("text-classification", model=model, tokenizer=tokenizer, return_all_scores=True)
labels = ['K567', 'K573', 'K358', 'K590', 'K800', 'K379', 'K802', 'K610', 'K566', 'K509', 'K859', 'K572', 'K353', 'K650', 'K922', 'K565', 'K210', 'K560']

In [None]:
# this defines an explicit python function that takes a list of strings and outputs scores for each class
def f(x):
    tv = torch.tensor([tokenizer.encode(v, padding='max_length', max_length=128, truncation=True) for v in x])
    attention_mask = (tv!=0).type(torch.int64)
    outputs = model(tv,attention_mask=attention_mask)[0].detach().cpu().numpy()
    scores = (np.exp(outputs).T / np.exp(outputs).sum(-1)).T
    val = sp.special.logit(scores)
    return val

method = "custom tokenizer"

# build an explainer by passing a transformers tokenizer
if method == "transformers tokenizer":
    explainer = shap.Explainer(f, tokenizer, output_names=labels)

# build an explainer by explicitly creating a masker
elif method == "default masker":
    masker = shap.maskers.Text(r"\W") # this will create a basic whitespace tokenizer
    explainer = shap.Explainer(f, masker, output_names=labels)

# build a fully custom tokenizer
elif method == "custom tokenizer":
    import re

    def custom_tokenizer(s, return_offsets_mapping=True):
        """ Custom tokenizers conform to a subset of the transformers API.
        """
        pos = 0
        offset_ranges = []
        input_ids = []
        for m in re.finditer(r"\W", s):
            start, end = m.span(0)
            offset_ranges.append((pos, start))
            input_ids.append(s[pos:start])
            pos = end
        if pos != len(s):
            offset_ranges.append((pos, len(s)))
            input_ids.append(s[pos:])
        out = {}
        out["input_ids"] = input_ids
        if return_offsets_mapping:
            out["offset_mapping"] = offset_ranges
        return out

    masker = shap.maskers.Text(custom_tokenizer)
    explainer = shap.Explainer(f, masker, output_names=labels)

In [None]:
shap_values = explainer(["<DISCHARGE SUMMARY TO EXPLAIN>"])
shap.plots.text(shap_values)