# Train an ML Model That Finds Cryptic Pockets Within MD Trajectories.
The training dataset consits of crystal structures from the extended Cryptosite database.  There are several structures for each protein in the database.  One structures is "holo"; i.e. it has a cryptic pocket.  The other structures are "apo"; i.e. they (hopefully) don't have the pocket.  The data is meant to mimic an MD trajectory in which some frames have the pocket and others don't.

It is assumed that the only holo structure is the one found the Cryptosite database.  This assumption isn't always true; i.e. some of the (supposedly) "apo" structures may have cryptic pockets that are partially or fully open.  This introduces error into the database; the database is expected to be useful despite the error.

### Initialize the dataset.
The notebook cell below creates the dataframe `df_train` that stores training data for ML.  It contains a line for each residue of each protein in the extended database.  Each line contains the following:
* Info about the protein and residue.
* Info about the "crypto_apo" protein.  This is the apo protein from the Cryptosite database with the same sequence as the protein.
* Info about the holo protein from the Cryptosite database.
* Misc. features for ML, which are added later in the notebook.

In [None]:
import glob
import json
import os
import shutil

import pandas as pd

from align import align_res_nums
from get_concavity_score import get_concavity_score


# Delete old output
ali_output_dir = "apo_holo_alignments"
if os.path.isdir(ali_output_dir):
    shutil.rmtree(ali_output_dir)
os.mkdir(ali_output_dir)

conc_output_dir = "train_set_conc"
if os.path.isdir(conc_output_dir):
    shutil.rmtree(conc_output_dir)
os.mkdir(conc_output_dir)


# Initialize the training set data as a list.  Start by adding the data from the apo and holo proteins
# in the original cryptosite database.  This is necessary because later parts of the code require that
# cryptosite's apo and holo structures are present.

# Later, the code will convert the list to a DataFrame.  This is faster than iteratively creating a DataFrame.
train_set_as_list = []
df_cryptosite_orig = pd.read_csv("../generate_training_database/gen_crypto_database/cryptosite_database.csv")
for index, row in df_cryptosite_orig.iterrows():
    # Add cryptosite holo.
    train_set_as_list.append([row["holo_pdb_id"].lower(), row["holo_chain"], row["holo_resnum"],
                              row["holo_pdb_id"].lower(), row["holo_chain"], row["holo_resnum"],
                              row["apo_pdb_id"].lower(), row["apo_chain"], row["apo_resnum"],
                              row["is_cryptic"]])
    # Add cryptosite apo.  By definition is_cryptic=False for all apo residues.
    train_set_as_list.append([row["apo_pdb_id"].lower(), row["apo_chain"], row["apo_resnum"],
                              row["holo_pdb_id"].lower(), row["holo_chain"], row["holo_resnum"],
                              row["apo_pdb_id"].lower(), row["apo_chain"], row["apo_resnum"],
                              False])

# Find which chain of each holo PDB is in the database.
cryptosite_db_list_loc = ("../generate_training_database/gen_crypto_database/pdbs_and_ligands.csv")
with open(cryptosite_db_list_loc, "r") as cryptosite_db_list_file:
    cryptosite_db_list_lines = cryptosite_db_list_file.readlines()
dict_holo_chains = {}
for line in cryptosite_db_list_lines[1:]:
    holo_pdb_id = line[7:11].lower()
    holo_chain_id = line[12]
    dict_holo_chains[holo_pdb_id] = holo_chain_id
dict_crypto_apo_chains = {}
for line in cryptosite_db_list_lines[1:]:
    apo_pdb_id = line[0:4].lower()
    apo_chain_id = line[5]
    dict_crypto_apo_chains[apo_pdb_id] = apo_chain_id
    
dict_crypto_apo_to_holo = {}
for line in cryptosite_db_list_lines[1:]:
    crypto_apo_pdb_id = line[0:4].lower()
    crypto_apo_chain_id = line[5]
    crypto_apo_pdb_and_chain = crypto_apo_pdb_id + ":" + crypto_apo_chain_id
    holo_pdb_id = line[7:11].lower()
    holo_chain_id = line[12]
    dict_holo_to_crypto_apo[holo_pdb_id] = crypto_apo_pdb_and_chain

