# Interpretation of fine-tuned DistilProtBERT using [**Captum**](https://captum.ai/)

Source code info:

Used notebook: https://gist.github.com/davidefiocco/3e1a0ed030792230a33c726c61f6b3a5 (adaptation to classification from original tutorial on question answering: https://captum.ai/tutorials/Bert_SQUAD_Interpret)

(Used notebook is based on this github issue: https://github.com/pytorch/captum/issues/303)

Related github issue: https://github.com/pytorch/captum/issues/249

## Load initial libraries, models, data:

In [1]:
!pip install transformers datasets tokenizers evaluate captum --quiet

In [2]:
TOKENIZER = 'EvaKlimentova/knots_distillprotbert_alphafold'
HF_MODEL_NAME = 'EvaKlimentova/knots_distillprotbert_alphafold'
INPUT_CSV = '/home/jovyan/data/proteins_m1/minimums_p40_for_interpretation.csv'

In [3]:
import numpy as np
import torch

torch.no_grad()

<torch.autograd.grad_mode.no_grad at 0x7f5ec989d490>

In [4]:
import tensorflow as tf
from transformers import BertTokenizer, BertForSequenceClassification
from captum.attr import visualization as viz
from captum.attr import LayerIntegratedGradients

# device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')  # solves all CUDA out of memory problems :)
device

2023-02-06 11:33:20.214487: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-02-06 11:33:21.851026: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64
2023-02-06 11:33:21.851099: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64


device(type='cpu')

### Load the tokenizer:

In [5]:
tokenizer = BertTokenizer.from_pretrained(TOKENIZER, max_length=1024, truncation=True, num_labels=2)
tokenizer

PreTrainedTokenizer(name_or_path='EvaKlimentova/knots_distillprotbert_alphafold', vocab_size=30, model_max_len=1000000000000000019884624838656, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})

### Load the model:

In [6]:
model = BertForSequenceClassification.from_pretrained(HF_MODEL_NAME, output_attentions=True)
model.to(device)
model.eval()
model.zero_grad()

### Load the data:

In [7]:
import csv
import pandas as pd

with open(INPUT_CSV, newline='') as f:
    reader = csv.reader(f)
    data = list(reader)

df = pd.DataFrame(data[1:], columns=data[0])
df

