In [1]:
import torch
import ast
import torch.nn as nn
import pandas as pd
import numpy as np
import re

from Bio import SeqIO
from glob import glob
from transformers import (AutoModelForTokenClassification, AutoTokenizer,
                          AutoModelForMaskedLM, DataCollatorForTokenClassification,
                           EsmForMaskedLM, EsmTokenizer,
                           TrainingArguments, Trainer
                        )
from transformers.trainer_callback import ProgressCallback
from sklearn.metrics import (accuracy_score, precision_recall_fscore_support,
                             matthews_corrcoef, roc_auc_score)
from sklearn.model_selection import train_test_split
from pprint import pprint
from datasets import Dataset
from datetime import datetime

  from .autonotebook import tqdm as notebook_tqdm
2025-02-25 18:31:56.259061: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1740483116.363056  105774 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1740483116.396305  105774 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-02-25 18:31:56.619582: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
def load_fasta_dataframe(file_path):
    """
    Load fasta file into a pandas dataframe
    :param file_path: path to fasta file
    :return: dataframe with columns 'id', 'sequence', 'seq_len'
    """
    # read fasta file
    records = [
        (record.id, str(record.seq)) for record in SeqIO.parse(file_path, "fasta")
    ]

    # create dataframe
    df_fasta = pd.DataFrame(records, columns=["id", "sequence"])
    df_fasta["seq_len"] = df_fasta["sequence"].apply(len)
    return df_fasta


def load_binding_sites_dataframe(file_path, target=None):
    """
    Load binding sites file into a pandas dataframe
    :param file_path: path to binding sites file
    :param target: target protein class ('metal', 'nuclear', 'small')
    :return: dataframe with columns 'id', 'binding_sites'
    """
    # check if target is valid
    assert target in [
        "metal",
        "nuclear",
        "small",
    ], "target must be one of 'metal', 'nuclear', 'small'"

    # read binding sites file
    binding_sites = []
    with open(file_path, "r") as f:
        for line in f:
            protein_id, sites = line.strip().split("\t")
            binding_sites.append((protein_id, [int(site) for site in sites.split(",")]))

    # create dataframe
    df_binding_sites = pd.DataFrame(binding_sites, columns=["id", "binding_sites"])
    df_binding_sites["num_residues"] = df_binding_sites["binding_sites"].apply(len)
    df_binding_sites["target"] = target
    return df_binding_sites

In [3]:
df_proteins = load_fasta_dataframe("data/development_set/all.fasta")
display(df_proteins.head())

Unnamed: 0,id,sequence,seq_len
0,Q5LL55,MSETWLPTLVTATPQEGFDLAVKLSRIAVKKTQPDAQVRDTLRAVY...,76
1,H9L4N9,MQINIQGHHIDLTDSMQDYVHSKFDKLERFFDHINHVQVILRVEKL...,95
2,O34738,MKSWKVKEIVIMSVISIVFAVVYLLFTHFGNVLAGMFGPIAYEPIY...,199
3,P39579,MDFKQEVLDVLAEVCQDDIVKENPDIEIFEEGLLDSFGTVELLLAI...,78
4,P01887,MARSVTLVFLVLVSLTGLYAIQKTPQIQVYSRHPPENGKPNILNCY...,119


In [4]:
def convert_to_binary_list(original_binding_sites_lst, sequence_len):
    """Convert a Binding-Active site string to a binary list based on the sequence length."""
    binary_list = [0] * sequence_len  # Initialize a list of zeros

    # Ensure original_binding_sites_lst is a list and not empty
    if isinstance(original_binding_sites_lst, list) and len(original_binding_sites_lst) > 0:
        for idx in original_binding_sites_lst:
            if isinstance(idx, int) and 1 <= idx <= sequence_len:  # Ensure index is valid
                binary_list[idx - 1] = 1

    return binary_list

In [5]:
test_file = open("data/development_set/uniprot_test.txt")
test_ids = test_file.readlines()
test_ids = [re.sub("\n", "", t) for t in test_ids]

In [6]:
training_ids = []
for i in range(1, 6):
    fold_file = open(f"data/development_set/ids_split{i}.txt")
    fold_ids = fold_file.readlines()
    fold_ids = [re.sub("\n", "", t) for t in fold_ids]
    training_ids.extend(fold_ids)

print(training_ids)

