# Flood Model Training Notebook

Train a Flood ConvLSTM Model using `usl_models` lib.

In [1]:
from usl_models.flood_ml.dataset import  load_dataset_windowed_cached

import tensorflow as tf
import keras_tuner
import time
from datetime import datetime
import keras
import logging
from usl_models.flood_ml.model import FloodModel
from usl_models.flood_ml.dataset import load_dataset_windowed
import pathlib
# === CONFIG ===
# Set random seeds and GPU memory growth
logging.getLogger().setLevel(logging.WARNING)
keras.utils.set_random_seed(812)

for gpu in tf.config.list_physical_devices("GPU"):
    tf.config.experimental.set_memory_growth(gpu, True)

timestamp = time.strftime("%Y%m%d-%H%M%S")
log_dir = f"logs/training_{timestamp}"

2025-10-15 17:59:11.744914: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-10-15 17:59:11.796899: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-10-15 17:59:11.796950: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-10-15 17:59:11.798448: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-10-15 17:59:11.806469: I tensorflow/core/platform/cpu_feature_guar

In [2]:
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))
print(tf.config.list_physical_devices('GPU'))


Num GPUs Available:  1
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]


In [None]:
# :package: Download, :steam_locomotive: Train, and :floppy_disk: Save FloodML model from cached dataset
filecache_dir = pathlib.Path("/home/shared/climateiq/filecache")
city_config_mapping = {
    "Manhattan": "Manhattan_config",
    # "Atlanta": "Atlanta_config",
    # "Phoenix_SM": "PHX_SM",
    # "Phoenix_PV": "PHX_PV",
    # "Phoenix_central": "PHX_CCC"
    # "Atlanta_Prediction": "Atlanta_config",
    
}
# Rainfall files you want
rainfall_files = [7,5,13,11,9,16,15,10,12,2,3]  # Only 5 and 6
# rainfall_files = [5]  # Only 5 and 6
dataset_splits = ["test", "train", "val"]
n_flood_maps = 5
m_rainfall = 6
batch_size = 10
epochs = 2
# Generate sim_names
sim_names = []
for city, config in city_config_mapping.items():
    for rain_id in rainfall_files:
        sim_name = f"{city}-{config}/Rainfall_Data_{rain_id}.txt"
        sim_names.append(sim_name)

# === STEP 1: DOWNLOAD DATASET TO FILECACHE ===
# print("Downloading simulations into local cache")
# download_dataset(
#     sim_names=sim_names,
#     output_path=filecache_dir,
#     dataset_splits=dataset_splits,
#    include_labels=True  
# )


# print(":white_check_mark: Download complete.")


2025-10-15 14:14:01.567913: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1929] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 38364 MB memory:  -> device: 0, name: NVIDIA A100-SXM4-40GB, pci bus id: 0000:00:04.0, compute capability: 8.0


In [6]:
# for fatser loading during hyperparameter tuning use this function
def get_datasets(batch_size=4):
    filecache_dir = pathlib.Path("/home/shared/climateiq/filecache")
    city_config_mapping = {"Manhattan": "Manhattan_config"}
    rainfall_files = [7, 5, 13, 11, 9, 16, 15, 10, 12, 2, 3]
    n_flood_maps, m_rainfall = 5, 6

    sim_names = []
    for city, config in city_config_mapping.items():
        for rain_id in rainfall_files:
            sim_names.append(f"{city}-{config}/Rainfall_Data_{rain_id}.txt")

    train_ds = load_dataset_windowed_cached(
        filecache_dir=filecache_dir,
        sim_names=sim_names,
        dataset_split="train",
        batch_size=batch_size,
        n_flood_maps=n_flood_maps,
        m_rainfall=m_rainfall,
        shuffle=True,
    )

    val_ds = load_dataset_windowed_cached(
        filecache_dir=filecache_dir,
        sim_names=sim_names,
        dataset_split="val",
        batch_size=batch_size,
        n_flood_maps=n_flood_maps,
        m_rainfall=m_rainfall,
        shuffle=True,
    )

    return train_ds.prefetch(8), val_ds.prefetch(8)

