# Atmo Model Training Notebook

Train an Atmo Model using `usl_models` lib.

In [None]:
%load_ext autoreload
%autoreload 2
import logging
import os, time
import pathlib

import keras

from usl_models.atmo_ml.model import AtmoModel
from usl_models.atmo_ml import dataset, visualizer, vars


logging.getLogger().setLevel(logging.WARNING)
visualizer.init_plt()


batch_size = 4

filecache_path = pathlib.Path("/home/shared/climateiq/filecache")
example_keys=[
    ("NYC_Heat_Test/NYC_summer_2000_01p", "2000-05-25"),
    ("NYC_Heat_Test/NYC_summer_2000_01p", "2000-05-26"),
    ("NYC_Heat_Test/NYC_summer_2000_01p", "2000-05-27"),
    ("NYC_Heat_Test/NYC_summer_2000_01p", "2000-05-28")
]
train_frac = 0.8
train_ds = dataset.load_dataset_cached(
    filecache_path,
    example_keys=example_keys
).batch(batch_size=batch_size)
val_ds = dataset.load_dataset_cached(
    filecache_path,
    example_keys=example_keys
).batch(batch_size=batch_size)

In [None]:
# Initialize the Atmo Model
model = AtmoModel()
model.summary(expand_nested=True)

In [None]:
# Train the model
# Create a unique log directory by appending the current timestamp
log_dir = os.path.join("./logs", "run_" + time.strftime("%Y%m%d-%H%M%S"))
print(log_dir)
tb_callback = keras.callbacks.TensorBoard(log_dir=log_dir)
model.fit(train_ds, val_ds, epochs=2, callbacks=[tb_callback], validation_freq=10)
model.save_model(log_dir + "/model")

In [None]:
# Visualize model outputs for the given variables.

st_var = vars.Spatiotemporal.TT
sto_var = vars.SpatiotemporalOutput.T2

# To load a previous model.
# model = AtmoModel.from_checkpoint(log_dir + "/model")

for input_batch, label_batch in val_ds:
    pred_batch = model.call(input_batch)
    for b, _ in enumerate(label_batch):
        figs = visualizer.plot(
            inputs={k: v[b] for k, v in input_batch.items()},
            label=label_batch[b],
            pred=pred_batch[b],
            st_var=st_var,
            sto_var=sto_var,
        )
        for fig in figs:
            fig.show()