# Flood Model Training Notebook

Train a Flood ConvLSTM Model using `usl_models` lib.

In [1]:
import tensorflow as tf
import keras_tuner
import time
import keras
import logging
from usl_models.flood_ml import constants
from usl_models.flood_ml.model import FloodModel
from usl_models.flood_ml.model_params import FloodModelParams
from usl_models.flood_ml.dataset import load_dataset_windowed, load_dataset
from usl_models.flood_ml import customloss, dataset

# Setup
logging.getLogger().setLevel(logging.WARNING)
keras.utils.set_random_seed(812)

for gpu in tf.config.list_physical_devices("GPU"):
    tf.config.experimental.set_memory_growth(gpu, True)

timestamp = time.strftime("%Y%m%d-%H%M%S")

# Cities and their config folders
city_config_mapping = {
    "Manhattan": "Manhattan_config",
    "Atlanta": "Atlanta_config",
    # "Phoenix_SM": "PHX_SM",
    # "Phoenix_PV": "PHX_PV",
}

# Rainfall files you want
rainfall_files = [5]  # Only 5 and 6

# Generate sim_names
sim_names = []
for city, config in city_config_mapping.items():
    for rain_id in rainfall_files:
        sim_name = f"{city}-{config}/Rainfall_Data_{rain_id}.txt"
        sim_names.append(sim_name)

print(f"Training on {len(sim_names)} simulations.")
for s in sim_names:
    print(s)

# Now load dataset
ds_config = dataset.Config(
    input_height=10, input_width=10, output_height=10, output_width=10
)


train_dataset = load_dataset_windowed(
    sim_names=sim_names, batch_size=4, dataset_split="train", ds_config=ds_config
).cache()

validation_dataset = load_dataset_windowed(
    sim_names=sim_names, batch_size=4, dataset_split="val", ds_config=ds_config
).cache()

2025-07-09 17:08:10.464694: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-07-09 17:08:10.515807: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-07-09 17:08:10.515838: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-07-09 17:08:10.517134: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-07-09 17:08:10.525683: I tensorflow/core/platform/cpu_feature_guar

Training on 2 simulations.
Manhattan-Manhattan_config/Rainfall_Data_5.txt
Atlanta-Atlanta_config/Rainfall_Data_5.txt


2025-07-09 17:08:15.583242: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1929] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 38364 MB memory:  -> device: 0, name: NVIDIA A100-SXM4-40GB, pci bus id: 0000:00:04.0, compute capability: 8.0


In [None]:
log_dir = f"logs/htune_project_{timestamp}"
print(log_dir)
tb_callback = keras.callbacks.TensorBoard(log_dir=log_dir)

tuner = keras_tuner.BayesianOptimization(
    FloodModel.get_hypermodel(
        lstm_units=[32, 64, 128],
        lstm_kernel_size=[3, 5],
        lstm_dropout=[0.2, 0.3],
        lstm_recurrent_dropout=[0.2, 0.3],
        n_flood_maps=[5],
        m_rainfall=[6],
        loss_scale=[50.0, 100.0, 200.0],
    ),
    objective="val_loss",
    max_trials=1,
    project_name=log_dir,
)

tuner.search(
    train_dataset,
    epochs=2,
    validation_data=validation_dataset,
    callbacks=[tb_callback],
)

best_model, best_hp = tuner.get_best_models()[0], tuner.get_best_hyperparameters()[0]
best_hp.values

In [None]:
# Define final parameters and model
from keras.callbacks import ModelCheckpoint, EarlyStopping

# Define final parameters and model
final_params_dict = best_hp.values.copy()
loss_scale = final_params_dict.pop("loss_scale", 100.0)
final_params = FloodModel.Params(**final_params_dict)
model = FloodModel(params=final_params, loss_scale=loss_scale)
# Define callbacks
callbacks = [
    keras.callbacks.TensorBoard(log_dir=log_dir),
    ModelCheckpoint(
        filepath=log_dir + "/checkpoint",
        save_best_only=True,
        monitor="val_loss",
        mode="min",
        save_format="tf",
    ),
    EarlyStopping(  # <--- ADD THIS
        monitor="val_loss",  # What to monitor
        patience=100,  # Number of epochs with no improvement to wait
        restore_best_weights=True,  # Restore model weights from best epoch
        mode="min",  # "min" because lower val_loss is better
    ),
]