num_passed = 0 # Useful for estimating how fast the code is running
#for apo_list_loc in sorted(glob.glob("extended_db/*good*"))[0:10]:
for apo_list_loc in sorted(glob.glob("extended_db/*good*")):
    crypto_apo_pdb_id = apo_list_loc[12:16]
    crypto_apo_chain_id = dict_crypto_apo_chains[crypto_apo_pdb_id]
    holo_pdb_id = dict_crypto_apo_to_holo[crypto_apo_pdb_id]
    holo_chain_id = dict_holo_chains[holo_pdb_id]
    holo_pdb_loc = ("../generate_training_database/gen_crypto_database/holo_structures/%s.pdb" %(holo_pdb_id))
    # 1AFQ.C and 2CGA.B don't show up in each other's sequence identity clusters.  I suspect that this is because
    # part of 2CGA.B corresponds to 1AFQ.B; this is a result of 1AFQ being divided into more chains than 2CGA.
    # Regardless of if this is a correct explanation, the code must look at the cryptosite database to get the
    # crypto_apo because of this issue.
    crypto_apo_pdb_loc = ("../generate_training_database/gen_crypto_database/"
                          "apo_structures/%s.pdb" %(crypto_apo_pdb_id))
    with open(apo_list_loc, "r") as apo_list_opened:
        apo_list_lines = apo_list_opened.readlines()
    for line in apo_list_lines[1:min(len(apo_list_lines), 51)]: # Skip the header, and read only 50 lines.
        apo_pdb_id = line[0:4]
        apo_pdb_loc = ("/home/devans61/Desktop/ml_on_traj/extended_db/"
                       "%s_download_pdbs/%s.pdb" %(crypto_apo_pdb_id, apo_pdb_id))
        apo_chain_id = line[5]
        print("num_passed", num_passed, "args", apo_pdb_loc, apo_chain_id, holo_pdb_loc, holo_chain_id)
        apo_holo_dict = align_res_nums(apo_pdb_loc, apo_chain_id, holo_pdb_loc, holo_chain_id)
        apo_holo_dict_loc = "%s/apo_%s%s_holo_%s%s.json" %(ali_output_dir, apo_pdb_id, apo_chain_id,
                                                           holo_pdb_id, holo_chain_id)
        print(apo_pdb_id, holo_pdb_id)
        with open(apo_holo_dict_loc, "w") as apo_holo_dict_file:
            json.dump(apo_holo_dict, apo_holo_dict_file)
            
        apo_crypto_apo_dict = align_res_nums(apo_pdb_loc, apo_chain_id, crypto_apo_pdb_loc, crypto_apo_chain_id)
        apo_crypto_apo_dict_loc = "%s/apo_%s%s_crypto_apo_%s%s.json" %(ali_output_dir, apo_pdb_id, apo_chain_id,
                                                                       crypto_apo_pdb_id, crypto_apo_chain_id)
        print(apo_pdb_id, crypto_apo_pdb_id)
        with open(apo_crypto_apo_dict_loc, "w") as apo_crypto_apo_dict_file:
            json.dump(apo_crypto_apo_dict, apo_crypto_apo_dict_file)

        for apo_resnum, holo_resnum in apo_holo_dict.items():
            crypto_apo_resnum = apo_crypto_apo_dict[apo_resnum]
            train_set_as_list.append([apo_pdb_id, apo_chain_id, apo_resnum,
                                      holo_pdb_id, holo_chain_id, holo_resnum, 
                                      crypto_apo_pdb_id, crypto_apo_chain_id, crypto_apo_resnum,
                                      False])
    num_passed += 1
df_train = pd.DataFrame(train_set_as_list, columns=["prot_id", "chain_id", "resnum",
                                                    "holo_prot_id", "holo_chain_id", "holo_resnum",
                                                    "crypto_apo_prot_id","crypto_apo_chain_id", "crypto_apo_resnum",
                                                    "is_site"])
df_train.head()