# 1. Imports

In [44]:
import numpy as np
import json
from tqdm.notebook import tqdm

import torch

from datasets import load_dataset, load_metric
from transformers import BertTokenizer, DataCollatorWithPadding, BertForSequenceClassification, BertConfig, \
    TrainingArguments, Trainer

from captum.attr import LayerIntegratedGradients

In [3]:
repo_name = "ft-sentiment"
SEED = 42
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

# 2. Preprocess data

In [4]:
imdb = load_dataset("imdb")

Reusing dataset imdb (/home/akshen/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1)


  0%|          | 0/3 [00:00<?, ?it/s]

In [5]:
train_dataset = imdb["train"].shuffle(seed=SEED).select([i for i in list(range(300))])
test_dataset = imdb["test"].shuffle(seed=SEED).select([i for i in list(range(30))])
print(train_dataset[0])
print(test_dataset[0])

Loading cached shuffled indices for dataset at /home/akshen/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1/cache-8a9e43a6ac4acdff.arrow
Loading cached shuffled indices for dataset at /home/akshen/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1/cache-2eff9f118d84c6fe.arrow


{'text': 'There is no relation at all between Fortier and Profiler but the fact that both are police series about violent crimes. Profiler looks crispy, Fortier looks classic. Profiler plots are quite simple. Fortier\'s plot are far more complicated... Fortier looks more like Prime Suspect, if we have to spot similarities... The main character is weak and weirdo, but have "clairvoyance". People like to compare, to judge, to evaluate. How about just enjoying? Funny thing too, people writing Fortier looks American but, on the other hand, arguing they prefer American series (!!!). Maybe it\'s the language, or the spirit, but I think this series is more English than American. By the way, the actors are really good and funny. The acting is not superficial at all...', 'label': 1}
{'text': "<br /><br />When I unsuspectedly rented A Thousand Acres, I thought I was in for an entertaining King Lear story and of course Michelle Pfeiffer was in it, so what could go wrong?<br /><br />Very quickly, 

In [6]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [7]:
# Prepare the text inputs for the model
def preprocess_function(examples):
    return tokenizer(examples["text"], truncation=True)

tokenized_train = train_dataset.map(preprocess_function, batched=True)
tokenized_test = test_dataset.map(preprocess_function, batched=True)

Loading cached processed dataset at /home/akshen/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1/cache-2d9e896ec1217be0.arrow
Loading cached processed dataset at /home/akshen/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1/cache-d70b3b7323b283e5.arrow


In [8]:
# Use data_collector to convert our samples to PyTorch tensors and concatenate them with the correct amount of padding
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

# 3. Training the model

In [None]:
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)

In [None]:
# Define the evaluation metrics 
def compute_metrics(eval_pred):
    load_accuracy = load_metric("accuracy")
    load_f1 = load_metric("f1")
    
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    accuracy = load_accuracy.compute(predictions=predictions, references=labels)["accuracy"]
    f1 = load_f1.compute(predictions=predictions, references=labels)["f1"]
    return {"accuracy": accuracy, "f1": f1}