# Train
model.fit(train_dataset, validation_dataset, epochs=2, callbacks=callbacks)

model.save_model(log_dir + "/model")

In [None]:
# Load the model and make predictions
# Load the validation dataset with the correct configuration
ds_configa = dataset.Config(
    input_height=1000, input_width=1000, output_height=1000, output_width=1000
)
validation_dataset = load_dataset_windowed(
    sim_names=sim_names, batch_size=4, dataset_split="val", ds_config=ds_configa
).cache()

loss_scale = 200.0

# Path to trained model
model_path = log_dir + "/model"

# Create the loss function with the correct scale
loss_fn = customloss.make_hybrid_loss(scale=loss_scale)

# Load model with custom loss function
model = tf.keras.models.load_model(model_path, custom_objects={"loss_fn": loss_fn})

model = tf.keras.models.load_model(model_path, custom_objects={"loss_fn": loss_fn})
inputs, labels_ = next(iter(validation_dataset))
prediction = model.call(inputs)
prediction.shape

In [None]:
# This code is used to stitch together chunks of geospatial and spatiotemporal data
# into a full-sized input for the FloodConvLSTM model.
# It assumes that the validation dataset is already loaded and consists of chunks of data.
# It stitches the chunks together to create a full input tensor for the model.
import numpy as np
import tensorflow as tf
from usl_models.flood_ml.dataset import _generate_temporal_tensor
from usl_models.flood_ml.model import FloodModel
from usl_models.flood_ml.model import FloodConvLSTM
import numpy as np
import tensorflow as tf
from usl_models.flood_ml import dataset, customloss
from usl_models.flood_ml.model import FloodConvLSTM
# Constants
chunks_per_row = 1
chunks_per_col = 1
chunk_h, chunk_w = 1000, 1000
n_flood_maps = 5
geo_features = 9
m_rainfall = 6
batch_size = 1
stitched_temp_full = None  # will be (T_max, 6)

# === Configuration ===
chunk_h, chunk_w = 1000, 1000
# === Load validation dataset ===
ds_config = dataset.Config(
    input_height=chunk_h,
    input_width=chunk_w,
    output_height=chunk_h,
    output_width=chunk_w,
)
validation_dataset = dataset.load_dataset_windowed(
    sim_names=sim_names, batch_size=1, dataset_split="val", ds_config=ds_config
).cache()
# === Collect one sample to print its shape ===
sample_input, _ = next(iter(validation_dataset))
print("===> EXAMPLE VALIDATION DATASET SHAPES")
for k, v in sample_input.items():
    print(f"{k}: {v.shape}")
# === Initialize stitched arrays
stitched_geo = np.zeros((chunks_per_col * chunk_h + 300 , chunks_per_row * chunk_w + 300, geo_features), dtype=np.float32)
stitched_st = np.zeros((n_flood_maps, chunks_per_col * chunk_h + 300, chunks_per_row * chunk_w + 300, 1), dtype=np.float32)
# Fill in stitched inputs from chunks
row_idx = 0
col_idx = 0
for inputs, _ in validation_dataset:
    geo = inputs["geospatial"].numpy().squeeze(0)        # shape (H, W, F)
    st = inputs["spatiotemporal"].numpy().squeeze(0)     # shape (N, H, W, 1)
    r = row_idx * chunk_h
    c = col_idx * chunk_w
    temp = inputs["temporal"].numpy()[0]  
    if stitched_temp_full is None:
        stitched_temp_full = temp  # (T_max, 6)
    stitched_geo[r:r+chunk_h, c:c+chunk_w, :] = geo
    stitched_st[:, r:r+chunk_h, c:c+chunk_w, :] = st
    col_idx += 1
    if col_idx == chunks_per_row:
        col_idx = 0
        row_idx += 1
    if row_idx == chunks_per_col:
        break
