In [None]:
import functools
import shap
import argparse
import numpy as np 
import pandas as pd
import seaborn as sns
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

import torch
import torch.nn as nn
from datasets import load_dataset
from torch.utils.data import DataLoader, Dataset
from torch.utils.tensorboard import SummaryWriter

from transformers import (
    AdamW,
    AutoConfig,
    AutoModel,
    AutoModelForPreTraining,
    AutoModelForSequenceClassification,
    AlbertForSequenceClassification,
    MT5ForConditionalGeneration,
    AutoModelWithLMHead,
    AutoTokenizer,
    AlbertTokenizer,
    T5Tokenizer,
    PretrainedConfig,
    PreTrainedTokenizer,
)
from transformers.optimization import get_linear_schedule_with_warmup
from captum.attr import visualization as viz

shap.initjs()

In [None]:
kb_bert = 'mt5'

tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")
model = MT5ForConditionalGeneration.from_pretrained("google/mt5-small")

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.load_state_dict(torch.load("../models/"+kb_bert+".pt"))
model = model.to(device)

vis_data_records = []

In [None]:
RANDOM_SEED = 0
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

df = pd.read_csv("../data/seq_dataset.csv")
df.columns = ['text', 'label']
random = df.iloc[np.random.permutation(len(df))]
train = random.iloc[:round(len(df)*.8)]
test = random.iloc[round(len(df)*.8):]
test.to_csv('../data/test2_seq.csv', index = False)
print(train.shape)
print(test.shape)

In [None]:
test.head(15)

In [None]:
class AntibioticsDataset(Dataset):
    def __init__(self, text, labels, tokenizer, max_len):
        self.text = text
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.text)

    def __getitem__(self, item):
        text = str(self.text[item])
        label = self.labels[item]

        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_len,
            return_tensors='pt'
            )
        
        encoding_labels = self.tokenizer.encode_plus(
            label,
            add_special_tokens=True,
            max_length=self.max_len,
            return_tensors='pt'            
        )
        
        return {
            'text':text,
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'labels': label,
            'labels_ids': encoding_labels['input_ids'].flatten()
        }

def create_data_loader(df, tokenizer, max_len, batch_size):
    ds = AntibioticsDataset(
        text=df.text.to_numpy(),
        labels=df.label.to_numpy(),
        tokenizer=tokenizer,
        max_len=max_len
        )

    return ds


In [None]:
test_dataset = create_data_loader(test, tokenizer, 512, 4)

In [None]:
def predict_fn(input_ids, attention_mask=None, batch_size=64, label=None):
    """
    Wrapper function for a Huggingface Transformers model into the format that KernelSHAP expects,
    i.e. where inputs and outputs are numpy arrays.
    """

    input_ids = torch.tensor(input_ids, device=device)
    attention_mask = torch.ones_like(input_ids, device=device) if attention_mask is None else torch.tensor(attention_mask, device=device)

 
    ds = torch.utils.data.TensorDataset(input_ids.long(), attention_mask.long())
    dl = torch.utils.data.DataLoader(ds, batch_size=batch_size)
    probas = []
    with torch.no_grad():
        for batch in dl:
            out = model.generate(batch[0], attention_mask=batch[1])
            generated = out.clone().detach()
            preds = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in generated]
            preds = [1 if p=='Positive' else 0 for p in preds]
            preds = torch.tensor(preds)
            probas.append(preds)

    predictions = torch.cat(probas, dim=0).detach().cpu().clone().numpy()

    return predictions


def tokens2words(tokens, seq, token_prefix="##"):#"▁"):
    """
    Utility function to aggregate 'seq' on word-level based on 'tokens'
    """

    tmp = []
    for token, x in zip(tokens, seq):
        if token.startswith(token_prefix):
            if type(x) == str:
                x = x.replace(token_prefix,"")
            tmp[-1] += x
        else:
            if type(x) == str:
                tmp.append(x)
            else:
                tmp.append(x.item())

    return tmp if type(tmp[-1]) == str else torch.tensor(tmp, device=device)

