# 1. Imports

In [1]:
import numpy as np

import torch

from datasets import load_dataset, load_metric
from transformers import BertTokenizer, DataCollatorWithPadding, BertForSequenceClassification, BertConfig, \
    TrainingArguments, Trainer

from captum.attr import visualization as viz
from captum.attr import IntegratedGradients, LayerConductance, LayerIntegratedGradients
from captum.attr import configure_interpretable_embedding_layer, remove_interpretable_embedding_layer

In [2]:
repo_name = "ft-sentiment"
SEED = 42
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

# 2. Preprocess data

In [3]:
imdb = load_dataset("imdb")

Reusing dataset imdb (/home/akshen/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1)


  0%|          | 0/3 [00:00<?, ?it/s]

In [4]:
train_dataset = imdb["train"].shuffle(seed=SEED).select([i for i in list(range(300))])
test_dataset = imdb["test"].shuffle(seed=SEED).select([i for i in list(range(30))])
print(train_dataset[0])
print(test_dataset[0])

Loading cached shuffled indices for dataset at /home/akshen/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1/cache-8a9e43a6ac4acdff.arrow
Loading cached shuffled indices for dataset at /home/akshen/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1/cache-2eff9f118d84c6fe.arrow


{'text': 'There is no relation at all between Fortier and Profiler but the fact that both are police series about violent crimes. Profiler looks crispy, Fortier looks classic. Profiler plots are quite simple. Fortier\'s plot are far more complicated... Fortier looks more like Prime Suspect, if we have to spot similarities... The main character is weak and weirdo, but have "clairvoyance". People like to compare, to judge, to evaluate. How about just enjoying? Funny thing too, people writing Fortier looks American but, on the other hand, arguing they prefer American series (!!!). Maybe it\'s the language, or the spirit, but I think this series is more English than American. By the way, the actors are really good and funny. The acting is not superficial at all...', 'label': 1}
{'text': "<br /><br />When I unsuspectedly rented A Thousand Acres, I thought I was in for an entertaining King Lear story and of course Michelle Pfeiffer was in it, so what could go wrong?<br /><br />Very quickly, 

In [5]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [6]:
# Prepare the text inputs for the model
def preprocess_function(examples):
    return tokenizer(examples["text"], truncation=True, padding=True)

tokenized_train = train_dataset.map(preprocess_function, batched=True)
tokenized_test = test_dataset.map(preprocess_function, batched=True)

Loading cached processed dataset at /home/akshen/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1/cache-91c464eab7729f65.arrow
Loading cached processed dataset at /home/akshen/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1/cache-c2eaed3ae2cad3fb.arrow


In [7]:
# Use data_collector to convert our samples to PyTorch tensors and concatenate them with the correct amount of padding
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

# 3. Training the model

In [None]:
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)

In [None]:
# Define the evaluation metrics 
def compute_metrics(eval_pred):
    load_accuracy = load_metric("accuracy")
    load_f1 = load_metric("f1")
    
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    accuracy = load_accuracy.compute(predictions=predictions, references=labels)["accuracy"]
    f1 = load_f1.compute(predictions=predictions, references=labels)["f1"]
    return {"accuracy": accuracy, "f1": f1}

In [None]:
# Define a new Trainer with all the objects we constructed so far
training_args = TrainingArguments(
    output_dir=repo_name,
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=1,
    weight_decay=0.01,
    save_strategy="epoch", 
    push_to_hub=False,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_test,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

In [None]:
# Train the model
trainer.train()

In [None]:
# Compute the evaluation metrics
trainer.evaluate()

# 5. Interpreting

In [14]:
model_folder = repo_name + "/checkpoint-19"

In [15]:
model = BertForSequenceClassification.from_pretrained(model_folder)
model.to(device)
model.eval()
model.zero_grad()

In [16]:
tokenizer = BertTokenizer.from_pretrained(model_folder)

In [17]:
ref_token_id = tokenizer.pad_token_id # A token used for generating token reference
sep_token_id = tokenizer.sep_token_id # A token used as a separator between question and text and it is also added to the end of the text.
cls_token_id = tokenizer.cls_token_id # A token used for prepending to the concatenated question-text word sequence

In [18]:
def construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id):

    text_ids = tokenizer.encode(text, add_special_tokens=False)
    # construct input token ids
    input_ids = [cls_token_id] + text_ids + [sep_token_id]
    # construct reference token ids 
    ref_input_ids = [cls_token_id] + [ref_token_id] * len(text_ids) + [sep_token_id]

    return torch.tensor([input_ids], device=device), torch.tensor([ref_input_ids], device=device), len(text_ids)

