In [None]:
import os
os.chdir('..')

In [None]:
import torch
import optuna
from optuna.trial import TrialState
from gnn_library.util import train, gen_train_input
from params import *

%load_ext autoreload
%autoreload 2

In [None]:
device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')
print("PyTorch has version {}".format(torch.__version__))
print('Using device:', device)

## GNN hyperparameter tuning

In [None]:
def define_model(trial):
    args = {
        'processor':         'GENConv',
        'head':              'regression',
        'num_layers':        trial.suggest_int("num_layers{}", 1, 6),
        'num_mlp_layers':    trial.suggest_int("num_mlp_layers{}", 1, 5),
        'aggr':              'max',
        'batch_size':        2**trial.suggest_int("log_batch_size", 1, 6), 
        'node_feature_dim':  5,
        'edge_feature_dim':  1,
        'graph_feature_dim': 2,
        'hidden_dim':        2**trial.suggest_int("hidden_dim", 1, 7),
        'output_dim':        1,
        'dropout':           trial.suggest_float("dropout", 0, 0.5),
        'epochs':            2**trial.suggest_int("epochs", 2, 8),
        'opt':               trial.suggest_categorical("optimizer", ["adam", "adagrad"]),
        'opt_scheduler':     'none',
        'opt_restart':       0,
        'weight_decay':      5e-3,
        'lr':                trial.suggest_float("lr", 1e-5, 1e-1, log=True),
        'device':            device,
        'noise':             0
    }
    return args

In [None]:
def objective(trial):
	args = define_model(trial)

	train_loader, val_loader = gen_train_input(BASE_MODEL_TRAIN_CONFIG, args, seed=0)

	_, _, test_accuracies, _, _ = train(train_loader, val_loader, args, trial)
	model_accuracy = test_accuracies[-1]
	
	return model_accuracy

In [None]:

study = optuna.create_study(
	study_name='hyperparam-study',
	direction='maximize',
	storage='sqlite:///hyperparam.db',
	load_if_exists=True
)
study.optimize(objective, n_trials=1000, timeout=60000)

pruned_trials = study.get_trials(deepcopy=False, states=[TrialState.PRUNED])
complete_trials = study.get_trials(deepcopy=False, states=[TrialState.COMPLETE])

print("Study statistics: ")
print("  Number of finished trials: ", len(study.trials))
print("  Number of pruned trials: ", len(pruned_trials))
print("  Number of complete trials: ", len(complete_trials))

print("Best trial:")
trial = study.best_trial

print("  Value: ", trial.value)

print("  Params: ")
for key, value in trial.params.items():
	print("    {}: {}".format(key, value))
