# Exploratory Data Analysis for Runoff Forecasting

This notebook explores the National Water Model (NWM) forecasts and USGS observational data to understand patterns, biases, and potential for improvement using deep learning.

In [None]:
# Import necessary libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os

# Set plotting style
plt.style.use('seaborn-whitegrid')
sns.set_palette('deep')
plt.rcParams['figure.figsize'] = [12, 6]

# Configure notebook display options
pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', 100)

## 1. Load the Data

In [None]:
# Define data paths
data_dir = os.path.join('..', 'data', 'raw')
nwm_file = os.path.join(data_dir, 'nwm_forecasts.csv')
usgs_file = os.path.join(data_dir, 'usgs_observations.csv')

# Load NWM forecast data
try:
    nwm_df = pd.read_csv(nwm_file, parse_dates=['datetime'])
    print(f"NWM data loaded with shape: {nwm_df.shape}")
except FileNotFoundError:
    print(f"NWM file not found: {nwm_file}")
    nwm_df = pd.DataFrame()
    
# Load USGS observation data
try:
    usgs_df = pd.read_csv(usgs_file, parse_dates=['datetime'])
    print(f"USGS data loaded with shape: {usgs_df.shape}")
except FileNotFoundError:
    print(f"USGS file not found: {usgs_file}")
    usgs_df = pd.DataFrame()

## 2. Examine the Data Structure

In [None]:
# If data is available, examine the structure
if not nwm_df.empty:
    print("\nNWM Data Sample:")
    display(nwm_df.head())
    print("\nNWM Data Info:")
    display(nwm_df.info())
    print("\nNWM Data Description:")
    display(nwm_df.describe())

if not usgs_df.empty:
    print("\nUSGS Data Sample:")
    display(usgs_df.head())
    print("\nUSGS Data Info:")
    display(usgs_df.info())
    print("\nUSGS Data Description:")
    display(usgs_df.describe())

## 3. Check for Missing Values

In [None]:
# Check for missing values in NWM data
if not nwm_df.empty:
    print("Missing values in NWM data:")
    display(nwm_df.isnull().sum())
    print(f"Total missing values: {nwm_df.isnull().sum().sum()}")
    
# Check for missing values in USGS data
if not usgs_df.empty:
    print("\nMissing values in USGS data:")
    display(usgs_df.isnull().sum())
    print(f"Total missing values: {usgs_df.isnull().sum().sum()}")

## 4. Explore Temporal Coverage

In [None]:
# Check the time range for each dataset
if not nwm_df.empty:
    print("NWM data temporal coverage:")
    print(f"Start: {nwm_df['datetime'].min()}")
    print(f"End: {nwm_df['datetime'].max()}")
    print(f"Duration: {nwm_df['datetime'].max() - nwm_df['datetime'].min()}")
    
if not usgs_df.empty:
    print("\nUSGS data temporal coverage:")
    print(f"Start: {usgs_df['datetime'].min()}")
    print(f"End: {usgs_df['datetime'].max()}")
    print(f"Duration: {usgs_df['datetime'].max() - usgs_df['datetime'].min()}")

## 5. Explore Station Coverage

In [None]:
# Check the number of stations in each dataset
if not nwm_df.empty and 'station_id' in nwm_df.columns:
    nwm_stations = nwm_df['station_id'].unique()
    print(f"Number of stations in NWM data: {len(nwm_stations)}")
    
if not usgs_df.empty and 'station_id' in usgs_df.columns:
    usgs_stations = usgs_df['station_id'].unique()
    print(f"Number of stations in USGS data: {len(usgs_stations)}")
    
# Check for overlap between datasets
if not nwm_df.empty and not usgs_df.empty and 'station_id' in nwm_df.columns and 'station_id' in usgs_df.columns:
    common_stations = set(nwm_stations).intersection(set(usgs_stations))
    print(f"Number of common stations: {len(common_stations)}")
    print(f"Percentage of NWM stations with USGS data: {len(common_stations)/len(nwm_stations)*100:.2f}%")

## 6. Visualize Distribution of Runoff Values

In [None]:
# Visualize runoff distributions if data is available
fig, axes = plt.subplots(1, 2, figsize=(15, 6))

# NWM Runoff Distribution
if not nwm_df.empty and 'runoff_nwm' in nwm_df.columns:
    sns.histplot(nwm_df['runoff_nwm'], bins=50, kde=True, ax=axes[0])
    axes[0].set_title('NWM Runoff Distribution')
    axes[0].set_xlabel('Runoff (cms)')
    axes[0].set_ylabel('Frequency')
    
    # Log scale for better visualization of extreme values
    axes[0].set_yscale('log')
    
