# Flood Model Training Notebook

Train a Flood ConvLSTM Model using `usl_models` lib.

In [9]:
import tensorflow as tf
import keras_tuner
import time
import keras
import logging
from usl_models.flood_ml import constants
from usl_models.flood_ml.model import FloodModel
from usl_models.flood_ml.model_params import FloodModelParams
from usl_models.flood_ml.dataset import load_dataset_windowed, load_dataset
from usl_models.flood_ml import customloss, dataset

# Setup
logging.getLogger().setLevel(logging.WARNING)
keras.utils.set_random_seed(812)

for gpu in tf.config.list_physical_devices("GPU"):
    tf.config.experimental.set_memory_growth(gpu, True)

timestamp = time.strftime("%Y%m%d-%H%M%S")

# Cities and their config folders
city_config_mapping = {
    "Manhattan": "Manhattan_config",
    "Atlanta": "Atlanta_config",
    # "Phoenix_SM": "PHX_SM",
    # "Phoenix_PV": "PHX_PV",
}

# Rainfall files you want
rainfall_files = [5]  # Only 5 and 6

# Generate sim_names
sim_names = []
for city, config in city_config_mapping.items():
    for rain_id in rainfall_files:
        sim_name = f"{city}-{config}/Rainfall_Data_{rain_id}.txt"
        sim_names.append(sim_name)

print(f"Training on {len(sim_names)} simulations.")
for s in sim_names:
    print(s)

# Now load dataset
ds_config = dataset.Config(
    input_height=10, input_width=10,
    output_height=10, output_width=10
)

train_dataset = load_dataset_windowed(
    sim_names=sim_names, batch_size=4, dataset_split="train", ds_config=ds_config
).cache()

validation_dataset = load_dataset_windowed(
    sim_names=sim_names, batch_size=4, dataset_split="val", ds_config=ds_config
).cache()

Training on 2 simulations.
Manhattan-Manhattan_config/Rainfall_Data_5.txt
Atlanta-Atlanta_config/Rainfall_Data_5.txt


In [11]:
log_dir = f"logs/htune_project_{timestamp}"
print(log_dir)
tb_callback = keras.callbacks.TensorBoard(log_dir=log_dir)

tuner = keras_tuner.BayesianOptimization(
    FloodModel.get_hypermodel(
        lstm_units=[32, 64, 128],
        lstm_kernel_size=[3, 5],
        lstm_dropout=[0.2, 0.3],
        lstm_recurrent_dropout=[0.2, 0.3],
        n_flood_maps=[5],
        m_rainfall=[6],
        loss_scale=[50.0, 100.0, 200.0],
    ),
    objective="val_loss",
    max_trials=1,
    project_name=log_dir,
)

tuner.search(
    train_dataset,
    epochs=2,
    validation_data=validation_dataset,
    callbacks=[tb_callback],
)

best_model, best_hp = tuner.get_best_models()[0], tuner.get_best_hyperparameters()[0]
best_hp.values

Trial 1 Complete [00h 01m 24s]
val_loss: 0.006333703175187111

Best val_loss So Far: 0.006333703175187111
Total elapsed time: 00h 01m 24s












































































































































































{'lstm_units': 64,
 'lstm_kernel_size': 3,
 'lstm_dropout': 0.3,
 'lstm_recurrent_dropout': 0.3,
 'n_flood_maps': 5,
 'm_rainfall': 6,
 'loss_scale': 200.0}

In [None]:
tuner = keras_tuner.BayesianOptimization(
    FloodModel.get_hypermodel(
        lstm_units=[32, 64, 128],
        lstm_kernel_size=[3, 5],
        lstm_dropout=[0.2, 0.3],
        lstm_recurrent_dropout=[0.2, 0.3],
        n_flood_maps=[5],
        m_rainfall=[6],
        loss_scale=[50.0, 100.0, 200.0],
    ),
    objective="val_loss",
    max_trials=10,
    project_name=f"logs/htune_project_{timestamp}",
)

tuner.search_space_summary()

In [19]:
tuner.search_space_summary()

