# FloodML Prediction Notebook

Clean, single-purpose notebook for running flood predictions.

**Sections:**
1. Configuration - model path, data settings
2. Environment Setup
3. Load Model
4. Download Data (Optional) - skip if data already cached
5. Load Data - supports both training data (with labels) and prediction-only data
6. Run Predictions
7. Visualize Results
8. Compute Metrics (if labels available)
9. Export Results (GeoTIFF, PNG, GCS upload)

---
## 1. Configuration

In [None]:
# ============================================================================
# CONFIGURATION - Edit these values
# ============================================================================

# --- Model ---
MODEL_PATH = "logs/flood_training_XXXXXXXX-XXXXXX/final_model.keras"  # <-- UPDATE THIS

# --- Data paths ---
FILECACHE_DIR = "/home/shared/climateiq/filecache"

# --- Download settings ---
# Set to True ONLY if you need to download data from GCS
DOWNLOAD_DATA = False

# --- Prediction mode ---
# INCLUDE_LABELS = True  -> Validation mode (has ground truth to compare)
# INCLUDE_LABELS = False -> Pure prediction mode (no ground truth)
INCLUDE_LABELS = True

# ===== FOR VALIDATION MODE (INCLUDE_LABELS=True) =====
# Specify simulation names with labels
CITY_CONFIG = {
    "Manhattan": "Manhattan_config",
}
RAINFALL_IDS = [5]  # Which rainfall scenarios to predict on
DATASET_SPLIT = "val"  # "train", "val", or "test"

# ===== FOR PREDICTION MODE (INCLUDE_LABELS=False) =====
# Specify study area and rainfall scenario
STUDY_AREA = "Atlanta_Prediction"  # Study area name (for features)
RAINFALL_SIM = "Atlanta-Atlanta_config/Rainfall_Data_22.txt"  # For temporal features

# --- Prediction settings ---
N_STEPS = 13              # Number of timesteps to predict
BATCH_SIZE = 2            # Reduce if OOM errors
MAX_CHUNKS = None         # None = all chunks, or set integer for quick tests

# --- Export settings ---
EXPORT_RESULTS = False    # Set True to export GeoTIFF/PNG
GCS_BUCKET = "mloutputstest"  # GCS bucket for uploads (if EXPORT_RESULTS=True)

# Model params (should match training)
N_FLOOD_MAPS = 5
M_RAINFALL = 6

print(f"Mode: {'Validation (with labels)' if INCLUDE_LABELS else 'Prediction (no labels)'}")
print(f"Download data: {DOWNLOAD_DATA}")

---
## 2. Environment Setup

In [None]:
%load_ext autoreload
%autoreload 2

import os
import pathlib
import logging

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
logging.getLogger().setLevel(logging.WARNING)

import tensorflow as tf
import keras
import numpy as np
import matplotlib.pyplot as plt

keras.utils.set_random_seed(42)

for gpu in tf.config.list_physical_devices("GPU"):
    tf.config.experimental.set_memory_growth(gpu, True)

print(f"TensorFlow version: {tf.__version__}")
print(f"GPUs available: {len(tf.config.list_physical_devices('GPU'))}")

In [None]:
from usl_models.flood_ml.model import FloodModel, SpatialAttention
from usl_models.flood_ml.dataset import load_dataset_cached, download_dataset
from usl_models.flood_ml import constants

In [None]:
# Build simulation names based on mode
if INCLUDE_LABELS:
    # Validation mode: build from city/rainfall config
    sim_names = []
    for city, config in CITY_CONFIG.items():
        for rain_id in RAINFALL_IDS:
            sim_names.append(f"{city}-{config}/Rainfall_Data_{rain_id}.txt")
    print(f"Validation mode - {len(sim_names)} simulations:")
    for s in sim_names:
        exists = (pathlib.Path(FILECACHE_DIR) / s).exists()
        print(f"  [{'CACHED' if exists else 'NOT CACHED'}] {s}")
