# 1. Imports

In [None]:
import numpy as np

import torch

from datasets import load_dataset, load_metric
from transformers import BertTokenizer, DataCollatorWithPadding, BertForSequenceClassification, BertConfig, \
    TrainingArguments, Trainer

from captum.attr import visualization as viz
from captum.attr import IntegratedGradients, LayerConductance, LayerIntegratedGradients
from captum.attr import configure_interpretable_embedding_layer, remove_interpretable_embedding_layer

In [None]:
repo_name = "ft-sentiment"
SEED = 42
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

# 2. Preprocess data

In [None]:
imdb = load_dataset("imdb")

In [None]:
train_dataset = imdb["train"].shuffle(seed=SEED).select([i for i in list(range(300))])
test_dataset = imdb["test"].shuffle(seed=SEED).select([i for i in list(range(30))])
print(train_dataset[0])
print(test_dataset[0])

In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [None]:
# Prepare the text inputs for the model
def preprocess_function(examples):
    return tokenizer(examples["text"], truncation=True)

tokenized_train = train_dataset.map(preprocess_function, batched=True)
tokenized_test = test_dataset.map(preprocess_function, batched=True)

In [None]:
# Use data_collector to convert our samples to PyTorch tensors and concatenate them with the correct amount of padding
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

# 3. Training the model

In [None]:
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)

In [None]:
# Define the evaluation metrics 
def compute_metrics(eval_pred):
    load_accuracy = load_metric("accuracy")
    load_f1 = load_metric("f1")
    
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    accuracy = load_accuracy.compute(predictions=predictions, references=labels)["accuracy"]
    f1 = load_f1.compute(predictions=predictions, references=labels)["f1"]
    return {"accuracy": accuracy, "f1": f1}

In [None]:
# Define a new Trainer with all the objects we constructed so far
training_args = TrainingArguments(
    output_dir=repo_name,
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=1,
    weight_decay=0.01,
    save_strategy="epoch", 
    push_to_hub=False,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_test,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

In [None]:
# Train the model
trainer.train()

In [None]:
# Compute the evaluation metrics
trainer.evaluate()

# 5. Interpreting

In [None]:
model_folder = repo_name + "/checkpoint-19"

In [None]:
model = BertForSequenceClassification.from_pretrained(model_folder)
model.to(device)
model.eval()
model.zero_grad()

In [None]:
tokenizer = BertTokenizer.from_pretrained(model_folder)

In [None]:
ref_token_id = tokenizer.pad_token_id # A token used for generating token reference
sep_token_id = tokenizer.sep_token_id # A token used as a separator between question and text and it is also added to the end of the text.
cls_token_id = tokenizer.cls_token_id # A token used for prepending to the concatenated question-text word sequence

In [None]:
def construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id):

    text_ids = tokenizer.encode(text, add_special_tokens=False)
    # construct input token ids
    input_ids = [cls_token_id] + text_ids + [sep_token_id]
    # construct reference token ids 
    ref_input_ids = [cls_token_id] + [ref_token_id] * len(text_ids) + [sep_token_id]

    return torch.tensor([input_ids], device=device), torch.tensor([ref_input_ids], device=device), len(text_ids)

In [None]:
def predict(inputs):
    return model(inputs)[0]

def custom_forward(inputs):
    preds = predict(inputs)
    return torch.softmax(preds, dim = 1)[:, 0].unsqueeze(-1)

def summarize_attributions(attributions):
    attributions = attributions.sum(dim=-1).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    return attributions

In [None]:
lig = LayerIntegratedGradients(custom_forward, model.bert.embeddings)

In [None]:
def get_attribution_score(data):
    words = list(map(tokenizer.tokenize, data["text"]))
    lengths = list(map(len, data["input_ids"]))
    max_len = max(lengths)
    input_ids = []
    for row in data["input_ids"]:
        tmp = row.copy()
        if len(row) < max_len:
            tmp.extend([ref_token_id] * (max_len - len(row)))
        input_ids.append(tmp)
            
    input_ids = torch.Tensor(input_ids, device=device).to(torch.long)
    ref_input_ids = []
    for l in lengths:
        pad_len = max(max_len - l, 0)
        ref_input_ids.append([cls_token_id] + [ref_token_id] * (l - 2) + [sep_token_id] + [ref_token_id] * pad_len)
    
    ref_input_ids = torch.tensor(ref_input_ids, device=device).to(torch.long)

    attributions, delta = lig.attribute(inputs=input_ids,
                                    baselines=ref_input_ids,
                                    return_convergence_delta=True)
    attributions_sum = summarize_attributions(attributions)

    scores = []
    for max_len, row, word in zip(lengths, attributions_sum, words):
        tmp = row[1:max_len - 1].tolist()
        if len(tmp) < len(word):
            tmp.extend([0] * (len(word) - len(tmp)))
        scores.append(tmp)
    
    return {
        'words': words,
        'scores': scores,
    }

In [None]:
train_attr = tokenized_test.map(get_attribution_score, batched=True, batch_size=1)

In [None]:
train_attr[0]