# Atmo Model Training Notebook

Train an Atmo Model using `usl_models` lib.

In [None]:
%load_ext autoreload
%autoreload 2
import tensorflow as tf
import keras
import os, time
import pathlib
from usl_models.atmo_ml.model import AtmoModel, AtmoModelParams
from usl_models.atmo_ml import dataset, visualizer, vars

import logging

logging.getLogger().setLevel(logging.INFO)

# climateiq-study-area-feature-chunks/NYC_Heat/NYC_summer_2000_01p
time_steps_per_day = 6
batch_size = 2

sim_dirs = [
    (
        "NYC_Heat_Test",
        [
            "NYC_summer_2000_01p",
            # 'NYC_summer_2010_99p',
            # 'NYC_summer_2015_50p',
            # 'NYC_summer_2017_25p',
            # 'NYC_summer_2018_75p'
        ],
    ),
    (
        "PHX_Heat_Test",
        [
            # 'PHX_summer_2008_25p',
            # 'PHX_summer_2009_50p',
            # 'PHX_summer_2011_99p',
            # 'PHX_summer_2015_75p',
            # 'PHX_summer_2020_01p'
        ],
    ),
]

sim_names = []
for sim_dir, subdirs in sim_dirs:
    for subdir in subdirs:
        sim_names.append(sim_dir + "/" + subdir)

print(sim_names)


output_vars = [
    vars.SpatiotemporalOutput.RH2
]

filecache_path = pathlib.Path("/home/shared/climateiq/filecache")

example_keys=[
    ("NYC_Heat_Test/NYC_summer_2000_01p", "2000-05-25"),
    ("NYC_Heat_Test/NYC_summer_2000_01p", "2000-05-26"),
    ("NYC_Heat_Test/NYC_summer_2000_01p", "2000-05-27"),
    ("NYC_Heat_Test/NYC_summer_2000_01p", "2000-05-28")
]

train_frac = 0.8
train_ds = dataset.load_dataset_cached(
    filecache_path,
    output_vars=output_vars,
    example_keys=example_keys
).batch(batch_size=batch_size)
val_ds = dataset.load_dataset_cached(
    filecache_path,
    output_vars=output_vars,
    example_keys=example_keys
).batch(batch_size=batch_size)

In [None]:
# Initialize the Atmo Model
model_params = AtmoModelParams(output_vars=output_vars)
model = AtmoModel(model_params)
model.summary(expand_nested=True)

In [None]:
# Train the model
# Create a unique log directory by appending the current timestamp
log_dir = os.path.join("./logs", "run_" + time.strftime("%Y%m%d-%H%M%S"))
print(log_dir)
tb_callback = keras.callbacks.TensorBoard(log_dir=log_dir)
model.fit(train_ds, val_ds, epochs=50, callbacks=[tb_callback], validation_freq=10)

In [None]:
visualizer.init_plt()

for input_batch, label_batch in val_ds:
    preds = model.call(input_batch)
    for b, _ in enumerate(label_batch):
        figs = visualizer.plot(
            inputs={k: v[b] for k, v in input_batch.items()},
            label=tf.expand_dims(label_batch[b], axis=0),
            pred=tf.expand_dims(preds[b], axis=0),
            st_var=vars.Spatiotemporal.RH,
            sto_var=vars.SpatiotemporalOutput.RH2,
        )
        for fig in figs:
            fig.show()