In [1]:
from Bio import SeqIO
from transformers import AutoModel, AutoTokenizer, AutoModelForMaskedLM, EsmForMaskedLM, EsmTokenizer
from pathlib import Path
from pprint import pprint

import proteinbert
import os
import torch
import torch.nn as nn
import pandas as pd
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm
2025-01-31 18:21:18.428437: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1738322478.578383   22806 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1738322478.625797   22806 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-01-31 18:21:18.880320: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
def parse_fasta(fasta_file):
    sequence_records = []
    for record in SeqIO.parse(fasta_file, "fasta"):
        sequence_records.append((record.id, str(record.seq)))
    return sequence_records

def parse_labels(label_file, ligand_type):
    binding_labels = []
    with open(label_file, "r") as f:
        for line in f:
            protein_id, indices = line.strip().split("\t")
            binding_labels.append((protein_id, set(map(int, indices.split(",")))))
    
    binding_labels_df = pd.DataFrame(binding_labels, columns=['prot_id', 'binding_sites'])
    binding_labels_df['ligand_type'] = ligand_type
    return binding_labels_df

In [3]:
fasta_file = "data/development_set/all.fasta"
metal_label_file = "data/development_set/binding_residues_2.5_metal.txt"
nuclear_label_file = "data/development_set/binding_residues_2.5_nuclear.txt"
small_label_file = "data/development_set/binding_residues_2.5_small.txt"

seq_records = parse_fasta(fasta_file)
metal_binding_sites_df = parse_labels(metal_label_file, 'metal')
nuclear_binding_sites_df = parse_labels(nuclear_label_file, 'nuclear')
small_binding_sites_df = parse_labels(small_label_file, 'small')

In [4]:
sequences_df = pd.DataFrame(seq_records, columns=['prot_id', 'sequence'])

In [5]:
all_binding_sites_df = pd.concat([metal_binding_sites_df, nuclear_binding_sites_df, small_binding_sites_df],
                                 ignore_index=True)

In [6]:
seq_dict = {}
for record in SeqIO.parse(fasta_file, "fasta"):
    seq_dict[record.id] = str(record.seq)

In [7]:
all_binding_sites_df['sequence'] = all_binding_sites_df['prot_id'].map(seq_dict)
all_binding_sites_df['binding_sites'] = all_binding_sites_df['binding_sites'].apply(lambda x: list(x))
all_binding_sites_df['sequence_length'] = all_binding_sites_df['sequence'].apply(lambda x: len(x) if type(x) == str else 0)

In [8]:
print(all_binding_sites_df.iloc[19052]['sequence'])
print(type(all_binding_sites_df.iloc[19052]['sequence']))

MAASSRAQVLSLYRAMLRESKRFSAYNYRTYAVRRIRDAFRENKNVKDPVEIQTLVNKAKRDLGVIRRQVHIGQLYSTDKLIIENRDMPRT
<class 'str'>


In [9]:
display(all_binding_sites_df)

Unnamed: 0,prot_id,binding_sites,ligand_type,sequence,sequence_length
0,P02185,"[65, 36, 37, 69, 39, 42, 44, 13, 120, 25, 123,...",metal,,0
1,P09211,"[8, 14, 78, 82, 114, 148, 117, 86, 118, 30, 31]",metal,,0
2,P00817,"[193, 121, 102, 59, 79, 116, 148, 118, 153, 15...",metal,,0
3,P01112,"[3, 137, 138, 13, 16, 17, 153, 154, 28, 30, 31...",metal,,0
4,P07378,"[39, 399, 376, 377, 219]",metal,,0
...,...,...,...,...,...
19048,C7G9B5,"[64, 102, 134, 136, 110, 62, 95]",small,,0
19049,Q91159,"[76, 77, 80, 81, 91, 125, 126, 127]",small,,0
19050,Q4K977,"[101, 102, 71, 14, 17, 152, 121, 123, 124]",small,,0
19051,Q9Y697,"[128, 257, 258, 255, 232, 234, 203, 235, 207, ...",small,,0


In [16]:
# sample_lst = [12, 13, 16, 20, 23, 24, 27, 30, 31]
# binary_lst = [0] * 32
# for idx in sample_lst:
#     binary_lst[idx - 1] = 1
# print(binary_lst)

def convert_to_binary_list(original_binding_sites_lst, sequence_len):
    """Convert a Binding-Active site string to a binary list based on the sequence length."""
    binary_list = [0] * sequence_len  # Initialize a list of zeros
    
    # Ensure original_binding_sites_lst is a list and not empty
    if isinstance(original_binding_sites_lst, list) and len(original_binding_sites_lst) > 0:
        for idx in original_binding_sites_lst:
            if isinstance(idx, int) and 1 <= idx <= sequence_len:  # Ensure index is valid
                binary_list[idx - 1] = 1  # Convert to 0-based index

    return binary_list

