# Flood Model Training Notebook

Train a Flood ConvLSTM Model using `usl_models` lib.

In [None]:
from usl_models.flood_ml.dataset import load_dataset_windowed_cached

import tensorflow as tf
import keras
import keras_tuner
import time
from datetime import datetime
import logging

from usl_models.flood_ml.model import FloodModel
from usl_models.flood_ml.dataset import load_dataset_windowed

import pathlib

# === CONFIG ===
# Set random seeds and GPU memory growth
logging.getLogger().setLevel(logging.WARNING)
keras.utils.set_random_seed(812)

for gpu in tf.config.list_physical_devices("GPU"):
    tf.config.experimental.set_memory_growth(gpu, True)

timestamp = time.strftime("%Y%m%d-%H%M%S")
log_dir = f"logs/training_{timestamp}"

In [None]:
# :package: Download, :steam_locomotive: Train, and :floppy_disk: Save FloodML model from cached dataset
filecache_dir = pathlib.Path("/home/shared/climateiq/filecache")
city_config_mapping = {
    "Manhattan": "Manhattan_config",
    # "Atlanta": "Atlanta_config",
    # "Phoenix_SM": "PHX_SM",
    # "Phoenix_PV": "PHX_PV",
    # "Phoenix_central": "PHX_CCC"
    # "Atlanta_Prediction": "Atlanta_config",
}
# Rainfall files you want
rainfall_files = [7, 5, 13, 11, 9, 16, 15, 10, 12, 2, 3]  # Only 5 and 6
# rainfall_files = [5]  # Only 5 and 6
dataset_splits = ["test", "train", "val"]
n_flood_maps = 5
m_rainfall = 6
batch_size = 10
epochs = 2
# Generate sim_names
sim_names = []
for city, config in city_config_mapping.items():
    for rain_id in rainfall_files:
        sim_name = f"{city}-{config}/Rainfall_Data_{rain_id}.txt"
        sim_names.append(sim_name)

# === STEP 1: DOWNLOAD DATASET TO FILECACHE ===
# print("Downloading simulations into local cache")
# download_dataset(
#     sim_names=sim_names,
#     output_path=filecache_dir,
#     dataset_splits=dataset_splits,
#    include_labels=True
# )


# print(":white_check_mark: Download complete.")

In [None]:
# for fatser loading during hyperparameter tuning use this function
def get_datasets(batch_size=2):
    filecache_dir = pathlib.Path("/home/shared/climateiq/filecache")
    city_config_mapping = {"Manhattan": "Manhattan_config"}
    # rainfall_files = [7, 5, 13, 11, 9, 16, 15, 10, 12, 2, 3]
    rainfall_files = [7, 5, 16, 15]  # Only 5 and 6
    m_rainfall = 6
    n_flood_maps = 5

    sim_names = []
    for city, config in city_config_mapping.items():
        for rain_id in rainfall_files:
            sim_names.append(f"{city}-{config}/Rainfall_Data_{rain_id}.txt")
    print("Sim names in use:")
    for s in sim_names:
        print("  ", s, (filecache_dir / s).exists())

    train_ds = load_dataset_windowed_cached(
        filecache_dir=filecache_dir,
        sim_names=sim_names,
        dataset_split="train",
        batch_size=batch_size,
        n_flood_maps=n_flood_maps,
        m_rainfall=m_rainfall,
        shuffle=True,
    )

    val_ds = load_dataset_windowed_cached(
        filecache_dir=filecache_dir,
        sim_names=sim_names,
        dataset_split="val",
        batch_size=batch_size,
        n_flood_maps=n_flood_maps,
        m_rainfall=m_rainfall,
        shuffle=True,
    )

    return train_ds, val_ds

In [None]:
train_dataset, val_dataset = get_datasets(batch_size=2)
print("Train dataset:", train_dataset)
print("Validation dataset:", val_dataset)
# %%
# Get one batch
for batch in train_dataset.take(1):
    x, y = batch
    print("\nüîπ Input keys:", x.keys())
    print("üîπ Geospatial shape:", x["geospatial"].shape)
    print("üîπ Temporal shape:", x["temporal"].shape)
    print("üîπ Spatiotemporal shape:", x["spatiotemporal"].shape)
    print("üîπ Label shape:", y.shape)
    print("üîπ Example pixel values:", y.numpy()[0, :5, :5])

