# Flood Model Training Notebook

Train a Flood ConvLSTM Model using `usl_models` lib.

In [1]:
import tensorflow as tf
import keras_tuner
import time
import keras
import logging
from usl_models.flood_ml import constants
from usl_models.flood_ml.model import FloodModel
from usl_models.flood_ml.model_params import FloodModelParams
from usl_models.flood_ml.dataset import load_dataset_windowed, load_dataset
from usl_models.flood_ml import customloss

# Setup
logging.getLogger().setLevel(logging.WARNING)
keras.utils.set_random_seed(812)

for gpu in tf.config.list_physical_devices("GPU"):
    tf.config.experimental.set_memory_growth(gpu, True)

timestamp = time.strftime("%Y%m%d-%H%M%S")

# Cities and their config folders
city_config_mapping = {
    "Manhattan": "Manhattan_config",
    # "Atlanta": "Atlanta_config",
    # "Atlanta": "Atlanta_config",
    # "Phoenix_SM": "PHX_SM",
    # "Phoenix_PV": "PHX_PV",
}

# Rainfall files you want
rainfall_files = [5]  # Only 5 and 6

# Generate sim_names
sim_names = []
for city, config in city_config_mapping.items():
    for rain_id in rainfall_files:
        sim_name = f"{city}-{config}/Rainfall_Data_{rain_id}.txt"
        sim_names.append(sim_name)

print(f"Training on {len(sim_names)} simulations.")
for s in sim_names:
    print(s)

# Now load dataset
train_dataset = load_dataset_windowed(
    sim_names=sim_names, batch_size=4, dataset_split="train"
).cache()

validation_dataset = load_dataset_windowed(
    sim_names=sim_names, batch_size=4, dataset_split="val"
).cache()

2025-08-15 21:51:47.539372: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-08-15 21:51:47.590155: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-08-15 21:51:47.590186: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-08-15 21:51:47.591431: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-08-15 21:51:47.599068: I tensorflow/core/platform/cpu_feature_guar

Training on 1 simulations.
Manhattan-Manhattan_config/Rainfall_Data_5.txt


2025-08-15 21:51:52.622128: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1929] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 38364 MB memory:  -> device: 0, name: NVIDIA A100-SXM4-40GB, pci bus id: 0000:00:04.0, compute capability: 8.0


In [2]:
full_dataset = load_dataset(
    sim_names=sim_names,
    dataset_split=None,    # ← use ALL chunks (no train/val/test filter)
    batch_size=4 
)

In [3]:
import tensorflow as tf
from usl_models.flood_ml.dataset import load_dataset, load_dataset_windowed

# --- your existing sim_names here ---
# sim_names = [...]

def count_dataset(ds):
    return sum(1 for _ in ds)

# 1) CHUNKS (non-windowed) --------------------------
train_chunks_ds = load_dataset(
    sim_names=sim_names,
    dataset_split="train",
    batch_size=None,          # <- no batching: 1 element == 1 chunk
)
all_chunks_ds = load_dataset(
    sim_names=sim_names,
    dataset_split=None,       # <- ALL splits combined
    batch_size=None,
)

train_chunks = count_dataset(train_chunks_ds)
all_chunks = count_dataset(all_chunks_ds)

print(f"Chunks in TRAIN only: {train_chunks}")
print(f"Chunks in ALL splits: {all_chunks}")

# 2) WINDOWS (windowed for teacher-forcing) --------
train_windows_ds = load_dataset_windowed(
    sim_names=sim_names,
    dataset_split="train",
    batch_size=None,          # <- no batching: 1 element == 1 window
)
all_windows_ds = load_dataset_windowed(
    sim_names=sim_names,
    dataset_split=None,       # <- ALL splits combined (works with your updated code)
    batch_size=None,
)

train_windows = count_dataset(train_windows_ds)
all_windows = count_dataset(all_windows_ds)

print(f"Windows in TRAIN only: {train_windows}")
print(f"Windows in ALL splits: {all_windows}")

# 3) (Optional) peek at one example shape ----------
ex_inputs, ex_labels = next(iter(all_chunks_ds.take(1)))
print("Example CHUNK shapes:")
print("  spatiotemporal:", ex_inputs["spatiotemporal"].shape)  # [N, H, W, 1]
print("  geospatial:    ", ex_inputs["geospatial"].shape)       # [H, W, F]
print("  temporal:      ", ex_inputs["temporal"].shape)         # [T_MAX, M]
print("  labels:        ", ex_labels.shape)                     # [T_label, H, W]


Chunks in TRAIN only: 13
Chunks in ALL splits: 26
Windows in TRAIN only: 169
Windows in ALL splits: 338
Example CHUNK shapes:
  spatiotemporal: (5, 1000, 1000, 1)
  geospatial:     (1000, 1000, 9)
  temporal:       (864, 6)
  labels:         (13, 1000, 1000)


In [4]:
# This will iterate the dataset WITHOUT batching so you see the raw order
debug_dataset = load_dataset(
    sim_names=sim_names,
    dataset_split=None,   # full (train+val+test)
    batch_size=None       # no batching so we see per-chunk order
)