In [11]:
all_binding_sites_not_null_df = all_binding_sites_df[all_binding_sites_df['sequence'].notna()]
all_binding_sites_not_null_df

Unnamed: 0,prot_id,binding_sites,ligand_type,sequence,sequence_length
6,P00698,"[128, 131, 132, 136, 139, 142, 29, 32, 33, 42,...",metal,MRSLLILVLCFLPLAALGKVFGRCELAAAMKRHGLDNYRGYSLGNW...,147
15,P00648,"[130, 107, 109, 149, 120]",metal,MMKMEGIALKKRLSWISVCLLVLVSAAGMLFSTAAKTETSSHKAHT...,157
38,P19267,"[34, 38]",metal,MELPIAPIGRIIKDAGAERVSDDARITLAKILEEMGRDIASEAIKL...,69
55,P09598,"[160, 166, 39, 107, 52, 180, 184, 156, 93]",metal,MKKKVLALAAAITVVAPLQSVAFAHENDGGSKIKIVHRWSAEDKHK...,283
87,P01574,"[114, 118]",metal,MTNKCLLQIALLLCFSTTALSMSYNLLGFLQRSSNFQCQKLLWQLN...,187
...,...,...,...,...,...
18981,P58568,"[19, 21, 22, 23, 24, 25, 26, 28, 29, 30, 31, 3...",small,MADKADQSSYLIKFISTAPVAATIWLTITAGILIEFNRFFPDLLFHPLP,49
18983,P58560,"[32, 33, 37, 41, 16, 17, 20, 21, 24, 25, 26, 2...",small,MATAFLPSILADASFLSSIFVPVIGWVVPIATFSFLFLYIEREDVA,46
18985,Q8YNB0,"[12, 13, 16, 20, 23, 24, 27, 30, 31]",small,MSSISDTQVYIALVVALIPGLLAWRLATELYK,32
19026,A9CID9,"[72, 73, 78, 71]",small,MKLVWTLSSWDDYEFWQRTDARMVEKINDLIRNAKRTPFAGLGKPE...,89


In [17]:
all_binding_sites_not_null_df['binary_binding_sites'] = all_binding_sites_not_null_df.apply(
    lambda row: convert_to_binary_list(row['binding_sites'], row['sequence_length']), axis=1
)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  all_binding_sites_not_null_df['binary_binding_sites'] = all_binding_sites_not_null_df.apply(


In [18]:
display(all_binding_sites_not_null_df)

Unnamed: 0,prot_id,binding_sites,ligand_type,sequence,sequence_length,binary_binding_sites
6,P00698,"[128, 131, 132, 136, 139, 142, 29, 32, 33, 42,...",metal,MRSLLILVLCFLPLAALGKVFGRCELAAAMKRHGLDNYRGYSLGNW...,147,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
15,P00648,"[130, 107, 109, 149, 120]",metal,MMKMEGIALKKRLSWISVCLLVLVSAAGMLFSTAAKTETSSHKAHT...,157,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
38,P19267,"[34, 38]",metal,MELPIAPIGRIIKDAGAERVSDDARITLAKILEEMGRDIASEAIKL...,69,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
55,P09598,"[160, 166, 39, 107, 52, 180, 184, 156, 93]",metal,MKKKVLALAAAITVVAPLQSVAFAHENDGGSKIKIVHRWSAEDKHK...,283,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
87,P01574,"[114, 118]",metal,MTNKCLLQIALLLCFSTTALSMSYNLLGFLQRSSNFQCQKLLWQLN...,187,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
...,...,...,...,...,...,...
18981,P58568,"[19, 21, 22, 23, 24, 25, 26, 28, 29, 30, 31, 3...",small,MADKADQSSYLIKFISTAPVAATIWLTITAGILIEFNRFFPDLLFHPLP,49,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
18983,P58560,"[32, 33, 37, 41, 16, 17, 20, 21, 24, 25, 26, 2...",small,MATAFLPSILADASFLSSIFVPVIGWVVPIATFSFLFLYIEREDVA,46,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
18985,Q8YNB0,"[12, 13, 16, 20, 23, 24, 27, 30, 31]",small,MSSISDTQVYIALVVALIPGLLAWRLATELYK,32,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, ..."
19026,A9CID9,"[72, 73, 78, 71]",small,MKLVWTLSSWDDYEFWQRTDARMVEKINDLIRNAKRTPFAGLGKPE...,89,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
