In [12]:
import pandas as pd
from collections import defaultdict
import os
from numpy import random
import pickle
import itertools

In [2]:
data = pd.read_csv("../data/mega_scale/Processed_K50_dG_datasets/Tsuboyama2023_Dataset2_Dataset3_20230416.csv")
data.shape

  data = pd.read_csv("../data/mega_scale/Processed_K50_dG_datasets/Tsuboyama2023_Dataset2_Dataset3_20230416.csv")


(776298, 39)

In [3]:
# remove unreliable data points, insertions, deletions, and multiple mutations
data = data.loc[data.ddG_ML != '-', :].reset_index(drop=True)
data = data.loc[~data.mut_type.str.contains("ins") & ~data.mut_type.str.contains("del"), :].reset_index(drop=True)
data = data.loc[~data.mut_type.str.contains("del"), :].reset_index(drop=True)
data = data.loc[~data.mut_type.str.contains(":"), :].reset_index(drop=True)
data.shape

(391090, 39)

In [5]:
names = data.WT_name.unique()

# check that each name has matching PDB somewhere
pdb_dir = '../data/mega_scale/AlphaFold_model_PDBs'

pdb_files = os.listdir(pdb_dir)
names = [n for n in names if n in pdb_files]

seqs = data.aa_seq.unique()
df = data
mut_rows, wt_seqs = {}, {}
df['wt_seq'] = ''

# make new wt_seq column to accompany each data point
for wt_name in names:
    wt_rows = df.query('WT_name == @wt_name and mut_type == "wt"').reset_index(drop=True)
    mut_rows[wt_name] = df.query('WT_name == @wt_name and mut_type != "wt"').reset_index(drop=True)
    wt_seqs[wt_name] = wt_rows.aa_seq[0]
    df.loc[df['WT_name'] == wt_name, 'wt_seq'] = wt_rows.aa_seq[0]

df = df.loc[df['WT_name'].isin(names), :]
df.shape

(272712, 40)

In [6]:
# put sequences into FASTA and cluster with mmseqs2

# from Bio import SeqIO
# from Bio.Seq import Seq

# df = df.loc[df['WT_name'].isin(names), :]

# records = {}
# for r in df.to_records():
#     name = r.WT_name
#     if name not in records:
#         records[name] = SeqIO.SeqRecord(Seq(r.wt_seq), id=name, name=name, description='')
    
# with open('../data/mega_scale/mega_proteins.fasta', 'w') as outFile:
#     SeqIO.write(list(records.values()), outFile, 'fasta')

# !mmseqs2 easy-cluster ../data/mega_scale/mega_proteins.fasta ../data/mega_scale/mega_proteins ../data/mega_scale/tmp --min-seq-id 0.25

In [10]:
clust = pd.read_csv('../data/mega_scale/mega_proteins_cluster.tsv', sep='\t', header=None, names=['cluster', 'member'])
vcs = clust.cluster.value_counts()
cs = clust.cluster.unique()

# splitting into 5 cross-validation folds based on cluster IDs
total_size = df.shape[0]
random.seed(10)

cv_df_list = []
cv_folds = [0.2, 0.2, 0.2, 0.2, -1]
used = []
print('Total size:\t', total_size)

for fold in cv_folds:

    fold_df = pd.DataFrame(columns=df.columns)
    fold_size = fold_df.shape[0]
    # loop until size is met:
    if fold == -1:
        # select all that remain
        all_clusters = cs
        rdf = df
        for d in cv_df_list:
            rdf = rdf[~rdf['WT_name'].isin(d['WT_name'])]  
        print('Fold Size:\t', rdf.shape)
        print('=' * 50)
        cv_df_list.append(rdf)
        break
    
    while fold_size < total_size * fold:
        pick = random.randint(0, len(cs))
        if pick not in used:
            used.append(pick)
            cluster_picked = cs[pick]
        else:
            continue
        # add all members of that cluster to fold_df
        cluster_all = clust[clust['cluster'] == cluster_picked]
        for c in cluster_all.member:
            target_rows = df.loc[df['WT_name'] == c, :]
            fold_df = pd.concat([fold_df, target_rows])
        fold_size = fold_df.shape[0]
    print('Fold Size:\t', fold_df.shape)
    print('=' * 50)
    cv_df_list.append(fold_df)

Total size:	 272712
Fold Size:	 (54962, 40)
Fold Size:	 (54800, 40)
Fold Size:	 (55266, 40)
Fold Size:	 (55886, 40)
Fold Size:	 (51798, 40)


In [14]:
# assembling train/val/test splits based on random combinations of the 5 folds

train_list = [
    [1, 2, 3],
    [1, 2, 5], 
    [1, 4, 5], 
    [3, 4, 5],
    [2, 3, 4]
]

test_list = [5, 4, 3, 2, 1]
val_list = [4, 3, 2, 1, 5]