In [None]:
# Skip this step if dataset is already downloaded
# # === STEP 2: LOAD CACHED WINDOWED DATASETS ===
# print("open_file_folder: Loading datasets from cache")
train_dataset = load_dataset_windowed_cached(
    filecache_dir=filecache_dir,
    sim_names=sim_names,
    dataset_split="train",
    batch_size=batch_size,
    n_flood_maps=n_flood_maps,
    m_rainfall=m_rainfall,
    shuffle=True
).prefetch(tf.data.AUTOTUNE)
# train_dataset = train_dataset.cache("/tmp/train_cache")
# dataset = train_dataset.map(lambda x, y: (x, y), num_parallel_calls=tf.data.AUTOTUNE)
# train_dataset = dataset.prefetch(tf.data.AUTOTUNE)

validation_dataset = load_dataset_windowed_cached(
    filecache_dir=filecache_dir,
    sim_names=sim_names,
    dataset_split="val",
    batch_size=batch_size,
    n_flood_maps=n_flood_maps,
    m_rainfall=m_rainfall,
    shuffle=True
).prefetch(tf.data.AUTOTUNE)
# validation_dataset = validation_dataset.cache("/tmp/val_cache")
# validation_dataset = validation_dataset.map(lambda x, y: (x, y), num_parallel_calls=tf.data.AUTOTUNE)
# validation_dataset = validation_dataset.prefetch(tf.data.AUTOTUNE)

train_dataset = train_dataset.map(lambda x, y: (x, y), num_parallel_calls=tf.data.AUTOTUNE)
train_dataset = train_dataset.prefetch(tf.data.AUTOTUNE)
validation_dataset = validation_dataset.map(lambda x, y: (x, y), num_parallel_calls=tf.data.AUTOTUNE)
validation_dataset = validation_dataset.prefetch(tf.data.AUTOTUNE)

test_dataset = load_dataset_windowed_cached(
    filecache_dir=filecache_dir,
    sim_names=sim_names,
    dataset_split="test",
    batch_size=batch_size,
    n_flood_maps=n_flood_maps,
    m_rainfall=m_rainfall,
    shuffle=True
).prefetch(tf.data.AUTOTUNE)


In [3]:
# === Debug dataset structure and counts ===
num_samples = 0
example_shapes = None
unique_chunks = set()

print("üîç Scanning train_dataset...")
for i, (inputs, labels) in enumerate(train_dataset):
    num_samples += inputs["geospatial"].shape[0]
    example_shapes = {
        "geospatial": inputs["geospatial"].shape,
        "temporal": inputs["temporal"].shape,
        "spatiotemporal": inputs["spatiotemporal"].shape,
        "labels": labels.shape,
    }
    # Optional: show the first few batches
    if i < 2:
        print(f"\nBatch {i}:")
        print(f"  geospatial: {inputs['geospatial'].shape}")
        print(f"  temporal: {inputs['temporal'].shape}")
        print(f"  spatiotemporal: {inputs['spatiotemporal'].shape}")
        print(f"  labels: {labels.shape}")
    if i % 100 == 0 and i > 0:
        print(f"  Processed {i} batches so far...")

print("\n‚úÖ Dataset scan complete.")
print(f"Total samples (windows): {num_samples}")
print(f"Example tensor shapes: {example_shapes}")


üîç Scanning train_dataset...

Batch 0:
  geospatial: (2, 1000, 1000, 9)
  temporal: (2, 5, 6)
  spatiotemporal: (2, 5, 1000, 1000, 1)
  labels: (2, 1000, 1000)

Batch 1:
  geospatial: (2, 1000, 1000, 9)
  temporal: (2, 5, 6)
  spatiotemporal: (2, 5, 1000, 1000, 1)
  labels: (2, 1000, 1000)
  Processed 100 batches so far...
  Processed 200 batches so far...
  Processed 300 batches so far...
  Processed 400 batches so far...
  Processed 500 batches so far...
  Processed 600 batches so far...
  Processed 700 batches so far...
  Processed 800 batches so far...
  Processed 900 batches so far...
  Processed 1000 batches so far...
  Processed 1100 batches so far...
  Processed 1200 batches so far...
  Processed 1300 batches so far...