Search space summary
Default search space size: 7
lstm_units (Choice)
{'default': 32, 'conditions': [], 'values': [32, 64, 128], 'ordered': True}
lstm_kernel_size (Choice)
{'default': 3, 'conditions': [], 'values': [3, 5], 'ordered': True}
lstm_dropout (Choice)
{'default': 0.2, 'conditions': [], 'values': [0.2, 0.3], 'ordered': True}
lstm_recurrent_dropout (Choice)
{'default': 0.2, 'conditions': [], 'values': [0.2, 0.3], 'ordered': True}
n_flood_maps (Choice)
{'default': 5, 'conditions': [], 'values': [5], 'ordered': True}
m_rainfall (Choice)
{'default': 6, 'conditions': [], 'values': [6], 'ordered': True}
loss_scale (Choice)
{'default': 50.0, 'conditions': [], 'values': [50.0, 100.0, 200.0], 'ordered': True}


In [None]:
log_dir = f"logs/htune_project_{timestamp}"
print(log_dir)
tb_callback = keras.callbacks.TensorBoard(log_dir=log_dir)
tuner.search(
    train_dataset,
    epochs=2,
    validation_data=validation_dataset,
    callbacks=[tb_callback],
)
best_model, best_hp = tuner.get_best_models()[0], tuner.get_best_hyperparameters()[0]
best_hp.values

In [17]:

def from_checkpoint(artifact_uri: str, **kwargs):
    """Loads the FloodModel from a checkpoint, allowing new spatial shapes."""

    # Load the model to get weights only
    loaded_model = keras.models.load_model(artifact_uri)

    # Rebuild model architecture dynamically
    params = FloodModel.Params.from_config(loaded_model.get_config())
    print(f"Loaded model params: {params}")
    model = FloodModel(params=params, **kwargs)

    # Now set weights
    model._model.set_weights(loaded_model.get_weights())

    return model

In [12]:
from keras.callbacks import ModelCheckpoint, EarlyStopping

# Define final parameters and model
final_params_dict = best_hp.values.copy()
loss_scale = final_params_dict.pop("loss_scale", 100.0)
final_params = FloodModel.Params(**final_params_dict)
model = FloodModel(params=final_params, loss_scale=loss_scale)
# Define callbacks
callbacks = [
    keras.callbacks.TensorBoard(log_dir=log_dir),
    ModelCheckpoint(
        filepath=log_dir + "/checkpoint",
        save_best_only=True,
        monitor="val_loss",
        mode="min",
        save_format="tf",
    ),
    EarlyStopping(  # <--- ADD THIS
        monitor="val_loss",  # What to monitor
        patience=100,  # Number of epochs with no improvement to wait
        restore_best_weights=True,  # Restore model weights from best epoch
        mode="min",  # "min" because lower val_loss is better
    ),
]

# Train
model.fit(train_dataset, validation_dataset, epochs=2, callbacks=callbacks)

model.save_model(log_dir + "/model")

Epoch 1/2


2025-06-26 18:25:10.881894: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:961] layout failed: INVALID_ARGUMENT: Size of values 0 does not match size of permutation 4 @ fanin shape inflood_conv_lstm_1/conv_lstm/conv_lstm2d_1/while/body/_1/flood_conv_lstm_1/conv_lstm/conv_lstm2d_1/while/dropout_7/SelectV2-2-TransposeNHWCToNCHW-LayoutOptimizer


     89/Unknown - 8s 25ms/step - loss: 0.0061 - mean_absolute_error: 0.0292 - root_mean_squared_error: 0.1137

2025-06-26 18:25:15.047877: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 15534628773044783567
2025-06-26 18:25:15.047929: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 13295164407447007945
2025-06-26 18:25:15.047939: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 12771502634948723015
2025-06-26 18:25:15.047947: I tensorflow/core/framework/local_rendezvous.cc:421] Local rendezvous recv item cancelled. Key hash: 8480965922338745377


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(7, 7, 2, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f9e7c1127d0>, 140317996690080), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(7, 7, 2, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f9e7c1127d0>, 140317996690080), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f9e7c371310>, 140317885506240), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f9e7c371310>, 140317885506240), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(7, 7, 2, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f9e7c342150>, 140317994654544), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(7, 7, 2, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f9e7c342150>, 140317994654544), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f9e7c322d10>, 140317994654880), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f9e7c322d10>, 140317994654880), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(7, 7, 2, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f9e7c1127d0>, 140317996690080), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(7, 7, 2, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f9e7c1127d0>, 140317996690080), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f9e7c371310>, 140317885506240), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f9e7c371310>, 140317885506240), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(7, 7, 2, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f9e7c342150>, 140317994654544), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(7, 7, 2, 1), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f9e7c342150>, 140317994654544), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f9e7c322d10>, 140317994654880), {}).


