# Flood Model Training Notebook

Train a Flood ConvLSTM Model using `usl_models` lib.

In [None]:
%load_ext autoreload
%autoreload 2
import tensorflow as tf
import keras_tuner
import time
import keras
import logging
from usl_models.flood_ml import constants
from usl_models.flood_ml.model import FloodModel
from usl_models.flood_ml.model_params import FloodModelParams
from usl_models.flood_ml.dataset import load_dataset_windowed, load_dataset

logging.getLogger().setLevel(logging.WARNING)
keras.utils.set_random_seed(812)

timestamp = time.strftime("%Y%m%d-%H%M%S")
sim_names = ["Manhattan-config_v1/Rainfall_Data_1.txt"]

In [None]:
train_dataset = load_dataset_windowed(sim_names=sim_names, batch_size=4, dataset_split='train')
validation_data = load_dataset_windowed(sim_names=sim_names, batch_size=4, dataset_split='val')

In [None]:
tuner = keras_tuner.BayesianOptimization(
    FloodModel.get_hypermodel(
        lstm_units=[32, 64, 128],
        lstm_kernel_size=[3, 5],
        lstm_dropout=[0.2, 0.3],
        lstm_recurrent_dropout=[0.2, 0.3],
        n_flood_maps=[5],
        m_rainfall=[6],
    ),
        objective="val_loss",
        max_trials=10,
        project_name=f"logs/htune_project_{timestamp}",
)

tuner.search_space_summary()


In [None]:
log_dir = f"logs/htune_{timestamp}"
print(log_dir)
tb_callback = keras.callbacks.TensorBoard(log_dir=log_dir)
tuner.search(train_dataset, epochs=100, validation_data=validation_data , callbacks=[tb_callback])
best_model, best_hp = tuner.get_best_models()[0], tuner.get_best_hyperparameters()[0]
best_hp.values

In [None]:
final_params = FloodModel.Params(**best_hp.values)
model = FloodModel(params=final_params)
tb_callback = keras.callbacks.TensorBoard(log_dir=log_dir)
model.fit(train_dataset, validation_data, epochs=2000, callbacks=[tb_callback])
model.save_model(log_dir + "/model")


In [None]:
# Test calling the model on some data.
inputs, labels_ = next(iter(train_dataset))
prediction = model.call(inputs)
prediction.shape

In [None]:
# Test calling the model for n predictions
full_dataset = load_dataset(sim_names=sim_names, batch_size=1)
inputs, labels = next(iter(full_dataset))
predictions = model.call_n(inputs, n=4)
predictions.shape