In [1]:
import pandas as pd
from Bio import SeqIO
from pprint import pprint
from Bio.SeqUtils import ProtParam
import numpy as np

from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import cross_val_score
from sklearn.metrics import precision_score, recall_score, f1_score, matthews_corrcoef

### Data pre-processing

In [2]:
def load_fasta(fasta_file):
    return {record.id: str(record.seq) for record in SeqIO.parse(fasta_file, "fasta")}

def load_binding_sites(file_path):
    binding_sites = {}
    with open(file_path, "r") as f:
        for line in f:
            protein_id, sites = line.strip().split("\t")
            binding_sites[protein_id] = [int(site) for site in sites.split(",")]
    return binding_sites

In [6]:
proteins = load_fasta("data/development_set/all.fasta")
metal_sites = load_binding_sites(
    "data/development_set/binding_residues_2.5_metal.txt"
)
nuclear_sites = load_binding_sites(
    "data/development_set/binding_residues_2.5_nuclear.txt"
)
small_sites = load_binding_sites(
    "data/development_set/binding_residues_2.5_small.txt"
)

# print first few lines of each dictionary
pprint(list(proteins.items())[:3])
pprint(list(metal_sites.items())[:3])
print(list(nuclear_sites.items())[:3])
print(list(small_sites.items())[:3])

[('Q5LL55',
  'MSETWLPTLVTATPQEGFDLAVKLSRIAVKKTQPDAQVRDTLRAVYEKDANALIAVSAVVATHFQTIAAANDYWKD'),
 ('H9L4N9',
  'MQINIQGHHIDLTDSMQDYVHSKFDKLERFFDHINHVQVILRVEKLRQIAEATLHVNQAEIHAHADDENMYAAIDSLVDKLVRQLNKHKEKLSSH'),
 ('O34738',
  'MKSWKVKEIVIMSVISIVFAVVYLLFTHFGNVLAGMFGPIAYEPIYGIWFIVSVIAAYMIRKPGAALVSEIIAALVECLLGNPSGPMVIVIGIVQGLGAEAVFLATRWKAYSLPVLMLAGMGSSVASFIYDLFVSGYAAYSPGYLLIMLVIRLISGALLAGLLGKAVSDSLAYTGVLNGMALGKELKKKRKRASEHASL')]
[('P02185', [37, 30, 36, 39, 120, 69, 63, 42, 44, 25, 123, 60, 13, 65]),
 ('P09211', [30, 8, 31, 86, 14, 148, 117, 118, 82, 78, 114]),
 ('P00817', [153, 121, 57, 79, 193, 116, 155, 102, 148, 118, 94, 59])]
