# AtmoML Hyperparameter Tuning

In [10]:
import logging
import keras_tuner
import keras
import time
import pathlib

from usl_models.atmo_ml.model import AtmoModel
from usl_models.atmo_ml import dataset, visualizer, vars


logging.getLogger().setLevel(logging.WARNING)
keras.utils.set_random_seed(812)
visualizer.init_plt()

batch_size = 8
filecache_dir = pathlib.Path("/home/shared/climateiq/filecache")
example_keys = [
    ("NYC_Heat_Test/NYC_summer_2000_01p", "2000-05-25"),
    ("NYC_Heat_Test/NYC_summer_2000_01p", "2000-05-26"),
    ("NYC_Heat_Test/NYC_summer_2000_01p", "2000-05-27"),
    ("NYC_Heat_Test/NYC_summer_2000_01p", "2000-05-28"),
    ("PHX_Heat_Test/PHX_summer_2008_25p", "2008-05-25"),
    ("PHX_Heat_Test/PHX_summer_2008_25p", "2008-05-26"),
    ("PHX_Heat_Test/PHX_summer_2008_25p", "2008-05-27"),
    ("PHX_Heat_Test/PHX_summer_2008_25p", "2008-05-28"),
]
timestamp = time.strftime("%Y%m%d-%H%M%S")

ds_config = dataset.Config(output_timesteps=2)
train_ds = dataset.load_dataset_cached(
    filecache_dir,
    example_keys=example_keys,
    config=ds_config,
).batch(batch_size=batch_size)
val_ds = dataset.load_dataset_cached(
    filecache_dir,
    example_keys=example_keys,
    config=ds_config,
    shuffle=False,
).batch(batch_size=batch_size)

In [11]:
tuner = keras_tuner.BayesianOptimization(
    AtmoModel.get_hypermodel(
        input_cnn_kernel_size=[1, 2, 5],
        lstm_kernel_size=[5],
    ),
    objective="val_loss",
    max_trials=10,
    project_name=f"logs/htune_project_{timestamp}",
)
tuner.search_space_summary()

Search space summary
Default search space size: 2
input_cnn_kernel_size (Choice)
{'default': 1, 'conditions': [], 'values': [1, 2, 5], 'ordered': True}
lstm_kernel_size (Choice)
{'default': 5, 'conditions': [], 'values': [5], 'ordered': True}


In [None]:
log_dir = f"logs/htune_{timestamp}"
print(log_dir)
tb_callback = keras.callbacks.TensorBoard(log_dir=log_dir)
tuner.search(train_ds, epochs=100, validation_data=val_ds, callbacks=[tb_callback])
best_model, best_hp = tuner.get_best_models()[0], tuner.get_best_hyperparameters()[0]
best_hp.values

Trial 2 Complete [00h 02m 13s]
val_loss: 0.02264540083706379

Best val_loss So Far: 0.0167386494576931
Total elapsed time: 00h 04m 25s

Search: Running Trial #3

Value             |Best Value So Far |Hyperparameter
1                 |5                 |input_cnn_kernel_size
5                 |5                 |lstm_kernel_size

Epoch 1/100


2025-03-02 02:14:15.051968: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:961] layout failed: INVALID_ARGUMENT: Size of values 0 does not match size of permutation 4 @ fanin shape inatmo_conv_lstm/conv_lstm/conv_lstm2d/while/body/_1/atmo_conv_lstm/conv_lstm/conv_lstm2d/while/dropout_7/SelectV2-2-TransposeNHWCToNCHW-LayoutOptimizer


      1/Unknown - 11s 11s/step - loss: 0.1514 - mean_absolute_error: 0.2943 - root_mean_squared_error: 0.3891 - mean_absolute_percentage_error: 1594.8286 - nrmse: 0.3891 - ssim_metric: 0.1150 - psnr_metric: 8.3367 - mse_RH2: 0.1815 - mse_T2: 0.1617 - mse_WSPD_WDIR10: 0.0016 - mse_WSPD_WDIR10_COS: 0.1996 - mse_WSPD_WDIR10_SIN: 0.2125

In [None]:
# Train the best option further and save.
model = AtmoModel(model=best_model)
tb_callback = keras.callbacks.TensorBoard(log_dir=log_dir)
model.fit(train_ds, val_ds, epochs=200, callbacks=[tb_callback], validation_freq=10)
model.save_model(log_dir + "/model")

Epoch 1/10


2025-03-02 02:07:23.524506: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:961] layout failed: INVALID_ARGUMENT: Size of values 0 does not match size of permutation 4 @ fanin shape inatmo_conv_lstm/conv_lstm/conv_lstm2d/while/body/_1/atmo_conv_lstm/conv_lstm/conv_lstm2d/while/dropout_7/SelectV2-2-TransposeNHWCToNCHW-LayoutOptimizer


Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<keras.src.callbacks.History at 0x7fbfb872b750>

In [None]:
# Plot results
model = AtmoModel.from_checkpoint(log_dir + "/model")
input_batch, label_batch = next(iter(val_ds))
pred_batch = model.call(input_batch)

for fig in visualizer.plot_batch(
    input_batch=input_batch,
    label_batch=label_batch,
    pred_batch=pred_batch,
    st_var=vars.Spatiotemporal.RH,
    sto_var=vars.SpatiotemporalOutput.RH2,
    max_examples=None,
):
    fig.show()