INFO:tensorflow:Unsupported signature for serialization: ((TensorSpec(shape=(1,), dtype=tf.float32, name='gradient'), <tensorflow.python.framework.func_graph.UnknownArgument object at 0x7f9e7c322d10>, 140317994654880), {}).


INFO:tensorflow:Assets written to: logs/htune_project_20250626-182308/checkpoint/assets


INFO:tensorflow:Assets written to: logs/htune_project_20250626-182308/checkpoint/assets


Epoch 2/2
INFO:tensorflow:Assets written to: logs/htune_project_20250626-182308/model/assets


INFO:tensorflow:Assets written to: logs/htune_project_20250626-182308/model/assets


In [20]:
best_hp.values

{'lstm_units': 64,
 'lstm_kernel_size': 3,
 'lstm_dropout': 0.3,
 'lstm_recurrent_dropout': 0.3,
 'n_flood_maps': 5,
 'm_rainfall': 6,
 'loss_scale': 200.0}

In [18]:
ds_config = dataset.Config(
    input_width=5, input_height=5, output_width=5, output_height=5
)
val_ds = load_dataset_windowed(
    sim_names=sim_names, batch_size=4, dataset_split="val", ds_config=ds_config
).cache()
modello = from_checkpoint(log_dir + "/model")
input_batch, label_batch = next(iter(val_ds))
pred_batch = modello.call(input_batch)





Loaded model params: FloodModel.Params(lstm_units=64, lstm_kernel_size=3, lstm_dropout=0.3, lstm_recurrent_dropout=0.3, m_rainfall=6, n_flood_maps=5, num_features=22, pad_mode='REFLECT', optimizer=<keras.src.optimizers.adam.Adam object at 0x7f9e4f73f490>)