print("First 20 chunk indices from full dataset:")
for i, (features, labels) in enumerate(debug_dataset.take(20)):
    # Pull the chunk's position from Firestore metadata order
    # Features are shape [H, W, ...], but we don't know the indices unless we print in _iter_geo_feature_label_tensors
    print(f"Sample {i}: feature shape = {features['geospatial'].shape}, label shape = {labels.shape}")

# If you want exact (x_index, y_index) printed:
# Add a debug print in _iter_geo_feature_label_tensors before yield:
# print("Yielding chunk", index)


First 20 chunk indices from full dataset:
Sample 0: feature shape = (1000, 1000, 9), label shape = (13, 1000, 1000)
Sample 1: feature shape = (1000, 1000, 9), label shape = (13, 1000, 1000)
Sample 2: feature shape = (1000, 1000, 9), label shape = (13, 1000, 1000)
Sample 3: feature shape = (1000, 1000, 9), label shape = (13, 1000, 1000)
Sample 4: feature shape = (1000, 1000, 9), label shape = (13, 1000, 1000)
Sample 5: feature shape = (1000, 1000, 9), label shape = (13, 1000, 1000)
Sample 6: feature shape = (1000, 1000, 9), label shape = (13, 1000, 1000)
Sample 7: feature shape = (1000, 1000, 9), label shape = (13, 1000, 1000)
Sample 8: feature shape = (1000, 1000, 9), label shape = (13, 1000, 1000)
Sample 9: feature shape = (1000, 1000, 9), label shape = (13, 1000, 1000)
Sample 10: feature shape = (1000, 1000, 9), label shape = (13, 1000, 1000)
Sample 11: feature shape = (1000, 1000, 9), label shape = (13, 1000, 1000)
Sample 12: feature shape = (1000, 1000, 9), label shape = (13, 1000,

In [5]:
full_dataset = load_dataset(
    sim_names=sim_names,
    dataset_split=None,  # Load all splits combined
    batch_size=None      # No batching, one chunk at a time
)

from usl_models.flood_ml import metastore
from google.cloud import firestore, storage

firestore_client = firestore.Client()
storage_client = storage.Client()

for sim_name in sim_names:
    # Get ALL label chunks without split filter
    label_chunks_collection = metastore._get_simulation_doc(
        firestore_client, sim_name
    ).collection("label_chunks")
    label_metadata = [doc.to_dict() for doc in label_chunks_collection.stream()]

    # Sort by (x_index, y_index) to match dataset order
    ordered_labels = sorted(
        label_metadata, key=lambda l: (l["x_index"], l["y_index"])
    )

    # Iterate over ALL chunks (no slicing)
    for i, label in enumerate(ordered_labels):
        print(f"Chunk {i}: {label['gcs_uri']}")


Chunk 0: gs://test-climateiq-study-area-label-chunks/Manhattan/Manhattan_config/Rainfall_Data_5.txt/1_7.npy
Chunk 1: gs://test-climateiq-study-area-label-chunks/Manhattan/Manhattan_config/Rainfall_Data_5.txt/1_8.npy
Chunk 2: gs://test-climateiq-study-area-label-chunks/Manhattan/Manhattan_config/Rainfall_Data_5.txt/1_9.npy
Chunk 3: gs://test-climateiq-study-area-label-chunks/Manhattan/Manhattan_config/Rainfall_Data_5.txt/1_10.npy
Chunk 4: gs://test-climateiq-study-area-label-chunks/Manhattan/Manhattan_config/Rainfall_Data_5.txt/2_5.npy
Chunk 5: gs://test-climateiq-study-area-label-chunks/Manhattan/Manhattan_config/Rainfall_Data_5.txt/2_6.npy
Chunk 6: gs://test-climateiq-study-area-label-chunks/Manhattan/Manhattan_config/Rainfall_Data_5.txt/2_7.npy
Chunk 7: gs://test-climateiq-study-area-label-chunks/Manhattan/Manhattan_config/Rainfall_Data_5.txt/2_8.npy
Chunk 8: gs://test-climateiq-study-area-label-chunks/Manhattan/Manhattan_config/Rainfall_Data_5.txt/2_9.npy
Chunk 9: gs://test-climatei

In [6]:
def count_dataset(ds):
    return sum(1 for _ in ds)

# No batching so we count raw elements
train_ds = load_dataset(sim_names=sim_names, dataset_split="train", batch_size=None)
val_ds   = load_dataset(sim_names=sim_names, dataset_split="val", batch_size=None)
test_ds  = load_dataset(sim_names=sim_names, dataset_split="test", batch_size=None)
full_ds  = load_dataset(sim_names=sim_names, dataset_split=None,    batch_size=None)

train_count = count_dataset(train_ds)
val_count   = count_dataset(val_ds)
test_count  = count_dataset(test_ds)
full_count  = count_dataset(full_ds)

print(f"Train chunks: {train_count}")
print(f"Val chunks:   {val_count}")
print(f"Test chunks:  {test_count}")
print(f"Full chunks:  {full_count}")
print(f"Sum of splits: {train_count + val_count + test_count}")
print("MATCH:", full_count == (train_count + val_count + test_count))


Train chunks: 13
Val chunks:   6
Test chunks:  7
Full chunks:  26
Sum of splits: 26
MATCH: True


In [7]:
train_count = sum(1 for _ in train_dataset)
full_count = sum(1 for _ in full_dataset)

print(f"Train batches: {train_count}")
print(f"Full batches: {full_count}")

Train batches: 43
Full batches: 26


In [8]:
full_chunks = 0
for inputs, labels in full_dataset:
    full_chunks += int(inputs["geospatial"].shape[0])  # batch size for this batch
print("Full chunks (all splits combined):", full_chunks)


Full chunks (all splits combined): 26000


In [9]:
T_max = 169  # adjust
train_windows = 0
for inputs, label in train_dataset:  # windowed: label shape [B, H, W]
    train_windows += int(label.shape[0])
approx_train_chunks = train_windows // T_max
print("Train windows:", train_windows)
print("≈ Train chunks:", approx_train_chunks)


Train windows: 169
≈ Train chunks: 1


In [10]:
tuner = keras_tuner.BayesianOptimization(
    FloodModel.get_hypermodel(
        lstm_units=[32, 64, 128],
        lstm_kernel_size=[3, 5],
        lstm_dropout=[0.2, 0.3],
        lstm_recurrent_dropout=[0.2, 0.3],
        n_flood_maps=[5],
        m_rainfall=[6],
    ),
    objective="val_loss",
    max_trials=1,
    project_name=f"logs/htune_project_{timestamp}",
)

tuner.search_space_summary()

Search space summary
Default search space size: 6
lstm_units (Choice)
{'default': 32, 'conditions': [], 'values': [32, 64, 128], 'ordered': True}
lstm_kernel_size (Choice)
{'default': 3, 'conditions': [], 'values': [3, 5], 'ordered': True}
lstm_dropout (Choice)
{'default': 0.2, 'conditions': [], 'values': [0.2, 0.3], 'ordered': True}
lstm_recurrent_dropout (Choice)
{'default': 0.2, 'conditions': [], 'values': [0.2, 0.3], 'ordered': True}
n_flood_maps (Choice)
{'default': 5, 'conditions': [], 'values': [5], 'ordered': True}
m_rainfall (Choice)
{'default': 6, 'conditions': [], 'values': [6], 'ordered': True}


In [11]:
log_dir = f"logs/htune_project_{timestamp}"
print(log_dir)
tb_callback = keras.callbacks.TensorBoard(log_dir=log_dir)
tuner.search(
    train_dataset,
    epochs=2,
    validation_data=validation_dataset,
    callbacks=[tb_callback],
)
best_model, best_hp = tuner.get_best_models()[0], tuner.get_best_hyperparameters()[0]
best_hp.values

Trial 1 Complete [00h 00m 49s]
val_loss: 0.0028904567006975412

Best val_loss So Far: 0.0028904567006975412
Total elapsed time: 00h 00m 49s


{'lstm_units': 32,
 'lstm_kernel_size': 3,
 'lstm_dropout': 0.2,
 'lstm_recurrent_dropout': 0.2,
 'n_flood_maps': 5,
 'm_rainfall': 6}

In [12]:
from keras.callbacks import ModelCheckpoint, EarlyStopping

# Define final parameters and model
final_params_dict = best_hp.values.copy()
final_params = FloodModel.Params(**final_params_dict)
model = FloodModel(params=final_params)
# Define callbacks
callbacks = [
    keras.callbacks.TensorBoard(log_dir=log_dir),
    ModelCheckpoint(
        filepath=log_dir + "/checkpoint",
        save_best_only=True,
        monitor="val_loss",
        mode="min",
        save_format="tf",
    ),
    EarlyStopping(  # <--- ADD THIS
        monitor="val_loss",  # What to monitor
        patience=100,  # Number of epochs with no improvement to wait
        restore_best_weights=True,  # Restore model weights from best epoch
        mode="min",  # "min" because lower val_loss is better
    ),
]

# Train
model.fit(train_dataset, validation_dataset, epochs=2, callbacks=callbacks)

# Save final model
model.save_model(log_dir + "/model")

Epoch 1/2


2025-08-15 22:00:13.310409: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:961] layout failed: INVALID_ARGUMENT: Size of values 0 does not match size of permutation 4 @ fanin shape inflood_conv_lstm_1/conv_lstm/conv_lstm2d_1/while/body/_1/flood_conv_lstm_1/conv_lstm/conv_lstm2d_1/while/dropout_7/SelectV2-2-TransposeNHWCToNCHW-LayoutOptimizer


     43/Unknown - 13s 182ms/step - loss: 0.0040 - mean_absolute_error: 0.0254 - root_mean_squared_error: 0.1021

