In [1]:
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from glob import glob
from tqdm import tqdm

import src as sp

from runtime_models import load_structure

In [2]:
# load structures clustering
df_bio_clust = pd.read_csv("datasets/biounits_cath_seq_cluster.csv")
sids = df_bio_clust['sid'].values

# training / testing set of PeSTo
sids_train = np.genfromtxt("datasets/subunits_train_set.txt", dtype=np.dtype("U"))
sids_valid = np.genfromtxt("datasets/subunits_test_set.txt", dtype=np.dtype("U"))
sids_train = np.concatenate([sids_train, sids_valid])

# load ESM-IF1 dataset split
esm_splits = json.load(open("external_sets/esm-if1_splits.json", 'r'))

# training set of ESM
sids_esm_train = np.array(['{}_{}'.format(sid.split('.')[0].upper(), sid.split('.')[1]) for sid in esm_splits['train']])
sids_esm_valid = np.array(['{}_{}'.format(sid.split('.')[0].upper(), sid.split('.')[1]) for sid in esm_splits['validation']])
sids_esm_train = np.concatenate([sids_esm_train, sids_esm_valid])

# ProteinMPNN info
df_mpnn = pd.read_csv("external_sets/pdb_2021aug02/list.csv")[['CHAINID', 'CLUSTER']]
cids_mpnn = np.genfromtxt("external_sets/pdb_2021aug02/test_clusters.txt", dtype=int)
df_mpnn_train = df_mpnn[~df_mpnn['CLUSTER'].isin(cids_mpnn)]

# training set of ProteinMPNN
sids_mpnn_train = np.array(["{}_{}".format(sid.split('_')[0].upper(), sid.split('_')[1]) for sid in df_mpnn_train['CHAINID'].values])

# union of the 3 training set
sids_sset_train = np.unique(np.concatenate([sids_train, sids_esm_train, sids_mpnn_train]))

# debug
df_bio_clust.shape, sids_train.shape, sids_esm_train.shape, sids_mpnn_train.shape, sids_sset_train.shape

((363788, 26), (473641,), (18228,), (537616,), (676377,))

In [3]:
# map with existing clustering information
sids_map = {sid:0 for sid in sids}
sids_match_train = np.array([sid for sid in sids_sset_train if sid in sids_map])

# find clustered structure in training set and not in training set
df_bio_clust_train = df_bio_clust[df_bio_clust['sid'].isin(sids_match_train)]
df_bio_clust_nottrain = df_bio_clust[~df_bio_clust['sid'].isin(sids_match_train)]

# debug
df_bio_clust_nottrain

Unnamed: 0,sid,C,A,T,H,S05,S15,S10,S25,S20,...,S65,S60,S75,S70,S85,S80,S95,S90,S100,date
52,12E8_H,5,1,1,1,5757,5757,5757,5813,5759,...,43740,62278,64947,65866,10024,9700,28816,27338,30,1998-03-14
53,12E8_L,5,1,1,1,5757,5757,5757,5813,5759,...,43740,62278,47543,33899,49106,40536,55609,48689,31,1998-03-14
54,12E8_M,5,1,1,1,5757,5757,5757,5813,5759,...,43740,62278,47543,33899,49106,40536,55609,48689,31,1998-03-14
55,12E8_P,5,1,1,1,5757,5757,5757,5813,5759,...,43740,62278,64947,65866,10024,9700,28816,27338,30,1998-03-14
100,15C8_H,5,1,1,1,5757,5757,5757,5813,5759,...,43740,62278,64947,65866,10024,9700,8,37993,62,1998-03-18
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
363246,6RFK_S,8,1,1,1,30532,30532,30532,30984,30539,...,5055,4870,5453,5276,5851,5635,7499,7005,94797,2019-04-15
363510,6ROG_X,8,1,1,1,31271,31271,31271,31733,31277,...,10899,10520,11586,11272,12291,11908,68607,64475,95019,2019-05-13
363523,6RPV_A,8,1,1,1,5802,5802,5802,5858,5804,...,2110,2016,2291,2205,2489,2389,2861,2652,95016,2019-05-14
363524,6RQ7_B,8,1,1,1,29376,29376,29376,29816,29382,...,43225,41809,45688,44484,48520,47036,53681,50662,75862,2019-05-15