In [19]:
def predict(inputs):
    return model(inputs)[0]

def custom_forward(inputs):
    preds = predict(inputs)
    return torch.softmax(preds, dim = 1)[:, 0].unsqueeze(-1)

def summarize_attributions(attributions):
    attributions = attributions.sum(dim=-1).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    return attributions

In [20]:
lig = LayerIntegratedGradients(custom_forward, model.bert.embeddings)

In [21]:
text = "I am a gay black man"

In [22]:
input_ids, ref_input_ids, sep_id = construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id)
# token_type_ids, ref_token_type_ids = construct_input_ref_token_type_pair(input_ids, sep_id)
# position_ids, ref_position_ids = construct_input_ref_pos_id_pair(input_ids)
# attention_mask = construct_attention_mask(input_ids)

indices = input_ids[0].detach().tolist()
all_tokens = tokenizer.convert_ids_to_tokens(indices)

In [23]:
input_ids

tensor([[ 101, 1045, 2572, 1037, 5637, 2304, 2158,  102]])

In [24]:
input_ids.size()

torch.Size([1, 8])

In [39]:
ref_input_ids

tensor([[101,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0

In [40]:
ref_input_ids.size()

torch.Size([1, 512])

In [25]:
attributions, delta = lig.attribute(inputs=input_ids,
                                    baselines=ref_input_ids,
                                    return_convergence_delta=True)

In [26]:
attributions

tensor([[[ 0.0000e+00,  0.0000e+00, -0.0000e+00,  ..., -0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 5.4194e-05,  2.6661e-04, -2.6724e-06,  ..., -2.4288e-06,
          -6.5795e-05,  3.1591e-05],
         [ 1.5350e-04,  3.2380e-04, -1.7821e-04,  ..., -1.1435e-04,
           4.4291e-04, -1.8851e-04],
         ...,
         [-5.3501e-05,  6.2052e-04,  1.3276e-04,  ..., -1.7988e-04,
           4.5214e-05, -2.0778e-04],
         [ 6.5290e-07,  3.3863e-04, -8.2308e-05,  ...,  1.1222e-04,
           3.7409e-05, -1.3750e-04],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
          -0.0000e+00, -0.0000e+00]]], dtype=torch.float64)

In [27]:
attributions.size()

torch.Size([1, 8, 768])

In [None]:
score = predict(input_ids)

print('Question: ', text)
print('Predicted Answer: ' + str(torch.argmax(score[0]).numpy()) + ', prob ungrammatical: ' + str(torch.softmax(score, dim = 1)[0][0].detach().numpy()))

In [28]:
attributions_sum = summarize_attributions(attributions)

In [29]:
attributions_sum

tensor([ 0.0000,  0.3756,  0.7016, -0.3179,  0.3146, -0.3724, -0.1675,  0.0000],
       dtype=torch.float64)

In [30]:
attributions_sum.size()

torch.Size([8])

In [None]:
attributions_sum[1:-1]

In [None]:
# storing couple samples in an array for visualization purposes
score_vis = viz.VisualizationDataRecord(
                        attributions_sum,
                        torch.softmax(score, dim = 1)[0][0],
                        torch.argmax(torch.softmax(score, dim = 1)[0]),
                        1,
                        text,
                        attributions_sum.sum(),       
                        all_tokens,
                        delta)

print('\033[1m', 'Visualization For Score', '\033[0m')
viz.visualize_text([score_vis])

In [31]:
def get_attribution_score(data):
    words = list(map(tokenizer.tokenize, data["text"]))
    lengths = list(map(len, words))
    input_ids = torch.Tensor(data["input_ids"], device=device).to(torch.long)
    ref_input_ids = []
    for l in lengths:
        pad_len = max(int(input_ids.size(1)) - l - 2, 0)
        ref_input_ids.append([cls_token_id] + [ref_token_id] * l + [sep_token_id] + [ref_token_id] * pad_len)
    
    ref_input_ids = torch.tensor(ref_input_ids, device=device).to(torch.long)
#     input_ids, ref_input_ids, sep_id = construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id)
    attributions, delta = lig.attribute(inputs=input_ids,
                                    baselines=ref_input_ids,
                                    return_convergence_delta=True)
    attributions_sum = summarize_attributions(attributions)
#     scores = attributions_sum[1:-1]
    return {
        'words': words,
        'scores': attributions_sum.detach().tolist(),
    }

In [None]:
train_attr = tokenized_train.select([0, 1]).map(get_attribution_score, batched=True)

In [None]:
train_attr[0]

In [32]:
tmp = tokenized_train.select([0])

In [33]:
data = tmp[:]

In [35]:
words = list(map(tokenizer.tokenize, data["text"]))
words

[['there',
  'is',
  'no',
  'relation',
  'at',
  'all',
  'between',
  'fort',
  '##ier',
  'and',
  'profile',
  '##r',
  'but',
  'the',
  'fact',
  'that',
  'both',
  'are',
  'police',
  'series',
  'about',
  'violent',
  'crimes',
  '.',
  'profile',
  '##r',
  'looks',
  'crisp',
  '##y',
  ',',
  'fort',
  '##ier',
  'looks',
  'classic',
  '.',
  'profile',
  '##r',
  'plots',
  'are',
  'quite',
  'simple',
  '.',
  'fort',
  '##ier',
  "'",
  's',
  'plot',
  'are',
  'far',
  'more',
  'complicated',
  '.',
  '.',
  '.',
  'fort',
  '##ier',
  'looks',
  'more',
  'like',
  'prime',
  'suspect',
  ',',
  'if',
  'we',
  'have',
  'to',
  'spot',
  'similarities',
  '.',
  '.',
  '.',
  'the',
  'main',
  'character',
  'is',
  'weak',
  'and',
  'weird',
  '##o',
  ',',
  'but',
  'have',
  '"',
  'clair',
  '##vo',
  '##yan',
  '##ce',
  '"',
  '.',
  'people',
  'like',
  'to',
  'compare',
  ',',
  'to',
  'judge',
  ',',
  'to',
  'evaluate',
  '.',
  'how',
  'about

In [36]:
lengths = list(map(len, words))
lengths

[177]

In [37]:
input_ids = torch.Tensor(data["input_ids"], device=device).to(torch.long)
input_ids

tensor([[  101,  2045,  2003,  2053,  7189,  2012,  2035,  2090,  3481,  3771,
          1998,  6337,  2099,  2021,  1996,  2755,  2008,  2119,  2024,  2610,
          2186,  2055,  6355,  6997,  1012,  6337,  2099,  3504, 15594,  2100,
          1010,  3481,  3771,  3504,  4438,  1012,  6337,  2099, 14811,  2024,
          3243,  3722,  1012,  3481,  3771,  1005,  1055,  5436,  2024,  2521,
          2062,  8552,  1012,  1012,  1012,  3481,  3771,  3504,  2062,  2066,
          3539,  8343,  1010,  2065,  2057,  2031,  2000,  3962, 12319,  1012,
          1012,  1012,  1996,  2364,  2839,  2003,  5410,  1998,  6881,  2080,
          1010,  2021,  2031,  1000, 17936,  6767,  7054,  3401,  1000,  1012,
          2111,  2066,  2000, 12826,  1010,  2000,  3648,  1010,  2000, 16157,
          1012,  2129,  2055,  2074,  9107,  1029,  6057,  2518,  2205,  1010,
          2111,  3015,  3481,  3771,  3504,  2137,  2021,  1010,  2006,  1996,
          2060,  2192,  1010,  9177,  2027,  9544,  

In [38]:
ref_input_ids = []
for l in lengths:
    pad_len = max(int(input_ids.size(1)) - l - 2, 0)
    ref_input_ids.append([cls_token_id] + [ref_token_id] * l + [sep_token_id] + [ref_token_id] * pad_len)

ref_input_ids = torch.tensor(ref_input_ids, device=device).to(torch.long)
ref_input_ids

tensor([[101,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0,   0,   0,   0,   0,   0,   0,   0

In [41]:
print(input_ids.size(), ref_input_ids.size())

torch.Size([1, 512]) torch.Size([1, 512])


In [None]:
attributions = lig.attribute(inputs=input_ids, baselines=ref_input_ids)
attributions

In [None]:
attributions_sum = summarize_attributions(attributions)
attributions_sum