In [2]:
import pandas as pd
import pickle 
from numpy import random


In [3]:
f = '../data/fireprot/4_fireprotDB_bestpH.csv'

df = pd.read_csv(f).dropna(subset=['ddG'])
df.shape

(3438, 43)

In [16]:
# put seqs into FASTA for clustering

from Bio import SeqIO
from Bio.Seq import Seq

records = {}
for r in df.to_records():
    name = r.pdb_id_corrected
    if name not in records:
        records[name] = SeqIO.SeqRecord(Seq(r.pdb_sequence.replace('X', '-')), id=name, name=name, description='')
    
with open('../data/fireprot/fireprot_proteins.fasta', 'w') as outFile:
    SeqIO.write(list(records.values()), outFile, 'fasta')

# !mmseqs2 easy-cluster ../data/fireprot/fireprot_proteins.fasta ../data/fireprot/fireprot_proteins ../data/fireprot/tmp --min-seq-id 0.25

In [4]:
# TSV has 2 columns: cluster representative and cluster member

clust = pd.read_csv('../data/fireprot/fireprot_proteins_cluster.tsv', sep='\t', header=None, names=['cluster', 'member'])

vcs = clust.cluster.value_counts()
cs = clust.cluster.unique()
# print(len(cs), vcs)

In [13]:
# load tsv with homologue overlap between MegaScale and FireProt

overlap = pd.read_csv('../data/mmseqs_searches/mega_vs_fireprot.m8', sep='\t', header=None)
mega_overlap = overlap.iloc[:, 0].values
fp_overlap = overlap.iloc[:, 1].values
# print(mega_overlap, '\n\n', fp_overlap)

In [18]:
# generate homologue-free split

tmp = df

for m in mega_overlap:
    tmp = tmp[tmp['pdb_id_corrected'] != m.strip('.pdb')]
    
for f in fp_overlap:
    tmp = tmp[tmp['pdb_id_corrected'] != f.strip('.pdb')]

    
# splits = {'train': train_names, 'val': val_names, 'test': test_names}

with open('../data/fireprot/fireprot_splits.pkl', 'rb') as f:
    splits = pickle.load(f)
splits['homologue-free'] = tmp.pdb_id_corrected.unique()

print(splits.keys())


print(splits)

with open('../data/fireprot/fireprot_splits.pkl', 'wb') as f:
    pickle.dump(splits, f)

df['pdb_id_corrected'].isin(splits['homologue-free']).sum()

tmp.to_csv('../data/fireprot/fireprot_homologue_free.csv')