‚úÖ Dataset scan complete.
Total samples (windows): 2748
Example tensor shapes: {'geospatial': TensorShape([2, 1000, 1000, 9]), 'temporal': TensorShape([2, 5, 6]), 'spatiotemporal': TensorShape([2, 5, 1000, 1000, 1]), 'labels': TensorShape([2, 1000, 1000])}


In [None]:
print(f"[DEBUG] Labels shape before windowing: {labels.shape}")

In [None]:
for inputs, labels in train_dataset.take(1):
    print(inputs["geospatial"].shape, labels.shape)


In [None]:
# If working locally, comment out this block


# Cities and their config folders
city_config_mapping = {
    "Manhattan": "Manhattan_config",
    # "Atlanta": "Atlanta_config",
    # "Phoenix_SM": "PHX_SM",
    # "Phoenix_PV": "PHX_PV",
}

# Rainfall files you want
rainfall_files = [5]  # Only 5 and 6

# Generate sim_names
sim_names = []
for city, config in city_config_mapping.items():
    for rain_id in rainfall_files:
        sim_name = f"{city}-{config}/Rainfall_Data_{rain_id}.txt"
        sim_names.append(sim_name)

print(f"Training on {len(sim_names)} simulations.")
for s in sim_names:
    print(s)

# Now load dataset
train_dataset = load_dataset_windowed(
    sim_names=sim_names, batch_size=2, dataset_split="train"
).cache()

validation_dataset = load_dataset_windowed(
    sim_names=sim_names, batch_size=2, dataset_split="val"
).cache()

In [7]:
import gc
tuner = keras_tuner.BayesianOptimization(
    FloodModel.get_hypermodel(
        lstm_units=[32, 64, 128],
        lstm_kernel_size=[3, 5],
        lstm_dropout=[0.2, 0.3],
        lstm_recurrent_dropout=[0.2, 0.3],
        n_flood_maps=[5],
        m_rainfall=[6],
    ),
    objective="val_loss",
    max_trials=1,  # increase if you want more search
    project_name=log_dir,
)

tb_callback = keras.callbacks.TensorBoard(
    log_dir=log_dir,
    histogram_freq=0,
    profile_batch=0,
)

def tuner_search(batch_size=4):
    # Clear any leftover sessions
    gc.collect()
    tf.keras.backend.clear_session()

    train_ds, val_ds = get_datasets(batch_size=batch_size)

    tuner.search(
        train_ds.take(200),
        validation_data=val_ds.take(50),
        epochs=10,
        callbacks=[],
        verbose=1,
    )

tf.debugging.set_log_device_placement(True)
tuner_search(batch_size=2)

best_model, best_hp = tuner.get_best_models()[0], tuner.get_best_hyperparameters()[0]
print("Best hyperparameters:", best_hp.values)


Trial 1 Complete [00h 04m 58s]
val_loss: 0.0009991672122851014

Best val_loss So Far: 0.0009991672122851014
Total elapsed time: 00h 04m 58s
Best hyperparameters: {'lstm_units': 64, 'lstm_kernel_size': 3, 'lstm_dropout': 0.3, 'lstm_recurrent_dropout': 0.3, 'n_flood_maps': 5, 'm_rainfall': 6}


In [5]:
tuner = keras_tuner.BayesianOptimization(
    FloodModel.get_hypermodel(
        lstm_units=[32, 64, 128],
        lstm_kernel_size=[3, 5],
        lstm_dropout=[0.2, 0.3],
        lstm_recurrent_dropout=[0.2, 0.3],
        n_flood_maps=[5],
        m_rainfall=[6],
    ),
    objective="val_loss",
    max_trials=1,
    project_name=f"logs/htune_project_{timestamp}",
)

tuner.search_space_summary()