2025-08-15 22:00:22.939146: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 216111812316947242


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(7, 7, 2, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6f045fc210>, 140115259971760), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(7, 7, 2, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6f045fc210>, 140115259971760), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6f046070d0>, 140115259972096), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6f046070d0>, 140115259972096), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(7, 7, 2, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6f0464f210>, 140115258687056), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(7, 7, 2, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6f0464f210>, 140115258687056), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6f0465a2d0>, 140115258785840), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6f0465a2d0>, 140115258785840), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(7, 7, 2, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6f045fc210>, 140115259971760), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(7, 7, 2, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6f045fc210>, 140115259971760), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6f046070d0>, 140115259972096), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6f046070d0>, 140115259972096), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(7, 7, 2, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6f0464f210>, 140115258687056), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(7, 7, 2, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6f0464f210>, 140115258687056), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6f0465a2d0>, 140115258785840), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6f0465a2d0>, 140115258785840), {}).


INFO:tensorflow:Assets written to: logs/htune_project_20250815-215150/checkpoint/assets


INFO:tensorflow:Assets written to: logs/htune_project_20250815-215150/checkpoint/assets


Epoch 2/2


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(7, 7, 2, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6f045fc210>, 140115259971760), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6f046070d0>, 140115259972096), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6f046070d0>, 140115259972096), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(7, 7, 2, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6f0464f210>, 140115258687056), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(7, 7, 2, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6f0464f210>, 140115258687056), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6f0465a2d0>, 140115258785840), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6f0465a2d0>, 140115258785840), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(7, 7, 2, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6f045fc210>, 140115259971760), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(7, 7, 2, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6f045fc210>, 140115259971760), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6f046070d0>, 140115259972096), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6f046070d0>, 140115259972096), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(7, 7, 2, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6f0464f210>, 140115258687056), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(7, 7, 2, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6f0464f210>, 140115258687056), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6f0465a2d0>, 140115258785840), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6f0465a2d0>, 140115258785840), {}).


