# FloodML Training Notebook

Train the FloodConvLSTM model with support for:
- **Full maps** (1000x1000) - original approach
- **Patches** (e.g., 256x256) - sliding window with flood filtering

**Sections:**
1. Configuration
2. Environment Setup
3. Download Data (Optional)
4. Load Data
5. Model Building
6. Training
7. Evaluation
8. Save Model

---
## 1. Configuration

In [None]:
# ============================================================================
# CONFIGURATION - Edit these values
# ============================================================================

# --- Data paths ---
FILECACHE_DIR = "/home/shared/climateiq/filecache"

# --- Download settings ---
DOWNLOAD_DATA = False  # Set True to download from GCS

# --- Cities and rainfall scenarios ---
CITY_CONFIG = {
    "Manhattan": "Manhattan_config",
    # "Atlanta": "Atlanta_config",
    # "Phoenix_SM": "PHX_SM",
}

RAINFALL_IDS = [5, 7, 9, 10, 11, 12, 13, 15, 16]
DATASET_SPLITS = ["train", "val", "test"]

# ==========================================================================
# PATCH MODE vs FULL MAP MODE
# ==========================================================================
USE_PATCHES = True

# --- Patch settings (only used if USE_PATCHES=True) ---
PATCH_SIZE = 256
PATCH_STRIDE = 128
MIN_FLOOD_FRACTION = 0.01
MIN_MAX_DEPTH = 0.1
MAX_PATCHES_PER_CHUNK = 20

# ==========================================================================
# HYPERPARAMETER TUNING vs FIXED PARAMS
# ==========================================================================
# RUN_HYPERTUNING = True  -> Run Bayesian optimization to find best params
# RUN_HYPERTUNING = False -> Use fixed hyperparameters below
RUN_HYPERTUNING = True

# --- Hypertuning settings (only used if RUN_HYPERTUNING=True) ---
HTUNE_MAX_TRIALS = 5           # Number of hyperparameter combinations to try
HTUNE_EPOCHS = 10              # Epochs per trial
HTUNE_TRAIN_SAMPLES = 200      # Training samples per trial (for speed)
HTUNE_VAL_SAMPLES = 50         # Validation samples per trial

# --- Hyperparameter search space ---
HTUNE_LSTM_UNITS = [32, 64, 128]
HTUNE_LSTM_KERNEL_SIZE = [3, 5]
HTUNE_LSTM_DROPOUT = [0.2, 0.3]
HTUNE_LSTM_RECURRENT_DROPOUT = [0.2, 0.3]

# --- Fixed hyperparameters (used if RUN_HYPERTUNING=False) ---
LSTM_UNITS = 128
LSTM_KERNEL_SIZE = 3
LSTM_DROPOUT = 0.2
LSTM_RECURRENT_DROPOUT = 0.2

# --- Constants (don't change unless you know what you're doing) ---
N_FLOOD_MAPS = 5
M_RAINFALL = 6

# --- Training settings ---
BATCH_SIZE = 4 if USE_PATCHES else 2
EPOCHS = 50
EARLY_STOPPING_PATIENCE = 10
MAX_CHUNKS = None  # None = all, or integer for quick test

# --- Output ---
import time
TIMESTAMP = time.strftime("%Y%m%d-%H%M%S")
mode_str = f"patches{PATCH_SIZE}" if USE_PATCHES else "fullmap"
htune_str = "htune_" if RUN_HYPERTUNING else ""
LOG_DIR = f"logs/flood_{htune_str}{mode_str}_{TIMESTAMP}"

# Derived: spatial dimensions for model
SPATIAL_SIZE = PATCH_SIZE if USE_PATCHES else 1000

print(f"Mode: {'PATCHES' if USE_PATCHES else 'FULL MAPS'} ({SPATIAL_SIZE}x{SPATIAL_SIZE})")
print(f"Hypertuning: {'ON' if RUN_HYPERTUNING else 'OFF'}")
print(f"Log directory: {LOG_DIR}")

---
## 2. Environment Setup

In [None]:
%load_ext autoreload
%autoreload 2

import gc
import os
import pathlib
import logging

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
logging.getLogger().setLevel(logging.WARNING)

import tensorflow as tf
import keras
import keras_tuner
import numpy as np

SEED = 42
keras.utils.set_random_seed(SEED)

for gpu in tf.config.list_physical_devices("GPU"):
    tf.config.experimental.set_memory_growth(gpu, True)

print(f"TensorFlow: {tf.__version__}")
print(f"Keras Tuner: {keras_tuner.__version__}")
print(f"GPUs: {len(tf.config.list_physical_devices('GPU'))}")

