# Flood Risk Forecasting with LSTM
This notebook demonstrates how to fetch hydrological data, train an LSTM model to predict streamflow, and derive flood risk alerts.

In [None]:
%matplotlib inline
import os
import datetime as dt
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
from flood_risk.data_fetch import (USGSConfig, NOAAConfig, fetch_noaa_precipitation, fetch_usgs_streamflow, merge_hydro_meteorological, create_supervised_sequences)
from flood_risk.preprocess import TimeSeriesScaler, train_val_test_split, compute_risk_score
from flood_risk.model import LSTMForecaster
from flood_risk.training import TrainingConfig, create_dataloader, train_model, iter_predictions
from flood_risk.alerting import evaluate_risk
plt.style.use('seaborn-v0_8')


## Configuration
Adjust the identifiers and lookback horizon for your region of interest.

In [None]:
SITE_ID = '07289000'  # example: Mississippi River at Vicksburg
STATION_ID = 'GHCND:USW00012916'  # example: New Orleans International Airport
START_DATE = dt.date(2020, 1, 1)
END_DATE = dt.date(2020, 12, 31)
LOOKBACK = 30
HORIZON = 7
RISK_THRESHOLD = 200000  # cubic feet per second
NOAA_TOKEN = os.getenv('NOAA_TOKEN')


## Data Ingestion
The cell below attempts to download live data. If the APIs are unavailable, synthetic data is generated to illustrate the workflow.

In [None]:
try:
    usgs_cfg = USGSConfig(site=SITE_ID, start_date=START_DATE, end_date=END_DATE)
    noaa_cfg = NOAAConfig(station=STATION_ID, start_date=START_DATE, end_date=END_DATE)
    streamflow = fetch_usgs_streamflow(usgs_cfg)
    precip = fetch_noaa_precipitation(noaa_cfg, token=NOAA_TOKEN)
    data = merge_hydro_meteorological(streamflow, precip, freq='D')
    source_label = 'USGS/NOAA API'
except Exception as exc:
    print(f'Falling back to synthetic data because: {exc}')
    index = pd.date_range(START_DATE, END_DATE, freq='D')
    discharge = 150000 + 40000 * np.sin(np.linspace(0, 12 * np.pi, len(index)))
    precip = 10 * np.maximum(0, np.random.randn(len(index)))
    data = pd.DataFrame({'discharge_cfs': discharge, 'precip_mm': precip}, index=index)
    source_label = 'synthetic generator'
data.head()


## Feature Engineering
Convert the daily time series into rolling windows for LSTM training.

In [None]:
x_seq, y_seq = create_supervised_sequences(data, target_col='discharge_cfs', lookback=LOOKBACK, horizon=HORIZON)
X = np.stack(x_seq.values)
y = np.stack(y_seq.values)
scaler = TimeSeriesScaler()
X_scaled = scaler.fit_transform(X)
y_min = y.min()
y_max = y.max()
y_scaled = (y - y_min) / (y_max - y_min)
(train, val, test) = train_val_test_split(X_scaled, y_scaled)
len(train[0]), len(val[0]), len(test[0])


## Model Training
Instantiate the LSTM forecaster and train with early stopping.

In [None]:
input_size = X.shape[2]
model = LSTMForecaster(input_size=input_size, hidden_size=64, num_layers=2, horizon=HORIZON)
config = TrainingConfig(epochs=30, learning_rate=1e-3, batch_size=32, patience=5, device='cuda' if torch.cuda.is_available() else 'cpu')
train_loader = create_dataloader(train[0], train[1], batch_size=config.batch_size, shuffle=True)
val_loader = create_dataloader(val[0], val[1], batch_size=config.batch_size, shuffle=False)
history = train_model(model, train_loader, val_loader, config)
history


### Loss Curves

In [None]:
plt.figure(figsize=(8, 4))
plt.plot(history['train_loss'], label='Train')
plt.plot(history['val_loss'], label='Validation')
plt.title('Training History')
plt.xlabel('Epoch')
plt.ylabel('MSE Loss')
plt.legend()
plt.show()


## Evaluation
Generate predictions for the holdout set and compare with actual discharge.

In [None]:
test_loader = create_dataloader(test[0], test[1], batch_size=config.batch_size, shuffle=False)
device = torch.device(config.device)
pred_batches = list(iter_predictions(model, test_loader, device))
predictions_scaled = np.concatenate(pred_batches, axis=0)
predictions = predictions_scaled * (y_max - y_min) + y_min
actuals = test[1] * (y_max - y_min) + y_min
pred_series = pd.Series(predictions[:, 0], index=data.index[-len(predictions):])
actual_series = pd.Series(actuals[:, 0], index=data.index[-len(actuals):])
plt.figure(figsize=(10, 4))
plt.plot(actual_series, label='Observed')
plt.plot(pred_series, label='Forecasted')
plt.title(f'Streamflow forecast comparison ({source_label})')
plt.ylabel('Discharge (cfs)')
plt.legend()
plt.show()


## Flood Risk Assessment
Calculate risk scores relative to a discharge threshold and inspect exceedance days.

In [None]:
risk_series = compute_risk_score(pred_series, threshold=RISK_THRESHOLD)
plt.figure(figsize=(10, 3))
plt.plot(risk_series, label='Risk score')
plt.axhline(1.0, color='red', linestyle='--', label='Alert threshold')
plt.ylabel('Risk Ratio')
plt.title('Flood risk ratio over forecast horizon')
plt.legend()
plt.show()
risk_series.tail()


## Alert Simulation
Use the alerting utilities to determine whether an AWS notification should be sent.

In [None]:
peak_predictions = predictions[:, 0]
risk_score = evaluate_risk(peak_predictions, threshold=RISK_THRESHOLD)
if risk_score >= 1.0:
    print(f'ALERT: Risk score {risk_score:.2f} exceeds threshold!')
else:
    print(f'No alert. Risk score {risk_score:.2f}.')
risk_score
