In this notebook, we test out hyperparameter tuning with Optuna. again. with GNBlocks and only 5 edges per node. Also we restrict the range of possible parameter values to more closely match GNBLock_baseline because we know we got not-insane results from there. 

similar to sandbox 5 but this time we properly take the average loss. and slightly tweak some params ranges.

In [1]:
# ipython extension to autoreload imported modules so that any changes will be up to date before running code in this nb
%load_ext autoreload 
%autoreload 2

In [2]:
from utils.jraph_training import train_and_evaluate_with_data, create_dataset
# from utils.jraph_models import MLPGraphNetwork
from utils.jraph_data import print_graph_fts
import ml_collections
import optuna 
from functools import partial
from datetime import datetime
import os 

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# set up logging
import logging
logger = logging.getLogger()
logger.setLevel(logging.WARNING)

### set up functions for optuna

In [4]:
CHECKPOINT_PATH = "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/experiments/tuning"

In [5]:
def objective(trial, datasets):
    """ Defines the objective function to be optimized over, aka the validation loss of a model.
    
        Args:
            trial: object which characterizes the current run 
            datasets: dictionary of data. we explicitly pass this in so that we don't have to waste runtime regenerating the same dataset over and over. 
    """
    # create config 
    config = ml_collections.ConfigDict()

    # Optimizer.
    config.optimizer = "adam"
    # config.optimizer = trial.suggest_categorical("optimizer", ["adam", "sgd"])
    config.learning_rate = trial.suggest_float('learning_rate', 4e-4, 1e-2, 
                                               log=True)
    # if config.optimizer == "sgd":
    #     config.momentum = trial.suggest_float('momentum', 0, 0.999) # upper bound is inclusive, and we want to exclude a momentum of 1 because that would yield no decay 

    # Data params that are used in training 
    config.output_steps = 4

    # Training hyperparameters.
    config.batch_size = 1 # variable currently not used
    config.epochs = 10
    config.log_every_epochs = 5
    config.eval_every_epochs = 5
    config.checkpoint_every_epochs = 10

    # GNN hyperparameters.
    config.model = 'MLPBlock'
    config.dropout_rate = trial.suggest_float('dropout_rate', 0, 0.6)
    config.skip_connections = False # This was throwing a broadcast error in add_graphs_tuples_nodes when this was set to True
    config.layer_norm = False # TODO perhaps we want to turn on later

    # choose the hidden layer feature size using powers of 2 
    config.edge_features = (
        2**trial.suggest_int("edge_mlp_1_power", 1, 3), # range 2 - 8; upper bound is inclusive
        2**trial.suggest_int("edge_mlp_2_power", 1, 3), # range 2 - 8
    )
    config.node_features = (
        2**trial.suggest_int("node_mlp_1_power", 1, 7), # range 2 - 128
        # 2**trial.suggest_int("node_mlp_2_power", 1, 8), # range 2 - 512
        2) 
    # note the last feature size will be the number of features that the graph predicts
    config.global_features = None

    # generate a workdir 
    # TODO: check if we actually care about referencing this in the future or if we can just create a temp dir 
    workdir=os.path.join(CHECKPOINT_PATH, "checkpoints", str(datetime.now()))

    # run training 
    state, train_metrics, eval_metrics_dict = train_and_evaluate_with_data(config=config, workdir=workdir, datasets=datasets)
    
    # retrieve and return val loss (MSE)
    print("eval_metrics_dict['val'].loss", eval_metrics_dict['val'].loss)
    print()
    return eval_metrics_dict['val'].loss.total / eval_metrics_dict['val'].loss.count




In [6]:
def get_data_config():
    config = ml_collections.ConfigDict()

    config.n_samples=100
    config.input_steps=1
    config.output_delay=8 # predict 24 hrs into the future 
    config.output_steps=4
    config.timestep_duration=3 # equivalent to 3 hours
    # note a 3 hour timestep resolution would be 5*24/3=40
    # if the time_resolution is 120, then a sampling frequency of 3 would achieve a 3 hour timestep 
    config.sample_buffer = -1 * (config.input_steps + config.output_delay + config.output_steps - 1) # negative buffer so that our sample input are continuous (i.e. the first sample would overlap a bit with consecutive samples) 
        # number of timesteps strictly between the end 
        # of one full sample and the start of the next sample
    config.time_resolution=120 # the number of 
                # raw data points generated per time unit, equivalent to the 
                # number of data points generated per 5 days in the simulation
    config.init_buffer_samples=100
    config.train_pct=0.7
    config.val_pct=0.2
    config.test_pct=0.1
    config.K=36
    config.F=8
    config.c=10
    config.b=10
    config.h=1
    config.seed=42
    config.normalize=True
    config.fully_connected_edges=False

    return config

