# Interpretation of [BertForSequenceClassification](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertForSequenceClassification) using [**Captum**](https://captum.ai/) (on a model that was trained for 1 epoch)

Source code info:

Used notebook: https://gist.github.com/davidefiocco/3e1a0ed030792230a33c726c61f6b3a5 (adaptation to classification from original tutorial on question answering: https://captum.ai/tutorials/Bert_SQUAD_Interpret)

(Used notebook is based on this github issue: https://github.com/pytorch/captum/issues/303)

Related github issue: https://github.com/pytorch/captum/issues/249

---

Used model: [roa7n/DNABert_K6_G_quad_1](https://huggingface.co/roa7n/DNABert_K6_G_quad_1)



## Load initial libraries, models, data:

In [1]:
!pip install transformers datasets tokenizers evaluate captum --quiet

In [2]:
TOKENIZER = 'yarongef/DistilProtBert'
HF_MODEL_NAME = 'simecek/knotted_proteins_demo_model'  # trained for 1 epoch

FILE_KNOTTED = "/home/jovyan/data/proteins/SPOUT_knotted.csv"
FILE_UNKNOTTED = "/home/jovyan/data/proteins/Rossmann_unknotted.csv"

# FILE_KNOTTED = "/home/jovyan/data/proteins/SPOUT_knotted_small.csv"
# FILE_UNKNOTTED = "/home/jovyan/data/proteins/Rossmann_unknotted_small.csv"  # echo "$(head -101 Rossmann_unknotted.csv)" > Rossmann_unknotted_small.csv

In [5]:
import torch
torch.cuda.memory_summary()



In [6]:
from transformers import BertTokenizer, BertForSequenceClassification

from captum.attr import visualization as viz
from captum.attr import LayerIntegratedGradients

import torch

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda', index=0)

### Load the tokenizer:

In [7]:
tokenizer = BertTokenizer.from_pretrained('yarongef/DistilProtBert', max_length=1024, truncation=True, num_labels=2)
tokenizer

PreTrainedTokenizer(name_or_path='yarongef/DistilProtBert', vocab_size=30, model_max_len=1000000000000000019884624838656, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})

### Load the model:

In [8]:
model = BertForSequenceClassification.from_pretrained(HF_MODEL_NAME, output_attentions=True)
model.to(device)
model.eval()
model.zero_grad()
# model

## Get the data:

In [9]:
import pandas as pd

df_neg = pd.read_csv(FILE_UNKNOTTED)
df_neg['label'] = 0
df_neg

Unnamed: 0,seq,label
0,MSVSMRDMLAAGVHFGHQTRFWNPKMAPYLFGARNKIHIINLEQTL...,0
1,MAGSETMERTVVWFRRDLRIDDNPALAAAAREGSVLPVFIWCSADE...,0
2,MEILSPSPPPSHCPLLRCGGHWEQHHGETWVHVAVGRSPEKTLSLL...,0
3,MIRSVVRSGRTVGRRSSRRLISQSTIKSNVQKLESPPIPPAVNPGK...,0
4,MTVAKGEMCSVNDCRFDDNDDEITKNKGKDSLADVTLCRKCKCENA...,0
...,...,...
108141,MAKTRVIAAMSGGVDSAVAAALLAEQGYEVIGVTMRMYEATQPAHA...,0
108142,MSSQFATSTLPNPASPRATVTRETVAVAMSGGVDSSTVAAMLRAQG...,0
108143,MKKVVIGMSGGVDSSVSAYLLKEQGYEVIGVTLNQHLEENSKDIED...,0
108144,MLYRLSPNKLEKLIFPLKDYSKQEIREIALKIGLEIHNKKDSQGIC...,0


In [10]:
df_pos = pd.read_csv(FILE_KNOTTED)
df_pos = df_pos.rename(columns={'seq;knotted?': 'seq'})
df_pos['label'] = 1
df_pos

Unnamed: 0,seq,label
0,MLLVKTLREMEYVAASHIKDAIGDVEIEIRPSGFLGLLIVHCDESL...,1
1,MAKYIIKTQKGFENIVVNNLKEIIGDFKYTVSPDGYQGIVIVEHDE...,1
2,MKFLVKTQRDMEAVAGNYITEAVPDAEVWIAPMGYTGLVLVEADEN...,1
3,MIFVKTQRGMEYIAAQNIKELLGDVKIEIRPAGYLGILVVHSDELE...,1
4,MIFVKTQRGMEYIAMQNIKELMGDVKIEVRPAGYLGVLIVHSDDIE...,1
...,...,...
140293,MKGKGFTVYGTELNEEAHALDKVEKTEDFAIIMGNEGQGVSQEILS...,1
140294,TVAAVARGGVAPEALPRDRPIVLVMGNEEQGLPEASIAACAARVTL...,1
140295,AGLEADGDKDYRDGDYRGGVALVIGGEGNGLARLTRELCDYIVSIP...,1
140296,MGVELTDESIRLAELPAARRRTVVVLGNEGSGIPSDAMELLDLAVE...,1


In [11]:
from datasets import Dataset

df_merged = pd.concat([df_neg, df_pos], ignore_index=True, sort=False)

# https://huggingface.co/docs/datasets/process
# dataset = Dataset.from_pandas(df_merged).shuffle(seed=42).train_test_split(test_size=0.2, shuffle=False)
dss = Dataset.from_pandas(df_merged).train_test_split(test_size=0.2, seed=42, shuffle=True)  # splits are shuffled by default
dss

DatasetDict({
    train: Dataset({
        features: ['seq', 'label'],
        num_rows: 198755
    })
    test: Dataset({
        features: ['seq', 'label'],
        num_rows: 49689
    })
})

In [12]:
def tokenize_function(s):
    seq_split = ' '.join(s['seq'])
    # print(seq_split)
    return tokenizer(seq_split)

# dataset = dss.map(tokenize_function, remove_columns='seq', num_proc=4)
# dataset.set_format('pt')
# dataset

In [13]:
dataset = dss

In [11]:
dataset['train'][0]['seq']

'MKKKILQLTLENAIAFKGKANPKAVINKIIPTVKDKSKLKAIGNEVSATIKKVNKLSLSKQKEQLKKINPTFFNKKIKVKKGIIDLPKVGKNFRARFAPSASGPLHIGHALVISLNKIYADKYKGKHILRIEDTNPDANFKEFYKMIPKDYTWLAGKPSETYIQSARVKTYYKYAEQLIKAGHLYVCEETPEEVKAKLKKGIQPFGRRDDPKEVLRKWKRMLTGKYNPGESVVRVKTDLKGKNPALKEWVAFRISGGTHPKVGNKVRVWPLMNFAVAIDDYELKMTHVIRGKDHEDNTKKQKMIYDFFGWTYPEYIHLGRINFKNMIISASDIRKGVEEGIYKGYDDEQVESLASIRKRGIKPKALLKFFYEIGPTKRDKTVDKKEVKHNK'

In [12]:
print(tokenize_function(dataset['train'][0]))

{'input_ids': [2, 21, 12, 12, 12, 11, 5, 18, 5, 15, 5, 9, 17, 6, 11, 6, 19, 12, 7, 12, 6, 17, 16, 12, 6, 8, 11, 17, 12, 11, 11, 16, 15, 8, 12, 14, 12, 10, 12, 5, 12, 6, 11, 7, 17, 9, 8, 10, 6, 15, 11, 12, 12, 8, 17, 12, 5, 10, 5, 10, 12, 18, 12, 9, 18, 5, 12, 12, 11, 17, 16, 15, 19, 19, 17, 12, 12, 11, 12, 8, 12, 12, 7, 11, 11, 14, 5, 16, 12, 8, 7, 12, 17, 19, 13, 6, 13, 19, 6, 16, 10, 6, 10, 7, 16, 5, 22, 11, 7, 22, 6, 5, 8, 11, 10, 5, 17, 12, 11, 20, 6, 14, 12, 20, 12, 7, 12, 22, 11, 5, 13, 11, 9, 14, 15, 17, 16, 14, 6, 17, 19, 12, 9, 19, 20, 12, 21, 11, 16, 12, 14, 20, 15, 24, 5, 6, 7, 12, 16, 10, 9, 15, 20, 11, 18, 10, 6, 13, 8, 12, 15, 20, 20, 12, 20, 6, 9, 18, 5, 11, 12, 6, 7, 22, 5, 20, 8, 23, 9, 9, 15, 16, 9, 9, 8, 12, 6, 12, 5, 12, 12, 7, 11, 18, 16, 19, 7, 13, 13, 14, 14, 16, 12, 9, 8, 5, 13, 12, 24, 12, 13, 21, 5, 15, 7, 12, 20, 17, 16, 7, 9, 10, 8, 8, 13, 8, 12, 15, 14, 5, 12, 7, 12, 17, 16, 6, 5, 12, 9, 24, 8, 6, 19, 13, 11, 10, 7, 7, 15, 22, 16, 12, 8, 7, 17, 12, 8, 13, 8

## Captum interpretation:

### 1. On untrained model:

*Helper function to perform forward pass of the model and make predictions:*

In [14]:
def predict(inputs):
    score = model(inputs)
    return score[0]

*Custom forward function that will allow us to access the postitions of our prediction using position input argument:*

In [15]:
### original for question answering looked like this: ######
# def squad_pos_forward_func(inputs, token_type_ids=None, position_ids=None, attention_mask=None, position=0):
#     pred = predict(inputs,
#                    token_type_ids=token_type_ids,
#                    position_ids=position_ids,
#                    attention_mask=attention_mask)
#     pred = pred[position]
#     return pred.max(1).values
############################################################

def custom_forward(inputs):
    preds = predict(inputs)
    return torch.softmax(preds, dim = 1)[0][0].unsqueeze(-1)

Compute attributions with respect to the `BertEmbeddings` layer:

1. define baselines/ references,
2. numericalize baselines and inputs.

*(helper functions to achieve that)*

In [16]:
ref_token_id = tokenizer.pad_token_id  # token used for generating token reference
sep_token_id = tokenizer.sep_token_id  # token added to the end of the input text
cls_token_id = tokenizer.cls_token_id  # token used at the beginning of the input text
print(ref_token_id)
print(sep_token_id)
print(cls_token_id)

0
3
2


In [17]:
def construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id):
    text_ids = tokenizer.encode(text, add_special_tokens=False)
    # construct input token ids
    input_ids = [cls_token_id] + text_ids + [sep_token_id]
    # construct reference token ids 
    ref_input_ids = [cls_token_id] + [ref_token_id] * len(text_ids) + [sep_token_id]

    return torch.tensor([input_ids], device=device), torch.tensor([ref_input_ids], device=device), len(text_ids)

Define the input sequence `sample1` that we'd like to use as an input for our Bert model and interpret what the model was focusing on when predicting the class:

In [18]:
def get_sample_and_label(dataset, seq_id):
    seq = dataset[seq_id]['seq']
    label = dataset[seq_id]['label']
    seq_kmers = ' '.join(seq)
    tokenized_seq = tokenizer.encode(seq_kmers, add_special_tokens=False)
    print(f'Seq label: {label}')
    print(f'Input seq (raw): {seq}')
    print(f'Returned (tokenized[:10]): {tokenized_seq[:10]}')
    return tokenized_seq, label

sample1, sample_label1 = get_sample_and_label(dataset['test'], 0)

Seq label: 1
Input seq (raw): MKLEAVYGLHAVTTLLQRSPDQVVELWVMKGRQDQRMQRVLELAAEQGLDIREADKGLMNQKADEGNHQGIIAWRKPVQNKNEKHLPDILDSISGNALILILDGVTDPHNLGACLRTADAAGVQVVIAPKDKSAPLNATAAKVACGAAEAVPYIQVTNLARTMKELQERGIWIVGTAGEATHSIYQQDFTGPTALVMGAEGAGMRRLTREHCDYLVNIPMAGEVSSVNVSVATGICLFEAVRQRQLS;1
Returned (tokenized[:10]): [21, 12, 5, 9, 6, 8, 20, 7, 5, 22]


Let's numericalize the input `sample1` and generate corresponding baselines/references for all three sub-embeddings (word, token type and position embeddings) types using our helper functions defined above:

In [18]:
input_ids, ref_input_ids, sep_id = construct_input_ref_pair(sample1, ref_token_id, sep_token_id, cls_token_id)
print(input_ids)
print(ref_input_ids)
print(sep_id)

tensor([[ 2, 21, 12,  5,  9,  6,  8, 20,  7,  5, 22,  6,  8, 15, 15,  5,  5, 18,
         13, 10, 16, 14, 18,  8,  8,  9,  5, 24,  8, 21, 12,  7, 13, 18, 14, 18,
         13, 21, 18, 13,  8,  5,  9,  5,  6,  6,  9, 18,  7,  5, 14, 11, 13,  9,
          6, 14, 12,  7,  5, 21, 17, 18, 12,  6, 14,  9,  7, 17, 22, 18,  7, 11,
         11,  6, 24, 13, 12, 16,  8, 18, 17, 12, 17,  9, 12, 22,  5, 16, 14, 11,
          5, 14, 10, 11, 10,  7, 17,  6,  5, 11,  5, 11,  5, 14,  7,  8, 15, 14,
         16, 22, 17,  5,  7,  6, 23,  5, 13, 15,  6, 14,  6,  6,  7,  8, 18,  8,
          8, 11,  6, 16, 12, 14, 12, 10,  6, 16,  5, 17,  6, 15,  6,  6, 12,  8,
          6, 23,  7,  6,  6,  9,  6,  8, 16, 20, 11, 18,  8, 15, 17,  5,  6, 13,
         15, 21, 12,  9,  5, 18,  9, 13,  7, 11, 24, 11,  8,  7, 15,  6,  7,  9,
          6, 15, 22, 10, 11, 20, 18, 18, 14, 19, 15,  7, 16, 15,  6,  5,  8, 21,
          7,  6,  9,  7,  6,  7, 21, 13, 13,  5, 15, 13,  9, 22, 23, 14, 20,  5,
          8, 17, 11, 16, 21,

In [19]:
predict(input_ids)

tensor([[-6.6038,  6.7049]], device='cuda:0', grad_fn=<AddmmBackward0>)

In [20]:
custom_forward(input_ids)

tensor([1.6599e-06], device='cuda:0', grad_fn=<UnsqueezeBackward0>)

There are 2 different ways of computing the attributions for emebdding layers. One option is to use `LayerIntegratedGradients` and compute the attributions with respect to `BertEmbedding`. The second option is to use `LayerIntegratedGradients` for each `word_embeddings`, `token_type_embeddings` and `position_embeddings` and compute the attributions w.r.t each embedding vector.

In [19]:
lig = LayerIntegratedGradients(custom_forward, model.bert.embeddings)

Helper function to summarize attributions for each word token in the sequence:

In [20]:
def summarize_attributions(attributions):
    attributions = attributions.sum(dim=-1).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    return attributions

In [21]:
def interpret_and_visualize(tokenized_sample, label):
    input_ids, ref_input_ids, sep_id = construct_input_ref_pair(tokenized_sample, ref_token_id, sep_token_id, cls_token_id)
    score = predict(input_ids)

    attributions, delta = lig.attribute(inputs=input_ids,
                                    baselines=ref_input_ids,
                                    return_convergence_delta=True)
    
    indices = input_ids[0].detach().tolist()
    
    all_tokens = tokenizer.convert_ids_to_tokens(indices)
    
    attributions_sum = summarize_attributions(attributions)
    
#     # debug prints:
#     attributions_list = attributions_sum.tolist()
#     # print(attributions_list)
    
#     for i in range(len(attributions_list)):
#         if attributions_list[i] > 0.3:
#             print(f'POS: idx={i}, attribution={attributions_list[i]}')
#         elif attributions_list[i] < -0.3:
#             print(f'NEG: idx={i}, attribution={attributions_list[i]}')
#     ##

    score_vis = viz.VisualizationDataRecord(
        word_attributions=attributions_sum, 
        pred_prob=torch.softmax(score, dim=1)[0][0],
        pred_class=torch.argmax(torch.softmax(score, dim=1)[0]), 
        true_class=label, 
        attr_class=tokenized_sample, 
        attr_score=attributions_sum.sum(), 
        raw_input_ids=all_tokens, 
        convergence_score=delta)

    print('\033[1m', 'Visualization For Score', '\033[0m')
    viz.visualize_text([score_vis])

In [24]:
interpret_and_visualize(sample1, sample_label1)

[1m Visualization For Score [0m


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
1.0,1 (0.00),"[21, 12, 5, 9, 6, 8, 20, 7, 5, 22, 6, 8, 15, 15, 5, 5, 18, 13, 10, 16, 14, 18, 8, 8, 9, 5, 24, 8, 21, 12, 7, 13, 18, 14, 18, 13, 21, 18, 13, 8, 5, 9, 5, 6, 6, 9, 18, 7, 5, 14, 11, 13, 9, 6, 14, 12, 7, 5, 21, 17, 18, 12, 6, 14, 9, 7, 17, 22, 18, 7, 11, 11, 6, 24, 13, 12, 16, 8, 18, 17, 12, 17, 9, 12, 22, 5, 16, 14, 11, 5, 14, 10, 11, 10, 7, 17, 6, 5, 11, 5, 11, 5, 14, 7, 8, 15, 14, 16, 22, 17, 5, 7, 6, 23, 5, 13, 15, 6, 14, 6, 6, 7, 8, 18, 8, 8, 11, 6, 16, 12, 14, 12, 10, 6, 16, 5, 17, 6, 15, 6, 6, 12, 8, 6, 23, 7, 6, 6, 9, 6, 8, 16, 20, 11, 18, 8, 15, 17, 5, 6, 13, 15, 21, 12, 9, 5, 18, 9, 13, 7, 11, 24, 11, 8, 7, 15, 6, 7, 9, 6, 15, 22, 10, 11, 20, 18, 18, 14, 19, 15, 7, 16, 15, 6, 5, 8, 21, 7, 6, 9, 7, 6, 7, 21, 13, 13, 5, 15, 13, 9, 22, 23, 14, 20, 5, 8, 17, 11, 16, 21, 6, 7, 9, 8, 10, 10, 8, 17, 8, 10, 8, 6, 15, 7, 11, 23, 5, 19, 9, 6, 8, 13, 18, 13, 18, 5, 10, 1, 1]",10.99,[CLS] M K L E A V Y G L H A V T T L L Q R S P D Q V V E L W V M K G R Q D Q R M Q R V L E L A A E Q G L D I R E A D K G L M N Q K A D E G N H Q G I I A W R K P V Q N K N E K H L P D I L D S I S G N A L I L I L D G V T D P H N L G A C L R T A D A A G V Q V V I A P K D K S A P L N A T A A K V A C G A A E A V P Y I Q V T N L A R T M K E L Q E R G I W I V G T A G E A T H S I Y Q Q D F T G P T A L V M G A E G A G M R R L T R E H C D Y L V N I P M A G E V S S V N V S V A T G I C L F E A V R Q R Q L S [UNK] [UNK] [SEP]
,,,,


**How to interpret the colours:**

https://datascience.stackexchange.com/questions/87670/what-exactly-negative-positive-value-of-captums-integrated-gradient-mean

```
Positive attribution score means that the input in that particular position positively contributed to the final prediction and negative means the opposite. The magnitude of the attribution score signifies the strength of the contribution. Zero attribution score means no contribution from that particular feature.
```

*Github issue: https://github.com/pytorch/captum/issues/249#issuecomment-580569266*

*TL;DR: The output is the prediction probability (`p`) of being of the positive class. A negative class would be (`1 - p`). We attribute positive class probability (`p`) to the inputs of our model and in case something is predicted with high probability (as the positive class) we see many tokens that positively contribute to it.*

*When `p` is very low, there are no words contributing to the positive class. When we attribute to the positive class prob (`p`) we find words that pull away from it (influence it negatively). Those tokens are obviously the ones that pull towards the negative class with higher (`1-p`) probability.*

https://github.com/pytorch/captum/issues/249#issuecomment-580846266

*In a general case, red means that those tokens are pulling away from the positive class and most probably pulling towards the opposite class however I think that red might not always mean that it will always attribute to the other class. I think that's the assumption that we make here. We assume that the classifier is able to identify that a token is negatively correlated with the positive class so it must know something about that token, namely, that it is strongly pulling towards the opposite class (because there are no other options) and this is much easier to imagine for 2 class problem.*

## Other input sequences:

### Baseline

(Should have attribution score close to 0.)

In [25]:
interpret_and_visualize(ref_input_ids.tolist()[0], 0)

[1m Visualization For Score [0m


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,0 (0.99),"[2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3]",-1.06,[CLS] [CLS] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [SEP] [SEP]
,,,,


In [26]:
len(ref_input_ids.tolist()[0])

251

In [27]:
# equivalent to 
# interpret_and_visualize([2] + [0] * (len(ref_input_ids.tolist()[0]) - 2) + [3], 0)

In [28]:
interpret_and_visualize([2] + [0] * 20 + [3], 0)

[1m Visualization For Score [0m


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,0 (1.00),"[2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3]",1.2,[CLS] [CLS] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [SEP] [SEP]
,,,,


In [29]:
interpret_and_visualize([2] + [0] * 400 + [3], 0)

[1m Visualization For Score [0m


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,0 (0.99),"[2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3]",0.99,[CLS] [CLS] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [SEP] [SEP]
,,,,


In [30]:
interpret_and_visualize([0] * 50, 0)

[1m Visualization For Score [0m


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,0 (1.00),"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]",,[CLS] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [SEP]
,,,,


### [Towards falsifiable interpretability research](https://www.youtube.com/watch?v=BQ06EydLF0Q)

- if the input is shifted, the attribution should not change

In [31]:
sample_shift3, sample_label_shift3 = get_sample_and_label(dataset['train'], 50)
interpret_and_visualize(sample_shift3, sample_label_shift3)

Seq label: 0
Input seq (raw): MTQVIPMLRRMTVVVGYVPSAEGRAALDAAIEEAARRGETLHLVNVGQSDASNDPKFLDEGEVERLRGRLAEAGVPFEIEQLVRGRDAAEEVVDAAERIGATLVVIGMRRRSPTGKLLFGSQAQRILLDADCPVLAVKATR
Returned (tokenized[:10]): [21, 15, 18, 8, 11, 16, 21, 5, 13, 13]
[1m Visualization For Score [0m


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,0 (1.00),"[21, 15, 18, 8, 11, 16, 21, 5, 13, 13, 21, 15, 8, 8, 8, 7, 20, 8, 16, 10, 6, 9, 7, 13, 6, 6, 5, 14, 6, 6, 11, 9, 9, 6, 6, 13, 13, 7, 9, 15, 5, 22, 5, 8, 17, 8, 7, 18, 10, 14, 6, 10, 17, 14, 16, 12, 19, 5, 14, 9, 7, 9, 8, 9, 13, 5, 13, 7, 13, 5, 6, 9, 6, 7, 8, 16, 19, 9, 11, 9, 18, 5, 8, 13, 7, 13, 14, 6, 6, 9, 9, 8, 8, 14, 6, 6, 9, 13, 11, 7, 6, 15, 5, 8, 8, 11, 7, 21, 13, 13, 13, 10, 16, 15, 7, 12, 5, 5, 19, 7, 10, 18, 6, 18, 13, 11, 5, 5, 14, 6, 14, 23, 16, 8, 5, 6, 8, 12, 6, 15, 13]",2.33,[CLS] M T Q V I P M L R R M T V V V G Y V P S A E G R A A L D A A I E E A A R R G E T L H L V N V G Q S D A S N D P K F L D E G E V E R L R G R L A E A G V P F E I E Q L V R G R D A A E E V V D A A E R I G A T L V V I G M R R R S P T G K L L F G S Q A Q R I L L D A D C P V L A V K A T R [SEP]
,,,,


In [32]:
# adding [PAD] token here forces the rest of the sequence to shift to right
sample_shift3 = [0] + sample_shift3
print(sample_shift3)

[0, 21, 15, 18, 8, 11, 16, 21, 5, 13, 13, 21, 15, 8, 8, 8, 7, 20, 8, 16, 10, 6, 9, 7, 13, 6, 6, 5, 14, 6, 6, 11, 9, 9, 6, 6, 13, 13, 7, 9, 15, 5, 22, 5, 8, 17, 8, 7, 18, 10, 14, 6, 10, 17, 14, 16, 12, 19, 5, 14, 9, 7, 9, 8, 9, 13, 5, 13, 7, 13, 5, 6, 9, 6, 7, 8, 16, 19, 9, 11, 9, 18, 5, 8, 13, 7, 13, 14, 6, 6, 9, 9, 8, 8, 14, 6, 6, 9, 13, 11, 7, 6, 15, 5, 8, 8, 11, 7, 21, 13, 13, 13, 10, 16, 15, 7, 12, 5, 5, 19, 7, 10, 18, 6, 18, 13, 11, 5, 5, 14, 6, 14, 23, 16, 8, 5, 6, 8, 12, 6, 15, 13]


In [33]:
interpret_and_visualize(sample_shift3, sample_label_shift3)

[1m Visualization For Score [0m


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,0 (1.00),"[0, 21, 15, 18, 8, 11, 16, 21, 5, 13, 13, 21, 15, 8, 8, 8, 7, 20, 8, 16, 10, 6, 9, 7, 13, 6, 6, 5, 14, 6, 6, 11, 9, 9, 6, 6, 13, 13, 7, 9, 15, 5, 22, 5, 8, 17, 8, 7, 18, 10, 14, 6, 10, 17, 14, 16, 12, 19, 5, 14, 9, 7, 9, 8, 9, 13, 5, 13, 7, 13, 5, 6, 9, 6, 7, 8, 16, 19, 9, 11, 9, 18, 5, 8, 13, 7, 13, 14, 6, 6, 9, 9, 8, 8, 14, 6, 6, 9, 13, 11, 7, 6, 15, 5, 8, 8, 11, 7, 21, 13, 13, 13, 10, 16, 15, 7, 12, 5, 5, 19, 7, 10, 18, 6, 18, 13, 11, 5, 5, 14, 6, 14, 23, 16, 8, 5, 6, 8, 12, 6, 15, 13]",1.35,[CLS] [PAD] M T Q V I P M L R R M T V V V G Y V P S A E G R A A L D A A I E E A A R R G E T L H L V N V G Q S D A S N D P K F L D E G E V E R L R G R L A E A G V P F E I E Q L V R G R D A A E E V V D A A E R I G A T L V V I G M R R R S P T G K L L F G S Q A Q R I L L D A D C P V L A V K A T R [SEP]
,,,,


### Modified sample:

According to [Axiomatic Attribution for Deep Networks](https://arxiv.org/abs/1703.01365) paper we should see a decrease/ increase in attribution score if we changed some of the input parts that highly contributed to it.

1. When the True and predicted labels are different:

In [34]:
sample_shift, sample_label_shift = get_sample_and_label(dataset['test'], 81)
interpret_and_visualize(sample_shift, sample_label_shift)

Seq label: 0
Input seq (raw): MIVYPKNWINIGQSIQIKEIENTILQVLSEINCNCISFSGGLDSSLMLYYMLQVYDQVYAFTMGSSEEHPDVEYSKLVVSDLENVVHRVYIPSYKELEIAEFRHGDFEGDKEVRLFYKYVKQYTDEIIACDGIDEFMCGYYSHQDKPYEDTYYTHLRELSGKHLIPLYKNSGDVKVYLPYLDDGLISLFSQIEISRKVDKGCRKKLLVEMADGKIPDEIIHRRKYGFCDVLKIKG
Returned (tokenized[:10]): [21, 11, 8, 20, 16, 12, 17, 24, 11, 17]
[1m Visualization For Score [0m


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,0 (1.00),"[21, 11, 8, 20, 16, 12, 17, 24, 11, 17, 11, 7, 18, 10, 11, 18, 11, 12, 9, 11, 9, 17, 15, 11, 5, 18, 8, 5, 10, 9, 11, 17, 23, 17, 23, 11, 10, 19, 10, 7, 7, 5, 14, 10, 10, 5, 21, 5, 20, 20, 21, 5, 18, 8, 20, 14, 18, 8, 20, 6, 19, 15, 21, 7, 10, 10, 9, 9, 22, 16, 14, 8, 9, 20, 10, 12, 5, 8, 8, 10, 14, 5, 9, 17, 8, 8, 22, 13, 8, 20, 11, 16, 10, 20, 12, 9, 5, 9, 11, 6, 9, 19, 13, 22, 7, 14, 19, 9, 7, 14, 12, 9, 8, 13, 5, 19, 20, 12, 20, 8, 12, 18, 20, 15, 14, 9, 11, 11, 6, 23, 14, 7, 11, 14, 9, 19, 21, 23, 7, 20, 20, 10, 22, 18, 14, 12, 16, 20, 9, 14, 15, 20, 20, 15, 22, 5, 13, 9, 5, 10, 7, 12, 22, 5, 11, 16, 5, 20, 12, 17, 10, 7, 14, 8, 12, 8, 20, 5, 16, 20, 5, 14, 14, 7, 5, 11, 10, 5, 19, 10, 18, 11, 9, 11, 10, 13, 12, 8, 14, 12, 7, 23, 13, 12, 12, 5, 5, 8, 9, 21, 6, 14, 7, 12, 11, 16, 14, 9, 11, 11, 22, 13, 13, 12, 20, 7, 19, 23, 14, 8, 5, 12, 11, 12, 7]",0.12,[CLS] M I V Y P K N W I N I G Q S I Q I K E I E N T I L Q V L S E I N C N C I S F S G G L D S S L M L Y Y M L Q V Y D Q V Y A F T M G S S E E H P D V E Y S K L V V S D L E N V V H R V Y I P S Y K E L E I A E F R H G D F E G D K E V R L F Y K Y V K Q Y T D E I I A C D G I D E F M C G Y Y S H Q D K P Y E D T Y Y T H L R E L S G K H L I P L Y K N S G D V K V Y L P Y L D D G L I S L F S Q I E I S R K V D K G C R K K L L V E M A D G K I P D E I I H R R K Y G F C D V L K I K G [SEP]
,,,,


In [35]:
# the attribution score 
sample_shift[2] = 0
interpret_and_visualize(sample_shift, sample_label_shift)

[1m Visualization For Score [0m


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,0 (1.00),"[21, 11, 0, 20, 16, 12, 17, 24, 11, 17, 11, 7, 18, 10, 11, 18, 11, 12, 9, 11, 9, 17, 15, 11, 5, 18, 8, 5, 10, 9, 11, 17, 23, 17, 23, 11, 10, 19, 10, 7, 7, 5, 14, 10, 10, 5, 21, 5, 20, 20, 21, 5, 18, 8, 20, 14, 18, 8, 20, 6, 19, 15, 21, 7, 10, 10, 9, 9, 22, 16, 14, 8, 9, 20, 10, 12, 5, 8, 8, 10, 14, 5, 9, 17, 8, 8, 22, 13, 8, 20, 11, 16, 10, 20, 12, 9, 5, 9, 11, 6, 9, 19, 13, 22, 7, 14, 19, 9, 7, 14, 12, 9, 8, 13, 5, 19, 20, 12, 20, 8, 12, 18, 20, 15, 14, 9, 11, 11, 6, 23, 14, 7, 11, 14, 9, 19, 21, 23, 7, 20, 20, 10, 22, 18, 14, 12, 16, 20, 9, 14, 15, 20, 20, 15, 22, 5, 13, 9, 5, 10, 7, 12, 22, 5, 11, 16, 5, 20, 12, 17, 10, 7, 14, 8, 12, 8, 20, 5, 16, 20, 5, 14, 14, 7, 5, 11, 10, 5, 19, 10, 18, 11, 9, 11, 10, 13, 12, 8, 14, 12, 7, 23, 13, 12, 12, 5, 5, 8, 9, 21, 6, 14, 7, 12, 11, 16, 14, 9, 11, 11, 22, 13, 13, 12, 20, 7, 19, 23, 14, 8, 5, 12, 11, 12, 7]",0.12,[CLS] M I [PAD] Y P K N W I N I G Q S I Q I K E I E N T I L Q V L S E I N C N C I S F S G G L D S S L M L Y Y M L Q V Y D Q V Y A F T M G S S E E H P D V E Y S K L V V S D L E N V V H R V Y I P S Y K E L E I A E F R H G D F E G D K E V R L F Y K Y V K Q Y T D E I I A C D G I D E F M C G Y Y S H Q D K P Y E D T Y Y T H L R E L S G K H L I P L Y K N S G D V K V Y L P Y L D D G L I S L F S Q I E I S R K V D K G C R K K L L V E M A D G K I P D E I I H R R K Y G F C D V L K I K G [SEP]
,,,,


2. When the True and predicted labels are the same:

In [36]:
sample_shift2, sample_label_shift2 = get_sample_and_label(dataset['train'], 200)
interpret_and_visualize(sample_shift2, sample_label_shift2)

Seq label: 0
Input seq (raw): MTDSSEHPPIVVGITPDTGQREALLWAAAEAQHSGAPLLLVHAWGMPSMSYGAAVLASDVAANLRAQGEQALTESEQFVTDRYPQVEVTGVAADEQPAEALRARAAGAAMVVLGARPPSKRGPFPVSAVALPVMAHVHCPVAVVPEEARKPATGEPFLVVGVDGSPSAAAAARLAFGEAAARGAALRAVCAWHSPWLGSLDVQAVAGEAERTLEEVVSPLSARHPGVRVEQEAVAGHPVQVLTDAAEGATGLVVGSRGHGGFVGMLLGSVSQGVLRHARCPVVVVPPAAEP
Returned (tokenized[:10]): [21, 15, 14, 10, 10, 9, 22, 16, 16, 11]
[1m Visualization For Score [0m


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,0 (1.00),"[21, 15, 14, 10, 10, 9, 22, 16, 16, 11, 8, 8, 7, 11, 15, 16, 14, 15, 7, 18, 13, 9, 6, 5, 5, 24, 6, 6, 6, 9, 6, 18, 22, 10, 7, 6, 16, 5, 5, 5, 8, 22, 6, 24, 7, 21, 16, 10, 21, 10, 20, 7, 6, 6, 8, 5, 6, 10, 14, 8, 6, 6, 17, 5, 13, 6, 18, 7, 9, 18, 6, 5, 15, 9, 10, 9, 18, 19, 8, 15, 14, 13, 20, 16, 18, 8, 9, 8, 15, 7, 8, 6, 6, 14, 9, 18, 16, 6, 9, 6, 5, 13, 6, 13, 6, 6, 7, 6, 6, 21, 8, 8, 5, 7, 6, 13, 16, 16, 10, 12, 13, 7, 16, 19, 16, 8, 10, 6, 8, 6, 5, 16, 8, 21, 6, 22, 8, 22, 23, 16, 8, 6, 8, 8, 16, 9, 9, 6, 13, 12, 16, 6, 15, 7, 9, 16, 19, 5, 8, 8, 7, 8, 14, 7, 10, 16, 10, 6, 6, 6, 6, 6, 13, 5, 6, 19, 7, 9, 6, 6, 6, 13, 7, 6, 6, 5, 13, 6, 8, 23, 6, 24, 22, 10, 16, 24, 5, 7, 10, 5, 14, 8, 18, 6, 8, 6, 7, 9, 6, 9, 13, 15, 5, 9, 9, 8, 8, 10, 16, 5, 10, 6, 13, 22, 16, 7, 8, 13, 8, 9, 18, 9, 6, 8, 6, 7, 22, 16, 8, 18, 8, 5, 15, 14, 6, 6, 9, 7, 6, 15, 7, 5, 8, 8, 7, 10, 13, 7, 22, 7, 7, 19, 8, 7, 21, 5, 5, 7, 10, 8, 10, 18, 7, 8, 5, 13, 22, 6, 13, 23, 16, 8, 8, 8, 8, 16, 16, 6, 6, 9, 16]",11.73,[CLS] M T D S S E H P P I V V G I T P D T G Q R E A L L W A A A E A Q H S G A P L L L V H A W G M P S M S Y G A A V L A S D V A A N L R A Q G E Q A L T E S E Q F V T D R Y P Q V E V T G V A A D E Q P A E A L R A R A A G A A M V V L G A R P P S K R G P F P V S A V A L P V M A H V H C P V A V V P E E A R K P A T G E P F L V V G V D G S P S A A A A A R L A F G E A A A R G A A L R A V C A W H S P W L G S L D V Q A V A G E A E R T L E E V V S P L S A R H P G V R V E Q E A V A G H P V Q V L T D A A E G A T G L V V G S R G H G G F V G M L L G S V S Q G V L R H A R C P V V V V P P A A E P [SEP]
,,,,


In [37]:
# the attribution score decreases
sample_shift2[0] = 0
interpret_and_visualize(sample_shift2, sample_label_shift2)

[1m Visualization For Score [0m


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,0 (1.00),"[0, 15, 14, 10, 10, 9, 22, 16, 16, 11, 8, 8, 7, 11, 15, 16, 14, 15, 7, 18, 13, 9, 6, 5, 5, 24, 6, 6, 6, 9, 6, 18, 22, 10, 7, 6, 16, 5, 5, 5, 8, 22, 6, 24, 7, 21, 16, 10, 21, 10, 20, 7, 6, 6, 8, 5, 6, 10, 14, 8, 6, 6, 17, 5, 13, 6, 18, 7, 9, 18, 6, 5, 15, 9, 10, 9, 18, 19, 8, 15, 14, 13, 20, 16, 18, 8, 9, 8, 15, 7, 8, 6, 6, 14, 9, 18, 16, 6, 9, 6, 5, 13, 6, 13, 6, 6, 7, 6, 6, 21, 8, 8, 5, 7, 6, 13, 16, 16, 10, 12, 13, 7, 16, 19, 16, 8, 10, 6, 8, 6, 5, 16, 8, 21, 6, 22, 8, 22, 23, 16, 8, 6, 8, 8, 16, 9, 9, 6, 13, 12, 16, 6, 15, 7, 9, 16, 19, 5, 8, 8, 7, 8, 14, 7, 10, 16, 10, 6, 6, 6, 6, 6, 13, 5, 6, 19, 7, 9, 6, 6, 6, 13, 7, 6, 6, 5, 13, 6, 8, 23, 6, 24, 22, 10, 16, 24, 5, 7, 10, 5, 14, 8, 18, 6, 8, 6, 7, 9, 6, 9, 13, 15, 5, 9, 9, 8, 8, 10, 16, 5, 10, 6, 13, 22, 16, 7, 8, 13, 8, 9, 18, 9, 6, 8, 6, 7, 22, 16, 8, 18, 8, 5, 15, 14, 6, 6, 9, 7, 6, 15, 7, 5, 8, 8, 7, 10, 13, 7, 22, 7, 7, 19, 8, 7, 21, 5, 5, 7, 10, 8, 10, 18, 7, 8, 5, 13, 22, 6, 13, 23, 16, 8, 8, 8, 8, 16, 16, 6, 6, 9, 16]",11.7,[CLS] [PAD] T D S S E H P P I V V G I T P D T G Q R E A L L W A A A E A Q H S G A P L L L V H A W G M P S M S Y G A A V L A S D V A A N L R A Q G E Q A L T E S E Q F V T D R Y P Q V E V T G V A A D E Q P A E A L R A R A A G A A M V V L G A R P P S K R G P F P V S A V A L P V M A H V H C P V A V V P E E A R K P A T G E P F L V V G V D G S P S A A A A A R L A F G E A A A R G A A L R A V C A W H S P W L G S L D V Q A V A G E A E R T L E E V V S P L S A R H P G V R V E Q E A V A G H P V Q V L T D A A E G A T G L V V G S R G H G G F V G M L L G S V S Q G V L R H A R C P V V V V P P A A E P [SEP]
,,,,


In [38]:
# the attribution score decreases 
sample_shift2[119] = 0
interpret_and_visualize(sample_shift2, sample_label_shift2)

[1m Visualization For Score [0m


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,0 (1.00),"[0, 15, 14, 10, 10, 9, 22, 16, 16, 11, 8, 8, 7, 11, 15, 16, 14, 15, 7, 18, 13, 9, 6, 5, 5, 24, 6, 6, 6, 9, 6, 18, 22, 10, 7, 6, 16, 5, 5, 5, 8, 22, 6, 24, 7, 21, 16, 10, 21, 10, 20, 7, 6, 6, 8, 5, 6, 10, 14, 8, 6, 6, 17, 5, 13, 6, 18, 7, 9, 18, 6, 5, 15, 9, 10, 9, 18, 19, 8, 15, 14, 13, 20, 16, 18, 8, 9, 8, 15, 7, 8, 6, 6, 14, 9, 18, 16, 6, 9, 6, 5, 13, 6, 13, 6, 6, 7, 6, 6, 21, 8, 8, 5, 7, 6, 13, 16, 16, 10, 0, 13, 7, 16, 19, 16, 8, 10, 6, 8, 6, 5, 16, 8, 21, 6, 22, 8, 22, 23, 16, 8, 6, 8, 8, 16, 9, 9, 6, 13, 12, 16, 6, 15, 7, 9, 16, 19, 5, 8, 8, 7, 8, 14, 7, 10, 16, 10, 6, 6, 6, 6, 6, 13, 5, 6, 19, 7, 9, 6, 6, 6, 13, 7, 6, 6, 5, 13, 6, 8, 23, 6, 24, 22, 10, 16, 24, 5, 7, 10, 5, 14, 8, 18, 6, 8, 6, 7, 9, 6, 9, 13, 15, 5, 9, 9, 8, 8, 10, 16, 5, 10, 6, 13, 22, 16, 7, 8, 13, 8, 9, 18, 9, 6, 8, 6, 7, 22, 16, 8, 18, 8, 5, 15, 14, 6, 6, 9, 7, 6, 15, 7, 5, 8, 8, 7, 10, 13, 7, 22, 7, 7, 19, 8, 7, 21, 5, 5, 7, 10, 8, 10, 18, 7, 8, 5, 13, 22, 6, 13, 23, 16, 8, 8, 8, 8, 16, 16, 6, 6, 9, 16]",11.66,[CLS] [PAD] T D S S E H P P I V V G I T P D T G Q R E A L L W A A A E A Q H S G A P L L L V H A W G M P S M S Y G A A V L A S D V A A N L R A Q G E Q A L T E S E Q F V T D R Y P Q V E V T G V A A D E Q P A E A L R A R A A G A A M V V L G A R P P S [PAD] R G P F P V S A V A L P V M A H V H C P V A V V P E E A R K P A T G E P F L V V G V D G S P S A A A A A R L A F G E A A A R G A A L R A V C A W H S P W L G S L D V Q A V A G E A E R T L E E V V S P L S A R H P G V R V E Q E A V A G H P V Q V L T D A A E G A T G L V V G S R G H G G F V G M L L G S V S Q G V L R H A R C P V V V V P P A A E P [SEP]
,,,,


#### TODO: Are predictions of 100% OK?

Predictions 1.00 could be OK, if it's just a rounding error:

In [39]:
print(dataset['test'][0])

{'seq': 'MKLEAVYGLHAVTTLLQRSPDQVVELWVMKGRQDQRMQRVLELAAEQGLDIREADKGLMNQKADEGNHQGIIAWRKPVQNKNEKHLPDILDSISGNALILILDGVTDPHNLGACLRTADAAGVQVVIAPKDKSAPLNATAAKVACGAAEAVPYIQVTNLARTMKELQERGIWIVGTAGEATHSIYQQDFTGPTALVMGAEGAGMRRLTREHCDYLVNIPMAGEVSSVNVSVATGICLFEAVRQRQLS;1', 'label': 1}


In [40]:
ex, ex_label = get_sample_and_label(dataset['test'], 0)
ex_torch = torch.tensor([ex], device=device)
prediction = model(ex_torch)[0]
prediction

Seq label: 1
Input seq (raw): MKLEAVYGLHAVTTLLQRSPDQVVELWVMKGRQDQRMQRVLELAAEQGLDIREADKGLMNQKADEGNHQGIIAWRKPVQNKNEKHLPDILDSISGNALILILDGVTDPHNLGACLRTADAAGVQVVIAPKDKSAPLNATAAKVACGAAEAVPYIQVTNLARTMKELQERGIWIVGTAGEATHSIYQQDFTGPTALVMGAEGAGMRRLTREHCDYLVNIPMAGEVSSVNVSVATGICLFEAVRQRQLS;1
Returned (tokenized[:10]): [21, 12, 5, 9, 6, 8, 20, 7, 5, 22]


tensor([[-6.5988,  6.7009]], device='cuda:0', grad_fn=<AddmmBackward0>)

In [41]:
softmax1 = torch.softmax(prediction, dim=1)
softmax1

tensor([[1.6751e-06, 1.0000e+00]], device='cuda:0', grad_fn=<SoftmaxBackward0>)

In [42]:
softmax2 = torch.softmax(prediction, dim=1)[0][0].unsqueeze(-1)
softmax2

tensor([1.6751e-06], device='cuda:0', grad_fn=<UnsqueezeBackward0>)

In [43]:
predicted_class = torch.argmax(softmax1)
predicted_class

tensor(1, device='cuda:0')

### Test data:

In [44]:
for i in range(20):
    print(f'i={i}')
    sample, sample_label = get_sample_and_label(dataset['test'], i)
    interpret_and_visualize(sample, sample_label)

i=0
Seq label: 1
Input seq (raw): MKLEAVYGLHAVTTLLQRSPDQVVELWVMKGRQDQRMQRVLELAAEQGLDIREADKGLMNQKADEGNHQGIIAWRKPVQNKNEKHLPDILDSISGNALILILDGVTDPHNLGACLRTADAAGVQVVIAPKDKSAPLNATAAKVACGAAEAVPYIQVTNLARTMKELQERGIWIVGTAGEATHSIYQQDFTGPTALVMGAEGAGMRRLTREHCDYLVNIPMAGEVSSVNVSVATGICLFEAVRQRQLS;1
Returned (tokenized[:10]): [21, 12, 5, 9, 6, 8, 20, 7, 5, 22]
[1m Visualization For Score [0m


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
1.0,1 (0.00),"[21, 12, 5, 9, 6, 8, 20, 7, 5, 22, 6, 8, 15, 15, 5, 5, 18, 13, 10, 16, 14, 18, 8, 8, 9, 5, 24, 8, 21, 12, 7, 13, 18, 14, 18, 13, 21, 18, 13, 8, 5, 9, 5, 6, 6, 9, 18, 7, 5, 14, 11, 13, 9, 6, 14, 12, 7, 5, 21, 17, 18, 12, 6, 14, 9, 7, 17, 22, 18, 7, 11, 11, 6, 24, 13, 12, 16, 8, 18, 17, 12, 17, 9, 12, 22, 5, 16, 14, 11, 5, 14, 10, 11, 10, 7, 17, 6, 5, 11, 5, 11, 5, 14, 7, 8, 15, 14, 16, 22, 17, 5, 7, 6, 23, 5, 13, 15, 6, 14, 6, 6, 7, 8, 18, 8, 8, 11, 6, 16, 12, 14, 12, 10, 6, 16, 5, 17, 6, 15, 6, 6, 12, 8, 6, 23, 7, 6, 6, 9, 6, 8, 16, 20, 11, 18, 8, 15, 17, 5, 6, 13, 15, 21, 12, 9, 5, 18, 9, 13, 7, 11, 24, 11, 8, 7, 15, 6, 7, 9, 6, 15, 22, 10, 11, 20, 18, 18, 14, 19, 15, 7, 16, 15, 6, 5, 8, 21, 7, 6, 9, 7, 6, 7, 21, 13, 13, 5, 15, 13, 9, 22, 23, 14, 20, 5, 8, 17, 11, 16, 21, 6, 7, 9, 8, 10, 10, 8, 17, 8, 10, 8, 6, 15, 7, 11, 23, 5, 19, 9, 6, 8, 13, 18, 13, 18, 5, 10, 1, 1]",10.99,[CLS] M K L E A V Y G L H A V T T L L Q R S P D Q V V E L W V M K G R Q D Q R M Q R V L E L A A E Q G L D I R E A D K G L M N Q K A D E G N H Q G I I A W R K P V Q N K N E K H L P D I L D S I S G N A L I L I L D G V T D P H N L G A C L R T A D A A G V Q V V I A P K D K S A P L N A T A A K V A C G A A E A V P Y I Q V T N L A R T M K E L Q E R G I W I V G T A G E A T H S I Y Q Q D F T G P T A L V M G A E G A G M R R L T R E H C D Y L V N I P M A G E V S S V N V S V A T G I C L F E A V R Q R Q L S [UNK] [UNK] [SEP]
,,,,


i=1
Seq label: 1
Input seq (raw): MGLHVVLYQPEIPQNTGNIMRTCAGTNTTLHLIEPLGFKVDDKSLKRSGVNYLEHTKFFVYPDFDTFLSKNQGEFLFFTRYGKKTPDQFDLSNSDKNIYLVFGRESTGIPKSILREHLDRCTRYPMNENIRSLNLSNTVCLGIYEVLRQQNYQGLSKTEPESMKGEDWLIKD;1
Returned (tokenized[:10]): [21, 7, 5, 22, 8, 8, 5, 20, 18, 16]
[1m Visualization For Score [0m


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
1.0,1 (0.00),"[21, 7, 5, 22, 8, 8, 5, 20, 18, 16, 9, 11, 16, 18, 17, 15, 7, 17, 11, 21, 13, 15, 23, 6, 7, 15, 17, 15, 15, 5, 22, 5, 11, 9, 16, 5, 7, 19, 12, 8, 14, 14, 12, 10, 5, 12, 13, 10, 7, 8, 17, 20, 5, 9, 22, 15, 12, 19, 19, 8, 20, 16, 14, 19, 14, 15, 19, 5, 10, 12, 17, 18, 7, 9, 19, 5, 19, 19, 15, 13, 20, 7, 12, 12, 15, 16, 14, 18, 19, 14, 5, 10, 17, 10, 14, 12, 17, 11, 20, 5, 8, 19, 7, 13, 9, 10, 15, 7, 11, 16, 12, 10, 11, 5, 13, 9, 22, 5, 14, 13, 23, 15, 13, 20, 16, 21, 17, 9, 17, 11, 13, 10, 5, 17, 5, 10, 17, 15, 8, 23, 5, 7, 11, 20, 9, 8, 5, 13, 18, 18, 17, 20, 18, 7, 5, 10, 12, 15, 9, 16, 9, 10, 21, 12, 7, 9, 14, 24, 5, 11, 12, 14, 1, 1]",5.87,[CLS] M G L H V V L Y Q P E I P Q N T G N I M R T C A G T N T T L H L I E P L G F K V D D K S L K R S G V N Y L E H T K F F V Y P D F D T F L S K N Q G E F L F F T R Y G K K T P D Q F D L S N S D K N I Y L V F G R E S T G I P K S I L R E H L D R C T R Y P M N E N I R S L N L S N T V C L G I Y E V L R Q Q N Y Q G L S K T E P E S M K G E D W L I K D [UNK] [UNK] [SEP]
,,,,


i=2
Seq label: 1
Input seq (raw): MRLEIVAVGRRPPAWITEGFETFAARMPRHLPLGLREVNAGDARRSGDVVRARAQEADHLLSAVGDARLIALEETGKAWTTRDLADYLGDAMQQGDDLAFVIGGADGLDPRCLQAAERRWSLSALTLPHMLVRVVVAEQLYRAWTLLAGHPYHRGGSPDCARRPL;1
Returned (tokenized[:10]): [21, 13, 5, 9, 11, 8, 6, 8, 7, 13]
[1m Visualization For Score [0m


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
1.0,1 (0.00),"[21, 13, 5, 9, 11, 8, 6, 8, 7, 13, 13, 16, 16, 6, 24, 11, 15, 9, 7, 19, 9, 15, 19, 6, 6, 13, 21, 16, 13, 22, 5, 16, 5, 7, 5, 13, 9, 8, 17, 6, 7, 14, 6, 13, 13, 10, 7, 14, 8, 8, 13, 6, 13, 6, 18, 9, 6, 14, 22, 5, 5, 10, 6, 8, 7, 14, 6, 13, 5, 11, 6, 5, 9, 9, 15, 7, 12, 6, 24, 15, 15, 13, 14, 5, 6, 14, 20, 5, 7, 14, 6, 21, 18, 18, 7, 14, 14, 5, 6, 19, 8, 11, 7, 7, 6, 14, 7, 5, 14, 16, 13, 23, 5, 18, 6, 6, 9, 13, 13, 24, 10, 5, 10, 6, 5, 15, 5, 16, 22, 21, 5, 8, 13, 8, 8, 8, 6, 9, 18, 5, 20, 13, 6, 24, 15, 5, 5, 6, 7, 22, 16, 20, 22, 13, 7, 7, 10, 16, 14, 23, 6, 13, 13, 16, 5, 1, 1]",6.58,[CLS] M R L E I V A V G R R P P A W I T E G F E T F A A R M P R H L P L G L R E V N A G D A R R S G D V V R A R A Q E A D H L L S A V G D A R L I A L E E T G K A W T T R D L A D Y L G D A M Q Q G D D L A F V I G G A D G L D P R C L Q A A E R R W S L S A L T L P H M L V R V V V A E Q L Y R A W T L L A G H P Y H R G G S P D C A R R P L [UNK] [UNK] [SEP]
,,,,


i=3
Seq label: 1
Input seq (raw): MGDETVKRIESPKNARVKQWKKLQTKKGRDETGLFLLEGFHLVEEAVKSRAPLVELMVDERTAIPPGWDVSVPVVIVTEAVMKAISSTETPQGIAAVCRQLPAELEGVKTALLIDAVQDPGNLGTMIRTADAAGIDAVILGEGCADVYNPKVVRATQGSLFHLPVVKGDLAQWIARFKEQGIPVYGTALENAVDYRTVPPSSSFALLVGNEGSGVRREWLEMTTETIYIPIYGQAESLNVAVAAGILLYSLQAVR;1
Returned (tokenized[:10]): [21, 7, 14, 9, 15, 8, 12, 13, 11, 9]
[1m Visualization For Score [0m


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
1.0,1 (0.00),"[21, 7, 14, 9, 15, 8, 12, 13, 11, 9, 10, 16, 12, 17, 6, 13, 8, 12, 18, 24, 12, 12, 5, 18, 15, 12, 12, 7, 13, 14, 9, 15, 7, 5, 19, 5, 5, 9, 7, 19, 22, 5, 8, 9, 9, 6, 8, 12, 10, 13, 6, 16, 5, 8, 9, 5, 21, 8, 14, 9, 13, 15, 6, 11, 16, 16, 7, 24, 14, 8, 10, 8, 16, 8, 8, 11, 8, 15, 9, 6, 8, 21, 12, 6, 11, 10, 10, 15, 9, 15, 16, 18, 7, 11, 6, 6, 8, 23, 13, 18, 5, 16, 6, 9, 5, 9, 7, 8, 12, 15, 6, 5, 5, 11, 14, 6, 8, 18, 14, 16, 7, 17, 5, 7, 15, 21, 11, 13, 15, 6, 14, 6, 6, 7, 11, 14, 6, 8, 11, 5, 7, 9, 7, 23, 6, 14, 8, 20, 17, 16, 12, 8, 8, 13, 6, 15, 18, 7, 10, 5, 19, 22, 5, 16, 8, 8, 12, 7, 14, 5, 6, 18, 24, 11, 6, 13, 19, 12, 9, 18, 7, 11, 16, 8, 20, 7, 15, 6, 5, 9, 17, 6, 8, 14, 20, 13, 15, 8, 16, 16, 10, 10, 10, 19, 6, 5, 5, 8, 7, 17, 9, 7, 10, 7, 8, 13, 13, 9, 24, 5, 9, 21, 15, 15, 9, 15, 11, 20, 11, 16, 11, 20, 7, 18, 6, 9, 10, 5, 17, 8, 6, 8, 6, 6, 7, 11, 5, 5, 20, 10, 5, 18, 6, 8, 13, 1, 1]",10.65,[CLS] M G D E T V K R I E S P K N A R V K Q W K K L Q T K K G R D E T G L F L L E G F H L V E E A V K S R A P L V E L M V D E R T A I P P G W D V S V P V V I V T E A V M K A I S S T E T P Q G I A A V C R Q L P A E L E G V K T A L L I D A V Q D P G N L G T M I R T A D A A G I D A V I L G E G C A D V Y N P K V V R A T Q G S L F H L P V V K G D L A Q W I A R F K E Q G I P V Y G T A L E N A V D Y R T V P P S S S F A L L V G N E G S G V R R E W L E M T T E T I Y I P I Y G Q A E S L N V A V A A G I L L Y S L Q A V R [UNK] [UNK] [SEP]
,,,,


i=4
Seq label: 1
Input seq (raw): MIQEGGQGGEQPVLRLDLALVHYPVCNKNGETIGSAVTNLDLHDIARAGRTFGIDTLYIVTPFADQQALVRDILAHWQTGHGATYNPKRKEALALVRLCHDLAELYELVQAKWRQRPTVLATSAKAQANQLDFTEARRRIFSGEPHLILFGTGWGMAPEVFAEVDALLPPIVGLGEYNHLSVRSAAAIVLDRVSGIH;1
Returned (tokenized[:10]): [21, 11, 18, 9, 7, 7, 18, 7, 7, 9]
[1m Visualization For Score [0m


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
1.0,1 (0.00),"[21, 11, 18, 9, 7, 7, 18, 7, 7, 9, 18, 16, 8, 5, 13, 5, 14, 5, 6, 5, 8, 22, 20, 16, 8, 23, 17, 12, 17, 7, 9, 15, 11, 7, 10, 6, 8, 15, 17, 5, 14, 5, 22, 14, 11, 6, 13, 6, 7, 13, 15, 19, 7, 11, 14, 15, 5, 20, 11, 8, 15, 16, 19, 6, 14, 18, 18, 6, 5, 8, 13, 14, 11, 5, 6, 22, 24, 18, 15, 7, 22, 7, 6, 15, 20, 17, 16, 12, 13, 12, 9, 6, 5, 6, 5, 8, 13, 5, 23, 22, 14, 5, 6, 9, 5, 20, 9, 5, 8, 18, 6, 12, 24, 13, 18, 13, 16, 15, 8, 5, 6, 15, 10, 6, 12, 6, 18, 6, 17, 18, 5, 14, 19, 15, 9, 6, 13, 13, 13, 11, 19, 10, 7, 9, 16, 22, 5, 11, 5, 19, 7, 15, 7, 24, 7, 21, 6, 16, 9, 8, 19, 6, 9, 8, 14, 6, 5, 5, 16, 16, 11, 8, 7, 5, 7, 9, 20, 17, 22, 5, 10, 8, 13, 10, 6, 6, 6, 11, 8, 5, 14, 13, 8, 10, 7, 11, 22, 1, 1]",8.63,[CLS] M I Q E G G Q G G E Q P V L R L D L A L V H Y P V C N K N G E T I G S A V T N L D L H D I A R A G R T F G I D T L Y I V T P F A D Q Q A L V R D I L A H W Q T G H G A T Y N P K R K E A L A L V R L C H D L A E L Y E L V Q A K W R Q R P T V L A T S A K A Q A N Q L D F T E A R R R I F S G E P H L I L F G T G W G M A P E V F A E V D A L L P P I V G L G E Y N H L S V R S A A A I V L D R V S G I H [UNK] [UNK] [SEP]
,,,,


i=5
Seq label: 1
Input seq (raw): MRLDVLTIFPEYLDPLRHALLGKAIEDGTLEVGVHDLRNWATGGHKAVDDTPYGGGPGMVMKPEVWGPALDDVAAGHVVGAELNSAAPHLKNARHDELGGVEKRIYAADDEDLDLPLLLVPTPAGKPFTQADARAWSNEKHIVFACGRYEGIDQRVIDDAAKRYRVREVSIGDYVLIGGEVAVLVIAEAVVRLIPGVLGNRRSHEEDSFSDGLLEGPSYTKPRTWRGLDVPEVLFSGNHAKVDRWRRDQALLRTQAIRPELIDASLLDSTDLKVLGLDK;1
Returned (tokenized[:10]): [21, 13, 5, 14, 8, 5, 15, 11, 19, 16]
[1m Visualization For Score [0m


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
1.0,1 (0.00),"[21, 13, 5, 14, 8, 5, 15, 11, 19, 16, 9, 20, 5, 14, 16, 5, 13, 22, 6, 5, 5, 7, 12, 6, 11, 9, 14, 7, 15, 5, 9, 8, 7, 8, 22, 14, 5, 13, 17, 24, 6, 15, 7, 7, 22, 12, 6, 8, 14, 14, 15, 16, 20, 7, 7, 7, 16, 7, 21, 8, 21, 12, 16, 9, 8, 24, 7, 16, 6, 5, 14, 14, 8, 6, 6, 7, 22, 8, 8, 7, 6, 9, 5, 17, 10, 6, 6, 16, 22, 5, 12, 17, 6, 13, 22, 14, 9, 5, 7, 7, 8, 9, 12, 13, 11, 20, 6, 6, 14, 14, 9, 14, 5, 14, 5, 16, 5, 5, 5, 8, 16, 15, 16, 6, 7, 12, 16, 19, 15, 18, 6, 14, 6, 13, 6, 24, 10, 17, 9, 12, 22, 11, 8, 19, 6, 23, 7, 13, 20, 9, 7, 11, 14, 18, 13, 8, 11, 14, 14, 6, 6, 12, 13, 20, 13, 8, 13, 9, 8, 10, 11, 7, 14, 20, 8, 5, 11, 7, 7, 9, 8, 6, 8, 5, 8, 11, 6, 9, 6, 8, 8, 13, 5, 11, 16, 7, 8, 5, 7, 17, 13, 13, 10, 22, 9, 9, 14, 10, 19, 10, 14, 7, 5, 5, 9, 7, 16, 10, 20, 15, 12, 16, 13, 15, 24, 13, 7, 5, 14, 8, 16, 9, 8, 5, 19, 10, 7, 17, 22, 6, 12, 8, 14, 13, 24, 13, 13, 14, 18, 6, 5, 5, 13, 15, 18, 6, 11, 13, 16, 9, 5, 11, 14, 6, 10, 5, 5, 14, 10, 15, 14, 5, 12, 8, 5, 7, 5, 14, 12, 1, 1]",11.57,[CLS] M R L D V L T I F P E Y L D P L R H A L L G K A I E D G T L E V G V H D L R N W A T G G H K A V D D T P Y G G G P G M V M K P E V W G P A L D D V A A G H V V G A E L N S A A P H L K N A R H D E L G G V E K R I Y A A D D E D L D L P L L L V P T P A G K P F T Q A D A R A W S N E K H I V F A C G R Y E G I D Q R V I D D A A K R Y R V R E V S I G D Y V L I G G E V A V L V I A E A V V R L I P G V L G N R R S H E E D S F S D G L L E G P S Y T K P R T W R G L D V P E V L F S G N H A K V D R W R R D Q A L L R T Q A I R P E L I D A S L L D S T D L K V L G L D K [UNK] [UNK] [SEP]
,,,,


i=6
Seq label: 1
Input seq (raw): MTISKAKIKYIRSLEAKKHRDAEGVFVAEGPKVVGDLLAIMPAKLLVATSQWQTPEHLAATTELINVSEDELQKISFLRAPQQVMAVFPKPNQQESGLDTLVATNELTLMLDGIQDPGNLGTIIRLADWFGIRHVVCSNDTADVFNPKVIQATMGSIARVKVSYTPLEPLLDVLPASLPVYGTLLDGTNIYQQDLSSNGIIVMGNEGKGLSPAVRQRVSHKLLIPRFVGTEQGAESLNVAIATAIVCAEFRRQGAQMR;1
Returned (tokenized[:10]): [21, 15, 11, 10, 12, 6, 12, 11, 12, 20]
[1m Visualization For Score [0m


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
1.0,1 (0.00),"[21, 15, 11, 10, 12, 6, 12, 11, 12, 20, 11, 13, 10, 5, 9, 6, 12, 12, 22, 13, 14, 6, 9, 7, 8, 19, 8, 6, 9, 7, 16, 12, 8, 8, 7, 14, 5, 5, 6, 11, 21, 16, 6, 12, 5, 5, 8, 6, 15, 10, 18, 24, 18, 15, 16, 9, 22, 5, 6, 6, 15, 15, 9, 5, 11, 17, 8, 10, 9, 14, 9, 5, 18, 12, 11, 10, 19, 5, 13, 6, 16, 18, 18, 8, 21, 6, 8, 19, 16, 12, 16, 17, 18, 18, 9, 10, 7, 5, 14, 15, 5, 8, 6, 15, 17, 9, 5, 15, 5, 21, 5, 14, 7, 11, 18, 14, 16, 7, 17, 5, 7, 15, 11, 11, 13, 5, 6, 14, 24, 19, 7, 11, 13, 22, 8, 8, 23, 10, 17, 14, 15, 6, 14, 8, 19, 17, 16, 12, 8, 11, 18, 6, 15, 21, 7, 10, 11, 6, 13, 8, 12, 8, 10, 20, 15, 16, 5, 9, 16, 5, 5, 14, 8, 5, 16, 6, 10, 5, 16, 8, 20, 7, 15, 5, 5, 14, 7, 15, 17, 11, 20, 18, 18, 14, 5, 10, 10, 17, 7, 11, 11, 8, 21, 7, 17, 9, 7, 12, 7, 5, 10, 16, 6, 8, 13, 18, 13, 8, 10, 22, 12, 5, 5, 11, 16, 13, 19, 8, 7, 15, 9, 18, 7, 6, 9, 10, 5, 17, 8, 6, 11, 6, 15, 6, 11, 8, 23, 6, 9, 19, 13, 13, 18, 7, 6, 18, 21, 13, 1, 1]",5.86,[CLS] M T I S K A K I K Y I R S L E A K K H R D A E G V F V A E G P K V V G D L L A I M P A K L L V A T S Q W Q T P E H L A A T T E L I N V S E D E L Q K I S F L R A P Q Q V M A V F P K P N Q Q E S G L D T L V A T N E L T L M L D G I Q D P G N L G T I I R L A D W F G I R H V V C S N D T A D V F N P K V I Q A T M G S I A R V K V S Y T P L E P L L D V L P A S L P V Y G T L L D G T N I Y Q Q D L S S N G I I V M G N E G K G L S P A V R Q R V S H K L L I P R F V G T E Q G A E S L N V A I A T A I V C A E F R R Q G A Q M R [UNK] [UNK] [SEP]
,,,,


i=7
Seq label: 0
Input seq (raw): RFRMPDRDCTWNDLIRGEITFAAGQIPDFVLQRANGEPLYTLVNPTDDAAMKITHVLRGEDLLSSTPRQIALYEAMIDLGVFDGPVPQFGHLPYVMGEGNKKLSKRDPESSLQMYRDRGYLPEALTNYLALLGWSPGGDVEFFSKEQMAQSFSLERVNPNPARFDVKKCTAINGDWTRHLAIDDLVERLVPYLQRDGVIGSTPSAEDMNLVRAVVPLISERLETLGQASAMVGFLFTDDIAIDAGDAEKIMGDQADDVLREAEAALAGLDEWTTEAIERALRATLIEERGLKPKMAFGPVRLAITGRRVSPPLFESMELLGADRSLSRIRSLHAS
Returned (tokenized[:10]): [13, 19, 13, 21, 16, 14, 13, 14, 23, 15]
[1m Visualization For Score [0m


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,0 (1.00),"[13, 19, 13, 21, 16, 14, 13, 14, 23, 15, 24, 17, 14, 5, 11, 13, 7, 9, 11, 15, 19, 6, 6, 7, 18, 11, 16, 14, 19, 8, 5, 18, 13, 6, 17, 7, 9, 16, 5, 20, 15, 5, 8, 17, 16, 15, 14, 14, 6, 6, 21, 12, 11, 15, 22, 8, 5, 13, 7, 9, 14, 5, 5, 10, 10, 15, 16, 13, 18, 11, 6, 5, 20, 9, 6, 21, 11, 14, 5, 7, 8, 19, 14, 7, 16, 8, 16, 18, 19, 7, 22, 5, 16, 20, 8, 21, 7, 9, 7, 17, 12, 12, 5, 10, 12, 13, 14, 16, 9, 10, 10, 5, 18, 21, 20, 13, 14, 13, 7, 20, 5, 16, 9, 6, 5, 15, 17, 20, 5, 6, 5, 5, 7, 24, 10, 16, 7, 7, 14, 8, 9, 19, 19, 10, 12, 9, 18, 21, 6, 18, 10, 19, 10, 5, 9, 13, 8, 17, 16, 17, 16, 6, 13, 19, 14, 8, 12, 12, 23, 15, 6, 11, 17, 7, 14, 24, 15, 13, 22, 5, 6, 11, 14, 14, 5, 8, 9, 13, 5, 8, 16, 20, 5, 18, 13, 14, 7, 8, 11, 7, 10, 15, 16, 10, 6, 9, 14, 21, 17, 5, 8, 13, 6, 8, 8, 16, 5, 11, 10, 9, 13, 5, 9, 15, 5, 7, 18, 6, 10, 6, 21, 8, 7, 19, 5, 19, 15, 14, 14, 11, 6, 11, 14, 6, 7, 14, 6, 9, 12, 11, 21, 7, 14, 18, 6, 14, 14, 8, 5, 13, 9, 6, 9, 6, 6, 5, 6, 7, 5, 14, 9, 24, 15, 15, 9, 6, 11, 9, 13, 6, 5, 13, 6, 15, 5, 11, 9, 9, 13, 7, 5, 12, 16, 12, 21, 6, 19, 7, 16, 8, 13, 5, 6, 11, 15, 7, 13, 13, 8, 10, 16, 16, 5, 19, 9, 10, 21, 9, 5, 5, 7, 6, 14, 13, 10, 5, 10, 13, 11, 13, 10, 5, 22, 6, 10]",1.02,[CLS] R F R M P D R D C T W N D L I R G E I T F A A G Q I P D F V L Q R A N G E P L Y T L V N P T D D A A M K I T H V L R G E D L L S S T P R Q I A L Y E A M I D L G V F D G P V P Q F G H L P Y V M G E G N K K L S K R D P E S S L Q M Y R D R G Y L P E A L T N Y L A L L G W S P G G D V E F F S K E Q M A Q S F S L E R V N P N P A R F D V K K C T A I N G D W T R H L A I D D L V E R L V P Y L Q R D G V I G S T P S A E D M N L V R A V V P L I S E R L E T L G Q A S A M V G F L F T D D I A I D A G D A E K I M G D Q A D D V L R E A E A A L A G L D E W T T E A I E R A L R A T L I E E R G L K P K M A F G P V R L A I T G R R V S P P L F E S M E L L G A D R S L S R I R S L H A S [SEP]
,,,,


i=8
Seq label: 0
Input seq (raw): MAVALFNVAGGCTRPATRFDIEADSGEPVLPSDKNLEKALDNSLDIVQHRGPDARGQWISPDRLVGFGHVRLSIVDLSSGGNQPFHDSREEVHAVVNGELYGHEEYRAALSNEFDFKGHSDCEIAIALYQHYGLSFLSHLRGEFALVLWDAKRQLFFAARDRYGAKSLYYTFVNGQLLVATEMKSFLAFGWQPEWCTRSIREKTWHHHSATFFKGIRKVKPGHFLTSRNFCPVEQGQYWDLDYPVKTKLETRTEAEMIAGVRERLLEAVRLRLCADVPVGVFLSGGLDSSAIAGMVAHLVRHEGAKIGNDASSKASRIECFTVQFDKESGVDESDVAQRTADWLGFGFHPVPMDEQSLVSRLEDVTWHSESPLPDVNGMGRLAMAERARAAGLKVVLTGEGADEHFGGYAELRADALLEPDLSWPPSCFPEREETWKSVVAGPSTHTHVSAVKDPSSPSSTRRMFNSTKVPNMASLVNALPFAPWMVHDIDTNPETALAESLSGPAREAIAHKWHPLHTSEYVFVKSPLSNFILRYNGDNIDMIHQIESRPVFLDHRLTEYANGLPPGLKMKYNPQDGDFREKHILREAVKPFVTEEVYNRRKQPFMGPSRFAAGGPLHQKLKGLLTRENVEALGFVDWSRVSAYLERAFQEKDSLSLRPALLTAQFVVLSRRFHVPKAKPVQKGENVLLNPGSDDPEEDMPFHRNFTVFNPPYMKLGNSGL
Returned (tokenized[:10]): [21, 6, 8, 6, 5, 19, 17, 8, 6, 7]


RuntimeError: CUDA out of memory. Tried to allocate 1.56 GiB (GPU 0; 44.56 GiB total capacity; 41.47 GiB already allocated; 1.32 GiB free; 42.28 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

### Train data:

In [22]:
# check memory usage:
torch.cuda.memory_summary(device=None, abbreviated=False)



In [19]:
for i in range(20):
    print(f'i={i}')
    sample, sample_label = get_sample_and_label(dataset['train'], i)
    interpret_and_visualize(sample, sample_label)

i=0
Seq label: 0
Input seq (raw): MKKKILQLTLENAIAFKGKANPKAVINKIIPTVKDKSKLKAIGNEVSATIKKVNKLSLSKQKEQLKKINPTFFNKKIKVKKGIIDLPKVGKNFRARFAPSASGPLHIGHALVISLNKIYADKYKGKHILRIEDTNPDANFKEFYKMIPKDYTWLAGKPSETYIQSARVKTYYKYAEQLIKAGHLYVCEETPEEVKAKLKKGIQPFGRRDDPKEVLRKWKRMLTGKYNPGESVVRVKTDLKGKNPALKEWVAFRISGGTHPKVGNKVRVWPLMNFAVAIDDYELKMTHVIRGKDHEDNTKKQKMIYDFFGWTYPEYIHLGRINFKNMIISASDIRKGVEEGIYKGYDDEQVESLASIRKRGIKPKALLKFFYEIGPTKRDKTVDKKEVKHNK
Returned (tokenized[:10]): [21, 12, 12, 12, 11, 5, 18, 5, 15, 5]
[1m Visualization For Score [0m


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,0 (1.00),"[21, 12, 12, 12, 11, 5, 18, 5, 15, 5, 9, 17, 6, 11, 6, 19, 12, 7, 12, 6, 17, 16, 12, 6, 8, 11, 17, 12, 11, 11, 16, 15, 8, 12, 14, 12, 10, 12, 5, 12, 6, 11, 7, 17, 9, 8, 10, 6, 15, 11, 12, 12, 8, 17, 12, 5, 10, 5, 10, 12, 18, 12, 9, 18, 5, 12, 12, 11, 17, 16, 15, 19, 19, 17, 12, 12, 11, 12, 8, 12, 12, 7, 11, 11, 14, 5, 16, 12, 8, 7, 12, 17, 19, 13, 6, 13, 19, 6, 16, 10, 6, 10, 7, 16, 5, 22, 11, 7, 22, 6, 5, 8, 11, 10, 5, 17, 12, 11, 20, 6, 14, 12, 20, 12, 7, 12, 22, 11, 5, 13, 11, 9, 14, 15, 17, 16, 14, 6, 17, 19, 12, 9, 19, 20, 12, 21, 11, 16, 12, 14, 20, 15, 24, 5, 6, 7, 12, 16, 10, 9, 15, 20, 11, 18, 10, 6, 13, 8, 12, 15, 20, 20, 12, 20, 6, 9, 18, 5, 11, 12, 6, 7, 22, 5, 20, 8, 23, 9, 9, 15, 16, 9, 9, 8, 12, 6, 12, 5, 12, 12, 7, 11, 18, 16, 19, 7, 13, 13, 14, 14, 16, 12, 9, 8, 5, 13, 12, 24, 12, 13, 21, 5, 15, 7, 12, 20, 17, 16, 7, 9, 10, 8, 8, 13, 8, 12, 15, 14, 5, 12, 7, 12, 17, 16, 6, 5, 12, 9, 24, 8, 6, 19, 13, 11, 10, 7, 7, 15, 22, 16, 12, 8, 7, 17, 12, 8, 13, 8, 24, 16, 5, 21, 17, 19, 6, 8, 6, 11, 14, 14, 20, 9, 5, 12, 21, 15, 22, 8, 11, 13, 7, 12, 14, 22, 9, 14, 17, 15, 12, 12, 18, 12, 21, 11, 20, 14, 19, 19, 7, 24, 15, 20, 16, 9, 20, 11, 22, 5, 7, 13, 11, 17, 19, 12, 17, 21, 11, 11, 10, 6, 10, 14, 11, 13, 12, 7, 8, 9, 9, 7, 11, 20, 12, 7, 20, 14, 14, 9, 18, 8, 9, 10, 5, 6, 10, 11, 13, 12, 13, 7, 11, 12, 16, 12, 6, 5, 5, 12, 19, 19, 20, 9, 11, 7, 16, 15, 12, 13, 14, 12, 15, 8, 14, 12, 12, 9, 8, 12, 22, 17, 12]",9.66,[CLS] M K K K I L Q L T L E N A I A F K G K A N P K A V I N K I I P T V K D K S K L K A I G N E V S A T I K K V N K L S L S K Q K E Q L K K I N P T F F N K K I K V K K G I I D L P K V G K N F R A R F A P S A S G P L H I G H A L V I S L N K I Y A D K Y K G K H I L R I E D T N P D A N F K E F Y K M I P K D Y T W L A G K P S E T Y I Q S A R V K T Y Y K Y A E Q L I K A G H L Y V C E E T P E E V K A K L K K G I Q P F G R R D D P K E V L R K W K R M L T G K Y N P G E S V V R V K T D L K G K N P A L K E W V A F R I S G G T H P K V G N K V R V W P L M N F A V A I D D Y E L K M T H V I R G K D H E D N T K K Q K M I Y D F F G W T Y P E Y I H L G R I N F K N M I I S A S D I R K G V E E G I Y K G Y D D E Q V E S L A S I R K R G I K P K A L L K F F Y E I G P T K R D K T V D K K E V K H N K [SEP]
,,,,


i=1
Seq label: 0
Input seq (raw): MCGIFGVVEFRGGTVDKSLIRQSAETQTHRGPDSIGVFSADGVGLGHNRLSLVDLSERANQPFLDETGRYALVFNGEIYNFHELKAELEGEGQTFRTTSDTEVLLYLLLRQGAEVALPKLNGMFAFALVDLKTRQVTMARDRFGMKPLHYHATADRLIFASETAAFGPWMEMRPHAGTVAAYLMNFGGPTRGVTFFDGIYQLGPGEVMTAAPGQAPDIRPFFALTDFIDDGEYDRLLGRSETEIVDEFEALMTDSIRLHAFADARVGAFCSGGVDSSLIVALASRSNSAIELFHANVVGSWSEVEAARALARHLKLELNAVDVVEQDFVTSIPRVMRHYGYPFTYHPNCGPLMMIAGLARDTGVKGLLSGEGSDEMFLGYPWLGRKRITDAWDRMRDGLAGAVRRIPAVGTILLPEQQLNAMKVRNILNGREMLDDLNKVSDALSHSRAAARDPRMRWTLDYMHHHLRTLLHRNDTMGMAASIEARFPFLENRIAHFAVNLPARHKLKFSPFTLEKAHPFIRDKWVVREVADRYIPRDLSQRIKIGFWTTVFQRLDISERYFATSGLGDMLSLSRRQFSDLVAEASPTFRLKLLHLDVWQRICVERQEEAAPAALLSDHVRILTESEARSTRRAGAKKKGGAAQLPSAPV
Returned (tokenized[:10]): [21, 23, 7, 11, 19, 7, 8, 8, 9, 19]


RuntimeError: CUDA out of memory. Tried to allocate 1.27 GiB (GPU 0; 44.56 GiB total capacity; 41.71 GiB already allocated; 258.00 MiB free; 43.35 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

## [Multi-Embedding attribution](https://captum.ai/tutorials/Bert_SQUAD_Interpret#Multi-Embedding-attribution)

Modify previous code:

In [23]:
def predict2(inputs, token_type_ids=None, position_ids=None, attention_mask=None):
    score = model(inputs, 
                  token_type_ids=token_type_ids,
                  position_ids=position_ids, 
                  attention_mask=attention_mask)
    return score[0]

In [24]:
def custom_forward2(inputs, token_type_ids=None, position_ids=None, attention_mask=None):
    preds = predict2(inputs,
                   token_type_ids=token_type_ids,
                   position_ids=position_ids,
                   attention_mask=attention_mask)
    return torch.softmax(preds, dim = 1)[0][0].unsqueeze(-1)

In [25]:
def construct_input_ref_token_type_pair(input_ids, sep_ind=0):
    seq_len = input_ids.size(1)
    token_type_ids = torch.tensor([[0 if i <= sep_ind else 1 for i in range(seq_len)]], device=device)
    ref_token_type_ids = torch.zeros_like(token_type_ids, device=device)
    return token_type_ids, ref_token_type_ids

def construct_input_ref_pos_id_pair(input_ids):
    seq_length = input_ids.size(1)
    position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
    # we could potentially also use random permutation with `torch.randperm(seq_length, device=device)`
    ref_position_ids = torch.zeros(seq_length, dtype=torch.long, device=device)

    position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
    ref_position_ids = ref_position_ids.unsqueeze(0).expand_as(input_ids)
    return position_ids, ref_position_ids
    
def construct_attention_mask(input_ids):
    return torch.ones_like(input_ids)

In [26]:
tokenized_sample, tokenized_sample_label = get_sample_and_label(dataset['test'], 200)
tokenized_sample = tokenized_sample[1:len(tokenized_sample)-1]
print(tokenized_sample)

Seq label: 0
Input seq (raw): ISFLSYLDVSPDHHMLPTGRAQNIMILSRTSSLVDDDDDNVNSTVIAIDKDKNSHYAVKWAVDNLLGRSTDHCIQLVHVRNQCLHPHDFDQAVSREGRPPNEPELQQLFLPYRGFCARKGIQAKEVILHDIDIPSALIDYIAHHSISNIVVGASHRNAITRKFRDADVPSSLFKSAPASCAIYVISKGKVQSTRPAGRSETSRQRSQKVVRHTAHPDTHDSDDTNRNSVVVGRWRSTGSDIFSLDRSSDSLHTPQSNFGSSSRTSSPTLSIDSYASASSSQRNSDSSEPFGFRPYDMYLDNLESSVALESSNSPGSSQTIKGIEAEKMRLRIELKHTMDTYNSVCKEAVVARQKAGELQQWKKFEEQHKLEEAKLAEEAALVLAEVERHKTKAALEAAKMQQRLVEMETQRRKNTEMQAKQEAEEKKRAMDTLANNNVVYRRYSMTEIEVATDHFNSALKIGEGGYGPVYKGVLDHTIVAIKILRPDLSQGQQQFQREIEVLSCIRHPNMVLLLGACPEYGCLVYEYMDNGSLEDRLFRKDDTPPIPWPTRFKIAAEIATGLRFLQTDPEPIVHRDLKPGNILLDKNYQSKISDVGLARLVPPSAADSVTQYHMTAAAGTFCYIDPEYQQTGELSVKSDIYSLGVVLLQIITARPPIGLAHQVGEAIEQETFSEMLDPTVTDWPIEEALSLANLALKCCEMRKRDRPDLGSVLLPELDRLRDLGSVYLSINNQMIANETRFPNSDPVIGSTVNEEEEHDILELDIQRRSV
Returned (tokenized[:10]): [11, 10, 19, 5, 10, 20, 5, 14, 8, 10]
[10, 19, 5, 10, 20, 5, 14, 8, 10, 16, 14, 22, 22, 21, 5, 16, 15, 7, 13, 6, 18, 17, 11, 21, 11, 5, 10, 13, 15, 10, 10, 5, 8, 14, 14, 14

In [27]:
input_ids, ref_input_ids, sep_id = construct_input_ref_pair(tokenized_sample, ref_token_id, sep_token_id, cls_token_id)
token_type_ids, ref_token_type_ids = construct_input_ref_token_type_pair(input_ids, sep_id)
position_ids, ref_position_ids = construct_input_ref_pos_id_pair(input_ids)
attention_mask = construct_attention_mask(input_ids)

indices = input_ids[0].detach().tolist()
all_tokens = tokenizer.convert_ids_to_tokens(indices)
print(indices)
print(all_tokens)

[2, 10, 19, 5, 10, 20, 5, 14, 8, 10, 16, 14, 22, 22, 21, 5, 16, 15, 7, 13, 6, 18, 17, 11, 21, 11, 5, 10, 13, 15, 10, 10, 5, 8, 14, 14, 14, 14, 14, 17, 8, 17, 10, 15, 8, 11, 6, 11, 14, 12, 14, 12, 17, 10, 22, 20, 6, 8, 12, 24, 6, 8, 14, 17, 5, 5, 7, 13, 10, 15, 14, 22, 23, 11, 18, 5, 8, 22, 8, 13, 17, 18, 23, 5, 22, 16, 22, 14, 19, 14, 18, 6, 8, 10, 13, 9, 7, 13, 16, 16, 17, 9, 16, 9, 5, 18, 18, 5, 19, 5, 16, 20, 13, 7, 19, 23, 6, 13, 12, 7, 11, 18, 6, 12, 9, 8, 11, 5, 22, 14, 11, 14, 11, 16, 10, 6, 5, 11, 14, 20, 11, 6, 22, 22, 10, 11, 10, 17, 11, 8, 8, 7, 6, 10, 22, 13, 17, 6, 11, 15, 13, 12, 19, 13, 14, 6, 14, 8, 16, 10, 10, 5, 19, 12, 10, 6, 16, 6, 10, 23, 6, 11, 20, 8, 11, 10, 12, 7, 12, 8, 18, 10, 15, 13, 16, 6, 7, 13, 10, 9, 15, 10, 13, 18, 13, 10, 18, 12, 8, 8, 13, 22, 15, 6, 22, 16, 14, 15, 22, 14, 10, 14, 14, 15, 17, 13, 17, 10, 8, 8, 8, 7, 13, 24, 13, 10, 15, 7, 10, 14, 11, 19, 10, 5, 14, 13, 10, 10, 14, 10, 5, 22, 15, 16, 18, 10, 17, 19, 7, 10, 10, 10, 13, 15, 10, 10, 16, 15

Have a look into the sub-embeddings of `BertEmbeddings` and try to understand the contributions and roles of the predicted positions.

To do so, we will use `LayerIntegratedGradients` for all three layers: `word_embeddings`, `token_type_embeddings` and `position_embeddings`.

Create an instance of `LayerIntegratedGradients` and compute the attributions with respect to all those embeddings and summarize them for each word token:

In [28]:
lig2 = LayerIntegratedGradients(custom_forward2,  [model.bert.embeddings.word_embeddings,
                                 model.bert.embeddings.token_type_embeddings,
                                 model.bert.embeddings.position_embeddings])
lig2



<captum.attr._core.layer.layer_integrated_gradients.LayerIntegratedGradients at 0x7f0f88e16af0>

In [29]:
attributions = lig2.attribute(inputs=(input_ids, token_type_ids, position_ids),
                                  baselines=(ref_input_ids, ref_token_type_ids, ref_position_ids),
                                  additional_forward_args=(attention_mask))
attributions

RuntimeError: CUDA out of memory. Tried to allocate 152.00 MiB (GPU 0; 44.56 GiB total capacity; 43.38 GiB already allocated; 66.00 MiB free; 43.54 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [None]:
attributions_word = summarize_attributions(attributions[0])
attributions_token_type = summarize_attributions(attributions[1])
attributions_position = summarize_attributions(attributions[2])

In [None]:
print(attributions_word)
print(attributions_token_type)
print(attributions_position)

Auxilary function to help us compute top K attributions and corresponding indices:

In [None]:
def get_top_k_attributed_tokens(attrs, k=5):
    values, indices = torch.topk(attrs, k)
    top_tokens = [all_tokens[idx] for idx in indices]
    return top_tokens, values, indices

Remove interpretation hooks from all layers after finishing attribution.

Compute top K attributions for all sub-embeddings and place them in a dataframe for better visualization:

In [None]:
import pandas as pd

top_words, top_words_val, top_word_ind = get_top_k_attributed_tokens(attributions_word)
top_token_type, top_token_type_val, top_token_type_ind = get_top_k_attributed_tokens(attributions_token_type)
top_pos, top_pos_val, pos_ind = get_top_k_attributed_tokens(attributions_position)

df = pd.DataFrame({'Word (Index), Attribution': ["{} (i={}), {}".format(word, pos, round(val.item(),2)) for word, pos, val in zip(top_words, top_word_ind, top_words_val)],
                   'Token Type (Index), Attribution': ["{} (i={}), {}".format(ttype, pos, round(val.item(),2)) for ttype, pos, val in zip(top_token_type, top_token_type_ind, top_words_val)],
                   'Position (Index), Attribution': ["{} (i={}), {}".format(position, pos, round(val.item(),2)) for position, pos, val in zip(top_pos, pos_ind, top_pos_val)]})

# df.style.apply(['cell_ids: False'])
# just prints out token and its index:
# ['{}({})'.format(token, str(i)) for i, token in enumerate(all_tokens)]

Top 5 attribution results from all three embedding types in predicting the class:

In [None]:
df

*Word embeddings help to focus more on the surrounding tokens of the predicted position.*

*...*

## [Interpreting Bert Layers](https://captum.ai/tutorials/Bert_SQUAD_Interpret#Interpreting-Bert-Layers)

Let's look into the layers of the network - into the distribution of attribution scores for each token across all layers in Bert model and dive deeper into specific tokens.

Use the `LayerConductance` layer attribution algorithm. *It is encouraged you to try out and compare the results with other algorithms as well.*

In [None]:
def construct_whole_bert_embeddings(input_ids, ref_input_ids, \
                                    token_type_ids=None, ref_token_type_ids=None, \
                                    position_ids=None, ref_position_ids=None):
    input_embeddings = model.bert.embeddings(input_ids, token_type_ids=token_type_ids, position_ids=position_ids)
    ref_input_embeddings = model.bert.embeddings(ref_input_ids, token_type_ids=ref_token_type_ids, position_ids=ref_position_ids)
    
    return input_embeddings, ref_input_embeddings

In [None]:
def predict3(input_emb, attention_mask=None):
    score = model(inputs_embeds=input_emb, attention_mask=attention_mask, )
    return score[0].max(1).values

Sample:

In [None]:
tokenized_sample = dataset['test'][1]['input_ids']
tokenized_sample = tokenized_sample[1:len(tokenized_sample)-1]
print(tokenized_sample)

In [None]:
input_ids, ref_input_ids, sep_id = construct_input_ref_pair(tokenized_sample, ref_token_id, sep_token_id, cls_token_id)
token_type_ids, ref_token_type_ids = construct_input_ref_token_type_pair(input_ids, sep_id)
position_ids, ref_position_ids = construct_input_ref_pos_id_pair(input_ids)
attention_mask = construct_attention_mask(input_ids)

Iterate over all layers and compute the attributions for all tokens. 

Additionally, choose a specific token that we would like to examine in detail, specified by an id `token_to_explain` and store related information in a separate array.

In [None]:
from captum.attr import LayerConductance

layer_attrs = []

# the token that we would like to examine separately:
token_to_explain = 126 # the index of the token that we would like to examine more thoroughly
layer_attrs_dist = []

input_embeddings, ref_input_embeddings = construct_whole_bert_embeddings(input_ids, ref_input_ids, \
                                         token_type_ids=token_type_ids, ref_token_type_ids=ref_token_type_ids, \
                                         position_ids=position_ids, ref_position_ids=ref_position_ids)

In [None]:
input_embeddings

In [None]:
ref_input_embeddings

In [None]:
for i in range(model.config.num_hidden_layers):
    lc = LayerConductance(predict3, model.bert.encoder.layer[i])
    layer_attributions = lc.attribute(inputs=input_embeddings, baselines=ref_input_embeddings, additional_forward_args=(attention_mask))[0]
    layer_attrs.append(summarize_attributions(layer_attributions).cpu().detach().tolist())

    # storing attributions of the token id that we would like to examine in more detail in token_to_explain
    layer_attrs_dist.append(layer_attributions[0,token_to_explain,:].cpu().detach().tolist())

Heat map of attributions across all layers and tokens: 

In [None]:
# BERT has 12 layers (https://www.analyticsvidhya.com/blog/2021/05/all-you-need-to-know-about-bert/)
model.config.num_hidden_layers

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

fig, ax = plt.subplots(figsize=(40,5))  # change 40 here if the input seq is shorter/ longer
xticklabels=all_tokens
yticklabels=list(range(1,13))  
ax = sns.heatmap(np.array(layer_attrs), xticklabels=xticklabels, yticklabels=yticklabels, linewidth=0.2, cmap='BuPu')  # cmap='viridis'
plt.xlabel('Tokens')
plt.ylabel('Layers')
plt.show()

We can observe high attribution score for token `AGAGGT` in the last 10 layers. Token in first positions is especially prominent in the last two layers.

Dig deeper into specific tokens and look into the distribution of attributions per layer for the chosen token. 

The box plot diagram below shows the presence of outliers especially in almost all layers.

In [None]:
fig, ax = plt.subplots(figsize=(15,15))
ax = sns.boxplot(data=layer_attrs_dist)
plt.xlabel('Layers')
plt.ylabel('Attribution')
plt.show()

In addition to that we can also look into the distribution of attributions in each layer for any input token. This will help us to better understand and compare the distributional patterns of attributions across multiple layers. We can for example represent attributions as a probability density function (`pdf`) and compute the entropy of it in order to estimate the entropy of attributions in each layer. This can be easily computed using a histogram.

In [None]:
def pdf_attr(attrs, bins=100):
    return np.histogram(attrs, bins=bins, density=True)[0]

In this particular case let's compute the `pdf` for the attributions of the chosen token. We can however do it for all tokens.

Compute and visualize the `pdf`s and entropies using Shannon's Entropy measure for each layer for the chosen token:

In [None]:
layer_attrs_pdf = map(lambda layer_attrs_dist: pdf_attr(layer_attrs_dist), layer_attrs_dist)
layer_attrs_pdf = np.array(list(layer_attrs_pdf))

# summing attribution along embedding diemension for each layer
# size: #layers
attr_sum = np.array(layer_attrs_dist).sum(-1)

# size: #layers
layer_attrs_pdf_norm = np.linalg.norm(layer_attrs_pdf, axis=-1, ord=1)

#size: #bins x #layers
layer_attrs_pdf = np.transpose(layer_attrs_pdf)

#size: #bins x #layers
layer_attrs_pdf = np.divide(layer_attrs_pdf, layer_attrs_pdf_norm, where=layer_attrs_pdf_norm!=0)

The plot below visualizes the probability mass function (`pmf`) of attributions for each layer for the position of the chosen token. From the plot we can observe that the distributions are taking bell-curved shapes with different means and variances. We can now use attribution `pdf`s to compute entropies in the next cell.

In [None]:
fig, ax = plt.subplots(figsize=(20,10))
plt.plot(layer_attrs_pdf)
plt.xlabel('Bins')
plt.ylabel('Density')
plt.legend(['Layer '+ str(i) for i in range(1,13)])
plt.show()

Calculate and visualize attribution entropies based on Shannon entropy measure where the X-axis corresponds to the number of layers and the Y-axis corresponds to the total attribution in that layer. The size of the circles for each (layer, total_attribution) pair correspond to the normalized entropy value at that point.

In this particular example, we observe that the entropy changes pretty from layer to layer. 

**In a general case entropy can provide us an intuition about the distributional characteristics of attributions in each layer and can be useful especially when comparing it across multiple tokens.**


In [None]:
fig, ax = plt.subplots(figsize=(20,10))

# replacing 0s with 1s. np.log(1) = 0 and np.log(0) = -inf
layer_attrs_pdf[layer_attrs_pdf == 0] = 1
layer_attrs_pdf_log = np.log2(layer_attrs_pdf)

# size: #layers
entropies= -(layer_attrs_pdf * layer_attrs_pdf_log).sum(0)

plt.scatter(np.arange(12), attr_sum, s=entropies * 100)
plt.xlabel('Layers')
plt.ylabel('Total Attribution')
plt.show()

## Attention matrices, their importance scores, and vector norms

Continuing according to this tutorial: https://captum.ai/tutorials/Bert_SQUAD_Interpret2 

As proposed in paper [Attention is Not Only a Weight: Analyzing Transformers with Vector Norms](https://arxiv.org/abs/2004.10102) we will compare attention matrices with their importance scores when we attribute them to a particular class, and vector norms.

It will be shown that the importance scores computed for the attention matrices and specific class are more meaningful than the attention matrices alone or different norm vectors computed for different input activations.

## [Visualizing Attention Matrices](https://captum.ai/tutorials/Bert_SQUAD_Interpret2#Visualizing-Attention-Matrices)

If we want to get the `attentions` fro mmodel as well, we need to pass "output_attentions=True" when initializing the model as well! 

https://huggingface.co/docs/transformers/main_classes/output

In [None]:
def predict4(inputs, token_type_ids=None, position_ids=None, attention_mask=None):
    score = model(inputs, 
                  token_type_ids=token_type_ids,
                  position_ids=position_ids, 
                  attention_mask=attention_mask)
    return score[0], score.attentions

`output_attentions` 

- represent attention matrices (attention probabilities) for all 12 layers and all 12 heads
- represents softmax-normalized dot-product between the key and query vectors
- has been used as an importance indicator of how much a token attends/ relates to another token in the text (https://www.aclweb.org/anthology/W19-4828.pdf)

*Examples:*

- *in translation it is a good indicator of how much a token in one language attends to the corresponding translation in another language*
- *in Question Answering model it indicates which tokens attend/ relate to each other in question, text or answer segment*

Since `output_attentions` contains the layers in a list, we will stack them in order to move everything into a tensor:attention_mask

In [None]:
scores, output_attentions = predict4(input_ids,
                                     token_type_ids=token_type_ids,
                                     position_ids=position_ids, 
                                     attention_mask=attention_mask)

In [None]:
scores

In [None]:
print(output_attentions is None)

In [None]:
# shape -> layer x batch x head x seq_len x seq_len
output_attentions_all = torch.stack(output_attentions)
# output_attentions_all

**Helper function for visualizing Token-To-Token matices:**

*(visualize token-to-token relation/ attention scores for all heads in a given layer or for all layers across all heads)*

In [None]:
# TODO: fix visual (too many tokens)

def visualize_token2token_scores(scores_mat, x_label_name='Head'):
    fig = plt.figure(figsize=(20, 20))

    for idx, scores in enumerate(scores_mat):
        scores_np = np.array(scores)
        ax = fig.add_subplot(4, 3, idx+1)
        # append the attention weights
        im = ax.imshow(scores, cmap='viridis')

        fontdict = {'fontsize': 10}

        ax.set_xticks(range(len(all_tokens)))
        ax.set_yticks(range(len(all_tokens)))

        ax.set_xticklabels(all_tokens, fontdict=fontdict, rotation=90)
        ax.set_yticklabels(all_tokens, fontdict=fontdict)
        ax.set_xlabel('{} {}'.format(x_label_name, idx+1))

        fig.colorbar(im, fraction=0.046, pad=0.04)
    plt.tight_layout()
    plt.show()

**Helper function for visualizing Token-To-Head matrices:**

*(visualize the importance scores for tokens across all heads in all layers)*

In [None]:
def visualize_token2head_scores(scores_mat):
    fig = plt.figure(figsize=(30, 50))

    for idx, scores in enumerate(scores_mat):
        scores_np = np.array(scores)
        ax = fig.add_subplot(6, 2, idx+1)
        # append the attention weights
        im = ax.matshow(scores_np, cmap='viridis')

        fontdict = {'fontsize': 20}

        ax.set_xticks(range(len(all_tokens)))
        ax.set_yticks(range(len(scores)))

        ax.set_xticklabels(all_tokens, fontdict=fontdict, rotation=90)
        ax.set_yticklabels(range(len(scores[0])), fontdict=fontdict)
        ax.set_xlabel('Layer {}'.format(idx+1))

        fig.colorbar(im, fraction=0.046, pad=0.04)
    plt.tight_layout()
    plt.show()

Examine a specific layer: Define a fixed layer id that will be used for visualization purposes. 

In [None]:
layer = 11

Visualize attention matrices for the selected layer `layer`:

In [None]:
visualize_token2token_scores(output_attentions_all[layer].squeeze().detach().cpu().numpy())

Based on the visualizations above we observe that there is a high attention set along the diagonals and on an uninformative token such as `[SEP]`. This is something that was observed in previous papers which indicates that attention matrices aren't always a good indicator of finding which tokens are more important or which token is related to which. We observe similar pattern when we examine another layer.

In the cell below we compute and visualize L2 norm across head axis for all 12 layer. This provides a summary for each layer across all heads.

Defining normalization function depending on pytorch version.

In [None]:
if torch.__version__ >= '1.7.0':
    norm_fn = torch.linalg.norm
else:
    norm_fn = torch.norm

In [None]:
visualize_token2token_scores(norm_fn(output_attentions_all, dim=2).squeeze().detach().cpu().numpy(),
                             x_label_name='Layer')

Based on the visualiziation above we can convince ourselves that attention scores aren't trustworthy measures of importances for token-to-token relations across all layers. We see strong signal along the diagonal. These signals, however, aren't true indicators of what semantic the model learns.

## [Visualizing attribution/ importance scores](https://captum.ai/tutorials/Bert_SQUAD_Interpret2#Visualizing-attribution-/-importance-scores)

In the cells below we visualize the attribution scores of attention matrices for the prediction and compare with the actual attention matrices. To do so, first of all, we compute the attribution scores using `LayerConductance` algorithm.

Helper function to summarize attributions for each word token in the sequence:

In [None]:
# difference from "summarize_attributions": torch.norm -> norm_fn

def summarize_attributions2(attributions):
    attributions = attributions.sum(dim=-1).squeeze(0)
    attributions = attributions/ norm_fn(attributions)
    return attributions

## [Interpreting BertLayer Outputs and Self-Attention Matrices in each Layer](https://captum.ai/tutorials/Bert_SQUAD_Interpret2#Visualizing-Attention-Matrices)

Let's look into the layers of our network - into the distribution of attribution scores for each token across all layers and attribution matrices for each head in all layers in Bert model => Layer Conductance algorithm.

Sample:

In [None]:
tokenized_sample = dataset['test'][1]['input_ids']
tokenized_sample = tokenized_sample[1:len(tokenized_sample)-1]
input_ids, ref_input_ids, sep_id = construct_input_ref_pair(tokenized_sample, ref_token_id, sep_token_id, cls_token_id)
token_type_ids, ref_token_type_ids = construct_input_ref_token_type_pair(input_ids, sep_id)
position_ids, ref_position_ids = construct_input_ref_pos_id_pair(input_ids)
attention_mask = construct_attention_mask(input_ids)

Configure the `InterpretableEmbeddingsBase` again, in this case in order to interpret the layers of our model:

In [None]:
from captum.attr import configure_interpretable_embedding_layer

interpretable_embedding = configure_interpretable_embedding_layer(model, 'bert.embeddings.word_embeddings')

In [None]:
def predict5(inputs, token_type_ids=None, position_ids=None, attention_mask=None):
    score = model(inputs_embeds=inputs, 
                  token_type_ids=token_type_ids,
                  position_ids=position_ids, 
                  attention_mask=attention_mask)
    return score[0].max(1).values

Iterate over all layers and compute the attributions w.r.t. all tokens in the input and attention matrices:

In [None]:
layer_attrs = []
layer_attn_mat = []

for i in range(model.config.num_hidden_layers):
    lc = LayerConductance(predict5, model.bert.encoder.layer[i])
    layer_attributions = lc.attribute(inputs=input_embeddings, baselines=ref_input_embeddings, additional_forward_args=(token_type_ids, position_ids, attention_mask))
    layer_attrs.append(summarize_attributions2(layer_attributions[0]))
    layer_attn_mat.append(layer_attributions[1])

In [None]:
# layer x seq_len
layer_attrs = torch.stack(layer_attrs)

# layer x batch x head x seq_len x seq_len
layer_attn_mat = torch.stack(layer_attn_mat)

## [Interpreting Attribution Scores for Attention Matrices](https://captum.ai/tutorials/Bert_SQUAD_Interpret2#Interpreting-Attribution-Scores-for-Attention-Matrices)

Visualize the attribution scores of position predictions w.r.t. attention matrices. Note that each layer has 12 heads, hence attention matrices. We will first visualize for a specific layer and head, later we will summarize across all heads in order to gain a bigger picture.

Below we visualize the attribution scores of 12 heads for selected layer layer for position prediction:

In [None]:
visualize_token2token_scores(layer_attn_mat[layer].squeeze().cpu().detach().numpy())

As we can see from the visualizations above, in contrary to attention scores the attributions of specific target w.r.t. to those scores are more meaningful and most importantly -> they do not show diagonal patterns. 

These observations are for a selected layer. We can change the index of selected layer and examine interesting relationships in other layers.

In the cell below we visualize the attention attribution scores normalized across the head axis.

In [None]:
visualize_token2token_scores(norm_fn(layer_attn_mat, dim=2).squeeze().detach().cpu().numpy(),
                             x_label_name='Layer')

By looking at the visualizations above we can see that the model pays attention to very specific handpicked relationships when making a sprediction for the position.

## [Computing and Visualizing Vector Norms](https://captum.ai/tutorials/Bert_SQUAD_Interpret2#Computing-and-Visualizing-Vector-Norms)

In this section of the tutorial we will compute Vector norms for activation layers such as `||f(x)||`, `||α * f(x)||` and `||Σαf(x)||` as also described in https://arxiv.org/pdf/2004.10102.pdf

As also shown in the paper mentioned above, normalized activations are better indicators of importance scores than the attention scores however they aren't as indicative as the attribution scores. This is because normalized activations `||f(x)||` and `||α * f(x)||` aren't attributed to a specific output prediction. 

Below we define/ extract all parameters that we need to computation vector norms:

In [None]:
output_attentions_all_shape = output_attentions_all.shape

batch = output_attentions_all_shape[1]
num_heads = output_attentions_all_shape[2]
head_size = 64
all_head_size = 768

In order to compute above mentioned norms we need to get access to dense layer's weights and value vector of the self attention layer.

Getting Access to Value Activations: define the list of all layers for which we would like to access Value Activations

In [None]:
layers = [model.bert.encoder.layer[layer].attention.self.value for layer in range(len(model.bert.encoder.layer))]

Use Captum's `LayerActivation` algorithm to access the outputs of all layers:

*(Perform several transformations with the value layer activations and bring it to the shape so that we can compute different norms. The transformations are done the same way as it is described in the original paper and corresponding github implementation.)*

In [None]:
from captum.attr import LayerActivation

la = LayerActivation(predict5, layers)

value_layer_acts = la.attribute(input_embeddings, additional_forward_args=(token_type_ids, position_ids, attention_mask))
# shape -> layer x batch x seq_len x all_head_size
value_layer_acts = torch.stack(value_layer_acts)

In [None]:
new_x_shape = value_layer_acts.size()[:-1] + (num_heads, head_size)
value_layer_acts = value_layer_acts.view(*new_x_shape)

# layer x batch x neum_heads x 1 x head_size
value_layer_acts = value_layer_acts.permute(0, 1, 3, 2, 4)

value_layer_acts = value_layer_acts.permute(0, 1, 3, 2, 4).contiguous()
value_layer_acts_shape = value_layer_acts.size()

# layer x batch x seq_length x num_heads x 1 x head_size
value_layer_acts = value_layer_acts.view(value_layer_acts_shape[:-1] + (1, value_layer_acts_shape[-1],))

print('value_layer_acts: ', value_layer_acts.shape)

Getting Access to Dense Features: transform dense features so that we can use them to compute `||f(x)||` and `||α * f(x)||`

In [None]:
dense_acts = torch.stack([dlayer.attention.output.dense.weight for dlayer in model.bert.encoder.layer])

dense_acts = dense_acts.view(len(layers), all_head_size, num_heads, head_size)

# layer x num_heads x head_size x all_head_size
dense_acts = dense_acts.permute(0, 2, 3, 1).contiguous()

Compute `f(x)` score by multiplying the value vector with the weights of the dense vector for all layers:

In [None]:
# layers, batch, seq_length, num_heads, 1, all_head_size
f_x = torch.stack([value_layer_acts_i.matmul(dense_acts_i) for value_layer_acts_i, dense_acts_i in zip(value_layer_acts, dense_acts)])
f_x.shape

In [None]:
# layer x batch x seq_length x num_heads x 1 x all_head_size)
f_x_shape = f_x.size() 
f_x = f_x.view(f_x_shape[:-2] + (f_x_shape[-1],))
f_x = f_x.permute(0, 1, 3, 2, 4).contiguous() 

#(layers x batch, num_heads, seq_length, all_head_size)
f_x_shape = f_x.size() 

#(layers x batch, num_heads, seq_length)
f_x_norm = norm_fn(f_x, dim=-1)

Visualize `||f(x)||` scores for all layers and examine the distribution of those scores:

In [None]:
def visualize_token2head_scores(scores_mat):
    fig = plt.figure(figsize=(30, 50))

    for idx, scores in enumerate(scores_mat):
        scores_np = np.array(scores)
        ax = fig.add_subplot(6, 2, idx+1)
        # append the attention weights
        im = ax.matshow(scores_np, cmap='viridis')

        fontdict = {'fontsize': 20}

        ax.set_xticks(range(len(all_tokens)))
        ax.set_yticks(range(len(scores)))

        ax.set_xticklabels(all_tokens, fontdict=fontdict, rotation=90)
        ax.set_yticklabels(range(len(scores[0])), fontdict=fontdict)
        ax.set_xlabel('Layer {}'.format(idx+1))

        fig.colorbar(im, fraction=0.046, pad=0.04)
    plt.tight_layout()
    plt.show()

In [None]:
visualize_token2head_scores(f_x_norm.squeeze().detach().cpu().numpy())