# AtmoML Hyperparameter Tuning

In [None]:
import logging
import keras_tuner
import keras
import time
import pathlib

from usl_models.atmo_ml.model import AtmoModel
from usl_models.atmo_ml import dataset, visualizer, vars


logging.getLogger().setLevel(logging.WARNING)
keras.utils.set_random_seed(812)
visualizer.init_plt()

batch_size = 8
filecache_dir = pathlib.Path("/home/shared/climateiq/filecache")
example_keys = [
    ("NYC_Heat_Test/NYC_summer_2000_01p", "2000-05-25"),
    ("NYC_Heat_Test/NYC_summer_2000_01p", "2000-05-26"),
    ("NYC_Heat_Test/NYC_summer_2000_01p", "2000-05-27"),
    ("NYC_Heat_Test/NYC_summer_2000_01p", "2000-05-28"),
    ("NYC_Heat_Test/NYC_summer_2000_01p", "2000-05-29"),
    ("NYC_Heat_Test/NYC_summer_2000_01p", "2000-06-01"),
    ("NYC_Heat_Test/NYC_summer_2000_01p", "2000-06-02"),
    ("NYC_Heat_Test/NYC_summer_2000_01p", "2000-06-25"),
    ("NYC_Heat_Test/NYC_summer_2000_01p", "2000-06-26"),
    ("NYC_Heat_Test/NYC_summer_2000_01p", "2000-06-27"),
    ("NYC_Heat_Test/NYC_summer_2000_01p", "2000-06-28"),
    ("NYC_Heat_Test/NYC_summer_2000_01p", "2000-07-03"),
    ("NYC_Heat_Test/NYC_summer_2000_01p", "2000-07-25"),
    ("NYC_Heat_Test/NYC_summer_2000_01p", "2000-07-26"),
    ("NYC_Heat_Test/NYC_summer_2000_01p", "2000-07-27"),
    ("NYC_Heat_Test/NYC_summer_2000_01p", "2000-07-28"),
    ("NYC_Heat_Test/NYC_summer_2000_01p", "2000-08-03"),
    ("NYC_Heat_Test/NYC_summer_2000_01p", "2000-08-01"),
    ("NYC_Heat_Test/NYC_summer_2000_01p", "2000-08-27"),
    ("NYC_Heat_Test/NYC_summer_2000_01p", "2000-08-28"),
    ("PHX_Heat_Test/PHX_summer_2008_25p", "2008-05-25"),
    ("PHX_Heat_Test/PHX_summer_2008_25p", "2008-05-26"),
    ("PHX_Heat_Test/PHX_summer_2008_25p", "2008-05-27"),
    ("PHX_Heat_Test/PHX_summer_2008_25p", "2008-05-28"),
]
timestamp = time.strftime("%Y%m%d-%H%M%S")

ds_config = dataset.Config(output_timesteps=2)
train_ds = dataset.load_dataset_cached(
    filecache_dir,
    example_keys=example_keys,
    config=ds_config,
).batch(batch_size=batch_size)
val_ds = dataset.load_dataset_cached(
    filecache_dir,
    example_keys=example_keys,
    config=ds_config,
    shuffle=False,
).batch(batch_size=batch_size)

In [None]:
# Adding all parameters that need to be tuned. 10 max trials.
tuner = keras_tuner.BayesianOptimization(
    AtmoModel.get_hypermodel(
        input_cnn_kernel_size=[3, 5, 7],
        lstm_units=[32],
        lstm_kernel_size=[5],
        lstm_dropout=[0.2],
        lstm_recurrent_dropout=[ 0.3],
        conv1_stride=[1, 2, 3, 4, 5, 7],
        conv2_stride=[2, 5, 7],
        convlstm_stride=[1, 5],
        spatial_filters=[64],
        spatiotemporal_filters=[64],
        spatial_activation=["relu"],
        st_activation=["relu"],
        lstm_activation=["relu"],
        output_activation=["tanh"],
    ),
    objective="val_loss",
    max_trials=10,
    project_name=f"logs/htune_project_{timestamp}",
)

tuner.search_space_summary()

In [None]:
log_dir = f"logs/htune_{timestamp}"
print(log_dir)
tb_callback = keras.callbacks.TensorBoard(log_dir=log_dir)
tuner.search(train_ds, epochs=50, validation_data=val_ds, callbacks=[tb_callback])
best_model, best_hp = tuner.get_best_models()[0], tuner.get_best_hyperparameters()[0]
best_hp.values

In [None]:
# Train the best option further and save.
model = AtmoModel(model=best_model)
tb_callback = keras.callbacks.TensorBoard(log_dir=log_dir)
model.fit(train_ds, val_ds, epochs=1000, callbacks=[tb_callback], validation_freq=10)
model.save_model(log_dir + "/model")

In [None]:
# Plot results
# model = AtmoModel.from_checkpoint(log_dir + "/model")
input_batch, label_batch = next(iter(val_ds))
pred_batch = model.call(input_batch)

for fig in visualizer.plot_batch(
    input_batch=input_batch,
    label_batch=label_batch,
    pred_batch=pred_batch,
    st_var=vars.Spatiotemporal.TT,
    sto_var=vars.SpatiotemporalOutput.T2,
    max_examples=None,
):
    fig.show()