Search space summary
Default search space size: 6
lstm_units (Choice)
{'default': 32, 'conditions': [], 'values': [32, 64, 128], 'ordered': True}
lstm_kernel_size (Choice)
{'default': 3, 'conditions': [], 'values': [3, 5], 'ordered': True}
lstm_dropout (Choice)
{'default': 0.2, 'conditions': [], 'values': [0.2, 0.3], 'ordered': True}
lstm_recurrent_dropout (Choice)
{'default': 0.2, 'conditions': [], 'values': [0.2, 0.3], 'ordered': True}
n_flood_maps (Choice)
{'default': 5, 'conditions': [], 'values': [5], 'ordered': True}
m_rainfall (Choice)
{'default': 6, 'conditions': [], 'values': [6], 'ordered': True}


In [None]:

tf.debugging.set_log_device_placement(True)
log_dir = f"logs/htune_project_{timestamp}"
print(log_dir)
tb_callback = keras.callbacks.TensorBoard(log_dir=log_dir)
tuner.search(
    train_dataset.take(200),
    epochs=10,
    validation_data=validation_dataset.take(50),
    callbacks=[tb_callback],
)
best_model, best_hp = tuner.get_best_models()[0], tuner.get_best_hyperparameters()[0]
best_hp.values

logs/htune_project_20251015-141132

Search: Running Trial #1

Value             |Best Value So Far |Hyperparameter
64                |64                |lstm_units
3                 |3                 |lstm_kernel_size
0.3               |0.3               |lstm_dropout
0.3               |0.3               |lstm_recurrent_dropout
5                 |5                 |n_flood_maps
6                 |6                 |m_rainfall

Epoch 1/10


2025-10-15 14:13:15.070899: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:961] layout failed: INVALID_ARGUMENT: Size of values 0 does not match size of permutation 4 @ fanin shape inflood_conv_lstm/conv_lstm/conv_lstm2d/while/body/_1/flood_conv_lstm/conv_lstm/conv_lstm2d/while/dropout_7/SelectV2-2-TransposeNHWCToNCHW-LayoutOptimizer
2025-10-15 14:13:18.007358: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:454] Loaded cuDNN version 8900
2025-10-15 14:13:24.118440: I external/local_xla/xla/service/service.cc:168] XLA service 0x7f3560d58e40 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
2025-10-15 14:13:24.118512: I external/local_xla/xla/service/service.cc:176]   StreamExecutor device (0): NVIDIA A100-SXM4-40GB, Compute Capability 8.0
2025-10-15 14:13:24.126012: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:269] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
I0000 00:

In [11]:
from keras.callbacks import ModelCheckpoint, EarlyStopping
train_ds, val_ds = get_datasets(batch_size=2)
# Define final parameters and model
final_params_dict = best_hp.values.copy()
final_params = FloodModel.Params(**final_params_dict)
model = FloodModel(params=final_params)
# Define callbacks
callbacks = [
    keras.callbacks.TensorBoard(log_dir=log_dir),
    ModelCheckpoint(
        filepath=log_dir + "/checkpoint",
        save_best_only=True,
        monitor="val_loss",
        mode="min",
        save_format="tf",
    ),
    EarlyStopping(  # <--- ADD THIS
        monitor="val_loss",  # What to monitor
        patience=100,  # Number of epochs with no improvement to wait
        restore_best_weights=True,  # Restore model weights from best epoch
        mode="min",  # "min" because lower val_loss is better
    ),
]

# Train
model.fit(train_ds, val_ds, epochs=2, callbacks=callbacks)

# Save final model
model.save_model(log_dir + "/model")

Epoch 1/2


2025-10-15 16:50:56.078717: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:961] layout failed: INVALID_ARGUMENT: Size of values 0 does not match size of permutation 4 @ fanin shape inflood_conv_lstm_3/conv_lstm/conv_lstm2d_3/while/body/_1/flood_conv_lstm_3/conv_lstm/conv_lstm2d_3/while/dropout_7/SelectV2-2-TransposeNHWCToNCHW-LayoutOptimizer


   1374/Unknown - 178s 125ms/step - loss: 0.0019 - mean_absolute_error: 0.0137 - root_mean_squared_error: 0.0696