In [None]:
from usl_models.flood_ml.model import FloodModel
from usl_models.flood_ml.dataset import (
    load_dataset_windowed_cached,
    load_dataset_windowed_patches,
    download_dataset,
)
from usl_models.flood_ml import constants

print(f"Map size: {constants.MAP_HEIGHT}x{constants.MAP_WIDTH}")

In [None]:
# Build simulation names
sim_names = []
for city, config in CITY_CONFIG.items():
    for rain_id in RAINFALL_IDS:
        sim_names.append(f"{city}-{config}/Rainfall_Data_{rain_id}.txt")

print(f"Configured {len(sim_names)} simulations:")
for s in sim_names:
    exists = (pathlib.Path(FILECACHE_DIR) / s).exists()
    print(f"  [{'OK' if exists else 'MISSING'}] {s}")

---
## 3. Download Data (Optional)

In [None]:
if DOWNLOAD_DATA:
    print("Downloading from GCS...")
    download_dataset(
        sim_names=sim_names,
        output_path=pathlib.Path(FILECACHE_DIR),
        dataset_splits=DATASET_SPLITS,
        include_labels=True,
    )
    print("Download complete!")
else:
    print("Skipping download (DOWNLOAD_DATA=False)")

---
## 4. Load Data

In [None]:
# Verify data exists
missing = [s for s in sim_names if not (pathlib.Path(FILECACHE_DIR) / s).exists()]
if missing:
    print(f"ERROR: {len(missing)} simulations not cached")
    for s in missing:
        print(f"  - {s}")
    raise FileNotFoundError("Set DOWNLOAD_DATA=True")
print(f"All {len(sim_names)} simulations found.")

In [None]:
# Load datasets based on mode
filecache = pathlib.Path(FILECACHE_DIR)

if USE_PATCHES:
    print(f"Loading PATCH datasets ({PATCH_SIZE}x{PATCH_SIZE}, stride={PATCH_STRIDE})...")
    print(f"  Filters: min_flood_fraction={MIN_FLOOD_FRACTION}, min_max_depth={MIN_MAX_DEPTH}m")
    
    train_dataset = load_dataset_windowed_patches(
        filecache_dir=filecache,
        sim_names=sim_names,
        dataset_split="train",
        patch_size=PATCH_SIZE,
        stride=PATCH_STRIDE,
        batch_size=BATCH_SIZE,
        n_flood_maps=N_FLOOD_MAPS,
        m_rainfall=M_RAINFALL,
        max_chunks=MAX_CHUNKS,
        max_patches_per_chunk=MAX_PATCHES_PER_CHUNK,
        min_flood_fraction=MIN_FLOOD_FRACTION,
        min_max_depth=MIN_MAX_DEPTH,
        shuffle=True,
    )
    
    val_dataset = load_dataset_windowed_patches(
        filecache_dir=filecache,
        sim_names=sim_names,
        dataset_split="val",
        patch_size=PATCH_SIZE,
        stride=PATCH_STRIDE,
        batch_size=BATCH_SIZE,
        n_flood_maps=N_FLOOD_MAPS,
        m_rainfall=M_RAINFALL,
        max_chunks=MAX_CHUNKS,
        max_patches_per_chunk=MAX_PATCHES_PER_CHUNK,
        min_flood_fraction=MIN_FLOOD_FRACTION,
        min_max_depth=MIN_MAX_DEPTH,
        shuffle=False,
    )
else:
    print("Loading FULL MAP datasets (1000x1000)...")
    
    train_dataset = load_dataset_windowed_cached(
        filecache_dir=filecache,
        sim_names=sim_names,
        dataset_split="train",
        batch_size=BATCH_SIZE,
        n_flood_maps=N_FLOOD_MAPS,
        m_rainfall=M_RAINFALL,
        max_chunks=MAX_CHUNKS,
        shuffle=True,
    ).prefetch(tf.data.AUTOTUNE)
    
    val_dataset = load_dataset_windowed_cached(
        filecache_dir=filecache,
        sim_names=sim_names,
        dataset_split="val",
        batch_size=BATCH_SIZE,
        n_flood_maps=N_FLOOD_MAPS,
        m_rainfall=M_RAINFALL,
        max_chunks=MAX_CHUNKS,
        shuffle=False,
    ).prefetch(tf.data.AUTOTUNE)

print("Datasets loaded.")

