In this notebook, we test out hyperparameter tuning with Optuna. 

In [27]:
# ipython extension to autoreload imported modules so that any changes will be up to date before running code in this nb
%load_ext autoreload 
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [28]:
%reload_ext autoreload

In [29]:
from utils.hyperparam_tuning import prepare_study, get_data_config
from utils.jraph_training import create_dataset, train_and_evaluate_with_data
import ml_collections
import tempfile

In [30]:
# set up logging
import logging
logger = logging.getLogger()
logger.setLevel(logging.WARNING)

### run hyperparameter tuning

In [31]:
# get study
study, objective_with_dataset = prepare_study()

TypeError: prepare_study() missing 1 required positional argument: 'study_name'

In [23]:
study.optimize(objective_with_dataset, 
                n_trials=15-len(study.trials), 
                n_jobs=1)

[I 2023-11-30 16:57:13,349] Trial 5 finished with value: 22.660564422607422 and parameters: {'optimizer': 'adam', 'learning_rate': 0.005477855138506957, 'n_blocks': 2, 'dropout_rate': 0.11683037445895036, 'edge_mlp_1_power': 5, 'edge_mlp_2_power': 3, 'node_mlp_1_power': 1, 'node_mlp_2_power': 9}. Best is trial 3 with value: 22.660564422607422.
[W 2023-11-30 16:59:09,760] Trial 6 failed with parameters: {'optimizer': 'sgd', 'learning_rate': 0.00016933345217322904, 'momentum': 0.27042246029031203, 'n_blocks': 8, 'dropout_rate': 0.5332399900416053, 'edge_mlp_1_power': 7, 'edge_mlp_2_power': 1, 'node_mlp_1_power': 8, 'node_mlp_2_power': 3} because of the following error: The value nan is not acceptable.
[W 2023-11-30 16:59:09,761] Trial 6 failed with value Array(nan, dtype=float32).
[W 2023-11-30 17:01:05,846] Trial 7 failed with parameters: {'optimizer': 'sgd', 'learning_rate': 0.00963386459129923, 'momentum': 0.1200246093338188, 'n_blocks': 6, 'dropout_rate': 0.049797152487681635, 'edge_

why are we getting nans? 

let's keep running some more to see if it can get the error under 1

In [24]:
study.optimize(objective_with_dataset, 
                n_trials=30-len(study.trials), 
                n_jobs=1)

[W 2023-11-30 17:09:36,050] Trial 15 failed with parameters: {'optimizer': 'sgd', 'learning_rate': 0.0025771234320143646, 'momentum': 0.96671754713713, 'n_blocks': 5, 'dropout_rate': 0.4503902165026035, 'edge_mlp_1_power': 4, 'edge_mlp_2_power': 1, 'node_mlp_1_power': 8, 'node_mlp_2_power': 4} because of the following error: The value nan is not acceptable.
[W 2023-11-30 17:09:36,052] Trial 15 failed with value Array(nan, dtype=float32).
[W 2023-11-30 17:12:33,109] Trial 16 failed with parameters: {'optimizer': 'adam', 'learning_rate': 0.00011463856548994692, 'n_blocks': 8, 'dropout_rate': 0.3147370342604336, 'edge_mlp_1_power': 9, 'edge_mlp_2_power': 4, 'node_mlp_1_power': 3, 'node_mlp_2_power': 6} because of the following error: The value nan is not acceptable.
[W 2023-11-30 17:12:33,110] Trial 16 failed with value Array(nan, dtype=float32).
[I 2023-11-30 17:12:52,102] Trial 17 finished with value: 22.660564422607422 and parameters: {'optimizer': 'adam', 'learning_rate': 0.0074250077

why are all these trials getting the exact same val error of 22.660564422607422 ???

### debug bad trials

let's look into the trials with the nans to try and troubleshoot what is happening. perhaps itll also give insight into the 22.66 results..

In [18]:
# get training config
trial_29_config = ml_collections.ConfigDict()

# Optimizer.
trial_29_config.optimizer = "sgd"
trial_29_config.learning_rate = 0.00011053076030500855
trial_29_config.momentum = 0.949299633037675

# Data params that are used in training 
trial_29_config.output_steps=4

# Training hyperparameters.
trial_29_config.batch_size = 1 # variable currently not used
trial_29_config.epochs = 5
trial_29_config.log_every_epochs = 5
trial_29_config.eval_every_epochs = 5
trial_29_config.checkpoint_every_epochs = 10

# GNN hyperparameters.
trial_29_config.model = 'MLPGraphNetwork'
trial_29_config.n_blocks = 8
trial_29_config.share_params = False
trial_29_config.dropout_rate = 0.19218851091771358
trial_29_config.skip_connections = False # This was throwing a broadcast error in add_graphs_tuples_nodes when this was set to True
trial_29_config.layer_norm = False # TODO perhaps we want to turn on later

# choose the hidden layer feature size using powers of 2 
trial_29_config.edge_features = (
    2**1,
    2**6,
)
trial_29_config.node_features = (
    2**2,
    2**4,
    2) 
# note the last feature size will be the number of features that the graph predicts
trial_29_config.global_features = None


In [5]:
# get datasets 
dataset_config = get_data_config()
datasets = create_dataset(dataset_config)

In [19]:
# run training 
workdir=tempfile.mkdtemp()
logger.setLevel(logging.INFO)
state, train_metrics, eval_metrics_dict = train_and_evaluate_with_data(
    config=trial_29_config, workdir=workdir, datasets=datasets)

logger.setLevel(logging.WARNING)

INFO:absl:Hyperparameters: {'batch_size': 1, 'checkpoint_every_epochs': 10, 'dropout_rate': 0.19218851091771358, 'edge_features': (2, 64), 'epochs': 5, 'eval_every_epochs': 5, 'global_features': None, 'layer_norm': False, 'learning_rate': 0.00011053076030500855, 'log_every_epochs': 5, 'model': 'MLPGraphNetwork', 'momentum': 0.949299633037675, 'n_blocks': 8, 'node_features': (4, 16, 2), 'optimizer': 'sgd', 'output_steps': 4, 'share_params': False, 'skip_connections': False}
INFO:absl:Initializing network.


processed_graphs.nodes Traced<ShapedArray(float32[36,2])>with<DynamicJaxprTrace(level=1/0)>
> [0;32m/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/utils/jraph_models.py[0m(128)[0;36m__call__[0;34m()[0m
[0;32m    126 [0;31m        [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    127 [0;31m[0;34m[0m[0m
[0m[0;32m--> 128 [0;31m        [0;32mreturn[0m [0;34m[[0m[0mprocessed_graphs[0m[0;34m][0m [0;31m# so that the input and output types will be consistent, and allow nn.Sequential to work[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    129 [0;31m[0;34m[0m[0m
[0m[0;32m    130 [0;31m[0;34m[0m[0m
[0m
processed_graphs.nodes Traced<ShapedArray(float32[36,2])>with<DynamicJaxprTrace(level=1/0)>
> [0;32m/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/utils/jraph_models.py[0m(128)[0;36m__call__[0;34m()[0m
[0;32m    126 [0;31m        [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0

INFO:absl:
+----------------------------------------+----------+------+-----------+--------+
| Name                                   | Shape    | Size | Mean      | Std    |
+----------------------------------------+----------+------+-----------+--------+
| params/MLPBlock_0/MLP_0/Dense_0/bias   | (2,)     | 2    | 0.0       | 0.0    |
| params/MLPBlock_0/MLP_0/Dense_0/kernel | (6, 2)   | 12   | -0.0855   | 0.551  |
| params/MLPBlock_0/MLP_0/Dense_1/bias   | (64,)    | 64   | 0.0       | 0.0    |
| params/MLPBlock_0/MLP_0/Dense_1/kernel | (2, 64)  | 128  | 0.0121    | 0.72   |
| params/MLPBlock_0/MLP_1/Dense_0/bias   | (4,)     | 4    | 0.0       | 0.0    |
| params/MLPBlock_0/MLP_1/Dense_0/kernel | (131, 4) | 524  | 0.000394  | 0.0889 |
| params/MLPBlock_0/MLP_1/Dense_1/bias   | (16,)    | 16   | 0.0       | 0.0    |
| params/MLPBlock_0/MLP_1/Dense_1/kernel | (4, 16)  | 64   | -0.0177   | 0.552  |
| params/MLPBlock_0/MLP_1/Dense_2/bias   | (2,)     | 2    | 0.0       | 0.0    |
| par

processed_graphs.nodes Traced<ConcreteArray([[30.47887     0.        ]
 [33.533257    0.        ]
 [34.019836    0.        ]
 [ 2.9336348   0.        ]
 [23.69712     0.        ]
 [48.554474    0.        ]
 [58.999664    0.        ]
 [20.547878    0.        ]
 [33.964165    0.        ]
 [ 0.         21.087168  ]
 [48.11723     0.        ]
 [39.197285    0.        ]
 [ 0.          0.        ]
 [ 6.0189505   0.        ]
 [ 2.3474889   0.        ]
 [ 0.         14.430358  ]
 [15.018629    0.        ]
 [ 0.          0.        ]
 [17.03293     0.        ]
 [ 0.         15.193044  ]
 [41.03015     0.        ]
 [ 0.         30.221815  ]
 [98.12388     0.        ]
 [ 0.         16.945902  ]
 [ 0.          0.        ]
 [ 0.          0.        ]
 [ 0.          0.        ]
 [ 1.3857424   0.        ]
 [44.797195    0.        ]
 [ 0.          0.        ]
 [ 0.         17.118277  ]
 [23.5069      0.        ]
 [ 0.          5.726407  ]
 [ 0.56623733  0.        ]
 [ 0.          0.        ]
 [45.697906

the preds are initially valid floats but by around the third passthrough turn into nans. why??

i am not sure if this is the culprit but it appears that the error grows exponentially larger over time. 

ok maybe let's put this exploration on hold and try to figure out why those other trials that didn't crash all had errors of 22.66. 

In [7]:
# make training config
trial_24_config = ml_collections.ConfigDict()

# Optimizer.
trial_24_config.optimizer = "adam"
trial_24_config.learning_rate = 0.00047851982698472186

# Data params that are used in training 
trial_24_config.output_steps=4

# Training hyperparameters.
trial_24_config.batch_size = 1 # variable currently not used
trial_24_config.epochs = 5
trial_24_config.log_every_epochs = 5
trial_24_config.eval_every_epochs = 5
trial_24_config.checkpoint_every_epochs = 10

# GNN hyperparameters.
trial_24_config.model = 'MLPGraphNetwork'
trial_24_config.n_blocks = 4
trial_24_config.share_params = False
trial_24_config.dropout_rate = 0.4968507323037491
trial_24_config.skip_connections = False # This was throwing a broadcast error in add_graphs_tuples_nodes when this was set to True
trial_24_config.layer_norm = False # TODO perhaps we want to turn on later

# choose the hidden layer feature size using powers of 2 
trial_24_config.edge_features = (
    2**7,
    2**8,
)
trial_24_config.node_features = (
    2**7,
    2**6,
    2) 
# note the last feature size will be the number of features that the graph predicts
trial_24_config.global_features = None


In [10]:
# run training 
workdir=tempfile.mkdtemp()
logger.setLevel(logging.INFO)

state, train_metrics, eval_metrics_dict = train_and_evaluate_with_data(
    config=trial_24_config, workdir=workdir, datasets=datasets)

logger.setLevel(logging.WARNING)

INFO:absl:Hyperparameters: {'batch_size': 1, 'checkpoint_every_epochs': 10, 'dropout_rate': 0.4968507323037491, 'edge_features': (128, 256), 'epochs': 5, 'eval_every_epochs': 5, 'global_features': None, 'layer_norm': False, 'learning_rate': 0.00047851982698472186, 'log_every_epochs': 5, 'model': 'MLPGraphNetwork', 'n_blocks': 4, 'node_features': (128, 64, 2), 'optimizer': 'adam', 'output_steps': 4, 'share_params': False, 'skip_connections': False}
INFO:absl:Initializing network.


processed_graphs_list [GraphsTuple(nodes=Traced<ShapedArray(float32[36,2])>with<DynamicJaxprTrace(level=1/0)>, edges=Traced<ShapedArray(float32[1296,1])>with<DynamicJaxprTrace(level=1/0)>, receivers=Traced<ShapedArray(int32[1296])>with<DynamicJaxprTrace(level=1/0)>, senders=Traced<ShapedArray(int32[1296])>with<DynamicJaxprTrace(level=1/0)>, globals=Traced<ShapedArray(float32[1,1])>with<DynamicJaxprTrace(level=1/0)>, n_node=Traced<ShapedArray(int32[1])>with<DynamicJaxprTrace(level=1/0)>, n_edge=Traced<ShapedArray(int32[1])>with<DynamicJaxprTrace(level=1/0)>)]
> [0;32m/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/utils/jraph_models.py[0m(186)[0;36m__call__[0;34m()[0m
[0;32m    184 [0;31m        [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    185 [0;31m[0;34m[0m[0m
[0m[0;32m--> 186 [0;31m        [0;32mreturn[0m [0mprocessed_graphs_list[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    187 [0;31m[0;34m[0

INFO:absl:
+----------------------------------------+------------+--------+-----------+--------+
| Name                                   | Shape      | Size   | Mean      | Std    |
+----------------------------------------+------------+--------+-----------+--------+
| params/MLPBlock_0/MLP_0/Dense_0/bias   | (128,)     | 128    | 0.0       | 0.0    |
| params/MLPBlock_0/MLP_0/Dense_0/kernel | (6, 128)   | 768    | 0.00353   | 0.405  |
| params/MLPBlock_0/MLP_0/Dense_1/bias   | (256,)     | 256    | 0.0       | 0.0    |
| params/MLPBlock_0/MLP_0/Dense_1/kernel | (128, 256) | 32,768 | -0.000687 | 0.0881 |
| params/MLPBlock_0/MLP_1/Dense_0/bias   | (128,)     | 128    | 0.0       | 0.0    |
| params/MLPBlock_0/MLP_1/Dense_0/kernel | (515, 128) | 65,920 | 0.000205  | 0.0441 |
| params/MLPBlock_0/MLP_1/Dense_1/bias   | (64,)      | 64     | 0.0       | 0.0    |
| params/MLPBlock_0/MLP_1/Dense_1/kernel | (128, 64)  | 8,192  | 0.000387  | 0.0882 |
| params/MLPBlock_0/MLP_1/Dense_2/bias   | 

processed_graphs_list [GraphsTuple(nodes=Traced<ConcreteArray([[     0.          0.    ]
 [     0.          0.    ]
 [     0.          0.    ]
 [     0.          0.    ]
 [ 26301.068   34053.844 ]
 [     0.          0.    ]
 [     0.     157839.47  ]
 [     0.          0.    ]
 [     0.          0.    ]
 [     0.          0.    ]
 [     0.       2693.6804]
 [ 18218.834       0.    ]
 [     0.          0.    ]
 [     0.      23090.492 ]
 [     0.          0.    ]
 [  3020.5586  19424.756 ]
 [     0.      14149.171 ]
 [     0.          0.    ]
 [ 27222.848       0.    ]
 [     0.          0.    ]
 [     0.          0.    ]
 [     0.          0.    ]
 [     0.          0.    ]
 [  1878.6757 201805.48  ]
 [     0.      19328.293 ]
 [     0.          0.    ]
 [     0.      13839.831 ]
 [     0.       6773.979 ]
 [     0.          0.    ]
 [     0.       9021.793 ]
 [ 22374.08    42674.94  ]
 [     0.          0.    ]
 [  6558.3643      0.    ]
 [ 22153.752       0.    ]
 [     0.      54879

INFO:absl:Finished training step 0.


processed_graphs_list [GraphsTuple(nodes=Traced<ConcreteArray([[     0.          0.    ]
 [  6653.9814      0.    ]
 [     0.      60863.31  ]
 [     0.          0.    ]
 [     0.          0.    ]
 [     0.          0.    ]
 [ 50120.395   55096.844 ]
 [     0.          0.    ]
 [     0.          0.    ]
 [     0.      31649.81  ]
 [  4869.107       0.    ]
 [     0.          0.    ]
 [     0.          0.    ]
 [     0.          0.    ]
 [     0.      23778.049 ]
 [     0.      55112.875 ]
 [     0.          0.    ]
 [     0.      47212.13  ]
 [     0.          0.    ]
 [     0.     108309.055 ]
 [     0.      86366.63  ]
 [     0.          0.    ]
 [     0.          0.    ]
 [     0.      14146.822 ]
 [     0.          0.    ]
 [     0.      25214.145 ]
 [     0.      16184.681 ]
 [     0.          0.    ]
 [     0.      13732.679 ]
 [     0.          0.    ]
 [     0.      31847.443 ]
 [     0.          0.    ]
 [     0.          0.    ]
 [     0.          0.    ]
 [     0.          0

even for trial 24 where it said the error was 22.66, the actual loss we're seeing is also growing exponentially. what is going wrong? 

some ppl online say it's due to the learning rate being too high. but the GNBLock_overfit config, which trained and got errors less than 1.0, had a learning rate of 1e-3, whereas most of these learning rates are on the order of 1e-4..... so how could that be? could it be due to a combination of too many parameters and learning rate? 

maybe we need to retest this whole pipeline. let's use the params from the GNBLock_overfit config because we knew that one returned us an error slightly under 1.0 before. 

In [11]:
from experiments.configs import GNBlock_overfit
GNBlock_overfit_config = GNBlock_overfit.get_config()
GNBlock_overfit_config

F: 8
K: 36
add_self_loops: true
add_undirected_edges: true
add_virtual_node: true
b: 10
batch_size: 3
c: 10
checkpoint_every_epochs: 10
dropout_rate: 0.1
edge_features: !!python/tuple
- 32
- 16
- 8
epochs: 500
eval_every_epochs: 10
global_features: null
h: 1
init_buffer_samples: 100
input_steps: 1
layer_norm: false
learning_rate: 0.001
log_every_epochs: 1
model: MLPBlock
n_samples: 20
node_features: !!python/tuple
- 32
- 64
- 32
- 2
normalize: true
optimizer: adam
output_delay: 8
output_steps: 4
sample_buffer: -12
seed: 42
skip_connections: false
test_pct: 0.1
time_resolution: 120
timestep_duration: 3
train_pct: 0.7
val_pct: 0.2

In [12]:
GNBlock_overfit_datasets = create_dataset(GNBlock_overfit_config)

In [14]:
# run training 
workdir=tempfile.mkdtemp()
logger.setLevel(logging.INFO)

state, train_metrics, eval_metrics_dict = train_and_evaluate_with_data(
    config=GNBlock_overfit_config, workdir=workdir, datasets=GNBlock_overfit_datasets)

logger.setLevel(logging.WARNING)

INFO:absl:Hyperparameters: {'F': 8, 'K': 36, 'add_self_loops': True, 'add_undirected_edges': True, 'add_virtual_node': True, 'b': 10, 'batch_size': 3, 'c': 10, 'checkpoint_every_epochs': 10, 'dropout_rate': 0.1, 'edge_features': (32, 16, 8), 'epochs': 500, 'eval_every_epochs': 10, 'global_features': None, 'h': 1, 'init_buffer_samples': 100, 'input_steps': 1, 'layer_norm': False, 'learning_rate': 0.001, 'log_every_epochs': 1, 'model': 'MLPBlock', 'n_samples': 20, 'node_features': (32, 64, 32, 2), 'normalize': True, 'optimizer': 'adam', 'output_delay': 8, 'output_steps': 4, 'sample_buffer': -12, 'seed': 42, 'skip_connections': False, 'test_pct': 0.1, 'time_resolution': 120, 'timestep_duration': 3, 'train_pct': 0.7, 'val_pct': 0.2}
INFO:absl:Initializing network.
INFO:absl:
+-----------------------------+----------+-------+----------+-------+
| Name                        | Shape    | Size  | Mean     | Std   |
+-----------------------------+----------+-------+----------+-------+
| params

KeyboardInterrupt: 

### try hyperparameter tuning again with fewer params and lower learning rate options

MOVED TO NEW NOTEBOOK since this one is not refreshing the imported modules and keeps running pdb even though i deleted it

In [32]:
# get study
study2, objective_with_dataset = prepare_study(study_name="hparam_study_2")

[I 2023-11-30 21:18:43,885] Using an existing study with name 'hparam_study_2' instead of creating a new one.


In [33]:
study2.optimize(objective_with_dataset, 
                n_trials=30-len(study2.trials), 
                n_jobs=1)

> [0;32m/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/utils/jraph_models.py[0m(186)[0;36m__call__[0;34m()[0m
[0;32m    184 [0;31m[0;34m[0m[0m
[0m[0;32m    185 [0;31m[0;34m[0m[0m
[0m[0;32m--> 186 [0;31m[0;32mdef[0m [0mnaive_const_fn[0m[0;34m([0m[0mgraph[0m[0;34m:[0m [0mjraph[0m[0;34m.[0m[0mGraphsTuple[0m[0;34m)[0m [0;34m->[0m [0mjraph[0m[0;34m.[0m[0mGraphsTuple[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    187 [0;31m    [0;32mreturn[0m [0mgraph[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    188 [0;31m[0;34m[0m[0m
[0m
> [0;32m/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/utils/jraph_models.py[0m(186)[0;36m__call__[0;34m()[0m
[0;32m    184 [0;31m[0;34m[0m[0m
[0m[0;32m    185 [0;31m[0;34m[0m[0m
[0m[0;32m--> 186 [0;31m[0;32mdef[0m [0mnaive_const_fn[0m[0;34m([0m[0mgraph[0m[0;34m:[0m [0mjraph[0m[0;34m.[0m[0mGraphsTuple[0m[0;34m)[0m [0;34m->[0m [0mjraph[0m[0;34m.

[W 2023-11-30 21:21:14,547] Trial 3 failed with parameters: {'learning_rate': 0.0008429681477469999, 'n_blocks': 1, 'dropout_rate': 0.5280166331584394, 'edge_mlp_1_power': 1, 'edge_mlp_2_power': 4, 'node_mlp_1_power': 9} because of the following error: BdbQuit().
Traceback (most recent call last):
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/lorenzvenv/lib/python3.9/site-packages/optuna/study/_optimize.py", line 200, in _run_trial
    value_or_values = func(trial)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/utils/hyperparam_tuning.py", line 65, in objective
    state, train_metrics, eval_metrics_dict = train_and_evaluate_with_data(config=config, workdir=workdir, datasets=datasets)
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/utils/jraph_training.py", line 456, in train_and_evaluate_with_data
    state, metrics_update, _ = train_step_fn(
  File "/Users/h.lu/Documents/_code/_research lorenz code/lorenzGNN/utils/jraph_t