# Runoff Forecasting using LSTM and Transformer Models

This notebook implements and evaluates LSTM and Transformer models for runoff forecasting at two different stations (21609641 and 20380357).

## Import Required Libraries
Import libraries such as NumPy, Pandas, Matplotlib, TensorFlow, and potentially PyTorch for data handling, visualization, and model development.

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import os
from joblib import load as joblib_load
# import torch # Uncomment if using PyTorch

# Evaluation metrics (if not using standard libraries)
# from sklearn.metrics import mean_squared_error, r2_score # Example
# Define custom metrics if needed (e.g., NSE, PBIAS, CC)

print("TensorFlow Version:", tf.__version__)
# print("PyTorch Version:", torch.__version__) # Uncomment if using PyTorch

## Load and Explore Data
Load the raw and processed data for both stations. Perform exploratory data analysis to understand patterns and anomalies.

In [None]:
# Define file paths (adjust as necessary)
PROCESSED_DATA_DIR = 'data/processed'
MODELS_DIR = 'models'
RESULTS_DIR = 'results'
PLOTS_DIR = os.path.join(RESULTS_DIR, 'plots')
METRICS_DIR = os.path.join(RESULTS_DIR, 'metrics')

# Create results directories if they don't exist
os.makedirs(PLOTS_DIR, exist_ok=True)
os.makedirs(METRICS_DIR, exist_ok=True)

# Load processed data info (example for one station to show structure)
station_id_s1 = '21609641'
station_id_s2 = '20380357'

try:
    test_data_s1 = np.load(os.path.join(PROCESSED_DATA_DIR, 'test', f'{station_id_s1}.npz'))
    print(f"Loaded test data keys for {station_id_s1}: {list(test_data_s1.keys())}")
    print(f"  X_test shape: {test_data_s1['X_test'].shape}")
    print(f"  y_test_scaled shape: {test_data_s1['y_test_scaled'].shape}")
    print(f"  nwm_test_original shape: {test_data_s1['nwm_test_original'].shape}")
    print(f"  usgs_test_original shape: {test_data_s1['usgs_test_original'].shape}")
    test_data_s1.close() # Close the file

    test_data_s2 = np.load(os.path.join(PROCESSED_DATA_DIR, 'test', f'{station_id_s2}.npz'))
    print(f"\nLoaded test data keys for {station_id_s2}: {list(test_data_s2.keys())}")
    print(f"  X_test shape: {test_data_s2['X_test'].shape}")
    print(f"  y_test_scaled shape: {test_data_s2['y_test_scaled'].shape}")
    print(f"  nwm_test_original shape: {test_data_s2['nwm_test_original'].shape}")
    print(f"  usgs_test_original shape: {test_data_s2['usgs_test_original'].shape}")
    test_data_s2.close()

except FileNotFoundError as e:
    print(f"Error loading test data: {e}. Please ensure preprocessing was run.")

# --- (Optional) Exploratory Data Analysis on loaded arrays ---
# Example: Plot time series for one lead time from the loaded arrays
try:
    test_data_s1 = np.load(os.path.join(PROCESSED_DATA_DIR, 'test', f'{station_id_s1}.npz'))
    usgs_s1_lead1 = test_data_s1['usgs_test_original'][:, 0] # Lead time 1
    nwm_s1_lead1 = test_data_s1['nwm_test_original'][:, 0] # Lead time 1
    # Create dummy time index for plotting
    time_index_s1 = np.arange(len(usgs_s1_lead1))

    plt.figure(figsize=(12, 6))
    plt.plot(time_index_s1, usgs_s1_lead1, label='USGS Observation (S1, Lead 1)')
    plt.plot(time_index_s1, nwm_s1_lead1, label='NWM Forecast (S1, Lead 1)', alpha=0.7)
    plt.title(f'Station {station_id_s1} - Observed vs. NWM Forecast (Lead 1 - Test Set)')
    plt.xlabel('Time Step Index')
    plt.ylabel('Runoff (cms)')
    plt.legend()
    plt.show()
    test_data_s1.close()