[('P00698', [28, 19, 144, 24, 25, 23, 146, 145]), ('P08046', [379, 350, 377, 375, 418, 380, 334, 359, 406, 382, 408, 401, 410, 354, 347, 420, 332, 349, 373, 378, 345, 355, 405, 348, 383, 364, 386, 407, 412, 384, 403, 387, 376, 352, 411, 366, 356, 358, 415, 351]), ('P27958', [1418, 1295, 1324, 1397, 1527, 1298, 1439, 1256, 1297, 1458, 1281, 1396, 1582, 1460, 1474, 1

In [7]:
all_binding_sites = {}
for protein_id in proteins:
    all_binding_sites[protein_id] = set(
        metal_sites.get(protein_id, [])
        + nuclear_sites.get(protein_id, [])
        + small_sites.get(protein_id, [])
    )

In [9]:
# Load fold splits
folds = []
for i in range(1, 6):
    with open(f"data/development_set/ids_split{i}.txt", "r") as f:
        folds.append([line.strip() for line in f])

# print first few lines of each fold
pprint(folds[0][:3])

# print data size of each fold
print([len(fold) for fold in folds])

['Q5LL55', 'H9L4N9', 'O34738']
[203, 203, 203, 203, 202]


In [10]:
def rename_sequence(sequence):
    return sequence.replace("U", "").replace("X", "")

def extract_features(sequence):
    sequence = rename_sequence(sequence)

    # Basic features
    pp = ProtParam.ProteinAnalysis(sequence)

    # Amino acid composition
    aa_comp = pp.get_amino_acids_percent()

    # Secondary structure prediction (simplified)
    ss_pred = pp.secondary_structure_fraction()

    # Hydrophobicity
    hydrophobicity = pp.gravy()

    # Isoelectric point
    isoelectric = pp.isoelectric_point()

    return list(aa_comp.values()) + list(ss_pred) + [hydrophobicity, isoelectric]


def prepare_data(protein_ids, proteins, all_binding_sites):
    X, y = [], []
    for protein_id in protein_ids:
        sequence = proteins[protein_id]
        features = extract_features(sequence)
        binding_sites = all_binding_sites.get(protein_id, set())

        for i, aa in enumerate(sequence):
            X.append(features + [i + 1])  # Add position as a feature
            y.append(1 if i + 1 in binding_sites else 0)

    return np.array(X), np.array(y)

In [18]:
sample_sequence = "MSETWLPTLVTATPQEGFDLAVKLSRIAVKKTQPDAQVRDTLRAVYEKDANALIAVSAVVATHFQTIAAANDYWKD"
# sample_sequence = rename_sequence(sample_sequence)
features = extract_features(sample_sequence)

In [20]:
# Prepare data for all folds
X_all, y_all = prepare_data(
    [id for fold in folds for id in fold], proteins, all_binding_sites
)

print(X_all.shape)
print(y_all.shape)

(170686, 26)
(170686,)


### Model Training


In [21]:
rf = RandomForestClassifier(n_estimators=100, random_state=42)
scores = cross_val_score(rf, X_all, y_all, cv=5, scoring="f1")
print(f"Cross-validation F1 scores: {scores}")
print(f"Mean F1 score: {scores.mean()}")

Cross-validation F1 scores: [0.01478743 0.00274307 0.00508331 0.00071403 0.00708079]
Mean F1 score: 0.006081724529713235


In [22]:
rf.fit(X_all, y_all)

### Prediction on test set

In [28]:
test_file = 'data/development_set/uniprot_test.txt'
with open(test_file, 'r') as f:
    test_prot_ids = [id.replace('\n', '') for id in f.readlines()] 

print(test_prot_ids)

['D0VX23', 'B7J6R7', 'Q3IDI7', 'E9JSA3', 'P12734', 'P80547', 'P39621', 'Q8KRV3', 'Q747J2', 'P10245', 'P60848', 'Q9EV85', 'P07603', 'Q03243', 'P10868', 'A6L7X2', 'Q5SK07', 'P40065', 'Q7A1N5', 'Q44501', 'Q5NV90', 'Q6D2K4', 'P63165', 'Q9HC16', 'Q8A8Q1', 'P68661', 'Q7VWF8', 'Q01747', 'Q8GCY3', 'Q8ZJW4', 'P56930', 'Q72EF4', 'O29338', 'B7IE18', 'P00426', 'Q9HVI1', 'Q46898', 'Q8DRZ8', 'Q81BL7', 'P15328', 'Q8T6U0', 'P19656', 'Q8LGG8', 'O92323', 'A2I2W2', 'Q13M28', 'Q15369', 'Q9SE33', 'D2Z0P2', 'Q96GG9', 'Q8P4Q6', 'P45696', 'P69687', 'Q57587', 'P0CB20', 'Q929T5', 'B9DL91', 'Q8IJK2', 'G2EA45', 'P29460', 'P07059', 'P80882', 'Q8IKH2', 'C8WS74', 'Q9KFV3', 'Q9BVM4', 'Q9NVS9', 'A0L5S6', 'B7TYB2', 'G0RUC2', 'Q06151', 'Q96YD0', 'Q8DJ43', 'Q3E840', 'Q9A585', 'P00698', 'P00648', 'P00766', 'P05373', 'P00651', 'P22636', 'P03050', 'P0ACH5', 'P19793', 'P56406', 'Q05599', 'P16184', 'P63159', 'P19080', 'P07445', 'P06956', 'P09184', 'P40347', 'Q55389', 'O15527', 'Q5SJ80', 'P50861', 'P02263', 'P84229', 'P62801',

In [29]:
X_test, _ = prepare_data(test_prot_ids, proteins, {})
y_pred_test = rf.predict(X_test)



In [32]:
print(list(y_pred_test))

[np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0), np.int64(0)