In [None]:
import functools
import shap
import argparse
import numpy as np 
import pandas as pd
import seaborn as sns
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

import torch
import torch.nn as nn
from datasets import load_dataset
from torch.utils.data import DataLoader, Dataset
from torch.utils.tensorboard import SummaryWriter

from transformers import (
    AdamW,
    AutoConfig,
    AutoModel,
    AutoModelForPreTraining,
    AutoModelForSequenceClassification,
    AlbertForSequenceClassification,
    MT5ForConditionalGeneration,
    AutoModelWithLMHead,
    AutoTokenizer,
    AlbertTokenizer,
    T5Tokenizer,
    PretrainedConfig,
    PreTrainedTokenizer,
)
from transformers.optimization import get_linear_schedule_with_warmup
from captum.attr import visualization as viz

shap.initjs()

In [None]:
#kb_bert = 'KB/bert-base-swedish-cased'
#kb_bert = 'KB/electra-base-swedish-cased-discriminator'
#kb_bert = 'KB/albert-base-swedish-cased-alpha'
#kb_bert = 'bert-base-multilingual-cased'
kb_bert = 'xlm-roberta-base'
#kb_bert = 'sentence-transformers/xlm-r-100langs-bert-base-nli-mean-tokens'

tokenizer = AutoTokenizer.from_pretrained(kb_bert)
#tokenizer = AlbertTokenizer.from_pretrained('KB/albert-base-swedish-cased-alpha')

model = AutoModelForSequenceClassification.from_pretrained(kb_bert)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#model.load_state_dict(torch.load("../models/KB/bert-base-swedish-cased.pt"))
model.load_state_dict(torch.load("../models/"+kb_bert+"_ft.pt"))
#model.load_state_dict(torch.load("../models/sentence-transformers/xlm-r-100langs-bert-base-nli-mean-tokens_ft.pt"))

model = model.to(device)

vis_data_records = []

In [None]:
RANDOM_SEED = 0
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

df = pd.read_csv("../data/dataset_no_recipe.csv")
df.columns = ['text', 'label']
random = df.iloc[np.random.permutation(len(df))]
train = random.iloc[:round(len(df)*.8)]
test = random.iloc[round(len(df)*.8):]
test.to_csv('../data/test2_names.csv', index = False)
print(train.shape)
print(test.shape)

In [None]:
test_dataset = load_dataset("csv", data_files='../data/test2.csv')
test_ind = test_dataset

def tokenize(batch):
    return tokenizer(batch['text'], max_length = 512, add_special_tokens = True)

test_dataset = test_dataset.map(tokenize, batched=True, batch_size=len(test_dataset))
test_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'label'])

In [None]:
def predict_fn(input_ids, attention_mask=None, batch_size=64, label=None,
               output_logits=False, repeat_input_ids=False):
    """
    Wrapper function for a Huggingface Transformers model into the format that KernelSHAP expects,
    i.e. where inputs and outputs are numpy arrays.
    """

    input_ids = torch.tensor(input_ids, device=device)
    attention_mask = torch.ones_like(input_ids, device=device) if attention_mask is None else torch.tensor(attention_mask, device=device)

    if repeat_input_ids:
        assert input_ids.shape[0] == 1
        input_ids = input_ids.repeat(attention_mask.shape[0], 1)
 
    ds = torch.utils.data.TensorDataset(input_ids.long(), attention_mask.long())
    dl = torch.utils.data.DataLoader(ds, batch_size=batch_size)
    probas = []
    logits = []
    with torch.no_grad():
        for batch in dl:
            out = model(batch[0], attention_mask=batch[1])
            logits.append(out[0].clone().detach())
            probas.append(torch.nn.functional.softmax(out[0],
                                                      dim=1).detach())
    logits = torch.cat(logits, dim=0).detach().cpu().clone().numpy()
    probas = torch.cat(probas, dim=0).detach().cpu().clone().numpy()

    if label is not None:
        probas = probas[:, label]
        logits = logits[:, label]

    return (probas, logits) if output_logits else probas


def tokens2words(tokens, seq, token_prefix="##"):#"▁"):
    """
    Utility function to aggregate 'seq' on word-level based on 'tokens'
    """

    tmp = []
    for token, x in zip(tokens, seq):
        if token.startswith(token_prefix):
            if type(x) == str:
                x = x.replace(token_prefix,"")
            tmp[-1] += x
        else:
            if type(x) == str:
                tmp.append(x)
            else:
                tmp.append(x.item())

    return tmp if type(tmp[-1]) == str else torch.tensor(tmp, device=device)

def tokens2wordssentence(tokens, seq, token_prefix="▁"):
    """
    Utility function to aggregate 'seq' on word-level based on 'tokens'
    """

    tmp = []
    i = 0
    for token, x in zip(tokens, seq):
        
        if token.startswith(token_prefix):
            if type(x) == str:
                x = x.replace(token_prefix," ")
                tmp.append(x)
        elif token.startswith(" "):
            if type(x) == str:
                tmp.append(x)
        else:
            if i!=0:
                tmp[-1] += x
        i+=1


    return tmp if type(tmp[-1]) == str else torch.tensor(tmp, device=device)