# === Get full-length temporal input
from usl_models.flood_ml import metastore
from google.cloud import firestore, storage
firestore_client = firestore.Client()
storage_client = storage.Client()
temporal, _ = _generate_temporal_tensor(
    metastore.get_temporal_feature_metadata(firestore_client, sim_names[0]),
    storage_client,
    sim_names[0],
    m_rainfall,
)
temporal = tf.convert_to_tensor(temporal[tf.newaxis, ...])  # shape (1, 864, 6)
# === Final temporal window ===
t = 4  # Must be >= n_flood_maps - 1

temporal_window = FloodConvLSTM._get_temporal_window(
    tf.convert_to_tensor(stitched_temp_full[None, ...]),  # shape (1, T_max, M)
    t=t,
    n=n_flood_maps,
)  # → shape (1, 5, 6)
# === Prepare model input
input_dict = {
    "geospatial": tf.convert_to_tensor(stitched_geo[None, ...]),        # (1, H, W, F)
    "spatiotemporal": tf.convert_to_tensor(stitched_st[None, ...]),     # (1, N, H, W, 1)
    "temporal": temporal_window,                                        # (1, 5, 6)
}
print("\n===> STITCHED INPUT SHAPES")
for k, v in input_dict.items():
    print(f"{k}: {v.shape}")

2025-07-09 17:08:22.436584: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.


===> EXAMPLE VALIDATION DATASET SHAPES
geospatial: (1, 1000, 1000, 9)
temporal: (1, 5, 6)
spatiotemporal: (1, 5, 1000, 1000, 1)


2025-07-09 17:08:24.200758: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.



===> STITCHED INPUT SHAPES
geospatial: (1, 1300, 1300, 9)
spatiotemporal: (1, 5, 1300, 1300, 1)
temporal: (1, 5, 6)


In [3]:
model_path = "/home/se2890/climateiq-cnn-2/logs/htune_project_20250708-184915/model"
loss_fn = customloss.make_hybrid_loss(scale=200.0)
model = tf.keras.models.load_model(model_path, custom_objects={"loss_fn": loss_fn})
# === Run model prediction
prediction = model.call(input_dict)
print("\nPrediction shape:", prediction.shape)

2025-07-09 17:08:30.858396: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:454] Loaded cuDNN version 8900



Prediction shape: (1, 1300, 1300, 1)


In [None]:
# # Test calling the model for n predictions
# full_dataset = load_dataset(sim_names=sim_names, batch_size=1)
# inputs, labels = next(iter(full_dataset))
# predictions = model.call_n(inputs, n=4)
# predictions.shape

In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

from usl_models.flood_ml.dataset import load_dataset_windowed
from usl_models.flood_ml import constants

# Path to trained model
# Known value used during training
loss_scale = 200.0

# Path to trained model
model_path = log_dir + "/model"

# Create the loss function with the correct scale
loss_fn = customloss.make_hybrid_loss(scale=loss_scale)

# Load model with custom loss function
model = tf.keras.models.load_model(model_path, custom_objects={"loss_fn": loss_fn})


# Number of samples to visualize
n_samples = 20