In [7]:
def prepare_study(study_name):
    # generate dataset 
    dataset_config = get_data_config()
    datasets = create_dataset(dataset_config)
    print_graph_fts(datasets['train']['inputs'][0][0])

    # get the objective function that reuses the pre-generated datasets 
    objective_with_dataset = partial(objective, datasets=datasets)

    # run optimization study
    db_path = os.path.join(CHECKPOINT_PATH, study_name, "optuna_hparam_search.db")
    if not os.path.exists(os.path.join(CHECKPOINT_PATH, study_name)):
        os.makedirs(os.path.join(CHECKPOINT_PATH, study_name))

    study = optuna.create_study(
        study_name=study_name,
        storage=f'sqlite:///{db_path}', # generates a new db if it doesn't exist
        direction='minimize',
        pruner=optuna.pruners.MedianPruner(
            n_startup_trials=5, 
            n_warmup_steps=5
            ), 
        load_if_exists=True, 
    )
    # study.optimize(objective_with_dataset, 
    #                n_trials=25-len(study.trials), 
    #                n_jobs=1)
    
    return study, objective_with_dataset

### try hyperparameter tuning again with fewer params and lower learning rate options

In [8]:
# get study
study8, objective_with_dataset = prepare_study(study_name="hparam_study_8")

Number of nodes: 36
Number of edges: 180
Node features shape: (36, 2)
Edge features shape: (180, 1)
Global features shape: (1, 1)


[I 2023-12-01 00:21:02,578] A new study created in RDB with name: hparam_study_8


In [9]:
study8.optimize(objective_with_dataset, 
                n_trials=5-len(study8.trials), 
                n_jobs=1)

[I 2023-12-01 00:21:09,914] Trial 0 finished with value: 1.133028268814087 and parameters: {'learning_rate': 0.004958867013306832, 'dropout_rate': 0.09142476605100704, 'edge_mlp_1_power': 3, 'edge_mlp_2_power': 1, 'node_mlp_1_power': 4}. Best is trial 0 with value: 1.133028268814087.


eval_metrics_dict['val'].loss Metric.from_output.<locals>.FromOutput(total=Array(22.660564, dtype=float32), count=Array(20., dtype=float32))



