Oops. realized the edge features were totally fucked up. In this notebook, we test out hyperparameter tuning with Optuna. again. with GNBlocks and fixed (still fully connected) edges.

In [24]:
# ipython extension to autoreload imported modules so that any changes will be up to date before running code in this nb
%load_ext autoreload 
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [25]:
from utils.jraph_training import train_and_evaluate_with_data, create_dataset
# from utils.jraph_models import MLPGraphNetwork
import ml_collections
import optuna 
from functools import partial
from datetime import datetime
import os 

In [26]:
# set up logging
import logging
logger = logging.getLogger()
logger.setLevel(logging.WARNING)

### set up functions for optuna

In [27]:
CHECKPOINT_PATH = "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/experiments/tuning"

In [28]:
def objective(trial, datasets):
    """ Defines the objective function to be optimized over, aka the validation loss of a model.
    
        Args:
            trial: object which characterizes the current run 
            datasets: dictionary of data. we explicitly pass this in so that we don't have to waste runtime regenerating the same dataset over and over. 
    """
    # create config 
    config = ml_collections.ConfigDict()

    # Optimizer.
    config.optimizer = "adam"
    # config.optimizer = trial.suggest_categorical("optimizer", ["adam", "sgd"])
    config.learning_rate = trial.suggest_float('learning_rate', 1e-4, 1e-2, 
                                               log=True)
    if config.optimizer == "sgd":
        config.momentum = trial.suggest_float('momentum', 0, 0.999) # upper bound is inclusive, and we want to exclude a momentum of 1 because that would yield no decay 

    # Data params that are used in training 
    config.output_steps = 4

    # Training hyperparameters.
    config.batch_size = 1 # variable currently not used
    config.epochs = 5
    config.log_every_epochs = 5
    config.eval_every_epochs = 5
    config.checkpoint_every_epochs = 10

    # GNN hyperparameters.
    config.model = 'MLPBlock'
    config.dropout_rate = trial.suggest_float('dropout_rate', 0, 0.6)
    config.skip_connections = False # This was throwing a broadcast error in add_graphs_tuples_nodes when this was set to True
    config.layer_norm = False # TODO perhaps we want to turn on later

    # choose the hidden layer feature size using powers of 2 
    config.edge_features = (
        2**trial.suggest_int("edge_mlp_1_power", 1, 5), # range 2 - 64; upper bound is inclusive
        2**trial.suggest_int("edge_mlp_2_power", 1, 5), # range 2 - 64
    )
    config.node_features = (
        2**trial.suggest_int("node_mlp_1_power", 1, 8), # range 2 - 512
        2**trial.suggest_int("node_mlp_2_power", 1, 8), # range 2 - 512
        2) 
    # note the last feature size will be the number of features that the graph predicts
    config.global_features = None

    # generate a workdir 
    # TODO: check if we actually care about referencing this in the future or if we can just create a temp dir 
    workdir=os.path.join(CHECKPOINT_PATH, "checkpoints", str(datetime.now()))

    # run training 
    state, train_metrics, eval_metrics_dict = train_and_evaluate_with_data(config=config, workdir=workdir, datasets=datasets)
    
    # retrieve and return val loss (MSE)
    print("eval_metrics_dict['val'].loss", eval_metrics_dict['val'].loss)
    print()
    return eval_metrics_dict['val'].loss.total




In [29]:
def get_data_config():
    config = ml_collections.ConfigDict()

    config.n_samples=100
    config.input_steps=1
    config.output_delay=8 # predict 24 hrs into the future 
    config.output_steps=4
    config.timestep_duration=3 # equivalent to 3 hours
    # note a 3 hour timestep resolution would be 5*24/3=40
    # if the time_resolution is 120, then a sampling frequency of 3 would achieve a 3 hour timestep 
    config.sample_buffer = -1 * (config.input_steps + config.output_delay + config.output_steps - 1) # negative buffer so that our sample input are continuous (i.e. the first sample would overlap a bit with consecutive samples) 
        # number of timesteps strictly between the end 
        # of one full sample and the start of the next sample
    config.time_resolution=120 # the number of 
                # raw data points generated per time unit, equivalent to the 
                # number of data points generated per 5 days in the simulation
    config.init_buffer_samples=100
    config.train_pct=0.7
    config.val_pct=0.2
    config.test_pct=0.1
    config.K=36
    config.F=8
    config.c=10
    config.b=10
    config.h=1
    config.seed=42
    config.normalize=True
    config.fully_connected_edges=True

    return config