except FileNotFoundError:
    print("Skipping EDA plot for Station 1: Test data not found.")

try:
    test_data_s2 = np.load(os.path.join(PROCESSED_DATA_DIR, 'test', f'{station_id_s2}.npz'))
    usgs_s2_lead1 = test_data_s2['usgs_test_original'][:, 0] # Lead time 1
    nwm_s2_lead1 = test_data_s2['nwm_test_original'][:, 0] # Lead time 1
    time_index_s2 = np.arange(len(usgs_s2_lead1))

    plt.figure(figsize=(12, 6))
    plt.plot(time_index_s2, usgs_s2_lead1, label='USGS Observation (S2, Lead 1)')
    plt.plot(time_index_s2, nwm_s2_lead1, label='NWM Forecast (S2, Lead 1)', alpha=0.7)
    plt.title(f'Station {station_id_s2} - Observed vs. NWM Forecast (Lead 1 - Test Set)')
    plt.xlabel('Time Step Index')
    plt.ylabel('Runoff (cms)')
    plt.legend()
    plt.show()
    test_data_s2.close()
except FileNotFoundError:
    print("Skipping EDA plot for Station 2: Test data not found.")

## Preprocess Data
Implement data preprocessing steps, including cleaning, aligning NWM forecasts with USGS observations, creating input-output sequences, and splitting data into training and testing sets.

In [None]:
# --- Preprocessing was done in preprocess.py ---
# This cell is kept for context but the actual data loading for evaluation
# will happen in the evaluation step using the .npz files.
print("Data preprocessing is assumed to be completed by 'src/preprocess.py'.")
print("Loading preprocessed data from .npz files in the evaluation step.")

## Evaluate Models
Evaluate the trained models on the test set using metrics such as CC (Correlation Coefficient), RMSE (Root Mean Squared Error), PBIAS (Percent Bias), and NSE (Nash-Sutcliffe Efficiency). Compare results against raw NWM forecasts.

In [None]:
# --- Evaluation Metrics Functions ---
def calculate_rmse(y_true, y_pred):
    # Ensure inputs are numpy arrays
    y_true, y_pred = np.array(y_true), np.array(y_pred)
    # Filter out NaN values to avoid computation errors
    mask = ~np.isnan(y_true) & ~np.isnan(y_pred)
    if np.sum(mask) == 0:
        return np.nan # Return NaN if no valid pairs
    return np.sqrt(np.mean((y_true[mask] - y_pred[mask])**2))

def calculate_cc(y_true, y_pred):
    y_true, y_pred = np.array(y_true), np.array(y_pred)
    mask = ~np.isnan(y_true) & ~np.isnan(y_pred)
    if np.sum(mask) < 2: # Need at least 2 points for correlation
        return np.nan
    # Check for zero standard deviation
    if np.std(y_true[mask]) == 0 or np.std(y_pred[mask]) == 0:
        return np.nan
    return np.corrcoef(y_true[mask], y_pred[mask])[0, 1]

def calculate_pbias(y_true, y_pred):
    y_true, y_pred = np.array(y_true), np.array(y_pred)
    mask = ~np.isnan(y_true) & ~np.isnan(y_pred)
    if np.sum(mask) == 0:
        return np.nan
    sum_true = np.sum(y_true[mask])
    if sum_true == 0:
        return np.nan # Avoid division by zero
    return 100 * np.sum(y_pred[mask] - y_true[mask]) / sum_true

def calculate_nse(y_true, y_pred):
    y_true, y_pred = np.array(y_true), np.array(y_pred)
    mask = ~np.isnan(y_true) & ~np.isnan(y_pred)
    if np.sum(mask) == 0:
        return np.nan
    mean_true = np.mean(y_true[mask])
    numerator = np.sum((y_true[mask] - y_pred[mask])**2)
    denominator = np.sum((y_true[mask] - mean_true)**2)
    if denominator == 0:
        # Handle case where observed data is constant: NSE is -inf if pred != true, 1 if pred == true
        return 1.0 if numerator == 0 else -np.inf
    return 1 - (numerator / denominator)

