In [None]:
import os
import sys
import glob
from pathlib import Path
from datetime import datetime
import logging
import cProfile, pstats, io
import h5py
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    roc_curve,
    auc,
    average_precision_score,
    precision_score,
    recall_score,
    accuracy_score,
    f1_score,
    matthews_corrcoef)
import torch
from deeprankcore.trainer import Trainer
from deeprankcore.utils.exporters import HDF5OutputExporter
from deeprankcore.dataset import GraphDataset
from deeprankcore.neuralnets.gnn.naive_gnn import NaiveNetwork
#from pmhc_gnn import PMHCI_Network01

# initialize
starttime = datetime.now()
torch.manual_seed(22) #11 22 33 44 55

In [None]:
#################### To fill
# Input data
# run_day_data = '230329' # 692 data points (local folder)
run_day_data = '230329' # 100k data points (proj folder)
# Paths
protein_class = 'I'
target_data = 'BA'
resolution_data = 'residue' # either 'residue' or 'atomic'
project_folder = '/home/ccrocion/snellius_data_sample' # local resized df path
#project_folder = '/projects/0/einf2380'
folder_data = f'{project_folder}/data/pMHC{protein_class}/features_output_folder/GNN/{resolution_data}/{run_day_data}'
input_data_path = glob.glob(os.path.join(folder_data, '*.hdf5'))
# Experiment naming
exp_name = 'Batch Size 128_1'
exp_date = True # bool
exp_suffix = ''
# Target/s
target_group = 'target_values'
target_dataset = 'binary'
task = 'classif'
standardize = True
# Clusters
# If cluster_dataset is None, sets are randomly splitted
cluster_dataset = None # 'cl_allele'# None # 'allele_type'
cluster_dataset_type = 'string' # None # 'string'
# train_clusters = [0, 1, 2, 3, 4, 7, 9]
# val_clusters = [5, 8]
test_clusters = ['C']
# Dataset
# node_features = [
#     'bsa', 'hb_acceptors', 'hb_donors',
#     'hse', 'info_content', 'irc_negative_negative',
#     'irc_negative_positive', 'irc_nonpolar_negative', 'irc_nonpolar_nonpolar',
#     'irc_nonpolar_polar', 'irc_nonpolar_positive', 'irc_polar_negative',
#     'irc_polar_polar', 'irc_polar_positive', 'irc_positive_positive',
#     'irc_total', 'polarity',
#     'res_charge', 'res_depth', 'res_mass',
#     'res_pI', 'res_size', 'res_type', 'sasa']
# node_features = "all"
node_features = "all"
# edge_features = [
#     "covalent", "distance", "same_chain", "electrostatic", "vanderwaals"]
# edge_features = "all"
edge_features = "all"

# standardize & transform Dictionary
feat_trans_dict={'bsa':{'transform':lambda t:np.log(t+1),'standardize':True},
               'res_depth':{'transform':lambda t:np.log(t+1),'standardize':True},
               'info_content':{'transform':lambda t:np.log(t+1),'standardize':True},
               'sasa':{'transform':lambda t:np.sqrt(t),'standardize':True},
               'electrostatic':{'transform':lambda t:np.cbrt(t),'standardize':True},
               'vanderwaals':{'transform':lambda t:np.cbrt(t),'standardize':True},
               'res_size':{'transform':None,'standardize':True},
               'res_charge':{'transform':None,'standardize':True},
               'hb_donors':{'transform':None,'standardize':True},
               'hb_acceptors':{'transform':None,'standardize':True},
               'hse':{'transform':None,'standardize':True},
               'irc_nonpolar_negative':{'transform':None,'standardize':True},
               'irc_nonpolar_nonpolar':{'transform':None,'standardize':True},
               'irc_nonpolar_polar':{'transform':None,'standardize':True},
               'irc_nonpolar_positive':{'transform':None,'standardize':True},
               'irc_polar_polar':{'transform':None,'standardize':True},
               'irc_polar_positive':{'transform':None,'standardize':True},
               'irc_total':{'transform':None,'standardize':True},
               'irc_negative_positive':{'transform':None,'standardize':True},
               'irc_positive_positive':{'transform':None,'standardize':True},
               'irc_polar_negative':{'transform':None,'standardize':True},
               'irc_negative_negative':{'transform':None,'standardize':True},
               'res_mass':{'transform':None,'standardize':True},
               'res_pI':{'transform':None,'standardize':True},
               'distance':{'transform':None,'standardize':True},
               'pssm':{'transform':None,'standardize':True}}