Unnamed: 0,id,sequence_str,sequence_len,sequence_pred,family,knot_start,knot_end,knot,knot_str,patch,min_sequence_str,min_pred,min_start,min_end,overlap_pred,drop_difference
0,A0A0X3YDV6,MVSKLGWQLVQLGRRLWVRTVAFAVLALFSALVAVLVQDYIPETLS...,416,1.0,DUF,184,335,PWLQAQPLQSLQQIPKDAKPVLTNTIGYLQLIDIKAINKWAAENNC...,XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX...,AEIASKALSPAINDPGTAIEIIGRGMRVLSNWQAPSMDKA,MVSKLGWQLVQLGRRLWVRTVAFAVLALFSALVAVLVQDYIPETLS...,0.99658203125,285,325,1.0,0.00341796875
1,J3D4J7,MANDFLFTSESVSEGHPDKVADQISDAILDAIFKQDPRSRVAAETL...,303,0.9995117,AdoMet synthase,15,279,HPDKVADQISDAILDAIFKQDPRSRVAAETLTNTGLVVLAGEITTN...,XXXXXXXXXXXXXXXHPDKVADQISDAILDAIFKQDPRSRVAAETL...,VDTYGGACPHGGGAFSGKDPSKVDRSAAYAARYVAKNIVA,MANDFLFTSESVSEGHPDKVADQISDAILDAIFKQDPRSRVAAETL...,0.0007562637329101,257,297,0.55,0.99875543626709
2,A0A645GPG4,MEDKYLKRAGLDYWSLVEIKYHDNLDAFFDMYREGKFFLSTTKAKN...,115,1.0,SPOUT,38,80,LSTTKAKNKYTDLKYEKDCFILFGKETAGLPKDLLLKNPDEC,XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXLSTTKAKN...,FILFGKETAGLPKDLLLKNPDECIRVPMIDEARSLNLSNS,MEDKYLKRAGLDYWSLVEIKYHDNLDAFFDMYREGKFFLSTTKAKN...,0.0019111633300781,57,97,0.575,0.9980888366699219
3,A0A2G5ETQ7,MCFNGKRQSPIEIVKKNTVFDQNLGPLIVGYNDASATLINNGFNVE...,212,0.99853516,Carbonic anhydrase,8,207,SPIEIVKKNTVFDQNLGPLIVGYNDASATLINNGFNVELRYENDVG...,XXXXXXXXSPIEIVKKNTVFDQNLGPLIVGYNDASATLINNGFNVE...,PIEIVKKNTVFDQNLGPLIVGYNDASATLINNGFNVELRY,MCFNGKRQSXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX...,0.0004529953002929,9,49,1.0,0.9980821646997071
4,A0A3G9G4U1,MRRPVLLFTVAAFALMSAGLSSCGKPKAEHGDPHAEAAGEHGGGDH...,284,0.99902344,Carbonic anhydrase,89,280,PINLTGVAAPKSVNLTLDYTSSPAKIQNLGHAIQVSPTDGGGVVMD...,XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX...,TWLVYASPLSISPEQVDAYQRLTGPNARPIQPPQGRDILH,MRRPVLLFTVAAFALMSAGLSSCGKPKAEHGDPHAEAAGEHGGGDH...,0.046966552734375,240,280,1.0,0.952056887265625
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
6589,A0A3S5HM02,MPKEIAAPEFKPYIPASAVLPEFTLRALVMGVVLGMIFGASSLYLV...,757,0.99902344,membrane,97,693,FGLGVTMPAILILGFDLEISRVALVAVLGGLLGILLMIPMRRAMIV...,XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX...,TAGSAGESIAFGLGVTMPAILILGFDLEISRVALVAVLGG,MPKEIAAPEFKPYIPASAVLPEFTLRALVMGVVLGMIFGASSLYLV...,0.998046875,87,127,0.75,0.0009765649999999848
6590,A0A3B4X9R3,MSPDCWKKRQNPQIKRFPMETGLVNVFRPVQELNERQVSASPATPL...,355,0.99902344,Carbonic anhydrase,80,325,PINIVTKTAVIDEHLDAFTYTKFDDKNTIKSITNSGHSVKCVLKED...,XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX...,SSEVSVTEEISIDDLLGNVNRAAYYRYNGSLTTPSCNEAV,MSPDCWKKRQNPQIKRFPMETGLVNVFRPVQELNERQVSASPATPL...,0.958984375,244,284,1.0,0.040039064999999985
6591,A0A1X0QVI8,MASQENVTKDELQYSEKKEEEIINEKETKGQEEFWLKDTSDVKAVE...,761,0.9995117,membrane,161,716,SAYGTSVLSTQQLYFNRTPGVAGSIFFLFSTQLIGYGIAGQLRSFM...,XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX...,KHDDEHYFNADETDLAIVNEIAITEDTPNVPIITVRAIVV,MASQENVTKDELQYSEKKEEEIINEKETKGQEEFWLKDTSDVKAVE...,0.9990234375,49,89,0.0,0.00048826250000000293
6592,A0A6M2AUN7,MAYLFTSESVSEGHPDKVADQISDALLDNFLAFDPESKVACETLVT...,418,0.9995117,AdoMet synthase,417,417,,XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX...,MAYLFTSESVSEGHPDKVADQISDALLDNFLAFDPESKVA,XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXCETLVT...,0.0023307800292968,0,40,0.0,0.9971809199707032


In [8]:
df_tst = df.sample(3, random_state=444)
df_tst

