# FloodML Training Notebook

Clean, single-purpose notebook for training the FloodConvLSTM model.

**Sections:**
1. Configuration - all settings in one place
2. Environment Setup - imports, GPU, seeds
3. Download Data (Optional) - skip if data already cached
4. Load Data - from local cache
5. Model Building - create FloodModel with params
6. Training - fit with callbacks
7. Evaluation - quick validation metrics
8. Save Model

---
## 1. Configuration

Edit these values to customize your training run.

In [None]:
# ============================================================================
# CONFIGURATION - Edit these values
# ============================================================================

# --- Data paths ---
FILECACHE_DIR = "/home/shared/climateiq/filecache"

# --- Download settings ---
# Set to True ONLY if you need to download data from GCS
# Set to False to skip download (use existing cache)
DOWNLOAD_DATA = False

# --- Cities and rainfall scenarios ---
# Format: {"CityName": "config_folder"}
CITY_CONFIG = {
    "Manhattan": "Manhattan_config",
    # "Atlanta": "Atlanta_config",
    # "Phoenix_SM": "PHX_SM",
    # "Phoenix_PV": "PHX_PV",
    # "Phoenix_central": "PHX_CCC",
}

# Rainfall scenario IDs to train on
RAINFALL_IDS = [5, 7, 9, 10, 11, 12, 13, 15, 16]  # Add more as needed

# Dataset splits to download/use
DATASET_SPLITS = ["train", "val", "test"]

# --- Model hyperparameters ---
LSTM_UNITS = 128          # ConvLSTM hidden units (64, 128, or 256)
LSTM_KERNEL_SIZE = 3      # ConvLSTM kernel size (3 or 5)
LSTM_DROPOUT = 0.2        # Dropout rate
LSTM_RECURRENT_DROPOUT = 0.2
N_FLOOD_MAPS = 5          # Timesteps of flood history as input
M_RAINFALL = 6            # Number of temporal features

# --- Training settings ---
BATCH_SIZE = 4            # Reduce if OOM errors occur
EPOCHS = 50               # Number of training epochs
EARLY_STOPPING_PATIENCE = 10  # Stop if val_loss doesn't improve
MAX_CHUNKS = None         # None = use all chunks, or set integer for quick tests

# --- Output ---
import time
TIMESTAMP = time.strftime("%Y%m%d-%H%M%S")
LOG_DIR = f"logs/flood_training_{TIMESTAMP}"

print(f"Download data: {DOWNLOAD_DATA}")
print(f"Log directory: {LOG_DIR}")

---
## 2. Environment Setup

In [None]:
# Auto-reload modules during development
%load_ext autoreload
%autoreload 2

import os
import pathlib
import logging

# Suppress TF info messages
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
logging.getLogger().setLevel(logging.WARNING)

import tensorflow as tf
import keras
import numpy as np

# Set random seed for reproducibility
SEED = 42
keras.utils.set_random_seed(SEED)

# Enable GPU memory growth (prevents OOM)
for gpu in tf.config.list_physical_devices("GPU"):
    tf.config.experimental.set_memory_growth(gpu, True)

print(f"TensorFlow version: {tf.__version__}")
print(f"GPUs available: {len(tf.config.list_physical_devices('GPU'))}")

In [None]:
# Import flood model components
from usl_models.flood_ml.model import FloodModel
from usl_models.flood_ml.dataset import (
    load_dataset_windowed_cached,
    load_dataset_cached,
    download_dataset,
)
from usl_models.flood_ml import constants

print(f"Constants: MAP_SIZE={constants.MAP_HEIGHT}x{constants.MAP_WIDTH}, "
      f"N_FLOOD_MAPS={constants.N_FLOOD_MAPS}, M_RAINFALL={constants.M_RAINFALL}")

In [None]:
# Build list of simulation names (used for both download and training)
sim_names = []
for city, config in CITY_CONFIG.items():
    for rain_id in RAINFALL_IDS:
        sim_names.append(f"{city}-{config}/Rainfall_Data_{rain_id}.txt")

print(f"Configured {len(sim_names)} simulations:")
for s in sim_names:
    exists = (pathlib.Path(FILECACHE_DIR) / s).exists()
    status = "CACHED" if exists else "NOT CACHED"
    print(f"  [{status}] {s}")

---
## 3. Download Data (Optional)

**Skip this section if data is already cached.**

Set `DOWNLOAD_DATA = True` in Configuration to enable download from GCS.

In [None]:
# Download data from GCS to local filecache
# This cell is skipped if DOWNLOAD_DATA = False