def add_attributions_to_visualizer(attributions, pred, pred_ind, label, tokens, delta, vis_data_records):
    # storing couple samples in an array for visualization purposes
    vis_data_records.append(viz.VisualizationDataRecord(
                            attributions/attributions.norm(),
                            pred,
                            pred_ind,
                            label,
                            pred_ind,
                            attributions.sum(),       
                            tokens,
                            delta)) 
   

In [None]:
test_dataset.__getitem__(1)

In [None]:
nsamples = 500
torch.cuda.empty_cache()
idx = 1
#idx = np.random.choice(test_dataset.__len__())
print(idx)
ref_token = tokenizer.pad_token_id # Could also consider <UNK> or <PAD> tokens

In [None]:
input_x = test_dataset.__getitem__(idx)
input_text = input_x['text']
label = input_x['labels']
input_ids = input_x['input_ids'].unsqueeze(0)
attention_mask = input_x['attention_mask'].unsqueeze(0)

input_tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
input_words = tokens2words(input_tokens, input_tokens)
#pred = predict_fn(input_ids=input_ids, attention_mask=attention_mask)
pred = predict_fn(input_ids=input_ids)
pred_p = pred
pred_label = pred_p

In [None]:
baseline = input_ids.detach().cpu().clone().numpy()
baseline_attn = np.zeros_like(input_ids)

# Keep CLS and SEP tokens fixed in baseline
baseline[:,1:-1] = ref_token
#baseline_attn[:, 0] = 1
#baseline_attn[:, -1] = 1

explainer = shap.KernelExplainer(predict_fn, baseline)
#explainer_attn = shap.KernelExplainer(predict_fn_label_attn, baseline_attn)

phi = explainer.shap_values(input_ids.detach().cpu().clone().numpy(), nsamples=500)
phi_words = tokens2words(input_tokens, phi.squeeze())

phi.shape
explainer.expected_value

In [None]:
phi_words

In [None]:
add_attributions_to_visualizer(phi_words, float(pred_p), pred_label,[1 if label=='Positive' else 0][0], input_words, None, vis_data_records)

In [None]:
viz.visualize_text(vis_data_records)

In [None]:
predict_fn_label = functools.partial(predict_fn, label=1)
ref_token = tokenizer.pad_token_id
input_x = test_dataset["train"]
nsamples = 500

shap_val = []
for i in range(len(input_x)):
    input_ids = input_x['input_ids'][i].unsqueeze(0)
    baseline = input_ids.detach().cpu().clone().numpy()
    baseline[:,1:-1] = ref_token

    explainer = shap.KernelExplainer(predict_fn_label, baseline)
    phi = explainer.shap_values(input_ids.detach().cpu().clone().numpy(), nsamples=nsamples)
    shap_val.append(phi)
    torch.cuda.empty_cache()

In [None]:
import pickle

# Store data (serialize)
with open('../data/'+ kb_bert +'.pickle', 'wb') as handle:
    pickle.dump(shap_val, handle, protocol=pickle.HIGHEST_PROTOCOL)

# Load data (deserialize)
with open('../data/'+ kb_bert +'.pickle', 'rb') as handle:
    unserialized_data = pickle.load(handle)


In [None]:
#tokenizer.convert_ids_to_tokens(baseline[0])
#model.config.output_hidden_states = True
#attention_mask = torch.ones_like(input_ids)

"""input_ids = input_x['input_ids'][-1].unsqueeze(0)

input_ids = torch.tensor(input_ids, device=device)
attention_mask = torch.ones_like(input_ids, device=device)

output = model(input_ids=input_ids, attention_mask=attention_mask)
logits = output.logits
hidden_states = output.hidden_states"""

In [None]:
input_tokens = []
features = {}
for i in range(len(input_x)):
    input_ids = input_x['input_ids'][i]
    input_tokens = tokenizer.convert_ids_to_tokens(input_ids)
    input_words = tokens2words(input_tokens, input_tokens)
    phi_words = tokens2words(input_tokens, shap_val[i][0])
    for j in range(len(input_words)):
        if input_words[j] in features.keys():
            old_val = features[input_words[j]]
            features[input_words[j]] = ((phi_words[j]).item() + old_val[0], old_val[1]+1)
        else:
            features[input_words[j]] = ((phi_words[j]).item(), 1)


In [None]:
{k: v for k, v in sorted(features.items(), key=lambda item: item[1][0], reverse=True)}