In [None]:
import os
from visar.model_training_utils import (
    ST_model_hyperparam_screen, 
    ST_model_training,
    RobustMT_model_training,
    RobustMT_model_hyperparam_screen
)
from visar.VISAR_model_utils import (
    generate_RUNKEY_dataframe_baseline,
    generate_RUNKEY_dataframe_RobustMT,
    generate_RUNKEY_dataframe_ST,
    generate_performance_plot_ST,
    generate_performance_plot_RobustMT
)

import pandas as pd
import seaborn as sns
from collections import OrderedDict
os.environ['CUDA_VISIBLE_DEVICES']='1'

## model training

In [None]:
# initialize parameters
protein_targets = ['5HT-1b', '5HT-2b', '5HT-2c']
task_names = ['T106', 'T227', 'T108']  # refer to ./data/MT_assay_table_Feb28.csv
MT_dat_name = './data/MT_data_clean_June28.csv'
FP_type = 'Circular_2048'
log_path = './logs/Demo_GPCRs'
add_features = None
smiles_field = 'salt_removed_smi'
id_field = 'molregno'
dataset_file = './logs/Demo_GPCRs/tmp.csv'
n_features = 2048

In [None]:
# set parameters
params_dict = OrderedDict(
    n_tasks = [1],
    
    n_features = [2048], ## need modification given FP types
    activation = ['relu'],
    momentum = [.9],
    batch_size = [128],
    init = ['glorot_uniform'],
    learning_rate = [0.01],
    decay = [1e-6],
    nb_epoch = [30],
    dropouts = [.2, .4],
    nb_layers = [1],
    batchnorm = [False],
    layer_sizes = [(1024, 512),(1024,128) ,(512, 128),(512,64),(128,64),(64,32), 
                   (512,128,64), (128,64,32)],
    penalty = [0.1]
)

In [None]:
# hyperparam screening using deepchem
log_output = ST_model_hyperparam_screen(MT_dat_name, task_names, FP_type, params_dict, 
                                        log_path = log_path)

In [None]:
# option1: hyperparameter automatic selction
hyper_param_df = pd.read_csv(log_path + '/hyperparam_log.txt', header = None, sep = '\t')
hyper_param_df.columns = ['rep_label', 'task_name', 'param', 'r2_score']
hyper_param_df = hyper_param_df.sort_values(by = ['task_name', 'param', 'rep_label'], axis = 0)

best_hyperparams = {}
for task in task_names:
    hyper_stat = hyper_param_df.loc[hyper_param_df['task_name'] == task].groupby('param').agg({'r2_score': ['mean','max','std']})
    valid_mask = hyper_stat['r2_score']['std'] < 0.15 # filter out ones without reasonable generalization power
    hyper_stat = hyper_stat.loc[valid_mask]
    if hyper_stat.shape[0] >= 1:
        select_param = hyper_stat['r2_score']['max'].sort_values(ascending=False).index[0]
        select_r2 = hyper_stat['r2_score']['max'].sort_values(ascending=False)[0]
        
        select_param = select_param.replace('(', '')
        select_param = select_param.replace(')', '')
        
        tmp_layer1 = int(select_param.split(', ')[12])
        tmp_layer2 = int(select_param.split(', ')[13])
        tmp_drop = float(select_param.split(', ')[9])
        
        best_hyperparams[task] = [(tmp_layer1, tmp_layer2, 1), tmp_drop]
        print(task + ': ' + str(hyper_stat.shape[0]) + ', ' + str(select_r2))
    else:
        print(task_name + ' with training variance too high.')
        continue

In [None]:
# model training
output_df = ST_model_training(MT_dat_name, FP_type, 
                              best_hyperparams, result_path = log_path)

In [None]:
# evaluation
plot_df = generate_performance_plot_ST('logs/Demo_GPCRs/performance_metrics.csv')
g = sns.catplot(x = 'task', y = 'value', hue = 'method', 
                col = 'tt', row = 'performance', 
                data = plot_df, kind = 'bar')

## process trained models and generate files for visualization

In [None]:
# baseline models -- RidgeCV
custom_file = './data/custom_file.txt'
custom_id_field = 'id'
custom_task_field = 'dummy_value'
custom_smiles_field = 'SMILES'
sep_custom_file = '\t'
model_flag = 'MT'

for i in range(len(task_names)):
    task = task_names[i]
    output_prefix = './logs/Demo_GPCRs/RidgeCV_' + task + '_new_'
    
    generate_RUNKEY_dataframe_baseline(output_prefix, task, dataset_file, FP_type, 
                                   add_features, mode = 'RidgeCV', 
                                   MT_dat_name = MT_dat_name, 
                                   smiles_field = smiles_field, id_field = id_field,
                                   custom_file = custom_file, custom_id_field = custom_id_field, 
                                   custom_task_field = custom_task_field, 
                                   custom_smiles_field = custom_smiles_field,
                                   sep_custom_file = sep_custom_file)

In [None]:
# baseline models -- SVR
custom_file = './data/custom_file.txt'
custom_id_field = 'id'
custom_task_field = 'dummy_value'
custom_smiles_field = 'SMILES'
sep_custom_file = '\t'
model_flag = 'MT'

for i in range(len(task_names)):
    task = task_names[i]
    output_prefix = './logs/Demo_GPCRs/SVR_' + task + '_new_'
    
    generate_RUNKEY_dataframe_baseline(output_prefix, task, dataset_file, FP_type, 
                                   add_features, mode = 'SVR', 
                                   MT_dat_name = MT_dat_name, 
                                   smiles_field = smiles_field, id_field = id_field,
                                   custom_file = custom_file, custom_id_field = custom_id_field, 
                                   custom_task_field = custom_task_field, 
                                   custom_smiles_field = custom_smiles_field,
                                   sep_custom_file = sep_custom_file)

In [None]:
# single task models
custom_file = './data/custom_file.txt'
custom_id_field = 'id'
custom_task_field = 'dummy_value'
custom_smiles_field = 'SMILES'
sep_custom_file = '\t'
model_flag = 'MT'

for i in range(len(task_names)):
    task = task_names[i]
    output_prefix = './logs/Demo_GPCRs/ST_' + task + '_new_'
    prev_model = './logs/Demo_GPCRs/' + task + '_rep0_50.hdf5'
    
    generate_RUNKEY_dataframe_ST(prev_model, output_prefix, [task], dataset_file, FP_type, 
                             add_features, mode = 'ST',
                             MT_dat_name = MT_dat_name, n_layer = 1,
                             smiles_field = smiles_field, id_field = id_field,
                             custom_file = custom_file, custom_id_field = custom_id_field, 
                             custom_task_field = custom_task_field, 
                             custom_smiles_field = custom_smiles_field,
                             sep_custom_file = sep_custom_file)