# --- Helper Function to Load Data and Scaler ---
def load_evaluation_data(station_id):
    test_data_path = os.path.join(PROCESSED_DATA_DIR, 'test', f'{station_id}.npz')
    scaler_path = os.path.join(PROCESSED_DATA_DIR, 'scalers', f'{station_id}_y_scaler.joblib')
    if not os.path.exists(test_data_path) or not os.path.exists(scaler_path):
        print(f"Warning: Test data or scaler not found for station {station_id}. Skipping evaluation.")
        return None, None, None, None, None
    try:
        data = np.load(test_data_path)
        X_test = data['X_test']
        y_test_scaled = data['y_test_scaled'] # Scaled errors
        nwm_test_original = data['nwm_test_original']
        usgs_test_original = data['usgs_test_original']
        data.close()
        y_scaler = joblib_load(scaler_path)
        print(f"Loaded test data and y_scaler for station {station_id}.")
        # Ensure y_test_scaled has the correct shape (samples, lead_times)
        if len(y_test_scaled.shape) == 1:
             # This might happen if preprocess saved it flattened, reshape based on nwm/usgs data
             n_samples = X_test.shape[0]
             n_lead_times = nwm_test_original.shape[1]
             if len(y_test_scaled) == n_samples * n_lead_times:
                 y_test_scaled = y_test_scaled.reshape(n_samples, n_lead_times)
                 print(f"Reshaped y_test_scaled to {y_test_scaled.shape}")
             else:
                 print(f"Warning: Cannot reshape y_test_scaled for station {station_id}. Unexpected length.")
                 return None, None, None, None, None
        elif len(y_test_scaled.shape) != 2:
             print(f"Warning: y_test_scaled for station {station_id} has unexpected shape {y_test_scaled.shape}.")
             return None, None, None, None, None

        return X_test, y_test_scaled, nwm_test_original, usgs_test_original, y_scaler
    except Exception as e:
        print(f"Error loading data for station {station_id}: {e}")
        return None, None, None, None, None

