In [None]:
import functools
import argparse
import numpy as np 
import pandas as pd
import seaborn as sns
from tqdm.notebook import tqdm
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

import torch
import torch.nn as nn
from datasets import load_dataset
from torch.utils.data import DataLoader, Dataset
from torch.utils.tensorboard import SummaryWriter

from transformers import (
    AdamW,
    AutoConfig,
    AutoModel,
    AutoModelForPreTraining,
    AutoModelForSequenceClassification,
    AlbertForSequenceClassification,
    MT5ForConditionalGeneration,
    AutoModelWithLMHead,
    AutoTokenizer,
    AlbertTokenizer,
    T5Tokenizer,
    PretrainedConfig,
    PreTrainedTokenizer,
)
from transformers.optimization import get_linear_schedule_with_warmup

import shap
from captum.attr import visualization as viz
from captum.attr import LayerIntegratedGradients, GradientShap

In [None]:
#kb_bert = 'KB/bert-base-swedish-cased'
#kb_bert = 'xlm-roberta-base'
kb_bert = 'bert-base-multilingual-cased'

#kb_bert = 'KB/electra-base-swedish-cased-discriminator'
tokenizer = AutoTokenizer.from_pretrained(kb_bert)
model = AutoModelForSequenceClassification.from_pretrained(kb_bert)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#model.load_state_dict(torch.load("../models/KB/bert-base-swedish-cased.pt"))
#model.load_state_dict(torch.load("../models/KB/electra-base-swedish-cased-discriminator_ft.pt"))
#model.load_state_dict(torch.load("../models/xlm-roberta-base_ft.pt"))
#model.load_state_dict(torch.load("../models/bert-base-multilingual-cased_ft.pt"))

model = model.to(device)

In [None]:
test_dataset = load_dataset("csv", data_files='../models/test2.csv')
test_ind = test_dataset

def tokenize(batch):
    return tokenizer(batch['text'], max_length = 512, add_special_tokens = True)

test_dataset = test_dataset.map(tokenize, batched=True, batch_size=len(test_dataset))
test_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'label'])

In [None]:
class BertModelWrapper(nn.Module):
    
    def __init__(self, model):
        super(BertModelWrapper, self).__init__()
        self.model = model
        
    def forward(self, input_ids):        
        outputs = self.model(input_ids=input_ids)
        logits = outputs.logits
        return nn.functional.softmax(logits, dim=1)

def tokens2words(tokens, seq, token_prefix="##"):
    """
    Utility function to aggregate 'seq' on word-level based on 'tokens'
    """

    tmp = []
    for token, x in zip(tokens, seq):
        if token.startswith(token_prefix):
            if type(x) == str:
                x = x.replace(token_prefix,"")
            tmp[-1] += x
        else:
            if type(x) == str:
                tmp.append(x)
            else:
                tmp.append(x.item())

    return tmp if type(tmp[-1]) == str else torch.tensor(tmp, device=device)

def add_attributions_to_visualizer(attributions, pred, pred_ind, label, tokens, delta, vis_data_records):
    vis_data_records.append(viz.VisualizationDataRecord(
                            attributions/np.linalg.norm(attributions),
                            pred,
                            pred_ind,
                            label,
                            pred_ind,
                            attributions.sum(),       
                            tokens,
                            delta)) 

def input_ref(sentence):
    input_ids = torch.tensor(sentence, device=device)
    
    ref_token_id = tokenizer.pad_token_id # A token used for generating token reference
    baseline = input_ids.clone()
    baseline[:,1:-1] = ref_token_id 
    return input_ids, baseline


In [None]:
input_x = test_dataset['train']
input_text = input_x['text']
attention_mask = input_x['attention_mask']
label = input_x['label']
input_ids = input_x['input_ids']

#pred = predict_fn(input_ids=input_ids, attention_mask=attention_mask, output_logits=True)
#sns.distplot(pred[0])