ERROR:absl:Could not start profiling: Profile has already been started. Only one profile may be run at a time.
Traceback (most recent call last):
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/periodic_actions.py", line 327, in _start_session
    profiler.start(logdir=self._logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/profiler.py", line 38, in start
    jax.profiler.start_trace(logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/jax/_src/profiler.py", line 111, in start_trace
    raise RuntimeError("Profile has already been started. "
RuntimeError: Profile has already been started. Only one profile may be run at a time.
[I 2023-12-01 00:21:17,425] Trial 1 finished with value: 1.133028268814087 and parameters: {'learning_rate': 0.00395826399056704, 'dropout_rate': 0.018219481715815954, 'ed

eval_metrics_dict['val'].loss Metric.from_output.<locals>.FromOutput(total=Array(22.660564, dtype=float32), count=Array(20., dtype=float32))



ERROR:absl:Could not start profiling: Profile has already been started. Only one profile may be run at a time.
Traceback (most recent call last):
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/periodic_actions.py", line 327, in _start_session
    profiler.start(logdir=self._logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/profiler.py", line 38, in start
    jax.profiler.start_trace(logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/jax/_src/profiler.py", line 111, in start_trace
    raise RuntimeError("Profile has already been started. "
RuntimeError: Profile has already been started. Only one profile may be run at a time.
[I 2023-12-01 00:21:24,712] Trial 2 finished with value: 1.133028268814087 and parameters: {'learning_rate': 0.0017108717729719133, 'dropout_rate': 0.5351367589385111, 'ed

eval_metrics_dict['val'].loss Metric.from_output.<locals>.FromOutput(total=Array(22.660564, dtype=float32), count=Array(20., dtype=float32))



ERROR:absl:Could not start profiling: Profile has already been started. Only one profile may be run at a time.
Traceback (most recent call last):
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/periodic_actions.py", line 327, in _start_session
    profiler.start(logdir=self._logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/profiler.py", line 38, in start
    jax.profiler.start_trace(logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/jax/_src/profiler.py", line 111, in start_trace
    raise RuntimeError("Profile has already been started. "
RuntimeError: Profile has already been started. Only one profile may be run at a time.
[I 2023-12-01 00:21:31,210] Trial 3 finished with value: 1.133028268814087 and parameters: {'learning_rate': 0.003307128265304628, 'dropout_rate': 0.275780233744929, 'edge

eval_metrics_dict['val'].loss Metric.from_output.<locals>.FromOutput(total=Array(22.660564, dtype=float32), count=Array(20., dtype=float32))



ERROR:absl:Could not start profiling: Profile has already been started. Only one profile may be run at a time.
Traceback (most recent call last):
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/periodic_actions.py", line 327, in _start_session
    profiler.start(logdir=self._logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/profiler.py", line 38, in start
    jax.profiler.start_trace(logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/jax/_src/profiler.py", line 111, in start_trace
    raise RuntimeError("Profile has already been started. "
RuntimeError: Profile has already been started. Only one profile may be run at a time.
[I 2023-12-01 00:21:37,965] Trial 4 finished with value: 1.133028268814087 and parameters: {'learning_rate': 0.008851536262518615, 'dropout_rate': 0.542019077077234, 'edge

eval_metrics_dict['val'].loss Metric.from_output.<locals>.FromOutput(total=Array(22.660564, dtype=float32), count=Array(20., dtype=float32))



In [10]:
study8.trials

[FrozenTrial(number=0, state=TrialState.COMPLETE, values=[1.133028268814087], datetime_start=datetime.datetime(2023, 12, 1, 0, 21, 2, 691404), datetime_complete=datetime.datetime(2023, 12, 1, 0, 21, 9, 894042), params={'learning_rate': 0.004958867013306832, 'dropout_rate': 0.09142476605100704, 'edge_mlp_1_power': 3, 'edge_mlp_2_power': 1, 'node_mlp_1_power': 4}, user_attrs={}, system_attrs={}, intermediate_values={}, distributions={'learning_rate': FloatDistribution(high=0.01, log=True, low=0.0004, step=None), 'dropout_rate': FloatDistribution(high=0.6, log=False, low=0.0, step=None), 'edge_mlp_1_power': IntDistribution(high=3, log=False, low=1, step=1), 'edge_mlp_2_power': IntDistribution(high=3, log=False, low=1, step=1), 'node_mlp_1_power': IntDistribution(high=7, log=False, low=1, step=1)}, trial_id=1, value=None),
 FrozenTrial(number=1, state=TrialState.COMPLETE, values=[1.133028268814087], datetime_start=datetime.datetime(2023, 12, 1, 0, 21, 9, 925605), datetime_complete=datetime

ok let's try visualizing this to see what in the hell is going wrong

In [11]:
fig = optuna.visualization.plot_intermediate_values(study8)
fig.show()

[W 2023-12-01 00:21:38,376] You need to set up the pruning feature to utilize `plot_intermediate_values()`


In [12]:
# plot the estimated accuracy surface over hyperparameters:
fig = optuna.visualization.plot_contour(study8, params=['learning_rate', 'dropout_rate'])
fig.show()

In [13]:
# plot the estimated accuracy surface over hyperparameters:
fig = optuna.visualization.plot_contour(study8, params=['learning_rate', 'node_mlp_1_power'])
fig.show()

ok literally why is this obsessed with the number 22.

In [14]:
study8.optimize(objective_with_dataset, 
                n_trials=30-len(study8.trials), 
                n_jobs=1)

ERROR:absl:Could not start profiling: Profile has already been started. Only one profile may be run at a time.
Traceback (most recent call last):
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/periodic_actions.py", line 327, in _start_session
    profiler.start(logdir=self._logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/profiler.py", line 38, in start
    jax.profiler.start_trace(logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/jax/_src/profiler.py", line 111, in start_trace
    raise RuntimeError("Profile has already been started. "
RuntimeError: Profile has already been started. Only one profile may be run at a time.
[I 2023-12-01 00:21:46,145] Trial 5 finished with value: 1.133028268814087 and parameters: {'learning_rate': 0.0013074150782212857, 'dropout_rate': 0.13018841960537633, 'e

eval_metrics_dict['val'].loss Metric.from_output.<locals>.FromOutput(total=Array(22.660564, dtype=float32), count=Array(20., dtype=float32))



ERROR:absl:Could not start profiling: Profile has already been started. Only one profile may be run at a time.
Traceback (most recent call last):
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/periodic_actions.py", line 327, in _start_session
    profiler.start(logdir=self._logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/profiler.py", line 38, in start
    jax.profiler.start_trace(logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/jax/_src/profiler.py", line 111, in start_trace
    raise RuntimeError("Profile has already been started. "
RuntimeError: Profile has already been started. Only one profile may be run at a time.
[I 2023-12-01 00:21:52,996] Trial 6 finished with value: 1.3503010272979736 and parameters: {'learning_rate': 0.0005451661945199484, 'dropout_rate': 0.42500179535736465, '

eval_metrics_dict['val'].loss Metric.from_output.<locals>.FromOutput(total=Array(27.006021, dtype=float32), count=Array(20., dtype=float32))



ERROR:absl:Could not start profiling: Profile has already been started. Only one profile may be run at a time.
Traceback (most recent call last):
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/periodic_actions.py", line 327, in _start_session
    profiler.start(logdir=self._logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/profiler.py", line 38, in start
    jax.profiler.start_trace(logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/jax/_src/profiler.py", line 111, in start_trace
    raise RuntimeError("Profile has already been started. "
RuntimeError: Profile has already been started. Only one profile may be run at a time.
[I 2023-12-01 00:21:59,598] Trial 7 finished with value: 1.152649998664856 and parameters: {'learning_rate': 0.0017758788748046001, 'dropout_rate': 0.5678827187710389, 'ed

eval_metrics_dict['val'].loss Metric.from_output.<locals>.FromOutput(total=Array(23.053, dtype=float32), count=Array(20., dtype=float32))



ERROR:absl:Could not start profiling: Profile has already been started. Only one profile may be run at a time.
Traceback (most recent call last):
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/periodic_actions.py", line 327, in _start_session
    profiler.start(logdir=self._logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/profiler.py", line 38, in start
    jax.profiler.start_trace(logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/jax/_src/profiler.py", line 111, in start_trace
    raise RuntimeError("Profile has already been started. "
RuntimeError: Profile has already been started. Only one profile may be run at a time.
[I 2023-12-01 00:22:06,307] Trial 8 finished with value: 1.0349403619766235 and parameters: {'learning_rate': 0.004943400259940722, 'dropout_rate': 0.06599471548289682, 'e

eval_metrics_dict['val'].loss Metric.from_output.<locals>.FromOutput(total=Array(20.698807, dtype=float32), count=Array(20., dtype=float32))



ERROR:absl:Could not start profiling: Profile has already been started. Only one profile may be run at a time.
Traceback (most recent call last):
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/periodic_actions.py", line 327, in _start_session
    profiler.start(logdir=self._logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/profiler.py", line 38, in start
    jax.profiler.start_trace(logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/jax/_src/profiler.py", line 111, in start_trace
    raise RuntimeError("Profile has already been started. "
RuntimeError: Profile has already been started. Only one profile may be run at a time.
[I 2023-12-01 00:22:13,504] Trial 9 finished with value: 1.3637374639511108 and parameters: {'learning_rate': 0.0009874580324473034, 'dropout_rate': 0.18345443604401604, '

eval_metrics_dict['val'].loss Metric.from_output.<locals>.FromOutput(total=Array(27.27475, dtype=float32), count=Array(20., dtype=float32))



ERROR:absl:Could not start profiling: Profile has already been started. Only one profile may be run at a time.
Traceback (most recent call last):
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/periodic_actions.py", line 327, in _start_session
    profiler.start(logdir=self._logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/profiler.py", line 38, in start
    jax.profiler.start_trace(logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/jax/_src/profiler.py", line 111, in start_trace
    raise RuntimeError("Profile has already been started. "
RuntimeError: Profile has already been started. Only one profile may be run at a time.
[I 2023-12-01 00:22:20,464] Trial 10 finished with value: 1.01261305809021 and parameters: {'learning_rate': 0.009365880664313988, 'dropout_rate': 0.004974525929238784, 'e

eval_metrics_dict['val'].loss Metric.from_output.<locals>.FromOutput(total=Array(20.25226, dtype=float32), count=Array(20., dtype=float32))



ERROR:absl:Could not start profiling: Profile has already been started. Only one profile may be run at a time.
Traceback (most recent call last):
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/periodic_actions.py", line 327, in _start_session
    profiler.start(logdir=self._logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/profiler.py", line 38, in start
    jax.profiler.start_trace(logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/jax/_src/profiler.py", line 111, in start_trace
    raise RuntimeError("Profile has already been started. "
RuntimeError: Profile has already been started. Only one profile may be run at a time.
[I 2023-12-01 00:22:27,039] Trial 11 finished with value: 1.0199716091156006 and parameters: {'learning_rate': 0.008161641667005134, 'dropout_rate': 0.01692027353290193, '

eval_metrics_dict['val'].loss Metric.from_output.<locals>.FromOutput(total=Array(20.399431, dtype=float32), count=Array(20., dtype=float32))



ERROR:absl:Could not start profiling: Profile has already been started. Only one profile may be run at a time.
Traceback (most recent call last):
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/periodic_actions.py", line 327, in _start_session
    profiler.start(logdir=self._logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/profiler.py", line 38, in start
    jax.profiler.start_trace(logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/jax/_src/profiler.py", line 111, in start_trace
    raise RuntimeError("Profile has already been started. "
RuntimeError: Profile has already been started. Only one profile may be run at a time.
[I 2023-12-01 00:22:33,331] Trial 12 finished with value: 1.0179482698440552 and parameters: {'learning_rate': 0.008507869795322127, 'dropout_rate': 3.8285263444292304e-05

eval_metrics_dict['val'].loss Metric.from_output.<locals>.FromOutput(total=Array(20.358965, dtype=float32), count=Array(20., dtype=float32))



ERROR:absl:Could not start profiling: Profile has already been started. Only one profile may be run at a time.
Traceback (most recent call last):
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/periodic_actions.py", line 327, in _start_session
    profiler.start(logdir=self._logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/profiler.py", line 38, in start
    jax.profiler.start_trace(logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/jax/_src/profiler.py", line 111, in start_trace
    raise RuntimeError("Profile has already been started. "
RuntimeError: Profile has already been started. Only one profile may be run at a time.
[I 2023-12-01 00:22:39,657] Trial 13 finished with value: 1.028456687927246 and parameters: {'learning_rate': 0.009334192970484148, 'dropout_rate': 0.006845258791804809, '

eval_metrics_dict['val'].loss Metric.from_output.<locals>.FromOutput(total=Array(20.569134, dtype=float32), count=Array(20., dtype=float32))



ERROR:absl:Could not start profiling: Profile has already been started. Only one profile may be run at a time.
Traceback (most recent call last):
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/periodic_actions.py", line 327, in _start_session
    profiler.start(logdir=self._logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/profiler.py", line 38, in start
    jax.profiler.start_trace(logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/jax/_src/profiler.py", line 111, in start_trace
    raise RuntimeError("Profile has already been started. "
RuntimeError: Profile has already been started. Only one profile may be run at a time.
[I 2023-12-01 00:22:46,524] Trial 14 finished with value: 1.133028268814087 and parameters: {'learning_rate': 0.0059946756059946794, 'dropout_rate': 0.179534030790483, 'ed

eval_metrics_dict['val'].loss Metric.from_output.<locals>.FromOutput(total=Array(22.660564, dtype=float32), count=Array(20., dtype=float32))



ERROR:absl:Could not start profiling: Profile has already been started. Only one profile may be run at a time.
Traceback (most recent call last):
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/periodic_actions.py", line 327, in _start_session
    profiler.start(logdir=self._logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/profiler.py", line 38, in start
    jax.profiler.start_trace(logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/jax/_src/profiler.py", line 111, in start_trace
    raise RuntimeError("Profile has already been started. "
RuntimeError: Profile has already been started. Only one profile may be run at a time.
[I 2023-12-01 00:22:53,846] Trial 15 finished with value: 1.1009385585784912 and parameters: {'learning_rate': 0.0027244639889053377, 'dropout_rate': 0.13101511569012492, 

eval_metrics_dict['val'].loss Metric.from_output.<locals>.FromOutput(total=Array(22.018772, dtype=float32), count=Array(20., dtype=float32))



ERROR:absl:Could not start profiling: Profile has already been started. Only one profile may be run at a time.
Traceback (most recent call last):
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/periodic_actions.py", line 327, in _start_session
    profiler.start(logdir=self._logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/profiler.py", line 38, in start
    jax.profiler.start_trace(logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/jax/_src/profiler.py", line 111, in start_trace
    raise RuntimeError("Profile has already been started. "
RuntimeError: Profile has already been started. Only one profile may be run at a time.
[I 2023-12-01 00:23:01,196] Trial 16 finished with value: 1.0876290798187256 and parameters: {'learning_rate': 0.005997706517496031, 'dropout_rate': 0.2592437987362193, 'e

eval_metrics_dict['val'].loss Metric.from_output.<locals>.FromOutput(total=Array(21.752583, dtype=float32), count=Array(20., dtype=float32))



ERROR:absl:Could not start profiling: Profile has already been started. Only one profile may be run at a time.
Traceback (most recent call last):
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/periodic_actions.py", line 327, in _start_session
    profiler.start(logdir=self._logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/profiler.py", line 38, in start
    jax.profiler.start_trace(logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/jax/_src/profiler.py", line 111, in start_trace
    raise RuntimeError("Profile has already been started. "
RuntimeError: Profile has already been started. Only one profile may be run at a time.
[I 2023-12-01 00:23:08,386] Trial 17 finished with value: 1.10599684715271 and parameters: {'learning_rate': 0.009314193431071939, 'dropout_rate': 0.07668624784797326, 'ed

eval_metrics_dict['val'].loss Metric.from_output.<locals>.FromOutput(total=Array(22.119938, dtype=float32), count=Array(20., dtype=float32))



ERROR:absl:Could not start profiling: Profile has already been started. Only one profile may be run at a time.
Traceback (most recent call last):
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/periodic_actions.py", line 327, in _start_session
    profiler.start(logdir=self._logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/profiler.py", line 38, in start
    jax.profiler.start_trace(logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/jax/_src/profiler.py", line 111, in start_trace
    raise RuntimeError("Profile has already been started. "
RuntimeError: Profile has already been started. Only one profile may be run at a time.
[I 2023-12-01 00:23:15,212] Trial 18 finished with value: 1.133028268814087 and parameters: {'learning_rate': 0.0025620667684577297, 'dropout_rate': 0.3549765948492857, 'e

eval_metrics_dict['val'].loss Metric.from_output.<locals>.FromOutput(total=Array(22.660564, dtype=float32), count=Array(20., dtype=float32))



ERROR:absl:Could not start profiling: Profile has already been started. Only one profile may be run at a time.
Traceback (most recent call last):
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/periodic_actions.py", line 327, in _start_session
    profiler.start(logdir=self._logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/profiler.py", line 38, in start
    jax.profiler.start_trace(logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/jax/_src/profiler.py", line 111, in start_trace
    raise RuntimeError("Profile has already been started. "
RuntimeError: Profile has already been started. Only one profile may be run at a time.
[I 2023-12-01 00:23:21,588] Trial 19 finished with value: 1.133028268814087 and parameters: {'learning_rate': 0.006592299590357683, 'dropout_rate': 0.20558726042904313, 'e

eval_metrics_dict['val'].loss Metric.from_output.<locals>.FromOutput(total=Array(22.660564, dtype=float32), count=Array(20., dtype=float32))



ERROR:absl:Could not start profiling: Profile has already been started. Only one profile may be run at a time.
Traceback (most recent call last):
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/periodic_actions.py", line 327, in _start_session
    profiler.start(logdir=self._logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/profiler.py", line 38, in start
    jax.profiler.start_trace(logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/jax/_src/profiler.py", line 111, in start_trace
    raise RuntimeError("Profile has already been started. "
RuntimeError: Profile has already been started. Only one profile may be run at a time.
[I 2023-12-01 00:23:28,087] Trial 20 finished with value: 1.0339829921722412 and parameters: {'learning_rate': 0.004139007283600871, 'dropout_rate': 0.052915222899609045, 

eval_metrics_dict['val'].loss Metric.from_output.<locals>.FromOutput(total=Array(20.679659, dtype=float32), count=Array(20., dtype=float32))



ERROR:absl:Could not start profiling: Profile has already been started. Only one profile may be run at a time.
Traceback (most recent call last):
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/periodic_actions.py", line 327, in _start_session
    profiler.start(logdir=self._logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/profiler.py", line 38, in start
    jax.profiler.start_trace(logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/jax/_src/profiler.py", line 111, in start_trace
    raise RuntimeError("Profile has already been started. "
RuntimeError: Profile has already been started. Only one profile may be run at a time.
[I 2023-12-01 00:23:34,446] Trial 21 finished with value: 1.022143840789795 and parameters: {'learning_rate': 0.0071012405629852175, 'dropout_rate': 0.0065950682650443085,

eval_metrics_dict['val'].loss Metric.from_output.<locals>.FromOutput(total=Array(20.442877, dtype=float32), count=Array(20., dtype=float32))



ERROR:absl:Could not start profiling: Profile has already been started. Only one profile may be run at a time.
Traceback (most recent call last):
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/periodic_actions.py", line 327, in _start_session
    profiler.start(logdir=self._logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/profiler.py", line 38, in start
    jax.profiler.start_trace(logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/jax/_src/profiler.py", line 111, in start_trace
    raise RuntimeError("Profile has already been started. "
RuntimeError: Profile has already been started. Only one profile may be run at a time.
[I 2023-12-01 00:23:40,642] Trial 22 finished with value: 1.0133659839630127 and parameters: {'learning_rate': 0.009981354498725286, 'dropout_rate': 0.004927023212420835, 

eval_metrics_dict['val'].loss Metric.from_output.<locals>.FromOutput(total=Array(20.267319, dtype=float32), count=Array(20., dtype=float32))



ERROR:absl:Could not start profiling: Profile has already been started. Only one profile may be run at a time.
Traceback (most recent call last):
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/periodic_actions.py", line 327, in _start_session
    profiler.start(logdir=self._logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/profiler.py", line 38, in start
    jax.profiler.start_trace(logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/jax/_src/profiler.py", line 111, in start_trace
    raise RuntimeError("Profile has already been started. "
RuntimeError: Profile has already been started. Only one profile may be run at a time.
[I 2023-12-01 00:23:46,802] Trial 23 finished with value: 1.133028268814087 and parameters: {'learning_rate': 0.00982437699485291, 'dropout_rate': 0.10615666223448592, 'ed

eval_metrics_dict['val'].loss Metric.from_output.<locals>.FromOutput(total=Array(22.660564, dtype=float32), count=Array(20., dtype=float32))



ERROR:absl:Could not start profiling: Profile has already been started. Only one profile may be run at a time.
Traceback (most recent call last):
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/periodic_actions.py", line 327, in _start_session
    profiler.start(logdir=self._logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/profiler.py", line 38, in start
    jax.profiler.start_trace(logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/jax/_src/profiler.py", line 111, in start_trace
    raise RuntimeError("Profile has already been started. "
RuntimeError: Profile has already been started. Only one profile may be run at a time.
[I 2023-12-01 00:23:53,612] Trial 24 finished with value: 1.0369309186935425 and parameters: {'learning_rate': 0.007192446599728663, 'dropout_rate': 0.059847662286855514, 

eval_metrics_dict['val'].loss Metric.from_output.<locals>.FromOutput(total=Array(20.738619, dtype=float32), count=Array(20., dtype=float32))



ERROR:absl:Could not start profiling: Profile has already been started. Only one profile may be run at a time.
Traceback (most recent call last):
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/periodic_actions.py", line 327, in _start_session
    profiler.start(logdir=self._logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/profiler.py", line 38, in start
    jax.profiler.start_trace(logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/jax/_src/profiler.py", line 111, in start_trace
    raise RuntimeError("Profile has already been started. "
RuntimeError: Profile has already been started. Only one profile may be run at a time.
[I 2023-12-01 00:23:59,915] Trial 25 finished with value: 1.0766483545303345 and parameters: {'learning_rate': 0.00982664330110224, 'dropout_rate': 0.14860780380445382, 'e

eval_metrics_dict['val'].loss Metric.from_output.<locals>.FromOutput(total=Array(21.532967, dtype=float32), count=Array(20., dtype=float32))



ERROR:absl:Could not start profiling: Profile has already been started. Only one profile may be run at a time.
Traceback (most recent call last):
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/periodic_actions.py", line 327, in _start_session
    profiler.start(logdir=self._logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/profiler.py", line 38, in start
    jax.profiler.start_trace(logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/jax/_src/profiler.py", line 111, in start_trace
    raise RuntimeError("Profile has already been started. "
RuntimeError: Profile has already been started. Only one profile may be run at a time.
[I 2023-12-01 00:24:06,067] Trial 26 finished with value: 1.133028268814087 and parameters: {'learning_rate': 0.005702690470782735, 'dropout_rate': 0.04853655240689727, 'e

eval_metrics_dict['val'].loss Metric.from_output.<locals>.FromOutput(total=Array(22.660564, dtype=float32), count=Array(20., dtype=float32))



ERROR:absl:Could not start profiling: Profile has already been started. Only one profile may be run at a time.
Traceback (most recent call last):
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/periodic_actions.py", line 327, in _start_session
    profiler.start(logdir=self._logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/profiler.py", line 38, in start
    jax.profiler.start_trace(logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/jax/_src/profiler.py", line 111, in start_trace
    raise RuntimeError("Profile has already been started. "
RuntimeError: Profile has already been started. Only one profile may be run at a time.
[I 2023-12-01 00:24:12,515] Trial 27 finished with value: 1.330989122390747 and parameters: {'learning_rate': 0.007323315502807532, 'dropout_rate': 0.1048291022601456, 'ed

eval_metrics_dict['val'].loss Metric.from_output.<locals>.FromOutput(total=Array(26.619781, dtype=float32), count=Array(20., dtype=float32))



ERROR:absl:Could not start profiling: Profile has already been started. Only one profile may be run at a time.
Traceback (most recent call last):
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/periodic_actions.py", line 327, in _start_session
    profiler.start(logdir=self._logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/profiler.py", line 38, in start
    jax.profiler.start_trace(logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/jax/_src/profiler.py", line 111, in start_trace
    raise RuntimeError("Profile has already been started. "
RuntimeError: Profile has already been started. Only one profile may be run at a time.
[I 2023-12-01 00:24:18,787] Trial 28 finished with value: 1.0462067127227783 and parameters: {'learning_rate': 0.004824769800772931, 'dropout_rate': 0.039161527904321806, 

eval_metrics_dict['val'].loss Metric.from_output.<locals>.FromOutput(total=Array(20.924135, dtype=float32), count=Array(20., dtype=float32))



ERROR:absl:Could not start profiling: Profile has already been started. Only one profile may be run at a time.
Traceback (most recent call last):
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/periodic_actions.py", line 327, in _start_session
    profiler.start(logdir=self._logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/clu/profiler.py", line 38, in start
    jax.profiler.start_trace(logdir)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/jax/_src/profiler.py", line 111, in start_trace
    raise RuntimeError("Profile has already been started. "
RuntimeError: Profile has already been started. Only one profile may be run at a time.
[I 2023-12-01 00:24:25,891] Trial 29 finished with value: 1.0754796266555786 and parameters: {'learning_rate': 0.007395346055601503, 'dropout_rate': 0.08254611510373808, '

eval_metrics_dict['val'].loss Metric.from_output.<locals>.FromOutput(total=Array(21.509592, dtype=float32), count=Array(20., dtype=float32))



In [15]:
fig = optuna.visualization.plot_intermediate_values(study8)
fig.show()

[W 2023-12-01 00:24:26,020] You need to set up the pruning feature to utilize `plot_intermediate_values()`


In [16]:
# plot the estimated accuracy surface over hyperparameters:
fig = optuna.visualization.plot_contour(study8, params=['learning_rate', 'dropout_rate'])
fig.show()

In [20]:
# plot the estimated accuracy surface over hyperparameters:
fig = optuna.visualization.plot_contour(study8, params=['learning_rate', 'edge_mlp_1_power'])
fig.show()

In [None]:
# plot the estimated accuracy surface over hyperparameters:
fig = optuna.visualization.plot_contour(study8, params=['learning_rate', 'node_mlp_1_power'])
fig.show()

In [18]:
print(study8.direction)

StudyDirection.MINIMIZE


In [19]:
study8.trials

[FrozenTrial(number=0, state=TrialState.COMPLETE, values=[1.133028268814087], datetime_start=datetime.datetime(2023, 12, 1, 0, 21, 2, 691404), datetime_complete=datetime.datetime(2023, 12, 1, 0, 21, 9, 894042), params={'learning_rate': 0.004958867013306832, 'dropout_rate': 0.09142476605100704, 'edge_mlp_1_power': 3, 'edge_mlp_2_power': 1, 'node_mlp_1_power': 4}, user_attrs={}, system_attrs={}, intermediate_values={}, distributions={'learning_rate': FloatDistribution(high=0.01, log=True, low=0.0004, step=None), 'dropout_rate': FloatDistribution(high=0.6, log=False, low=0.0, step=None), 'edge_mlp_1_power': IntDistribution(high=3, log=False, low=1, step=1), 'edge_mlp_2_power': IntDistribution(high=3, log=False, low=1, step=1), 'node_mlp_1_power': IntDistribution(high=7, log=False, low=1, step=1)}, trial_id=1, value=None),
 FrozenTrial(number=1, state=TrialState.COMPLETE, values=[1.133028268814087], datetime_start=datetime.datetime(2023, 12, 1, 0, 21, 9, 925605), datetime_complete=datetime