# Atmo Model Training Notebook

Train an Atmo Model using `usl_models` lib.

In [11]:
%load_ext autoreload
%autoreload 2
import tensorflow as tf
from usl_models.atmo_ml.model import AtmoModel, AtmoModelParams
from usl_models.atmo_ml import dataset
from google.cloud import storage

# climateiq-study-area-feature-chunks/NYC_Heat/NYC_summer_2000_01p
# Define bucket names and folder paths
data_bucket_name = "climateiq-study-area-feature-chunks"
label_bucket_name = "climateiq-atmospheric-label-chunks"
sim_names = ['NYC_Heat_Test/NYC_summer_2000_01p']
time_steps_per_day = 6
batch_size = 4

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
# Create the training dataset using create_atmo_dataset with sim_names
client = storage.Client(project="climateiq")
ds = dataset.load_dataset(
    data_bucket_name=data_bucket_name,
    label_bucket_name=label_bucket_name,
    sim_names=sim_names,
    time_steps_per_day=time_steps_per_day,
    batch_size=batch_size,
)
train_dataset, val_dataset, test_dataset = dataset.split_dataset(
    ds, train_frac=0.7, val_frac=0.15, test_frac=0.15
)

In [None]:
# Check the structure of the dataset
for model_input, labels in train_dataset.take(1):
    print(f"Spatiotemporal data shape: {model_input['spatiotemporal'].shape}")
    print(f"Spatial data shape: {model_input['spatial'].shape}")
    print(f"LU Index data shape: {model_input['lu_index'].shape}")
    print(f"Labels shape: {labels.shape}")

In [None]:
# Initialize the Atmo Model
model_params = AtmoModelParams()
model = AtmoModel(model_params)
# Compile the model with appropriate optimizer and loss
model.compile(optimizer='adam', loss='mse')

In [None]:
# Train the model
model.fit(train_dataset, epochs=1)

In [None]:
# Test calling the model on some training data
inputs, labels = next(iter(train_dataset))
prediction = model.call(inputs)
print("Prediction shape:", prediction.shape)