In [None]:
import functools
import shap
import argparse
import numpy as np 
import pandas as pd
import seaborn as sns
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

import torch
import torch.nn as nn
from datasets import load_dataset
from torch.utils.data import DataLoader, Dataset
from torch.utils.tensorboard import SummaryWriter

from transformers import (
    AdamW,
    AutoConfig,
    AutoModel,
    AutoModelForPreTraining,
    AutoModelForSequenceClassification,
    AlbertForSequenceClassification,
    MT5ForConditionalGeneration,
    AutoModelWithLMHead,
    AutoTokenizer,
    AlbertTokenizer,
    T5Tokenizer,
    PretrainedConfig,
    PreTrainedTokenizer,
)
from transformers.optimization import get_linear_schedule_with_warmup
from captum.attr import visualization as viz

shap.initjs()

In [None]:
#kb_bert = 'KB/bert-base-swedish-cased'
kb_bert = 'KB/electra-base-swedish-cased-discriminator'
tokenizer = AutoTokenizer.from_pretrained(kb_bert)
model = AutoModelForSequenceClassification.from_pretrained(kb_bert)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#model.load_state_dict(torch.load("../models/KB/bert-base-swedish-cased.pt"))
model.load_state_dict(torch.load("../models/KB/electra-base-swedish-cased-discriminator.pt"))

model = model.to(device)

vis_data_records = []

In [None]:
RANDOM_SEED = 0
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

df = pd.read_csv("../data/dataset_no_recipe.csv")
df.columns = ['text', 'label']
random = df.iloc[np.random.permutation(len(df))]
train = random.iloc[:round(len(df)*.8)]
test = random.iloc[round(len(df)*.8):]
test.to_csv('../data/models/test2.csv', index = False)
print(train.shape)
print(test.shape)

In [None]:
test_dataset = load_dataset("csv", data_files='../data/test2.csv')

def tokenize(batch):
    return tokenizer(batch['text'], padding='max_length', truncation=True,  max_length = 512, add_special_tokens = True)

test_dataset = test_dataset.map(tokenize, batched=True, batch_size=len(test_dataset))
test_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'label'])

In [None]:
def predict_fn(input_ids, attention_mask=None, batch_size=64, label=None,
               output_logits=False, repeat_input_ids=False):
    """
    Wrapper function for a Huggingface Transformers model into the format that KernelSHAP expects,
    i.e. where inputs and outputs are numpy arrays.
    """

    input_ids = torch.tensor(input_ids, device=device)
    attention_mask = torch.ones_like(input_ids, device=device) if attention_mask is None else torch.tensor(attention_mask, device=device)

    if repeat_input_ids:
        assert input_ids.shape[0] == 1
        input_ids = input_ids.repeat(attention_mask.shape[0], 1)
 
    ds = torch.utils.data.TensorDataset(input_ids.long(), attention_mask.long())
    dl = torch.utils.data.DataLoader(ds, batch_size=batch_size)
    probas = []
    logits = []
    with torch.no_grad():
        for batch in dl:
            out = model(batch[0], attention_mask=batch[1])
            logits.append(out[0].clone().detach())
            probas.append(torch.nn.functional.softmax(out[0],
                                                      dim=1).detach())
    logits = torch.cat(logits, dim=0).detach().cpu().clone().numpy()
    probas = torch.cat(probas, dim=0).detach().cpu().clone().numpy()

    if label is not None:
        probas = probas[:, label]
        logits = logits[:, label]

    return (probas, logits) if output_logits else probas

def tokens2words(tokens, seq, token_prefix="##"):
    """
    Utility function to aggregate 'seq' on word-level based on 'tokens'
    """

    tmp = []
    for token, x in zip(tokens, seq):
        if token.startswith(token_prefix):
            if type(x) == str:
                x = x.replace(token_prefix,"")
            tmp[-1] += x
        else:
            if type(x) == str:
                tmp.append(x)
            else:
                tmp.append(x.item())

    return tmp if type(tmp[-1]) == str else torch.tensor(tmp, device=device)

def add_attributions_to_visualizer(attributions, pred, pred_ind, label, tokens, delta, vis_data_records):
    # storing couple samples in an array for visualization purposes
    vis_data_records.append(viz.VisualizationDataRecord(
                            attributions/attributions.norm(),
                            pred,
                            pred_ind,
                            label,
                            pred_ind,
                            attributions.sum(),       
                            tokens,
                            delta)) 
   

In [None]:
nsamples = 500
idx = 370
#idx = np.random.choice(len(test_dataset['train']))
#ref_token = tokenizer.mask_token_id # Could also consider <UNK> or <PAD> tokens
ref_token = tokenizer.pad_token_id # Could also consider <UNK> or <PAD> tokens

In [None]:
test_dataset['train']['text'][idx]

In [None]:
input_x = test_dataset['train']
input_text = input_x['text'][idx]
label = input_x['label'][idx]
input_ids = input_x['input_ids'][idx].unsqueeze(0)
attention_mask = input_x['attention_mask'][idx].unsqueeze(0)

input_tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
input_words = tokens2words(input_tokens, input_tokens)
#pred = predict_fn(input_ids=input_ids, attention_mask=attention_mask)
pred = predict_fn(input_ids=input_ids)
pred_label = pred.argmax()
pred_p = pred[0, pred_label]

