In [None]:
import functools
import shap
import argparse
import numpy as np 
import pandas as pd
import seaborn as sns
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

import torch
import torch.nn as nn
from datasets import load_dataset
from torch.utils.data import DataLoader, Dataset
from torch.utils.tensorboard import SummaryWriter

from transformers import (
    AdamW,
    AutoConfig,
    AutoModel,
    AutoModelForPreTraining,
    AutoModelForSequenceClassification,
    AlbertForSequenceClassification,
    MT5ForConditionalGeneration,
    AutoModelWithLMHead,
    AutoTokenizer,
    AlbertTokenizer,
    T5Tokenizer,
    PretrainedConfig,
    PreTrainedTokenizer,
)
from transformers.optimization import get_linear_schedule_with_warmup
from captum.attr import visualization as viz
from captum.attr import LayerIntegratedGradients, GradientShap

## Model

In [None]:
#kb_bert = 'KB/bert-base-swedish-cased'
kb_bert = 'bert-base-multilingual-cased'
#kb_bert = 'xlm-roberta-base'
tokenizer = AutoTokenizer.from_pretrained(kb_bert)
model = AutoModelForSequenceClassification.from_pretrained(kb_bert)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.load_state_dict(torch.load("../models/"+kb_bert+"_ft.pt"))

model = model.to(device)

vis_data_records = []

## Data

In [None]:
test_dataset = load_dataset("csv", data_files='../data/test2.csv')
train_dataset = load_dataset("csv", data_files='../data/train2.csv')
#test_dataset = load_dataset("csv", data_files='../data/test2_names.csv')
#train_dataset = load_dataset("csv", data_files='../data/train2_names.csv')

test_ind = test_dataset

def tokenize(batch):
    return tokenizer(batch['text'], max_length = 512, add_special_tokens = True)

test_dataset = test_dataset.map(tokenize, batched=True, batch_size=len(test_dataset))
test_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'label'])


train_dataset = train_dataset.map(tokenize, batched=True, batch_size=len(train_dataset))
train_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'label'])

## Prediction for SHAP

In [None]:
def predict_fn(input_ids, attention_mask=None, batch_size=64, label=None,
               output_logits=False):
    """
    Wrapper function for a Huggingface Transformers model into the format that KernelSHAP expects,
    i.e. where inputs and outputs are numpy arrays.
    """

    input_ids = torch.tensor(input_ids, device=device)
    attention_mask = torch.ones_like(input_ids, device=device) if attention_mask is None else torch.tensor(attention_mask, device=device)

    ds = torch.utils.data.TensorDataset(input_ids.long(), attention_mask.long())
    dl = torch.utils.data.DataLoader(ds, batch_size=batch_size)
    probas = []
    logits = []
    with torch.no_grad():
        for batch in dl:
            out = model(batch[0], attention_mask=batch[1])
            logits.append(out[0].clone().detach())
            probas.append(torch.nn.functional.softmax(out[0],
                                                      dim=1).detach())
    logits = torch.cat(logits, dim=0).detach().cpu().clone().numpy()
    probas = torch.cat(probas, dim=0).detach().cpu().clone().numpy()

    if label is not None:
        probas = probas[:, label]
        logits = logits[:, label]

    return (probas, logits) if output_logits else probas


def tokens2words(tokens, seq, token_prefix="##"):
    """
    Utility function to aggregate 'seq' on word-level based on 'tokens'
    """

    tmp = []
    for token, x in zip(tokens, seq):
        if token.startswith(token_prefix):
            if type(x) == str:
                x = x.replace(token_prefix,"")
            tmp[-1] += x
        else:
            if type(x) == str:
                tmp.append(x)
            else:
                tmp.append(x.item())

    return tmp if type(tmp[-1]) == str else torch.tensor(tmp, device=device)

def tokens2wordssentence(tokens, seq, token_prefix="▁"):
    """
    Utility function to aggregate 'seq' on word-level based on 'tokens'
    """

    tmp = []
    for token, x in zip(tokens, seq):
        if token.startswith(token_prefix):
            if type(x) == str:
                x = x.replace(token_prefix," ")
            if type(x) == str:
                tmp.append(x)
            else:
                tmp.append(x.item())
                
        elif not token.startswith(" "):
            if type(x) == str:
                x = x.replace(" ","")
            if len(tmp)>0:
                tmp[-1] += x
        else:
            if type(x) == str:
                tmp.append(x)
            else:
                tmp.append(x.item())
                
    if len(tmp)==0:
        print(tokens,tmp)
    
    return tmp if type(tmp[-1]) == str else torch.tensor(tmp, device=device)