2025-10-15 16:54:20.263860: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 3258415898515464409
2025-10-15 16:54:20.263930: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 14801390994140985373
2025-10-15 16:54:20.263940: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 4185487910489117445
2025-10-15 16:54:20.263952: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 10507816670147353428
2025-10-15 16:54:20.263973: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 16856470375749547650


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(7, 7, 2, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6914cd3e50>, 140089889807856), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(7, 7, 2, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6914cd3e50>, 140089889807856), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6914cdedd0>, 140089889808192), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6914cdedd0>, 140089889808192), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(7, 7, 2, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6914b1eb50>, 140089300006848), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(7, 7, 2, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6914b1eb50>, 140089300006848), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6914b39a50>, 140089300007184), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6914b39a50>, 140089300007184), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(7, 7, 2, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6914cd3e50>, 140089889807856), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(7, 7, 2, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6914cd3e50>, 140089889807856), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6914cdedd0>, 140089889808192), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6914cdedd0>, 140089889808192), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(7, 7, 2, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6914b1eb50>, 140089300006848), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(7, 7, 2, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6914b1eb50>, 140089300006848), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6914b39a50>, 140089300007184), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6914b39a50>, 140089300007184), {}).


INFO:tensorflow:Assets written to: logs/training_20251015-164325/checkpoint/assets


INFO:tensorflow:Assets written to: logs/training_20251015-164325/checkpoint/assets


Epoch 2/2


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(7, 7, 2, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6914cd3e50>, 140089889807856), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6914cdedd0>, 140089889808192), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6914cdedd0>, 140089889808192), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(7, 7, 2, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6914b1eb50>, 140089300006848), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(7, 7, 2, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6914b1eb50>, 140089300006848), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6914b39a50>, 140089300007184), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6914b39a50>, 140089300007184), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(7, 7, 2, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6914cd3e50>, 140089889807856), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(7, 7, 2, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6914cd3e50>, 140089889807856), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6914cdedd0>, 140089889808192), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6914cdedd0>, 140089889808192), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(7, 7, 2, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6914b1eb50>, 140089300006848), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(7, 7, 2, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6914b1eb50>, 140089300006848), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6914b39a50>, 140089300007184), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6914b39a50>, 140089300007184), {}).


INFO:tensorflow:Assets written to: logs/training_20251015-164325/checkpoint/assets


INFO:tensorflow:Assets written to: logs/training_20251015-164325/checkpoint/assets


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(7, 7, 2, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6914cd3e50>, 140089889807856), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(7, 7, 2, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6914cd3e50>, 140089889807856), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6914cdedd0>, 140089889808192), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6914cdedd0>, 140089889808192), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(7, 7, 2, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6914b1eb50>, 140089300006848), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(7, 7, 2, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6914b1eb50>, 140089300006848), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6914b39a50>, 140089300007184), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6914b39a50>, 140089300007184), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(7, 7, 2, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6914cd3e50>, 140089889807856), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(7, 7, 2, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6914cd3e50>, 140089889807856), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6914cdedd0>, 140089889808192), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6914cdedd0>, 140089889808192), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(7, 7, 2, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6914b1eb50>, 140089300006848), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(7, 7, 2, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6914b1eb50>, 140089300006848), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6914b39a50>, 140089300007184), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6914b39a50>, 140089300007184), {}).


INFO:tensorflow:Assets written to: logs/training_20251015-164325/model/assets


INFO:tensorflow:Assets written to: logs/training_20251015-164325/model/assets


In [None]:
# # Test calling the model on some data.
inputs, labels_ = next(iter(train_dataset))
prediction = model.call(inputs)
prediction.shape

TensorShape([2, 1000, 1000, 1])

In [14]:
# Prediction mode
from usl_models.flood_ml import dataset
# Parameters
filecache_dir = pathlib.Path("/home/shared/climateiq/filecache")
# prediction
sim_name = ["Atlanta_Prediction"]
rainfall_sim = "Atlanta-Atlanta_config/Rainfall_Data_22.txt"