# --- Helper Function to Evaluate a Model ---
def run_evaluation(station_id, model_type):
    model_filename = f"{station_id}_{model_type.lower()}_best.keras"
    model_path = os.path.join(MODELS_DIR, model_filename)
    if not os.path.exists(model_path):
        print(f"Model file not found: {model_path}. Skipping evaluation.")
        return None, None, None, None

    X_test, y_test_scaled, nwm_test_original, usgs_test_original, y_scaler = load_evaluation_data(station_id)
    if X_test is None:
        return None, None, None, None

    try:
        model = tf.keras.models.load_model(model_path)
        print(f"Loaded model {model_filename}")

        # Predict scaled errors
        predicted_errors_scaled = model.predict(X_test)
        print(f"Predicted errors (scaled) shape: {predicted_errors_scaled.shape}")

        # Ensure prediction shape matches y_test_scaled shape
        if predicted_errors_scaled.shape != y_test_scaled.shape:
             print(f"Warning: Shape mismatch between predicted ({predicted_errors_scaled.shape}) and true scaled errors ({y_test_scaled.shape}). Trying to reshape prediction.")
             # Attempt common reshape if prediction is missing lead time dim
             if len(predicted_errors_scaled.shape) == 1 and len(y_test_scaled.shape) == 2:
                 if predicted_errors_scaled.shape[0] == y_test_scaled.shape[0] * y_test_scaled.shape[1]:
                     predicted_errors_scaled = predicted_errors_scaled.reshape(y_test_scaled.shape)
                     print(f"Reshaped prediction to {predicted_errors_scaled.shape}")
                 else:
                     print("Cannot reshape prediction due to element count mismatch.")
                     return None, None, None, None
             # Attempt common reshape if prediction has extra dim
             elif len(predicted_errors_scaled.shape) == 3 and predicted_errors_scaled.shape[-1] == 1 and len(y_test_scaled.shape) == 2:
                 if predicted_errors_scaled.shape[0] == y_test_scaled.shape[0] and predicted_errors_scaled.shape[1] == y_test_scaled.shape[1]:
                      predicted_errors_scaled = predicted_errors_scaled.squeeze(-1)
                      print(f"Reshaped prediction to {predicted_errors_scaled.shape}")
                 else:
                     print("Cannot reshape prediction due to dimension mismatch.")
                     return None, None, None, None
             else:
                 print("Unhandled shape mismatch.")
                 return None, None, None, None

        # Inverse transform predicted errors
        # Scaler expects (n_samples * n_features, 1) or (n_samples, n_features)
        n_samples = predicted_errors_scaled.shape[0]
        n_lead_times = predicted_errors_scaled.shape[1]
        predicted_errors_unscaled = y_scaler.inverse_transform(predicted_errors_scaled.reshape(-1, n_lead_times))
        # predicted_errors_unscaled = predicted_errors_unscaled.reshape(n_samples, n_lead_times) # Already in this shape
        print(f"Predicted errors (unscaled) shape: {predicted_errors_unscaled.shape}")

        # Calculate corrected NWM forecasts
        corrected_nwm_forecasts = nwm_test_original - predicted_errors_unscaled
        print(f"Corrected NWM forecasts shape: {corrected_nwm_forecasts.shape}")

        # Calculate Metrics for each Lead Time
        print("Calculating evaluation metrics per lead time...")
        metrics = {'lead_time': list(range(1, n_lead_times + 1))}
        metric_funcs = {'CC': calculate_cc, 'RMSE': calculate_rmse, 'PBIAS': calculate_pbias, 'NSE': calculate_nse}

        for metric_name, func in metric_funcs.items():
            metrics[f'NWM_{metric_name}'] = []
            metrics[f'Corrected_{metric_name}'] = []
            for i in range(n_lead_times):
                obs = usgs_test_original[:, i]
                nwm_pred = nwm_test_original[:, i]
                corrected_pred = corrected_nwm_forecasts[:, i]

                # Handle potential NaNs from calculations or data
                nwm_metric = func(obs, nwm_pred)
                corrected_metric = func(obs, corrected_pred)

                metrics[f'NWM_{metric_name}'].append(nwm_metric)
                metrics[f'Corrected_{metric_name}'].append(corrected_metric)

        metrics_df = pd.DataFrame(metrics)
        metrics_filename = os.path.join(METRICS_DIR, f"{station_id}_{model_type.lower()}_evaluation_metrics.csv")
        metrics_df.to_csv(metrics_filename, index=False)
        print(f"Saved evaluation metrics to {metrics_filename}")
        print(metrics_df.head())

        return metrics_df, usgs_test_original, nwm_test_original, corrected_nwm_forecasts

    except Exception as e:
        print(f"Error during evaluation for station {station_id} ({model_type}): {e}")
        import traceback
        traceback.print_exc()
        return None, None, None, None

# --- Run Evaluation for Both Stations ---
print("\n--- Evaluating Station 21609641 (LSTM) ---")
metrics_s1, usgs_s1, nwm_s1, corrected_s1 = run_evaluation(station_id_s1, 'lstm')

print("\n--- Evaluating Station 20380357 (Transformer) ---")
metrics_s2, usgs_s2, nwm_s2, corrected_s2 = run_evaluation(station_id_s2, 'transformer')

print("\n--- Evaluation Summary --- ")
if metrics_s1 is not None:
    print(f"\nMetrics for Station {station_id_s1} (LSTM):")
    print(metrics_s1.describe())
if metrics_s2 is not None:
    print(f"\nMetrics for Station {station_id_s2} (Transformer):")
    print(metrics_s2.describe())

## Generate Visualizations
Create box plots for observed, NWM, and corrected runoff for each lead time. Generate box plots for evaluation metrics across lead times (if multiple lead times were predicted/evaluated).