if DOWNLOAD_DATA:
    print("="*60)
    print("DOWNLOADING DATA FROM GCS")
    print("="*60)
    print(f"Target directory: {FILECACHE_DIR}")
    print(f"Simulations: {len(sim_names)}")
    print(f"Splits: {DATASET_SPLITS}")
    print()
    
    download_dataset(
        sim_names=sim_names,
        output_path=pathlib.Path(FILECACHE_DIR),
        dataset_splits=DATASET_SPLITS,
        include_labels=True,  # Training mode needs labels
    )
    
    print("\nDownload complete!")
else:
    print("Skipping download (DOWNLOAD_DATA=False)")
    print("Using existing cached data.")

---
## 4. Load Data

Load cached datasets for training.

In [None]:
# Verify data exists before loading
missing = []
for s in sim_names:
    if not (pathlib.Path(FILECACHE_DIR) / s).exists():
        missing.append(s)

if missing:
    print("ERROR: The following simulations are not cached:")
    for s in missing:
        print(f"  - {s}")
    print("\nSet DOWNLOAD_DATA=True and re-run, or check FILECACHE_DIR path.")
    raise FileNotFoundError(f"{len(missing)} simulations not found in cache")
else:
    print(f"All {len(sim_names)} simulations found in cache.")

In [None]:
# Load training dataset (windowed for teacher-forcing)
print("Loading training dataset...")
train_dataset = load_dataset_windowed_cached(
    filecache_dir=pathlib.Path(FILECACHE_DIR),
    sim_names=sim_names,
    dataset_split="train",
    batch_size=BATCH_SIZE,
    n_flood_maps=N_FLOOD_MAPS,
    m_rainfall=M_RAINFALL,
    max_chunks=MAX_CHUNKS,
    shuffle=True,
).prefetch(tf.data.AUTOTUNE)

# Load validation dataset
print("Loading validation dataset...")
val_dataset = load_dataset_windowed_cached(
    filecache_dir=pathlib.Path(FILECACHE_DIR),
    sim_names=sim_names,
    dataset_split="val",
    batch_size=BATCH_SIZE,
    n_flood_maps=N_FLOOD_MAPS,
    m_rainfall=M_RAINFALL,
    max_chunks=MAX_CHUNKS,
    shuffle=False,
).prefetch(tf.data.AUTOTUNE)

print("Datasets loaded.")

In [None]:
# Verify dataset shapes
for inputs, labels in train_dataset.take(1):
    print("Sample batch shapes:")
    print(f"  geospatial:    {inputs['geospatial'].shape}")
    print(f"  temporal:      {inputs['temporal'].shape}")
    print(f"  spatiotemporal:{inputs['spatiotemporal'].shape}")
    print(f"  labels:        {labels.shape}")
    
    # Verify temporal features are distinct (bug fix check)
    t = inputs['temporal'].numpy()
    if t.max() > 0:  # Skip if all-zero window
        cols_distinct = not np.allclose(t[0, :, 0], t[0, :, 1])
        print(f"  Temporal features distinct: {cols_distinct}")

---
## 5. Model Building

In [None]:
# Create model parameters
params = FloodModel.Params(
    lstm_units=LSTM_UNITS,
    lstm_kernel_size=LSTM_KERNEL_SIZE,
    lstm_dropout=LSTM_DROPOUT,
    lstm_recurrent_dropout=LSTM_RECURRENT_DROPOUT,
    n_flood_maps=N_FLOOD_MAPS,
    m_rainfall=M_RAINFALL,
)

# Build model
model = FloodModel(params=params)

print("Model parameters:")
for k, v in params.to_dict().items():
    if k != "optimizer":
        print(f"  {k}: {v}")

In [None]:
# Show model summary
model._model.summary()

---
## 6. Training

In [None]:
# Create output directory
os.makedirs(LOG_DIR, exist_ok=True)

# Define callbacks
callbacks = [
    # TensorBoard logging
    keras.callbacks.TensorBoard(
        log_dir=LOG_DIR,
        histogram_freq=0,
        profile_batch=0,
    ),
    # Save best model
    keras.callbacks.ModelCheckpoint(
        filepath=f"{LOG_DIR}/best_model.keras",
        save_best_only=True,
        monitor="val_loss",
        mode="min",
        verbose=1,
    ),
    # Early stopping
    keras.callbacks.EarlyStopping(
        monitor="val_loss",
        patience=EARLY_STOPPING_PATIENCE,
        restore_best_weights=True,
        verbose=1,
    ),
]

print(f"Training for up to {EPOCHS} epochs...")
print(f"Logs will be saved to: {LOG_DIR}")

In [None]:
# Train the model
history = model.fit(
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    epochs=EPOCHS,
    callbacks=callbacks,
)

print("\nTraining complete!")