In [None]:
bert_model_wrapper = BertModelWrapper(model)

lig = LayerIntegratedGradients(bert_model_wrapper, bert_model_wrapper.model.roberta.embeddings)
# accumalate couple samples in this array for visualization purposes
vis_data_records_ig = []

bert_model_wrapper.eval()
bert_model_wrapper.zero_grad()

In [None]:
n_steps = 500
torch.cuda.empty_cache()
#idx = 344, 673,1895, 1537
idx = 1537
#idx = np.random.choice(len(test_dataset['train']))

input_x = test_dataset['train']
input_text = input_x['text'][idx]
label = input_x['label'][idx]
input_ids = input_x['input_ids'][idx].unsqueeze(0)
attention_mask = input_x['attention_mask'][idx].unsqueeze(0)

input_tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
input_words = tokens2words(input_tokens, input_tokens)

input_ids, baseline = input_ref(input_ids)
#pred = bert_model_wrapper(input_ids)[:, 1].unsqueeze(1).item()
pred = bert_model_wrapper(input_ids)
pred_label = pred.argmax()
pred_p = pred[0, pred_label]

In [None]:
attributions, delta = lig.attribute(inputs=input_ids, n_steps=n_steps,
                                   baselines=baseline,
                                    internal_batch_size=16,
                                    return_convergence_delta=True,
                                    target=pred_label
                                   )

In [None]:
# storing couple samples in an array for visualization purposes
att = attributions.sum(dim=2).squeeze(0)
#attributions = attributions / torch.norm(attributions)
att = att.detach().cpu().clone().numpy()

phi_words = tokens2words(input_tokens, att)


In [None]:
print('pred: ', pred_label.item(), '(', '%.2f' % pred_p.item(), ')', ', delta: ', abs(delta.item()))

add_attributions_to_visualizer(att, pred_p, pred_label, label, input_words, delta, vis_data_records_ig)

In [None]:
viz.visualize_text(vis_data_records_ig)

In [None]:
input_x = test_dataset["train"]
n_steps = 500
torch.cuda.empty_cache()
ig_val = []

for i in tqdm(range(len(input_x))):
    input_ids = input_x['input_ids'][i].unsqueeze(0)
    input_ids, baseline = input_ref(input_ids)

    attributions, delta = lig.attribute(inputs=input_ids, n_steps=n_steps,
                                   baselines=baseline,
                                    internal_batch_size=8,
                                    return_convergence_delta=True,
                                    target=1
                                   )
    ig_val.append(attributions)
    torch.cuda.empty_cache()

In [None]:
import pickle

# Store data (serialize)
with open('../data/'+ kb_bert +'_ig.pickle', 'wb') as handle:
    pickle.dump(ig_val, handle, protocol=pickle.HIGHEST_PROTOCOL)

# Load data (deserialize)
with open('../data/'+ kb_bert +'_ig.pickle', 'rb') as handle:
    unserialized_data = pickle.load(handle)


In [None]:
ig_val[0].shape

In [None]:
input_tokens = []
features = {}
for i in range(len(input_x)):
    input_ids = input_x['input_ids'][i]
    input_tokens = tokenizer.convert_ids_to_tokens(input_ids)
    input_words = tokens2words(input_tokens, input_tokens)
    ig_val_sum = ig_val[i].sum(dim=2).squeeze(0)
    ig_val_sum = ig_val_sum.detach().cpu().clone().numpy()
    phi_words = tokens2words(input_tokens, ig_val_sum)
    for j in range(len(input_words)):
        if input_words[j] in features.keys():
            old_val = features[input_words[j]]
            features[input_words[j]] = ((phi_words[j]).item() + old_val[0], old_val[1]+1)
        else:
            features[input_words[j]] = ((phi_words[j]).item(), 1)


In [None]:
{k: v for k, v in sorted(features.items(), key=lambda item: item[1][0], reverse=True)}