# Flood Model Training Notebook

Train a Flood ConvLSTM Model using `usl_models` lib.

In [None]:
import tensorflow as tf
import time
import keras
import logging
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import mean_squared_error, mean_absolute_error

from usl_models.flood_ml import constants
from usl_models.flood_ml.model import FloodModel
from usl_models.flood_ml.dataset import load_dataset_windowed

# GPU memory growth setup
gpus = tf.config.list_physical_devices("GPU")
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    except RuntimeError as e:
        print(e)

# Seed and logging
keras.utils.set_random_seed(812)
logging.getLogger().setLevel(logging.WARNING)
early_stop = tf.keras.callbacks.EarlyStopping(
    monitor="loss", patience=100000, restore_best_weights=True, min_delta=1e-5
)


# ===== DATA LOADING =====
def remove_elevation_features(input_dict, label):
    input_dict["geospatial"] = input_dict["geospatial"][..., :8]  # Keep channels 2-8
    return input_dict, label


timestamp = time.strftime("%Y%m%d-%H%M%S")
sim_names = ["Atlanta-Atlanta_config/Rainfall_Data_1.txt"]

# Grab only the first 3 samples
train_dataset_full = load_dataset_windowed(
    sim_names=sim_names, batch_size=1, dataset_split="train"
).map(remove_elevation_features)

for i, (x, y) in enumerate(train_dataset_full.take(20)):
    if tf.reduce_mean(y).numpy() > 0.001:
        print(f"Sample {i} has mean flood depth {tf.reduce_mean(y).numpy():.4f}")
# train_dataset = train_dataset_full.skip(19).take(1).cache().repeat()


selected_indices = [7, 12, 19]

# Create a dataset of multiple manually selected samples
selected_samples = []

for idx in selected_indices:
    ds = train_dataset_full.skip(idx).take(1)
    selected_samples.append(ds)

# Concatenate them together
train_dataset = selected_samples[0]
for ds in selected_samples[1:]:
    train_dataset = train_dataset.concatenate(ds)

# Cache and repeat
train_dataset = train_dataset.cache().repeat()

validation_data = (
    load_dataset_windowed(sim_names=sim_names, batch_size=1, dataset_split="val")
    .map(remove_elevation_features)
    .take(3)
    .cache()
)

constants.GEO_FEATURES = 8

# ===== MODEL SETUP =====
params = FloodModel.Params(
    num_features=constants.GEO_FEATURES,
    lstm_units=128,
    lstm_kernel_size=3,
    lstm_dropout=0,
    lstm_recurrent_dropout=0,
    n_flood_maps=5,
    m_rainfall=6,
    optimizer=keras.optimizers.Adam(learning_rate=0.001),
)
model = FloodModel(params=params)

# ===== TRAINING =====
log_dir = f"logs/training_{timestamp}"
print(f"Training with 1 sample in {log_dir}")

steps_per_epoch = 10
history = model._model.fit(
    train_dataset,
    epochs=800,
    steps_per_epoch=steps_per_epoch,
    validation_data=validation_data,
    validation_steps=10,  # optional
    callbacks=[keras.callbacks.TensorBoard(log_dir), early_stop],
)


# ===== SAVE MODEL =====
model.save_model(log_dir + "/model")

In [None]:
# import tensorflow as tf
# Path to your saved model
# model_path = "logs/Baseline-200epochs/model"

# Load the model
# model = tf.keras.models.load_model(model_path)

In [None]:
# Get the same input used during training
val_sample = next(iter(train_dataset))  # Already cached/repeated, returns the same one

val_input = val_sample[0]
val_gt = val_sample[1].numpy().squeeze()  # shape: (H, W)
val_pred = model.call(val_input).numpy().squeeze()  # shape: (H, W)
# val_pred = model.call(val_sample[0])
print(
    "Prediction stats:",
    tf.reduce_min(val_pred).numpy(),
    tf.reduce_max(val_pred).numpy(),
    tf.reduce_mean(val_pred).numpy(),
)
import matplotlib.pyplot as plt
import numpy as np

vmax_val = max(val_gt.max(), val_pred.max())  # auto scale color

plt.figure(figsize=(18, 6))

plt.subplot(1, 3, 1)
plt.imshow(val_gt, cmap="Blues", vmin=0, vmax=vmax_val)
plt.title("Ground Truth (Flood Depth)")
plt.axis("off")

plt.subplot(1, 3, 2)
plt.imshow(val_pred, cmap="Blues", vmin=0, vmax=vmax_val)
plt.title("Model Prediction")
plt.axis("off")