In [None]:
# Skip this step if dataset is already downloaded
# # === STEP 2: LOAD CACHED WINDOWED DATASETS ===
# print("open_file_folder: Loading datasets from cache")
train_dataset = load_dataset_windowed_cached(
    filecache_dir=filecache_dir,
    sim_names=sim_names,
    dataset_split="train",
    batch_size=batch_size,
    n_flood_maps=n_flood_maps,
    m_rainfall=m_rainfall,
    shuffle=True,
).prefetch(tf.data.AUTOTUNE)
# train_dataset = train_dataset.cache("/tmp/train_cache")
# dataset = train_dataset.map(lambda x, y: (x, y), num_parallel_calls=tf.data.AUTOTUNE)
# train_dataset = dataset.prefetch(tf.data.AUTOTUNE)

validation_dataset = load_dataset_windowed_cached(
    filecache_dir=filecache_dir,
    sim_names=sim_names,
    dataset_split="val",
    batch_size=batch_size,
    n_flood_maps=n_flood_maps,
    m_rainfall=m_rainfall,
    shuffle=True,
).prefetch(tf.data.AUTOTUNE)
# validation_dataset = validation_dataset.cache("/tmp/val_cache")
# validation_dataset = validation_dataset.map(lambda x, y: (x, y), num_parallel_calls=tf.data.AUTOTUNE)
# validation_dataset = validation_dataset.prefetch(tf.data.AUTOTUNE)

train_dataset = train_dataset.map(
    lambda x, y: (x, y), num_parallel_calls=tf.data.AUTOTUNE
)
train_dataset = train_dataset.prefetch(tf.data.AUTOTUNE)
validation_dataset = validation_dataset.map(
    lambda x, y: (x, y), num_parallel_calls=tf.data.AUTOTUNE
)
validation_dataset = validation_dataset.prefetch(tf.data.AUTOTUNE)

test_dataset = load_dataset_windowed_cached(
    filecache_dir=filecache_dir,
    sim_names=sim_names,
    dataset_split="test",
    batch_size=batch_size,
    n_flood_maps=n_flood_maps,
    m_rainfall=m_rainfall,
    shuffle=True,
).prefetch(tf.data.AUTOTUNE)

In [None]:
# %%
# === Debug dataset structure and counts ===
train_ds, val_ds = get_datasets(batch_size=2)
num_samples = 0
example_shapes = None
unique_chunks = set()
train_dataset = train_ds
print("üîç Scanning train_dataset...")
for i, (inputs, labels) in enumerate(train_dataset):
    num_samples += inputs["geospatial"].shape[0]
    example_shapes = {
        "geospatial": inputs["geospatial"].shape,
        "temporal": inputs["temporal"].shape,
        "spatiotemporal": inputs["spatiotemporal"].shape,
        "labels": labels.shape,
    }
    # Optional: show the first few batches
    if i < 2:
        print(f"\nBatch {i}:")
        print(f"  geospatial: {inputs['geospatial'].shape}")
        print(f"  temporal: {inputs['temporal'].shape}")
        print(f"  spatiotemporal: {inputs['spatiotemporal'].shape}")
        print(f"  labels: {labels.shape}")
    if i % 100 == 0 and i > 0:
        print(f"  Processed {i} batches so far...")

print("\n‚úÖ Dataset scan complete.")
print(f"Total samples (windows): {num_samples}")
print(f"Example tensor shapes: {example_shapes}")

In [None]:
# %%
print(f"[DEBUG] Labels shape before windowing: {labels.shape}")

# %%
for inputs, labels in train_dataset.take(1):
    print(inputs["geospatial"].shape, labels.shape)

In [None]:
# If working locally, comment out this block


# Cities and their config folders
city_config_mapping = {
    "Manhattan": "Manhattan_config",
    # "Atlanta": "Atlanta_config",
    # "Phoenix_SM": "PHX_SM",
    # "Phoenix_PV": "PHX_PV",
}

# Rainfall files you want
rainfall_files = [5]  # Only 5 and 6