In [None]:
# --- Visualization Functions ---
def plot_runoff_comparison(station_id, model_type, usgs_data, nwm_data, corrected_data):
    if usgs_data is None or nwm_data is None or corrected_data is None:
        print(f"Skipping runoff comparison plot for {station_id}: Missing data.")
        return
    n_lead_times = usgs_data.shape[1]
    lead_times = list(range(1, n_lead_times + 1))

    plt.figure(figsize=(18, 8)) # Wider figure
    plot_data = []
    for i in range(n_lead_times):
        plot_data.append(pd.DataFrame({
            'Runoff': usgs_data[:, i],
            'Lead Time': lead_times[i],
            'Type': 'Observed (USGS)'
        }))
        plot_data.append(pd.DataFrame({
            'Runoff': nwm_data[:, i],
            'Lead Time': lead_times[i],
            'Type': 'NWM Forecast'
        }))
        plot_data.append(pd.DataFrame({
            'Runoff': corrected_data[:, i],
            'Lead Time': lead_times[i],
            'Type': f'Corrected ({model_type.upper()})'
        }))
    plot_df = pd.concat(plot_data)

    # Filter out potential NaNs before plotting
    plot_df.dropna(subset=['Runoff'], inplace=True)

    sns.boxplot(data=plot_df, x='Lead Time', y='Runoff', hue='Type', showfliers=False) # Hide outliers for clarity
    plt.title(f'Runoff Comparison by Lead Time - Station {station_id}')
    plt.xlabel('Lead Time (Hours)')
    plt.ylabel('Runoff (cms)')
    plt.xticks(rotation=45)
    plt.legend(title='Forecast Type', bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.tight_layout(rect=[0, 0, 0.85, 1]) # Adjust layout to make space for legend
    plot_filename = os.path.join(PLOTS_DIR, f"{station_id}_{model_type.lower()}_runoff_boxplot.png")
    plt.savefig(plot_filename)
    print(f"Saved runoff comparison plot to {plot_filename}")
    plt.show()
    plt.close()

def plot_metrics_comparison(station_id, model_type, metrics_df):
    if metrics_df is None:
        print(f"Skipping metrics comparison plot for {station_id}: Missing metrics data.")
        return

    fig, axes = plt.subplots(2, 2, figsize=(15, 12), sharex=True)
    axes = axes.flatten()
    metric_plot_names = ['CC', 'RMSE', 'PBIAS', 'NSE']

    for i, metric_name in enumerate(metric_plot_names):
        # Check if metric columns exist
        nwm_col = f'NWM_{metric_name}'
        corrected_col = f'Corrected_{metric_name}'
        if nwm_col not in metrics_df.columns or corrected_col not in metrics_df.columns:
             print(f"Warning: Skipping plot for metric '{metric_name}' - columns not found in DataFrame.")
             continue

        # Melt dataframe for seaborn boxplot
        melted_df = pd.melt(metrics_df,
                            id_vars=['lead_time'],
                            value_vars=[nwm_col, corrected_col],
                            var_name='Forecast Type',
                            value_name=metric_name)
        # Clean up the 'Forecast Type' names
        melted_df['Forecast Type'] = melted_df['Forecast Type'].str.replace(f'_{metric_name}', '').replace('Corrected', f'Corrected ({model_type.upper()})')

        # Filter out potential NaNs before plotting
        melted_df.dropna(subset=[metric_name], inplace=True)

        sns.boxplot(data=melted_df, x='lead_time', y=metric_name, hue='Forecast Type', ax=axes[i], showfliers=False)
        axes[i].set_title(f'{metric_name} Comparison')
        axes[i].set_xlabel('Lead Time (Hours)')
        axes[i].set_ylabel(metric_name)
        # axes[i].legend(title='Forecast Type')
        axes[i].legend().set_visible(False) # Hide individual legends, use main figure legend
        axes[i].tick_params(axis='x', rotation=45)
        axes[i].grid(axis='y', linestyle='--')

    # Add a single legend to the figure
    handles, labels = axes[0].get_legend_handles_labels() # Get handles/labels from one subplot
    fig.legend(handles, labels, title='Forecast Type', bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.suptitle(f'Evaluation Metrics Comparison by Lead Time - Station {station_id} ({model_type.upper()})', y=1.02)
    plt.tight_layout(rect=[0, 0, 0.85, 1]) # Adjust layout
    plot_filename = os.path.join(PLOTS_DIR, f"{station_id}_{model_type.lower()}_metrics_boxplot.png")
    plt.savefig(plot_filename)
    print(f"Saved metrics comparison plot to {plot_filename}")
    plt.show()
    plt.close()

# --- Generate Plots ---
print("\n--- Generating Visualizations ---")
plot_runoff_comparison(station_id_s1, 'lstm', usgs_s1, nwm_s1, corrected_s1)
plot_metrics_comparison(station_id_s1, 'lstm', metrics_s1)

plot_runoff_comparison(station_id_s2, 'transformer', usgs_s2, nwm_s2, corrected_s2)
plot_metrics_comparison(station_id_s2, 'transformer', metrics_s2)

# --- (Optional) Time Series Plot for a specific lead time ---
def plot_timeseries_lead_time(station_id, model_type, usgs_data, nwm_data, corrected_data, lead_time_index=0):
    if usgs_data is None or nwm_data is None or corrected_data is None:
        print(f"Skipping time series plot for {station_id}: Missing data.")
        return
    if lead_time_index >= usgs_data.shape[1]:
        print(f"Skipping time series plot for {station_id}: Invalid lead_time_index {lead_time_index}.")
        return

    lead_time_hour = lead_time_index + 1
    time_index = np.arange(len(usgs_data[:, lead_time_index]))
    plt.figure(figsize=(15, 6))
    plt.plot(time_index, usgs_data[:, lead_time_index], label=f'Observed (USGS)', color='blue')
    plt.plot(time_index, nwm_data[:, lead_time_index], label=f'NWM Forecast', color='green', alpha=0.7, linestyle='--')
    plt.plot(time_index, corrected_data[:, lead_time_index], label=f'Corrected ({model_type.upper()})', color='red', alpha=0.8)
    plt.title(f'Station {station_id}: Forecast vs Actual Runoff (Lead Time {lead_time_hour} hr - Test Set)')
    plt.xlabel('Time Step Index')
    plt.ylabel('Runoff (cms)')
    plt.legend()
    plt.show()
    plt.close()

print("\n--- Generating Time Series Plots (Lead Time 1) ---")
plot_timeseries_lead_time(station_id_s1, 'lstm', usgs_s1, nwm_s1, corrected_s1, lead_time_index=0)
plot_timeseries_lead_time(station_id_s2, 'transformer', usgs_s2, nwm_s2, corrected_s2, lead_time_index=0)

## Compare Model Performance
Analyze and compare the performance of LSTM and Transformer models for both stations. Highlight differences in behavior and accuracy based on the evaluation metrics and visualizations.

**Station 21609641 (LSTM):**
*   Analyze the LSTM model's performance based on RMSE, CC, PBIAS, NSE compared to the raw NWM forecast.
*   Discuss the shape of the error distribution (from box plot). Is there bias? How wide is the spread?
*   Examine the time series plot: Does the LSTM capture peaks and troughs better than NWM? Are there lags?

**Station 20380357 (Transformer):**
*   Analyze the Transformer model's performance similarly.
*   Compare its metrics to the raw NWM forecast for this station.
*   Discuss its error distribution and time series behavior.

**Overall Comparison:**
*   Which model type performed better overall, considering the metrics? (Note: They are applied to different stations here, so direct comparison is tricky unless the stations/data are very similar).
*   Did one model type show specific strengths (e.g., capturing extremes, lower bias)?
*   Relate performance differences to potential characteristics of the data for each station or the inherent differences between LSTM (sequential processing) and Transformer (attention mechanism) architectures.
*   Discuss potential reasons for observed performance (e.g., data quality, sequence length choice, model complexity).
*   Suggest future improvements or experiments (e.g., hyperparameter tuning, different features, longer sequences, different model variants).