INFO:tensorflow:Assets written to: logs/htune_project_20250815-215150/checkpoint/assets


INFO:tensorflow:Assets written to: logs/htune_project_20250815-215150/checkpoint/assets


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(7, 7, 2, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6f045fc210>, 140115259971760), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(7, 7, 2, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6f045fc210>, 140115259971760), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6f046070d0>, 140115259972096), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6f046070d0>, 140115259972096), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(7, 7, 2, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6f0464f210>, 140115258687056), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(7, 7, 2, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6f0464f210>, 140115258687056), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6f0465a2d0>, 140115258785840), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6f0465a2d0>, 140115258785840), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(7, 7, 2, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6f045fc210>, 140115259971760), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(7, 7, 2, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6f045fc210>, 140115259971760), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6f046070d0>, 140115259972096), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6f046070d0>, 140115259972096), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(7, 7, 2, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6f0464f210>, 140115258687056), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(7, 7, 2, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6f0464f210>, 140115258687056), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6f0465a2d0>, 140115258785840), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f6f0465a2d0>, 140115258785840), {}).


INFO:tensorflow:Assets written to: logs/htune_project_20250815-215150/model/assets


INFO:tensorflow:Assets written to: logs/htune_project_20250815-215150/model/assets


In [None]:
# # Test calling the model on some data.
inputs, labels_ = next(iter(train_dataset))
prediction = model.call(inputs)
prediction.shape

In [13]:
import tensorflow as tf
from usl_models.flood_ml.model import FloodModel, SpatialAttention
# Path to your saved model
model_path = "/home/se2890/climateiq-cnn-6/logs/htune_project_20250815-215150/model"
#loaded_model = tf.keras.models.load_model(model_path)
#loaded_model.summary()
# Load the model
model = tf.keras.models.load_model(model_path)

from usl_models.flood_ml.model import SpatialAttention
custom_objects = {'SpatialAttention': SpatialAttention}
loaded_model = tf.keras.models.load_model(
    model_path,
    custom_objects=custom_objects,
    compile=False
)
model.set_weights(loaded_model.get_weights())

# # # Test calling the model for n predictions
full_dataset = load_dataset(sim_names=sim_names, batch_size=4, dataset_split= "train")
inputs, labels = next(iter(full_dataset))
predictions = model.call_n(inputs, n=10)
predictions.shape

TensorShape([4, 10, 1000, 1000])

In [None]:
ref_shapes = None

for i, (inputs, labels) in enumerate(full_dataset):
    current_shapes = (
        inputs["spatiotemporal"].shape,
        inputs["geospatial"].shape,
        inputs["temporal"].shape,
        labels.shape,
    )

    if ref_shapes is None:
        ref_shapes = current_shapes
    else:
        assert current_shapes == ref_shapes, f"Mismatch at batch {i}: {current_shapes} ≠ {ref_shapes}"


In [None]:
full_dataset = load_dataset(sim_names=sim_names, batch_size=4, dataset_split="train")

all_preds = []
all_labels = []

for i, (inputs, labels) in enumerate(full_dataset):
    print(f"\n--- Batch {i} ---")
    
    st = inputs["spatiotemporal"]
    geo = inputs["geospatial"]
    temp = inputs["temporal"]

    print(f"spatiotemporal shape: {st.shape}")
    print(f"geospatial shape:     {geo.shape}")
    print(f"temporal shape:       {temp.shape}")
    print(f"labels shape:         {labels.shape}")

    try:
        preds = model.call_n(inputs, n=10)
        print(f"predictions shape:    {preds.shape}")
        all_preds.append(preds)
        all_labels.append(labels)
    except Exception as e:
        print(f"Error at batch {i}: {e}")
        break


In [None]:
BATCH_SIZE = 4
N_STEPS = 10

all_preds = []
all_labels = []