Unnamed: 0,id,sequence_str,sequence_len,sequence_pred,family,knot_start,knot_end,knot,knot_str,patch,min_sequence_str,min_pred,min_start,min_end,overlap_pred,drop_difference
4614,A0A1M4TWI8,MKKFTSVSDVENLQEIIKKALQIKENPLSETEKGKGKTIGLVFLNS...,332,1.0,ATCase/OTCase,184,251,LTWAPHIKPIAQAVGNSFAEWMQEMDVEFVITNPEGYDLDKNFTKE...,XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX...,NWHFDSSPSSPQTKLSINKKPKVVLTWAPHIKPIAQAVGN,MKKFTSVSDVENLQEIIKKALQIKENPLSETEKGKGKTIGLVFLNS...,0.0010566711425781,160,200,0.4,0.998943328857422
2956,A0A1Q3KF83,MSDPVDNKVKVLILQHPQEQDRVLGTAKLIATTLADARVVIGLSWR...,217,1.0,TDD,62,129,VLYLGSTQVKGGKQGPAPVVAVDRKGEPLADQAAGLRGLKGLIALD...,XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX...,LRGLKGLIALDGNWAQAKALWWRNAWLTKLRRFVVVPDGP,MSDPVDNKVKVLILQHPQEQDRVLGTAKLIATTLADARVVIGLSWR...,0.0135116577148437,97,137,0.8,0.9864883422851562
293,A0A4Q8QHA0,MKHFLSLNDIDSLPNLVEDAIALKKSPYQFDALGKNKTICLLFFNN...,312,1.0,ATCase/OTCase,170,237,LSWAPHPKALPHAVANSFVSMIKMQHAEFVITHPKGYELNPEITHG...,XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX...,AHPLQALADAITMEENKTKAKPKIVLSWAPHPKALPHAVA,MKHFLSLNDIDSLPNLVEDAIALKKSPYQFDALGKNKTICLLFFNN...,0.716796875,145,185,0.375,0.283203125


In [9]:
def tokenize_function(s):
    seq_split = ' '.join(s)
    return tokenizer(seq_split, return_tensors='pt')

## Captum interpretation:

In [10]:
def predict(inputs):
    return model(inputs).logits

def custom_forward(inputs):
    preds = predict(inputs)
    return torch.softmax(preds, dim = 1)[0][0].unsqueeze(-1)

Compute attributions with respect to the `BertEmbeddings` layer:

1. define baselines/ references,
2. numericalize baselines and inputs.

In [11]:
REF_TOKEN_ID = tokenizer.pad_token_id  # token used for generating token reference
SEP_TOKEN_ID = tokenizer.sep_token_id  # token added to the end of the input text
CLS_TOKEN_ID = tokenizer.cls_token_id  # token used at the beginning of the input text
print(f'REF={REF_TOKEN_ID}, SEP={SEP_TOKEN_ID}, CLS={CLS_TOKEN_ID}')

REF=0, SEP=3, CLS=2


Helper function for getting one sequence in raw, tokenized, patched, and patched_tokenized formats:

In [12]:
def get_sample(dataset, index):
    seq_info = df.iloc[index]
    seq_id = seq_info['id']
    seq_len = int(seq_info['sequence_len'])
    
    raw_seq = seq_info['sequence_str']
    tokenized_seq = tokenize_function(raw_seq)
    
    patched_seq = seq_info['min_sequence_str']
    tokenized_patched_seq = tokenize_function(patched_seq)

    print(f'Seq {index}-{seq_id}: {raw_seq} (length: {seq_len})')
    #print(f'Patched: \t  {patched_seq}')
    
    ref_input_ids = [CLS_TOKEN_ID] + [REF_TOKEN_ID] * seq_len + [SEP_TOKEN_ID]
    return raw_seq, tokenized_seq, tokenized_patched_seq, torch.tensor([ref_input_ids], device=device)

