# Atmo Model Training Notebook

Train an Atmo Model using `usl_models` lib.

In [7]:
%load_ext autoreload
%autoreload 2
import tensorflow as tf
import keras
from usl_models.atmo_ml.model import AtmoModel, AtmoModelParams
from usl_models.atmo_ml import dataset
from google.cloud import storage

import logging

logging.getLogger().setLevel(logging.INFO)

# climateiq-study-area-feature-chunks/NYC_Heat/NYC_summer_2000_01p
# Define bucket names and folder paths
data_bucket_name = "climateiq-study-area-feature-chunks"
label_bucket_name = "climateiq-study-area-label-chunks"
time_steps_per_day = 6
batch_size = 16

sim_dirs = [
    ('NYC_Heat_Test', [
        'NYC_summer_2000_01p',
        # 'NYC_summer_2010_99p',
        # 'NYC_summer_2015_50p',
        # 'NYC_summer_2017_25p',
        # 'NYC_summer_2018_75p'
    ]),
    ('PHX_Heat_Test', [
        'PHX_summer_2008_25p',
        # 'PHX_summer_2009_50p',
        # 'PHX_summer_2011_99p',
        # 'PHX_summer_2015_75p',
        # 'PHX_summer_2020_01p'
    ])
]

sim_names = []
for sim_dir, subdirs in sim_dirs:
    for subdir in subdirs:
        sim_names.append(sim_dir + '/' + subdir)

print(sim_names)
client = storage.Client(project="climateiq")


In [3]:
train_frac = 0.8

# Create training dataset with fused spatiotemporal data
train_ds = dataset.load_fake_dataset(
    data_bucket_name=data_bucket_name,
    label_bucket_name=label_bucket_name,
    sim_names=sim_names,
).batch(batch_size=batch_size)

# Create validation dataset with fused spatiotemporal data
val_ds = dataset.load_fake_dataset(
    data_bucket_name=data_bucket_name,
    label_bucket_name=label_bucket_name,
    sim_names=sim_names,
).batch(batch_size=batch_size)


INFO:root:Total simulation days before filtering: 200
2024-12-18 20:43:07.865107: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1929] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 38364 MB memory:  -> device: 0, name: NVIDIA A100-SXM4-40GB, pci bus id: 0000:00:04.0, compute capability: 8.0
INFO:root:load_day (2000-05-24, NYC_Heat_Test/NYC_summer_2000_01p)
INFO:root:Total simulation days before filtering: 200
INFO:root:load_day (2000-05-24, NYC_Heat_Test/NYC_summer_2000_01p)


In [None]:
train_frac = 0.8

# Create training dataset with fused spatiotemporal data
train_ds = dataset.load_dataset(
    data_bucket_name=data_bucket_name,
    label_bucket_name=label_bucket_name,
    sim_names=sim_names,
    hash_range=(0.0, train_frac),
).batch(batch_size=batch_size)

# Create validation dataset with fused spatiotemporal data
val_ds = dataset.load_dataset(
    data_bucket_name=data_bucket_name,
    label_bucket_name=label_bucket_name,
    sim_names=sim_names,
    hash_range=(train_frac, 1.0),
).batch(batch_size=batch_size)


In [None]:
num_samples = 0
for batch in train_ds:
    num_samples += batch[0]['spatiotemporal'].shape[0]
print("Number of samples:", num_samples)

In [None]:
num_samples = 0
for batch in val_ds:
    num_samples += batch[0]['spatiotemporal'].shape[0]
print("Number of samples:", num_samples)

In [4]:
# Initialize the Atmo Model
model_params = AtmoModelParams()
model = AtmoModel(model_params)

In [6]:
import sys
# Set up logging to a file
logging.basicConfig(filename="training_log.txt", level=logging.INFO)
sys.stdout = open("training_log.txt", "w")  # Redirect stdout to file

In [None]:
# Train the model
tb_callback = keras.callbacks.TensorBoard(log_dir="./logs")
model.fit( train_ds, val_ds, epochs=500, callbacks=[tb_callback])



INFO:root:Total generated samples: 1
INFO:root:Total generated samples: 1
2024-12-18 20:49:21.607804: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 2300809643403174713
2024-12-18 20:49:21.607858: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 17742187177231888353
2024-12-18 20:49:21.607873: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 7270693604177293786
2024-12-18 20:49:21.607882: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 11116527186322823712
2024-12-18 20:49:21.607893: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 17456970470947728542
2024-12-18 20:49:21.607918: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 14338868012032743136
INFO:root:Total generated sa

In [None]:
model._model.save("saved_model/atmo_model2", save_format="tf")


In [None]:
from tensorflow.keras.models import load_model

loaded_model = load_model("saved_model/atmo_model")