else:
    # Prediction mode: study area + rainfall sim
    sim_names = [STUDY_AREA]
    rainfall_dir = pathlib.Path(FILECACHE_DIR) / STUDY_AREA / pathlib.Path(RAINFALL_SIM).name
    print(f"Prediction mode:")
    print(f"  Study area: {STUDY_AREA}")
    print(f"  Rainfall: {RAINFALL_SIM}")
    print(f"  Cache exists: {rainfall_dir.exists()}")

---
## 3. Load Model

In [None]:
# Load the trained model
print(f"Loading model from: {MODEL_PATH}")

# Custom objects needed for loading
custom_objects = {'SpatialAttention': SpatialAttention}

try:
    # Try loading as FloodModel
    model = FloodModel.from_checkpoint(MODEL_PATH)
    print("Loaded via FloodModel.from_checkpoint()")
except Exception as e:
    print(f"FloodModel.from_checkpoint failed: {e}")
    print("Trying direct keras load...")
    
    # Fallback: direct keras load
    loaded_model = tf.keras.models.load_model(
        MODEL_PATH,
        custom_objects=custom_objects,
        compile=False
    )
    
    # Wrap in FloodModel
    params = FloodModel.Params(
        n_flood_maps=N_FLOOD_MAPS,
        m_rainfall=M_RAINFALL,
    )
    model = FloodModel(params=params)
    model._model.set_weights(loaded_model.get_weights())
    print("Loaded via keras and wrapped in FloodModel")

print("Model loaded successfully!")

In [None]:
# Show model summary
model._model.summary()

---
## 4. Download Data (Optional)

**Skip this section if data is already cached.**

Set `DOWNLOAD_DATA = True` in Configuration to enable download from GCS.

In [None]:
# Download data from GCS to local filecache
# This cell is skipped if DOWNLOAD_DATA = False

if DOWNLOAD_DATA:
    print("="*60)
    print("DOWNLOADING DATA FROM GCS")
    print("="*60)
    print(f"Target directory: {FILECACHE_DIR}")
    print(f"Mode: {'Validation (with labels)' if INCLUDE_LABELS else 'Prediction (features only)'}")
    print()
    
    if INCLUDE_LABELS:
        # Validation mode: download with labels
        print(f"Downloading {len(sim_names)} simulations with labels...")
        download_dataset(
            sim_names=sim_names,
            output_path=pathlib.Path(FILECACHE_DIR),
            dataset_splits=[DATASET_SPLIT],
            include_labels=True,
        )
    else:
        # Prediction mode: download features only
        print(f"Downloading study area: {STUDY_AREA}")
        print(f"Rainfall source: {RAINFALL_SIM}")
        download_dataset(
            sim_names=[STUDY_AREA],
            output_path=pathlib.Path(FILECACHE_DIR),
            include_labels=False,
            rainfall_sim_name=RAINFALL_SIM,
            allow_missing_sim=True,
        )
    
    print("\nDownload complete!")
else:
    print("Skipping download (DOWNLOAD_DATA=False)")
    print("Using existing cached data.")

---
## 5. Load Data

In [None]:
# Load dataset based on mode
filecache_dir = pathlib.Path(FILECACHE_DIR)

if INCLUDE_LABELS:
    # VALIDATION MODE - has ground truth labels
    print(f"Loading data WITH labels (split: {DATASET_SPLIT})")
    print(f"Simulations: {sim_names}")
    
    dataset = load_dataset_cached(
        filecache_dir=filecache_dir,
        sim_names=sim_names,
        dataset_split=DATASET_SPLIT,
        batch_size=BATCH_SIZE,
        n_flood_maps=N_FLOOD_MAPS,
        m_rainfall=M_RAINFALL,
        max_chunks=MAX_CHUNKS,
        include_labels=True,
        shuffle=False,
    )
