# Atmo Model Training Notebook

Train an Atmo Model using `usl_models` lib.

In [103]:
%load_ext autoreload
%autoreload 2
import tensorflow as tf
import keras
from usl_models.atmo_ml.model import AtmoModel, AtmoModelParams
from usl_models.atmo_ml import dataset
from google.cloud import storage

import logging

logging.getLogger().setLevel(logging.INFO)

# climateiq-study-area-feature-chunks/NYC_Heat/NYC_summer_2000_01p
# Define bucket names and folder paths
data_bucket_name = "climateiq-study-area-feature-chunks"
label_bucket_name = "climateiq-study-area-label-chunks"
time_steps_per_day = 6
batch_size = 25

sim_dirs = [
    ('NYC_Heat_Test', [
        'NYC_summer_2000_01p',
        # 'NYC_summer_2010_99p',
        # 'NYC_summer_2015_50p',
        # 'NYC_summer_2017_25p',
        # 'NYC_summer_2018_75p'
    ]),
    # ('PHX_Heat_Test', [
    #     'PHX_summer_2008_25p',
    #     # 'PHX_summer_2009_50p',
    #     # 'PHX_summer_2011_99p',
    #     # 'PHX_summer_2015_75p',
    #     # 'PHX_summer_2020_01p'
    # ])
]

sim_names = []
for sim_dir, subdirs in sim_dirs:
    for subdir in subdirs:
        sim_names.append(sim_dir + '/' + subdir)

print(sim_names)
client = storage.Client(project="climateiq")


In [104]:
# Create the training dataset using create_atmo_dataset with sim_names
train_frac = 0.8

train_ds = dataset.load_dataset(
    data_bucket_name=data_bucket_name,
    label_bucket_name=label_bucket_name,
    sim_names=sim_names,
    hash_range=(0.0, train_frac),
).batch(batch_size=batch_size)

val_ds = dataset.load_dataset(
    data_bucket_name=data_bucket_name,
    label_bucket_name=label_bucket_name,
    sim_names=sim_names,
    hash_range=(train_frac, 1.0),
).batch(batch_size=batch_size)