# Download (prediction mode)
# dataset.download_dataset(
#     sim_names=sim_name,          # study area
#     output_path=filecache_dir,
#     include_labels=False,                      # no labels
#     rainfall_sim_name=rainfall_sim,  # simulation for temporal vector
#     allow_missing_sim=True                     # skip temporal if missing
# )
# prediction mode
# # # Load dataset
full_dataset = dataset.load_dataset_cached(
    filecache_dir=filecache_dir,
    sim_names=sim_name,            # study area
    dataset_split=None,                          # no split for prediction
    batch_size=2,
    include_labels=False,
    rainfall_sim_name=rainfall_sim  # actual rainfall sim
)


# Download (training mode)
# dataset_splits = ["test", "train", "val"]
# dataset.download_dataset(
#     sim_names=["Atlanta-Atlanta_config/Rainfall_Data_22.txt"],  # normal simulations
#     output_path=filecache_dir,
#     dataset_splits=dataset_splits,               # train/val/test splits
#     include_labels=True                        # get labels too
# )

# full_dataset = dataset.load_dataset_cached(
#     filecache_dir=filecache_dir,
#     sim_names=["Atlanta-Atlanta_config/Rainfall_Data_20.txt"],
#     dataset_split="train",
#     include_labels=True
# )


In [15]:
import tensorflow as tf
from usl_models.flood_ml.model import FloodModel, SpatialAttention
# Path to your saved model
N_steps = 4
model_path = "/home/se2890/climateiq-cnn-9/logs/training_20251015-164325/model"
loaded_model = tf.keras.models.load_model(model_path)
loaded_model.summary()
# Load the model
model = tf.keras.models.load_model(model_path)
# model = FloodModel.from_checkpoint(model_path)

from usl_models.flood_ml.model import SpatialAttention
custom_objects = {'SpatialAttention': SpatialAttention}
loaded_model = tf.keras.models.load_model(
    model_path,
    custom_objects=custom_objects,
    compile=False
)
model.set_weights(loaded_model.get_weights())

# # Test calling the model for n predictions
# full_dataset = load_dataset(sim_names=sim_names, batch_size=4, dataset_split= "train")
inputs, labels, _ = next(iter(full_dataset))
predictions = model.call_n(inputs, n=N_steps)
predictions.shape

Model: "flood_conv_lstm_3"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 spatiotemporal_cnn (Sequen  (None, None, 250, 250,    3424      
 tial)                       16)                                 
                                                                 
 geospatial_cnn (Sequential  (None, 250, 250, 64)      29280     
 )                                                               
                                                                 
 spatial_attention_8 (Spati  multiple                  99        
 alAttention)                                                    
                                                                 
 conv_lstm (Sequential)      (None, 250, 250, 64)      345856    
                                                                 
 spatial_attention_9 (Spati  multiple                  99        
 alAttention)                                    

TensorShape([2, 4, 1000, 1000])

In [None]:
# loss_scale = best_hp.get("loss_scale")
# print("Loss scale used during training:", loss_scale)

In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

from usl_models.flood_ml.dataset import load_dataset_windowed
from usl_models.flood_ml import constants
from usl_models.flood_ml.model import FloodModel
from usl_models.flood_ml import customloss
# Path to trained model
# Known value used during training
loss_scale = 200.0

# Path to trained model
model_path = "/home/se2890/climateiq-cnn-5/logs/htune_project_20250801-155126/model"

# Create the loss function with the correct scale
loss_fn = customloss.make_hybrid_loss(scale=loss_scale)

# Load model with custom loss function
model = tf.keras.models.load_model(model_path, custom_objects={"loss_fn": loss_fn})
# Number of samples to visualize
n_samples = 20