In [30]:
def prepare_study(study_name):
    # generate dataset 
    dataset_config = get_data_config()
    datasets = create_dataset(dataset_config)

    # get the objective function that reuses the pre-generated datasets 
    objective_with_dataset = partial(objective, datasets=datasets)

    # run optimization study
    db_path = os.path.join(CHECKPOINT_PATH, study_name, "optuna_hparam_search.db")
    if not os.path.exists(os.path.join(CHECKPOINT_PATH, study_name)):
        os.makedirs(os.path.join(CHECKPOINT_PATH, study_name))

    study = optuna.create_study(
        study_name=study_name,
        storage=f'sqlite:///{db_path}', # generates a new db if it doesn't exist
        direction='minimize',
        pruner=optuna.pruners.MedianPruner(
            n_startup_trials=5, 
            n_warmup_steps=3
            ), 
        load_if_exists=True, 
    )
    # study.optimize(objective_with_dataset, 
    #                n_trials=25-len(study.trials), 
    #                n_jobs=1)
    
    return study, objective_with_dataset

### try hyperparameter tuning again with fewer params and lower learning rate options

In [31]:
# get study
study4, objective_with_dataset = prepare_study(study_name="hparam_study_4")

[I 2023-11-30 23:29:23,159] A new study created in RDB with name: hparam_study_4


In [32]:
study4.optimize(objective_with_dataset, 
                n_trials=5-len(study4.trials), 
                n_jobs=1)

[I 2023-11-30 23:29:32,231] Trial 0 finished with value: 22.660564422607422 and parameters: {'learning_rate': 0.0034667424009722694, 'dropout_rate': 0.1964864104266136, 'edge_mlp_1_power': 1, 'edge_mlp_2_power': 4, 'node_mlp_1_power': 5, 'node_mlp_2_power': 6}. Best is trial 0 with value: 22.660564422607422.


eval_metrics_dict['val'].loss Metric.from_output.<locals>.FromOutput(total=Array(22.660564, dtype=float32), count=Array(20., dtype=float32))



