# Training / testing of GNN models 

In [19]:
# import necessary packages

import sys
import numpy as np
import datetime

# import util scripts
# add path
sys.path.append('../utils/')
sys.path.append('../pred_models/')

import pred_utils
import gnn_torch_utils
import gnn_torch_models

import importlib
importlib.reload(gnn_torch_utils)

<module 'gnn_torch_utils' from '../pred_models/gnn_torch_utils.py'>

In [2]:
# load-in data

dataset = np.load('../data/dataset.npy', allow_pickle=True).item()
print(dataset.keys())
print(dataset['PCC'].keys())
print(dataset['PCC']['t1'].keys())


dict_keys(['PCC', 'STTC', 'CCH'])
dict_keys(['t1', 't2'])
dict_keys(['nodes', 'fc_graphs', 'target_fr', 'chip_ids'])


'dataset' stores all the data used for the prediction task. <br>
e.g.) dataset['PCC']['t1'] -> dataset for PCC and task1.

'nodes' -> node features (waveform features, firing pattern features) <br>
'fc_graphs'-> FC graphs <br>
'target_fr' -> fold-changes in differential firing rates <br>
'chip_ids' -> Microelectrode array id for each network. Used for validation.


# Model selection (grid search)

In [5]:
# grid search with nested leave-one-out cross validation
# Here we only put scripts to perform the grid search and we have uploaded the best parameter found after running this script.  

# grid search parameters 
dropout_probs = [0.1, 0.2, 0.3, 0.4]
learning_rates = [0.001, 0.005, 0.01]
l2_regs = [1e-4, 1e-3, 1e-2] 
hidden_dims = [8, 16, 32]

#generate parameter sets
fit_param_list = gnn_torch_utils.gen_gridparams(dropout_probs, learning_rates, l2_regs, hidden_dims)

print(fit_param_list[0]) # print only the first parameter entry

{'dropout_prob': 0.1, 'learning_rate': 0.001, 'weight_decay': 0.0001, 'hidden_dims': 8}


In [6]:
# using PCC task1 as an example 

# load in dataset
nodes = dataset['PCC']['t1']['nodes']
FCs = dataset['PCC']['t1']['fc_graphs']
target_frs = dataset['PCC']['t1']['target_fr']
chip_ids = dataset['PCC']['t1']['chip_ids']


In [7]:
# prep. step for the nested cross validation 

# split dataset 
full_index= np.arange(len(target_frs)) # getting indices

#get unique chip ids  (we have 24 networks from 8 different unique chips (for the undirected FC case))
uniq_chip = np.unique(chip_ids)

#sample one index per chip
uniq_indices=[] 
for uniq_c in uniq_chip:
    indices = np.where(np.array(chip_ids)==uniq_c)[0]
    uniq_indices.append(indices[0])


In [None]:
# main outer-inner loop
for_each_test_idx = [] # placeholder to collect resulting MSEs (outer loop)
for ii in uniq_indices: # we will not only take out the test network but also all networks that belong to the same chip for each inner loop. 
    same_chip = np.where(np.array(chip_ids) == chip_ids[ii])[0] 
    use_idx = np.setxor1d(full_index, same_chip)
    
    nodes_inner = np.array(nodes, dtype=object)[use_idx]
    FCs_inner = np.array(FCs,dtype=object)[use_idx]
    target_frs_inner = np.array(target_frs,dtype=object)[use_idx]
    
    
    # some parameters for running the grid search
    epochs = 1000 # we used 1000 for the paper. 
    iter_n = 1 # we will not iterate computations inside the inner loop
    graph_type = 'sage1_max' # we will use graphsage model with 1 conv. layer using max pooling.
    device = 'cuda'
    
    # this line runs the inner loop 
    gcn_result= gnn_torch_utils.run_gridsearch_batch_x(nodes_inner, FCs_inner, target_frs_inner, epochs, iter_n, graph_type, fit_param_list, device, chip_ids)
    for_each_test_idx.append(gcn_result)   # collect the result of the inner loop         