# Loop through the dataset and predict
for i, (input_data, ground_truth) in enumerate(validation_dataset.take(n_samples)):
    ground_truth = ground_truth.numpy().squeeze()
    prediction = model(input_data).numpy().squeeze()

    print(f"\nSample {i+1} Prediction Stats:")
    print("  Min:", prediction.min())
    print("  Max:", prediction.max())
    print("  Mean:", prediction.mean())

    # Choose timestep to plot
    timestep = 3
    gt_t = ground_truth[timestep]
    pred_t = prediction[timestep]
    vmax_val = max(gt_t.max(), pred_t.max())

    # Plot Ground Truth and Prediction
    fig, axes = plt.subplots(1, 2, figsize=(14, 6))
    fig.suptitle(f"Sample {i+1} - Timestep {timestep}", fontsize=16)

    im1 = axes[0].imshow(gt_t, cmap="Blues", vmin=0, vmax=vmax_val)
    axes[0].set_title("Ground Truth")
    axes[0].axis("off")
    plt.colorbar(im1, ax=axes[0], shrink=0.8)

    im2 = axes[1].imshow(pred_t, cmap="Blues", vmin=0, vmax=vmax_val)
    axes[1].set_title("Prediction")
    axes[1].axis("off")
    plt.colorbar(im2, ax=axes[1], shrink=0.8)

    plt.tight_layout()
    plt.show()

In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

from usl_models.flood_ml.dataset import load_dataset_windowed
from usl_models.flood_ml import constants
from sklearn.metrics import mean_absolute_error, mean_squared_error
from skimage.metrics import structural_similarity as ssim
import pandas as pd

# Path to trained model
# Known value used during training
loss_scale = 150.0

# Path to trained model
model_path = "/home/elhajjas/climateiq-cnn-11/usl_models/notebooks/logs/htune_project_20250611-205219/model"

# Create the loss function with the correct scale
loss_fn = customloss.make_hybrid_loss(scale=loss_scale)

# Load model with custom loss function
model = tf.keras.models.load_model(model_path, custom_objects={"loss_fn": loss_fn})


# Assuming validation_dataset is already defined
# Example:
# from usl_models.flood_ml.dataset import load_dataset_windowed
# validation_dataset = load_dataset_windowed(...)

n_samples = 20
timestep = 2
metrics_list = []

for i, (input_data, ground_truth) in enumerate(validation_dataset.take(n_samples)):
    ground_truth = ground_truth.numpy().squeeze()
    prediction = model(input_data).numpy().squeeze()

    gt_t = ground_truth[timestep]
    pred_t = prediction[timestep]
    vmax_val = np.nanpercentile([gt_t, pred_t], 99.5)

    # Mask out NaNs
    mask = ~np.isnan(gt_t)
    gt_flat = gt_t[mask].flatten()
    pred_flat = pred_t[mask].flatten()

    mae = mean_absolute_error(gt_flat, pred_flat)
    rmse = np.sqrt(mean_squared_error(gt_flat, pred_flat))
    bias = np.mean(pred_flat) - np.mean(gt_flat)
    iou = np.logical_and(gt_flat > 0.1, pred_flat > 0.1).sum() / max(1, np.logical_or(gt_flat > 0.1, pred_flat > 0.1).sum())
    ssim_val = ssim(gt_t, pred_t, data_range=gt_t.max() - gt_t.min())

    metrics_list.append({
        "Sample": i+1,
        "MAE": mae,
        "RMSE": rmse,
        "Bias": bias,
        "IoU > 0.1": iou,
        "SSIM": ssim_val
    })

    # Plot
    fig, axes = plt.subplots(1, 2, figsize=(14, 6))
    fig.suptitle(f"Sample {i+1} - Timestep {timestep}", fontsize=16)

    im1 = axes[0].imshow(gt_t, cmap="Blues", vmin=0, vmax=vmax_val)
    axes[0].set_title("Ground Truth")
    axes[0].axis("off")
    plt.colorbar(im1, ax=axes[0], shrink=0.8)

    im2 = axes[1].imshow(pred_t, cmap="Blues", vmin=0, vmax=vmax_val)
    axes[1].set_title("Prediction")
    axes[1].axis("off")
    plt.colorbar(im2, ax=axes[1], shrink=0.8)

    plt.tight_layout()
    plt.show()