else:
    # PREDICTION MODE - no ground truth
    print(f"Loading data WITHOUT labels (prediction mode)")
    print(f"Study area: {STUDY_AREA}")
    print(f"Rainfall sim: {RAINFALL_SIM}")
    
    dataset = load_dataset_cached(
        filecache_dir=filecache_dir,
        sim_names=[STUDY_AREA],
        dataset_split=None,
        batch_size=BATCH_SIZE,
        n_flood_maps=N_FLOOD_MAPS,
        m_rainfall=M_RAINFALL,
        max_chunks=MAX_CHUNKS,
        include_labels=False,
        rainfall_sim_name=RAINFALL_SIM,
        shuffle=False,
    )

print("Dataset loaded.")

In [None]:
# Preview one batch
for inputs, labels, metadata in dataset.take(1):
    print("Sample batch:")
    print(f"  geospatial: {inputs['geospatial'].shape}")
    print(f"  temporal: {inputs['temporal'].shape}")
    print(f"  spatiotemporal: {inputs['spatiotemporal'].shape}")
    print(f"  labels: {labels.shape}")
    print(f"  chunk names: {metadata['feature_chunk'].numpy()}")

---
## 6. Run Predictions

In [None]:
# Run predictions on all chunks
all_predictions = []   # List of [H, W] max-depth arrays
all_labels = []        # List of [H, W] max-depth arrays (if available)
all_chunk_names = []   # Chunk identifiers

print(f"Running predictions (n_steps={N_STEPS})...")

for batch_idx, (inputs, labels, metadata) in enumerate(dataset):
    current_bs = inputs["geospatial"].shape[0]
    chunk_names = [s.decode() if isinstance(s, bytes) else s 
                   for s in metadata["feature_chunk"].numpy()]
    
    # Handle incomplete final batch by padding
    if current_bs < BATCH_SIZE:
        def pad_tensor(t):
            repeats = BATCH_SIZE - current_bs
            return tf.concat([t, tf.repeat(t[-1:], repeats=repeats, axis=0)], axis=0)
        
        padded_inputs = {k: pad_tensor(v) for k, v in inputs.items()}
        preds = model.call_n(padded_inputs, n=N_STEPS)[:current_bs]
    else:
        preds = model.call_n(inputs, n=N_STEPS)
    
    # Convert to numpy and compute max over time
    preds_np = np.clip(preds.numpy(), 0, None)  # Clip negatives
    max_pred = np.max(preds_np, axis=1)         # [B, H, W]
    
    # Store predictions
    for i in range(current_bs):
        all_predictions.append(max_pred[i])
        all_chunk_names.append(chunk_names[i])
    
    # Store labels if available
    if INCLUDE_LABELS and labels.numpy().max() > 0:
        labels_np = labels.numpy()
        max_label = np.max(labels_np, axis=1)
        for i in range(current_bs):
            all_labels.append(max_label[i])
    
    if (batch_idx + 1) % 5 == 0:
        print(f"  Processed {batch_idx + 1} batches ({len(all_predictions)} chunks)")

# Convert to arrays
all_predictions = np.array(all_predictions)
all_labels = np.array(all_labels) if all_labels else None

print(f"\nPrediction complete!")
print(f"  Total chunks: {len(all_chunk_names)}")
print(f"  Predictions shape: {all_predictions.shape}")
print(f"  Labels available: {all_labels is not None}")

---
## 7. Visualize Results

In [None]:
# Plot a single chunk
idx = 0  # Change to view different chunks

pred = all_predictions[idx]
has_label = all_labels is not None and idx < len(all_labels)