['Q5LL55', 'H9L4N9', 'O34738', 'P39579', 'P01887', 'O32221', 'P0CL67', 'A7VAB4', 'O60895', 'P86179', 'P58568', 'Q9Y3B4', 'Q9NWV4', 'Q9KQN0', 'P85511', 'Q2VE61', 'P18138', 'Q9H5X1', 'Q7Z4H3', 'B2FQ63', 'Q05097', 'P20116', 'F6KMV5', 'A9CID9', 'P46926', 'D0VWR5', 'O34918', 'Q8TLY9', 'Q81HL8', 'Q9RJC1', 'P15570', 'Q9RY97', 'B3Y002', 'P07386', 'P77754', 'P94690', 'C9K1X5', 'P04382', 'O32108', 'P32021', 'P17900', 'P43933', 'P0AE05', 'D0VWX2', 'P49789', 'A0A384LKY8', 'P67700', 'Q58830', 'O87496', 'Q52424', 'C3SZN7', 'P50049', 'P16108', 'P05161', 'Q9RN60', 'P0CX80', 'G0Z026', 'P67809', 'P83467', 'P10175', 'F1NZ18', 'Q47765', 'Q82Z21', 'H9NAL3', 'Q9KCV1', 'Q07341', 'E7E815', 'A7UQX3', 'P04038', 'O68396', 'Q8Z4D7', 'Q77DJ5', 'P48061', 'Q8XMB9', 'P27999', 'I0BZV0', 'D1CIZ5', 'Q0P9D1', 'Q74G82', 'P07737', 'Q00277', 'Q96FJ2', 'P02876', 'Q8GGH0', 'P04418', 'P29256', 'Q1Q7P3', 'P03606', 'Q05776', 'P23907', 'P14135', 'P0A6I0', 'Q65NU7', 'Q5LAA6', 'Q38HX3', 'Q6XBH1', 'Q96YV5', 'Q8RQE7', 'O58552', 'P081

In [7]:
binding_sites_df = pd.read_csv("data/development_set/all_binding_sites_complete.csv")
binding_sites_df['binding_sites'] = binding_sites_df['binding_sites'].apply(ast.literal_eval)

In [21]:
training_binding_sites_df = binding_sites_df.loc[(binding_sites_df['prot_id'].isin(training_ids))]
display(training_binding_sites_df)

Unnamed: 0,prot_id,binding_sites,ligand_type,sequence,sequence_length,binary_binding_sites
6,P16113,"[31, 42, 46, 50, 52, 62, 66, 67, 68, 69, 70, 8...",small,MEHVAFGSEDIENTLAKMDDGQLDGLAFGAIQLDGDGNILQYNAAE...,125,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
14,O76242,"[11, 12, 22, 25, 26, 45, 48, 49, 52, 53, 66, 6...",small,MVNWAAVVDDFYQELFKAHPEYQNKFGFKGVALGSLKGNAAYKTQA...,110,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, ..."
39,P26789,"[1, 2, 3, 6, 7, 10, 15, 17, 18, 19, 20, 22, 23...",small,MNQGKIWTVVNPAIGIPALLGSVTVIAILVHLAILSHTTWFPAYWQ...,53,"[1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 1, ..."
40,P26790,"[33, 37, 39, 15, 16, 18, 19, 20, 21, 22, 23, 2...",small,ATLTAEQSEELHKYVIDGTRVFLGLALVAHFLAFSATPWLH,41,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, ..."
52,P21149,"[45, 46, 48, 50, 51, 52, 53, 83, 85, 55, 31]",small,MLSQVCRFGTITAVKGGVKKQLKFEDDQTLFTVLTEAGLMSADDTC...,100,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
...,...,...,...,...,...,...
18978,Q6N5V5,"[56, 59, 60, 63]",small,MTGPKQQPLPPDVEGREDAIEVLRAFVLDGGLSIAFMRAFEDPEMW...,98,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
18981,C4LSE7,"[7, 8, 9, 10, 11, 44, 45, 12, 13, 123, 125, 126]",small,MKLLFVCLGNICRSPAAEAVMKKVIQNHHLTEKYICDSAGTCSYHE...,157,"[0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, ..."
19012,Q0B0G9,"[143, 144, 145, 30, 44, 46, 47, 48, 49, 50, 62...",small,MLESVRKEWLEIMDRELLEKARSLINANYISTTLSTVDRNYEVNIA...,147,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
19047,Q5SHE1,"[129, 130, 74, 107, 75, 14, 111, 48, 49, 47, 126]",small,MDLTHFQDGRPRMVDVTEKPETFRTATAEAFVELTEEALSALEKGG...,157,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, ..."


In [9]:
testing_binding_sites_df = binding_sites_df.loc[(binding_sites_df['prot_id'].isin(test_ids))]
display(testing_binding_sites_df)

Unnamed: 0,prot_id,binding_sites,ligand_type,sequence,sequence_length,binary_binding_sites
4,P46859,"[160, 159, 163, 40, 17, 18, 19, 20, 21, 22, 23...",small,MSTTNHDHHIYVLMGVSGSGKSAVASEVAHQLHAAFLDGDFLHPRR...,175,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
10,Q9HAN9,"[269, 14, 15, 16, 17, 23, 24, 27, 155, 156, 15...",small,MENSEKTEVVLLACGSFNPITNMHLRLFELAKDYMNGTGRYTVVKG...,279,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, ..."
29,P62799,"[64, 103, 85, 89, 60, 61]",small,MSGRGKGGKGLGKGGAKRHRKVLRDNIQGITKPAIRRLARRGGVKR...,103,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
30,P06897,"[65, 66, 37, 38, 40, 17, 18, 93, 58, 61, 62]",small,MSGRGKQGGKTRAKAKTRSSRAGLQFPVGRVHRLLRKGNYAERVGA...,130,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
31,P02281,"[35, 72, 106, 107, 109, 110, 83, 84]",small,MPEPAKSAPAPKKGSKKAVTKTQKKDGKKRRKSRKESYAIYVYKVL...,126,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
...,...,...,...,...,...,...
18797,P76143,"[225, 226, 227, 203, 107, 57, 58, 251, 252, 25...",small,MADLDDIKDGKDFRTDQPQKNIPFTLKGCGALDWGMQSRLSRIFNP...,291,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
18799,Q979C2,"[7, 8, 9, 10, 14, 16, 17, 20, 118, 119, 120, 9...",small,MIRVMATGVFDILHLGHIHYLKESKKLGDELVVVVARDSTARNNGK...,142,"[0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 1, 0, ..."
18846,Q9A585,"[128, 65, 162, 67, 68, 69, 165, 71, 155, 86, 1...",small,MAGEDADMIDRRTMLMAGGLAMTTTMAKAGEGTDAIQALIQAYFTA...,174,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
18877,A2I2W2,"[164, 134, 166, 136, 169, 170, 171, 204, 173, ...",small,MKKIIFSALCALPLIVSLTSCGKKKDEPNQPSTPEAVTKTVTIDAS...,216,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."


In [23]:
print(training_binding_sites_df['ligand_type'].value_counts())
print(testing_binding_sites_df['ligand_type'].value_counts())

ligand_type
small      606
metal      455
nuclear    108
Name: count, dtype: int64
ligand_type
small      220
metal      124
nuclear     66
Name: count, dtype: int64


In [25]:
grouped_training_binding_sites_df = training_binding_sites_df.groupby('prot_id').agg(list)
multiple_ligands_binding_df = grouped_training_binding_sites_df[
    grouped_training_binding_sites_df['ligand_type'].apply(
        lambda x: len(set(x)) > 1
    )
]
display(grouped_training_binding_sites_df)

Unnamed: 0_level_0,binding_sites,ligand_type,sequence,sequence_length,binary_binding_sites
prot_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
A0A0F7RDM3,"[[129, 135, 137, 53, 29]]",[metal],[MEIRKKLVVPSKYGTKCPYTMKPKYITVHNTYNDAPAENEVNYMI...,[234],"[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,..."
A0A0F7RHX8,"[[67, 133, 135, 72, 76, 78, 123]]",[small],[MHLKEKITTIIQGQRTGVLSTVRNDKPHSAFMMFFHEDFVLYVAT...,[138],"[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,..."
A0A0H2W6Y8,"[[129, 197, 198, 207, 144, 85, 214, 119, 213, ...",[metal],[MMKILGLIGGMSWESTIPYYRMINQHVKAQLGGLHSAKIILYSVD...,[231],"[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,..."
A0A2B6C3P9,"[[111, 79, 114, 117, 122, 125]]",[metal],[MSINKWLFRFIGFLVMLVVITTLNSLNVFASVNDLAQPIASAKVI...,[129],"[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,..."
A0A384LKY8,"[[136, 76, 12], [5, 72, 136, 10, 43, 76, 12, 1...","[metal, small]",[MSEHFVGKYEVELKFRVMDLTTLHEQLVAQKATAFTLNNHEKDIY...,"[179, 179]","[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0,..."
...,...,...,...,...,...
S0BAP9,"[[128, 129, 132, 56, 66, 69, 70, 72, 73, 77, 8...",[small],[MNPLSTVLLVLCATSAALASEFCSEADATIVIKQWNQIYNAGIGA...,[168],"[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,..."
S3TFW2,"[[192, 193, 75, 204, 79, 277, 278, 279, 280, 2...",[small],[MTRPDSKSMNYQLLKTFSRQPIQFGRFLARLLAGLVNTLKITRTS...,[327],"[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,..."
V6F235,"[[232, 230, 255]]",[small],[MRKSGCAVCSRSIGWVGLAVSTVLMVMKAFVGLIGGSQAMLADAM...,[318],"[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,..."
V9P0A9,"[[257, 133, 135, 136, 13, 33, 34, 37, 38, 42, ...",[small],[MADFKFEPMRSLIYVDCVSEDYRPKLQRWIYKVHIPDSISQFEPY...,[283],"[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0,..."


In [26]:
display(multiple_ligands_binding_df)

Unnamed: 0_level_0,binding_sites,ligand_type,sequence,sequence_length,binary_binding_sites
prot_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
A0A384LKY8,"[[136, 76, 12], [5, 72, 136, 10, 43, 76, 12, 1...","[metal, small]",[MSEHFVGKYEVELKFRVMDLTTLHEQLVAQKATAFTLNNHEKDIY...,"[179, 179]","[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0,..."
A1TY68,"[[64, 58, 60, 94], [35, 106, 108, 45, 47, 30]]","[metal, small]",[MEINADFTKPVVIDTDQLEWRPSPMKGVERRMLDRIGGEVARATS...,"[222, 222]","[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,..."
A4GRC7,"[[6, 7, 8, 15, 16, 18, 19, 20, 21, 23, 28, 29,...","[small, metal]",[MKLSLMVAISKNGVIGNGPDIPWSAKGEQLLFKAITYNQWLLVGR...,"[157, 157]","[[0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1,..."
A4GRE3,"[[25, 21], [96, 33, 2, 3, 65, 37, 107, 109, 50...","[metal, small]",[MPMVRVATNLPDKDVPANFEERLTDLLAESMNKPRNRIAIEVLAG...,"[119, 119]","[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,..."
A7ATL3,"[[65, 66], [34, 69, 70, 46, 48, 49]]","[metal, small]",[MVSKSIVEERLRSMLSPQFLKVTDNSGGCGAAFNAYIVSQQFEGK...,"[86, 86]","[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,..."
...,...,...,...,...,...
Q9UT12,"[[186, 132, 134], [132, 134, 144, 210, 212, 18...","[metal, small]",[MLYENMSDSFLLSDAGLEFDEALLEVDQEKDDYLDDFENWTVVPV...,"[225, 225]","[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,..."
Q9X113,"[[68, 102, 70, 118, 23, 61, 63], [68, 102, 70,...","[small, metal]",[MKEGTGMVVRSSEITPERISNMRGGKGEVEMAHLLSKEAMHNKAR...,"[121, 121]","[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,..."
Q9X1W7,"[[42, 108], [132, 37, 38, 134, 40, 41, 42, 43,...","[metal, small]",[MRHLRFENLTEEQLKRLAKILTENLKGGEVVILSGNLGAGKTTFV...,"[161, 161]","[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,..."
Q9XG81,"[[64, 65, 132, 136, 106, 83, 86, 54, 87, 62, 6...","[small, metal]",[MRFFLKLAPRCSVLLLLLLVTASRGLNIGDLLGSTPAKDQGCSRT...,"[153, 153]","[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,..."


### Test functions

In [30]:
multiple_ligands_binding_df['ligand_type'].value_counts()

ligand_type
[metal, small]             74
[small, metal]             54
[nuclear, metal]           15
[metal, nuclear]            4
[nuclear, small]            2
[nuclear, metal, small]     2
[nuclear, small, metal]     1
Name: count, dtype: int64

In [28]:
# Get samples
binding_sites = multiple_ligands_binding_df.iloc[0]['binding_sites']
ligand_types = multiple_ligands_binding_df.iloc[0]['ligand_type']

print(binding_sites)
print(ligand_types)

[[136, 76, 12], [5, 72, 136, 10, 43, 76, 12, 14, 111, 172, 113, 173, 83, 140, 85, 74, 63]]
['metal', 'small']


In [None]:
# 0: No-binding, 1: Metal, 2: small, 3: nuclear, 4: metal+small, 5: metal+nuclear, 6: small+nuclear, 7: all 3