In [None]:
import functools
import shap
import argparse
import numpy as np 
import pandas as pd
import seaborn as sns
import pickle
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

import torch
import torch.nn as nn
from datasets import load_dataset
from torch.utils.data import DataLoader, Dataset
from torch.utils.tensorboard import SummaryWriter

from transformers import (
    AdamW,
    AutoConfig,
    AutoModel,
    AutoModelForPreTraining,
    AutoModelForSequenceClassification,
    AlbertForSequenceClassification,
    MT5ForConditionalGeneration,
    AutoModelWithLMHead,
    AutoTokenizer,
    AlbertTokenizer,
    T5Tokenizer,
    PretrainedConfig,
    PreTrainedTokenizer,
)
from transformers.optimization import get_linear_schedule_with_warmup
from captum.attr import visualization as viz

shap.initjs()

In [None]:
el = 'KB/electra-base-swedish-cased-discriminator'
#kb_bert = 'KB/bert-base-swedish-cased'
#kb_bert = 'KB/electra-base-swedish-cased-discriminator'
kb_bert =  'xlm-roberta-base'

with open('../models/'+ el +'.pickle', 'rb') as handle:
    electra = pickle.load(handle)
    
#with open('../models/'+kb_bert+'.pickle', 'rb') as handle:
with open('../models/'+ kb_bert +'.pickle_names', 'rb') as handle:
    bert = pickle.load(handle)

In [None]:
RANDOM_SEED = 0
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

df = pd.read_csv("../data/dataset_no_recipe.csv")
df.columns = ['text', 'label']
random = df.iloc[np.random.permutation(len(df))]
train = random.iloc[:round(len(df)*.8)]
test = random.iloc[round(len(df)*.8):]
#test.to_csv('../data/test2_names.csv', index = False)
print(train.shape)
print(test.shape)

In [None]:
#kb_bert = 'KB/bert-base-swedish-cased'
#kb_bert = 'KB/electra-base-swedish-cased-discriminator'
tokenizer = AutoTokenizer.from_pretrained(kb_bert)

device = torch.device("cpu")

#test_dataset = load_dataset("csv", data_files='../data/test2_names.csv')
test_dataset = load_dataset("csv", data_files='../data/test2.csv')
test_ind = test_dataset

def tokenize(batch):
    return tokenizer(batch['text'], max_length = 512, add_special_tokens = True)

test_dataset = test_dataset.map(tokenize, batched=True, batch_size=len(test_dataset))
test_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'label'])

In [None]:
def tokens2words(tokens, seq, token_prefix="##"):
    """
    Utility function to aggregate 'seq' on word-level based on 'tokens'
    """

    tmp = []
    for token, x in zip(tokens, seq):
        if token.startswith(token_prefix):
            if type(x) == str:
                x = x.replace(token_prefix,"")
            tmp[-1] += x
        else:
            if type(x) == str:
                tmp.append(x)
            else:
                tmp.append(x.item())

    return tmp if type(tmp[-1]) == str else torch.tensor(tmp, device=device)

In [None]:
input_x = test_dataset['train']
input_text = input_x['text']
label = input_x['label']
input_ids = input_x['input_ids']

#pred = predict_fn(input_ids, output_logits=True)
#sns.distplot(pred[0])


In [None]:
def tokens2wordssentence(tokens, seq, token_prefix="▁"):
    """
    Utility function to aggregate 'seq' on word-level based on 'tokens'
    """

    tmp = []
    for token, x in zip(tokens, seq):
        if token.startswith(token_prefix):
            if type(x) == str:
                x = x.replace(token_prefix," ")
            if type(x) == str:
                tmp.append(x)
            else:
                tmp.append(x.item())
                
        elif not token.startswith(" "):
            if type(x) == str:
                x = x.replace(" ","")
            if len(tmp)>0:
                tmp[-1] += x
        else:
            if type(x) == str:
                tmp.append(x)
            else:
                tmp.append(x.item())
                
    if len(tmp)==0:
        print(tokens,tmp)
    
    return tmp if type(tmp[-1]) == str else torch.tensor(tmp, device=device)

In [None]:
shap_val=bert

input_tokens = []
features = {}
absol = {}
for i in range(len(input_x)):
    input_ids = input_x['input_ids'][i]
    input_tokens = tokenizer.convert_ids_to_tokens(input_ids)
    input_words = tokens2wordssentence(input_tokens, input_tokens)
    phi_words = tokens2wordssentence(input_tokens, shap_val[i][0])
    for j in range(len(input_words)):
        if input_words[j] in features.keys():
            old_val = features[input_words[j]]
            features[input_words[j]] = ((phi_words[j]).item() + old_val[0], old_val[1]+1)
            absol[input_words[j]] = (np.abs((phi_words[j]).item()) + old_val[0], old_val[1]+1)
        else:
            features[input_words[j]] = ((phi_words[j]).item(), 1)
            absol[input_words[j]] = (np.abs((phi_words[j]).item()), 1)


In [None]:
{k: v for k, v in sorted(features.items(), key=lambda item: item[1][0], reverse=True)}

In [None]:
import matplotlib.pyplot as plt
i=0
top = 20
top_dict = {}
for k,v in sorted(features.items(), key=lambda item:item[1][0], reverse=True):
    if i==top:
        break
    if k != ' [PAD]' and k != ' 1' and k != '46' and k != ' .' and k != ',' and k != '(' and k != ')' and k != '7' and k != ':' and k != 'i':
        top_dict[k] = v[0]
        i+=1

plt.rcdefaults()
fig, ax = plt.subplots()

ax.barh(np.arange(top),top_dict.values(), color='limegreen')
ax.set_yticks(np.arange(top))
ax.set_yticklabels(top_dict.keys())
ax.invert_yaxis()
ax.set_xlabel("Global shap values")
ax.set_title("Top words for antibiotics prescription")
plt.savefig('top_ab_words.png', bbox_inches='tight')
plt.show()

In [None]:
i=0
top = 20
top_dict = {}
for k,v in sorted(features.items(), key=lambda item:item[1][0], reverse=False):
    if i==top:
        break
    if k != ':' and k != '-' and k != '.' and k != '/' and k != '%' and k != '0' and k != ',':
        top_dict[k] = -v[0]
        i+=1

plt.rcdefaults()
fig, ax = plt.subplots()

ax.barh(np.arange(top),top_dict.values(), color='r')
ax.set_yticks(np.arange(top))
ax.set_yticklabels(top_dict.keys())
ax.invert_yaxis()
ax.set_xlabel("Global shap values")
ax.set_title("Top words for not prescribing antibiotics")
plt.savefig('top_noab_words.png', bbox_inches='tight')
plt.show()


In [None]:
i=0
top = 20
top_dict = {}
for k,v in sorted(absol.items(), key=lambda item:item[1][0], reverse=True):
    if i==top:
        break
    if k != ':' and k != '-' and k != '.' and k != '/' and k != '%' and k != '0' and k != '1' and k != '26':
        top_dict[k] = v[0]
        i+=1

plt.rcdefaults()
fig, ax = plt.subplots()

ax.barh(np.arange(top),top_dict.values(), color='b')
ax.set_yticks(np.arange(top))
ax.set_yticklabels(top_dict.keys())
ax.invert_yaxis()
ax.set_xlabel("Global shap values")
ax.set_title("Top words ")
plt.savefig('top_words.png', bbox_inches='tight')
plt.show()