# Flood Model Training Notebook

Train a Flood ConvLSTM Model using `usl_models` lib.

In [None]:
%load_ext autoreload
%autoreload 2
import tensorflow as tf
import keras_tuner
import time
import keras
import logging
from usl_models.flood_ml import constants
from usl_models.flood_ml.model import FloodModel
from usl_models.flood_ml.model_params import FloodModelParams
from usl_models.flood_ml.dataset import load_dataset_windowed, load_dataset

logging.getLogger().setLevel(logging.WARNING)
keras.utils.set_random_seed(812)

timestamp = time.strftime("%Y%m%d-%H%M%S")
sim_names = ["Manhattan-config_v1/Rainfall_Data_1.txt"] # constant should be changed to 8 to run Manhattan

In [None]:
train_dataset = load_dataset_windowed(sim_names=sim_names, batch_size=4, dataset_split='train')
validation_data = load_dataset_windowed(sim_names=sim_names, batch_size=4, dataset_split='val')

In [None]:
tuner = keras_tuner.BayesianOptimization(
    FloodModel.get_hypermodel(
        lstm_units=[32, 64, 128],
        lstm_kernel_size=[3, 5],
        lstm_dropout=[0.2, 0.3],
        lstm_recurrent_dropout=[0.2, 0.3],
        n_flood_maps=[5],
        m_rainfall=[6],
    ),
        objective="val_loss",
        max_trials=10,
        project_name=f"logs/htune_project_{timestamp}",
)

tuner.search_space_summary()


In [None]:
log_dir = f"logs/htune_{timestamp}"
print(log_dir)
tb_callback = keras.callbacks.TensorBoard(log_dir=log_dir)
tuner.search(train_dataset, epochs=100, validation_data=validation_data , callbacks=[tb_callback])
best_model, best_hp = tuner.get_best_models()[0], tuner.get_best_hyperparameters()[0]
best_hp.values

In [None]:
final_params = FloodModel.Params(**best_hp.values)
model = FloodModel(params=final_params)
tb_callback = keras.callbacks.TensorBoard(log_dir=log_dir)
model.fit(train_dataset, validation_data, epochs=200, callbacks=[tb_callback])
model.save_model(log_dir + "/model")


In [None]:
# Test calling the model on some data.
inputs, labels_ = next(iter(train_dataset))
prediction = model.call(inputs)
prediction.shape

In [None]:
# Test calling the model for n predictions
full_dataset = load_dataset(sim_names=sim_names, batch_size=1)
inputs, labels = next(iter(full_dataset))
predictions = model.call_n(inputs, n=4)
predictions.shape

In [None]:
%load_ext autoreload
%autoreload 2
import tensorflow as tf
import time
import keras
import logging
from usl_models.flood_ml import constants
from usl_models.flood_ml.model import FloodModel
from usl_models.flood_ml.dataset import load_dataset_windowed, load_dataset

# Configure GPU
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    except RuntimeError as e:
        print(e)

logging.getLogger().setLevel(logging.WARNING)
keras.utils.set_random_seed(812)

# ===== DATA LOADING =====
def remove_elevation_features(input_dict, label):
    """Remove elevation features (channels 0 and 1)"""
    input_dict['geospatial'] = input_dict['geospatial'][..., 2:]  # Keep channels 2-8 (7 features)
    return input_dict, label

timestamp = time.strftime("%Y%m%d-%H%M%S")
sim_names = ["Atlanta-Atlanta_config/Rainfall_Data_1.txt"]

# Load datasets
train_dataset = load_dataset_windowed(
    sim_names=sim_names,
    batch_size=4,
    dataset_split='train'
).map(remove_elevation_features)

validation_data = load_dataset_windowed(
    sim_names=sim_names,
    batch_size=4,
    dataset_split='val'
).map(remove_elevation_features)

constants.GEO_FEATURES = 7  # Must match the number of features after removal

# ===== MODEL SETUP =====
standard_params = FloodModel.Params(
    num_features=constants.GEO_FEATURES,
    lstm_units=64,
    lstm_kernel_size=3,
    lstm_dropout=0.2,
    lstm_recurrent_dropout=0.2,
    n_flood_maps=5,
    m_rainfall=6,
    optimizer=keras.optimizers.Adam(learning_rate=0.001)
)

model = FloodModel(params=standard_params)

# ===== TRAINING =====
log_dir = f"logs/training_{timestamp}"
print(f"Training with {constants.GEO_FEATURES} features in {log_dir}")

# Verify data loading
try:
    sample = next(iter(train_dataset))
    print("Sample input shapes:")
    print(f"Geospatial: {sample[0]['geospatial'].shape} (should be (4, 1000, 1000, 9))")
    print(f"Temporal: {sample[0]['temporal'].shape}")
    print(f"Spatiotemporal: {sample[0]['spatiotemporal'].shape}")
except Exception as e:
    print(f"Data loading error: {str(e)}")
    raise

# Train using the underlying Keras model
history = model._model.fit(
    train_dataset,
    epochs=500,
    callbacks=[keras.callbacks.TensorBoard(log_dir)]
)

# ===== EVALUATION =====
model.save_model(log_dir + "/model")

# # Manual validation
# val_sample = next(iter(validation_data))
# val_pred = model.call(val_sample[0])
# val_loss = tf.keras.losses.MeanSquaredError()(val_sample[1], val_pred)
# print(f"Validation loss: {val_loss.numpy():.4f}")

# # Prediction test
# test_dataset = load_dataset(sim_names=sim_names, batch_size=1,dataset_split='test').map(remove_elevation_features)
# test_input, _ = next(iter(test_dataset))
# predictions = model.call_n(test_input, n=4)
# print("Autoregressive predictions shape:", predictions.shape)