for i, (inputs, labels) in enumerate(full_dataset):
    current_bs = inputs["spatiotemporal"].shape[0]

    if current_bs < BATCH_SIZE:
        print(f"[Batch {i}] Incomplete batch of size {current_bs}, padding to {BATCH_SIZE}")

        # Repeat the last sample to pad
        repeats = BATCH_SIZE - current_bs

        def pad_tensor(t):
            return tf.concat([t, tf.repeat(t[-1:], repeats=repeats, axis=0)], axis=0)

        padded_inputs = {
            k: pad_tensor(v) for k, v in inputs.items()
        }

        # Predict on padded batch
        preds_padded = model.call_n(padded_inputs, n=N_STEPS)  # [B, T, H, W]

        # Remove the extra samples
        preds = preds_padded[:current_bs]
    else:
        preds = model.call_n(inputs, n=N_STEPS)

    print(f"[Batch {i}] Prediction shape: {preds.shape}")
    all_preds.append(preds)
    all_labels.append(labels)


In [None]:
# After running all batches
final_preds = tf.concat(all_preds, axis=0)  # shape: [N, T, H, W]
max_preds_all = tf.reduce_max(final_preds, axis=1)  # shape: [N, H, W]
max_preds_all.shape

In [None]:
max_labels_all = []

for labels in all_labels:
    max_labels = tf.reduce_max(labels, axis=1)  # shape: [B, H, W]
    max_labels_all.append(max_labels)

# Now stack (uses less memory)
max_labels_all = tf.concat(max_labels_all, axis=0)  # shape: [N, H, W]


In [None]:
max_labels_all.shape

In [None]:
import matplotlib.pyplot as plt

i = 10  # sample index

plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plt.imshow(max_preds_all[i], cmap="Blues")
plt.title("Predicted Max Flood")
plt.colorbar()

plt.subplot(1, 2, 2)
plt.imshow(max_labels_all[i], cmap="Blues")
plt.title("Ground Truth Max Flood")
plt.colorbar()

plt.tight_layout()
plt.show()


In [None]:

import numpy as np
import tensorflow as tf
from google.cloud import storage
import tempfile

# Initialize GCS client
client = storage.Client()
bucket_name = "mloutputstest"
bucket = client.bucket(bucket_name)

# Loop over batches
for i, preds in enumerate(all_preds):  # each preds: [B, T, H, W]
    max_preds = tf.reduce_max(preds, axis=1).numpy()  # shape: [B, H, W]

    for j in range(max_preds.shape[0]):
        sample = max_preds[j]  # shape: [H, W]

        # Save to temporary .npy file
        with tempfile.NamedTemporaryFile(suffix=".npy") as tmp:
            np.save(tmp.name, sample)

            # Upload to GCS
            blob_name = f"predictions/max_flood_batch{i}_sample{j}.npy"
            blob = bucket.blob(blob_name)
            blob.upload_from_filename(tmp.name)

            print(f"Uploaded: gs://{bucket_name}/{blob_name}")


In [None]:
pip install rasterio

In [None]:
import tensorflow as tf
import numpy as np
import rasterio
from rasterio.transform import from_origin
from google.cloud import storage
import tempfile

# Define metadata — adjust as needed
pixel_size = 1  # in meters or units per pixel
top_left_x = 0  # e.g., UTM x or longitude
top_left_y = 0  # e.g., UTM y or latitude

transform = from_origin(top_left_x, top_left_y, pixel_size, pixel_size)
crs = "EPSG:4326"  # or your local UTM projection

# Setup GCS
client = storage.Client()
bucket_name = "mloutputstest"
bucket = client.bucket(bucket_name)

# Loop over prediction batches
for i, preds in enumerate(all_preds):  # shape: [B, T, H, W]
    max_preds = tf.reduce_max(preds, axis=1).numpy()  # shape: [B, H, W]

    for j in range(max_preds.shape[0]):
        sample = max_preds[j]

        with tempfile.NamedTemporaryFile(suffix=".tif") as tmp:
            with rasterio.open(
                tmp.name,
                "w",
                driver="GTiff",
                height=sample.shape[0],
                width=sample.shape[1],
                count=1,
                dtype=sample.dtype,
                crs=crs,
                transform=transform,
            ) as dst:
                dst.write(sample, 1)

            # Upload to GCS
            blob_name = f"predictionstiff/max_flood_batch{i}_sample{j}.tif"
            blob = bucket.blob(blob_name)
            blob.upload_from_filename(tmp.name)

            print(f"Uploaded GeoTIFF: gs://{bucket_name}/{blob_name}")


In [None]:
import tensorflow as tf
from usl_models.flood_ml.model import FloodModel, SpatialAttention
# Path to your saved model
model_path = "/home/se2890/climateiq-cnn-6/logs/htune_project_20250815-144148/model"
#loaded_model = tf.keras.models.load_model(model_path)
#loaded_model.summary()
# Load the model
model = tf.keras.models.load_model(model_path)

from usl_models.flood_ml.model import SpatialAttention
custom_objects = {'SpatialAttention': SpatialAttention}
loaded_model = tf.keras.models.load_model(
    model_path,
    custom_objects=custom_objects,
    compile=False
)
model.set_weights(loaded_model.get_weights())

# # # Test calling the model for n predictions
full_dataset = load_dataset(sim_names=sim_names, batch_size=4, dataset_split= "train")
inputs, labels = next(iter(full_dataset))
predictions = model.call_n(inputs, n=10)
predictions.shape

In [None]:
ref_shapes = None

