# Atmo Model Training Notebook

Train an Atmo Model using `usl_models` lib.

In [None]:
%load_ext autoreload
%autoreload 2
import keras
import os, time
import pathlib
from usl_models.atmo_ml.model import AtmoModel
from usl_models.atmo_ml import dataset, visualizer, vars

import logging

logging.getLogger().setLevel(logging.WARNING)
keras.utils.set_random_seed(812)
visualizer.init_plt()

batch_size = 8
filecache_dir = pathlib.Path("/home/shared/climateiq/filecache")
example_keys=[
    ("NYC_Heat_Test/NYC_summer_2000_01p", "2000-05-25"),
    ("NYC_Heat_Test/NYC_summer_2000_01p", "2000-05-26"),
    ("NYC_Heat_Test/NYC_summer_2000_01p", "2000-05-27"),
    ("NYC_Heat_Test/NYC_summer_2000_01p", "2000-05-28"),
]

# Create training dataset with fused spatiotemporal data
ds_config = dataset.Config(
    output_timesteps=1)
train_ds = dataset.load_dataset_cached(
    filecache_dir,
    example_keys=example_keys,
    config=ds_config,
).batch(batch_size=batch_size)
val_ds = dataset.load_dataset_cached(
    filecache_dir,
    example_keys=example_keys,
    config=ds_config,
).batch(batch_size=batch_size)

In [None]:
# Initialize the Atmo Model
params = AtmoModel.default_params()
params.update({"output_timesteps": ds_config.output_timesteps, "lstm_units": 64})
model = AtmoModel(params)
model.summary(expand_nested=True)

In [None]:
# Train the model
# Create a unique log directory by appending the current timestamp
log_dir = os.path.join("./logs", "run_" + time.strftime("%Y%m%d-%H%M%S"))
print(log_dir)
tb_callback = keras.callbacks.TensorBoard(log_dir=log_dir)
model.fit(train_ds, val_ds, epochs=100, callbacks=[tb_callback], validation_freq=10)
model.save_model(log_dir + "/model")

In [None]:
for input_batch, label_batch in val_ds.take(1):
    preds = model.call(input_batch)
    for b, _ in enumerate(label_batch):
        figs = visualizer.plot(
            inputs={k: v[b] for k, v in input_batch.items()},
            label=label_batch[b],
            pred=preds[b],
            st_var=vars.Spatiotemporal.RH,
            sto_var=vars.SpatiotemporalOutput.RH2,
        )
        for fig in figs:
            fig.show()
        break