sample, tokenized_sample, tokenized_patched_sample, tokenized_baseline = get_sample(df, 1)

Seq 1-J3D4J7: MANDFLFTSESVSEGHPDKVADQISDAILDAIFKQDPRSRVAAETLTNTGLVVLAGEITTNAHVDYIQVARDTIKRIGYDNTEYGIDYKGCAVMVCYDKQSNDIAQGVDHASDDHLNIGAGDQGLMFGYACDETPDLMPAPIYYAHRLVERQAQLRKDGRLPFLRPDAKSQVTMRYVDGKPHSIDTVVLSTQHSPDQSETPHKMKASFNEAIIEEIIKPVLPKGMLTKDTRYLINPTGRFVIGGPQGDCGLTGRKIIVDTYGGACPHGGGAFSGKDPSKVDRSAAYAARYVAKNIVAAGLARQ (length: 303)


Get predictions for the tokenized sequence:

There are 2 different ways of computing the attributions for emebedding layers. One option is to use `LayerIntegratedGradients` and compute the attributions with respect to `BertEmbedding`. The second option is to use `LayerIntegratedGradients` for each `word_embeddings`, `token_type_embeddings` and `position_embeddings` and compute the attributions w.r.t each embedding vector.

In [13]:
lig = LayerIntegratedGradients(custom_forward, model.bert.embeddings)

Helper function to summarize attributions for each word token in the sequence:

In [14]:
def summarize_attributions(attributions):
    attributions = attributions.sum(dim=-1).squeeze(0)
    attributions = attributions/ torch.norm(attributions)
    return attributions

In [15]:
def interpret(input_tensors, ref_input_ids, label, n_steps=1):
    input_ids = input_tensors['input_ids']
    attributions, delta = lig.attribute(inputs=input_ids,
                                    baselines=ref_input_ids,
                                    return_convergence_delta=True,
                                    n_steps=n_steps)
    # https://captum.ai/api/layer.html#layer-integrated-gradients
    # https://github.com/pytorch/captum/issues/538
    
    indices = input_ids[0].detach().tolist()
    all_tokens = tokenizer.convert_ids_to_tokens(indices)
    attributions_sum = summarize_attributions(attributions)
    attribution_scores = [round(_, 4) for _ in attributions_sum.tolist()]
    
    del(input_ids)
    del(attributions)
    del(delta)
    del(indices)
    del(all_tokens)
    del(attributions_sum)
    
    return attribution_scores