if has_label:
    label = all_labels[idx]
    vmax = max(pred.max(), label.max(), 0.1)
    
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))
    
    im0 = axes[0].imshow(label, cmap='Blues', vmin=0, vmax=vmax)
    axes[0].set_title(f'Ground Truth Max Depth\n(max={label.max():.2f}m)')
    axes[0].axis('off')
    plt.colorbar(im0, ax=axes[0], shrink=0.8)
    
    im1 = axes[1].imshow(pred, cmap='Blues', vmin=0, vmax=vmax)
    axes[1].set_title(f'Predicted Max Depth\n(max={pred.max():.2f}m)')
    axes[1].axis('off')
    plt.colorbar(im1, ax=axes[1], shrink=0.8)
else:
    fig, ax = plt.subplots(1, 1, figsize=(8, 6))
    im = ax.imshow(pred, cmap='Blues')
    ax.set_title(f'Predicted Max Depth\n(max={pred.max():.2f}m)')
    ax.axis('off')
    plt.colorbar(im, ax=ax)

plt.suptitle(f'Chunk: {all_chunk_names[idx]}', fontsize=12)
plt.tight_layout()
plt.show()

In [None]:
# Plot multiple chunks side by side
n_show = min(5, len(all_predictions))
has_labels = all_labels is not None and len(all_labels) >= n_show

ncols = 2 if has_labels else 1
fig, axes = plt.subplots(n_show, ncols, figsize=(6*ncols, 5*n_show))
if n_show == 1:
    axes = axes.reshape(1, -1)

for i in range(n_show):
    pred = all_predictions[i]
    
    if has_labels:
        label = all_labels[i]
        vmax = max(pred.max(), label.max(), 0.1)
        
        axes[i, 0].imshow(label, cmap='cubehelix', vmin=0, vmax=vmax)
        axes[i, 0].set_title(f'GT: {all_chunk_names[i]}')
        axes[i, 0].axis('off')
        
        axes[i, 1].imshow(pred, cmap='cubehelix', vmin=0, vmax=vmax)
        axes[i, 1].set_title(f'Pred (max={pred.max():.2f}m)')
        axes[i, 1].axis('off')
    else:
        axes[i, 0].imshow(pred, cmap='cubehelix')
        axes[i, 0].set_title(f'{all_chunk_names[i]} (max={pred.max():.2f}m)')
        axes[i, 0].axis('off')

plt.tight_layout()
plt.show()

---
## 8. Compute Metrics (if labels available)

In [None]:
if all_labels is not None:
    from sklearn.metrics import mean_absolute_error, mean_squared_error
    
    metrics_per_chunk = []
    
    for i in range(len(all_predictions)):
        pred = all_predictions[i].flatten()
        label = all_labels[i].flatten()
        
        # Mask NaNs
        mask = ~np.isnan(label)
        pred_m = pred[mask]
        label_m = label[mask]
        
        mae = mean_absolute_error(label_m, pred_m)
        rmse = np.sqrt(mean_squared_error(label_m, pred_m))
        bias = pred_m.mean() - label_m.mean()
        
        # IoU for flooded pixels (threshold 0.1m)
        thresh = 0.1
        intersection = np.logical_and(label_m > thresh, pred_m > thresh).sum()
        union = np.logical_or(label_m > thresh, pred_m > thresh).sum()
        iou = intersection / max(union, 1)
        
        metrics_per_chunk.append({
            'chunk': all_chunk_names[i],
            'mae': mae,
            'rmse': rmse,
            'bias': bias,
            'iou': iou,
        })
    
    # Summary
    import pandas as pd
    df = pd.DataFrame(metrics_per_chunk)
    
    print("Per-chunk metrics:")
    print(df.to_string(index=False))
    
    print("\nAggregate metrics:")
    print(f"  MAE:  {df['mae'].mean():.4f} +/- {df['mae'].std():.4f} m")
    print(f"  RMSE: {df['rmse'].mean():.4f} +/- {df['rmse'].std():.4f} m")
    print(f"  Bias: {df['bias'].mean():.4f} +/- {df['bias'].std():.4f} m")
    print(f"  IoU:  {df['iou'].mean():.4f} +/- {df['iou'].std():.4f}")
