In [None]:
import torch
import torch.nn as nn

from transformers import AutoTokenizer
from transformers import AutoModelForSequenceClassification

from datasets import load_dataset
import numpy as np

from captum.attr import LayerIntegratedGradients, GradientShap
from captum.attr import visualization

import shap
import scipy as sp

In [None]:
tokenizer = AutoTokenizer.from_pretrained('bert-base-multilingual-cased')
model = AutoModelForSequenceClassification.from_pretrained('bert-base-multilingual-cased')
model.load_state_dict(torch.load("../models/bert-base-multilingual-cased.pt"))


#tokenizer = AutoTokenizer.from_pretrained('bert-base-multilingual-cased')
#model = AutoModelForSequenceClassification.from_pretrained('bert-base-multilingual-cased')
#model.load_state_dict(torch.load("../models/bert-base-multilingual-cased.pt"))
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
class BertModelWrapper(nn.Module):
    
    def __init__(self, model):
        super(BertModelWrapper, self).__init__()
        self.model = model
        
    def forward(self, input_ids):        
        outputs = self.model(input_ids=input_ids)
        logits = outputs.logits
        return nn.functional.softmax(logits, dim=1)

In [None]:
def input_ref(model_wrapper, sentence):
    input_ids = torch.tensor([tokenizer.encode(sentence, add_special_tokens=True)], device=device)
    
    ref_token_id = tokenizer.pad_token_id # A token used for generating token reference
    sep_token_id = tokenizer.sep_token_id # A token used as a separator between question and text and it is also added to the end of the text.
    cls_token_id = tokenizer.cls_token_id # A token used for prepending to the concatenated question-text word sequence
    
    #ref_input_ids = [cls_token_id] + (input_ids.size(1)-2) * [ref_token_id] + [sep_token_id]
    ref_input_ids = [ref_token_id] + (input_ids.size(1)-2) * [ref_token_id] + [ref_token_id]
    ref_input_ids = torch.tensor([ref_input_ids], device=device)
    
    return input_ids, ref_input_ids

In [None]:
def add_attributions_to_visualizer(attributions, tokens, pred, pred_ind, label, delta, vis_data_records):
    attributions = attributions.sum(dim=2).squeeze(0)
    #attributions = attributions / torch.norm(attributions)
    attributions = attributions.detach().cpu().clone().numpy()
    
    # storing couple samples in an array for visualization purposes
    vis_data_records.append(visualization.VisualizationDataRecord(
                            attributions/np.linalg.norm(attributions),
                            pred,
                            pred_ind,
                            label,
                            "label",
                            attributions.sum(),       
                            tokens[:len(attributions)],
                            delta))    

In [None]:
test_dataset = load_dataset("csv", data_files='../data/test2.csv')

In [None]:
def tokenize(batch):
    return tokenizer(batch['text'], padding='max_length', truncation=True,  max_length = 512, add_special_tokens = True)

test_dataset = test_dataset.map(tokenize, batched=True, batch_size=len(test_dataset))
test_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'label'])

In [None]:
bert_model_wrapper = BertModelWrapper(model)
lig = LayerIntegratedGradients(bert_model_wrapper, bert_model_wrapper.model.bert.embeddings)

# accumalate couple samples in this array for visualization purposes
vis_data_records_ig = []

In [None]:
model.to(device)
bert_model_wrapper.eval()
bert_model_wrapper.zero_grad()

In [None]:
r = np.random.choice(len(test_dataset['train']))
label = test_dataset['train']['label'][r].item()
sentence = test_dataset['train']['text'][r]

In [None]:
sentence

In [None]:
input_ids, ref_input_ids = input_ref(bert_model_wrapper, sentence)

if input_ids.shape[1]>512:
    new_shape = input_ids.shape[1] - 512
    input_ids = input_ids[:,:-new_shape]
    ref_input_ids = ref_input_ids[:,:-new_shape]

pred = bert_model_wrapper(input_ids)[:, 1].unsqueeze(1).item()
pred_ind = round(pred)

In [None]:
pred

In [None]:
attributions, delta = lig.attribute(inputs=input_ids, n_steps=500,
                                   baselines=ref_input_ids,
                                    internal_batch_size=32,
                                    return_convergence_delta=True,
                                    target=pred_ind
                                   )

In [None]:
print('pred: ', pred_ind, '(', '%.2f' % pred, ')', ', delta: ', abs(delta))

tokens = tokenizer.convert_ids_to_tokens(input_ids[0].cpu().clone().numpy().tolist())    
add_attributions_to_visualizer(attributions, tokens, pred, pred_ind, label, delta, vis_data_records_ig)

In [None]:
visualization.visualize_text(vis_data_records_ig)

In [None]:
def f(x):
    tv = torch.tensor([tokenizer.encode(v, padding='max_length', max_length=512,truncation=True) for v in x]).cuda()    
    attention_mask = (tv!=0).type(torch.int64).cuda()
    outputs = model(tv,attention_mask=attention_mask)[0].detach().cpu().numpy()
    scores = (np.exp(outputs).T / np.exp(outputs).sum(-1)).T
    return scores

In [None]:
explainer = shap.Explainer(f, tokenizer)
shap_values = explainer([sentence])

In [None]:
shap.plots.text(shap_values[:,:,1])