for i, (inputs, labels) in enumerate(full_dataset):
    current_shapes = (
        inputs["spatiotemporal"].shape,
        inputs["geospatial"].shape,
        inputs["temporal"].shape,
        labels.shape,
    )

    if ref_shapes is None:
        ref_shapes = current_shapes
    else:
        assert current_shapes == ref_shapes, f"Mismatch at batch {i}: {current_shapes} ≠ {ref_shapes}"


In [None]:
full_dataset = load_dataset(sim_names=sim_names, batch_size=4, dataset_split="train")

all_preds = []
all_labels = []

for i, (inputs, labels) in enumerate(full_dataset):
    print(f"\n--- Batch {i} ---")
    
    st = inputs["spatiotemporal"]
    geo = inputs["geospatial"]
    temp = inputs["temporal"]

    print(f"spatiotemporal shape: {st.shape}")
    print(f"geospatial shape:     {geo.shape}")
    print(f"temporal shape:       {temp.shape}")
    print(f"labels shape:         {labels.shape}")

    try:
        preds = model.call_n(inputs, n=10)
        print(f"predictions shape:    {preds.shape}")
        all_preds.append(preds)
        all_labels.append(labels)
    except Exception as e:
        print(f"Error at batch {i}: {e}")
        break


In [None]:
BATCH_SIZE = 4
N_STEPS = 10

all_preds = []
all_labels = []

for i, (inputs, labels) in enumerate(full_dataset):
    current_bs = inputs["spatiotemporal"].shape[0]

    if current_bs < BATCH_SIZE:
        print(f"[Batch {i}] Incomplete batch of size {current_bs}, padding to {BATCH_SIZE}")

        # Repeat the last sample to pad
        repeats = BATCH_SIZE - current_bs

        def pad_tensor(t):
            return tf.concat([t, tf.repeat(t[-1:], repeats=repeats, axis=0)], axis=0)

        padded_inputs = {
            k: pad_tensor(v) for k, v in inputs.items()
        }

        # Predict on padded batch
        preds_padded = model.call_n(padded_inputs, n=N_STEPS)  # [B, T, H, W]

        # Remove the extra samples
        preds = preds_padded[:current_bs]
    else:
        preds = model.call_n(inputs, n=N_STEPS)

    print(f"[Batch {i}] Prediction shape: {preds.shape}")
    all_preds.append(preds)
    all_labels.append(labels)


In [None]:
# After running all batches
final_preds = tf.concat(all_preds, axis=0)  # shape: [N, T, H, W]
max_preds_all = tf.reduce_max(final_preds, axis=1)  # shape: [N, H, W]
max_preds_all.shape

In [None]:
max_labels_all = []

for labels in all_labels:
    max_labels = tf.reduce_max(labels, axis=1)  # shape: [B, H, W]
    max_labels_all.append(max_labels)

# Now stack (uses less memory)
max_labels_all = tf.concat(max_labels_all, axis=0)  # shape: [N, H, W]


In [None]:
max_labels_all.shape

In [None]:
import matplotlib.pyplot as plt

i = 10  # sample index

plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plt.imshow(max_preds_all[i], cmap="Blues")
plt.title("Predicted Max Flood")
plt.colorbar()

plt.subplot(1, 2, 2)
plt.imshow(max_labels_all[i], cmap="Blues")
plt.title("Ground Truth Max Flood")
plt.colorbar()

plt.tight_layout()
plt.show()


In [None]:

import numpy as np
import tensorflow as tf
from google.cloud import storage
import tempfile

# Initialize GCS client
client = storage.Client()
bucket_name = "mloutputstest"
bucket = client.bucket(bucket_name)

# Loop over batches
for i, preds in enumerate(all_preds):  # each preds: [B, T, H, W]
    max_preds = tf.reduce_max(preds, axis=1).numpy()  # shape: [B, H, W]

    for j in range(max_preds.shape[0]):
        sample = max_preds[j]  # shape: [H, W]

        # Save to temporary .npy file
        with tempfile.NamedTemporaryFile(suffix=".npy") as tmp:
            np.save(tmp.name, sample)

            # Upload to GCS
            blob_name = f"predictions/max_flood_batch{i}_sample{j}.npy"
            blob = bucket.blob(blob_name)
            blob.upload_from_filename(tmp.name)

            print(f"Uploaded: gs://{bucket_name}/{blob_name}")


In [None]:
pip install rasterio

In [None]:
import tensorflow as tf
import numpy as np
import rasterio
from rasterio.transform import from_origin
from google.cloud import storage
import tempfile

# Define metadata — adjust as needed
pixel_size = 1  # in meters or units per pixel
top_left_x = 0  # e.g., UTM x or longitude
top_left_y = 0  # e.g., UTM y or latitude

transform = from_origin(top_left_x, top_left_y, pixel_size, pixel_size)
crs = "EPSG:4326"  # or your local UTM projection

# Setup GCS
client = storage.Client()
bucket_name = "mloutputstest"
bucket = client.bucket(bucket_name)