with open('../data/mega_scale/mega_splits.pkl', 'rb') as f:
    splits = pickle.load(f)

n = 0

for tr, te, val in zip(train_list, test_list, val_list):
    test_val = cv_df_list[te - 1].WT_name.unique()
    val_val = cv_df_list[val - 1].WT_name.unique()
    train_val = [cv_df_list[t - 1].WT_name.unique() for t in tr]
    train_val = list(itertools.chain.from_iterable(train_val))
    print(len(train_val), len(val_val), len(test_val))
    assert len(train_val) + len(val_val) + len(test_val) == 298
    splits[f"cv_train_{n}"] = train_val
    splits[f"cv_val_{n}"] = val_val
    splits[f"cv_test_{n}"] = test_val
    n += 1

# with open('../data/mega_scale/mega_splits.pkl', 'wb') as f:
#     pickle.dump(splits, f)


188 59 51
176 63 59
172 63 63
173 62 63
185 51 62


dict_keys(['train', 'val', 'test', 'train_s669', 'cv_train_0', 'cv_val_0', 'cv_test_0', 'cv_train_1', 'cv_val_1', 'cv_test_1', 'cv_train_2', 'cv_val_2', 'cv_test_2', 'cv_train_3', 'cv_val_3', 'cv_test_3', 'cv_train_4', 'cv_val_4', 'cv_test_4'])

In [16]:
# doing non-cross-validation train/val/test split for main ablation study runs

# load tsv with homology overlap between MegaScale and FireProt
overlap = pd.read_csv('../data/mmseqs_searches/mega_vs_fireprot.m8', sep='\t', header=None)
mega_overlap = overlap.iloc[:, 0].values
fp_overlap = overlap.iloc[:, 1].values

test_size, val_size, train_size = 0, 0, 0
total_size = df.shape[0]

test_df = pd.DataFrame(columns=df.columns)
test_size = test_df.shape[0]
random.seed(1)

used = []
# add test proteins until minimum size cutoff is met:
while test_size < total_size * 0.10:
    # pick random cluster
    pick = random.randint(0, len(cs))
    if pick not in used:
        used.append(pick)
        cluster_picked = cs[pick]
    else: # avoid duplicate picks
        continue
    # pick a random cluster

    # add all members of that cluster to train_df
    cluster_all = clust[clust['cluster'] == cluster_picked]
    
    # check if any cluster members are in restricted list (overlap w/FireProt)
    for c in cluster_all.member:
        if c in mega_overlap or c in fp_overlap:
            break
    
    if c in mega_overlap or c in fp_overlap:  # skip adding rows if held in overlap
        continue
    
    for c in cluster_all.member:
        target_rows = df.loc[df['WT_name'] == c, :]
        test_df = pd.concat([test_df, target_rows])
        test_size = test_df.shape[0]

print('Test Dataset Size:\t', test_size)
print('=' * 50)
    
# repeat loop on validation set
val_df = pd.DataFrame(columns=df.columns)
val_size = val_df.shape[0]

while val_size < total_size * 0.10:
    # pick random cluster
    pick = random.randint(0, len(cs))
    if pick not in used:
        used.append(pick)
        cluster_picked = cs[pick]
    else:  # avoid duplicate picks
        continue
    # add all members of that cluster to val_df
    cluster_all = clust[clust['cluster'] == cluster_picked]
    
    # check if any cluster members are in restricted list (overlap w/FireProt)
    for c in cluster_all.member:
        if c in mega_overlap or c in fp_overlap:
            break
    
    if c in mega_overlap or c in fp_overlap:  # skip adding rows if held in overlap
        continue
    
    for c in cluster_all.member:
        target_rows = df.loc[df['WT_name'] == c, :]
        val_df = pd.concat([val_df, target_rows])
        val_size = val_df.shape[0]

print('Validation Dataset Size:\t', val_size)
print('=' * 50)

Test Dataset Size:	 28312
Validation Dataset Size:	 27481


In [19]:
val_names, test_names = val_df.WT_name.unique(), test_df.WT_name.unique()
all_names = df.WT_name.unique()

train_names = [a for a in all_names if a not in val_names and a not in test_names]
train_df = df.loc[df['WT_name'].isin(train_names), :]
print('Training Dataset Size:', train_df.shape[0])

splits = {'train': train_names, 'val': val_names, 'test': test_names}

# save splits to pickle file
# with open('../data/mega_scale/mega_splits.pkl', 'wb') as f:
#     pickle.dump(splits, f)

# if desired, save split df to separate files (not needed for training)
# train_df.to_csv('../data/mega_scale/Processed_K50_dG_datasets/mega_train.csv')
# val_df.to_csv('../data/mega_scale/Processed_K50_dG_datasets/mega_val.csv')
# test_df.to_csv('../data/mega_scale/Processed_K50_dG_datasets/mega_test.csv')

Training Dataset Size: 216919