ValueError: You called `set_weights(weights)` on layer "flood_conv_lstm_3" with a weight list of length 19, but the layer was expecting 15 weights. Provided weights: [array([[[[-0.0218656 ,  0.10016118,  0.12838762, ...

In [None]:
saved_model_path = "logs/htune_project_20250625-184752/model_dynamic"
loaded_model = tf.saved_model.load(saved_model_path)


In [None]:
predict_fn = loaded_model.signatures["serving_default"]
inputs = {
    "geospatial": tf.random.uniform((1, 64, 64, constants.GEO_FEATURES)),  # shape can vary
    "spatiotemporal": tf.random.uniform((1, final_params.n_flood_maps, 64, 64, 1)),
    "temporal": tf.random.uniform((1, constants.MAX_RAINFALL_DURATION, final_params.m_rainfall)),
}


In [1]:
ds_config = dataset.Config(
    input_width=5, input_height=5, output_width=5, output_height=5
)
val_ds = load_dataset_windowed(
    sim_names=sim_names, batch_size=2, dataset_split="val", ds_config=ds_config
).cache()

input_batch, label_batch = next(iter(val_ds))
pred_batch = model.call(input_batch)

NameError: name 'dataset' is not defined

In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

from usl_models.flood_ml.dataset import load_dataset_windowed
from usl_models.flood_ml import constants
from sklearn.metrics import mean_absolute_error, mean_squared_error
from skimage.metrics import structural_similarity as ssim
import pandas as pd

# Path to trained model
# Known value used during training
loss_scale = 150.0

# Path to trained model
model_path = "/home/elhajjas/climateiq-cnn-11/usl_models/notebooks/logs/htune_project_20250611-205219/model"

# Create the loss function with the correct scale
loss_fn = customloss.make_hybrid_loss(scale=loss_scale)

# Load model with custom loss function
model = tf.keras.models.load_model(model_path, custom_objects={"loss_fn": loss_fn})


# Assuming validation_dataset is already defined
# Example:
# from usl_models.flood_ml.dataset import load_dataset_windowed
# validation_dataset = load_dataset_windowed(...)

n_samples = 20
timestep = 2
metrics_list = []

for i, (input_data, ground_truth) in enumerate(validation_dataset.take(n_samples)):
    ground_truth = ground_truth.numpy().squeeze()
    prediction = model(input_data).numpy().squeeze()

    gt_t = ground_truth[timestep]
    pred_t = prediction[timestep]
    vmax_val = np.nanpercentile([gt_t, pred_t], 99.5)

    # Mask out NaNs
    mask = ~np.isnan(gt_t)
    gt_flat = gt_t[mask].flatten()
    pred_flat = pred_t[mask].flatten()

    mae = mean_absolute_error(gt_flat, pred_flat)
    rmse = np.sqrt(mean_squared_error(gt_flat, pred_flat))
    bias = np.mean(pred_flat) - np.mean(gt_flat)
    iou = np.logical_and(gt_flat > 0.1, pred_flat > 0.1).sum() / max(1, np.logical_or(gt_flat > 0.1, pred_flat > 0.1).sum())
    ssim_val = ssim(gt_t, pred_t, data_range=gt_t.max() - gt_t.min())

    metrics_list.append({
        "Sample": i+1,
        "MAE": mae,
        "RMSE": rmse,
        "Bias": bias,
        "IoU > 0.1": iou,
        "SSIM": ssim_val
    })

    # Plot
    fig, axes = plt.subplots(1, 2, figsize=(14, 6))
    fig.suptitle(f"Sample {i+1} - Timestep {timestep}", fontsize=16)

    im1 = axes[0].imshow(gt_t, cmap="Blues", vmin=0, vmax=vmax_val)
    axes[0].set_title("Ground Truth")
    axes[0].axis("off")
    plt.colorbar(im1, ax=axes[0], shrink=0.8)

    im2 = axes[1].imshow(pred_t, cmap="Blues", vmin=0, vmax=vmax_val)
    axes[1].set_title("Prediction")
    axes[1].axis("off")
    plt.colorbar(im2, ax=axes[1], shrink=0.8)

    plt.tight_layout()
    plt.show()

# Convert to DataFrame
df = pd.DataFrame(metrics_list)
print("\n=== Metrics Summary ===")
print(df.describe())


In [None]:
flood_model=model._model
@tf.function(
    input_signature=[{
        "geospatial": tf.TensorSpec(shape=(None, None, None, constants.GEO_FEATURES), dtype=tf.float32),
        "spatiotemporal": tf.TensorSpec(shape=(None, final_params.n_flood_maps, None, None, 1), dtype=tf.float32),
        "temporal": tf.TensorSpec(shape=(None, constants.MAX_RAINFALL_DURATION, final_params.m_rainfall), dtype=tf.float32)
    }]
)
def dynamic_call(inputs):
    return flood_model(inputs, training=False)
tf.saved_model.save(
    flood_model,
    export_dir=log_dir + "/model_dynamic",
    signatures={"serving_default": dynamic_call}
)


In [None]:
modello = tf.saved_model.load(log_dir + "/model_dynamic")
predict_fn = modello.signatures["serving_default"]
input_batch, label_batch = next(iter(val_ds))
output = predict_fn(input_batch)


In [None]:
# # Test calling the model on some data.
inputs, labels_ = next(iter(train_dataset))
prediction = model.call(inputs)
prediction.shape

In [None]:
# # Test calling the model for n predictions
# full_dataset = load_dataset(sim_names=sim_names, batch_size=1)
# inputs, labels = next(iter(full_dataset))
# predictions = model.call_n(inputs, n=4)
# predictions.shape

In [None]:
model = FloodModel.from_checkpoint("/home/elhajjas/climateiq-cnn-11/logs/htune_project_20250623-225804/model")

In [None]:
# model_path = "/home/elhajjas/climateiq-cnn-11/logs/htune_project_20250623-223109/model"

# # Create the loss function with the correct scale
# loss_fn = customloss.make_hybrid_loss(scale=loss_scale)

# # Load model with custom loss function
# model = tf.keras.models.load_model(model_path, custom_objects={"loss_fn": loss_fn})

# model = AtmoModel.from_checkpoint("logs/main_20250319-204950/model")


# After refactor you already have this
model = FloodModel.from_checkpoint("/home/elhajjas/climateiq-cnn-11/logs/htune_project_20250625-184752/model")


ds_config = dataset.Config(
    input_width=5, input_height=5, output_width=5, output_height=5
)
val_ds = load_dataset_windowed(
    sim_names=sim_names, batch_size=4, dataset_split="val", ds_config=ds_config
).cache()

input_batch, label_batch = next(iter(val_ds))
pred_batch = model.call(input_batch)

In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

from usl_models.flood_ml.dataset import load_dataset_windowed
from usl_models.flood_ml import constants

# Path to trained model
# Known value used during training
loss_scale = 200.0

# Path to trained model
model_path = "logs/htune_project_20250625-184752/model_dynamic"

# Create the loss function with the correct scale
loss_fn = customloss.make_hybrid_loss(scale=loss_scale)

# Load model with custom loss function
model = tf.keras.models.load_model(model_path, custom_objects={"loss_fn": loss_fn})



# Number of samples to visualize
n_samples = 20

# Loop through the dataset and predict
for i, (input_data, ground_truth) in enumerate(val_ds.take(n_samples)):
    ground_truth = ground_truth.numpy().squeeze()
    prediction = model(input_data).numpy().squeeze()

    print(f"\nSample {i+1} Prediction Stats:")
    print("  Min:", prediction.min())
    print("  Max:", prediction.max())
    print("  Mean:", prediction.mean())

    # Choose timestep to plot
    timestep = 3
    gt_t = ground_truth[timestep]
    pred_t = prediction[timestep]
    vmax_val = max(gt_t.max(), pred_t.max())

    # Plot Ground Truth and Prediction
    fig, axes = plt.subplots(1, 2, figsize=(14, 6))
    fig.suptitle(f"Sample {i+1} - Timestep {timestep}", fontsize=16)

    im1 = axes[0].imshow(gt_t, cmap="Blues", vmin=0, vmax=vmax_val)
    axes[0].set_title("Ground Truth")
    axes[0].axis("off")
    plt.colorbar(im1, ax=axes[0], shrink=0.8)

    im2 = axes[1].imshow(pred_t, cmap="Blues", vmin=0, vmax=vmax_val)
    axes[1].set_title("Prediction")
    axes[1].axis("off")
    plt.colorbar(im2, ax=axes[1], shrink=0.8)

    plt.tight_layout()
    plt.show()

In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

from usl_models.flood_ml.dataset import load_dataset_windowed
from usl_models.flood_ml import constants
from sklearn.metrics import mean_absolute_error, mean_squared_error
from skimage.metrics import structural_similarity as ssim
import pandas as pd

# Path to trained model
# Known value used during training
loss_scale = 150.0

# Path to trained model
model_path = "/home/elhajjas/climateiq-cnn-11/usl_models/notebooks/logs/htune_project_20250611-205219/model"

# Create the loss function with the correct scale
loss_fn = customloss.make_hybrid_loss(scale=loss_scale)

# Load model with custom loss function
model = tf.keras.models.load_model(model_path, custom_objects={"loss_fn": loss_fn})


# Assuming validation_dataset is already defined
# Example:
# from usl_models.flood_ml.dataset import load_dataset_windowed
# validation_dataset = load_dataset_windowed(...)

n_samples = 20
timestep = 2
metrics_list = []

for i, (input_data, ground_truth) in enumerate(validation_dataset.take(n_samples)):
    ground_truth = ground_truth.numpy().squeeze()
    prediction = model(input_data).numpy().squeeze()

    gt_t = ground_truth[timestep]
    pred_t = prediction[timestep]
    vmax_val = np.nanpercentile([gt_t, pred_t], 99.5)

    # Mask out NaNs
    mask = ~np.isnan(gt_t)
    gt_flat = gt_t[mask].flatten()
    pred_flat = pred_t[mask].flatten()

    mae = mean_absolute_error(gt_flat, pred_flat)
    rmse = np.sqrt(mean_squared_error(gt_flat, pred_flat))
    bias = np.mean(pred_flat) - np.mean(gt_flat)
    iou = np.logical_and(gt_flat > 0.1, pred_flat > 0.1).sum() / max(1, np.logical_or(gt_flat > 0.1, pred_flat > 0.1).sum())
    ssim_val = ssim(gt_t, pred_t, data_range=gt_t.max() - gt_t.min())

    metrics_list.append({
        "Sample": i+1,
        "MAE": mae,
        "RMSE": rmse,
        "Bias": bias,
        "IoU > 0.1": iou,
        "SSIM": ssim_val
    })

    # Plot
    fig, axes = plt.subplots(1, 2, figsize=(14, 6))
    fig.suptitle(f"Sample {i+1} - Timestep {timestep}", fontsize=16)

    im1 = axes[0].imshow(gt_t, cmap="Blues", vmin=0, vmax=vmax_val)
    axes[0].set_title("Ground Truth")
    axes[0].axis("off")
    plt.colorbar(im1, ax=axes[0], shrink=0.8)

    im2 = axes[1].imshow(pred_t, cmap="Blues", vmin=0, vmax=vmax_val)
    axes[1].set_title("Prediction")
    axes[1].axis("off")
    plt.colorbar(im2, ax=axes[1], shrink=0.8)

    plt.tight_layout()
    plt.show()

# Convert to DataFrame
df = pd.DataFrame(metrics_list)
print("\n=== Metrics Summary ===")
print(df.describe())


In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

from usl_models.flood_ml.dataset import load_dataset_windowed
from usl_models.flood_ml import constants
from usl_models.flood_ml import customloss
from sklearn.metrics import mean_absolute_error, mean_squared_error
from skimage.metrics import structural_similarity as ssim
import pandas as pd

# Parameters
loss_scale = 200.0
timestep = 3
n_samples = 20

# Paths to models
model_path_1 = "/home/elhajjas/climateiq-cnn-11/usl_models/notebooks/logs/attention/model"
model_path_2 = "/home/elhajjas/climateiq-cnn-11/usl_models/notebooks/logs/htune_project_20250612-010926/model"

# Loss function
loss_fn = customloss.make_hybrid_loss(scale=loss_scale)

# Load models
model_1 = tf.keras.models.load_model(model_path_1, custom_objects={"loss_fn": loss_fn})
model_2 = tf.keras.models.load_model(model_path_2, custom_objects={"loss_fn": loss_fn})

# Load validation dataset (ensure it's already prepared)
# Example:
# validation_dataset = load_dataset_windowed(...)

metrics_list = []

for i, (input_data, ground_truth) in enumerate(train_dataset.take(n_samples)):
    ground_truth = ground_truth.numpy().squeeze()

    pred_1 = model_1(input_data).numpy().squeeze()
    pred_2 = model_2(input_data).numpy().squeeze()

    gt_t = ground_truth[timestep]
    pred_1_t = pred_1[timestep]
    pred_2_t = pred_2[timestep]
    vmax_val = np.nanpercentile([gt_t, pred_1_t, pred_2_t], 99.5)

    mask = ~np.isnan(gt_t)
    gt_flat = gt_t[mask].flatten()
    pred_1_flat = pred_1_t[mask].flatten()
    pred_2_flat = pred_2_t[mask].flatten()

    # Compute metrics
    metrics_list.append({
        "Sample": i+1,
        "MAE_1": mean_absolute_error(gt_flat, pred_1_flat),
        "RMSE_1": np.sqrt(mean_squared_error(gt_flat, pred_1_flat)),
        "Bias_1": np.mean(pred_1_flat) - np.mean(gt_flat),
        "IoU_1": np.logical_and(gt_flat > 0.1, pred_1_flat > 0.1).sum() / max(1, np.logical_or(gt_flat > 0.1, pred_1_flat > 0.1).sum()),
        "SSIM_1": ssim(gt_t, pred_1_t, data_range=gt_t.max() - gt_t.min()),

        "MAE_2": mean_absolute_error(gt_flat, pred_2_flat),
        "RMSE_2": np.sqrt(mean_squared_error(gt_flat, pred_2_flat)),
        "Bias_2": np.mean(pred_2_flat) - np.mean(gt_flat),
        "IoU_2": np.logical_and(gt_flat > 0.1, pred_2_flat > 0.1).sum() / max(1, np.logical_or(gt_flat > 0.1, pred_2_flat > 0.1).sum()),
        "SSIM_2": ssim(gt_t, pred_2_t, data_range=gt_t.max() - gt_t.min()),
    })

    # Plotting
    fig, axes = plt.subplots(1, 3, figsize=(21, 6))
    fig.suptitle(f"Sample {i+1} - Timestep {timestep}", fontsize=16)

    axes[0].imshow(gt_t, cmap="Blues", vmin=0, vmax=vmax_val)
    axes[0].set_title("Ground Truth")
    axes[0].axis("off")

    axes[1].imshow(pred_1_t, cmap="Blues", vmin=0, vmax=vmax_val)
    axes[1].set_title("attention")
    axes[1].axis("off")

    axes[2].imshow(pred_2_t, cmap="Blues", vmin=0, vmax=vmax_val)
    axes[2].set_title("without attention")
    axes[2].axis("off")

    plt.tight_layout()
    plt.show()

# Summary metrics
df = pd.DataFrame(metrics_list)
print("\n=== Metrics Summary ===")
print(df.describe())