plt.subplot(1, 3, 3)
plt.imshow(np.abs(val_gt - val_pred), cmap="hot")
plt.title("Absolute Error")
plt.axis("off")

plt.tight_layout()
plt.show()

In [None]:
print(f"MAE: {np.mean(np.abs(val_gt - val_pred))}")
print(f"RMSE: {np.sqrt(np.mean((val_gt - val_pred) ** 2))}")

In [None]:
for val_sample in train_dataset_full.take(3):
    val_input = val_sample[0]
    val_gt = val_sample[1].numpy()
    val_pred = tf.squeeze(model.call(val_input), axis=-1).numpy()

    mae = mean_absolute_error(val_gt.flatten(), val_pred.flatten())
    rmse = np.sqrt(np.mean((val_gt.flatten() - val_pred.flatten()) ** 2))
    print(f"MAE: {mae:.4f}, RMSE: {rmse:.4f}")

    plt.figure(figsize=(15, 5))
    plt.subplot(1, 3, 1)
    plt.imshow(val_gt[0], cmap="Blues")
    plt.title("GT")

    plt.subplot(1, 3, 2)
    plt.imshow(val_pred[0], cmap="Blues")
    plt.title("Prediction")

    plt.subplot(1, 3, 3)
    plt.imshow(np.abs(val_pred[0] - val_gt[0]), cmap="hot")
    plt.title("Error")
    plt.tight_layout()
    plt.show()

In [None]:
print("Mean GT:", val_gt.mean(), "Mean Pred:", val_pred.mean())

In [None]:
plt.hist(val_pred.flatten(), bins=100)

In [None]:
val_sample = next(iter(train_dataset))

# === Unpack validation input/output ===
val_input = val_sample[0]
val_gt = val_sample[1]  # shape: (batch, H, W)
val_pred = model.call(val_input)  # shape: (batch, H, W, 1)

# Remove channels dimension from prediction
val_pred = tf.squeeze(val_pred, axis=-1).numpy()  # shape: (batch, H, W)
val_gt = val_gt.numpy()

# === Compute metrics across all batch samples ===
mae_list = []
rmse_list = []
binary_acc_list = []

threshold = 0.01
batch_size = val_gt.shape[0]

for i in range(batch_size):
    gt = val_gt[i]
    pred = val_pred[i]

    # Compute MAE and RMSE
    mae = mean_absolute_error(gt.flatten(), pred.flatten())
    rmse = np.sqrt(np.mean((gt.flatten() - pred.flatten()) ** 2))
    mae_list.append(mae)
    rmse_list.append(rmse)

    # Threshold maps
    gt_bin = (gt > threshold).astype(np.uint8)
    pred_bin = (pred > threshold).astype(np.uint8)

    # Binary accuracy
    correct = (gt_bin == pred_bin).sum()
    total = gt_bin.size
    binary_acc = correct / total
    binary_acc_list.append(binary_acc)

# === Report aggregate metrics ===
print(f"Average MAE over batch: {np.mean(mae_list):.4f}")
print(f"Average RMSE over batch: {np.mean(rmse_list):.4f}")
print(f"Average Binary Accuracy: {np.mean(binary_acc_list):.4f}")

# Optional: Show basic stats for first few samples
print("GT max:", np.max(val_gt))
print("GT min:", np.min(val_gt))
print("Prediction max:", np.max(val_pred))
print("Prediction min:", np.min(val_pred))
print("GT mean:", np.mean(val_gt))
print("GT unique values (sample):", np.unique(val_gt[0]))

# === Visualization for first few samples ===
num_samples = min(10, batch_size)  # visualize up to 3 samples
for i in range(num_samples):
    gt = val_gt[i]
    pred = val_pred[i]

    print(f"\nSample {i}")
    print("GT max:", np.max(gt), "GT min:", np.min(gt))
    print("Prediction max:", np.max(pred), "Prediction min:", np.min(pred))

    fig, axs = plt.subplots(1, 3, figsize=(15, 5))
    axs[0].imshow(gt, cmap="Blues", vmin=0, vmax=0.08)
    axs[0].set_title(f"GT Flood Map #{i}")
    axs[1].imshow(pred, cmap="Blues", vmin=0, vmax=0.08)
    axs[1].set_title(f"Prediction #{i}")
    axs[2].imshow(np.abs(gt - pred), cmap="hot")
    axs[2].set_title("Error Map")
    plt.tight_layout()
    plt.show()