dict_keys(['train', 'val', 'test', 'homologue-free'])
{'train': ['1PGA', '2OCJ', '1ISP', '1LZ1', '4LYZ', '1CYC', '5PTI', '1EY0', '1CEY', '1BVC', '1C2R', '2RN2', '1WQ5', '1POH', '1IGV', '1ZG4', '1CSP', '1MJC', '2LZM', '1BTA', '1HRC', '1QLP', '3MBP', '1AG2', '2ABD', '451C', '1TTG', '1TEN', '2HPR', '1EL1', '1SUP', '1TPK', '1C5G', '1A5E', '1KFW', '1HME', '2A36', '1YU5', '1H7M', '1UZC', '1HK0', '1IO2', '1MSI', '4BLM', '2Q98', '1QJP', '1THQ', '1AQH', '1BCX', '1SSO', '1W3D', '1HGU', '2CHF', '1FMK', '1APS', '1KCQ', '2DRI'], 'val': array(['1OIA', '1IET', '1CYO', '1B5M', '1IFC', '1HMS', '1GV2', '1ONC',
       '1AEP', '1A2P', '2SIL', '3PGK', '2NVH', '1CHK', '5DFR'],
      dtype=object), 'test': array(['1QGV', '1IMQ', '1BRF', '1A23', '1KF2', '1HZ6', '1G4I', '1BNL',
       '1RRO', '1RTP', '1ANK', '1AKY', '2ADA', '1JNX', '1MGR', '1QND',
       '1E0W', '1AYE', '1C52', '1RBP', '1FTG', '2TRX', '1RN1', '2AFG',
       '1KE4', '1FRD', '1FXA', '1CAH'], dtype=object), 'homologue-free': array(['5DFR', '2OCJ'

In [6]:
# generate test/val/train splits

test_size, val_size, train_size = 0, 0, 0
total_size = df.shape[0]

test_df = pd.DataFrame(columns=df.columns)
test_size = test_df.shape[0]

random.seed(0)

used = []
# until size is met:
while test_size < total_size * 0.10:
    pick = random.randint(0, len(cs))
    if pick not in used:
        used.append(pick)
        cluster_picked = cs[pick]
    else:
        continue
    # pick a random cluster

    # add all members of that cluster to train_df
    cluster_all = clust[clust['cluster'] == cluster_picked]
    
    # TODO check if any cluster members are in restricted list (overlap w/FireProt)
    for c in cluster_all.member:
        if c in mega_overlap or c in fp_overlap:
            break
    
    if c in mega_overlap or c in fp_overlap:  # skip adding rows if held in overlap
        continue
    
    for c in cluster_all.member:
        target_rows = df.loc[df['pdb_id_corrected'] == c, :]
        if target_rows.shape[0] > 250:  # we don't want large proteins in the val/test set skewing performance estimates
            continue
        test_df = pd.concat([test_df, target_rows])
        test_size = test_df.shape[0]

print('Test Dataset Size:\t', test_df.shape)
print('=' * 50)
    
# same procedure for val set

val_df = pd.DataFrame(columns=df.columns)
val_size = val_df.shape[0]

while val_size < total_size * 0.10:
    pick = random.randint(0, len(cs))
    if pick not in used:
        used.append(pick)
        cluster_picked = cs[pick]
    else:
        continue
    # pick a random cluster

    # add all members of that cluster to val_df
    cluster_all = clust[clust['cluster'] == cluster_picked]
    
    # TODO check if any cluster members are in restricted list (overlap w/FireProt)
    for c in cluster_all.member:
        if c in mega_overlap or c in fp_overlap:
            break
    
    if c in mega_overlap or c in fp_overlap:  # skip adding rows if held in overlap
        continue
    
    for c in cluster_all.member:
        target_rows = df.loc[df['pdb_id_corrected'] == c, :]
        if target_rows.shape[0] > 500:  # we don't want large proteins in the val/test set skewing performance estimates
            continue
        val_df = pd.concat([val_df, target_rows])
        val_size = val_df.shape[0]

print('Validation Dataset Size:\t', val_df.shape)
print('=' * 50)

Test Dataset Size:	 (350, 43)
Validation Dataset Size:	 (402, 43)


  test_df = pd.concat([test_df, target_rows])
  test_df = pd.concat([test_df, target_rows])
  test_df = pd.concat([test_df, target_rows])
  test_df = pd.concat([test_df, target_rows])
  test_df = pd.concat([test_df, target_rows])
  test_df = pd.concat([test_df, target_rows])
  test_df = pd.concat([test_df, target_rows])
  test_df = pd.concat([test_df, target_rows])
  test_df = pd.concat([test_df, target_rows])
  test_df = pd.concat([test_df, target_rows])
  test_df = pd.concat([test_df, target_rows])
  test_df = pd.concat([test_df, target_rows])
  test_df = pd.concat([test_df, target_rows])
  test_df = pd.concat([test_df, target_rows])
  test_df = pd.concat([test_df, target_rows])
  test_df = pd.concat([test_df, target_rows])
  test_df = pd.concat([test_df, target_rows])
  test_df = pd.concat([test_df, target_rows])
  test_df = pd.concat([test_df, target_rows])
  test_df = pd.concat([test_df, target_rows])
  test_df = pd.concat([test_df, target_rows])
  test_df = pd.concat([test_df, ta

In [7]:
val_names, test_names = val_df.pdb_id_corrected.unique(), test_df.pdb_id_corrected.unique()
all_names = df.pdb_id_corrected.unique()

train_names = [a for a in all_names if a not in val_names and a not in test_names]
train_df = df.loc[df['pdb_id_corrected'].isin(train_names), :]

In [8]:
print(train_df.shape, val_df.shape, test_df.shape, df.shape)
print(len(train_names), len(val_names), len(test_names), clust.member.unique().shape)

(2686, 43) (402, 43) (350, 43) (3438, 43)
57 15 28 (100,)


In [9]:
splits = {'train': train_names, 'val': val_names, 'test': test_names}

with open('../data/fireprot/fireprot_splits.pkl', 'wb') as f:
    pickle.dump(splits, f)

In [10]:
train_df.to_csv('../data/fireprot/fireprot_train.csv')
val_df.to_csv('../data/fireprot/fireprot_val.csv')
test_df.to_csv('../data/fireprot/fireprot_test.csv')