In [None]:
import os
from visar.model_training_utils import (
    ST_model_hyperparam_screen, 
    ST_model_training,
    RobustMT_model_training,
    RobustMT_model_hyperparam_screen
)
from visar.VISAR_model_utils import (
    generate_RUNKEY_dataframe_baseline,
    generate_RUNKEY_dataframe_RobustMT,
    generate_RUNKEY_dataframe_ST,
    generate_performance_plot_ST,
    generate_performance_plot_RobustMT
)

import pandas as pd
import seaborn as sns
from collections import OrderedDict
os.environ['CUDA_VISIBLE_DEVICES']='1'

## model training

In [None]:
# set parameters
params_dict = OrderedDict(
    n_tasks = [len(task_names)],
    n_features = [2048], ## need modification given FP types
    activation = ['relu'],
    momentum = [.9],
    batch_size = [128],
    init = ['glorot_uniform'],
    learning_rate = [0.0001],
    decay = [1e-6],
    nb_epoch = [30],
    dropouts = [.2, .4],
    nb_layers = [1],
    batchnorm = [False],
    layer_sizes = [(1024, 512),(1024, 128),(512, 128),(512,64)],
    bypass_dropouts = [0.5],
    bypass_layer_sizes = [[128], [64]]
)

In [None]:
# hyperparameter screening
log_output = RobustMT_model_hyperparam_screen(MT_dat_name, task_names, FP_type, params_dict, log_path, smiles_field, id_field)

In [None]:
# option1: hyperparameter automatic selction
hyper_param_df = pd.read_csv(log_path + '/hyperparam_log.txt', header = None, sep = '\t')
hyper_param_df.columns = ['rep_label', 'param', 'r2_score']
hyper_param_df = hyper_param_df.sort_values(by = ['param', 'rep_label'], axis = 0)

best_hyperparams = {}
hyper_stat = hyper_param_df.groupby('param').agg({'r2_score': ['mean','max','std']})
valid_mask = hyper_stat['r2_score']['std'] < 0.15 # filter out ones without reasonable generalization power
hyper_stat = hyper_stat.loc[valid_mask]
if hyper_stat.shape[0] >= 1:
    select_param = hyper_stat['r2_score']['max'].sort_values(ascending=False).index[0]
    select_r2 = hyper_stat['r2_score']['max'].sort_values(ascending=False)[0]
        
    select_param = select_param.replace('(', '')
    select_param = select_param.replace(')', '')
        
    tmp_layer1 = int(select_param.split(', ')[12])
    tmp_layer2 = int(select_param.split(', ')[13])
    bypass_layer = int(select_param.split(', ')[15].strip('[').strip(']'))
    tmp_drop = float(select_param.split(', ')[9])
        
    best_hyperparams = [(tmp_layer1, tmp_layer2), [bypass_layer], tmp_drop]
    print(str(hyper_stat.shape[0]) + ', ' + str(select_r2))
else:
    print(task_name + ' with training variance too high.')

best_hyperparams

In [None]:
# train
layer_sizes = best_hyperparams[0]
bypass_layer_sizes = best_hyperparams[1]
dropout = best_hyperparams[2]
lr = 0.0001
bypass_dropouts = 0.5
n_features = 2048

RobustMT_model_training(MT_dat_name, FP_type, task_names, log_path, 
                        n_features, layer_sizes, bypass_layer_sizes, bypass_dropouts, dropout, lr,
                        N_test = 500.0, add_features = None, n_epoch = 250, epoch_num = 10,
                        id_field = id_field, smiles_field = smiles_field)

In [None]:
# evaluation
plot_df = generate_performance_plot_RobustMT('logs/Demo_GPCRs/model_train_log.csv', 'logs/Demo_GPCRs/model_test_log.csv')
import matplotlib.pyplot as plt
g = sns.FacetGrid(plot_df, col = 'tt', hue = 'tasks')
g = (g.map(plt.plot, 'step', 'R2', marker = '.')).add_legend()

## process trained models and generate files for visualization

In [None]:
# multitask models
custom_file = './data/custom_file.txt'
custom_id_field = 'id'
custom_task_field = 'dummy_value'
custom_smiles_field = 'SMILES'
sep_custom_file = '\t'
output_prefix = './logs/Demo_GPCRs/RobustMT2_'

prev_model = './logs/Demo_GPCRs/model-2250'
layer_sizes = [1024, 128]
bypass_layer_sizes = [128]
dropout = 0.2
n_layer = 1
n_bypass = 2
model_flag = 'MT'

generate_RUNKEY_dataframe_RobustMT(prev_model, output_prefix, task_names, dataset_file, FP_type, add_features, 
                              n_features, layer_sizes, bypass_layer_sizes, model_flag, n_bypass,
                              MT_dat_name, model_test_log = './logs/Demo_GPCRs/model_test_log.csv',
                              smiles_field = smiles_field, id_field = id_field,
                              bypass_dropouts = [.5], dropout = dropout, learning_rate = 0.001, n_layer = n_layer,
                              custom_file = custom_file, custom_id_field = custom_id_field, 
                              custom_task_field = custom_task_field, custom_smiles_field = custom_smiles_field,
                              sep_custom_file = sep_custom_file, K = 5, valid_cutoff = None)