In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Get predictions from the validation set
predictions = model._model.predict(val_ds)  # Use the underlying Keras model

# Assuming the structure of val_ds returns (input_data, ground_truth)
for input_data, ground_truth in val_ds.take(1):  # Taking just one batch from val_ds
    # Get predicted labels
    predicted_labels = model._model.predict(input_data)
    
    # Visualize the first sample
    fig, axes = plt.subplots(1, 2, figsize=(12, 6))
    
    # Ground Truth Visualization
    axes[0].imshow(ground_truth[0,:, :, :, 0], cmap='viridis')  # Adjust if output shape differs
    axes[0].set_title('Ground Truth')

    # Prediction Visualization
    axes[1].imshow(predicted_labels[0, :,:, :, 0], cmap='viridis')  # Adjust if output shape differs
    axes[1].set_title('Predicted Labels')

    plt.show()
    break  # Break after visualizing one batch


In [None]:
predicted_labels.shape
ground_truth[0, :, :, :].shape
predicted_labels[0, :, :, :].shape
ground_truth[0, 0, :, :, channel_index].shape

In [None]:
# Select the first image in the batch
sample_index = 14  # Index of the sample in the batch
channel_index = 1  # Select the first channel to visualize

# Extract data for visualization
ground_truth_image = ground_truth[sample_index, 0, :, :, channel_index]  # Shape: (200, 200)
predicted_image = predicted_labels[sample_index,0,  :, :, channel_index]  # Shape: (200, 200)

# Visualize the ground truth and predictions
fig, axes = plt.subplots(1, 2, figsize=(12, 6))

# Ground Truth Visualization
axes[0].imshow(ground_truth_image, cmap='viridis')
axes[0].set_title('Ground Truth (Channel 0)')

# Prediction Visualization
axes[1].imshow(predicted_image, cmap='viridis')
axes[1].set_title('Predicted Labels (Channel 0)')

plt.show()

In [None]:
predictions = model._model.predict(val_ds)  # Use the underlying Keras model
# Select the first image in the batch
sample_index = 0  # Index of the sample in the batch
channel_index = 1  # Select the first channel to visualize
for input_data, ground_truth in val_ds.take(1):  # Taking just one batch from val_ds
    # Extract data for visualization
    predicted_labels = model._model.predict(input_data)
    ground_truth_image = ground_truth[sample_index, 0, :, :, channel_index]  # Shape: (200, 200)
    predicted_image = predicted_labels[sample_index,0,  :, :, channel_index]  # Shape: (200, 200)

    # Visualize the ground truth and predictions
    fig, axes = plt.subplots(1, 2, figsize=(12, 6))

    # Ground Truth Visualization
    axes[0].imshow(ground_truth_image, cmap='viridis')
    axes[0].set_title('Ground Truth (Channel 0)')

    # Prediction Visualization
    axes[1].imshow(predicted_image, cmap='viridis')
    axes[1].set_title('Predicted Labels (Channel 0)')

    plt.show()

In [None]:
for batch in train_ds.take(1):
    print(type(batch))  # Check the type
    #print(batch)        # Print batch content to inspect its structure
    break


In [None]:
input_data['spatiotemporal'].shape[0]

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Extract a single batch from the training dataset
for input_data, ground_truth in train_ds.take(1):  # Take one batch
    # Dynamically fetch batch size and channels
    batch_size = 25  # Number of samples in the batch
    num_channels_input = 34 # Input channels
    num_channels_ground_truth = 5  # Ground truth channels

    
    # Number of samples to visualize (adjust based on batch size)
    num_samples_to_plot = min(4, batch_size)  # Plot at most 4 samples

    # Plot input data and corresponding ground truth
    for i in range(num_samples_to_plot):
        fig, axes = plt.subplots(1, 2, figsize=(12, 6))

        # Input Data Visualization (Select Channel 0 for visualization)
        axes[0].imshow(input_data['spatiotemporal'][i, :, :, 2], cmap='viridis')  # Adjust channel index if needed
        axes[0].set_title(f"Input Data - Sample {i+1}")

        # Ground Truth Visualization (Select Channel 0 for visualization)
        axes[1].imshow(ground_truth[i,0, :, :, 0], cmap='viridis')  # Adjust channel index if needed
        axes[1].set_title(f"Ground Truth - Sample {i+1}")

        plt.tight_layout()
        plt.show()

    break  # Exit after visualizing one batch


In [None]:
inputs, labels = next(iter(train_ds))
{key: tensor.shape for key, tensor in inputs.items()}

In [None]:
model._model.summary()

In [None]:
# Test calling the model on some training data
inputs, labels = next(iter(train_ds))
prediction = model.call(inputs)
print("Prediction shape:", prediction.shape)