In [1]:
import torch
from torch.utils.data import DataLoader

from transformers import BertForSequenceClassification, BertTokenizerFast
from bertviz.bertviz import head_view
from tqdm import tqdm, trange

import utils
from dataset import MBTIDataset

In [2]:
%%javascript
require.config({
  paths: {
      d3: '//cdnjs.cloudflare.com/ajax/libs/d3/3.4.8/d3.min',
      jquery: '//ajax.googleapis.com/ajax/libs/jquery/2.0.0/jquery.min',
  }
});

<IPython.core.display.Javascript object>

In [3]:
# Load fine-tuned model & test set
model = BertForSequenceClassification.from_pretrained('./checkpoint', num_labels=16, output_attentions=True)
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
test_set = MBTIDataset('100speeches.tsv', tokenizer)

Reading 100speeches.tsv: 69it [00:00, 99.12it/s]


In [14]:
def get_attentions_from_correct_examples(model, test_set, tokenizer):
    attentions, tok_texts, labels = [], [], []
    model.eval()
    with torch.no_grad():
        for post, label in tqdm(test_set, desc="Testing Iteration"):
            logit, attention = model(torch.tensor(post).unsqueeze(0))
            pred = torch.argmax(logit[0].data)
            if pred.item() == label:
                tokens = tokenizer.convert_ids_to_tokens(post)
                attentions.append(attention)
                tok_texts.append(tokens)
                labels.append(utils.decode_label(label))
    return attentions, tok_texts, labels

In [15]:
attentions, tok_texts, labels = get_attentions_from_correct_examples(model, test_set, tokenizer)

Testing Iteration: 100%|██████████| 69/69 [00:18<00:00,  3.80it/s]


In [27]:
examples = list(range(5))
layer = 0
heads = list(range(12))
for example in examples:
    print(labels[example])
    print("\n\n")
    for head in heads:
        layer_val, layer_ind = torch.topk(attentions[example][layer][0,head,:,0], k=20) # attention[example][layer][batch,head,seq,seq]
        top_tokens = [tok_texts[example][i] for i in layer_ind]
        print(top_tokens)
        print()

# example = 4
# layer = 1
# head = 6
# print(labels[example])
# # print(tok_texts[example])
# print(tokenizer.decode(tokenizer.convert_tokens_to_ids(tok_texts[example])))
# layer_val, layer_ind = torch.topk(attentions[example][layer][0,head,:,0], k=20) # attention[example][layer][batch,head,seq,seq]
# top_tokens = [tok_texts[example][i] for i in layer_ind]
# print(layer_val)
# print(layer_ind)
# print(top_tokens)

INTJ



['were', 'of', 'by', 'haven', 'thought', 'and', 'with', 'identity', 'i', 'within', 'of', 'to', 'or', '[SEP]', 'of', ',', 'to', 'to', 'principle', 'help']

['##to', '##pha', '##rea', '##iger', '##her', '##ents', '##ved', '##ager', '##d', 'ha', '##scribe', '##zard', '##d', '##s', 'bell', '##scribe', '[CLS]', '##rand', '##ly', '##ly']

['gentlemen', 'february', 'february', 'assume', 'april', 'conform', 'sunk', 'imperial', 'imperial', 'neutral', 'earlier', 'officially', 'extraordinary', 'principle', 'constitutional', 'submarines', 'compassion', 'german', 'highways', 'submarines']

['gentlemen', '[SEP]', 'assume', 'constitutional', '[SEP]', '[CLS]', 'commanders', '##sea', 'conform', 'dominion', 'laid', 'imperial', 'somewhat', 'reckless', 'officially', 'neither', 'imperial', 'extraordinary', '[SEP]', 'per']


['and', 'and', 'and', '##rand', ',', 'the', 'reckless', 'and', '[SEP]', 'the', '[SEP]', ',', 'the', '[SEP]', '[SEP]', 'and', 'and', 'and', 'or', '##to']

['ports', 'ports', '##m

In [51]:
# head_view([layer[:,:,layer_ind,:20] for layer in attentions[example]], top_tokens)

In [23]:
head_view([layer[:, :, torch.cat([torch.zeros(1, dtype=torch.long), layer_ind]), :21] for layer in attentions[example]], ['[CLS]'] + top_tokens)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>