# Interpretation of [BertForSequenceClassification](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertForSequenceClassification) using [**Captum**](https://captum.ai/)

Source code info:

Used notebook: https://gist.github.com/davidefiocco/3e1a0ed030792230a33c726c61f6b3a5 (adaptation to classification from original tutorial on question answering: https://captum.ai/tutorials/Bert_SQUAD_Interpret)

(Used notebook is based on this github issue: https://github.com/pytorch/captum/issues/303)

Related github issue: https://github.com/pytorch/captum/issues/249

---

Used model: [armheb/DNA_bert_6](https://huggingface.co/armheb/DNA_bert_6?text=The+goal+of+life+is+%5BMASK%5D.)



## Load initial libraries, models, data:

In [1]:
!pip install transformers datasets tokenizers --quiet

In [2]:
!pip install captum



In [3]:
from transformers import BertTokenizer, BertForSequenceClassification, BertConfig

from captum.attr import visualization as viz
from captum.attr import IntegratedGradients, LayerConductance, LayerIntegratedGradients
from captum.attr import configure_interpretable_embedding_layer, remove_interpretable_embedding_layer

import torch

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

### Load tokenizer:

In [4]:
tokenizer = BertTokenizer.from_pretrained('armheb/DNA_bert_6')
tokenizer

PreTrainedTokenizer(name_or_path='armheb/DNA_bert_6', vocab_size=4101, model_max_len=512, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})

### Load model:

In [5]:
model = BertForSequenceClassification.from_pretrained('armheb/DNA_bert_6')
model.to(device)
model.eval()
model.zero_grad()
model

Some weights of the model checkpoint at armheb/DNA_bert_6 were not used when initializing BertForSequenceClassification: ['cls.predictions.decoder.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at armheb/DNA_bert_6 and are n

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(4101, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementw

## Get training data:

In [6]:
from datasets import load_dataset

DATASET_NAME = "simecek/human_nontata_promoters"

# take a small portion of the dataset for time purposes
# the fist and the last 500 samples because this specific dataset is ordered (positive, negative samples)
dataset_train = load_dataset(DATASET_NAME, split='train[:500]+train[-500:]')
dataset_train

Using custom data configuration simecek--human_nontata_promoters-2176576c12d02035
Reusing dataset parquet (/home/jovyan/.cache/huggingface/datasets/simecek___parquet/simecek--human_nontata_promoters-2176576c12d02035/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


Dataset({
    features: ['labels', 'seq'],
    num_rows: 1000
})

In [7]:
dataset_test = load_dataset(DATASET_NAME, split='test[:1000]+test[-1000:]')
dataset_test

Using custom data configuration simecek--human_nontata_promoters-2176576c12d02035
Reusing dataset parquet (/home/jovyan/.cache/huggingface/datasets/simecek___parquet/simecek--human_nontata_promoters-2176576c12d02035/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


Dataset({
    features: ['labels', 'seq'],
    num_rows: 2000
})

In [8]:
# one training sample sequence and its label
dataset_train[0]

{'labels': 0,
 'seq': 'ACAGATTCAGGATGTCCTGTCGGGGCATGGACCCTGGAAAGCTGCGGACACCAGGAGGGCAGGCAAGAGAGTCTCATCTCTTGCTCCCTAGGAGCTATGAGTTGAGGGCGCCGTCTGAGCAGGAGGGACGGACGGGTGCCCAGGGTTTGAGGAAAGAGGGGTGTGGGAAGGACGCATGCTAGAACTTCAGAGCAGTTCAGCAGGTGCAGAATGGGAGTTATCATGGGGACTGTGGGAGAAGGGGCGGTGGG'}

In [9]:
dataset_test[0]

{'labels': 0,
 'seq': 'CAACAGTACAGGATACTGAACAAAACCTGACAGCCTGTCTCAGTCAGGGTTTAGCAGAAGAAACAGGGACCATTCTGGGTGTTGTAAACAGGAGAAGGTGTTTAACAAAGGGTGCTTAAGAAGCCGCAGGAGCAAGCTTCAGGCAGAGTCCCAGACTATAGTGTTAGTCTCCAGTGGCTGCAGCCAGAGGCCAAGAAGCTCCTGCTGTGTCCTGACATCCAGGAAGCTGGAGAGCGGGTGGTAGGCTACCA'}

### Custom K-mer tokenization:

In [10]:
def kmers(s, k=6):
  return [s[i:i + k] for i in range(0, len(s)-k+1)]

def tokenization(x): 
  return tokenizer(" ".join(kmers(x["seq"])))

example = {'seq': 'ATGGAAAGAGGCACCATTCT'}
print(f'Example: {example}')

example_kmers = " ".join(kmers(example['seq']))
print(f'Example_kmers: {example_kmers}')

tokenized_example = tokenization(example)
print(f'Tokenization example: {tokenized_example}')

decoded_example = tokenizer.decode(tokenized_example['input_ids'])
print(f'Decoded tokenized example: {decoded_example}')

Example: {'seq': 'ATGGAAAGAGGCACCATTCT'}
Example_kmers: ATGGAA TGGAAA GGAAAG GAAAGA AAAGAG AAGAGG AGAGGC GAGGCA AGGCAC GGCACC GCACCA CACCAT ACCATT CCATTC CATTCT
Tokenization example: {'input_ids': [2, 501, 1989, 3848, 3089, 56, 212, 835, 3325, 999, 3983, 3629, 2214, 650, 2587, 2142, 3], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}
Decoded tokenized example: [CLS] ATGGAA TGGAAA GGAAAG GAAAGA AAAGAG AAGAGG AGAGGC GAGGCA AGGCAC GGCACC GCACCA CACCAT ACCATT CCATTC CATTCT [SEP]


In [11]:
dataset_train_tokenized = dataset_train.map(tokenization, batched=False)
dataset_test_tokenized = dataset_test.map(tokenization, batched=False)
dataset_train_tokenized

Loading cached processed dataset at /home/jovyan/.cache/huggingface/datasets/simecek___parquet/simecek--human_nontata_promoters-2176576c12d02035/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-d88a00571a062a99.arrow
Loading cached processed dataset at /home/jovyan/.cache/huggingface/datasets/simecek___parquet/simecek--human_nontata_promoters-2176576c12d02035/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec/cache-21554c77d7445ea5.arrow


Dataset({
    features: ['labels', 'seq', 'input_ids', 'token_type_ids', 'attention_mask'],
    num_rows: 1000
})

In [12]:
print(dataset_train_tokenized[0]['input_ids'])

[2, 566, 2250, 795, 3165, 360, 1428, 1601, 2294, 972, 3874, 3195, 479, 1902, 3500, 1698, 2683, 2528, 1908, 3524, 1796, 3075, 4093, 4070, 3980, 3620, 2177, 503, 1999, 3887, 3246, 684, 2724, 2689, 2549, 1989, 3848, 3091, 62, 236, 931, 3712, 2548, 1985, 3831, 3021, 3879, 3215, 557, 2216, 660, 2625, 2296, 980, 3908, 3331, 1021, 4072, 3988, 3651, 2301, 997, 3976, 3601, 2104, 209, 824, 3282, 827, 3294, 875, 3485, 1638, 2443, 1566, 2155, 414, 1642, 2460, 1635, 2430, 1515, 1951, 3695, 2478, 1705, 2712, 2644, 2369, 1272, 979, 3902, 3305, 918, 3660, 2337, 1144, 466, 1850, 3292, 865, 3448, 1492, 1860, 3331, 1024, 4083, 4031, 3824, 2994, 3771, 2782, 2924, 3489, 1656, 2515, 1853, 3304, 916, 3649, 2296, 980, 3908, 3329, 1015, 4048, 3892, 3265, 759, 3024, 3892, 3268, 770, 3068, 4067, 3967, 3567, 1965, 3752, 2708, 2628, 2306, 1018, 4058, 3932, 3425, 1400, 1492, 1857, 3317, 965, 3848, 3089, 56, 212, 836, 3332, 1026, 4092, 4066, 3964, 3556, 1924, 3585, 2037, 4040, 3860, 3137, 247, 976, 3891, 3261, 742, 

## Captum interpretation:

### 1. On untrained model:

*Helper function to perform forward pass of the model and make predictions:*

In [13]:
def predict(inputs):
    return model(inputs)[0]

*Custom forward function that will allow us to access the postitions of our prediction using position input argument:*

In [14]:
### original for question answering looked like this: ######
# def squad_pos_forward_func(inputs, token_type_ids=None, position_ids=None, attention_mask=None, position=0):
#     pred = predict(inputs,
#                    token_type_ids=token_type_ids,
#                    position_ids=position_ids,
#                    attention_mask=attention_mask)
#     pred = pred[position]
#     return pred.max(1).values
############################################################

def custom_forward(inputs):
    preds = predict(inputs)
    return torch.softmax(preds, dim = 1)[0][0].unsqueeze(-1)

Compute attributions with respect to the `BertEmbeddings` layer:

1. define baselines/ references,
2. numericalize baselines and inputs.

*(helper functions to achieve that)*

In [15]:
# Token used for generating token reference:
ref_token_id = tokenizer.pad_token_id
ref_token_id

0

In [16]:
# Token added to the end of the input text:
sep_token_id = tokenizer.sep_token_id
sep_token_id

3

In [17]:
# Token used at the beginning of the input text:
cls_token_id = tokenizer.cls_token_id
cls_token_id

2

In [18]:
def construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id):

    text_ids = tokenizer.encode(text, add_special_tokens=False)
    # construct input token ids
    input_ids = [cls_token_id] + text_ids + [sep_token_id]
    # construct reference token ids 
    ref_input_ids = [cls_token_id] + [ref_token_id] * len(text_ids) + [sep_token_id]

    return torch.tensor([input_ids], device=device), torch.tensor([ref_input_ids], device=device), len(text_ids)

def construct_input_ref_token_type_pair(input_ids, sep_ind=0):
    seq_len = input_ids.size(1)
    token_type_ids = torch.tensor([[0 if i <= sep_ind else 1 for i in range(seq_len)]], device=device)
    ref_token_type_ids = torch.zeros_like(token_type_ids, device=device)# * -1
    return token_type_ids, ref_token_type_ids

def construct_input_ref_pos_id_pair(input_ids):
    seq_length = input_ids.size(1)
    position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
    # we could potentially also use random permutation with `torch.randperm(seq_length, device=device)`
    ref_position_ids = torch.zeros(seq_length, dtype=torch.long, device=device)

    position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
    ref_position_ids = ref_position_ids.unsqueeze(0).expand_as(input_ids)
    return position_ids, ref_position_ids
    
def construct_attention_mask(input_ids):
    return torch.ones_like(input_ids)

Define the input sequence `text` that we'd like to use as an input for our Bert model and interpret what the model was focusing on when predicting the class:

In [19]:
seq = "CAACAGTACAGGATACTGAACAAAACCTGACAGCCTGTCTCAGTCAGGGTTTAGCAGAAGAAACAGGGACCATTCTGGGTGTTGTAAACAGGAGAAGGTGTTTAACAAAGGGTGCTTAAGAAGCCGCAGGAGCAAGCTTCAGGCAGAGTCCCAGACTATAGTGTTAGTCTCCAGTGGCTGCAGCCAGAGGCCAAGAAGCTCCTGCTGTGTCCTGACATCCAGGAAGCTGGAGAGCGGGTGGTAGGCTACCA"
seq_kmers = kmers(seq)
text = tokenizer.encode(seq_kmers, add_special_tokens=False)
text

[2088,
 146,
 569,
 2263,
 845,
 3368,
 1172,
 577,
 2294,
 969,
 3863,
 3150,
 300,
 1185,
 629,
 2503,
 1805,
 3109,
 133,
 517,
 2055,
 15,
 46,
 172,
 673,
 2679,
 2509,
 1832,
 3219,
 575,
 2286,
 940,
 3746,
 2683,
 2526,
 1899,
 3485,
 1640,
 2450,
 1595,
 2269,
 872,
 3476,
 1604,
 2306,
 1018,
 4058,
 3929,
 3416,
 1363,
 1341,
 1256,
 913,
 3637,
 2248,
 785,
 3125,
 197,
 775,
 3085,
 40,
 148,
 580,
 2305,
 1015,
 4047,
 3885,
 3238,
 650,
 2587,
 2142,
 364,
 1444,
 1668,
 2562,
 2044,
 4066,
 3962,
 3548,
 1890,
 3449,
 1493,
 1861,
 3335,
 1037,
 40,
 148,
 577,
 2296,
 977,
 3893,
 3272,
 788,
 3138,
 252,
 994,
 3962,
 3546,
 1881,
 3413,
 1351,
 1293,
 1061,
 133,
 520,
 2068,
 68,
 258,
 1020,
 4067,
 3966,
 3562,
 1945,
 3669,
 2376,
 1297,
 1077,
 200,
 787,
 3135,
 240,
 947,
 3773,
 2792,
 2964,
 3649,
 2296,
 979,
 3901,
 3301,
 904,
 3603,
 2110,
 234,
 923,
 3677,
 2408,
 1428,
 1603,
 2301,
 1000,
 3985,
 3640,
 2258,
 827,
 3295,
 879,
 3501,
 1704,
 2705,
 

Let's numericalize the input `text` and generate corresponding baselines/references for all three sub-embeddings (word, token type and position embeddings) types using our helper functions defined above:

In [20]:
input_ids, ref_input_ids, sep_id = construct_input_ref_pair(text, ref_token_id, sep_token_id, cls_token_id)
print(input_ids)
print(ref_input_ids)
print(sep_id)

tensor([[   2, 2088,  146,  569, 2263,  845, 3368, 1172,  577, 2294,  969, 3863,
         3150,  300, 1185,  629, 2503, 1805, 3109,  133,  517, 2055,   15,   46,
          172,  673, 2679, 2509, 1832, 3219,  575, 2286,  940, 3746, 2683, 2526,
         1899, 3485, 1640, 2450, 1595, 2269,  872, 3476, 1604, 2306, 1018, 4058,
         3929, 3416, 1363, 1341, 1256,  913, 3637, 2248,  785, 3125,  197,  775,
         3085,   40,  148,  580, 2305, 1015, 4047, 3885, 3238,  650, 2587, 2142,
          364, 1444, 1668, 2562, 2044, 4066, 3962, 3548, 1890, 3449, 1493, 1861,
         3335, 1037,   40,  148,  577, 2296,  977, 3893, 3272,  788, 3138,  252,
          994, 3962, 3546, 1881, 3413, 1351, 1293, 1061,  133,  520, 2068,   68,
          258, 1020, 4067, 3966, 3562, 1945, 3669, 2376, 1297, 1077,  200,  787,
         3135,  240,  947, 3773, 2792, 2964, 3649, 2296,  979, 3901, 3301,  904,
         3603, 2110,  234,  923, 3677, 2408, 1428, 1603, 2301, 1000, 3985, 3640,
         2258,  827, 3295,  

In [21]:
token_type_ids, ref_token_type_ids = construct_input_ref_token_type_pair(input_ids, sep_id)
print(token_type_ids)
print(ref_token_type_ids)

tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 1]], device='cuda:0')
tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 

In [22]:
position_ids, ref_position_ids = construct_input_ref_pos_id_pair(input_ids)
print(position_ids)
print(ref_position_ids)

tensor([[  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
          14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
          28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
          42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
          56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,
          70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
          84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
          98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
         112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
         126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139,
         140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153,
         154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167,
         168, 169, 170, 171, 172, 173, 174, 175, 176

In [23]:
attention_mask = construct_attention_mask(input_ids)
print(attention_mask)

tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:0')


In [24]:
indices = input_ids[0].detach().tolist()
print(indices)

[2, 2088, 146, 569, 2263, 845, 3368, 1172, 577, 2294, 969, 3863, 3150, 300, 1185, 629, 2503, 1805, 3109, 133, 517, 2055, 15, 46, 172, 673, 2679, 2509, 1832, 3219, 575, 2286, 940, 3746, 2683, 2526, 1899, 3485, 1640, 2450, 1595, 2269, 872, 3476, 1604, 2306, 1018, 4058, 3929, 3416, 1363, 1341, 1256, 913, 3637, 2248, 785, 3125, 197, 775, 3085, 40, 148, 580, 2305, 1015, 4047, 3885, 3238, 650, 2587, 2142, 364, 1444, 1668, 2562, 2044, 4066, 3962, 3548, 1890, 3449, 1493, 1861, 3335, 1037, 40, 148, 577, 2296, 977, 3893, 3272, 788, 3138, 252, 994, 3962, 3546, 1881, 3413, 1351, 1293, 1061, 133, 520, 2068, 68, 258, 1020, 4067, 3966, 3562, 1945, 3669, 2376, 1297, 1077, 200, 787, 3135, 240, 947, 3773, 2792, 2964, 3649, 2296, 979, 3901, 3301, 904, 3603, 2110, 234, 923, 3677, 2408, 1428, 1603, 2301, 1000, 3985, 3640, 2258, 827, 3295, 879, 3501, 1704, 2705, 2615, 2254, 809, 3222, 585, 2328, 1106, 316, 1250, 890, 3545, 1880, 3410, 1339, 1246, 875, 3487, 1645, 2472, 1682, 2620, 2276, 899, 3582, 2028, 400

In [25]:
all_tokens = tokenizer.convert_ids_to_tokens(indices)
print(all_tokens)

['[CLS]', 'CAACAG', 'AACAGT', 'ACAGTA', 'CAGTAC', 'AGTACA', 'GTACAG', 'TACAGG', 'ACAGGA', 'CAGGAT', 'AGGATA', 'GGATAC', 'GATACT', 'ATACTG', 'TACTGA', 'ACTGAA', 'CTGAAC', 'TGAACA', 'GAACAA', 'AACAAA', 'ACAAAA', 'CAAAAC', 'AAAACC', 'AAACCT', 'AACCTG', 'ACCTGA', 'CCTGAC', 'CTGACA', 'TGACAG', 'GACAGC', 'ACAGCC', 'CAGCCT', 'AGCCTG', 'GCCTGT', 'CCTGTC', 'CTGTCT', 'TGTCTC', 'GTCTCA', 'TCTCAG', 'CTCAGT', 'TCAGTC', 'CAGTCA', 'AGTCAG', 'GTCAGG', 'TCAGGG', 'CAGGGT', 'AGGGTT', 'GGGTTT', 'GGTTTA', 'GTTTAG', 'TTTAGC', 'TTAGCA', 'TAGCAG', 'AGCAGA', 'GCAGAA', 'CAGAAG', 'AGAAGA', 'GAAGAA', 'AAGAAA', 'AGAAAC', 'GAAACA', 'AAACAG', 'AACAGG', 'ACAGGG', 'CAGGGA', 'AGGGAC', 'GGGACC', 'GGACCA', 'GACCAT', 'ACCATT', 'CCATTC', 'CATTCT', 'ATTCTG', 'TTCTGG', 'TCTGGG', 'CTGGGT', 'TGGGTG', 'GGGTGT', 'GGTGTT', 'GTGTTG', 'TGTTGT', 'GTTGTA', 'TTGTAA', 'TGTAAA', 'GTAAAC', 'TAAACA', 'AAACAG', 'AACAGG', 'ACAGGA', 'CAGGAG', 'AGGAGA', 'GGAGAA', 'GAGAAG', 'AGAAGG', 'GAAGGT', 'AAGGTG', 'AGGTGT', 'GGTGTT', 'GTGTTT', 'TGTTTA', 

In [26]:
model(input_ids)

SequenceClassifierOutput(loss=None, logits=tensor([[ 0.1673, -0.0655]], device='cuda:0', grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None)

In [27]:
predict(input_ids)

tensor([[ 0.1673, -0.0655]], device='cuda:0', grad_fn=<AddmmBackward0>)

In [28]:
custom_forward(input_ids)

tensor([0.5579], device='cuda:0', grad_fn=<UnsqueezeBackward0>)

There are 2 different ways of computing the attributions for emebdding layers. One option is to use `LayerIntegratedGradients` and compute the attributions with respect to `BertEmbedding`. The second option is to use `LayerIntegratedGradients` for each `word_embeddings`, `token_type_embeddings` and `position_embeddings` and compute the attributions w.r.t each embedding vector.

In [29]:
lig = LayerIntegratedGradients(custom_forward, model.bert.embeddings)

In [30]:
attributions, delta = lig.attribute(inputs=input_ids,
                                    baselines=ref_input_ids,
                                    return_convergence_delta=True)

In [31]:
score = predict(input_ids)

print('Input sequence: ', text)
# cuda error: https://stackoverflow.com/questions/53900910/typeerror-can-t-convert-cuda-tensor-to-numpy-use-tensor-cpu-to-copy-the-tens
# updated torch.argmax(score[0]).numpy() -> torch.argmax(score[0]).cpu().data.numpy()
print('Predicted label: ' + str(torch.argmax(score[0]).cpu().data.numpy()) + ', prob ungrammatical: ' + str(torch.softmax(score, dim = 1)[0][0].detach().cpu().data.numpy()))

Input sequence:  [2088, 146, 569, 2263, 845, 3368, 1172, 577, 2294, 969, 3863, 3150, 300, 1185, 629, 2503, 1805, 3109, 133, 517, 2055, 15, 46, 172, 673, 2679, 2509, 1832, 3219, 575, 2286, 940, 3746, 2683, 2526, 1899, 3485, 1640, 2450, 1595, 2269, 872, 3476, 1604, 2306, 1018, 4058, 3929, 3416, 1363, 1341, 1256, 913, 3637, 2248, 785, 3125, 197, 775, 3085, 40, 148, 580, 2305, 1015, 4047, 3885, 3238, 650, 2587, 2142, 364, 1444, 1668, 2562, 2044, 4066, 3962, 3548, 1890, 3449, 1493, 1861, 3335, 1037, 40, 148, 577, 2296, 977, 3893, 3272, 788, 3138, 252, 994, 3962, 3546, 1881, 3413, 1351, 1293, 1061, 133, 520, 2068, 68, 258, 1020, 4067, 3966, 3562, 1945, 3669, 2376, 1297, 1077, 200, 787, 3135, 240, 947, 3773, 2792, 2964, 3649, 2296, 979, 3901, 3301, 904, 3603, 2110, 234, 923, 3677, 2408, 1428, 1603, 2301, 1000, 3985, 3640, 2258, 827, 3295, 879, 3501, 1704, 2705, 2615, 2254, 809, 3222, 585, 2328, 1106, 316, 1250, 890, 3545, 1880, 3410, 1339, 1246, 875, 3487, 1645, 2472, 1682, 2620, 2276, 899, 3

Helper function to summarize attributions for each word token in the sequence:

In [32]:
def summarize_attributions(attributions):
    attributions = attributions.sum(dim=-1).squeeze(0)
    attributions = attributions / torch.norm(attributions)
    return attributions

In [33]:
attributions_sum = summarize_attributions(attributions)

In [34]:
# storing couple samples in an array for visualization purposes
score_vis = viz.VisualizationDataRecord(
                        attributions_sum,
                        torch.softmax(score, dim = 1)[0][0],
                        torch.argmax(torch.softmax(score, dim = 1)[0]),
                        0,
                        text,
                        attributions_sum.sum(),       
                        all_tokens,
                        delta)

print('\033[1m', 'Visualization For Score', '\033[0m')
viz.visualize_text([score_vis])

[1m Visualization For Score [0m


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,0 (0.56),"[2088, 146, 569, 2263, 845, 3368, 1172, 577, 2294, 969, 3863, 3150, 300, 1185, 629, 2503, 1805, 3109, 133, 517, 2055, 15, 46, 172, 673, 2679, 2509, 1832, 3219, 575, 2286, 940, 3746, 2683, 2526, 1899, 3485, 1640, 2450, 1595, 2269, 872, 3476, 1604, 2306, 1018, 4058, 3929, 3416, 1363, 1341, 1256, 913, 3637, 2248, 785, 3125, 197, 775, 3085, 40, 148, 580, 2305, 1015, 4047, 3885, 3238, 650, 2587, 2142, 364, 1444, 1668, 2562, 2044, 4066, 3962, 3548, 1890, 3449, 1493, 1861, 3335, 1037, 40, 148, 577, 2296, 977, 3893, 3272, 788, 3138, 252, 994, 3962, 3546, 1881, 3413, 1351, 1293, 1061, 133, 520, 2068, 68, 258, 1020, 4067, 3966, 3562, 1945, 3669, 2376, 1297, 1077, 200, 787, 3135, 240, 947, 3773, 2792, 2964, 3649, 2296, 979, 3901, 3301, 904, 3603, 2110, 234, 923, 3677, 2408, 1428, 1603, 2301, 1000, 3985, 3640, 2258, 827, 3295, 879, 3501, 1704, 2705, 2615, 2254, 809, 3222, 585, 2328, 1106, 316, 1250, 890, 3545, 1880, 3410, 1339, 1246, 875, 3487, 1645, 2472, 1682, 2620, 2276, 899, 3582, 2028, 4003, 3709, 2536, 1939, 3647, 2285, 936, 3729, 2616, 2260, 835, 3327, 1005, 4005, 3720, 2577, 2101, 200, 787, 3134, 235, 927, 3694, 2476, 1699, 2686, 2540, 1954, 3708, 2530, 1915, 3551, 1902, 3500, 1697, 2679, 2509, 1830, 3211, 543, 2157, 424, 1684, 2625, 2293, 968, 3859, 3134, 236, 932, 3713, 2552, 2001, 3896, 3283, 832, 3316, 964, 3842, 3068, 4068, 3970, 3577, 2008, 3924, 3395, 1278, 1001, 3991, 3663, 2349]",-4.9,[CLS] CAACAG AACAGT ACAGTA CAGTAC AGTACA GTACAG TACAGG ACAGGA CAGGAT AGGATA GGATAC GATACT ATACTG TACTGA ACTGAA CTGAAC TGAACA GAACAA AACAAA ACAAAA CAAAAC AAAACC AAACCT AACCTG ACCTGA CCTGAC CTGACA TGACAG GACAGC ACAGCC CAGCCT AGCCTG GCCTGT CCTGTC CTGTCT TGTCTC GTCTCA TCTCAG CTCAGT TCAGTC CAGTCA AGTCAG GTCAGG TCAGGG CAGGGT AGGGTT GGGTTT GGTTTA GTTTAG TTTAGC TTAGCA TAGCAG AGCAGA GCAGAA CAGAAG AGAAGA GAAGAA AAGAAA AGAAAC GAAACA AAACAG AACAGG ACAGGG CAGGGA AGGGAC GGGACC GGACCA GACCAT ACCATT CCATTC CATTCT ATTCTG TTCTGG TCTGGG CTGGGT TGGGTG GGGTGT GGTGTT GTGTTG TGTTGT GTTGTA TTGTAA TGTAAA GTAAAC TAAACA AAACAG AACAGG ACAGGA CAGGAG AGGAGA GGAGAA GAGAAG AGAAGG GAAGGT AAGGTG AGGTGT GGTGTT GTGTTT TGTTTA GTTTAA TTTAAC TTAACA TAACAA AACAAA ACAAAG CAAAGG AAAGGG AAGGGT AGGGTG GGGTGC GGTGCT GTGCTT TGCTTA GCTTAA CTTAAG TTAAGA TAAGAA AAGAAG AGAAGC GAAGCC AAGCCG AGCCGC GCCGCA CCGCAG CGCAGG GCAGGA CAGGAG AGGAGC GGAGCA GAGCAA AGCAAG GCAAGC CAAGCT AAGCTT AGCTTC GCTTCA CTTCAG TTCAGG TCAGGC CAGGCA AGGCAG GGCAGA GCAGAG CAGAGT AGAGTC GAGTCC AGTCCC GTCCCA TCCCAG CCCAGA CCAGAC CAGACT AGACTA GACTAT ACTATA CTATAG TATAGT ATAGTG TAGTGT AGTGTT GTGTTA TGTTAG GTTAGT TTAGTC TAGTCT AGTCTC GTCTCC TCTCCA CTCCAG TCCAGT CCAGTG CAGTGG AGTGGC GTGGCT TGGCTG GGCTGC GCTGCA CTGCAG TGCAGC GCAGCC CAGCCA AGCCAG GCCAGA CCAGAG CAGAGG AGAGGC GAGGCC AGGCCA GGCCAA GCCAAG CCAAGA CAAGAA AAGAAG AGAAGC GAAGCT AAGCTC AGCTCC GCTCCT CTCCTG TCCTGC CCTGCT CTGCTG TGCTGT GCTGTG CTGTGT TGTGTC GTGTCC TGTCCT GTCCTG TCCTGA CCTGAC CTGACA TGACAT GACATC ACATCC CATCCA ATCCAG TCCAGG CCAGGA CAGGAA AGGAAG GGAAGC GAAGCT AAGCTG AGCTGG GCTGGA CTGGAG TGGAGA GGAGAG GAGAGC AGAGCG GAGCGG AGCGGG GCGGGT CGGGTG GGGTGG GGTGGT GTGGTA TGGTAG GGTAGG GTAGGC TAGGCT AGGCTA GGCTAC GCTACC CTACCA [SEP]
,,,,


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,0 (0.56),"[2088, 146, 569, 2263, 845, 3368, 1172, 577, 2294, 969, 3863, 3150, 300, 1185, 629, 2503, 1805, 3109, 133, 517, 2055, 15, 46, 172, 673, 2679, 2509, 1832, 3219, 575, 2286, 940, 3746, 2683, 2526, 1899, 3485, 1640, 2450, 1595, 2269, 872, 3476, 1604, 2306, 1018, 4058, 3929, 3416, 1363, 1341, 1256, 913, 3637, 2248, 785, 3125, 197, 775, 3085, 40, 148, 580, 2305, 1015, 4047, 3885, 3238, 650, 2587, 2142, 364, 1444, 1668, 2562, 2044, 4066, 3962, 3548, 1890, 3449, 1493, 1861, 3335, 1037, 40, 148, 577, 2296, 977, 3893, 3272, 788, 3138, 252, 994, 3962, 3546, 1881, 3413, 1351, 1293, 1061, 133, 520, 2068, 68, 258, 1020, 4067, 3966, 3562, 1945, 3669, 2376, 1297, 1077, 200, 787, 3135, 240, 947, 3773, 2792, 2964, 3649, 2296, 979, 3901, 3301, 904, 3603, 2110, 234, 923, 3677, 2408, 1428, 1603, 2301, 1000, 3985, 3640, 2258, 827, 3295, 879, 3501, 1704, 2705, 2615, 2254, 809, 3222, 585, 2328, 1106, 316, 1250, 890, 3545, 1880, 3410, 1339, 1246, 875, 3487, 1645, 2472, 1682, 2620, 2276, 899, 3582, 2028, 4003, 3709, 2536, 1939, 3647, 2285, 936, 3729, 2616, 2260, 835, 3327, 1005, 4005, 3720, 2577, 2101, 200, 787, 3134, 235, 927, 3694, 2476, 1699, 2686, 2540, 1954, 3708, 2530, 1915, 3551, 1902, 3500, 1697, 2679, 2509, 1830, 3211, 543, 2157, 424, 1684, 2625, 2293, 968, 3859, 3134, 236, 932, 3713, 2552, 2001, 3896, 3283, 832, 3316, 964, 3842, 3068, 4068, 3970, 3577, 2008, 3924, 3395, 1278, 1001, 3991, 3663, 2349]",-4.9,[CLS] CAACAG AACAGT ACAGTA CAGTAC AGTACA GTACAG TACAGG ACAGGA CAGGAT AGGATA GGATAC GATACT ATACTG TACTGA ACTGAA CTGAAC TGAACA GAACAA AACAAA ACAAAA CAAAAC AAAACC AAACCT AACCTG ACCTGA CCTGAC CTGACA TGACAG GACAGC ACAGCC CAGCCT AGCCTG GCCTGT CCTGTC CTGTCT TGTCTC GTCTCA TCTCAG CTCAGT TCAGTC CAGTCA AGTCAG GTCAGG TCAGGG CAGGGT AGGGTT GGGTTT GGTTTA GTTTAG TTTAGC TTAGCA TAGCAG AGCAGA GCAGAA CAGAAG AGAAGA GAAGAA AAGAAA AGAAAC GAAACA AAACAG AACAGG ACAGGG CAGGGA AGGGAC GGGACC GGACCA GACCAT ACCATT CCATTC CATTCT ATTCTG TTCTGG TCTGGG CTGGGT TGGGTG GGGTGT GGTGTT GTGTTG TGTTGT GTTGTA TTGTAA TGTAAA GTAAAC TAAACA AAACAG AACAGG ACAGGA CAGGAG AGGAGA GGAGAA GAGAAG AGAAGG GAAGGT AAGGTG AGGTGT GGTGTT GTGTTT TGTTTA GTTTAA TTTAAC TTAACA TAACAA AACAAA ACAAAG CAAAGG AAAGGG AAGGGT AGGGTG GGGTGC GGTGCT GTGCTT TGCTTA GCTTAA CTTAAG TTAAGA TAAGAA AAGAAG AGAAGC GAAGCC AAGCCG AGCCGC GCCGCA CCGCAG CGCAGG GCAGGA CAGGAG AGGAGC GGAGCA GAGCAA AGCAAG GCAAGC CAAGCT AAGCTT AGCTTC GCTTCA CTTCAG TTCAGG TCAGGC CAGGCA AGGCAG GGCAGA GCAGAG CAGAGT AGAGTC GAGTCC AGTCCC GTCCCA TCCCAG CCCAGA CCAGAC CAGACT AGACTA GACTAT ACTATA CTATAG TATAGT ATAGTG TAGTGT AGTGTT GTGTTA TGTTAG GTTAGT TTAGTC TAGTCT AGTCTC GTCTCC TCTCCA CTCCAG TCCAGT CCAGTG CAGTGG AGTGGC GTGGCT TGGCTG GGCTGC GCTGCA CTGCAG TGCAGC GCAGCC CAGCCA AGCCAG GCCAGA CCAGAG CAGAGG AGAGGC GAGGCC AGGCCA GGCCAA GCCAAG CCAAGA CAAGAA AAGAAG AGAAGC GAAGCT AAGCTC AGCTCC GCTCCT CTCCTG TCCTGC CCTGCT CTGCTG TGCTGT GCTGTG CTGTGT TGTGTC GTGTCC TGTCCT GTCCTG TCCTGA CCTGAC CTGACA TGACAT GACATC ACATCC CATCCA ATCCAG TCCAGG CCAGGA CAGGAA AGGAAG GGAAGC GAAGCT AAGCTG AGCTGG GCTGGA CTGGAG TGGAGA GGAGAG GAGAGC AGAGCG GAGCGG AGCGGG GCGGGT CGGGTG GGGTGG GGTGGT GTGGTA TGGTAG GGTAGG GTAGGC TAGGCT AGGCTA GGCTAC GCTACC CTACCA [SEP]
,,,,


### 2. On trained model:

#### Train the model:

In [35]:
from transformers import TrainingArguments, Trainer
from datasets import load_metric
import numpy as np

BATCH_SIZE = 32
LEARNING_RATE = 1e-5
EPOCHS = 10

training_arguments = TrainingArguments(
        output_dir='outputs', 
        learning_rate=LEARNING_RATE, 
        num_train_epochs=EPOCHS, 
        evaluation_strategy="epoch", 
        logging_strategy='epoch',
        per_device_train_batch_size=BATCH_SIZE, 
        per_device_eval_batch_size=BATCH_SIZE,
        fp16=True,
    )

def compute_metrics(eval_preds):
    metric = load_metric("accuracy")
    logits, labels = eval_preds
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

trainer = Trainer(
    model=model,
    args=training_arguments,
    train_dataset=dataset_train_tokenized,
    eval_dataset=dataset_test_tokenized,
    compute_metrics=compute_metrics,
)
trainer

Using cuda_amp half precision backend


<transformers.trainer.Trainer at 0x7f9eb5ea7580>

In [36]:
trainer.train()

The following columns in the training set don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: seq. If seq are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 1000
  Num Epochs = 10
  Instantaneous batch size per device = 32
  Total train batch size (w. parallel, distributed & accumulation) = 32
  Gradient Accumulation steps = 1
  Total optimization steps = 320
COMET INFO: Experiment is live on comet.com https://www.comet.com/roa7n/huggingface/b17ae8ebd74a4025826e931f8c1b57e0

Automatic Comet.ml online logging enabled


Epoch,Training Loss,Validation Loss,Accuracy
1,0.645,0.583451,0.7405
2,0.5162,0.486638,0.7875
3,0.4562,0.450197,0.8035
4,0.4316,0.434721,0.812
5,0.408,0.423467,0.82
6,0.3966,0.413795,0.829
7,0.3812,0.410467,0.83
8,0.38,0.406175,0.831
9,0.3717,0.403953,0.8325
10,0.3713,0.403301,0.833


The following columns in the evaluation set don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: seq. If seq are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 2000
  Batch size = 32
The following columns in the evaluation set don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: seq. If seq are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 2000
  Batch size = 32
The following columns in the evaluation set don't have a corresponding argument in `BertForSequenceClassification.forward` and have been ignored: seq. If seq are not expected by `BertForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 2000
  Batch size = 32
The following colum

TrainOutput(global_step=320, training_loss=0.43578380048274995, metrics={'train_runtime': 99.162, 'train_samples_per_second': 100.845, 'train_steps_per_second': 3.227, 'total_flos': 1274444174400000.0, 'train_loss': 0.43578380048274995, 'epoch': 10.0})

### Interpret again

In [37]:
score = predict(input_ids)

print('Input sequence: ', text)
print('Predicted label: ' + str(torch.argmax(score[0]).cpu().data.numpy()) + ', prob ungrammatical: ' + str(torch.softmax(score, dim = 1)[0][0].detach().cpu().data.numpy()))

Input sequence:  [2088, 146, 569, 2263, 845, 3368, 1172, 577, 2294, 969, 3863, 3150, 300, 1185, 629, 2503, 1805, 3109, 133, 517, 2055, 15, 46, 172, 673, 2679, 2509, 1832, 3219, 575, 2286, 940, 3746, 2683, 2526, 1899, 3485, 1640, 2450, 1595, 2269, 872, 3476, 1604, 2306, 1018, 4058, 3929, 3416, 1363, 1341, 1256, 913, 3637, 2248, 785, 3125, 197, 775, 3085, 40, 148, 580, 2305, 1015, 4047, 3885, 3238, 650, 2587, 2142, 364, 1444, 1668, 2562, 2044, 4066, 3962, 3548, 1890, 3449, 1493, 1861, 3335, 1037, 40, 148, 577, 2296, 977, 3893, 3272, 788, 3138, 252, 994, 3962, 3546, 1881, 3413, 1351, 1293, 1061, 133, 520, 2068, 68, 258, 1020, 4067, 3966, 3562, 1945, 3669, 2376, 1297, 1077, 200, 787, 3135, 240, 947, 3773, 2792, 2964, 3649, 2296, 979, 3901, 3301, 904, 3603, 2110, 234, 923, 3677, 2408, 1428, 1603, 2301, 1000, 3985, 3640, 2258, 827, 3295, 879, 3501, 1704, 2705, 2615, 2254, 809, 3222, 585, 2328, 1106, 316, 1250, 890, 3545, 1880, 3410, 1339, 1246, 875, 3487, 1645, 2472, 1682, 2620, 2276, 899, 3

In [38]:
def dna_to_input_text(seq):
    seq_kmers = kmers(seq)
    return tokenizer.encode(seq_kmers, add_special_tokens=False)

def interpret_and_visualize(tokenized_text):
    input_ids, ref_input_ids, sep_id = construct_input_ref_pair(tokenized_text, ref_token_id, sep_token_id, cls_token_id)
    score = predict(input_ids)
    
    attributions, delta = lig.attribute(inputs=input_ids,
                                    baselines=ref_input_ids,
                                    return_convergence_delta=True)
    
    indices = input_ids[0].detach().tolist()
    
    all_tokens = tokenizer.convert_ids_to_tokens(indices)
    
    attributions_sum = summarize_attributions(attributions)

    score_vis = viz.VisualizationDataRecord(
        attributions_sum, torch.softmax(score, dim=1)[0][0],
        torch.argmax(torch.softmax(score, dim=1)[0]), 0, tokenized_text, 
        attributions_sum.sum(), all_tokens, delta)

    print('\033[1m', 'Visualization For Score', '\033[0m')
    viz.visualize_text([score_vis])

In [40]:
interpret_and_visualize(text)

[1m Visualization For Score [0m


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,0 (0.91),"[2088, 146, 569, 2263, 845, 3368, 1172, 577, 2294, 969, 3863, 3150, 300, 1185, 629, 2503, 1805, 3109, 133, 517, 2055, 15, 46, 172, 673, 2679, 2509, 1832, 3219, 575, 2286, 940, 3746, 2683, 2526, 1899, 3485, 1640, 2450, 1595, 2269, 872, 3476, 1604, 2306, 1018, 4058, 3929, 3416, 1363, 1341, 1256, 913, 3637, 2248, 785, 3125, 197, 775, 3085, 40, 148, 580, 2305, 1015, 4047, 3885, 3238, 650, 2587, 2142, 364, 1444, 1668, 2562, 2044, 4066, 3962, 3548, 1890, 3449, 1493, 1861, 3335, 1037, 40, 148, 577, 2296, 977, 3893, 3272, 788, 3138, 252, 994, 3962, 3546, 1881, 3413, 1351, 1293, 1061, 133, 520, 2068, 68, 258, 1020, 4067, 3966, 3562, 1945, 3669, 2376, 1297, 1077, 200, 787, 3135, 240, 947, 3773, 2792, 2964, 3649, 2296, 979, 3901, 3301, 904, 3603, 2110, 234, 923, 3677, 2408, 1428, 1603, 2301, 1000, 3985, 3640, 2258, 827, 3295, 879, 3501, 1704, 2705, 2615, 2254, 809, 3222, 585, 2328, 1106, 316, 1250, 890, 3545, 1880, 3410, 1339, 1246, 875, 3487, 1645, 2472, 1682, 2620, 2276, 899, 3582, 2028, 4003, 3709, 2536, 1939, 3647, 2285, 936, 3729, 2616, 2260, 835, 3327, 1005, 4005, 3720, 2577, 2101, 200, 787, 3134, 235, 927, 3694, 2476, 1699, 2686, 2540, 1954, 3708, 2530, 1915, 3551, 1902, 3500, 1697, 2679, 2509, 1830, 3211, 543, 2157, 424, 1684, 2625, 2293, 968, 3859, 3134, 236, 932, 3713, 2552, 2001, 3896, 3283, 832, 3316, 964, 3842, 3068, 4068, 3970, 3577, 2008, 3924, 3395, 1278, 1001, 3991, 3663, 2349]",1.72,[CLS] CAACAG AACAGT ACAGTA CAGTAC AGTACA GTACAG TACAGG ACAGGA CAGGAT AGGATA GGATAC GATACT ATACTG TACTGA ACTGAA CTGAAC TGAACA GAACAA AACAAA ACAAAA CAAAAC AAAACC AAACCT AACCTG ACCTGA CCTGAC CTGACA TGACAG GACAGC ACAGCC CAGCCT AGCCTG GCCTGT CCTGTC CTGTCT TGTCTC GTCTCA TCTCAG CTCAGT TCAGTC CAGTCA AGTCAG GTCAGG TCAGGG CAGGGT AGGGTT GGGTTT GGTTTA GTTTAG TTTAGC TTAGCA TAGCAG AGCAGA GCAGAA CAGAAG AGAAGA GAAGAA AAGAAA AGAAAC GAAACA AAACAG AACAGG ACAGGG CAGGGA AGGGAC GGGACC GGACCA GACCAT ACCATT CCATTC CATTCT ATTCTG TTCTGG TCTGGG CTGGGT TGGGTG GGGTGT GGTGTT GTGTTG TGTTGT GTTGTA TTGTAA TGTAAA GTAAAC TAAACA AAACAG AACAGG ACAGGA CAGGAG AGGAGA GGAGAA GAGAAG AGAAGG GAAGGT AAGGTG AGGTGT GGTGTT GTGTTT TGTTTA GTTTAA TTTAAC TTAACA TAACAA AACAAA ACAAAG CAAAGG AAAGGG AAGGGT AGGGTG GGGTGC GGTGCT GTGCTT TGCTTA GCTTAA CTTAAG TTAAGA TAAGAA AAGAAG AGAAGC GAAGCC AAGCCG AGCCGC GCCGCA CCGCAG CGCAGG GCAGGA CAGGAG AGGAGC GGAGCA GAGCAA AGCAAG GCAAGC CAAGCT AAGCTT AGCTTC GCTTCA CTTCAG TTCAGG TCAGGC CAGGCA AGGCAG GGCAGA GCAGAG CAGAGT AGAGTC GAGTCC AGTCCC GTCCCA TCCCAG CCCAGA CCAGAC CAGACT AGACTA GACTAT ACTATA CTATAG TATAGT ATAGTG TAGTGT AGTGTT GTGTTA TGTTAG GTTAGT TTAGTC TAGTCT AGTCTC GTCTCC TCTCCA CTCCAG TCCAGT CCAGTG CAGTGG AGTGGC GTGGCT TGGCTG GGCTGC GCTGCA CTGCAG TGCAGC GCAGCC CAGCCA AGCCAG GCCAGA CCAGAG CAGAGG AGAGGC GAGGCC AGGCCA GGCCAA GCCAAG CCAAGA CAAGAA AAGAAG AGAAGC GAAGCT AAGCTC AGCTCC GCTCCT CTCCTG TCCTGC CCTGCT CTGCTG TGCTGT GCTGTG CTGTGT TGTGTC GTGTCC TGTCCT GTCCTG TCCTGA CCTGAC CTGACA TGACAT GACATC ACATCC CATCCA ATCCAG TCCAGG CCAGGA CAGGAA AGGAAG GGAAGC GAAGCT AAGCTG AGCTGG GCTGGA CTGGAG TGGAGA GGAGAG GAGAGC AGAGCG GAGCGG AGCGGG GCGGGT CGGGTG GGGTGG GGTGGT GTGGTA TGGTAG GGTAGG GTAGGC TAGGCT AGGCTA GGCTAC GCTACC CTACCA [SEP]
,,,,


## Other input sequences:

In [40]:
new_seq = dataset_test[1055]
new_seq

{'labels': 1,
 'seq': 'CTTTGTTTTCTCTGTGATGAAATGGGAAGTATACATACCATACTTCAGCTCGATATTGACATGCTATCTGAGTTTTGGGAAACATACTGATTATTTTCCCAGGGAATATAACTAGAATAAATATAATAAACACTTTTTTTTTTTCAGGTGCAGCTGCTGCTGCTGTGATGTCCAGTTCTAAAGTAACCACAGTCCTGAGGCCGACCTCACAGCTGCCAAATGCTGCTACTGCTCAGCCAGCAGTACAGCAC'}

In [41]:
tokenized_text = dna_to_input_text(new_seq['seq'])
interpret_and_visualize(tokenized_text)

[1m Visualization For Score [0m


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,0 (0.59),"[2402, 1402, 1498, 1882, 3419, 1374, 1387, 1438, 1644, 2466, 1660, 2529, 1910, 3532, 1825, 3189, 453, 1798, 3084, 36, 132, 513, 2037, 4040, 3858, 3129, 214, 841, 3351, 1101, 294, 1161, 535, 2127, 301, 1190, 649, 2583, 2126, 298, 1179, 605, 2408, 1427, 1598, 2283, 928, 3697, 2486, 1737, 2838, 3146, 284, 1121, 375, 1485, 1830, 3212, 547, 2174, 489, 1942, 3659, 2334, 1132, 417, 1656, 2514, 1850, 3290, 858, 3420, 1380, 1412, 1537, 2037, 4037, 3847, 3085, 38, 137, 535, 2126, 300, 1185, 630, 2506, 1817, 3158, 330, 1306, 1114, 347, 1375, 1391, 1453, 1704, 2708, 2628, 2305, 1013, 4038, 3849, 3094, 73, 277, 1095, 270, 1065, 152, 593, 2357, 1222, 777, 3093, 69, 262, 1033, 22, 73, 277, 1094, 265, 1045, 69, 263, 1037, 39, 142, 554, 2202, 602, 2394, 1370, 1370, 1370, 1370, 1370, 1370, 1371, 1373, 1384, 1428, 1602, 2300, 995, 3965, 3560, 1939, 3646, 2284, 931, 3710, 2540, 1955, 3710, 2540, 1955, 3710, 2540, 1954, 3708, 2529, 1910, 3532, 1826, 3195, 479, 1901, 3496, 1682, 2618, 2267, 862, 3433, 1429, 1605, 2312, 1042, 57, 213, 839, 3343, 1069, 167, 653, 2600, 2194, 571, 2271, 878, 3500, 1697, 2680, 2516, 1859, 3327, 1008, 4017, 3767, 2767, 2862, 3243, 669, 2663, 2445, 1576, 2195, 574, 2284, 931, 3711, 2541, 1957, 3717, 2566, 2060, 35, 126, 492, 1955, 3710, 2537, 1943, 3662, 2348, 1187, 638, 2539, 1949, 3688, 2451, 1599, 2285, 936, 3731, 2621, 2280, 914, 3641, 2263, 845, 3368, 1171, 573, 2279]",1.3,[CLS] CTTTGT TTTGTT TTGTTT TGTTTT GTTTTC TTTTCT TTTCTC TTCTCT TCTCTG CTCTGT TCTGTG CTGTGA TGTGAT GTGATG TGATGA GATGAA ATGAAA TGAAAT GAAATG AAATGG AATGGG ATGGGA TGGGAA GGGAAG GGAAGT GAAGTA AAGTAT AGTATA GTATAC TATACA ATACAT TACATA ACATAC CATACC ATACCA TACCAT ACCATA CCATAC CATACT ATACTT TACTTC ACTTCA CTTCAG TTCAGC TCAGCT CAGCTC AGCTCG GCTCGA CTCGAT TCGATA CGATAT GATATT ATATTG TATTGA ATTGAC TTGACA TGACAT GACATG ACATGC CATGCT ATGCTA TGCTAT GCTATC CTATCT TATCTG ATCTGA TCTGAG CTGAGT TGAGTT GAGTTT AGTTTT GTTTTG TTTTGG TTTGGG TTGGGA TGGGAA GGGAAA GGAAAC GAAACA AAACAT AACATA ACATAC CATACT ATACTG TACTGA ACTGAT CTGATT TGATTA GATTAT ATTATT TTATTT TATTTT ATTTTC TTTTCC TTTCCC TTCCCA TCCCAG CCCAGG CCAGGG CAGGGA AGGGAA GGGAAT GGAATA GAATAT AATATA ATATAA TATAAC ATAACT TAACTA AACTAG ACTAGA CTAGAA TAGAAT AGAATA GAATAA AATAAA ATAAAT TAAATA AAATAT AATATA ATATAA TATAAT ATAATA TAATAA AATAAA ATAAAC TAAACA AAACAC AACACT ACACTT CACTTT ACTTTT CTTTTT TTTTTT TTTTTT TTTTTT TTTTTT TTTTTT TTTTTT TTTTTC TTTTCA TTTCAG TTCAGG TCAGGT CAGGTG AGGTGC GGTGCA GTGCAG TGCAGC GCAGCT CAGCTG AGCTGC GCTGCT CTGCTG TGCTGC GCTGCT CTGCTG TGCTGC GCTGCT CTGCTG TGCTGT GCTGTG CTGTGA TGTGAT GTGATG TGATGT GATGTC ATGTCC TGTCCA GTCCAG TCCAGT CCAGTT CAGTTC AGTTCT GTTCTA TTCTAA TCTAAA CTAAAG TAAAGT AAAGTA AAGTAA AGTAAC GTAACC TAACCA AACCAC ACCACA CCACAG CACAGT ACAGTC CAGTCC AGTCCT GTCCTG TCCTGA CCTGAG CTGAGG TGAGGC GAGGCC AGGCCG GGCCGA GCCGAC CCGACC CGACCT GACCTC ACCTCA CCTCAC CTCACA TCACAG CACAGC ACAGCT CAGCTG AGCTGC GCTGCC CTGCCA TGCCAA GCCAAA CCAAAT CAAATG AAATGC AATGCT ATGCTG TGCTGC GCTGCT CTGCTA TGCTAC GCTACT CTACTG TACTGC ACTGCT CTGCTC TGCTCA GCTCAG CTCAGC TCAGCC CAGCCA AGCCAG GCCAGC CCAGCA CAGCAG AGCAGT GCAGTA CAGTAC AGTACA GTACAG TACAGC ACAGCA CAGCAC [SEP]
,,,,


In [42]:
new_seq = dataset_train[155]
new_seq

{'labels': 0,
 'seq': 'CGGGCTCTGCCATGCCCTCCTATGCTCAGGTGTGCTGAGGTCCACACGGCCCTGCCGTTGCACTGCAGCTGCCTGCAGGATTCAGTGCAGTGGCATGCAGTGCAGGTGCGGTGCCCCGGAGCCACAGGCCACACCACAGGGCCTGCATGCACAGGGGCTGCGGTGTCTGGGTTTGGGTAACTACGCCCTGTGACATTTGCACAGCAACAGAATTACCTAATGACGCATTTCTCAGAACACATCCCTGGCAC'}

In [43]:
tokenized_text = dna_to_input_text(new_seq['seq'])
interpret_and_visualize(tokenized_text)

[1m Visualization For Score [0m


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,0 (0.76),"[3070, 4075, 3998, 3692, 2467, 1663, 2541, 1958, 3724, 2595, 2175, 495, 1966, 3755, 2719, 2670, 2473, 1686, 2636, 2339, 1150, 491, 1949, 3688, 2452, 1602, 2300, 994, 3964, 3555, 1918, 3564, 1953, 3704, 2516, 1858, 3323, 991, 3949, 3495, 1677, 2599, 2192, 564, 2243, 767, 3055, 4014, 3756, 2723, 2687, 2544, 1970, 3770, 2780, 2915, 3453, 1511, 1934, 3628, 2211, 637, 2536, 1939, 3646, 2284, 931, 3711, 2542, 1964, 3747, 2685, 2536, 1940, 3649, 2294, 970, 3867, 3165, 360, 1426, 1596, 2275, 893, 3560, 1938, 3644, 2276, 899, 3581, 2022, 3980, 3619, 2173, 488, 1938, 3644, 2275, 893, 3560, 1940, 3650, 2300, 995, 3968, 3572, 1986, 3836, 3043, 3967, 3567, 1967, 3760, 2740, 2753, 2808, 3027, 3903, 3309, 935, 3725, 2600, 2196, 579, 2303, 1005, 4007, 3725, 2599, 2191, 557, 2215, 653, 2600, 2196, 580, 2307, 1023, 4078, 4012, 3747, 2685, 2534, 1932, 3619, 2173, 487, 1933, 3624, 2196, 580, 2308, 1027, 4094, 4076, 4003, 3712, 2548, 1986, 3836, 3042, 3963, 3550, 1900, 3492, 1668, 2562, 2042, 4058, 3932, 3428, 1412, 1538, 2041, 4053, 3911, 3342, 1065, 151, 592, 2355, 1215, 751, 2990, 3756, 2722, 2684, 2529, 1911, 3533, 1830, 3210, 538, 2140, 355, 1405, 1511, 1933, 3624, 2195, 573, 2277, 903, 3597, 2088, 145, 565, 2246, 778, 3097, 87, 335, 1326, 1193, 661, 2630, 2316, 1057, 119, 464, 1843, 3261, 742, 2954, 3610, 2139, 350, 1387, 1437, 1640, 2449, 1589, 2247, 781, 3111, 141, 550, 2187, 543, 2159, 430, 1708, 2724, 2691, 2557, 2023]",1.53,[CLS] CGGGCT GGGCTC GGCTCT GCTCTG CTCTGC TCTGCC CTGCCA TGCCAT GCCATG CCATGC CATGCC ATGCCC TGCCCT GCCCTC CCCTCC CCTCCT CTCCTA TCCTAT CCTATG CTATGC TATGCT ATGCTC TGCTCA GCTCAG CTCAGG TCAGGT CAGGTG AGGTGT GGTGTG GTGTGC TGTGCT GTGCTG TGCTGA GCTGAG CTGAGG TGAGGT GAGGTC AGGTCC GGTCCA GTCCAC TCCACA CCACAC CACACG ACACGG CACGGC ACGGCC CGGCCC GGCCCT GCCCTG CCCTGC CCTGCC CTGCCG TGCCGT GCCGTT CCGTTG CGTTGC GTTGCA TTGCAC TGCACT GCACTG CACTGC ACTGCA CTGCAG TGCAGC GCAGCT CAGCTG AGCTGC GCTGCC CTGCCT TGCCTG GCCTGC CCTGCA CTGCAG TGCAGG GCAGGA CAGGAT AGGATT GGATTC GATTCA ATTCAG TTCAGT TCAGTG CAGTGC AGTGCA GTGCAG TGCAGT GCAGTG CAGTGG AGTGGC GTGGCA TGGCAT GGCATG GCATGC CATGCA ATGCAG TGCAGT GCAGTG CAGTGC AGTGCA GTGCAG TGCAGG GCAGGT CAGGTG AGGTGC GGTGCG GTGCGG TGCGGT GCGGTG CGGTGC GGTGCC GTGCCC TGCCCC GCCCCG CCCCGG CCCGGA CCGGAG CGGAGC GGAGCC GAGCCA AGCCAC GCCACA CCACAG CACAGG ACAGGC CAGGCC AGGCCA GGCCAC GCCACA CCACAC CACACC ACACCA CACCAC ACCACA CCACAG CACAGG ACAGGG CAGGGC AGGGCC GGGCCT GGCCTG GCCTGC CCTGCA CTGCAT TGCATG GCATGC CATGCA ATGCAC TGCACA GCACAG CACAGG ACAGGG CAGGGG AGGGGC GGGGCT GGGCTG GGCTGC GCTGCG CTGCGG TGCGGT GCGGTG CGGTGT GGTGTC GTGTCT TGTCTG GTCTGG TCTGGG CTGGGT TGGGTT GGGTTT GGTTTG GTTTGG TTTGGG TTGGGT TGGGTA GGGTAA GGTAAC GTAACT TAACTA AACTAC ACTACG CTACGC TACGCC ACGCCC CGCCCT GCCCTG CCCTGT CCTGTG CTGTGA TGTGAC GTGACA TGACAT GACATT ACATTT CATTTG ATTTGC TTTGCA TTGCAC TGCACA GCACAG CACAGC ACAGCA CAGCAA AGCAAC GCAACA CAACAG AACAGA ACAGAA CAGAAT AGAATT GAATTA AATTAC ATTACC TTACCT TACCTA ACCTAA CCTAAT CTAATG TAATGA AATGAC ATGACG TGACGC GACGCA ACGCAT CGCATT GCATTT CATTTC ATTTCT TTTCTC TTCTCA TCTCAG CTCAGA TCAGAA CAGAAC AGAACA GAACAC AACACA ACACAT CACATC ACATCC CATCCC ATCCCT TCCCTG CCCTGG CCTGGC CTGGCA TGGCAC [SEP]
,,,,


In [44]:
new_seq = dataset_train[855]
new_seq

{'labels': 1,
 'seq': 'ATTTGAACAAAAACTACATATAGTATAGCAGAAAAATAAACTAATAGCATTTTATGTATTTATACATTCCTATTATGCAAGTTCTCCTATGATCCAGAATAATACTTTGATAATGCACTTTTAATTGCCTTGAGTAAAAGTATCCTCTTTTTTCTACTTTAGAAGCTGTTGTGAAGGCAGAGCAGCATCTGCTGAAGAGACAGAAACCAGCCCCAGAGGTGTCACAGGAAGGCACCAGCAAGGACATTGGT'}

In [45]:
tokenized_text = dna_to_input_text(new_seq['seq'])
interpret_and_visualize(tokenized_text)

[1m Visualization For Score [0m


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,0 (0.77),"[353, 1397, 1479, 1805, 3109, 133, 517, 2053, 7, 14, 41, 151, 589, 2342, 1161, 534, 2121, 280, 1106, 313, 1238, 841, 3352, 1107, 317, 1256, 913, 3637, 2245, 773, 3077, 6, 9, 21, 69, 263, 1038, 41, 149, 582, 2313, 1048, 83, 317, 1254, 906, 3610, 2138, 345, 1366, 1356, 1314, 1145, 470, 1866, 3354, 1113, 342, 1353, 1303, 1101, 294, 1162, 539, 2143, 366, 1449, 1686, 2634, 2329, 1110, 332, 1315, 1149, 485, 1928, 3602, 2106, 219, 862, 3435, 1439, 1646, 2473, 1686, 2636, 2337, 1142, 459, 1823, 3181, 424, 1681, 2613, 2246, 777, 3093, 70, 265, 1047, 78, 298, 1178, 604, 2401, 1398, 1481, 1813, 3142, 268, 1059, 125, 487, 1934, 3626, 2202, 602, 2393, 1365, 1350, 1290, 1052, 99, 383, 1518, 1962, 3740, 2657, 2424, 1490, 1849, 3285, 837, 3333, 1032, 18, 57, 214, 843, 3359, 1134, 427, 1694, 2666, 2458, 1626, 2394, 1370, 1371, 1374, 1385, 1431, 1614, 2346, 1178, 601, 2392, 1361, 1333, 1224, 787, 3134, 236, 930, 3706, 2524, 1890, 3452, 1505, 1909, 3528, 1812, 3139, 253, 1000, 3985, 3640, 2259, 829, 3304, 915, 3645, 2278, 907, 3614, 2156, 419, 1662, 2540, 1953, 3701, 2504, 1809, 3128, 209, 823, 3277, 808, 3217, 565, 2245, 775, 3087, 45, 168, 659, 2623, 2287, 943, 3757, 2728, 2705, 2616, 2260, 834, 3324, 994, 3963, 3549, 1895, 3469, 1576, 2196, 577, 2293, 968, 3860, 3139, 253, 999, 3983, 3629, 2216, 659, 2621, 2277, 904, 3604, 2113, 247, 973, 3878, 3210, 540, 2148, 386]",2.4,[CLS] ATTTGA TTTGAA TTGAAC TGAACA GAACAA AACAAA ACAAAA CAAAAA AAAAAC AAAACT AAACTA AACTAC ACTACA CTACAT TACATA ACATAT CATATA ATATAG TATAGT ATAGTA TAGTAT AGTATA GTATAG TATAGC ATAGCA TAGCAG AGCAGA GCAGAA CAGAAA AGAAAA GAAAAA AAAAAT AAAATA AAATAA AATAAA ATAAAC TAAACT AAACTA AACTAA ACTAAT CTAATA TAATAG AATAGC ATAGCA TAGCAT AGCATT GCATTT CATTTT ATTTTA TTTTAT TTTATG TTATGT TATGTA ATGTAT TGTATT GTATTT TATTTA ATTTAT TTTATA TTATAC TATACA ATACAT TACATT ACATTC CATTCC ATTCCT TTCCTA TCCTAT CCTATT CTATTA TATTAT ATTATG TTATGC TATGCA ATGCAA TGCAAG GCAAGT CAAGTT AAGTTC AGTTCT GTTCTC TTCTCC TCTCCT CTCCTA TCCTAT CCTATG CTATGA TATGAT ATGATC TGATCC GATCCA ATCCAG TCCAGA CCAGAA CAGAAT AGAATA GAATAA AATAAT ATAATA TAATAC AATACT ATACTT TACTTT ACTTTG CTTTGA TTTGAT TTGATA TGATAA GATAAT ATAATG TAATGC AATGCA ATGCAC TGCACT GCACTT CACTTT ACTTTT CTTTTA TTTTAA TTTAAT TTAATT TAATTG AATTGC ATTGCC TTGCCT TGCCTT GCCTTG CCTTGA CTTGAG TTGAGT TGAGTA GAGTAA AGTAAA GTAAAA TAAAAG AAAAGT AAAGTA AAGTAT AGTATC GTATCC TATCCT ATCCTC TCCTCT CCTCTT CTCTTT TCTTTT CTTTTT TTTTTT TTTTTC TTTTCT TTTCTA TTCTAC TCTACT CTACTT TACTTT ACTTTA CTTTAG TTTAGA TTAGAA TAGAAG AGAAGC GAAGCT AAGCTG AGCTGT GCTGTT CTGTTG TGTTGT GTTGTG TTGTGA TGTGAA GTGAAG TGAAGG GAAGGC AAGGCA AGGCAG GGCAGA GCAGAG CAGAGC AGAGCA GAGCAG AGCAGC GCAGCA CAGCAT AGCATC GCATCT CATCTG ATCTGC TCTGCT CTGCTG TGCTGA GCTGAA CTGAAG TGAAGA GAAGAG AAGAGA AGAGAC GAGACA AGACAG GACAGA ACAGAA CAGAAA AGAAAC GAAACC AAACCA AACCAG ACCAGC CCAGCC CAGCCC AGCCCC GCCCCA CCCCAG CCCAGA CCAGAG CAGAGG AGAGGT GAGGTG AGGTGT GGTGTC GTGTCA TGTCAC GTCACA TCACAG CACAGG ACAGGA CAGGAA AGGAAG GGAAGG GAAGGC AAGGCA AGGCAC GGCACC GCACCA CACCAG ACCAGC CCAGCA CAGCAA AGCAAG GCAAGG CAAGGA AAGGAC AGGACA GGACAT GACATT ACATTG CATTGG ATTGGT [SEP]
,,,,


*Colours: https://github.com/pytorch/captum/issues/249#issuecomment-580569266*

*That output is the prediction probability (p) of being a positive sentiment. A negative sentiment would be (1 - p).
In our case we attribute positive sentiment probability (p) to the inputs of our model and in case something is predicted with high probability as positive sentiment we see many tokens that contribute positively to the positive sentiment.*

*In case when p is very low, there are no words contributing to the positive sentiment and when we attribute to the positive sentiment prob (p) we find words that pull away (influence negatively) to the positive sentiment. Those tokens are obviously the ones that pull towards the negative sentiment with higher (1-p) probability.*

https://github.com/pytorch/captum/issues/249#issuecomment-580846266

*In a general case, red means that those tokens are pulling away from the Hate Speech (1) and most probably pulling towards the opposite class however I think that red might not always mean that it will always attribute to the other class. I think that's the assumption that we make here. We assume that the classifier is able to identify that a token is negatively correlated with the Hate Speech(1) class so it must know something about that token, namely, that it is strongly pulling towards the opposite class (because there are no other options) and this is much easier to imagine for 2 class problem.*