INFO:root:sim_name_dates [('NYC_Heat_Test/NYC_summer_2000_01p', '2000-08-24'), ('NYC_Heat_Test/NYC_summer_2000_01p', '2000-07-27'), ('NYC_Heat_Test/NYC_summer_2000_01p', '2000-08-12'), ('NYC_Heat_Test/NYC_summer_2000_01p', '2000-06-14'), ('NYC_Heat_Test/NYC_summer_2000_01p', '2000-08-21'), ('NYC_Heat_Test/NYC_summer_2000_01p', '2000-08-22'), ('NYC_Heat_Test/NYC_summer_2000_01p', '2000-07-01'), ('NYC_Heat_Test/NYC_summer_2000_01p', '2000-08-05'), ('NYC_Heat_Test/NYC_summer_2000_01p', '2000-06-04'), ('NYC_Heat_Test/NYC_summer_2000_01p', '2000-08-10'), ('NYC_Heat_Test/NYC_summer_2000_01p', '2000-08-06'), ('NYC_Heat_Test/NYC_summer_2000_01p', '2000-07-26'), ('NYC_Heat_Test/NYC_summer_2000_01p', '2000-06-17'), ('NYC_Heat_Test/NYC_summer_2000_01p', '2000-08-03'), ('NYC_Heat_Test/NYC_summer_2000_01p', '2000-08-19'), ('NYC_Heat_Test/NYC_summer_2000_01p', '2000-07-28'), ('NYC_Heat_Test/NYC_summer_2000_01p', '2000-07-15'), ('NYC_Heat_Test/NYC_summer_2000_01p', '2000-07-18'), ('NYC_Heat_Test/NYC_

In [89]:
num_samples = 0
for batch in train_ds:
    num_samples += batch[0]['spatiotemporal'].shape[0]
print("Number of samples:", num_samples)

INFO:root:load_day (2000-08-28, NYC_Heat_Test/NYC_summer_2000_01p)
INFO:root:load_day (2000-08-16, NYC_Heat_Test/NYC_summer_2000_01p)
INFO:root:load_day (2000-08-09, NYC_Heat_Test/NYC_summer_2000_01p)
INFO:root:load_day (2000-07-24, NYC_Heat_Test/NYC_summer_2000_01p)
INFO:root:load_day (2000-07-28, NYC_Heat_Test/NYC_summer_2000_01p)
INFO:root:load_day (2000-07-21, NYC_Heat_Test/NYC_summer_2000_01p)
INFO:root:load_day (2000-08-01, NYC_Heat_Test/NYC_summer_2000_01p)
INFO:root:load_day (2000-07-07, NYC_Heat_Test/NYC_summer_2000_01p)
INFO:root:load_day (2000-08-22, NYC_Heat_Test/NYC_summer_2000_01p)
INFO:root:load_day (2000-07-22, NYC_Heat_Test/NYC_summer_2000_01p)
INFO:root:load_day (2000-08-18, NYC_Heat_Test/NYC_summer_2000_01p)
INFO:root:load_day (2000-06-08, NYC_Heat_Test/NYC_summer_2000_01p)
INFO:root:load_day (2000-06-12, NYC_Heat_Test/NYC_summer_2000_01p)
INFO:root:load_day (2000-07-23, NYC_Heat_Test/NYC_summer_2000_01p)
INFO:root:load_day (2000-08-24, NYC_Heat_Test/NYC_summer_2000_

In [90]:
num_samples = 0
for batch in val_ds:
    num_samples += batch[0]['spatiotemporal'].shape[0]
print("Number of samples:", num_samples)

INFO:root:load_day (2000-05-25, NYC_Heat_Test/NYC_summer_2000_01p)
INFO:root:load_day (2000-06-02, NYC_Heat_Test/NYC_summer_2000_01p)
INFO:root:load_day (2000-06-15, NYC_Heat_Test/NYC_summer_2000_01p)
INFO:root:load_day (2000-05-26, NYC_Heat_Test/NYC_summer_2000_01p)
INFO:root:load_day (2000-06-20, NYC_Heat_Test/NYC_summer_2000_01p)
INFO:root:load_day (2000-08-17, NYC_Heat_Test/NYC_summer_2000_01p)
INFO:root:load_day (2000-06-18, NYC_Heat_Test/NYC_summer_2000_01p)
INFO:root:load_day (2000-06-27, NYC_Heat_Test/NYC_summer_2000_01p)
INFO:root:load_day (2000-08-29, NYC_Heat_Test/NYC_summer_2000_01p)
INFO:root:load_day (2000-08-21, NYC_Heat_Test/NYC_summer_2000_01p)
INFO:root:load_day (2000-06-10, NYC_Heat_Test/NYC_summer_2000_01p)
INFO:root:load_day (2000-08-20, NYC_Heat_Test/NYC_summer_2000_01p)
INFO:root:load_day (2000-07-31, NYC_Heat_Test/NYC_summer_2000_01p)
INFO:root:load_day (2000-08-15, NYC_Heat_Test/NYC_summer_2000_01p)
INFO:root:load_day (2000-05-30, NYC_Heat_Test/NYC_summer_2000_

In [105]:
# Initialize the Atmo Model
model_params = AtmoModelParams()
model = AtmoModel(model_params)

In [106]:
import sys
# Set up logging to a file
logging.basicConfig(filename="training_log.txt", level=logging.INFO)
sys.stdout = open("training_log.txt", "w")  # Redirect stdout to file

In [107]:
# Train the model
tb_callback = keras.callbacks.TensorBoard(log_dir="./logs")
model.fit(val_ds, train_ds, epochs=15, callbacks=[tb_callback])




2024-12-16 20:31:28.989633: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:961] layout failed: INVALID_ARGUMENT: Size of values 0 does not match size of permutation 4 @ fanin shape inatmo_conv_lstm_13/conv_lstm/conv_lstm2d_13/while/body/_1/atmo_conv_lstm_13/conv_lstm/conv_lstm2d_13/while/dropout_7/SelectV2-2-TransposeNHWCToNCHW-LayoutOptimizer
INFO:root:load_day (2000-06-04, NYC_Heat_Test/NYC_summer_2000_01p)
INFO:root:load_day (2000-07-01, NYC_Heat_Test/NYC_summer_2000_01p)
INFO:root:load_day (2000-06-30, NYC_Heat_Test/NYC_summer_2000_01p)
INFO:root:load_day (2000-07-18, NYC_Heat_Test/NYC_summer_2000_01p)
INFO:root:load_day (2000-06-26, NYC_Heat_Test/NYC_summer_2000_01p)
INFO:root:load_day (2000-07-14, NYC_Heat_Test/NYC_summer_2000_01p)
INFO:root:load_day (2000-05-28, NYC_Heat_Test/NYC_summer_2000_01p)
INFO:root:load_day (2000-06-23, NYC_Heat_Test/NYC_summer_2000_01p)
INFO:root:load_day (2000-06-01, NYC_Heat_Test/NYC_summer_2000_01p)
INFO:root:load_day (2000-08-27, NYC_Heat_T

KeyboardInterrupt: 

In [None]:
class PrintMetricsCallback(keras.callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        print(f"Epoch {epoch + 1}: Training loss = {logs['loss']}, "
              f"Validation loss = {logs['val_loss']}, "
              f"Accuracy = {logs['accuracy']}, "
              f"Validation accuracy = {logs['val_accuracy']}")

In [None]:
inputs, labels = next(iter(train_ds))
{key: tensor.shape for key, tensor in inputs.items()}

{'spatiotemporal': TensorShape([2, 6, 200, 200, 12]),
 'spatial': TensorShape([2, 200, 200, 22]),
 'lu_index': TensorShape([2, 200, 200])}

In [None]:
model._model.summary()

Model: "atmo_conv_lstm"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 embedding (Embedding)       multiple                  488       
                                                                 
 spatial_cnn (Sequential)    (None, 50, 50, 128)       252992    
                                                                 
 spatiotemporal_cnn (Sequen  (None, None, 50, 50, 64   30480     
 tial)                       )                                   
                                                                 
 conv_lstm (Sequential)      (None, None, 50, 50, 51   45877248  
                             2)                                  
                                                                 
 t2_output_cnn (Sequential)  (None, None, 200, 200,    69729     
                             1)                                  
                                                    

In [None]:
# Test calling the model on some training data
inputs, labels = next(iter(train_ds))
prediction = model.call(inputs)
print("Prediction shape:", prediction.shape)

lu_index_input (2, 200, 200)
lu_index_embedded_flat tf.Tensor(
[[[-0.0290212  -0.00429965  0.0070921  ...  0.04053745 -0.02890884
    0.00442706]
  [ 0.04117144  0.03346105  0.00490087 ...  0.01888017  0.02810446
    0.0290962 ]
  [ 0.01980658  0.01989866 -0.02752843 ...  0.02466523  0.03207915
    0.01597965]
  ...
  [ 0.04117144  0.03346105  0.00490087 ...  0.01888017  0.02810446
    0.0290962 ]
  [-0.00921012  0.02187858 -0.02844076 ...  0.04788701  0.01579146
   -0.04698772]
  [-0.00921012  0.02187858 -0.02844076 ...  0.04788701  0.01579146
   -0.04698772]]

 [[-0.0290212  -0.00429965  0.0070921  ...  0.04053745 -0.02890884
    0.00442706]
  [ 0.04117144  0.03346105  0.00490087 ...  0.01888017  0.02810446
    0.0290962 ]
  [ 0.01980658  0.01989866 -0.02752843 ...  0.02466523  0.03207915
    0.01597965]
  ...
  [ 0.04117144  0.03346105  0.00490087 ...  0.01888017  0.02810446
    0.0290962 ]
  [-0.00921012  0.02187858 -0.02844076 ...  0.04788701  0.01579146
   -0.04698772]
  [-0.0092