In [None]:
# Diagnostic: Count patches and show sample values
if USE_PATCHES:
    print("Analyzing patch selection...")
    
    # Count total windows (samples) in train dataset
    train_count = 0
    sample_labels = []
    for inputs, labels in train_dataset:
        train_count += labels.shape[0]
        if len(sample_labels) < 5:  # Store first 5 batches for analysis
            sample_labels.append(labels.numpy())
    
    val_count = sum(labels.shape[0] for _, labels in val_dataset)
    
    print(f"\n=== PATCH STATISTICS ===")
    print(f"Training samples (windows): {train_count}")
    print(f"Validation samples (windows): {val_count}")
    print(f"Total: {train_count + val_count}")
    
    # Analyze sample patches
    if sample_labels:
        all_samples = np.concatenate(sample_labels, axis=0)
        print(f"\n=== SAMPLE PATCH VALUES (first {len(all_samples)} patches) ===")
        for i in range(min(10, len(all_samples))):
            patch = all_samples[i]
            max_depth = patch.max()
            flooded_pct = (patch > 0).mean() * 100
            print(f"  Patch {i}: max_depth={max_depth:.3f}m, flooded={flooded_pct:.1f}%")
        
        print(f"\n  Overall: max={all_samples.max():.3f}m, "
              f"mean={all_samples[all_samples>0].mean():.4f}m (flooded pixels only)")
else:
    print("Full map mode - no patch filtering")

In [None]:
# Verify dataset shapes
for inputs, labels in train_dataset.take(1):
    print("Sample batch:")
    print(f"  geospatial:    {inputs['geospatial'].shape}")
    print(f"  temporal:      {inputs['temporal'].shape}")
    print(f"  spatiotemporal:{inputs['spatiotemporal'].shape}")
    print(f"  labels:        {labels.shape}")
    
    # Check temporal features are distinct
    t = inputs['temporal'].numpy()
    if t.max() > 0:
        distinct = not np.allclose(t[0, :, 0], t[0, :, 1])
        print(f"  Temporal cols distinct: {distinct}")
    
    # Show label stats
    lbl = labels.numpy()
    print(f"  Label max: {lbl.max():.3f}m, flooded: {(lbl>0).mean()*100:.1f}%")

---
## 5. Hyperparameter Tuning & Model Building

In [None]:
# Helper function to get fresh datasets (needed for tuner)
def get_datasets_for_tuning(batch_size):
    """Load datasets for hyperparameter tuning."""
    filecache = pathlib.Path(FILECACHE_DIR)
    
    if USE_PATCHES:
        train_ds = load_dataset_windowed_patches(
            filecache_dir=filecache,
            sim_names=sim_names,
            dataset_split="train",
            patch_size=PATCH_SIZE,
            stride=PATCH_STRIDE,
            batch_size=batch_size,
            n_flood_maps=N_FLOOD_MAPS,
            m_rainfall=M_RAINFALL,
            max_chunks=MAX_CHUNKS,
            max_patches_per_chunk=MAX_PATCHES_PER_CHUNK,
            min_flood_fraction=MIN_FLOOD_FRACTION,
            min_max_depth=MIN_MAX_DEPTH,
            shuffle=True,
        )
        val_ds = load_dataset_windowed_patches(
            filecache_dir=filecache,
            sim_names=sim_names,
            dataset_split="val",
            patch_size=PATCH_SIZE,
            stride=PATCH_STRIDE,
            batch_size=batch_size,
            n_flood_maps=N_FLOOD_MAPS,
            m_rainfall=M_RAINFALL,
            max_chunks=MAX_CHUNKS,
            max_patches_per_chunk=MAX_PATCHES_PER_CHUNK,
            min_flood_fraction=MIN_FLOOD_FRACTION,
            min_max_depth=MIN_MAX_DEPTH,
            shuffle=False,
        )
    else:
        train_ds = load_dataset_windowed_cached(
            filecache_dir=filecache,
            sim_names=sim_names,
            dataset_split="train",
            batch_size=batch_size,
            n_flood_maps=N_FLOOD_MAPS,
            m_rainfall=M_RAINFALL,
            max_chunks=MAX_CHUNKS,
            shuffle=True,
        ).prefetch(tf.data.AUTOTUNE)
        val_ds = load_dataset_windowed_cached(
            filecache_dir=filecache,
            sim_names=sim_names,
            dataset_split="val",
            batch_size=batch_size,
            n_flood_maps=N_FLOOD_MAPS,
            m_rainfall=M_RAINFALL,
            max_chunks=MAX_CHUNKS,
            shuffle=False,
        ).prefetch(tf.data.AUTOTUNE)
    
    return train_ds, val_ds

print("Dataset loader ready for tuning.")