# USGS Runoff Distribution
if not usgs_df.empty and 'runoff_usgs' in usgs_df.columns:
    sns.histplot(usgs_df['runoff_usgs'], bins=50, kde=True, ax=axes[1])
    axes[1].set_title('USGS Runoff Distribution')
    axes[1].set_xlabel('Runoff (cms)')
    axes[1].set_ylabel('Frequency')
    
    # Log scale for better visualization of extreme values
    axes[1].set_yscale('log')

plt.tight_layout()
plt.show()

## 7. Merge Datasets for Comparison

In [None]:
# Merge NWM and USGS datasets if both are available
if not nwm_df.empty and not usgs_df.empty:
    # Align on datetime and station_id
    merged_df = pd.merge(
        nwm_df,
        usgs_df,
        on=['datetime', 'station_id'],
        how='inner',
        suffixes=('_nwm', '_usgs')
    )
    
    print(f"Merged data shape: {merged_df.shape}")
    display(merged_df.head())

## 8. Analyze NWM Forecast Bias

In [None]:
# Calculate bias if merged data is available
if 'merged_df' in locals() and 'runoff_nwm' in merged_df.columns and 'runoff_usgs' in merged_df.columns:
    # Calculate absolute and relative errors
    merged_df['error'] = merged_df['runoff_nwm'] - merged_df['runoff_usgs']
    merged_df['rel_error'] = merged_df['error'] / merged_df['runoff_usgs'] * 100
    
    # Summary statistics of error
    print("Error summary statistics:")
    display(merged_df[['error', 'rel_error']].describe())
    
    # Visualize error distribution
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
    
    # Absolute error
    sns.histplot(merged_df['error'], bins=50, kde=True, ax=ax1)
    ax1.axvline(x=0, color='r', linestyle='--')
    ax1.set_title('NWM Absolute Error Distribution')
    ax1.set_xlabel('Error (cms)')
    ax1.set_ylabel('Frequency')
    
    # Relative error (with outlier removal for better visualization)
    rel_error_filtered = merged_df['rel_error'].clip(-200, 200)  # Clip extreme values
    sns.histplot(rel_error_filtered, bins=50, kde=True, ax=ax2)
    ax2.axvline(x=0, color='r', linestyle='--')
    ax2.set_title('NWM Relative Error Distribution (clipped)')
    ax2.set_xlabel('Relative Error (%)')
    ax2.set_ylabel('Frequency')
    
    plt.tight_layout()
    plt.show()

## 9. Analyze Temporal Patterns in Error

In [None]:
# Analyze temporal patterns in NWM error
if 'merged_df' in locals() and 'error' in merged_df.columns:
    # Extract time components
    merged_df['hour'] = merged_df['datetime'].dt.hour
    merged_df['month'] = merged_df['datetime'].dt.month
    merged_df['season'] = pd.cut(
        merged_df['datetime'].dt.month, 
        bins=[0, 3, 6, 9, 12], 
        labels=['Winter', 'Spring', 'Summer', 'Fall']
    )
    
    # Error by hour of day
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
    
    hourly_error = merged_df.groupby('hour')['error'].mean()
    hourly_error.plot(kind='line', marker='o', ax=ax1)
    ax1.set_title('Mean NWM Error by Hour of Day')
    ax1.set_xlabel('Hour')
    ax1.set_ylabel('Mean Error (cms)')
    ax1.set_xticks(range(0, 24, 2))
    ax1.axhline(y=0, color='r', linestyle='--')
    
    # Error by month
    monthly_error = merged_df.groupby('month')['error'].mean()
    monthly_error.plot(kind='line', marker='o', ax=ax2)
    ax2.set_title('Mean NWM Error by Month')
    ax2.set_xlabel('Month')
    ax2.set_ylabel('Mean Error (cms)')
    ax2.set_xticks(range(1, 13))
    ax2.axhline(y=0, color='r', linestyle='--')
    
    plt.tight_layout()
    plt.show()
    
    # Error by season (box plot)
    plt.figure(figsize=(10, 6))
    sns.boxplot(x='season', y='error', data=merged_df)
    plt.title('NWM Error Distribution by Season')
    plt.xlabel('Season')
    plt.ylabel('Error (cms)')
    plt.axhline(y=0, color='r', linestyle='--')
    plt.show()

## 10. Analyze Station-Specific Patterns