In [None]:
# Define a new Trainer with all the objects we constructed so far
training_args = TrainingArguments(
    output_dir=repo_name,
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=1,
    weight_decay=0.01,
    save_strategy="epoch", 
    push_to_hub=False,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_test,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

In [None]:
# Train the model
trainer.train()

In [None]:
# Compute the evaluation metrics
trainer.evaluate()

# 5. Interpreting

In [9]:
model_folder = repo_name + "/checkpoint-19"

In [10]:
model = BertForSequenceClassification.from_pretrained(model_folder)
model.to(device)
model.eval()
model.zero_grad()

In [11]:
tokenizer = BertTokenizer.from_pretrained(model_folder)

In [12]:
ref_token_id = tokenizer.pad_token_id # A token used for generating token reference
sep_token_id = tokenizer.sep_token_id # A token used as a separator between question and text and it is also added to the end of the text.
cls_token_id = tokenizer.cls_token_id # A token used for prepending to the concatenated question-text word sequence

In [21]:
def predict(inputs):
    return model(inputs)[0]

def custom_forward(inputs):
    preds = predict(inputs)
    return torch.softmax(preds, dim = 1)[0][0].unsqueeze(-1)

def summarize_attributions(attributions):
    attributions = attributions.sum(dim=-1).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    return attributions

In [22]:
lig = LayerIntegratedGradients(custom_forward, model.bert.embeddings)

In [33]:
input_ids

tensor([[  101,  2045,  2003,  2053,  7189,  2012,  2035,  2090,  3481,  3771,
          1998,  6337,  2099,  2021,  1996,  2755,  2008,  2119,  2024,  2610,
          2186,  2055,  6355,  6997,  1012,  6337,  2099,  3504, 15594,  2100,
          1010,  3481,  3771,  3504,  4438,  1012,  6337,  2099, 14811,  2024,
          3243,  3722,  1012,  3481,  3771,  1005,  1055,  5436,  2024,  2521,
          2062,  8552,  1012,  1012,  1012,  3481,  3771,  3504,  2062,  2066,
          3539,  8343,  1010,  2065,  2057,  2031,  2000,  3962, 12319,  1012,
          1012,  1012,  1996,  2364,  2839,  2003,  5410,  1998,  6881,  2080,
          1010,  2021,  2031,  1000, 17936,  6767,  7054,  3401,  1000,  1012,
          2111,  2066,  2000, 12826,  1010,  2000,  3648,  1010,  2000, 16157,
          1012,  2129,  2055,  2074,  9107,  1029,  6057,  2518,  2205,  1010,
          2111,  3015,  3481,  3771,  3504,  2137,  2021,  1010,  2006,  1996,
          2060,  2192,  1010,  9177,  2027,  9544,  

In [48]:
export_data = []

In [None]:
for data in tqdm(tokenized_train):
    tmp = {}
    input_ids = torch.Tensor([data["input_ids"]], device=device).to(torch.long)
    ref_input_ids = [cls_token_id] + [ref_token_id] * (len(data["input_ids"]) - 2) + [sep_token_id]
    ref_input_ids = torch.tensor([ref_input_ids], device=device).to(torch.long)
    attributions, delta = lig.attribute(inputs=input_ids,
                                    baselines=ref_input_ids,
                                    return_convergence_delta=True)
    attributions_sum = summarize_attributions(attributions)
    tmp["text"] = data["text"]
    tmp["labels"] = data["label"]
    tmp["words"] = tokenizer.tokenize(data["text"])
    tmp["annotations"] = [
        {
            "label": int(torch.argmax(torch.softmax(predict(input_ids), dim = 1))),
            "scores": attributions_sum[1:len(data["input_ids"]) - 1].tolist(),
        }
    ]
    export_data.append(tmp)

  0%|          | 0/5 [00:00<?, ?it/s]

In [46]:
export_data

[{'words': ['there',
   'is',
   'no',
   'relation',
   'at',
   'all',
   'between',
   'fort',
   '##ier',
   'and',
   'profile',
   '##r',
   'but',
   'the',
   'fact',
   'that',
   'both',
   'are',
   'police',
   'series',
   'about',
   'violent',
   'crimes',
   '.',
   'profile',
   '##r',
   'looks',
   'crisp',
   '##y',
   ',',
   'fort',
   '##ier',
   'looks',
   'classic',
   '.',
   'profile',
   '##r',
   'plots',
   'are',
   'quite',
   'simple',
   '.',
   'fort',
   '##ier',
   "'",
   's',
   'plot',
   'are',
   'far',
   'more',
   'complicated',
   '.',
   '.',
   '.',
   'fort',
   '##ier',
   'looks',
   'more',
   'like',
   'prime',
   'suspect',
   ',',
   'if',
   'we',
   'have',
   'to',
   'spot',
   'similarities',
   '.',
   '.',
   '.',
   'the',
   'main',
   'character',
   'is',
   'weak',
   'and',
   'weird',
   '##o',
   ',',
   'but',
   'have',
   '"',
   'clair',
   '##vo',
   '##yan',
   '##ce',
   '"',
   '.',
   'people',
   'like',


In [None]:
json.dump({'data': export_data}, open('import.json', 'w'), indent=2)