In [None]:
# Run Bayesian Optimization hyperparameter search
best_hp = None

if RUN_HYPERTUNING:
    print("=" * 60)
    print("HYPERPARAMETER TUNING - Bayesian Optimization")
    print("=" * 60)
    
    os.makedirs(LOG_DIR, exist_ok=True)
    
    # Create tuner
    tuner = keras_tuner.BayesianOptimization(
        FloodModel.get_hypermodel(
            lstm_units=HTUNE_LSTM_UNITS,
            lstm_kernel_size=HTUNE_LSTM_KERNEL_SIZE,
            lstm_dropout=HTUNE_LSTM_DROPOUT,
            lstm_recurrent_dropout=HTUNE_LSTM_RECURRENT_DROPOUT,
            n_flood_maps=[N_FLOOD_MAPS],
            m_rainfall=[M_RAINFALL],
        ),
        objective="val_loss",
        max_trials=HTUNE_MAX_TRIALS,
        project_name=f"{LOG_DIR}/tuner",
    )
    
    # Show search space
    print("\nSearch space:")
    tuner.search_space_summary()
    
    # TensorBoard callback for tuning
    tb_callback = keras.callbacks.TensorBoard(
        log_dir=f"{LOG_DIR}/tuner_tb",
        histogram_freq=0,
        profile_batch=0,
    )
    
    def run_tuner_search(batch_size=BATCH_SIZE):
        """Run tuner with limited samples for speed."""
        gc.collect()
        tf.keras.backend.clear_session()
        
        train_ds, val_ds = get_datasets_for_tuning(batch_size)
        
        # Convert sample counts to batch counts
        num_train_batches = max(1, HTUNE_TRAIN_SAMPLES // batch_size)
        num_val_batches = max(1, HTUNE_VAL_SAMPLES // batch_size)
        
        print(f"\nUsing {num_train_batches} train batches ({num_train_batches * batch_size} samples)")
        print(f"Using {num_val_batches} val batches ({num_val_batches * batch_size} samples)")
        print(f"Epochs per trial: {HTUNE_EPOCHS}")
        print(f"Max trials: {HTUNE_MAX_TRIALS}")
        
        tuner.search(
            train_ds.take(num_train_batches),
            validation_data=val_ds.take(num_val_batches),
            epochs=HTUNE_EPOCHS,
            callbacks=[tb_callback],
            verbose=1,
        )
    
    # Run the search
    run_tuner_search()
    
    # Get best hyperparameters
    best_hp = tuner.get_best_hyperparameters()[0]
    print("\n" + "=" * 60)
    print("BEST HYPERPARAMETERS FOUND:")
    print("=" * 60)
    for k, v in best_hp.values.items():
        print(f"  {k}: {v}")
    
    # Update the fixed values for display
    LSTM_UNITS = best_hp.get("lstm_units")
    LSTM_KERNEL_SIZE = best_hp.get("lstm_kernel_size")
    LSTM_DROPOUT = best_hp.get("lstm_dropout")
    LSTM_RECURRENT_DROPOUT = best_hp.get("lstm_recurrent_dropout")
    
else:
    print("Skipping hyperparameter tuning (RUN_HYPERTUNING=False)")
    print(f"Using fixed params: lstm_units={LSTM_UNITS}, kernel={LSTM_KERNEL_SIZE}, "
          f"dropout={LSTM_DROPOUT}, rec_dropout={LSTM_RECURRENT_DROPOUT}")

In [None]:
# Build model with correct spatial dimensions
params = FloodModel.Params(
    lstm_units=LSTM_UNITS,
    lstm_kernel_size=LSTM_KERNEL_SIZE,
    lstm_dropout=LSTM_DROPOUT,
    lstm_recurrent_dropout=LSTM_RECURRENT_DROPOUT,
    n_flood_maps=N_FLOOD_MAPS,
    m_rainfall=M_RAINFALL,
)

# IMPORTANT: Set spatial_dims to match data
model = FloodModel(
    params=params,
    spatial_dims=(SPATIAL_SIZE, SPATIAL_SIZE),
)

print(f"Model spatial dims: {SPATIAL_SIZE}x{SPATIAL_SIZE}")
print(f"Parameters:")
for k, v in params.to_dict().items():
    if k != "optimizer":
        print(f"  {k}: {v}")

In [None]:
model._model.summary()

---
## 6. Training

In [None]:
os.makedirs(LOG_DIR, exist_ok=True)

callbacks = [
    keras.callbacks.TensorBoard(log_dir=LOG_DIR, histogram_freq=0),
    keras.callbacks.ModelCheckpoint(
        filepath=f"{LOG_DIR}/best_model.keras",
        save_best_only=True,
        monitor="val_loss",
        mode="min",
        verbose=1,
    ),
    keras.callbacks.EarlyStopping(
        monitor="val_loss",
        patience=EARLY_STOPPING_PATIENCE,
        restore_best_weights=True,
        verbose=1,
    ),
]

print(f"Training for up to {EPOCHS} epochs...")
print(f"Logs: {LOG_DIR}")

In [None]:
history = model.fit(
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    epochs=EPOCHS,
    callbacks=callbacks,
)

print("\nTraining complete!")

In [None]:
# Plot training history
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 2, figsize=(12, 4))

axes[0].plot(history.history['loss'], label='Train')
axes[0].plot(history.history['val_loss'], label='Val')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

if 'mean_absolute_error' in history.history:
    axes[1].plot(history.history['mean_absolute_error'], label='Train')
    axes[1].plot(history.history['val_mean_absolute_error'], label='Val')
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('MAE (m)')
    axes[1].set_title('MAE')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(f"{LOG_DIR}/history.png", dpi=150)
plt.show()

print(f"Final train loss: {history.history['loss'][-1]:.6f}")
print(f"Final val loss:   {history.history['val_loss'][-1]:.6f}")

---
## 7. Quick Evaluation

In [None]:
# Quick prediction test on one batch
for inputs, labels in val_dataset.take(1):
    # Single-step prediction
    pred = model.call(inputs)
    pred_np = np.clip(pred.numpy(), 0, None)
    label_np = labels.numpy()
    
    print("Single-step prediction test:")
    print(f"  Pred shape: {pred_np.shape}")
    print(f"  Label shape: {label_np.shape}")
    print(f"  Pred range: [{pred_np.min():.3f}, {pred_np.max():.3f}]")
    print(f"  Label range: [{label_np.min():.3f}, {label_np.max():.3f}]")
    mae = np.abs(pred_np.squeeze() - label_np).mean()
    print(f"  MAE: {mae:.5f} m")

In [None]:
# Visualize prediction vs ground truth
import matplotlib.pyplot as plt

n_show = min(4, pred_np.shape[0])
fig, axes = plt.subplots(n_show, 2, figsize=(10, 4*n_show))
if n_show == 1:
    axes = axes.reshape(1, -1)

for i in range(n_show):
    gt = label_np[i]
    pd = pred_np[i].squeeze()
    vmax = max(gt.max(), pd.max(), 0.1)
    
    axes[i, 0].imshow(gt, cmap='Blues', vmin=0, vmax=vmax)
    axes[i, 0].set_title(f'GT (max={gt.max():.2f}m)')
    axes[i, 0].axis('off')
    
    axes[i, 1].imshow(pd, cmap='Blues', vmin=0, vmax=vmax)
    axes[i, 1].set_title(f'Pred (max={pd.max():.2f}m)')
    axes[i, 1].axis('off')

plt.tight_layout()
plt.savefig(f"{LOG_DIR}/prediction_sample.png", dpi=150)
plt.show()

---
## 8. Save Model

In [None]:
final_path = f"{LOG_DIR}/final_model.keras"
model.save_model(final_path)

print(f"Model saved to: {final_path}")
print(f"Best checkpoint: {LOG_DIR}/best_model.keras")

In [None]:
# Save config
import json

config = {
    "mode": "patches" if USE_PATCHES else "full_map",
    "spatial_size": SPATIAL_SIZE,
    "patch_size": PATCH_SIZE if USE_PATCHES else None,
    "patch_stride": PATCH_STRIDE if USE_PATCHES else None,
    "sim_names": sim_names,
    "hypertuning_used": RUN_HYPERTUNING,
    "best_hyperparameters": best_hp.values if best_hp else None,
    "params": params.to_dict(),
    "batch_size": BATCH_SIZE,
    "epochs_trained": len(history.history['loss']),
    "final_train_loss": float(history.history['loss'][-1]),
    "final_val_loss": float(history.history['val_loss'][-1]),
}

with open(f"{LOG_DIR}/config.json", "w") as f:
    json.dump(config, f, indent=2)

print(f"Config saved to: {LOG_DIR}/config.json")
print(f"\nSummary:")
print(f"  Hypertuning: {'Yes' if RUN_HYPERTUNING else 'No'}")
print(f"  Best params: lstm_units={LSTM_UNITS}, kernel={LSTM_KERNEL_SIZE}")
print(f"  Final val loss: {config['final_val_loss']:.6f}")