we interrupted the above fitting step as it can take very long time.<br>
Instead, the best performing parameters were uploaded to this repository under the path : '/gnn_prediction_sn/data/best_params'

# Training / testing with the selected parameters

In [8]:
# load the best parameter set

best_param = np.load('../data/best_params/sage_params_x_uniq_hd_20_corr_0.npy', allow_pickle=True).item()

# example:
print(best_param['sage1_max_0_max_p'][0]) # best parameter for the network 1 (index-wise 0) when using graphsage model with 1 conv. layer with max pooling.
    

{'bs_epoch': 250, 'bs_val': 0.7983556112478638, 'dropout_prob': 0.3, 'learning_rate': 0.005, 'weight_decay': 0.001, 'hidden_dims': 8}


'bs_epoch' shows the best epoch that showed the best validation performance. <br>
'bs_val' shows the resulting average MSE value

In [9]:
# repeat the same parameter set for test networks that belong to the same chip.

gnn_params = gnn_torch_utils.match_network_param(best_param, chip_ids)

print('best_param had {} parameter sets which corresponds to the number of chips'.format(len(best_param['sage1_max_0_max_p'])))
print('gnn_params now have {} sets by repeating the parameter sets for each network'.format(len(gnn_params['sage1_max_0_max_p'])))

best_param had 8 parameter sets which corresponds to the number of chips
gnn_params now have 24 sets by repeating the parameter sets for each network


In [20]:
# some parameters for running the grid search
n_epoch = 1000 # this will be overriden when the n_epoch defined in the parameter set is lower.
iter_n = 1 # for the paper, we iterated 30 times --> multiple runs with fixed random seed
graph_type = 'sage1_max' # we will use graphsage model with 1 conv. layer using max pooling.
device = 'cuda'

sage_param = gnn_params['sage1_max_0_max_p']

import warnings; warnings.simplefilter('ignore') # for turning off np.array(dtype=object)  warning.
gnn_result=gnn_torch_utils.run_GNN_batch_x(nodes, FCs, target_frs,n_epoch, iter_n, 'sage1_max', sage_param, device, chip_ids, 0)   

Epoch: 0, train_acc: 1.5909, validate_acc : 0.5056, LR : 0.00500000
Epoch: 50, train_acc: 0.9234, validate_acc : 0.4221, LR : 0.00500000
Epoch: 100, train_acc: 0.8318, validate_acc : 0.4101, LR : 0.00500000
Epoch: 150, train_acc: 0.7971, validate_acc : 0.3806, LR : 0.00500000
Epoch: 200, train_acc: 0.7864, validate_acc : 0.3746, LR : 0.00500000
iteration: 0, test_acc: 0.3686
Epoch: 0, train_acc: 1.0666, validate_acc : 0.3117, LR : 0.00500000
Epoch: 50, train_acc: 0.9069, validate_acc : 0.2827, LR : 0.00500000
Epoch: 100, train_acc: 0.8167, validate_acc : 0.2295, LR : 0.00500000
Epoch: 150, train_acc: 0.7948, validate_acc : 0.1836, LR : 0.00500000
Epoch: 200, train_acc: 0.7862, validate_acc : 0.1739, LR : 0.00500000
iteration: 0, test_acc: 0.1741
Epoch: 0, train_acc: 1.0686, validate_acc : 1.3804, LR : 0.00500000
Epoch: 50, train_acc: 0.8613, validate_acc : 1.1398, LR : 0.00500000
Epoch: 100, train_acc: 0.7893, validate_acc : 1.0988, LR : 0.00500000
Epoch: 150, train_acc: 0.7646, valida