feat_notrans_dict={'bsa':{'transform':None,'standardize':True},
               'res_depth':{'transform':None,'standardize':True},
               'info_content':{'transform':None,'standardize':True},
               'sasa':{'transform':None,'standardize':True},
               'electrostatic':{'transform':None,'standardize':True},
               'vanderwaals':{'transform':None,'standardize':True},
               'res_size':{'transform':None,'standardize':True},
               'res_charge':{'transform':None,'standardize':True},
               'hb_donors':{'transform':None,'standardize':True},
               'hb_acceptors':{'transform':None,'standardize':True},
               'hse':{'transform':None,'standardize':True},
               'irc_nonpolar_negative':{'transform':None,'standardize':True},
               'irc_nonpolar_nonpolar':{'transform':None,'standardize':True},
               'irc_nonpolar_polar':{'transform':None,'standardize':True},
               'irc_nonpolar_positive':{'transform':None,'standardize':True},
               'irc_polar_polar':{'transform':None,'standardize':True},
               'irc_polar_positive':{'transform':None,'standardize':True},
               'irc_total':{'transform':None,'standardize':True},
               'irc_negative_positive':{'transform':None,'standardize':True},
               'irc_positive_positive':{'transform':None,'standardize':True},
               'irc_polar_negative':{'transform':None,'standardize':True},
               'irc_negative_negative':{'transform':None,'standardize':True},
               'res_mass':{'transform':None,'standardize':True},
               'res_pI':{'transform':None,'standardize':True},
               'distance':{'transform':None,'standardize':True},
               'pssm':{'transform':None,'standardize':True}}

feat_notrans_nostd_dict={'bsa':{'transform':None,'standardize':False},
               'res_depth':{'transform':None,'standardize':False},
               'info_content':{'transform':None,'standardize':False},
               'sasa':{'transform':None,'standardize':False},
               'electrostatic':{'transform':None,'standardize':False},
               'vanderwaals':{'transform':None,'standardize':False},
               'res_size':{'transform':None,'standardize':False},
               'res_charge':{'transform':None,'standardize':False},
               'hb_donors':{'transform':None,'standardize':False},
               'hb_acceptors':{'transform':None,'standardize':False},
               'hse':{'transform':None,'standardize':False},
               'irc_nonpolar_negative':{'transform':None,'standardize':False},
               'irc_nonpolar_nonpolar':{'transform':None,'standardize':False},
               'irc_nonpolar_polar':{'transform':None,'standardize':False},
               'irc_nonpolar_positive':{'transform':None,'standardize':False},
               'irc_polar_polar':{'transform':None,'standardize':False},
               'irc_polar_positive':{'transform':None,'standardize':False},
               'irc_total':{'transform':None,'standardize':False},
               'irc_negative_positive':{'transform':None,'standardize':False},
               'irc_positive_positive':{'transform':None,'standardize':False},
               'irc_polar_negative':{'transform':None,'standardize':False},
               'irc_negative_negative':{'transform':None,'standardize':False},
               'res_mass':{'transform':None,'standardize':False},
               'res_pI':{'transform':None,'standardize':False},
               'distance':{'transform':None,'standardize':False},
               'pssm':{'transform':None,'standardize':False}}

In [None]:
# Trainer
net = NaiveNetwork
batch_size = 128
optimizer = torch.optim.Adam
lr = 1e-3
weight_decay = 0
epochs = 40
save_model = 'best'
class_weights = True # weighted loss function
cuda = False
ngpu = 0
num_workers = 16
train_profiling = False
check_integrity = True
# early stopping
earlystop_patience = 15
earlystop_maxgap = 0.06
min_epoch = 20
####################

In [None]:
#################### Folders and logger
# Outputs folder
exp_basepath = './experiments/'
exp_id = exp_name + '0'
if os.path.exists(exp_basepath):
    exp_list = [f for f in os.listdir(exp_basepath) if f.lower().startswith(exp_name.lower())]
    if len(exp_list) > 0:
        last_id = max([int(w[len(exp_name):].split('_')[0]) for w in exp_list])
        exp_id = exp_name + str(last_id + 1)
exp_path = os.path.join(exp_basepath, exp_id)
if exp_date:
    today = starttime.strftime('%y%m%d')
    exp_path += '_' + today
if exp_suffix:
    exp_path += '_' + exp_suffix
os.makedirs(exp_path)

data_path = os.path.join(exp_path, 'data')
output_path = os.path.join(exp_path, 'output')
img_path = os.path.join(exp_path, 'images')
os.makedirs(data_path)
os.makedirs(output_path)
os.makedirs(img_path)
# Loggers
_log = logging.getLogger('')
_log.setLevel(logging.INFO)