In [None]:
baseline = input_ids.detach().cpu().clone().numpy()
baseline_attn = np.zeros_like(input_ids)

# Keep CLS and SEP tokens fixed in baseline
baseline[:,1:-1] = ref_token
#baseline_attn[:, 0] = 1
#baseline_attn[:, -1] = 1

predict_fn_label = functools.partial(predict_fn, label=pred_label)
#predict_fn_label_attn = functools.partial(predict_fn_label, input_ids, repeat_input_ids=True)

explainer = shap.KernelExplainer(predict_fn_label, baseline)
#explainer_attn = shap.KernelExplainer(predict_fn_label_attn, baseline_attn)

phi = explainer.shap_values(input_ids.detach().cpu().clone().numpy(), nsamples=500)
phi_words = tokens2words(input_tokens, phi.squeeze())

phi.shape
explainer.expected_value

In [None]:
baseline

In [None]:
i = input_words.index('[SEP]') + 1

add_attributions_to_visualizer(phi_words[:i], pred_p, pred_label, label, input_words[:i], None, vis_data_records)

#vis_data_records.append(viz.VisualizationDataRecord(
    #phi_words[:i]/phi_words[:i].norm(), pred_p, pred_label, label,
    #pred_label, phi_words[:i].sum(), input_words[:i], None))

#phi_attn = explainer_attn.shap_values(np.ones_like(input_ids), nsamples=500)
#phi_attn_words = tokens2words(input_tokens, phi_attn.squeeze())
#viz_rec_attn = [viz.VisualizationDataRecord(
    #phi_attn_words[:idx]/phi_attn_words[:idx].norm(), pred_p, pred_label, label,
    #pred_label, phi_attn_words[:idx].sum(), input_words[:idx], None)]

In [None]:
viz.visualize_text(vis_data_records)

In [None]:
pred_p

In [None]:
phi_words

In [None]:
phi_words

In [None]:
test_dataset['train']

In [None]:
input_x = test_dataset['train']
input_text = input_x['text']
attention_mask = input_x['attention_mask']
label = input_x['label']
input_ids = input_x['input_ids']

pred = predict_fn(input_ids=input_ids, attention_mask=attention_mask, output_logits=True)
sns.distplot(pred[0])


In [None]:
accuracy_score(label.numpy(), np.argmax(pred[0],axis=1))


In [None]:
pred2 = predict_fn(input_ids=input_ids, output_logits=True)
sns.distplot(pred2[0])


In [None]:
accuracy_score(label.numpy(), np.argmax(pred2[0],axis=1))


In [None]:
positive_mask = np.where(label==1,True, False)
pred_pos = predict_fn(input_ids[positive_mask], attention_mask=attention_mask[positive_mask])
sns.distplot(pred_pos)

In [None]:
pred_neg = predict_fn(input_ids[~positive_mask], attention_mask=attention_mask[~positive_mask])
sns.distplot(pred_neg)

In [None]:
mask_correct = np.equal(np.argmax(pred[0],axis=1),label) 

pred_cor = predict_fn(input_ids[mask_correct], attention_mask=attention_mask[mask_correct])
sns.distplot(pred_cor)

In [None]:
pred_incor = predict_fn(input_ids[~mask_correct], attention_mask=attention_mask[~mask_correct])
sns.distplot(pred_incor)

In [None]:
len(label)

In [None]:
predict_fn_label = functools.partial(predict_fn, label=1)
explainer = shap.KernelExplainer(predict_fn_label, baseline)
phi = explainer.shap_values(input_ids = input_ids.detach().cpu().clone().numpy(), nsamples=nsamples)

In [None]:
import pickle

# Store data (serialize)
with open('../data/'+ kb_bert +'.pickle', 'wb') as handle:
    pickle.dump(phi, handle, protocol=pickle.HIGHEST_PROTOCOL)

# Load data (deserialize)
with open('../data/'+ kb_bert +'.pickle', 'rb') as handle:
    unserialized_data = pickle.load(handle)

print(phi == unserialized_data)

In [None]:
input_ids.shape

In [None]:
phi.shape

In [None]:
#shap.force_plot(explainer.expected_value, phi, input_ids.detach().cpu().clone().numpy())

In [None]:
#shap.summary_plot(phi, input_ids[-100:].detach().cpu().clone().numpy())

In [None]:
input_tokens = []
features = {}
for i in range(input_ids.shape[0]):
    input_tokens = tokenizer.convert_ids_to_tokens(input_ids[-i,:])
    input_words = tokens2words(input_tokens, input_tokens)
    phi_words = tokens2words(input_tokens, phi[i].squeeze())
    for j in range(len(input_words)):
        if input_words[j] in features.keys():
            old_val = features[input_words[j]]
            #features[input_words[j]] = ((phi_words[j]/phi_words.norm()).item() + old_val[0], old_val[1]+1)
            features[input_words[j]] = ((phi_words[j]).item() + old_val[0], old_val[1]+1)
        else:
            #features[input_words[j]] = ((phi_words[j]/phi_words.norm()).item(), 1)
            features[input_words[j]] = ((phi_words[j]).item(), 1)


In [None]:
explainer.expected_value

In [None]:
{k: v for k, v in sorted(features.items(), key=lambda item: item[1][0], reverse=True)}