Epoch: 650, train_acc: 0.7530, validate_acc : 2.0700, LR : 0.01000000
Epoch: 700, train_acc: 0.7505, validate_acc : 2.1066, LR : 0.01000000
Epoch: 750, train_acc: 0.7484, validate_acc : 2.0734, LR : 0.01000000
Epoch: 800, train_acc: 0.7480, validate_acc : 2.0946, LR : 0.01000000
Epoch: 850, train_acc: 0.7527, validate_acc : 2.0802, LR : 0.01000000
Epoch: 900, train_acc: 0.7506, validate_acc : 2.1254, LR : 0.01000000
Epoch: 950, train_acc: 0.7496, validate_acc : 2.0777, LR : 0.01000000
iteration: 0, test_acc: 2.0390
Epoch: 0, train_acc: 1.0728, validate_acc : 1.0675, LR : 0.01000000
Epoch: 50, train_acc: 0.8160, validate_acc : 0.9755, LR : 0.01000000
Epoch: 100, train_acc: 0.7670, validate_acc : 1.0300, LR : 0.01000000
iteration: 0, test_acc: 1.0370
Epoch: 0, train_acc: 1.0451, validate_acc : 0.9030, LR : 0.01000000
Epoch: 50, train_acc: 0.8773, validate_acc : 0.8187, LR : 0.01000000
Epoch: 100, train_acc: 0.8138, validate_acc : 0.7558, LR : 0.01000000
iteration: 0, test_acc: 0.7427
Epo

In [75]:
# looking at the result

print(gnn_result[0].keys()) # gnn result for network 1 (index 0) 

dict_keys(['mse_train', 'mae_train', 'mse_test', 'mae_test', 'train_curve', 'validate_curve'])


# Training / testing of non-GNN models 

As the workflow is same with the GNN models, here we provide fitting scripts for the baseline model (average of target variables), linear regression and random forest regressor. 

In [88]:
import non_gnn_models
importlib.reload(non_gnn_models)

y_scale = 1 # boolean for standard scaling target variables as well

# Baseline model
baseline_result = non_gnn_models.average_mse_batch_x(target_frs, y_scale, chip_ids) 
print(baseline_result.keys()) # baseline model result of network 1 (index 0)

dict_keys(['mse_test', 'mse_train', 'mae_test', 'mae_train'])


In [85]:
# linear regression model

iter_n = 30 # 30 runs with the fixed random seed 
linear_result = non_gnn_models.linear_reg_batch_x(nodes, target_frs, iter_n, y_scale, chip_ids)
print(linear_result[0].keys()) # R-sq is a R-sq value for the training data, R-sq test is a R-sq value for the testing data

dict_keys(['R-sq', 'slope_coef', 'mse_train', 'mae_train', 'pred', 'R-sq test', 'mse_test', 'mae_test'])


In [91]:
# random forest regression model

iter_n = 1 # 1 run (with the fixed random seed) 
rf_result = non_gnn_models.rf_reg_batch_x(nodes, target_frs, iter_n, y_scale, chip_ids, False) # rf regressor with default parameters
print(rf_result[0].keys()) # R-sq is a R-sq value for the training data, R-sq test is a R-sq value for the testing data

dict_keys(['reg_score', 'mse_train', 'y_pred', 'feat_importance', 'mse_test', 'mae_train', 'mae_test'])


In [93]:
# random forest regression model with a grid-searched parameter
rf_param = np.load('../data/best_params/rf_batch_best_param_0.2_0_max_p_x.npy', allow_pickle=True) # grid-searched parameter for undirected FC tasks

iter_n = 1 # 1 run (with the fixed random seed) 
rf_result = non_gnn_models.rf_reg_batch_x(nodes, target_frs, iter_n, y_scale, chip_ids, rf_param) # rf regressor with default parameters
print(rf_result[0].keys()) # R-sq is a R-sq value for the training data, R-sq test is a R-sq value for the testing data


dict_keys(['reg_score', 'mse_train', 'y_pred', 'feat_importance', 'mse_test', 'mae_train', 'mae_test'])