# Loop through the dataset and predict
for i, (input_data, ground_truth) in enumerate(validation_dataset.take(n_samples)):
    ground_truth = ground_truth.numpy().squeeze()
    prediction = model(input_data).numpy().squeeze()

    print(f"\nSample {i+1} Prediction Stats:")
    print("  Min:", prediction.min())
    print("  Max:", prediction.max())
    print("  Mean:", prediction.mean())

    # Choose timestep to plot
    timestep = 1
    gt_t = ground_truth[timestep]
    pred_t = prediction[timestep]
    vmax_val = max(gt_t.max(), pred_t.max())

    # Plot Ground Truth and Prediction
    fig, axes = plt.subplots(1, 2, figsize=(14, 6))
    fig.suptitle(f"Sample {i+1} - Timestep {timestep}", fontsize=16)

    im1 = axes[0].imshow(gt_t, cmap="Blues", vmin=0, vmax=vmax_val)
    axes[0].set_title("Ground Truth")
    axes[0].axis("off")
    plt.colorbar(im1, ax=axes[0], shrink=0.8)

    im2 = axes[1].imshow(pred_t, cmap="Blues", vmin=0, vmax=vmax_val)
    axes[1].set_title("Prediction")
    axes[1].axis("off")
    plt.colorbar(im2, ax=axes[1], shrink=0.8)

    plt.tight_layout()
    plt.show()

In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

from usl_models.flood_ml.dataset import load_dataset_windowed
from usl_models.flood_ml import constants
from sklearn.metrics import mean_absolute_error, mean_squared_error
from skimage.metrics import structural_similarity as ssim
import pandas as pd

# Path to trained model
# Known value used during training
loss_scale = 150.0

# Path to trained model
model_path = "/home/elhajjas/climateiq-cnn-11/usl_models/notebooks/logs/htune_project_20250611-205219/model"

# Create the loss function with the correct scale
loss_fn = customloss.make_hybrid_loss(scale=loss_scale)

# Load model with custom loss function
model = tf.keras.models.load_model(model_path, custom_objects={"loss_fn": loss_fn})


# Assuming validation_dataset is already defined
# Example:
# from usl_models.flood_ml.dataset import load_dataset_windowed
# validation_dataset = load_dataset_windowed(...)

n_samples = 20
timestep = 2
metrics_list = []

for i, (input_data, ground_truth) in enumerate(validation_dataset.take(n_samples)):
    ground_truth = ground_truth.numpy().squeeze()
    prediction = model(input_data).numpy().squeeze()

    gt_t = ground_truth[timestep]
    pred_t = prediction[timestep]
    vmax_val = np.nanpercentile([gt_t, pred_t], 99.5)

    # Mask out NaNs
    mask = ~np.isnan(gt_t)
    gt_flat = gt_t[mask].flatten()
    pred_flat = pred_t[mask].flatten()

    mae = mean_absolute_error(gt_flat, pred_flat)
    rmse = np.sqrt(mean_squared_error(gt_flat, pred_flat))
    bias = np.mean(pred_flat) - np.mean(gt_flat)
    iou = np.logical_and(gt_flat > 0.1, pred_flat > 0.1).sum() / max(
        1, np.logical_or(gt_flat > 0.1, pred_flat > 0.1).sum()
    )
    ssim_val = ssim(gt_t, pred_t, data_range=gt_t.max() - gt_t.min())

    metrics_list.append(
        {
            "Sample": i + 1,
            "MAE": mae,
            "RMSE": rmse,
            "Bias": bias,
            "IoU > 0.1": iou,
            "SSIM": ssim_val,
        }
    )

    # Plot
    fig, axes = plt.subplots(1, 2, figsize=(14, 6))
    fig.suptitle(f"Sample {i+1} - Timestep {timestep}", fontsize=16)

    im1 = axes[0].imshow(gt_t, cmap="Blues", vmin=0, vmax=vmax_val)
    axes[0].set_title("Ground Truth")
    axes[0].axis("off")
    plt.colorbar(im1, ax=axes[0], shrink=0.8)

    im2 = axes[1].imshow(pred_t, cmap="Blues", vmin=0, vmax=vmax_val)
    axes[1].set_title("Prediction")
    axes[1].axis("off")
    plt.colorbar(im2, ax=axes[1], shrink=0.8)

    plt.tight_layout()
    plt.show()

# Convert to DataFrame
df = pd.DataFrame(metrics_list)
print("\n=== Metrics Summary ===")
print(df.describe())

In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