In [None]:
# Plot training history
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Loss
axes[0].plot(history.history['loss'], label='Train')
axes[0].plot(history.history['val_loss'], label='Validation')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training & Validation Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# MAE
if 'mean_absolute_error' in history.history:
    axes[1].plot(history.history['mean_absolute_error'], label='Train')
    axes[1].plot(history.history['val_mean_absolute_error'], label='Validation')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('MAE (m)')
    axes[1].set_title('Mean Absolute Error')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(f"{LOG_DIR}/training_history.png", dpi=150)
plt.show()

print(f"\nFinal metrics:")
print(f"  Train loss: {history.history['loss'][-1]:.6f}")
print(f"  Val loss:   {history.history['val_loss'][-1]:.6f}")

---
## 7. Quick Evaluation

Run a quick prediction to verify the model works.

In [None]:
# Load one validation sample for quick test
from usl_models.flood_ml.dataset import _iter_model_inputs_cached

# Get one raw (non-windowed) sample for call_n test
sim_dir = pathlib.Path(FILECACHE_DIR) / sim_names[0]
raw_gen = _iter_model_inputs_cached(
    sim_dir=sim_dir,
    dataset_split="val",
    n_flood_maps=N_FLOOD_MAPS,
    m_rainfall=M_RAINFALL,
    max_chunks=1,
    include_labels=True,
    shuffle=False,
)
raw_input, raw_labels, chunk_name = next(raw_gen)
print(f"Testing on chunk: {chunk_name}")
print(f"Labels shape: {raw_labels.shape}")

In [None]:
# Run autoregressive prediction
N_STEPS = raw_labels.shape[0]  # Match label timesteps

batch_input = FloodModel.Input(
    geospatial=raw_input["geospatial"][tf.newaxis],
    temporal=raw_input["temporal"][tf.newaxis],
    spatiotemporal=raw_input["spatiotemporal"][tf.newaxis],
)

predictions = model.call_n(batch_input, n=N_STEPS)
print(f"Predictions shape: {predictions.shape}")

In [None]:
# Compare prediction vs ground truth
pred_np = np.clip(predictions.numpy()[0], 0, None)  # Clip negatives
label_np = raw_labels.numpy()

print("Per-timestep comparison:")
print(f"{'t':>3} | {'GT max':>8} {'GT mean':>10} | {'Pred max':>8} {'Pred mean':>10} | {'MAE':>8}")
print("-" * 70)

for t in range(min(N_STEPS, label_np.shape[0])):
    gt = label_np[t]
    pd = pred_np[t]
    mae = np.abs(gt - pd).mean()
    print(f"{t:>3} | {gt.max():>8.3f} {gt.mean():>10.6f} | {pd.max():>8.3f} {pd.mean():>10.6f} | {mae:>8.5f}")

overall_mae = np.abs(label_np[:N_STEPS] - pred_np[:N_STEPS]).mean()
print("-" * 70)
print(f"Overall MAE: {overall_mae:.5f} m")

In [None]:
# Visualize a few timesteps
import matplotlib.pyplot as plt

timesteps_to_show = [0, 4, 8, 12] if N_STEPS > 12 else list(range(min(4, N_STEPS)))
fig, axes = plt.subplots(len(timesteps_to_show), 2, figsize=(10, 4*len(timesteps_to_show)))

for i, t in enumerate(timesteps_to_show):
    if t >= N_STEPS:
        continue
    gt = label_np[t]
    pd = pred_np[t]
    vmax = max(gt.max(), pd.max(), 0.1)
    
    axes[i, 0].imshow(gt, cmap='Blues', vmin=0, vmax=vmax)
    axes[i, 0].set_title(f't={t} Ground Truth (max={gt.max():.2f}m)')
    axes[i, 0].axis('off')
    
    im = axes[i, 1].imshow(pd, cmap='Blues', vmin=0, vmax=vmax)
    axes[i, 1].set_title(f't={t} Prediction (max={pd.max():.2f}m)')
    axes[i, 1].axis('off')

plt.tight_layout()
plt.savefig(f"{LOG_DIR}/prediction_sample.png", dpi=150)
plt.show()

---
## 8. Save Model

In [None]:
# Save final model
final_model_path = f"{LOG_DIR}/final_model.keras"
model.save_model(final_model_path)

print(f"\nModel saved to: {final_model_path}")
print(f"Best checkpoint: {LOG_DIR}/best_model.keras")
print(f"\nTo load this model for prediction:")
print(f'  model = FloodModel.from_checkpoint("{final_model_path}")')

In [None]:
# Save training config for reference
import json

config = {
    "sim_names": sim_names,
    "params": params.to_dict(),
    "batch_size": BATCH_SIZE,
    "epochs_trained": len(history.history['loss']),
    "final_train_loss": float(history.history['loss'][-1]),
    "final_val_loss": float(history.history['val_loss'][-1]),
}

with open(f"{LOG_DIR}/training_config.json", "w") as f:
    json.dump(config, f, indent=2)

print(f"Config saved to: {LOG_DIR}/training_config.json")