# Generate sim_names
sim_names = []
for city, config in city_config_mapping.items():
    for rain_id in rainfall_files:
        sim_name = f"{city}-{config}/Rainfall_Data_{rain_id}.txt"
        sim_names.append(sim_name)

print(f"Training on {len(sim_names)} simulations.")
for s in sim_names:
    print(s)

# Now load dataset
train_dataset = load_dataset_windowed(
    sim_names=sim_names, batch_size=2, dataset_split="train"
).cache()

validation_dataset = load_dataset_windowed(
    sim_names=sim_names, batch_size=2, dataset_split="val"
).cache()

In [None]:
import gc

tuner = keras_tuner.BayesianOptimization(
    FloodModel.get_hypermodel(
        lstm_units=[32, 64, 128],
        lstm_kernel_size=[3, 5],
        lstm_dropout=[0.2, 0.3],
        lstm_recurrent_dropout=[0.2, 0.3],
        n_flood_maps=[5],
        m_rainfall=[6],
    ),
    objective="val_loss",
    max_trials=10,  # increase if you want more search
    project_name=log_dir,
)

tb_callback = keras.callbacks.TensorBoard(
    log_dir=log_dir,
    histogram_freq=0,
    profile_batch=0,
)


def tuner_search(batch_size=2, num_train_samples=200, num_val_samples=100):
    """
    Run Bayesian optimization tuner on a limited number of *samples*.
    Automatically computes how many batches are needed based on batch_size.
    """
    # Clear memory and TensorFlow graph
    gc.collect()
    tf.keras.backend.clear_session()

    # Get datasets
    train_ds, val_ds = get_datasets(batch_size=batch_size)

    # Convert sample counts ‚Üí batch counts
    num_train_batches = max(1, num_train_samples // batch_size)
    num_val_batches = max(1, num_val_samples // batch_size)

    print(
        f"Using {num_train_batches} train batches "
        f"({num_train_batches * batch_size} samples)"
    )
    print(
        f"Using {num_val_batches} validation batches "
        f"({num_val_batches * batch_size} samples)"
    )

    # Run tuner
    tuner.search(
        train_ds.take(num_train_batches),
        validation_data=val_ds.take(num_val_batches),
        epochs=2,
        callbacks=[tb_callback],
        verbose=1,
    )


# Enable GPU operation logging (optional)
tf.debugging.set_log_device_placement(True)

# Run tuner
tuner_search(batch_size=2, num_train_samples=200, num_val_samples=50)

# Retrieve best model and hyperparameters
best_hp = tuner.get_best_hyperparameters()[0]
best_model = tuner.hypermodel.build(best_hp)
print("Best hyperparameters:", best_hp.values)

In [None]:
tuner = keras_tuner.BayesianOptimization(
    FloodModel.get_hypermodel(
        lstm_units=[32, 64, 128],
        lstm_kernel_size=[3, 5],
        lstm_dropout=[0.2, 0.3],
        lstm_recurrent_dropout=[0.2, 0.3],
        n_flood_maps=[5],
        m_rainfall=[6],
    ),
    objective="val_loss",
    max_trials=1,
    project_name=f"logs/htune_project_{timestamp}",
)

tuner.search_space_summary()

In [None]:
tf.debugging.set_log_device_placement(True)
log_dir = f"logs/htune_project_{timestamp}"
print(log_dir)
tb_callback = keras.callbacks.TensorBoard(log_dir=log_dir)
tuner.search(
    train_dataset.take(200),
    epochs=10,
    validation_data=validation_dataset.take(50),
    callbacks=[tb_callback],
)
best_model, best_hp = tuner.get_best_models()[0], tuner.get_best_hyperparameters()[0]
best_hp.values

In [None]:
# %%
from keras.callbacks import ModelCheckpoint, EarlyStopping

train_ds, val_ds = get_datasets(batch_size=2)
# Define final parameters and model
final_params_dict = best_hp.values.copy()
final_params = FloodModel.Params(**final_params_dict)
model = FloodModel(params=final_params)
# Define callbacks
callbacks = [
    keras.callbacks.TensorBoard(log_dir=log_dir),
    ModelCheckpoint(
        filepath=log_dir + "/checkpoint",
        save_best_only=True,
        monitor="val_loss",
        mode="min",
        save_format="tf",
    ),
    EarlyStopping(  # <--- ADD THIS
        monitor="val_loss",  # What to monitor
        patience=100,  # Number of epochs with no improvement to wait
        restore_best_weights=True,  # Restore model weights from best epoch
        mode="min",  # "min" because lower val_loss is better
    ),
]
tf.debugging.set_log_device_placement(True)
# Train
model.fit(train_ds, val_ds, epochs=10, callbacks=callbacks)

# Save final model
model.save_model(log_dir + "/model")

In [None]:
# # Test calling the model on some data.
inputs, labels_ = next(iter(train_dataset))
prediction = model.call(inputs)
prediction.shape

In [None]:
# Prediction mode
from usl_models.flood_ml import dataset

# Parameters
filecache_dir = pathlib.Path("/home/shared/climateiq/filecache")
# prediction
sim_name = ["Atlanta_Prediction"]
rainfall_sim = "Atlanta-Atlanta_config/Rainfall_Data_22.txt"


# Download (prediction mode)
# dataset.download_dataset(
#     sim_names=sim_name,          # study area
#     output_path=filecache_dir,
#     include_labels=False,                      # no labels
#     rainfall_sim_name=rainfall_sim,  # simulation for temporal vector
#     allow_missing_sim=True                     # skip temporal if missing
# )
# prediction mode
# # # Load dataset
full_dataset = dataset.load_dataset_cached(
    filecache_dir=filecache_dir,
    sim_names=sim_name,  # study area
    dataset_split=None,  # no split for prediction
    batch_size=2,
    include_labels=False,
    rainfall_sim_name=rainfall_sim,  # actual rainfall sim
)


# Download (training mode)
# dataset_splits = ["test", "train", "val"]
# dataset.download_dataset(
#     sim_names=["Atlanta-Atlanta_config/Rainfall_Data_22.txt"],  # normal simulations
#     output_path=filecache_dir,
#     dataset_splits=dataset_splits,               # train/val/test splits
#     include_labels=True                        # get labels too
# )

# full_dataset = dataset.load_dataset_cached(
#     filecache_dir=filecache_dir,
#     sim_names=["Atlanta-Atlanta_config/Rainfall_Data_20.txt"],
#     dataset_split="train",
#     include_labels=True
# )

In [None]:
import tensorflow as tf
from usl_models.flood_ml.model import FloodModel, SpatialAttention

# Path to your saved model
N_steps = 4
model_path = "/home/se2890/climateiq-cnn-9/logs/training_20251015-164325/model"
loaded_model = tf.keras.models.load_model(model_path)
loaded_model.summary()
# Load the model
model = tf.keras.models.load_model(model_path)
# model = FloodModel.from_checkpoint(model_path)

from usl_models.flood_ml.model import SpatialAttention

custom_objects = {"SpatialAttention": SpatialAttention}
loaded_model = tf.keras.models.load_model(
    model_path, custom_objects=custom_objects, compile=False
)
model.set_weights(loaded_model.get_weights())

# # Test calling the model for n predictions
# full_dataset = load_dataset(sim_names=sim_names, batch_size=4, dataset_split= "train")
inputs, labels, _ = next(iter(full_dataset))
predictions = model.call_n(inputs, n=N_steps)
predictions.shape

In [None]:
# loss_scale = best_hp.get("loss_scale")
# print("Loss scale used during training:", loss_scale)

In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

from usl_models.flood_ml.dataset import load_dataset_windowed
from usl_models.flood_ml import constants
from usl_models.flood_ml.model import FloodModel
from usl_models.flood_ml import customloss

# Path to trained model
# Known value used during training
loss_scale = 200.0

# Path to trained model
model_path = "/home/se2890/climateiq-cnn-5/logs/htune_project_20250801-155126/model"

# Create the loss function with the correct scale
loss_fn = customloss.make_hybrid_loss(scale=loss_scale)

# Load model with custom loss function
model = tf.keras.models.load_model(model_path, custom_objects={"loss_fn": loss_fn})
# Number of samples to visualize
n_samples = 20

# Loop through the dataset and predict
for i, (input_data, ground_truth) in enumerate(validation_dataset.take(n_samples)):
    ground_truth = ground_truth.numpy().squeeze()
    prediction = model(input_data).numpy().squeeze()

    print(f"\nSample {i+1} Prediction Stats:")
    print("  Min:", prediction.min())
    print("  Max:", prediction.max())
    print("  Mean:", prediction.mean())

    # Choose timestep to plot
    timestep = 3
    gt_t = ground_truth[timestep]
    pred_t = prediction[timestep]
    vmax_val = max(gt_t.max(), pred_t.max())

    # Plot Ground Truth and Prediction
    fig, axes = plt.subplots(1, 2, figsize=(14, 6))
    fig.suptitle(f"Sample {i+1} - Timestep {timestep}", fontsize=16)

    im1 = axes[0].imshow(gt_t, cmap="Blues", vmin=0, vmax=vmax_val)
    axes[0].set_title("Ground Truth")
    axes[0].axis("off")
    plt.colorbar(im1, ax=axes[0], shrink=0.8)

    im2 = axes[1].imshow(pred_t, cmap="Blues", vmin=0, vmax=vmax_val)
    axes[1].set_title("Prediction")
    axes[1].axis("off")
    plt.colorbar(im2, ax=axes[1], shrink=0.8)

    plt.tight_layout()
    plt.show()

In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

from usl_models.flood_ml.dataset import load_dataset_windowed
from usl_models.flood_ml import constants
from sklearn.metrics import mean_absolute_error, mean_squared_error
from skimage.metrics import structural_similarity as ssim
import pandas as pd

# Path to trained model
# Known value used during training
loss_scale = 150.0

# Path to trained model
model_path = "/home/elhajjas/climateiq-cnn-11/usl_models/notebooks/logs/htune_project_20250611-205219/model"

# Create the loss function with the correct scale
loss_fn = customloss.make_hybrid_loss(scale=loss_scale)

# Load model with custom loss function
model = tf.keras.models.load_model(model_path, custom_objects={"loss_fn": loss_fn})


# Assuming validation_dataset is already defined
# Example:
# from usl_models.flood_ml.dataset import load_dataset_windowed
# validation_dataset = load_dataset_windowed(...)

n_samples = 20
timestep = 2
metrics_list = []

for i, (input_data, ground_truth) in enumerate(validation_dataset.take(n_samples)):
    ground_truth = ground_truth.numpy().squeeze()
    prediction = model(input_data).numpy().squeeze()

    gt_t = ground_truth[timestep]
    pred_t = prediction[timestep]
    vmax_val = np.nanpercentile([gt_t, pred_t], 99.5)

    # Mask out NaNs
    mask = ~np.isnan(gt_t)
    gt_flat = gt_t[mask].flatten()
    pred_flat = pred_t[mask].flatten()

    mae = mean_absolute_error(gt_flat, pred_flat)
    rmse = np.sqrt(mean_squared_error(gt_flat, pred_flat))
    bias = np.mean(pred_flat) - np.mean(gt_flat)
    iou = np.logical_and(gt_flat > 0.1, pred_flat > 0.1).sum() / max(
        1, np.logical_or(gt_flat > 0.1, pred_flat > 0.1).sum()
    )
    ssim_val = ssim(gt_t, pred_t, data_range=gt_t.max() - gt_t.min())

    metrics_list.append(
        {
            "Sample": i + 1,
            "MAE": mae,
            "RMSE": rmse,
            "Bias": bias,
            "IoU > 0.1": iou,
            "SSIM": ssim_val,
        }
    )

    # Plot
    fig, axes = plt.subplots(1, 2, figsize=(14, 6))
    fig.suptitle(f"Sample {i+1} - Timestep {timestep}", fontsize=16)

    im1 = axes[0].imshow(gt_t, cmap="Blues", vmin=0, vmax=vmax_val)
    axes[0].set_title("Ground Truth")
    axes[0].axis("off")
    plt.colorbar(im1, ax=axes[0], shrink=0.8)

    im2 = axes[1].imshow(pred_t, cmap="Blues", vmin=0, vmax=vmax_val)
    axes[1].set_title("Prediction")
    axes[1].axis("off")
    plt.colorbar(im2, ax=axes[1], shrink=0.8)

    plt.tight_layout()
    plt.show()

# Convert to DataFrame
df = pd.DataFrame(metrics_list)
print("\n=== Metrics Summary ===")
print(df.describe())

In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

from usl_models.flood_ml.dataset import load_dataset_windowed
from usl_models.flood_ml import constants
from usl_models.flood_ml import customloss
from sklearn.metrics import mean_absolute_error, mean_squared_error
from skimage.metrics import structural_similarity as ssim
import pandas as pd

# Parameters
loss_scale = 200.0
timestep = 3
n_samples = 20

# Paths to models
model_path_1 = (
    "/home/elhajjas/climateiq-cnn-11/usl_models/notebooks/logs/attention/model"
)
model_path_2 = "/home/elhajjas/climateiq-cnn-11/usl_models/notebooks/logs/htune_project_20250612-010926/model"

# Loss function
loss_fn = customloss.make_hybrid_loss(scale=loss_scale)

# Load models
model_1 = tf.keras.models.load_model(model_path_1, custom_objects={"loss_fn": loss_fn})
model_2 = tf.keras.models.load_model(model_path_2, custom_objects={"loss_fn": loss_fn})

# Load validation dataset (ensure it's already prepared)
# Example:
# validation_dataset = load_dataset_windowed(...)

metrics_list = []

for i, (input_data, ground_truth) in enumerate(train_dataset.take(n_samples)):
    ground_truth = ground_truth.numpy().squeeze()

    pred_1 = model_1(input_data).numpy().squeeze()
    pred_2 = model_2(input_data).numpy().squeeze()

    gt_t = ground_truth[timestep]
    pred_1_t = pred_1[timestep]
    pred_2_t = pred_2[timestep]
    vmax_val = np.nanpercentile([gt_t, pred_1_t, pred_2_t], 99.5)

    mask = ~np.isnan(gt_t)
    gt_flat = gt_t[mask].flatten()
    pred_1_flat = pred_1_t[mask].flatten()
    pred_2_flat = pred_2_t[mask].flatten()

    # Compute metrics
    metrics_list.append(
        {
            "Sample": i + 1,
            "MAE_1": mean_absolute_error(gt_flat, pred_1_flat),
            "RMSE_1": np.sqrt(mean_squared_error(gt_flat, pred_1_flat)),
            "Bias_1": np.mean(pred_1_flat) - np.mean(gt_flat),
            "IoU_1": np.logical_and(gt_flat > 0.1, pred_1_flat > 0.1).sum()
            / max(1, np.logical_or(gt_flat > 0.1, pred_1_flat > 0.1).sum()),
            "SSIM_1": ssim(gt_t, pred_1_t, data_range=gt_t.max() - gt_t.min()),
            "MAE_2": mean_absolute_error(gt_flat, pred_2_flat),
            "RMSE_2": np.sqrt(mean_squared_error(gt_flat, pred_2_flat)),
            "Bias_2": np.mean(pred_2_flat) - np.mean(gt_flat),
            "IoU_2": np.logical_and(gt_flat > 0.1, pred_2_flat > 0.1).sum()
            / max(1, np.logical_or(gt_flat > 0.1, pred_2_flat > 0.1).sum()),
            "SSIM_2": ssim(gt_t, pred_2_t, data_range=gt_t.max() - gt_t.min()),
        }
    )

    # Plotting
    fig, axes = plt.subplots(1, 3, figsize=(21, 6))
    fig.suptitle(f"Sample {i+1} - Timestep {timestep}", fontsize=16)

    axes[0].imshow(gt_t, cmap="Blues", vmin=0, vmax=vmax_val)
    axes[0].set_title("Ground Truth")
    axes[0].axis("off")

    axes[1].imshow(pred_1_t, cmap="Blues", vmin=0, vmax=vmax_val)
    axes[1].set_title("attention")
    axes[1].axis("off")

    axes[2].imshow(pred_2_t, cmap="Blues", vmin=0, vmax=vmax_val)
    axes[2].set_title("without attention")
    axes[2].axis("off")

    plt.tight_layout()
    plt.show()

# Summary metrics
df = pd.DataFrame(metrics_list)
print("\n=== Metrics Summary ===")
print(df.describe())