In [None]:
# Analyze patterns by station
if 'merged_df' in locals() and 'station_id' in merged_df.columns:
    # Calculate station-wise error statistics
    station_stats = merged_df.groupby('station_id').agg({
        'error': ['mean', 'std', 'median'],
        'rel_error': ['mean', 'std', 'median'],
        'runoff_usgs': ['mean', 'min', 'max', 'count']
    })
    
    # Flatten the column hierarchy
    station_stats.columns = ['_'.join(col).strip() for col in station_stats.columns.values]
    
    # Sort by absolute mean error
    station_stats = station_stats.sort_values(by='error_mean', key=abs, ascending=False)
    
    print("Station-wise error statistics (top 10 by absolute error):")
    display(station_stats.head(10))
    
    # Visualize station-specific error patterns
    # Focus on stations with reasonable data (e.g., at least 100 records)
    stations_to_plot = station_stats[station_stats['runoff_usgs_count'] >= 100].index[:5]
    
    for station in stations_to_plot:
        station_data = merged_df[merged_df['station_id'] == station]
        
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
        
        # Scatter plot of predicted vs observed
        ax1.scatter(station_data['runoff_usgs'], station_data['runoff_nwm'], alpha=0.5)
        max_val = max(station_data['runoff_usgs'].max(), station_data['runoff_nwm'].max()) * 1.1
        ax1.plot([0, max_val], [0, max_val], 'r--')
        ax1.set_xlabel('Observed Runoff (cms)')
        ax1.set_ylabel('NWM Predicted Runoff (cms)')
        ax1.set_title(f'Station {station}: NWM vs Observed')
        
        # Time series of a sample period (e.g., 30 days)
        sample_period = station_data.sort_values('datetime').iloc[:720]  # ~30 days of hourly data
        ax2.plot(sample_period['datetime'], sample_period['runoff_usgs'], 'b-', label='Observed')
        ax2.plot(sample_period['datetime'], sample_period['runoff_nwm'], 'r-', label='NWM')
        ax2.set_xlabel('Date')
        ax2.set_ylabel('Runoff (cms)')
        ax2.set_title(f'Station {station}: Sample Time Series')
        ax2.legend()
        
        plt.tight_layout()
        plt.show()

## 11. Analyze Relationship Between Error and Flow Magnitude

In [None]:
# Analyze how error relates to flow magnitude
if 'merged_df' in locals() and 'runoff_usgs' in merged_df.columns and 'error' in merged_df.columns:
    # Create flow magnitude bins
    merged_df['flow_bin'] = pd.qcut(merged_df['runoff_usgs'], q=10, duplicates='drop')
    
    # Calculate error statistics by flow bin
    flow_bin_stats = merged_df.groupby('flow_bin').agg({
        'error': ['mean', 'std', 'median'],
        'rel_error': ['mean', 'std', 'median'],
        'runoff_usgs': ['mean', 'count']
    })
    
    # Flatten the column hierarchy
    flow_bin_stats.columns = ['_'.join(col).strip() for col in flow_bin_stats.columns.values]
    
    print("Error statistics by flow magnitude bin:")
    display(flow_bin_stats)
    
    # Visualize relationship between error and flow magnitude
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
    
    # Scatter plot with flow magnitude
    ax1.scatter(merged_df['runoff_usgs'], merged_df['error'], alpha=0.1)
    ax1.axhline(y=0, color='r', linestyle='--')
    ax1.set_xlabel('Observed Runoff (cms)')
    ax1.set_ylabel('Error (cms)')
    ax1.set_title('NWM Error vs Flow Magnitude')
    
    # Box plot of relative error by flow bin
    sns.boxplot(x='flow_bin', y='rel_error', data=merged_df, ax=ax2)
    ax2.axhline(y=0, color='r', linestyle='--')
    ax2.set_xlabel('Flow Magnitude Bin')
    ax2.set_ylabel('Relative Error (%)')
    ax2.set_title('Relative Error by Flow Magnitude')
    ax2.set_xticklabels(ax2.get_xticklabels(), rotation=90)
    
    plt.tight_layout()
    plt.show()

## 12. Summary and Insights for Model Development

### Key Findings:

1. **Data Coverage**: [Summary of temporal and spatial coverage]

2. **Bias Patterns**: [Summary of observed bias patterns]
   - Temporal patterns: [hourly, seasonal variations]
   - Station-specific patterns: [variations across stations]
   - Flow magnitude relationships: [how error varies with flow magnitude]

3. **Missing Data**: [Summary of missing data issues]

### Implications for Model Development:

1. **Feature Engineering**:
   - Include temporal features (hour, day, month, season)
   - Include station-specific features or embeddings
   - Consider flow magnitude as a feature

2. **Model Architecture**:
   - Sequential models (LSTM/GRU) to capture temporal dependencies
   - Consider station-specific models for locations with unique patterns
   - Implement bias correction techniques specifically for high-flow events

3. **Data Preparation**:
   - Handle missing values appropriately
   - Consider data normalization strategies
   - Ensure adequate representation of both high and low flow events in training data

4. **Evaluation Metrics**:
   - Focus on relevant hydrological metrics (RMSE, NSE, PBIAS)
   - Consider performance across different flow regimes
   - Evaluate performance by season and station