else:
    print("No labels available - skipping metrics computation.")

---
## 9. Export Results (Optional)

Export predictions as GeoTIFF and PNG, optionally upload to GCS.

In [None]:
if EXPORT_RESULTS:
    import tempfile
    
    # Try importing rasterio (for GeoTIFF)
    try:
        import rasterio
        from rasterio.transform import from_origin
        HAS_RASTERIO = True
    except ImportError:
        print("Warning: rasterio not installed. GeoTIFF export disabled.")
        print("Install with: pip install rasterio")
        HAS_RASTERIO = False
    
    # Local output directory
    output_dir = pathlib.Path("prediction_output")
    output_dir.mkdir(exist_ok=True)
    
    # GeoTIFF settings (update these for your coordinate system)
    pixel_size = 10  # metres per pixel
    top_left_x = 0
    top_left_y = 0
    crs = "EPSG:32618"  # UTM zone 18N (adjust for your area)
    
    print(f"Exporting {len(all_predictions)} predictions...")
    
    for i, (pred, chunk_name) in enumerate(zip(all_predictions, all_chunk_names)):
        # Clean chunk name for filename
        safe_name = chunk_name.replace("/", "_").replace("\\", "_")
        
        # Save PNG
        png_path = output_dir / f"{safe_name}.png"
        plt.imsave(png_path, pred, cmap='Blues')
        
        # Save GeoTIFF
        if HAS_RASTERIO:
            tif_path = output_dir / f"{safe_name}.tif"
            transform = from_origin(top_left_x, top_left_y, pixel_size, pixel_size)
            
            with rasterio.open(
                tif_path, 'w',
                driver='GTiff',
                height=pred.shape[0],
                width=pred.shape[1],
                count=1,
                dtype=pred.dtype,
                crs=crs,
                transform=transform,
            ) as dst:
                dst.write(pred, 1)
        
        if (i + 1) % 10 == 0:
            print(f"  Exported {i + 1}/{len(all_predictions)}")
    
    print(f"\nExported to: {output_dir.absolute()}")
else:
    print("Export disabled. Set EXPORT_RESULTS=True to enable.")

In [None]:
# Optional: Upload to Google Cloud Storage
UPLOAD_TO_GCS = False  # Set True to upload

if EXPORT_RESULTS and UPLOAD_TO_GCS:
    from google.cloud import storage
    
    client = storage.Client()
    bucket = client.bucket(GCS_BUCKET)
    
    # Define GCS folder structure
    if INCLUDE_LABELS:
        gcs_folder = f"predictions/{sim_names[0].replace('/', '_')}"
    else:
        gcs_folder = f"predictions/{STUDY_AREA}/{pathlib.Path(RAINFALL_SIM).name}"
    
    print(f"Uploading to gs://{GCS_BUCKET}/{gcs_folder}/...")
    
    output_dir = pathlib.Path("prediction_output")
    for f in output_dir.glob("*"):
        blob_path = f"{gcs_folder}/{f.name}"
        bucket.blob(blob_path).upload_from_filename(str(f))
        print(f"  Uploaded: {blob_path}")
    
    print(f"\nUpload complete!")
else:
    print("GCS upload disabled.")

---
## Summary

In [None]:
print("=" * 50)
print("PREDICTION SUMMARY")
print("=" * 50)
print(f"Model: {MODEL_PATH}")
print(f"Mode: {'Validation (with labels)' if INCLUDE_LABELS else 'Prediction only'}")
print(f"Chunks processed: {len(all_predictions)}")
print(f"Prediction steps: {N_STEPS}")
print(f"Max depth range: {all_predictions.min():.3f} - {all_predictions.max():.3f} m")

if all_labels is not None:
    overall_mae = np.abs(all_predictions - all_labels).mean()
    print(f"Overall MAE: {overall_mae:.4f} m")