In [1]:
import os
import requests
import numpy as np
import pandas as pd
from tqdm import tqdm

In [2]:
# parameters input
pdb_clusters_url = "https://cdn.rcsb.org/resources/sequence/clusters/bc-30.out"
training_exclusion_lists = [
    "data/lists/ppdb5_set.txt",
    "data/lists/masif-site_test_set.txt",
    "data/lists/skempi_v2.txt",
    "data/lists/memcplxdb.txt",
    "data/lists/excluded.txt"
]

# parameters output
np.random.seed(1337)
train_ratio = 0.8
output_dir = "data/datasets"

In [3]:
# fetch sequence based clusters
r = requests.get(pdb_clusters_url)
assert r.status_code == 200
raw_data = r.text

# extract clusters
pdb_clusters = []
for line in raw_data.split('\n'):
    if len(line) > 0:
        pdb_clusters.append([pdbid.strip() for pdbid in line.split(' ')])
        
# get all structure ids
sids = np.concatenate(pdb_clusters)
    
# debug print
print(f"{len(pdb_clusters)} clusters for {sum(len(pdb_cluster) for pdb_cluster in pdb_clusters)} pdbs")

37681 clusters for 567273 pdbs


In [4]:
# exclude pdbs from training set based on pdbids (chain name may not be reliable)
pdbids_excluded = []
for fp in training_exclusion_lists:
    with open(fp, 'r') as fs:
        for line in fs:
            if len(line) > 0:
                pdbids_excluded.append(line.strip().split('_')[0])
                
# unique pdbids
pdbids_excluded = np.unique(pdbids_excluded)
                
# exclude from pdbids
training_excluded = []
for sid in tqdm(sids):
    for pdbid in pdbids_excluded:
        if pdbid in sid.split('_')[0]:
            training_excluded.append(sid)
            
# debug print
print(f"{len(training_excluded)} subunits excluded ({len(pdbids_excluded)} pdbids)")

100%|██████████| 567273/567273 [02:26<00:00, 3870.54it/s]

3281 subunits excluded (888 pdbids)





In [5]:
# define clusters mapping
clusters_mapping = {}
for k in range(len(pdb_clusters)):
    for pdbid in pdb_clusters[k]:
        clusters_mapping[pdbid] = k
        
# assign clusters
sid_clusters_dict = {}
not_located = []
for sid in sids:
    if sid in clusters_mapping:
        cid = clusters_mapping[sid]
        if cid in sid_clusters_dict:
            sid_clusters_dict[cid].append(sid)
        else:
            sid_clusters_dict[cid] = [sid]
    else:
        not_located.append(sid)
        
# define clusters exclusion list
clusters_exclusion_l = [clusters_mapping[sid] for sid in training_excluded if sid in clusters_mapping]
not_located_exclusion = [sid for sid in training_excluded if sid not in clusters_mapping]

# transform dict into list of list
sid_clusters = [sid_clusters_dict[k] for k in sid_clusters_dict if k not in clusters_exclusion_l]
sid_clusters_excluded = [sid_clusters_dict[k] for k in sid_clusters_dict if k in clusters_exclusion_l]

# debug print
print(f"extracted: {len(sid_clusters)} clusters / {sum(len(sid_cluster) for sid_cluster in sid_clusters)} subunits / {len(not_located)} subunits not located")
print(f"excluded: {len(sid_clusters_excluded)} clusters / {sum(len(sid_cluster) for sid_cluster in sid_clusters_excluded)} subunits / {len(not_located_exclusion)} subunits not located")

extracted: 36722 clusters / 479588 subunits / 0 subunits not located
excluded: 959 clusters / 87685 subunits / 0 subunits not located


In [6]:
# define and shuffle cluster indices
N = len(sid_clusters)
ids = np.arange(N)
np.random.shuffle(ids)

# split training/testing subunits
n = int(N*train_ratio)
ids_train = ids[:n]
ids_test = ids[n:]

# define train subunits
train_sids = []
for i in ids_train:
    train_sids.extend(sid_clusters[i])
    
# define test subunits
test_sids = []
for i in ids_test:
    test_sids.extend(sid_clusters[i])

# add excluded subunits to validation set
valid_sids = []
if len(sid_clusters_excluded) > 0:
    for sid in np.concatenate(sid_clusters_excluded):
        valid_sids.append(sid)

# debug print
print(f"training dataset: {len(ids_train)} clusters / {len(train_sids)} subunits")
print(f"testing dataset: {len(ids_test)} clusters / {len(test_sids)} subunits")
print(f"validation dataset: {len(sid_clusters_excluded)} clusters / {len(valid_sids)} subunits")

training dataset: 29377 clusters / 381068 subunits
testing dataset: 7345 clusters / 98520 subunits
validation dataset: 959 clusters / 87685 subunits