def add_attributions_to_visualizer(attributions, pred, pred_ind, label, tokens, delta, vis_data_records):
    # storing couple samples in an array for visualization purposes
    vis_data_records.append(viz.VisualizationDataRecord(
                            attributions/attributions.norm(),
                            pred,
                            pred_ind,
                            label,
                            pred_ind,
                            attributions.sum(),       
                            tokens,
                            delta)) 
   

## SHAP

In [None]:
nsamples = 1000

idx = np.random.choice(len(test_dataset['train']))
#idx = 20
#idx = 183
#idx = 471
#idx = 1469//872
#idx = 4625

#idx = 295
#idx = 1881
idx = 58
print(idx)
ref_token = tokenizer.pad_token_id # Could also consider <UNK> or <MASK> tokens

In [None]:
#input_x = train_dataset['train']
input_x = test_dataset['train']
input_text = input_x['text'][idx]
label = input_x['label'][idx]
input_ids = input_x['input_ids'][idx].unsqueeze(0)
attention_mask = input_x['attention_mask'][idx].unsqueeze(0)

input_tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
input_words = tokens2words(input_tokens, input_tokens)
pred = predict_fn(input_ids=input_ids)
pred_label = pred.argmax()
pred_p = pred[0, pred_label]

In [None]:
baseline = input_ids.detach().cpu().clone().numpy()

# Keep CLS and SEP tokens fixed in baseline
baseline[:,1:-1] = ref_token

In [None]:
predict_fn_label = functools.partial(predict_fn, label=pred_label)

explainer = shap.KernelExplainer(predict_fn_label, baseline)

phi = explainer.shap_values(input_ids.detach().cpu().clone().numpy(), nsamples=nsamples)
phi_words = tokens2words(input_tokens, phi.squeeze())

phi.shape
explainer.expected_value

In [None]:
label

In [None]:
add_attributions_to_visualizer(phi_words, pred_p, pred_label, label, input_words, None, vis_data_records)

In [None]:
viz.visualize_text(vis_data_records)

## Integrated Gradiente (IG)

In [None]:
class BertModelWrapper(nn.Module):
    
    def __init__(self, model):
        super(BertModelWrapper, self).__init__()
        self.model = model
        
    def forward(self, input_ids):        
        outputs = self.model(input_ids=input_ids)
        logits = outputs.logits
        return nn.functional.softmax(logits, dim=1)

def tokens2words(tokens, seq, token_prefix="##"):
    """
    Utility function to aggregate 'seq' on word-level based on 'tokens'
    """

    tmp = []
    for token, x in zip(tokens, seq):
        if token.startswith(token_prefix):
            if type(x) == str:
                x = x.replace(token_prefix,"")
            tmp[-1] += x
        else:
            if type(x) == str:
                tmp.append(x)
            else:
                tmp.append(x.item())

    return tmp if type(tmp[-1]) == str else torch.tensor(tmp, device=device)

def add_attributions_to_visualizer_ig(attributions, pred, pred_ind, label, tokens, delta, vis_data_records):
    vis_data_records.append(viz.VisualizationDataRecord(
                            attributions/attributions.norm(),
                            pred,
                            pred_ind,
                            label,
                            pred_ind,
                            attributions.sum(),       
                            tokens,
                            delta)) 

def input_ref(sentence):
    input_ids = torch.tensor(sentence, device=device)
    
    ref_token_id = tokenizer.pad_token_id # A token used for generating token reference
    baseline = input_ids.clone()
    baseline[:,1:-1] = ref_token_id 
    return input_ids, baseline


In [None]:
bert_model_wrapper = BertModelWrapper(model)

lig = LayerIntegratedGradients(bert_model_wrapper, bert_model_wrapper.model.bert.embeddings)
# accumalate couple samples in this array for visualization purposes
vis_data_records_ig = []

bert_model_wrapper.eval()
bert_model_wrapper.zero_grad()

In [None]:
n_steps = 1500
input_ids, baseline = input_ref(input_ids)

pred = bert_model_wrapper(input_ids)
pred_label = pred.argmax()
pred_p = pred[0, pred_label]

In [None]:
attributions, delta = lig.attribute(inputs=input_ids, n_steps=n_steps,
                                    baselines=baseline,
                                    internal_batch_size=16,
                                    return_convergence_delta=True,
                                    target=pred_label)

In [None]:
print('pred: ', pred_label.item(), '(', '%.2f' % pred_p.item(), ')', ', delta: ', abs(delta.item()))

# storing couple samples in an array for visualization purposes
att = attributions.sum(dim=2).squeeze(0)
#attributions = attributions / torch.norm(attributions)
att = att.detach().cpu().clone().numpy()

att = tokens2words(input_tokens, att)

add_attributions_to_visualizer_ig(att, pred_p, pred_label, label, input_words, delta, vis_data_records_ig)

In [None]:
len(att)

In [None]:
viz.visualize_text(vis_data_records_ig)