# Convert to DataFrame
df = pd.DataFrame(metrics_list)
print("\n=== Metrics Summary ===")
print(df.describe())


In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

from usl_models.flood_ml.dataset import load_dataset_windowed
from usl_models.flood_ml import constants
from usl_models.flood_ml import customloss
from sklearn.metrics import mean_absolute_error, mean_squared_error
from skimage.metrics import structural_similarity as ssim
import pandas as pd

# Parameters
loss_scale = 200.0
timestep = 3
n_samples = 20

# Paths to models
model_path_1 = "/home/elhajjas/climateiq-cnn-11/usl_models/notebooks/logs/attention/model"
model_path_2 = "/home/elhajjas/climateiq-cnn-11/usl_models/notebooks/logs/htune_project_20250612-010926/model"

# Loss function
loss_fn = customloss.make_hybrid_loss(scale=loss_scale)

# Load models
model_1 = tf.keras.models.load_model(model_path_1, custom_objects={"loss_fn": loss_fn})
model_2 = tf.keras.models.load_model(model_path_2, custom_objects={"loss_fn": loss_fn})

# Load validation dataset (ensure it's already prepared)
# Example:
# validation_dataset = load_dataset_windowed(...)

metrics_list = []

for i, (input_data, ground_truth) in enumerate(train_dataset.take(n_samples)):
    ground_truth = ground_truth.numpy().squeeze()

    pred_1 = model_1(input_data).numpy().squeeze()
    pred_2 = model_2(input_data).numpy().squeeze()

    gt_t = ground_truth[timestep]
    pred_1_t = pred_1[timestep]
    pred_2_t = pred_2[timestep]
    vmax_val = np.nanpercentile([gt_t, pred_1_t, pred_2_t], 99.5)

    mask = ~np.isnan(gt_t)
    gt_flat = gt_t[mask].flatten()
    pred_1_flat = pred_1_t[mask].flatten()
    pred_2_flat = pred_2_t[mask].flatten()

    # Compute metrics
    metrics_list.append({
        "Sample": i+1,
        "MAE_1": mean_absolute_error(gt_flat, pred_1_flat),
        "RMSE_1": np.sqrt(mean_squared_error(gt_flat, pred_1_flat)),
        "Bias_1": np.mean(pred_1_flat) - np.mean(gt_flat),
        "IoU_1": np.logical_and(gt_flat > 0.1, pred_1_flat > 0.1).sum() / max(1, np.logical_or(gt_flat > 0.1, pred_1_flat > 0.1).sum()),
        "SSIM_1": ssim(gt_t, pred_1_t, data_range=gt_t.max() - gt_t.min()),

        "MAE_2": mean_absolute_error(gt_flat, pred_2_flat),
        "RMSE_2": np.sqrt(mean_squared_error(gt_flat, pred_2_flat)),
        "Bias_2": np.mean(pred_2_flat) - np.mean(gt_flat),
        "IoU_2": np.logical_and(gt_flat > 0.1, pred_2_flat > 0.1).sum() / max(1, np.logical_or(gt_flat > 0.1, pred_2_flat > 0.1).sum()),
        "SSIM_2": ssim(gt_t, pred_2_t, data_range=gt_t.max() - gt_t.min()),
    })

    # Plotting
    fig, axes = plt.subplots(1, 3, figsize=(21, 6))
    fig.suptitle(f"Sample {i+1} - Timestep {timestep}", fontsize=16)

    axes[0].imshow(gt_t, cmap="Blues", vmin=0, vmax=vmax_val)
    axes[0].set_title("Ground Truth")
    axes[0].axis("off")

    axes[1].imshow(pred_1_t, cmap="Blues", vmin=0, vmax=vmax_val)
    axes[1].set_title("attention")
    axes[1].axis("off")

    axes[2].imshow(pred_2_t, cmap="Blues", vmin=0, vmax=vmax_val)
    axes[2].set_title("without attention")
    axes[2].axis("off")

    plt.tight_layout()
    plt.show()

# Summary metrics
df = pd.DataFrame(metrics_list)
print("\n=== Metrics Summary ===")
print(df.describe())