Observe the difference between different values used for `n_steps` (https://captum.ai/api/layer.html#captum.attr.LayerIntegratedGradients: "The number of steps used by the approximation method.").

*Default is 50.*

In [16]:
def get_top_k(attributions, k, largest=True):
    attributions_tensor = torch.tensor(attributions, device=device)
    values, indices = torch.topk(attributions_tensor, k, largest=largest)
    if largest:
        return {'top_k_largest_values': values.tolist(), 'top_k_largest_indices': indices.tolist()}
    else:
        return {'top_k_smallest_values': values.tolist(), 'top_k_smallest_indices': indices.tolist()}

In [17]:
import time

most_influential_tokens_pos = []
most_influential_tokens_neg = []

for i in range(len(df_tst)):
    _, tokenized_sample, _, tokenized_baseline = get_sample(df_tst, i)
    sample_attributions = interpret(tokenized_sample, tokenized_baseline, 1, 50)
    most_influential_tokens_pos.append(get_top_k(sample_attributions, 20))
    most_influential_tokens_neg.append(get_top_k(sample_attributions, 20, False))
    
    del(tokenized_sample)
    del(tokenized_baseline)
    del(sample_attributions)
    
len(most_influential_tokens_pos)

Seq 0-A0A0X3YDV6: MVSKLGWQLVQLGRRLWVRTVAFAVLALFSALVAVLVQDYIPETLSDIIGAGAAENILNILATSMLTVTTFSLSVMVAAYSASSKDVSPRATRLLMEDSTTQNALATFVGSFLFSIVSIILLNTEVYNQRGRVVLFLATILVIVLIVVMILIWISHLSSLGRVGETAGKVEDQAYAALKRHNKYPWLQAQPLQSLQQIPKDAKPVLTNTIGYLQLIDIKAINKWAAENNCHAYVAVRPGVFVEPCRPLLWLSPAPADTSSVPLDAFCISNYRTFDQDPRFGLLVLAEIASKALSPAINDPGTAIEIIGRGMRVLSNWQAPSMDKATIQYPQVSVLPLSLTDLFDDFYTPIARDGAAMLEVQIRLQKSLLALAGQRAEFRPHTQRHSTAALARAEHAMAYSEDIKTLKALHQQLCDS (length: 416)
Seq 1-J3D4J7: MANDFLFTSESVSEGHPDKVADQISDAILDAIFKQDPRSRVAAETLTNTGLVVLAGEITTNAHVDYIQVARDTIKRIGYDNTEYGIDYKGCAVMVCYDKQSNDIAQGVDHASDDHLNIGAGDQGLMFGYACDETPDLMPAPIYYAHRLVERQAQLRKDGRLPFLRPDAKSQVTMRYVDGKPHSIDTVVLSTQHSPDQSETPHKMKASFNEAIIEEIIKPVLPKGMLTKDTRYLINPTGRFVIGGPQGDCGLTGRKIIVDTYGGACPHGGGAFSGKDPSKVDRSAAYAARYVAKNIVAAGLARQ (length: 303)
Seq 2-A0A645GPG4: MEDKYLKRAGLDYWSLVEIKYHDNLDAFFDMYREGKFFLSTTKAKNKYTDLKYEKDCFILFGKETAGLPKDLLLKNPDECIRVPMIDEARSLNLSNSVAIVVYEALRQIGFPNMI (length: 115)


3

In [18]:
df_top_k_pos = pd.DataFrame(most_influential_tokens_pos)
df_top_k_pos

Unnamed: 0,top_k_largest_values,top_k_largest_indices
0,"[0.09099999815225601, 0.08780000358819962, 0.0...","[201, 212, 89, 233, 172, 227, 217, 257, 231, 1..."
1,"[0.12049999833106995, 0.10670000314712524, 0.1...","[88, 146, 143, 302, 144, 193, 110, 182, 115, 2..."
2,"[0.09359999746084213, 0.06599999964237213, 0.0...","[113, 104, 2, 114, 112, 105, 9, 85, 91, 1, 110..."


In [19]:
df_top_k_neg = pd.DataFrame(most_influential_tokens_neg)
df_top_k_neg

Unnamed: 0,top_k_smallest_values,top_k_smallest_indices
0,"[-0.16670000553131104, -0.1535000056028366, -0...","[1, 4, 41, 403, 57, 60, 35, 49, 100, 45, 10, 7..."
1,"[-0.3440999984741211, -0.1429000049829483, -0....","[303, 59, 4, 3, 8, 45, 252, 2, 257, 227, 10, 1..."
2,"[-0.569599986076355, -0.2240999937057495, -0.1...","[115, 94, 90, 92, 75, 82, 72, 74, 76, 84, 98, ..."


In [20]:
df_tst['top_k_largest_indices'] = df_top_k_pos['top_k_largest_indices'].to_list()
df_tst['top_k_largest_values'] = df_top_k_pos['top_k_largest_values'].to_list()
df_tst['top_k_smallest_indices'] = df_top_k_neg['top_k_smallest_indices'].to_list()
df_tst['top_k_smallest_values'] = df_top_k_neg['top_k_smallest_values'].to_list()
df_tst

Unnamed: 0,id,sequence_str,sequence_len,sequence_pred,family,knot_start,knot_end,knot,knot_str,patch,min_sequence_str,min_pred,min_start,min_end,overlap_pred,drop_difference,top_k_largest_indices,top_k_largest_values,top_k_smallest_indices,top_k_smallest_values
4614,A0A1M4TWI8,MKKFTSVSDVENLQEIIKKALQIKENPLSETEKGKGKTIGLVFLNS...,332,1.0,ATCase/OTCase,184,251,LTWAPHIKPIAQAVGNSFAEWMQEMDVEFVITNPEGYDLDKNFTKE...,XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX...,NWHFDSSPSSPQTKLSINKKPKVVLTWAPHIKPIAQAVGN,MKKFTSVSDVENLQEIIKKALQIKENPLSETEKGKGKTIGLVFLNS...,0.0010566711425781,160,200,0.4,0.998943328857422,"[201, 212, 89, 233, 172, 227, 217, 257, 231, 1...","[0.09099999815225601, 0.08780000358819962, 0.0...","[1, 4, 41, 403, 57, 60, 35, 49, 100, 45, 10, 7...","[-0.16670000553131104, -0.1535000056028366, -0..."
2956,A0A1Q3KF83,MSDPVDNKVKVLILQHPQEQDRVLGTAKLIATTLADARVVIGLSWR...,217,1.0,TDD,62,129,VLYLGSTQVKGGKQGPAPVVAVDRKGEPLADQAAGLRGLKGLIALD...,XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX...,LRGLKGLIALDGNWAQAKALWWRNAWLTKLRRFVVVPDGP,MSDPVDNKVKVLILQHPQEQDRVLGTAKLIATTLADARVVIGLSWR...,0.0135116577148437,97,137,0.8,0.9864883422851562,"[88, 146, 143, 302, 144, 193, 110, 182, 115, 2...","[0.12049999833106995, 0.10670000314712524, 0.1...","[303, 59, 4, 3, 8, 45, 252, 2, 257, 227, 10, 1...","[-0.3440999984741211, -0.1429000049829483, -0...."
293,A0A4Q8QHA0,MKHFLSLNDIDSLPNLVEDAIALKKSPYQFDALGKNKTICLLFFNN...,312,1.0,ATCase/OTCase,170,237,LSWAPHPKALPHAVANSFVSMIKMQHAEFVITHPKGYELNPEITHG...,XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX...,AHPLQALADAITMEENKTKAKPKIVLSWAPHPKALPHAVA,MKHFLSLNDIDSLPNLVEDAIALKKSPYQFDALGKNKTICLLFFNN...,0.716796875,145,185,0.375,0.283203125,"[113, 104, 2, 114, 112, 105, 9, 85, 91, 1, 110...","[0.09359999746084213, 0.06599999964237213, 0.0...","[115, 94, 90, 92, 75, 82, 72, 74, 76, 84, 98, ...","[-0.569599986076355, -0.2240999937057495, -0.1..."


Check overlaps with actual knot core:

In [21]:
def compute_attribution_overlaps(attribution_list, start_i, end_i):
    score = 0
    for attribution_score in attribution_list:
        if attribution_score >= int(start_i) and attribution_score <= int(end_i):
            score += 1
    score = score / len(attribution_list)
    return score

df_tst['core_overlap_pos'] = df_tst.apply(lambda row: compute_attribution_overlaps(row['top_k_largest_indices'], row['knot_start'], row['knot_end']), axis=1)
df_tst['core_overlap_neg'] = df_tst.apply(lambda row: compute_attribution_overlaps(row['top_k_smallest_indices'], row['knot_start'], row['knot_end']), axis=1)
df_tst['min_overlap_pos'] = df_tst.apply(lambda row: compute_attribution_overlaps(row['top_k_largest_indices'], row['min_start'], row['min_end']), axis=1)
df_tst['min_overlap_neg'] = df_tst.apply(lambda row: compute_attribution_overlaps(row['top_k_smallest_indices'], row['min_start'], row['min_end']), axis=1)
df_tst

Unnamed: 0,id,sequence_str,sequence_len,sequence_pred,family,knot_start,knot_end,knot,knot_str,patch,...,overlap_pred,drop_difference,top_k_largest_indices,top_k_largest_values,top_k_smallest_indices,top_k_smallest_values,core_overlap_pos,core_overlap_neg,min_overlap_pos,min_overlap_neg
4614,A0A1M4TWI8,MKKFTSVSDVENLQEIIKKALQIKENPLSETEKGKGKTIGLVFLNS...,332,1.0,ATCase/OTCase,184,251,LTWAPHIKPIAQAVGNSFAEWMQEMDVEFVITNPEGYDLDKNFTKE...,XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX...,NWHFDSSPSSPQTKLSINKKPKVVLTWAPHIKPIAQAVGN,...,0.4,0.998943328857422,"[201, 212, 89, 233, 172, 227, 217, 257, 231, 1...","[0.09099999815225601, 0.08780000358819962, 0.0...","[1, 4, 41, 403, 57, 60, 35, 49, 100, 45, 10, 7...","[-0.16670000553131104, -0.1535000056028366, -0...",0.45,0.0,0.3,0.0
2956,A0A1Q3KF83,MSDPVDNKVKVLILQHPQEQDRVLGTAKLIATTLADARVVIGLSWR...,217,1.0,TDD,62,129,VLYLGSTQVKGGKQGPAPVVAVDRKGEPLADQAAGLRGLKGLIALD...,XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX...,LRGLKGLIALDGNWAQAKALWWRNAWLTKLRRFVVVPDGP,...,0.8,0.9864883422851562,"[88, 146, 143, 302, 144, 193, 110, 182, 115, 2...","[0.12049999833106995, 0.10670000314712524, 0.1...","[303, 59, 4, 3, 8, 45, 252, 2, 257, 227, 10, 1...","[-0.3440999984741211, -0.1429000049829483, -0....",0.35,0.0,0.15,0.05
293,A0A4Q8QHA0,MKHFLSLNDIDSLPNLVEDAIALKKSPYQFDALGKNKTICLLFFNN...,312,1.0,ATCase/OTCase,170,237,LSWAPHPKALPHAVANSFVSMIKMQHAEFVITHPKGYELNPEITHG...,XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX...,AHPLQALADAITMEENKTKAKPKIVLSWAPHPKALPHAVA,...,0.375,0.283203125,"[113, 104, 2, 114, 112, 105, 9, 85, 91, 1, 110...","[0.09359999746084213, 0.06599999964237213, 0.0...","[115, 94, 90, 92, 75, 82, 72, 74, 76, 84, 98, ...","[-0.569599986076355, -0.2240999937057495, -0.1...",0.0,0.0,0.0,0.0


In [22]:
def print_stats(score_str):
    print(f'{score_str}:')
    print(f'Mean: {df_tst[score_str].mean()}')
    print(f'Median: {df_tst[score_str].median()}')
    print(f'Max: {df_tst[score_str].max()}')
    print(f'Min: {df_tst[score_str].min()}')
    
print_stats('core_overlap_pos')
print()
print_stats('core_overlap_neg')
print()
print_stats('min_overlap_pos')
print()
print_stats('min_overlap_neg')

core_overlap_pos:
Mean: 0.26666666666666666
Median: 0.35
Max: 0.45
Min: 0.0

core_overlap_neg:
Mean: 0.0
Median: 0.0
Max: 0.0
Min: 0.0

min_overlap_pos:
Mean: 0.15
Median: 0.15
Max: 0.3
Min: 0.0

min_overlap_neg:
Mean: 0.016666666666666666
Median: 0.0
Max: 0.05
Min: 0.0


In [82]:
df_tst.to_csv('/home/jovyan/data/proteins_m1/p40_word_attribution_interpretation_n_step_50.csv', encoding='utf-8', index=False)