fh = logging.FileHandler(os.path.join(exp_path, 'training.log'))
sh = logging.StreamHandler(sys.stdout)
fh.setLevel(logging.INFO)
sh.setLevel(logging.INFO)
formatter_fh = logging.Formatter('[%(asctime)s] - %(name)s - %(message)s',
                               datefmt='%a, %d %b %Y %H:%M:%S')
fh.setFormatter(formatter_fh)

_log.addHandler(fh)
_log.addHandler(sh)
####################

if __name__ == "__main__":
    _log.info(f'Created folder {exp_path}\n')

    _log.info("training.py has started!\n")

    #################### Data summary
    summary = {}
    summary['entry'] = []
    summary['target'] = []

    if cluster_dataset is not None:
        summary['cluster'] = []

    for fname in input_data_path:
        try:
            with h5py.File(fname, 'r') as hdf5:
                for mol in hdf5.keys():
                    target_value = float(hdf5[mol][target_group][target_dataset][()])
                    summary['entry'].append(mol)
                    summary['target'].append(target_value)

                    if cluster_dataset is not None:
                        if cluster_dataset_type == 'string':
                            cluster_value = hdf5[mol][target_group][cluster_dataset].asstr()[()]
                        else:
                            cluster_value = float(hdf5[mol][target_group][cluster_dataset][()])

                        summary['cluster'].append(cluster_value)

        except Exception as e:
            _log.error(e)
            _log.info(f'Error in opening {fname}, please check the file.')
    
    df_summ = pd.DataFrame(data=summary)

In [None]:
print(df_summ)

In [None]:
if cluster_dataset is None:
    # random split
    df_train, df_test = train_test_split(df_summ, test_size=0.1, stratify=df_summ.target, random_state=42)
    df_train, df_valid = train_test_split(df_train, test_size=0.2, stratify=df_train.target, random_state=42)
else:
    # use cluster for test, random split for train and valid
    df_test = df_summ[df_summ.cluster.isin(test_clusters)]
    df_train = df_summ[~df_summ.cluster.isin(test_clusters)]
    df_train, df_valid = train_test_split(df_train, test_size=0.2, stratify=df_train.target, random_state=42)

df_summ['phase'] = ['test' if entry in df_test.entry.values else 'valid' if entry in df_valid.entry.values else 'train' for entry in df_summ.entry]

df_summ.to_hdf(
    os.path.join(output_path, 'summary_data.hdf5'),
    key='summary',
    mode='w')

_log.info(f'Data statistics:\n')
_log.info(f'Total samples: {len(df_summ)}\n')
if cluster_dataset is not None:
    _log.info(f'Clustering on Dataset: {cluster_dataset}.\n')
_log.info(f'Training set: {len(df_train)} samples, {round(100*len(df_train)/len(df_summ))}%')
_log.info(f'\t- Class 0: {len(df_train[df_train.target == 0])} samples, {round(100*len(df_train[df_train.target == 0])/len(df_train))}%')
_log.info(f'\t- Class 1: {len(df_train[df_train.target == 1])} samples, {round(100*len(df_train[df_train.target == 1])/len(df_train))}%')
if cluster_dataset is not None:
    _log.info(f'Clusters present: {df_train.cluster.unique()}\n')
_log.info(f'Validation set: {len(df_valid)} samples, {round(100*len(df_valid)/len(df_summ))}%')
_log.info(f'\t- Class 0: {len(df_valid[df_valid.target == 0])} samples, {round(100*len(df_valid[df_valid.target == 0])/len(df_valid))}%')
_log.info(f'\t- Class 1: {len(df_valid[df_valid.target == 1])} samples, {round(100*len(df_valid[df_valid.target == 1])/len(df_valid))}%')
if cluster_dataset is not None:
    _log.info(f'Clusters present: {df_valid.cluster.unique()}\n')
_log.info(f'Testing set: {len(df_test)} samples, {round(100*len(df_test)/len(df_summ))}%')
_log.info(f'\t- Class 0: {len(df_test[df_test.target == 0])} samples, {round(100*len(df_test[df_test.target == 0])/len(df_test))}%')
_log.info(f'\t- Class 1: {len(df_test[df_test.target == 1])} samples, {round(100*len(df_test[df_test.target == 1])/len(df_test))}%')
if cluster_dataset is not None:
    _log.info(f'Clusters present: {df_test.cluster.unique()}\n')
####################

In [None]:
_log.info(f'HDF5DataSet loading...\n')
dataset_train = GraphDataset(
    hdf5_path = input_data_path,
    subset = list(df_train.entry),
    target = target_dataset,
    task = task,
    node_features = node_features,
    edge_features = edge_features,
    #standardize = standardize,
    check_integrity = check_integrity,
    features_transform = feat_notrans_dict
)
_log.info(f'Dataset_Train Setted\n')