# Atmo Model Training Notebook

Train an Atmo Model using `usl_models` lib.

In [None]:
%load_ext autoreload
%autoreload 2
import keras
import os, time
import pathlib
from usl_models.atmo_ml.model import AtmoModel, AtmoModelParams
from usl_models.atmo_ml import dataset
from google.cloud import storage

import logging

logging.getLogger().setLevel(logging.INFO)

data_bucket_name = "climateiq-study-area-feature-chunks"
label_bucket_name = "climateiq-study-area-label-chunks"
batch_size = 2

sim_dirs = [
    (
        "NYC_Heat_Test",
        [
            "NYC_summer_2000_01p",
            # 'NYC_summer_2010_99p',
            # 'NYC_summer_2015_50p',
            # 'NYC_summer_2017_25p',
            # 'NYC_summer_2018_75p'
        ],
    ),
    (
        "PHX_Heat_Test",
        [
            # 'PHX_summer_2008_25p',
            # 'PHX_summer_2009_50p',
            # 'PHX_summer_2011_99p',
            # 'PHX_summer_2015_75p',
            # 'PHX_summer_2020_01p'
        ],
    ),
]

filecache_dir = pathlib.Path("/home/shared/climateiq/filecache")

sim_names = []
for sim_dir, subdirs in sim_dirs:
    for subdir in subdirs:
        sim_names.append(sim_dir + "/" + subdir)

print(sim_names)

In [None]:
# Initialize the Atmo Model
model_params = AtmoModelParams()
model = AtmoModel(model_params)
model.summary(expand_nested=True)

In [None]:
# Train the model

example_keys = None  # To load entire dataset
# example_keys = [("NYC_Heat_Test/NYC_summer_2000_01p", "2000-05-25")]  # To test on single example
train_frac = 0.8

train_ds = dataset.load_dataset_cached(
    filecache_dir, example_keys=example_keys, hash_range=(0, train_frac)
).batch(batch_size=batch_size)
val_ds = dataset.load_dataset_cached(
    filecache_dir, example_keys=example_keys, hash_range=(train_frac, 1.0)
).batch(batch_size=batch_size)


log_dir = os.path.join("./logs", "run_" + time.strftime("%Y%m%d-%H%M%S"))
tb_callback = keras.callbacks.TensorBoard(log_dir=log_dir)
print("Tensorboard log directory:", log_dir)

model.fit(train_ds, val_ds, epochs=150, callbacks=[tb_callback])

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Get predictions from the validation set
predictions = model._model.predict(val_ds)  # Use the underlying Keras model

# Assuming the structure of val_ds returns (input_data, ground_truth)
for input_data, ground_truth in val_ds.take(1):  # Taking just one batch from val_ds
    # Get predicted labels
    predicted_labels = model._model.predict(input_data)

    # Compute shared vmin and vmax for consistent color range
    vmin = min(
        np.min(ground_truth[:, 0, :, :, 0]), np.min(predicted_labels[:, 0, :, :, 0])
    )
    vmax = max(
        np.max(ground_truth[:, 0, :, :, 0]), np.max(predicted_labels[:, 0, :, :, 0])
    )

    # Visualize the first sample
    fig, axes = plt.subplots(1, 2, figsize=(12, 6), dpi=150)  # Higher DPI for quality

    # Ground Truth Visualization
    img1 = axes[0].imshow(
        ground_truth[0, 0, :, :, 0], cmap="viridis", vmin=90, vmax=100
    )
    axes[0].set_title("Ground Truth")
    plt.colorbar(img1, ax=axes[0], fraction=0.046, pad=0.04)

    # Prediction Visualization
    img2 = axes[1].imshow(
        predicted_labels[0, 0, :, :, 0], cmap="viridis", vmin=vmin, vmax=vmax
    )
    axes[1].set_title("Predicted Labels")
    plt.colorbar(img2, ax=axes[1], fraction=0.046, pad=0.04)

    plt.tight_layout()
    plt.show()
    break  # Break after visualizing one batch