# Loop over prediction batches
for i, preds in enumerate(all_preds):  # shape: [B, T, H, W]
    max_preds = tf.reduce_max(preds, axis=1).numpy()  # shape: [B, H, W]

    for j in range(max_preds.shape[0]):
        sample = max_preds[j]

        with tempfile.NamedTemporaryFile(suffix=".tif") as tmp:
            with rasterio.open(
                tmp.name,
                "w",
                driver="GTiff",
                height=sample.shape[0],
                width=sample.shape[1],
                count=1,
                dtype=sample.dtype,
                crs=crs,
                transform=transform,
            ) as dst:
                dst.write(sample, 1)

            # Upload to GCS
            blob_name = f"predictionstiff/max_flood_batch{i}_sample{j}.tif"
            blob = bucket.blob(blob_name)
            blob.upload_from_filename(tmp.name)

            print(f"Uploaded GeoTIFF: gs://{bucket_name}/{blob_name}")


In [None]:
import tensorflow as tf
from usl_models.flood_ml.model import FloodModel, SpatialAttention
# Path to your saved model
model_path = "/home/se2890/climateiq-cnn-5/logs/htune_project_20250804-182136/model"
loaded_model = tf.keras.models.load_model(model_path)
loaded_model.summary()
# Load the model
# model = tf.keras.models.load_model(model_path)
# model = FloodModel.from_checkpoint(model_path)

from usl_models.flood_ml.model import SpatialAttention
custom_objects = {'SpatialAttention': SpatialAttention}
loaded_model = tf.keras.models.load_model(
    model_path,
    custom_objects=custom_objects,
    compile=False
)
model.set_weights(loaded_model.get_weights())

# # Test calling the model for n predictions
full_dataset = load_dataset(sim_names=sim_names, batch_size=4, dataset_split= "train")
inputs, labels = next(iter(full_dataset))
predictions = model.call_n(inputs, n=4)
predictions.shape

In [None]:
loss_scale = best_hp.get("loss_scale")
print("Loss scale used during training:", loss_scale)

In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

from usl_models.flood_ml.dataset import load_dataset_windowed
from usl_models.flood_ml import constants

# Path to trained model
# Known value used during training
loss_scale = 200.0

# Path to trained model
model_path = "/home/se2890/climateiq-cnn-5/logs/htune_project_20250801-155126/model"

# Create the loss function with the correct scale
loss_fn = customloss.make_hybrid_loss(scale=loss_scale)

# Load model with custom loss function
model = tf.keras.models.load_model(model_path, custom_objects={"loss_fn": loss_fn})
# Number of samples to visualize
n_samples = 20

# Loop through the dataset and predict
for i, (input_data, ground_truth) in enumerate(validation_dataset.take(n_samples)):
    ground_truth = ground_truth.numpy().squeeze()
    prediction = model(input_data).numpy().squeeze()

    print(f"\nSample {i+1} Prediction Stats:")
    print("  Min:", prediction.min())
    print("  Max:", prediction.max())
    print("  Mean:", prediction.mean())

    # Choose timestep to plot
    timestep = 3
    gt_t = ground_truth[timestep]
    pred_t = prediction[timestep]
    vmax_val = max(gt_t.max(), pred_t.max())

    # Plot Ground Truth and Prediction
    fig, axes = plt.subplots(1, 2, figsize=(14, 6))
    fig.suptitle(f"Sample {i+1} - Timestep {timestep}", fontsize=16)

    im1 = axes[0].imshow(gt_t, cmap="Blues", vmin=0, vmax=vmax_val)
    axes[0].set_title("Ground Truth")
    axes[0].axis("off")
    plt.colorbar(im1, ax=axes[0], shrink=0.8)

    im2 = axes[1].imshow(pred_t, cmap="Blues", vmin=0, vmax=vmax_val)
    axes[1].set_title("Prediction")
    axes[1].axis("off")
    plt.colorbar(im2, ax=axes[1], shrink=0.8)

    plt.tight_layout()
    plt.show()

In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

from usl_models.flood_ml.dataset import load_dataset_windowed
from usl_models.flood_ml import constants
from sklearn.metrics import mean_absolute_error, mean_squared_error
from skimage.metrics import structural_similarity as ssim
import pandas as pd

# Path to trained model
# Known value used during training
loss_scale = 150.0

# Path to trained model
model_path = "/home/elhajjas/climateiq-cnn-11/usl_models/notebooks/logs/htune_project_20250611-205219/model"

# Create the loss function with the correct scale
loss_fn = customloss.make_hybrid_loss(scale=loss_scale)

# Load model with custom loss function
model = tf.keras.models.load_model(model_path, custom_objects={"loss_fn": loss_fn})


# Assuming validation_dataset is already defined
# Example:
# from usl_models.flood_ml.dataset import load_dataset_windowed
# validation_dataset = load_dataset_windowed(...)

n_samples = 20
timestep = 2
metrics_list = []