ERROR:absl:Could not start profiling: Profile has already been started. Only one profile may be run at a time.
Traceback (most recent call last):
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/periodic_actions.py", line 327, in _start_session
    profiler.start(logdir=self._logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/profiler.py", line 38, in start
    jax.profiler.start_trace(logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/jax/_src/profiler.py", line 111, in start_trace
    raise RuntimeError("Profile has already been started. "
RuntimeError: Profile has already been started. Only one profile may be run at a time.
[I 2023-11-30 23:29:41,060] Trial 1 finished with value: 22.660564422607422 and parameters: {'learning_rate': 0.001066816559488955, 'dropout_rate': 0.09282417161763709, 'e

eval_metrics_dict['val'].loss Metric.from_output.<locals>.FromOutput(total=Array(22.660564, dtype=float32), count=Array(20., dtype=float32))



ERROR:absl:Could not start profiling: Profile has already been started. Only one profile may be run at a time.
Traceback (most recent call last):
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/periodic_actions.py", line 327, in _start_session
    profiler.start(logdir=self._logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/profiler.py", line 38, in start
    jax.profiler.start_trace(logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/jax/_src/profiler.py", line 111, in start_trace
    raise RuntimeError("Profile has already been started. "
RuntimeError: Profile has already been started. Only one profile may be run at a time.
[I 2023-11-30 23:29:50,716] Trial 2 finished with value: 22.745704650878906 and parameters: {'learning_rate': 0.007280362043053632, 'dropout_rate': 0.20132252731494885, 'e

eval_metrics_dict['val'].loss Metric.from_output.<locals>.FromOutput(total=Array(22.745705, dtype=float32), count=Array(20., dtype=float32))



ERROR:absl:Could not start profiling: Profile has already been started. Only one profile may be run at a time.
Traceback (most recent call last):
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/periodic_actions.py", line 327, in _start_session
    profiler.start(logdir=self._logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/profiler.py", line 38, in start
    jax.profiler.start_trace(logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/jax/_src/profiler.py", line 111, in start_trace
    raise RuntimeError("Profile has already been started. "
RuntimeError: Profile has already been started. Only one profile may be run at a time.
[I 2023-11-30 23:30:00,466] Trial 3 finished with value: 22.660564422607422 and parameters: {'learning_rate': 0.0018370814430460076, 'dropout_rate': 0.12937041581497571, '

eval_metrics_dict['val'].loss Metric.from_output.<locals>.FromOutput(total=Array(22.660564, dtype=float32), count=Array(20., dtype=float32))



ERROR:absl:Could not start profiling: Profile has already been started. Only one profile may be run at a time.
Traceback (most recent call last):
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/periodic_actions.py", line 327, in _start_session
    profiler.start(logdir=self._logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/profiler.py", line 38, in start
    jax.profiler.start_trace(logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/jax/_src/profiler.py", line 111, in start_trace
    raise RuntimeError("Profile has already been started. "
RuntimeError: Profile has already been started. Only one profile may be run at a time.
[I 2023-11-30 23:30:09,348] Trial 4 finished with value: 22.660564422607422 and parameters: {'learning_rate': 0.005303979026581395, 'dropout_rate': 0.1691693910708621, 'ed

eval_metrics_dict['val'].loss Metric.from_output.<locals>.FromOutput(total=Array(22.660564, dtype=float32), count=Array(20., dtype=float32))



ok let's try visualizing this to see what in the hell is going wrong

In [33]:
fig = optuna.visualization.plot_intermediate_values(study4)
fig.show()

[W 2023-11-30 23:30:09,670] You need to set up the pruning feature to utilize `plot_intermediate_values()`


In [34]:
# plot the estimated accuracy surface over hyperparameters:
fig = optuna.visualization.plot_contour(study4, params=['learning_rate', 'dropout_rate'])
fig.show()

In [35]:
# plot the estimated accuracy surface over hyperparameters:
fig = optuna.visualization.plot_contour(study4, params=['learning_rate', 'node_mlp_1_power'])
fig.show()

In [36]:
study4.optimize(objective_with_dataset, 
                n_trials=30-len(study4.trials), 
                n_jobs=1)

ERROR:absl:Could not start profiling: Profile has already been started. Only one profile may be run at a time.
Traceback (most recent call last):
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/periodic_actions.py", line 327, in _start_session
    profiler.start(logdir=self._logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/profiler.py", line 38, in start
    jax.profiler.start_trace(logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/jax/_src/profiler.py", line 111, in start_trace
    raise RuntimeError("Profile has already been started. "
RuntimeError: Profile has already been started. Only one profile may be run at a time.
[I 2023-11-30 23:30:18,692] Trial 5 finished with value: 22814802.0 and parameters: {'learning_rate': 0.00017339592270283462, 'dropout_rate': 0.3032256835826346, 'edge_mlp

eval_metrics_dict['val'].loss Metric.from_output.<locals>.FromOutput(total=Array(22814802., dtype=float32), count=Array(20., dtype=float32))



ERROR:absl:Could not start profiling: Profile has already been started. Only one profile may be run at a time.
Traceback (most recent call last):
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/periodic_actions.py", line 327, in _start_session
    profiler.start(logdir=self._logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/profiler.py", line 38, in start
    jax.profiler.start_trace(logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/jax/_src/profiler.py", line 111, in start_trace
    raise RuntimeError("Profile has already been started. "
RuntimeError: Profile has already been started. Only one profile may be run at a time.
[I 2023-11-30 23:30:26,731] Trial 6 finished with value: 22.660564422607422 and parameters: {'learning_rate': 0.008175768154051626, 'dropout_rate': 0.4217541437402053, 'ed

eval_metrics_dict['val'].loss Metric.from_output.<locals>.FromOutput(total=Array(22.660564, dtype=float32), count=Array(20., dtype=float32))



ERROR:absl:Could not start profiling: Profile has already been started. Only one profile may be run at a time.
Traceback (most recent call last):
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/periodic_actions.py", line 327, in _start_session
    profiler.start(logdir=self._logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/profiler.py", line 38, in start
    jax.profiler.start_trace(logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/jax/_src/profiler.py", line 111, in start_trace
    raise RuntimeError("Profile has already been started. "
RuntimeError: Profile has already been started. Only one profile may be run at a time.
[I 2023-11-30 23:30:34,426] Trial 7 finished with value: 22.660564422607422 and parameters: {'learning_rate': 0.0019016227469117167, 'dropout_rate': 0.13165895755478507, '

eval_metrics_dict['val'].loss Metric.from_output.<locals>.FromOutput(total=Array(22.660564, dtype=float32), count=Array(20., dtype=float32))



ERROR:absl:Could not start profiling: Profile has already been started. Only one profile may be run at a time.
Traceback (most recent call last):
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/periodic_actions.py", line 327, in _start_session
    profiler.start(logdir=self._logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/profiler.py", line 38, in start
    jax.profiler.start_trace(logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/jax/_src/profiler.py", line 111, in start_trace
    raise RuntimeError("Profile has already been started. "
RuntimeError: Profile has already been started. Only one profile may be run at a time.
[I 2023-11-30 23:30:42,452] Trial 8 finished with value: 22.660564422607422 and parameters: {'learning_rate': 0.00048630216602522344, 'dropout_rate': 0.4289139710249268, '

eval_metrics_dict['val'].loss Metric.from_output.<locals>.FromOutput(total=Array(22.660564, dtype=float32), count=Array(20., dtype=float32))



ERROR:absl:Could not start profiling: Profile has already been started. Only one profile may be run at a time.
Traceback (most recent call last):
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/periodic_actions.py", line 327, in _start_session
    profiler.start(logdir=self._logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/profiler.py", line 38, in start
    jax.profiler.start_trace(logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/jax/_src/profiler.py", line 111, in start_trace
    raise RuntimeError("Profile has already been started. "
RuntimeError: Profile has already been started. Only one profile may be run at a time.
[I 2023-11-30 23:30:49,952] Trial 9 finished with value: 22.660564422607422 and parameters: {'learning_rate': 0.009229296100081235, 'dropout_rate': 0.5917992200275655, 'ed

eval_metrics_dict['val'].loss Metric.from_output.<locals>.FromOutput(total=Array(22.660564, dtype=float32), count=Array(20., dtype=float32))



ERROR:absl:Could not start profiling: Profile has already been started. Only one profile may be run at a time.
Traceback (most recent call last):
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/periodic_actions.py", line 327, in _start_session
    profiler.start(logdir=self._logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/profiler.py", line 38, in start
    jax.profiler.start_trace(logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/jax/_src/profiler.py", line 111, in start_trace
    raise RuntimeError("Profile has already been started. "
RuntimeError: Profile has already been started. Only one profile may be run at a time.
[I 2023-11-30 23:30:58,128] Trial 10 finished with value: 49.12484359741211 and parameters: {'learning_rate': 0.003017155981381823, 'dropout_rate': 0.02762743369507939, 'e

eval_metrics_dict['val'].loss Metric.from_output.<locals>.FromOutput(total=Array(49.124844, dtype=float32), count=Array(20., dtype=float32))



ERROR:absl:Could not start profiling: Profile has already been started. Only one profile may be run at a time.
Traceback (most recent call last):
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/periodic_actions.py", line 327, in _start_session
    profiler.start(logdir=self._logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/profiler.py", line 38, in start
    jax.profiler.start_trace(logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/jax/_src/profiler.py", line 111, in start_trace
    raise RuntimeError("Profile has already been started. "
RuntimeError: Profile has already been started. Only one profile may be run at a time.
[I 2023-11-30 23:31:05,806] Trial 11 finished with value: 4209113.5 and parameters: {'learning_rate': 0.0010255112465886234, 'dropout_rate': 0.021553586834614002, 'edge_ml

eval_metrics_dict['val'].loss Metric.from_output.<locals>.FromOutput(total=Array(4209113.5, dtype=float32), count=Array(20., dtype=float32))



ERROR:absl:Could not start profiling: Profile has already been started. Only one profile may be run at a time.
Traceback (most recent call last):
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/periodic_actions.py", line 327, in _start_session
    profiler.start(logdir=self._logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/profiler.py", line 38, in start
    jax.profiler.start_trace(logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/jax/_src/profiler.py", line 111, in start_trace
    raise RuntimeError("Profile has already been started. "
RuntimeError: Profile has already been started. Only one profile may be run at a time.
[I 2023-11-30 23:31:14,156] Trial 12 finished with value: 22.660564422607422 and parameters: {'learning_rate': 0.0006781586311251609, 'dropout_rate': 0.21152252669658989, 

eval_metrics_dict['val'].loss Metric.from_output.<locals>.FromOutput(total=Array(22.660564, dtype=float32), count=Array(20., dtype=float32))



ERROR:absl:Could not start profiling: Profile has already been started. Only one profile may be run at a time.
Traceback (most recent call last):
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/periodic_actions.py", line 327, in _start_session
    profiler.start(logdir=self._logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/profiler.py", line 38, in start
    jax.profiler.start_trace(logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/jax/_src/profiler.py", line 111, in start_trace
    raise RuntimeError("Profile has already been started. "
RuntimeError: Profile has already been started. Only one profile may be run at a time.
[I 2023-11-30 23:31:21,939] Trial 13 finished with value: 22.660554885864258 and parameters: {'learning_rate': 0.0035733514657635317, 'dropout_rate': 0.07933914295817059, 

eval_metrics_dict['val'].loss Metric.from_output.<locals>.FromOutput(total=Array(22.660555, dtype=float32), count=Array(20., dtype=float32))



ERROR:absl:Could not start profiling: Profile has already been started. Only one profile may be run at a time.
Traceback (most recent call last):
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/periodic_actions.py", line 327, in _start_session
    profiler.start(logdir=self._logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/profiler.py", line 38, in start
    jax.profiler.start_trace(logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/jax/_src/profiler.py", line 111, in start_trace
    raise RuntimeError("Profile has already been started. "
RuntimeError: Profile has already been started. Only one profile may be run at a time.
[I 2023-11-30 23:31:30,038] Trial 14 finished with value: 22.660564422607422 and parameters: {'learning_rate': 0.003822058106255783, 'dropout_rate': 0.020162709083485192, 

eval_metrics_dict['val'].loss Metric.from_output.<locals>.FromOutput(total=Array(22.660564, dtype=float32), count=Array(20., dtype=float32))



ERROR:absl:Could not start profiling: Profile has already been started. Only one profile may be run at a time.
Traceback (most recent call last):
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/periodic_actions.py", line 327, in _start_session
    profiler.start(logdir=self._logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/profiler.py", line 38, in start
    jax.profiler.start_trace(logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/jax/_src/profiler.py", line 111, in start_trace
    raise RuntimeError("Profile has already been started. "
RuntimeError: Profile has already been started. Only one profile may be run at a time.
[I 2023-11-30 23:31:37,810] Trial 15 finished with value: 33394.25 and parameters: {'learning_rate': 0.0037050845176873387, 'dropout_rate': 0.2661140437966691, 'edge_mlp_1

eval_metrics_dict['val'].loss Metric.from_output.<locals>.FromOutput(total=Array(33394.25, dtype=float32), count=Array(20., dtype=float32))



ERROR:absl:Could not start profiling: Profile has already been started. Only one profile may be run at a time.
Traceback (most recent call last):
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/periodic_actions.py", line 327, in _start_session
    profiler.start(logdir=self._logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/profiler.py", line 38, in start
    jax.profiler.start_trace(logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/jax/_src/profiler.py", line 111, in start_trace
    raise RuntimeError("Profile has already been started. "
RuntimeError: Profile has already been started. Only one profile may be run at a time.
[I 2023-11-30 23:31:45,801] Trial 16 finished with value: 22.660564422607422 and parameters: {'learning_rate': 0.0026794675518192978, 'dropout_rate': 0.095931164246243, 'e

eval_metrics_dict['val'].loss Metric.from_output.<locals>.FromOutput(total=Array(22.660564, dtype=float32), count=Array(20., dtype=float32))



ERROR:absl:Could not start profiling: Profile has already been started. Only one profile may be run at a time.
Traceback (most recent call last):
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/periodic_actions.py", line 327, in _start_session
    profiler.start(logdir=self._logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/profiler.py", line 38, in start
    jax.profiler.start_trace(logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/jax/_src/profiler.py", line 111, in start_trace
    raise RuntimeError("Profile has already been started. "
RuntimeError: Profile has already been started. Only one profile may be run at a time.
[I 2023-11-30 23:31:53,582] Trial 17 finished with value: 22.660564422607422 and parameters: {'learning_rate': 0.005219008867149505, 'dropout_rate': 0.08524082776014036, '

eval_metrics_dict['val'].loss Metric.from_output.<locals>.FromOutput(total=Array(22.660564, dtype=float32), count=Array(20., dtype=float32))



ERROR:absl:Could not start profiling: Profile has already been started. Only one profile may be run at a time.
Traceback (most recent call last):
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/periodic_actions.py", line 327, in _start_session
    profiler.start(logdir=self._logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/profiler.py", line 38, in start
    jax.profiler.start_trace(logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/jax/_src/profiler.py", line 111, in start_trace
    raise RuntimeError("Profile has already been started. "
RuntimeError: Profile has already been started. Only one profile may be run at a time.
[I 2023-11-30 23:32:01,815] Trial 18 finished with value: 22.660564422607422 and parameters: {'learning_rate': 0.002047390326807629, 'dropout_rate': 0.2531510488233926, 'e

eval_metrics_dict['val'].loss Metric.from_output.<locals>.FromOutput(total=Array(22.660564, dtype=float32), count=Array(20., dtype=float32))



ERROR:absl:Could not start profiling: Profile has already been started. Only one profile may be run at a time.
Traceback (most recent call last):
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/periodic_actions.py", line 327, in _start_session
    profiler.start(logdir=self._logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/profiler.py", line 38, in start
    jax.profiler.start_trace(logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/jax/_src/profiler.py", line 111, in start_trace
    raise RuntimeError("Profile has already been started. "
RuntimeError: Profile has already been started. Only one profile may be run at a time.
[I 2023-11-30 23:32:10,184] Trial 19 finished with value: 22.660564422607422 and parameters: {'learning_rate': 0.0050848915966154465, 'dropout_rate': 0.17430181894541852, 

eval_metrics_dict['val'].loss Metric.from_output.<locals>.FromOutput(total=Array(22.660564, dtype=float32), count=Array(20., dtype=float32))



ERROR:absl:Could not start profiling: Profile has already been started. Only one profile may be run at a time.
Traceback (most recent call last):
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/periodic_actions.py", line 327, in _start_session
    profiler.start(logdir=self._logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/profiler.py", line 38, in start
    jax.profiler.start_trace(logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/jax/_src/profiler.py", line 111, in start_trace
    raise RuntimeError("Profile has already been started. "
RuntimeError: Profile has already been started. Only one profile may be run at a time.
[I 2023-11-30 23:32:17,441] Trial 20 finished with value: 22.660564422607422 and parameters: {'learning_rate': 0.009705621494843912, 'dropout_rate': 0.0021271497528881633,

eval_metrics_dict['val'].loss Metric.from_output.<locals>.FromOutput(total=Array(22.660564, dtype=float32), count=Array(20., dtype=float32))



ERROR:absl:Could not start profiling: Profile has already been started. Only one profile may be run at a time.
Traceback (most recent call last):
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/periodic_actions.py", line 327, in _start_session
    profiler.start(logdir=self._logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/profiler.py", line 38, in start
    jax.profiler.start_trace(logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/jax/_src/profiler.py", line 111, in start_trace
    raise RuntimeError("Profile has already been started. "
RuntimeError: Profile has already been started. Only one profile may be run at a time.
[I 2023-11-30 23:32:25,262] Trial 21 finished with value: 22.660564422607422 and parameters: {'learning_rate': 0.0013156901561046574, 'dropout_rate': 0.08390117180733755, 

eval_metrics_dict['val'].loss Metric.from_output.<locals>.FromOutput(total=Array(22.660564, dtype=float32), count=Array(20., dtype=float32))



ERROR:absl:Could not start profiling: Profile has already been started. Only one profile may be run at a time.
Traceback (most recent call last):
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/periodic_actions.py", line 327, in _start_session
    profiler.start(logdir=self._logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/profiler.py", line 38, in start
    jax.profiler.start_trace(logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/jax/_src/profiler.py", line 111, in start_trace
    raise RuntimeError("Profile has already been started. "
RuntimeError: Profile has already been started. Only one profile may be run at a time.
[I 2023-11-30 23:32:33,220] Trial 22 finished with value: 22.660564422607422 and parameters: {'learning_rate': 0.0013878068040686499, 'dropout_rate': 0.09365149322591754, 

eval_metrics_dict['val'].loss Metric.from_output.<locals>.FromOutput(total=Array(22.660564, dtype=float32), count=Array(20., dtype=float32))



ERROR:absl:Could not start profiling: Profile has already been started. Only one profile may be run at a time.
Traceback (most recent call last):
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/periodic_actions.py", line 327, in _start_session
    profiler.start(logdir=self._logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/profiler.py", line 38, in start
    jax.profiler.start_trace(logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/jax/_src/profiler.py", line 111, in start_trace
    raise RuntimeError("Profile has already been started. "
RuntimeError: Profile has already been started. Only one profile may be run at a time.
[I 2023-11-30 23:32:40,784] Trial 23 finished with value: 22.660564422607422 and parameters: {'learning_rate': 0.002534490031817606, 'dropout_rate': 0.06538890511889628, '

eval_metrics_dict['val'].loss Metric.from_output.<locals>.FromOutput(total=Array(22.660564, dtype=float32), count=Array(20., dtype=float32))



ERROR:absl:Could not start profiling: Profile has already been started. Only one profile may be run at a time.
Traceback (most recent call last):
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/periodic_actions.py", line 327, in _start_session
    profiler.start(logdir=self._logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/profiler.py", line 38, in start
    jax.profiler.start_trace(logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/jax/_src/profiler.py", line 111, in start_trace
    raise RuntimeError("Profile has already been started. "
RuntimeError: Profile has already been started. Only one profile may be run at a time.
[I 2023-11-30 23:32:48,604] Trial 24 finished with value: 24.446767807006836 and parameters: {'learning_rate': 0.0007281636878263835, 'dropout_rate': 0.1463165125591745, '

eval_metrics_dict['val'].loss Metric.from_output.<locals>.FromOutput(total=Array(24.446768, dtype=float32), count=Array(20., dtype=float32))



ERROR:absl:Could not start profiling: Profile has already been started. Only one profile may be run at a time.
Traceback (most recent call last):
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/periodic_actions.py", line 327, in _start_session
    profiler.start(logdir=self._logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/profiler.py", line 38, in start
    jax.profiler.start_trace(logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/jax/_src/profiler.py", line 111, in start_trace
    raise RuntimeError("Profile has already been started. "
RuntimeError: Profile has already been started. Only one profile may be run at a time.
[I 2023-11-30 23:32:56,426] Trial 25 finished with value: 22.660564422607422 and parameters: {'learning_rate': 0.003786274794309696, 'dropout_rate': 0.1261857932644735, 'e

eval_metrics_dict['val'].loss Metric.from_output.<locals>.FromOutput(total=Array(22.660564, dtype=float32), count=Array(20., dtype=float32))



ERROR:absl:Could not start profiling: Profile has already been started. Only one profile may be run at a time.
Traceback (most recent call last):
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/periodic_actions.py", line 327, in _start_session
    profiler.start(logdir=self._logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/profiler.py", line 38, in start
    jax.profiler.start_trace(logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/jax/_src/profiler.py", line 111, in start_trace
    raise RuntimeError("Profile has already been started. "
RuntimeError: Profile has already been started. Only one profile may be run at a time.
[I 2023-11-30 23:33:03,926] Trial 26 finished with value: 38.107181549072266 and parameters: {'learning_rate': 0.0014537843442665897, 'dropout_rate': 0.0723712209891952, '

eval_metrics_dict['val'].loss Metric.from_output.<locals>.FromOutput(total=Array(38.10718, dtype=float32), count=Array(20., dtype=float32))



ERROR:absl:Could not start profiling: Profile has already been started. Only one profile may be run at a time.
Traceback (most recent call last):
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/periodic_actions.py", line 327, in _start_session
    profiler.start(logdir=self._logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/profiler.py", line 38, in start
    jax.profiler.start_trace(logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/jax/_src/profiler.py", line 111, in start_trace
    raise RuntimeError("Profile has already been started. "
RuntimeError: Profile has already been started. Only one profile may be run at a time.
[I 2023-11-30 23:33:12,003] Trial 27 finished with value: 22.660564422607422 and parameters: {'learning_rate': 0.002456595858441338, 'dropout_rate': 0.04441066322306697, '

eval_metrics_dict['val'].loss Metric.from_output.<locals>.FromOutput(total=Array(22.660564, dtype=float32), count=Array(20., dtype=float32))



ERROR:absl:Could not start profiling: Profile has already been started. Only one profile may be run at a time.
Traceback (most recent call last):
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/periodic_actions.py", line 327, in _start_session
    profiler.start(logdir=self._logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/profiler.py", line 38, in start
    jax.profiler.start_trace(logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/jax/_src/profiler.py", line 111, in start_trace
    raise RuntimeError("Profile has already been started. "
RuntimeError: Profile has already been started. Only one profile may be run at a time.
[I 2023-11-30 23:33:19,546] Trial 28 finished with value: 22.660564422607422 and parameters: {'learning_rate': 0.005565921210501105, 'dropout_rate': 0.18823598855157853, '

eval_metrics_dict['val'].loss Metric.from_output.<locals>.FromOutput(total=Array(22.660564, dtype=float32), count=Array(20., dtype=float32))



ERROR:absl:Could not start profiling: Profile has already been started. Only one profile may be run at a time.
Traceback (most recent call last):
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/periodic_actions.py", line 327, in _start_session
    profiler.start(logdir=self._logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/profiler.py", line 38, in start
    jax.profiler.start_trace(logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/jax/_src/profiler.py", line 111, in start_trace
    raise RuntimeError("Profile has already been started. "
RuntimeError: Profile has already been started. Only one profile may be run at a time.
[I 2023-11-30 23:33:28,156] Trial 29 finished with value: 22.660564422607422 and parameters: {'learning_rate': 0.006666100002265899, 'dropout_rate': 0.20980454320082145, '

eval_metrics_dict['val'].loss Metric.from_output.<locals>.FromOutput(total=Array(22.660564, dtype=float32), count=Array(20., dtype=float32))



In [37]:
fig = optuna.visualization.plot_intermediate_values(study4)
fig.show()

[W 2023-11-30 23:33:28,281] You need to set up the pruning feature to utilize `plot_intermediate_values()`


In [38]:
# plot the estimated accuracy surface over hyperparameters:
fig = optuna.visualization.plot_contour(study4, params=['learning_rate', 'dropout_rate'])
fig.show()

In [39]:
# plot the estimated accuracy surface over hyperparameters:
fig = optuna.visualization.plot_contour(study4, params=['learning_rate', 'node_mlp_1_power'])
fig.show()

In [40]:
print(study4.direction)

StudyDirection.MINIMIZE


In [41]:
study4.trials

[FrozenTrial(number=0, state=TrialState.COMPLETE, values=[22.660564422607422], datetime_start=datetime.datetime(2023, 11, 30, 23, 29, 23, 273088), datetime_complete=datetime.datetime(2023, 11, 30, 23, 29, 32, 213207), params={'learning_rate': 0.0034667424009722694, 'dropout_rate': 0.1964864104266136, 'edge_mlp_1_power': 1, 'edge_mlp_2_power': 4, 'node_mlp_1_power': 5, 'node_mlp_2_power': 6}, user_attrs={}, system_attrs={}, intermediate_values={}, distributions={'learning_rate': FloatDistribution(high=0.01, log=True, low=0.0001, step=None), 'dropout_rate': FloatDistribution(high=0.6, log=False, low=0.0, step=None), 'edge_mlp_1_power': IntDistribution(high=5, log=False, low=1, step=1), 'edge_mlp_2_power': IntDistribution(high=5, log=False, low=1, step=1), 'node_mlp_1_power': IntDistribution(high=8, log=False, low=1, step=1), 'node_mlp_2_power': IntDistribution(high=8, log=False, low=1, step=1)}, trial_id=1, value=None),
 FrozenTrial(number=1, state=TrialState.COMPLETE, values=[22.6605644

well. fixing the edge features did not seem to solve our problem. ok let's try getting rid of the fully connected graph then. 