In [4]:
# parameters
ckeys = ['C', 'A', 'T', 'H', 'S30']

# extract clusters train
clust_train_map = {}
for r in tqdm(df_bio_clust_train.to_dict("records")):
    sid = r['sid']
    key = '-'.join([str(r[k]) for k in ckeys])
    if key in clust_train_map:
        clust_train_map[key].append(sid)
    else:
        clust_train_map[key] = [sid]

# extract clusters remainder
clust_nottrain_map = {}
for r in tqdm(df_bio_clust_nottrain.to_dict("records")):
    sid = r['sid']
    key = '-'.join([str(r[k]) for k in ckeys])
    if key in clust_nottrain_map:
        clust_nottrain_map[key].append(sid)
    else:
        clust_nottrain_map[key] = [sid]

# cluster indentifiers
ckeys_train = np.array(list(clust_train_map))
ckeys_nottrain = np.array(list(clust_nottrain_map))

# find cluster not in train of any methods and corresponding structure ID
sids_test = []
clust_test = ckeys_nottrain[~np.isin(ckeys_nottrain, ckeys_train)]
for ckey in clust_test:
    sids_test.extend(clust_nottrain_map[ckey])
sids_test = np.array(sids_test)

# debug
print(len(clust_test), len(sids_test))

100%|██████████| 354260/354260 [00:00<00:00, 854723.83it/s]
100%|██████████| 9528/9528 [00:00<00:00, 846896.00it/s]


72 228


In [5]:
# statistics
df_test = pd.DataFrame([{'pdbid':s.split('_')[0], 'cid':s.split('_')[1]} for s in sids_test])

# count subunits
df_test['num_subs'] = df_test.groupby('pdbid')['pdbid'].transform('size')

df_test

Unnamed: 0,pdbid,cid,num_subs
0,1ABO,A,2
1,1ABO,B,2
2,1ABQ,A,1
3,1AWO,A,1
4,1BBZ,A,4
...,...,...,...
223,6J5J,u,3
224,6KBR,C,1
225,6P6B,A,1
226,6PNW,A,2


In [6]:
sc = df_test['num_subs'].value_counts()
(sc / sc.index).astype(int)

1    76
2    37
3    13
4     7
6     1
5     1
dtype: int64

In [7]:
# manual exclusion
pdbid_excluded = ['6H82', '6H9C']  # too big / slow to load

for pdbid in tqdm(df_test['pdbid'].unique()):
    # excluded
    if pdbid in pdbid_excluded:
        continue

    # load structure
    pdb_filepath = glob("data/all_biounits_v2/{}/{}.pdb1.gz".format(pdbid[1:3].lower(), pdbid.lower()))[0]
    subunits = load_structure(pdb_filepath)
    
    # selected chains
    cids = df_test[df_test['pdbid'] == pdbid]['cid'].values
    
    # intersect subunits
    cids = cids[np.isin(cids, np.unique([c.split(':')[0] for c in list(subunits)]))]
    
    # save subunits alone
    for cid in cids:
        subunit = subunits["{}:0".format(cid)]
        
        # get sequence
        seq = sp.subunit_to_sequence(subunit)
        
        # check
        if len(seq) > 0:
            # save sequence
            with open("benchmark_data/single/{}_{}.fasta".format(pdbid, cid), 'w') as fs:
                fs.write(">{}_{}\n".format(pdbid, cid)+seq)
    
            # save structure
            sp.save_pdb({cid:subunit}, "benchmark_data/single/{}_{}.pdb".format(pdbid, cid))
    
    # save dimers
    if len(cids) == 2:
        subunits = {cid:subunits['{}:0'.format(cid)] for cid in cids}
        
        # get sequence
        seq = ':'.join([sp.subunit_to_sequence(subunits[cid]) for cid in cids])
        
        # check
        if len(seq) > 1:
            # save sequence
            with open("benchmark_data/dimer/{}.fasta".format(pdbid), 'w') as fs:
                fs.write(">{}\n".format(pdbid)+seq)
            
            # save structure
            sp.save_pdb(subunits, "benchmark_data/dimer/{}.pdb".format(pdbid))
        

100%|██████████| 135/135 [00:38<00:00,  3.53it/s]
