# Atmo Model Training Notebook

Train an Atmo Model using `usl_models` lib.

In [None]:
%load_ext autoreload
%autoreload 2
import tensorflow as tf
from usl_models.atmo_ml.model import AtmoModel, AtmoModelParams
from usl_models.atmo_ml.datasets import create_atmo_dataset

# Define bucket names and folder paths
data_bucket_name = "climateiq-atmospheric-simulation-input"
label_bucket_name = "climateiq-atmospheric-simulation-output"
spatiotemporal_folder = "spatiotemporal"
spatial_folder = "spatial"
lu_index_folder = "lu_index"
label_folder = "labels"
time_steps_per_day = 6
batch_size = 4

In [None]:

# Create the training dataset using create_atmo_dataset with sim_names
train_dataset, val_dataset, test_dataset = create_atmo_dataset(
    data_bucket_name=data_bucket_name,
    label_bucket_name=label_bucket_name,
    spatiotemporal_folder=spatiotemporal_folder,
    spatial_folder=spatial_folder,
    lu_index_folder=lu_index_folder,
    label_folder=label_folder,
    time_steps_per_day=time_steps_per_day,
    batch_size=batch_size,

)

In [None]:
# Check the structure of the dataset
for model_input, labels in train_dataset.take(1):
    print(f"Spatiotemporal data shape: {model_input['spatiotemporal'].shape}")
    print(f"Spatial data shape: {model_input['spatial'].shape}")
    print(f"LU Index data shape: {model_input['lu_index'].shape}")
    print(f"Labels shape: {labels.shape}")

In [None]:
# Initialize the Atmo Model
model_params = AtmoModelParams()
model = AtmoModel(model_params)
# Compile the model with appropriate optimizer and loss
model.compile(optimizer='adam', loss='mse')

In [None]:
# Train the model
model.fit(train_dataset, epochs=1)

In [None]:
# Test calling the model on some training data
inputs, labels = next(iter(train_dataset))
prediction = model.call(inputs)
print("Prediction shape:", prediction.shape)