def add_attributions_to_visualizer(attributions, pred, pred_ind, label, tokens, delta, vis_data_records):
    # storing couple samples in an array for visualization purposes
    vis_data_records.append(viz.VisualizationDataRecord(
                            attributions/attributions.norm(),
                            pred,
                            pred_ind,
                            label,
                            pred_ind,
                            attributions.sum(),       
                            tokens,
                            delta)) 
   

In [None]:
input_x = test_dataset['train']
input_text = input_x['text']
attention_mask = input_x['attention_mask']
label = input_x['label']
input_ids = input_x['input_ids']

#pred = predict_fn(input_ids=input_ids, attention_mask=attention_mask, output_logits=True)
#sns.distplot(pred[0])

In [None]:
"""input_x = test_dataset["train"]
acc = 0
predictions = []
for i in range(len(input_x)):
    attention_mask = input_x['attention_mask'][i].unsqueeze(0)
    input_ids = input_x['input_ids'][i].unsqueeze(0)

    #pred = predict_fn(input_ids=input_ids, attention_mask=attention_mask)
    pred = predict_fn(input_ids=input_ids)
    if np.argmax(pred) == input_x['label'][i]:
        acc+=1
    predictions.append(pred)
    
sns.distplot(predictions)
print(acc/len(input_x))"""

In [None]:
nsamples = 500
torch.cuda.empty_cache()
#idx = 2029#20
idx = np.random.choice(len(test_dataset['train']))
print(idx)
ref_token = tokenizer.pad_token_id # Could also consider <UNK> or <PAD> tokens

In [None]:
input_x = test_dataset['train']
input_text = input_x['text'][idx]
label = input_x['label'][idx]
input_ids = input_x['input_ids'][idx].unsqueeze(0)
attention_mask = input_x['attention_mask'][idx].unsqueeze(0)

input_tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
input_words = tokens2wordssentence(input_tokens, input_tokens)
#pred = predict_fn(input_ids=input_ids, attention_mask=attention_mask)
pred = predict_fn(input_ids=input_ids)
pred_label = pred.argmax()
pred_p = pred[0, pred_label]

In [None]:
print(input_text)
print(input_words)
print(input_tokens)


In [None]:
baseline = input_ids.detach().cpu().clone().numpy()
baseline_attn = np.zeros_like(input_ids)

# Keep CLS and SEP tokens fixed in baseline
baseline[:,1:-1] = ref_token
#baseline_attn[:, 0] = 1
#baseline_attn[:, -1] = 1

predict_fn_label = functools.partial(predict_fn, label=pred_label)
#predict_fn_label_attn = functools.partial(predict_fn_label, input_ids, repeat_input_ids=True)

explainer = shap.KernelExplainer(predict_fn_label, baseline)
#explainer_attn = shap.KernelExplainer(predict_fn_label_attn, baseline_attn)

phi = explainer.shap_values(input_ids.detach().cpu().clone().numpy(), nsamples=500)
phi_words = tokens2words(input_tokens, phi.squeeze())

phi.shape
explainer.expected_value

In [None]:
add_attributions_to_visualizer(phi_words, pred_p, pred_label, label, input_words, None, vis_data_records)

In [None]:
viz.visualize_text(vis_data_records)

In [None]:
predict_fn_label = functools.partial(predict_fn, label=1)
ref_token = tokenizer.pad_token_id
input_x = test_dataset["train"]
nsamples = 500

shap_val = []
for i in range(len(input_x)):
    input_ids = input_x['input_ids'][i].unsqueeze(0)
    baseline = input_ids.detach().cpu().clone().numpy()
    baseline[:,1:-1] = ref_token

    explainer = shap.KernelExplainer(predict_fn_label, baseline)
    phi = explainer.shap_values(input_ids.detach().cpu().clone().numpy(), nsamples=nsamples)
    shap_val.append(phi)
    torch.cuda.empty_cache()

In [None]:
import pickle

# Store data (serialize)
with open('../data/'+ kb_bert +'.pickle_names', 'wb') as handle:
    pickle.dump(shap_val, handle, protocol=pickle.HIGHEST_PROTOCOL)

# Load data (deserialize)
with open('../data/'+ kb_bert +'.pickle_names', 'rb') as handle:
    unserialized_data = pickle.load(handle)


In [None]:
#tokenizer.convert_ids_to_tokens(baseline[0])
#model.config.output_hidden_states = True
#attention_mask = torch.ones_like(input_ids)

"""input_ids = input_x['input_ids'][-1].unsqueeze(0)

input_ids = torch.tensor(input_ids, device=device)
attention_mask = torch.ones_like(input_ids, device=device)

output = model(input_ids=input_ids, attention_mask=attention_mask)
logits = output.logits
hidden_states = output.hidden_states"""

In [None]:
input_tokens = []
features = {}
for i in range(len(input_x)):
    input_ids = input_x['input_ids'][i]
    input_tokens = tokenizer.convert_ids_to_tokens(input_ids)
    input_words = tokens2words(input_tokens, input_tokens)
    phi_words = tokens2words(input_tokens, shap_val[i][0])
    for j in range(len(input_words)):
        if input_words[j] in features.keys():
            old_val = features[input_words[j]]
            features[input_words[j]] = ((phi_words[j]).item() + old_val[0], old_val[1]+1)
        else:
            features[input_words[j]] = ((phi_words[j]).item(), 1)


In [None]:
{k: v for k, v in sorted(features.items(), key=lambda item: item[1][0], reverse=True)}