from usl_models.flood_ml.dataset import load_dataset_windowed
from usl_models.flood_ml import constants
from usl_models.flood_ml import customloss
from sklearn.metrics import mean_absolute_error, mean_squared_error
from skimage.metrics import structural_similarity as ssim
import pandas as pd

# Parameters
loss_scale = 200.0
timestep = 3
n_samples = 20

# Paths to models
model_path_1 = (
    "/home/elhajjas/climateiq-cnn-11/usl_models/notebooks/logs/attention/model"
)
model_path_2 = "/home/elhajjas/climateiq-cnn-11/usl_models/notebooks/logs/htune_project_20250612-010926/model"

# Loss function
loss_fn = customloss.make_hybrid_loss(scale=loss_scale)

# Load models
model_1 = tf.keras.models.load_model(model_path_1, custom_objects={"loss_fn": loss_fn})
model_2 = tf.keras.models.load_model(model_path_2, custom_objects={"loss_fn": loss_fn})

# Load validation dataset (ensure it's already prepared)
# Example:
# validation_dataset = load_dataset_windowed(...)

metrics_list = []

for i, (input_data, ground_truth) in enumerate(train_dataset.take(n_samples)):
    ground_truth = ground_truth.numpy().squeeze()

    pred_1 = model_1(input_data).numpy().squeeze()
    pred_2 = model_2(input_data).numpy().squeeze()

    gt_t = ground_truth[timestep]
    pred_1_t = pred_1[timestep]
    pred_2_t = pred_2[timestep]
    vmax_val = np.nanpercentile([gt_t, pred_1_t, pred_2_t], 99.5)

    mask = ~np.isnan(gt_t)
    gt_flat = gt_t[mask].flatten()
    pred_1_flat = pred_1_t[mask].flatten()
    pred_2_flat = pred_2_t[mask].flatten()

    # Compute metrics
    metrics_list.append(
        {
            "Sample": i + 1,
            "MAE_1": mean_absolute_error(gt_flat, pred_1_flat),
            "RMSE_1": np.sqrt(mean_squared_error(gt_flat, pred_1_flat)),
            "Bias_1": np.mean(pred_1_flat) - np.mean(gt_flat),
            "IoU_1": np.logical_and(gt_flat > 0.1, pred_1_flat > 0.1).sum()
            / max(1, np.logical_or(gt_flat > 0.1, pred_1_flat > 0.1).sum()),
            "SSIM_1": ssim(gt_t, pred_1_t, data_range=gt_t.max() - gt_t.min()),
            "MAE_2": mean_absolute_error(gt_flat, pred_2_flat),
            "RMSE_2": np.sqrt(mean_squared_error(gt_flat, pred_2_flat)),
            "Bias_2": np.mean(pred_2_flat) - np.mean(gt_flat),
            "IoU_2": np.logical_and(gt_flat > 0.1, pred_2_flat > 0.1).sum()
            / max(1, np.logical_or(gt_flat > 0.1, pred_2_flat > 0.1).sum()),
            "SSIM_2": ssim(gt_t, pred_2_t, data_range=gt_t.max() - gt_t.min()),
        }
    )

    # Plotting
    fig, axes = plt.subplots(1, 3, figsize=(21, 6))
    fig.suptitle(f"Sample {i+1} - Timestep {timestep}", fontsize=16)

    axes[0].imshow(gt_t, cmap="Blues", vmin=0, vmax=vmax_val)
    axes[0].set_title("Ground Truth")
    axes[0].axis("off")

    axes[1].imshow(pred_1_t, cmap="Blues", vmin=0, vmax=vmax_val)
    axes[1].set_title("attention")
    axes[1].axis("off")

    axes[2].imshow(pred_2_t, cmap="Blues", vmin=0, vmax=vmax_val)
    axes[2].set_title("without attention")
    axes[2].axis("off")

    plt.tight_layout()
    plt.show()

# Summary metrics
df = pd.DataFrame(metrics_list)
print("\n=== Metrics Summary ===")
print(df.describe())