for i, (input_data, ground_truth) in enumerate(validation_dataset.take(n_samples)):
    ground_truth = ground_truth.numpy().squeeze()
    prediction = model(input_data).numpy().squeeze()

    gt_t = ground_truth[timestep]
    pred_t = prediction[timestep]
    vmax_val = np.nanpercentile([gt_t, pred_t], 99.5)

    # Mask out NaNs
    mask = ~np.isnan(gt_t)
    gt_flat = gt_t[mask].flatten()
    pred_flat = pred_t[mask].flatten()

    mae = mean_absolute_error(gt_flat, pred_flat)
    rmse = np.sqrt(mean_squared_error(gt_flat, pred_flat))
    bias = np.mean(pred_flat) - np.mean(gt_flat)
    iou = np.logical_and(gt_flat > 0.1, pred_flat > 0.1).sum() / max(1, np.logical_or(gt_flat > 0.1, pred_flat > 0.1).sum())
    ssim_val = ssim(gt_t, pred_t, data_range=gt_t.max() - gt_t.min())

    metrics_list.append({
        "Sample": i+1,
        "MAE": mae,
        "RMSE": rmse,
        "Bias": bias,
        "IoU > 0.1": iou,
        "SSIM": ssim_val
    })

    # Plot
    fig, axes = plt.subplots(1, 2, figsize=(14, 6))
    fig.suptitle(f"Sample {i+1} - Timestep {timestep}", fontsize=16)

    im1 = axes[0].imshow(gt_t, cmap="Blues", vmin=0, vmax=vmax_val)
    axes[0].set_title("Ground Truth")
    axes[0].axis("off")
    plt.colorbar(im1, ax=axes[0], shrink=0.8)

    im2 = axes[1].imshow(pred_t, cmap="Blues", vmin=0, vmax=vmax_val)
    axes[1].set_title("Prediction")
    axes[1].axis("off")
    plt.colorbar(im2, ax=axes[1], shrink=0.8)

    plt.tight_layout()
    plt.show()

# Convert to DataFrame
df = pd.DataFrame(metrics_list)
print("\n=== Metrics Summary ===")
print(df.describe())


In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

from usl_models.flood_ml.dataset import load_dataset_windowed
from usl_models.flood_ml import constants
from usl_models.flood_ml import customloss
from sklearn.metrics import mean_absolute_error, mean_squared_error
from skimage.metrics import structural_similarity as ssim
import pandas as pd

# Parameters
loss_scale = 200.0
timestep = 3
n_samples = 20

# Paths to models
model_path_1 = "/home/elhajjas/climateiq-cnn-11/usl_models/notebooks/logs/attention/model"
model_path_2 = "/home/elhajjas/climateiq-cnn-11/usl_models/notebooks/logs/htune_project_20250612-010926/model"

# Loss function
loss_fn = customloss.make_hybrid_loss(scale=loss_scale)

# Load models
model_1 = tf.keras.models.load_model(model_path_1, custom_objects={"loss_fn": loss_fn})
model_2 = tf.keras.models.load_model(model_path_2, custom_objects={"loss_fn": loss_fn})

# Load validation dataset (ensure it's already prepared)
# Example:
# validation_dataset = load_dataset_windowed(...)

metrics_list = []

for i, (input_data, ground_truth) in enumerate(train_dataset.take(n_samples)):
    ground_truth = ground_truth.numpy().squeeze()

    pred_1 = model_1(input_data).numpy().squeeze()
    pred_2 = model_2(input_data).numpy().squeeze()

    gt_t = ground_truth[timestep]
    pred_1_t = pred_1[timestep]
    pred_2_t = pred_2[timestep]
    vmax_val = np.nanpercentile([gt_t, pred_1_t, pred_2_t], 99.5)

    mask = ~np.isnan(gt_t)
    gt_flat = gt_t[mask].flatten()
    pred_1_flat = pred_1_t[mask].flatten()
    pred_2_flat = pred_2_t[mask].flatten()

    # Compute metrics
    metrics_list.append({
        "Sample": i+1,
        "MAE_1": mean_absolute_error(gt_flat, pred_1_flat),
        "RMSE_1": np.sqrt(mean_squared_error(gt_flat, pred_1_flat)),
        "Bias_1": np.mean(pred_1_flat) - np.mean(gt_flat),
        "IoU_1": np.logical_and(gt_flat > 0.1, pred_1_flat > 0.1).sum() / max(1, np.logical_or(gt_flat > 0.1, pred_1_flat > 0.1).sum()),
        "SSIM_1": ssim(gt_t, pred_1_t, data_range=gt_t.max() - gt_t.min()),

        "MAE_2": mean_absolute_error(gt_flat, pred_2_flat),
        "RMSE_2": np.sqrt(mean_squared_error(gt_flat, pred_2_flat)),
        "Bias_2": np.mean(pred_2_flat) - np.mean(gt_flat),
        "IoU_2": np.logical_and(gt_flat > 0.1, pred_2_flat > 0.1).sum() / max(1, np.logical_or(gt_flat > 0.1, pred_2_flat > 0.1).sum()),
        "SSIM_2": ssim(gt_t, pred_2_t, data_range=gt_t.max() - gt_t.min()),
    })

    # Plotting
    fig, axes = plt.subplots(1, 3, figsize=(21, 6))
    fig.suptitle(f"Sample {i+1} - Timestep {timestep}", fontsize=16)

    axes[0].imshow(gt_t, cmap="Blues", vmin=0, vmax=vmax_val)
    axes[0].set_title("Ground Truth")
    axes[0].axis("off")

    axes[1].imshow(pred_1_t, cmap="Blues", vmin=0, vmax=vmax_val)
    axes[1].set_title("attention")
    axes[1].axis("off")

    axes[2].imshow(pred_2_t, cmap="Blues", vmin=0, vmax=vmax_val)
    axes[2].set_title("without attention")
    axes[2].axis("off")

    plt.tight_layout()
    plt.show()

# Summary metrics
df = pd.DataFrame(metrics_list)
print("\n=== Metrics Summary ===")
print(df.describe())
