# GRDC Hydro Data Processing and Analysis

## Overview
This notebook processes and analyzes streamflow data from the Global Runoff Data Centre (GRDC) to create a comprehensive dataset of river discharge and runoff. The analysis is particularly useful for hydropower assessment, water resource management, and climate impact studies in Central African regions.

## Required Input Data
This analysis requires three types of input data, all available from open-source repositories. You must download these datasets separately before running this notebook:

1. **Monthly Streamflow Data from GRDC Stations**
   - Source: [Global Runoff Data Centre (GRDC)](https://www.bafg.de/GRDC/)
   - Format: NetCDF files containing monthly discharge measurements
   - Required fields: station metadata (coordinates, catchment area), discharge time series

2. **Hydropower Plant Data from the Global Hydropower Tracker**
   - Source: [Global Energy Monitor](https://globalenergymonitor.org/projects/global-hydropower-tracker/)
   - Format: Excel spreadsheet (.xlsx)
   - Required fields: plant names, locations, capacities, status, river names

3. **River Network Data from HydroRIVERS**
   - Source: [HydroSHEDS](https://www.hydrosheds.org/products/hydrorivers)
   - Format: Shapefile (.shp)
   - Required fields: river geometries, names, discharge estimates

## Expected Outputs
This notebook will generate:
- Filtered and processed discharge and runoff datasets
- Interactive maps showing stations, rivers, and hydropower plants
- Time series plots of runoff evolution
- Monthly discharge patterns for hydropower-relevant stations
- CSV files with processed data for further analysis

## Directory Structure
Before running, ensure you have the following directory structure:
```
hydro/
├── input/
│   ├── grdc_input/         # GRDC station data (NetCDF files)
│   ├── river_input/        # HydroRIVERS shapefiles
│   └── Global-Hydropower-Tracker-April-2025.xlsx
└── output/                 # Will be created if it doesn't exist
```



In [None]:
import os
import geopandas as gpd
import xarray as xr
import folium
import folium.plugins
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from branca.colormap import LinearColormap
from folium import CircleMarker
from matplotlib.colors import LogNorm
from shapely.geometry import box
import calendar
import numpy as np

from rapidfuzz import process, fuzz
import unicodedata
import re

# Import local utilities
from utils import map_grdc_stationbasins_and_subregions


## User Input Configuration

### Purpose
This section defines the geographical scope of the analysis and sets up the directory structure for input and output files. You need to customize these parameters to match your specific analysis requirements.

### Parameters to Configure
- **countries**: List of countries to include in the analysis. The hydropower plant data will be filtered to include only plants in these countries.
- **base_dir**: Base directory for all hydro-related data. This should point to where your input data is stored.
- **folder_in**: Directory containing all input data files (GRDC data, hydropower tracker, river shapefiles).
- **folder_out**: Directory where all output files will be saved (automatically created if it doesn't exist).

### Before Running
1. Ensure all required input data is downloaded and placed in the correct directories
2. Verify that the country names match exactly those used in the Global Hydropower Tracker dataset
3. Check that you have write permissions for the output directory

In [None]:
countries = ['Angola', 'Burundi', 'Cameroon', 'Central African Republic', 'Chad', 'Republic of the Congo', 'DR Congo', 'Equatorial Guinea', 'Gabon']

# Define the base input directory
base_dir = 'hydro'
folder_in  = os.path.join(base_dir, 'input')
folder_out = os.path.join(base_dir, 'output')
if not os.path.exists(folder_out): os.makedirs(folder_out)

## 1. Load Streamflow Data

### Purpose
This section loads monthly streamflow data from GRDC NetCDF files. The data contains time series of river discharge measurements from gauging stations across the selected region, along with station metadata (coordinates, catchment area, etc.).

### Process
1. The code recursively searches through the `folder_grdc` directory for files named 'GRDC-Monthly.nc'
2. Each NetCDF file is loaded as an xarray Dataset
3. All datasets are concatenated into a single dataset for further processing

### Technical Details
- GRDC data is stored in NetCDF format (.nc files)
- The xarray library is used to handle the multidimensional data efficiently
- Each station has a unique identifier ('id') used to merge datasets
- Key variables include:
  - 'runoff_mean': Monthly discharge values (m³/s)
  - 'area': Catchment area (km²)
  - 'geo_x', 'geo_y': Station coordinates (longitude, latitude)
  - 'station_name': Name of the gauging station
  - 'river_name': Name of the river where the station is located

### Troubleshooting
- If no data is found, check that the GRDC files are in the correct directory structure
- Ensure NetCDF files are properly formatted with the expected variables
- Large datasets may require significant memory; consider filtering by region earlier if performance issues occur

In [None]:
folder_grdc = os.path.join(folder_in, 'grdc_input')

# Initialize a list to hold all xarray Datasets
datasets = []

# Walk through all directories under 'input'
for root, dirs, files in os.walk(folder_grdc):
    if 'GRDC-Monthly.nc' in files:
        file_path = os.path.join(root, 'GRDC-Monthly.nc')
        try:
            data_discharge = xr.open_dataset(file_path)
            datasets.append(data_discharge)
            print(f"Loaded: {file_path}")
        except Exception as e:
            print(f"Failed to load {file_path}: {e}")

# Merge all datasets
if datasets:
    try:
        data_discharge = xr.concat(datasets, dim='id')  # You can use another dim if needed
        print(f"Total merged dimensions: {data_discharge.dims}")
    except Exception as e:
        print(f"Failed to merge datasets: {e}")
else:
    print("No streamflow data found.")

## 2. Filter Data for Quality Control

### Purpose
This section applies quality control filters to the streamflow data to ensure reliable analysis. Poor quality or incomplete data can lead to misleading results, so it's important to apply appropriate filtering criteria.

### Filtering Steps
1. **Remove Stations with Invalid Catchment Area**
   - Stations with missing or non-positive catchment area values are removed
   - Catchment area is essential for calculating runoff in mm/year
   - These stations are flagged and their metadata is preserved for reference

2. **Remove Stations with Low Runoff Values**
   - Stations with consistently low discharge values (below 15 m³/s) are removed
   - Low values may indicate intermittent streams, measurement errors, or stations on very small tributaries
   - This threshold can be adjusted based on your specific analysis needs

### Interpretation
- The filtering process typically removes a subset of stations from the original dataset
- The summary at the end shows how many stations were removed for each reason
- Removed stations are not discarded completely but kept in separate dataframes for reference
- In the visualization section, these removed stations will be shown as separate layers on the map

### Considerations
- Increasing the runoff threshold will result in fewer stations but potentially more reliable data
- Decreasing the threshold will include more stations but may introduce noise from small or intermittent streams
- For specific regional analyses, you might want to adjust these thresholds based on local hydrology

In [None]:
# -------------------------------
# 1. Filter out stations with no valid area
# -------------------------------

# Identify stations with invalid or missing area
invalid_area_mask = (data_discharge["area"] <= 0) | data_discharge["area"].isnull()
removed_area_ids = data_discharge["id"].values[invalid_area_mask.values]

# Extract metadata of removed stations
removed_area_meta = data_discharge.sel(id=removed_area_ids)[["station_name", "geo_x", "geo_y", "area"]]
stations_removed_area_df = removed_area_meta.to_dataframe().reset_index()

# Remove stations with invalid area
data_discharge = data_discharge.where(data_discharge["area"] > 0, drop=True)

# -------------------------------
# 2. Filter out stations with low runoff values
# -------------------------------

threshold_runoff = 15  # m³/s

# Identify stations where all runoff values are below threshold or NaN
runoff = data_discharge["runoff_mean"]
low_flow_mask = (runoff < threshold_runoff) | runoff.isnull()
stations_to_remove = low_flow_mask.all(dim="time")
removed_ids = data_discharge["id"].values[stations_to_remove.values]

# Extract metadata of removed stations
removed_meta = data_discharge.sel(id=removed_ids)[["station_name", "geo_x", "geo_y"]]
stations_removed_lowflow_df = removed_meta.to_dataframe().reset_index()

# Remove those stations from the dataset
data_discharge = data_discharge.sel(id=~stations_to_remove)

# -------------------------------
# 3. Summary (optional display or export)
# -------------------------------

print(f"Stations removed due to invalid area: {len(stations_removed_area_df)}")
print(f"Stations removed due to low runoff: {len(stations_removed_lowflow_df)}")

## 3. Format Data for Analysis

### Purpose
This section transforms the filtered xarray Dataset into a more analysis-friendly format. The data is restructured to facilitate time series analysis, station comparison, and export to standard file formats.

### Key Transformations
1. **Convert to Pandas DataFrame**
   - The xarray Dataset is converted to a pandas DataFrame for easier manipulation
   - Station metadata (coordinates, names) is merged with the discharge data
   - Time information is extracted into separate year and month columns

2. **Handle Duplicate Station Names**
   - The code checks for and resolves duplicate station names
   - Duplicates are renamed with unique identifiers to prevent data confusion
   - This is important because station names are used as identifiers in later analyses

3. **Create Pivot Tables**
   - Data is pivoted to create a wide-format table with years as rows and stations as columns
   - This format is ideal for time series analysis and visualization
   - The pivoted data is saved as a CSV file for external use

### Output Files
- **grdc_discharge_monthly-m3-s.csv**: Monthly discharge data in cubic meters per second
  - Rows represent years
  - Columns represent combinations of station names and months
  - This format allows for easy filtering and aggregation by station or time period

### Data Structure
After processing, the main DataFrame (`data_station`) contains:
- **station_name**: Name of the gauging station
- **id**: Unique station identifier
- **time**: Original timestamp
- **year, month**: Extracted time components
- **Q**: Discharge value in m³/s
- **geo_x, geo_y**: Station coordinates
- **area**: Catchment area in km²
- **station_label**: Formatted label combining name and coordinates

### Troubleshooting
- If you encounter issues with duplicate station names, check the original data source
- Missing values in the pivot table indicate months with no data for that station
- Large datasets may cause memory issues; consider filtering to specific regions if needed

In [None]:
def checking_duplicates_grdc(df):
    """
    Check for duplicated (year, station_name, month) entries in the DataFrame.

    This function identifies records that have the same year, station_name, and month,
    which should be unique in a properly formatted dataset.

    Parameters:
        df (pandas.DataFrame): DataFrame containing at least the columns 'year', 'station_name', and 'month'

    Returns:
        list: List of station names that have duplicate entries
    """
    dupes = (
        df.groupby(["year", "station_name", "month"])
        .size()
        .reset_index(name='count')
        .query("count > 1")
    )

    print(f"Found {len(dupes)} duplicated (year, station_name, month) entries")

    problem_stations = dupes['station_name'].unique()
    print(f'Duplicated stations: {problem_stations}')

    # Merge the duplicate keys back into df to see full rows
    duplicated_rows = df.merge(dupes[["year", "station_name", "month"]], on=["year", "station_name", "month"])
    #display(duplicated_rows.sort_values(["station_name", "year", "month"]))

    return problem_stations

In [None]:
var_name = 'runoff_mean'

meta = data_discharge[["station_name", "geo_x", "geo_y"]].to_dataframe().reset_index()
area = data_discharge["area"].to_dataframe().reset_index()

# Convert to DataFrame
data_station = data_discharge[var_name].to_dataframe(name="Q").reset_index()

# Merge metadata into the main DataFrame
data_station = data_station.merge(meta, on="id")
data_station = data_station.merge(area, on="id")

# Add year and month columns
data_station["year"] = data_station["time"].dt.year
data_station["month"] = data_station["time"].dt.month

# Check if station_name has duplicates
if data_station.duplicated(subset=["station_name"]).any():
    problem_stations = checking_duplicates_grdc(data_station)

    # Step 3: Get corresponding station IDs
    problem_ids = data_station[data_station["station_name"].isin(problem_stations)]["id"].unique()

    # Step 4: Assign unique station names for problematic IDs only
    rename_map = {
        id_: f"{data_station[data_station['id'] == id_]['station_name'].iloc[0]}_{i+1}"
        for i, id_ in enumerate(problem_ids)
    }

    # Step 5: Apply renaming based on `id` consistently
    data_station["station_name"] = data_station.apply(
        lambda row: rename_map[row["id"]] if row["id"] in rename_map else row["station_name"],
        axis=1
    )

# Create a unique label for each station using name + coordinates (optional)
data_station["station_label"] = data_station["station_name"].str.strip() + " (" + data_station["geo_y"].round(2).astype(str) + ", " + data_station["geo_x"].round(2).astype(str) + ")"

# Pivot to wide format: year as index, MultiIndex (month, id) as columns
data_station_pivot = data_station.pivot(index="year", columns=["station_name", "month"], values="Q")
data_station_pivot.dropna(axis=1, how='all', inplace=True)
data_station_pivot.sort_index(ascending=True, axis=1, inplace=True)
data_station_pivot.to_csv(os.path.join(folder_out, f'grdc_discharge_monthly-m3-s.csv'), index=True)

## 4. Calculate Runoff Data

### Purpose
This section converts river discharge measurements (m³/s) into runoff depth (mm/year). Runoff depth normalizes flow by catchment area, allowing for direct comparison between watersheds of different sizes. This is crucial for:
- Comparing water yield across different basins
- Assessing regional water resources
- Evaluating potential for hydropower development
- Analyzing climate change impacts on water availability

### Conversion Formula
Runoff data (mm/year) is calculated from monthly discharge data (m³/s) using the formula:

```
runoff_mm = (Σ(Q_monthly_avg × 86400 × days_in_month)) / area_m2 × 1000
```

Where:
- **Q** is discharge in m³/s (cubic meters per second)
- **86400** is seconds per day (60 × 60 × 24)
- **days_in_month** accounts for varying month lengths (28-31 days)
- **area_m2** is the catchment area in square meters
- **1000** converts meters to millimeters

### Process
1. The code first filters for complete years (with all 12 months of data)
2. It then applies the conversion formula to calculate annual runoff for each station
3. Results are saved as a CSV file with years as rows and stations as columns

### Output Files
- **grdc_runoff_mm-year.csv**: Annual runoff depth in mm/year
  - This standardized unit allows for comparison across different watersheds
  - Typical values range from 100-2000 mm/year depending on climate and geography

### Interpretation
- Higher runoff values (>1000 mm/year) typically indicate wet regions with high rainfall
- Lower values (<300 mm/year) suggest drier conditions or high evapotranspiration
- Sudden changes in runoff patterns may indicate:
  - Climate shifts
  - Land use changes in the catchment
  - Dam construction upstream
  - Data quality issues

### Troubleshooting
- If runoff values seem unreasonably high or low, check:
  - Catchment area values (common source of error)
  - Unit conversion factors
  - Completeness of monthly data


In [None]:
def filter_full_years(df):
    """
    Keep only (station, year) pairs with 12 valid months.
    Returns filtered DataFrame and number of dropped rows.
    """
    original_len = len(df)

    # Count valid months per station-year
    valid_counts = (
        df.groupby(['station_name', 'year'])['Q']
        .apply(lambda x: x.notna().sum())
        .reset_index(name='valid_months')
    )

    # Only keep those with all 12 months
    full_years = valid_counts[valid_counts['valid_months'] == 12]

    # Merge to filter original DataFrame
    df_filtered = df.merge(full_years[['station_name', 'year']], on=['station_name', 'year'])

    removed_rows = original_len - len(df_filtered)
    print(f"Removed {removed_rows} rows — kept {len(df_filtered)} only full (12-month) years.")

    return df_filtered

def discharge_to_runoff(df):
    """
    Convert monthly discharge (m³/s) into annual runoff (mm/year) at the station level.

    This function performs the following steps:
    1. Calculates the number of days in each month
    2. Converts catchment area from km² to m²
    3. Computes monthly water volume (m³) from discharge (m³/s)
    4. Sums monthly volumes for each station-year
    5. Divides by catchment area to get runoff depth (m)
    6. Converts from m to mm by multiplying by 1000

    The equation used is:
        runoff_mm = (Σ(Q_monthly_avg × 86400 × days_in_month)) / area_m2 × 1000

    Where:
        - Q is discharge in m³/s
        - 86400 is seconds per day
        - days_in_month accounts for monthly totals
        - area is the catchment area in m²
        - 1000 converts meters to millimeters

    Parameters:
        df (pandas.DataFrame): DataFrame containing monthly discharge data with columns:
            - 'station_name': Name of the gauging station
            - 'year': Year of the measurement
            - 'time': Datetime of the measurement
            - 'Q': Discharge in m³/s
            - 'area': Catchment area in km²

    Assumes:
        - Input DataFrame has one row per station/month
        - Years are complete (12 months per station)
        - Area is in km²

    Returns:
        pandas.DataFrame: DataFrame with columns:
            - 'station_name': Name of the gauging station
            - 'year': Year of the measurement
            - 'runoff_mm_year': Annual runoff in mm/year
    """
    df = df.copy()

    # Add number of days in each month
    df['days_in_month'] = pd.to_datetime(df['time']).dt.days_in_month

    # Convert area to m²
    df['area_m2'] = df['area'] * 1e6

    # Compute volume in m³ for each month
    df['volume_m3'] = df['Q'] * 86400 * df['days_in_month']

    # Sum monthly volumes per station-year
    runoff_by_year = (
        df.groupby(['station_name', 'year'])
        .apply(lambda x: x['volume_m3'].sum() / x['area_m2'].iloc[0] * 1000, include_groups=False)  # m to mm
        .reset_index(name='runoff_mm_year')
    )

    return runoff_by_year

In [None]:
# Convert discharge data (m3-s) to runoff in (mm-year)
data_station_filtered = filter_full_years(data_station)
data_runoff = discharge_to_runoff(data_station_filtered)
# Save to CSV
data_runoff.round(0).pivot(index="year", columns="station_name", values="runoff_mm_year").to_csv(os.path.join(folder_out, f'grdc_runoff_mm-year.csv'), index=True)

location = data_station[['station_name', 'geo_x', 'geo_y']].drop_duplicates().reset_index(drop=True)
# Merge with runoff_by_year data
data_runoff = data_runoff.merge(location, on='station_name')

## 5. Add Hydropower Plant Data & Associate with River Data

### Purpose
This section integrates hydropower plant (HPP) data with the streamflow analysis. It links each hydropower plant to the most relevant gauging station, either on the same river or by proximity. This connection is crucial for:
- Assessing water availability for existing and planned hydropower plants
- Evaluating seasonal flow patterns at hydropower locations
- Analyzing long-term trends that might affect energy production
- Supporting feasibility studies for new hydropower development

### Process
1. **Load Hydropower Plant Data**
   - Data is loaded from the Global Hydropower Tracker Excel file
   - Plants are filtered to include only those in the countries of interest
   - Basic validation checks ensure required columns are present

2. **Match Hydropower Plants to Rivers**
   - River names from HPP data are cleaned and standardized
   - Fuzzy matching algorithms find the best matches between HPP rivers and station rivers
   - A match score indicates the confidence level of each river name match

3. **Associate HPPs with Gauging Stations**
   - For HPPs with matched rivers, the station on the same river is preferred
   - For unmatched rivers, the nearest station by geographic distance is used
   - Each HPP is tagged with its nearest station and whether it's on the same river

### Output Data
The resulting dataset includes:
- All hydropower plants in the selected countries
- The nearest gauging station for each plant
- Distance to the station (in meters)
- Flag indicating whether the plant and station are on the same river

### Interpretation
- HPPs matched to stations on the same river (same_river = True) provide more reliable flow estimates
- For plants with same_river = False, flow patterns should be used with caution as they represent nearby but different watersheds
- Distance metrics help assess the reliability of the station-plant association

### Considerations
- River name matching is challenging due to spelling variations, translations, and naming conventions
- The fuzzy matching threshold (currently 85%) can be adjusted to be more or less strict
- Geographic proximity doesn't always indicate hydrological similarity, especially in mountainous regions

In [None]:
path_hpp = os.path.join(folder_in, 'Global-Hydropower-Tracker-April-2025.xlsx')

# Improved error handling for hydropower data loading
try:
    if not os.path.exists(path_hpp):
        raise FileNotFoundError(f"Hydropower data file not found: {path_hpp}")

    data_hpp = pd.read_excel(path_hpp, sheet_name='Data', header=[0], index_col=None)

    # Check if required columns exist
    required_columns = ['Country/Area 1', 'Project Name', 'River / Watercourse', 'Capacity (MW)', 'Status', 'Latitude', 'Longitude']
    missing_columns = [col for col in required_columns if col not in data_hpp.columns]

    if missing_columns:
        raise ValueError(f"Missing required columns in hydropower data: {missing_columns}")

    # Filter for countries of interest
    data_hpp = data_hpp.loc[data_hpp['Country/Area 1'].isin(countries), :]

    print(f"Loaded {len(data_hpp)} hydropower plants from {len(countries)} countries")

    # Check if we have any data after filtering
    if len(data_hpp) == 0:
        print("Warning: No hydropower plants found for the specified countries")

    display(data_hpp.head())

except Exception as e:
    print(f"Error loading hydropower data: {e}")
    # Create an empty DataFrame with the required columns to avoid breaking the rest of the code
    data_hpp = pd.DataFrame(columns=required_columns)
    print("Created empty hydropower DataFrame to continue processing")

In [None]:
# Add river name to runoff data
meta_df = data_discharge[["station_name", "river_name"]].to_dataframe().reset_index()
meta_df = meta_df.drop_duplicates(subset=["station_name"])

data_station_filtered_river = data_station_filtered.merge(meta_df, on="station_name", how="left")

In [None]:
def clean_river_name(name):
    """
    Clean and standardize river names for better matching.

    This function:
    1. Handles NaN/None values
    2. Normalizes Unicode characters and removes accents
    3. Removes common river-related words (RIVER, RIVIERE, etc.)
    4. Standardizes spacing and converts to uppercase

    Parameters:
        name (str): The river name to clean

    Returns:
        str: The cleaned and standardized river name
    """
    if pd.isna(name):
        return ""

    # Normalize unicode and remove accents (e.g. é → e)
    name = unicodedata.normalize('NFKD', name).encode('ASCII', 'ignore').decode('utf-8')

    # Remove common river words and clean up spacing
    name = re.sub(r"\b(RIVER|RIVIERE|RIVIÈRE|R\.|R)\b", "", name, flags=re.IGNORECASE)
    name = re.sub(r"\s+", " ", name).strip().upper()

    return name

# Prepare cleaned lists
station_rivers_raw = data_station_filtered_river['river_name'].dropna().unique()
station_rivers_cleaned = {clean_river_name(r): r for r in station_rivers_raw}

matches = []

# Iterate over all HPPs (row-wise)
for _, row in data_hpp.dropna(subset=["River / Watercourse"]).iterrows():
    hpp_river_orig = row["River / Watercourse"]
    hpp_river_clean = clean_river_name(hpp_river_orig)

    # Find best match among station rivers
    match, score, _ = process.extractOne(
        hpp_river_clean, station_rivers_cleaned.keys(), scorer=fuzz.token_set_ratio
    )

    if score >= 85:
        matched_station_river = station_rivers_cleaned[match]
        stations = data_station_filtered_river[
            data_station_filtered_river["river_name"] == matched_station_river
        ]["station_name"].tolist()
    else:
        matched_station_river = None
        stations = []

    matches.append({
        "hpp_name": row["Project Name"],
        "hpp_river": hpp_river_orig,
        "matched_river": matched_station_river,
        "score": score,
        "stations": set(stations),
        "hpp_capacity": row["Capacity (MW)"],
        "hpp_status": row["Status"]
    })

# Convert to DataFrame for easy display or export
matches_df = pd.DataFrame(matches)
matches_df.to_csv(os.path.join(folder_out, 'hpp_grdc_hydro_matches.csv'), index=False)


In [None]:
# Associate hydropower plants with gauging stations based on proximity
def associate_hpp_station(data_hpp, data_station, matches_df):
    # Step 1: Build GeoDataFrames
    gdf_hpp = gpd.GeoDataFrame(
        data_hpp.copy(),
        geometry=gpd.points_from_xy(data_hpp["Longitude"], data_hpp["Latitude"]),
        crs="EPSG:4326"
    )

    gdf_station = gpd.GeoDataFrame(
        data_station.copy(),
        geometry=gpd.points_from_xy(data_station["geo_x"], data_station["geo_y"]),
        crs="EPSG:4326"
    )

    # Step 2: Project both to a metric CRS (meters)
    gdf_hpp = gdf_hpp.to_crs(epsg=3857)
    gdf_station = gdf_station.to_crs(epsg=3857)

    # Step 3: Build dictionary from matches_df
    station_lookup = {
        row["hpp_name"]: row["stations"]
        for _, row in matches_df.iterrows()
        if row["stations"]
    }

    # Step 4: For each HPP, associate station
    results = []

    for idx, hpp in gdf_hpp.iterrows():
        hpp_name = hpp["Project Name"]
        if hpp_name in station_lookup and station_lookup[hpp_name]:
            # Use the first station from the match (same river)
            station_name = list(station_lookup[hpp_name])[0]
            station = gdf_station[gdf_station["station_name"] == station_name].iloc[0]
            distance = hpp.geometry.distance(station.geometry)
            results.append({
                "hpp_index": idx,
                "nearest_station": station_name,
                "distance_to_station_m": distance,
                "same_river": True
            })
        else:
            # Find nearest station by distance
            distances = gdf_station.geometry.distance(hpp.geometry)
            min_idx = distances.idxmin()
            nearest_station = gdf_station.loc[min_idx]
            results.append({
                "hpp_index": idx,
                "nearest_station": nearest_station["station_name"],
                "distance_to_station_m": distances[min_idx],
                "same_river": False
            })

    # Step 5: Merge back into HPP GeoDataFrame
    nearest_df = pd.DataFrame(results)
    gdf_hpp["nearest_station"] = nearest_df["nearest_station"].values
    gdf_hpp["distance_to_station_m"] = nearest_df["distance_to_station_m"].values
    gdf_hpp["same_river"] = nearest_df["same_river"].values

    return gdf_hpp

gdf_hpp = associate_hpp_station(data_hpp, data_station_filtered, matches_df)
data_hpp = data_hpp.merge(
    gdf_hpp[["Project Name", "nearest_station", "distance_to_station_m", 'same_river']],
)

## 6. Visualize All Data on Interactive Maps

### Purpose
This section creates comprehensive interactive maps that integrate all the processed data: gauging stations, river networks, and hydropower plants. These visualizations serve multiple purposes:
- Providing a spatial overview of water resources and infrastructure
- Identifying relationships between rivers, stations, and hydropower plants
- Enabling exploration of regional patterns in runoff and discharge
- Supporting decision-making for water resource management and energy planning

### Map Features
The interactive map includes several key components:

1. **Base Layers**
   - Multiple base maps (light, dark, satellite imagery, OpenStreetMap)
   - Toggle controls to switch between different background maps

2. **River Network Layer**
   - Rivers colored by discharge magnitude
   - Interactive tooltips showing river names and average discharge
   - Filtered to show only significant rivers (stream order ≥ 6)

3. **Gauging Station Layer**
   - Stations sized by average runoff (larger circles = higher runoff)
   - Stations colored by years of available data (more years = brighter colors)
   - Popup information showing detailed station metadata and statistics
   - Separate layer for stations removed during filtering

4. **Hydropower Plant Layer**
   - Plants represented as squares sized by capacity (MW)
   - Color-coded by status (operating, construction, announced, etc.)
   - Detailed popups showing plant information and associated station data
   - Information on whether the plant is on the same river as its associated station

### Interactive Tools
The map includes several interactive features:
- Layer controls to toggle visibility of different data elements
- Fullscreen mode for detailed exploration
- Measurement tool for calculating distances between features
- Popups with detailed information on each map element

### Output Files
- **station_runoff_map.html**: Basic map without river data
- **station_runoff_map_with_rivers.html**: Complete map with river network (if river data is available)

### Viewing and Sharing
- Open the HTML files in any modern web browser
- No internet connection required (all data is embedded)
- Maps can be shared as standalone files
- For large datasets, file sizes may be substantial

### Troubleshooting
- If rivers don't appear, check that the HydroRIVERS shapefile path is correct
- Large river datasets may slow down map rendering
- If the map appears blank, try a different web browser
- For memory issues, reduce the number of rivers by increasing the stream order threshold

In [None]:
# Function to clip rivers to the bounding box of runoff stations
def clip_rivers_to_stations(rivers_path, df_runoff, ord_stra=6):
    # Step 1: Load the HydroRIVERS shapefile

    # To be downloaded from: https://www.hydrosheds.org/products/hydrorivers using shapefile format
    # Load available layers (should return e.g. 'HydroRIVERS_v10_af')

    if not os.path.exists(rivers_path):
        raise FileNotFoundError(f"Geodatabase not found: {rivers_path}")

    # Load rivers
    rivers_gdf = gpd.read_file(rivers_path)

    # Filter based on average discharge and stream order
    rivers_gdf = rivers_gdf[(rivers_gdf["DIS_AV_CMS"] > 50) & (rivers_gdf["ORD_STRA"] >= 6) & (rivers_gdf["ORD_FLOW"] <= 5)]
    #rivers_gdf = rivers_gdf[(rivers_gdf["ORD_STRA"] >= ord_stra)]

    # Step 2: Prepare the bounding box for clipping
    # Create GeoDataFrame from stations
    stations_gdf = gpd.GeoDataFrame(
        df_runoff,
        geometry=gpd.points_from_xy(df_runoff["geo_x"], df_runoff["geo_y"]),
        crs="EPSG:4326"
    )

    # Get bounding box coordinates from the stations
    roi_bounds = stations_gdf.total_bounds  # [minx, miny, maxx, maxy]

    # Create a polygon box from the bounds
    roi_polygon = box(*roi_bounds).buffer(1.0)  # optional padding of 1 degree

    # Create a GeoDataFrame for clipping
    roi_gdf = gpd.GeoDataFrame(geometry=[roi_polygon], crs="EPSG:4326")

    rivers_clipped = gpd.clip(rivers_gdf, roi_gdf)

    return rivers_clipped


In [None]:
# This cell can take a while to run (some minutes), depending on the number of stations and rivers
def make_maps(df_runoff, data_hpp=None, rivers_path=None, folder_out=folder_out):
    """
    Create interactive maps showing runoff stations, rivers, and hydropower plants.

    This function generates a Folium map with multiple layers:
    1. River network colored by discharge (if rivers_path is provided)
    2. Gauging stations with size based on runoff and color based on years of data
    3. Hydropower plants with size based on capacity and color based on status

    Parameters:
        df_runoff (DataFrame): DataFrame containing runoff data with station coordinates
        data_hpp (DataFrame, optional): DataFrame containing hydropower plant data
        rivers_path (str, optional): Path to the HydroRIVERS shapefile
        folder_out (str): Output folder for saving the map

    Returns:
        folium.Map: The interactive map object
    """
    # Read station data
    agg = (
        df_runoff.groupby(["station_name", "geo_x", "geo_y"])
          .agg(
              avg_runoff=("runoff_mm_year", "mean"),
              n_years=("year", "count")
          )
          .reset_index()
    )

    # Create base map centered on the region
    center_lat = agg["geo_y"].mean()
    center_lon = agg["geo_x"].mean()

    # Create map with improved base layer options
    m = folium.Map(
        location=[center_lat, center_lon], 
        zoom_start=5, 
        tiles="CartoDB positron"
    )

    # Add additional base layers for better visualization options
    folium.TileLayer('CartoDB dark_matter', name='Dark Map').add_to(m)
    folium.TileLayer('OpenStreetMap', name='OpenStreetMap').add_to(m)
    folium.TileLayer(
        tiles='https://server.arcgisonline.com/ArcGIS/rest/services/World_Imagery/MapServer/tile/{z}/{y}/{x}',
        attr='Esri',
        name='Satellite'
    ).add_to(m)

    # Add rivers from HydroRIVERS
    name = f'station_runoff_map.html'
    if rivers_path is not None:
        try:
            rivers_clipped = clip_rivers_to_stations(rivers_path, df_runoff)

            # Create a feature group for rivers to allow toggling
            rivers_group = folium.FeatureGroup(name="Rivers", show=True)

            colormap = plt.cm.viridis
            norm = mcolors.Normalize(
                vmin=rivers_clipped["DIS_AV_CMS"].min(),
                vmax=rivers_clipped["DIS_AV_CMS"].max())

            # Add rivers to the map with color based on discharge
            for _, row in rivers_clipped.iterrows():
                discharge = row["DIS_AV_CMS"]
                color = mcolors.to_hex(colormap(norm(discharge)))

                folium.GeoJson(
                    row["geometry"],
                    style_function=lambda feature, color=color: {
                        "color": color,
                        "weight": 2,
                        "opacity": 0.8
                    },
                    tooltip=f"River: {row.get('RIVER_NAME', 'Unknown')}<br>Discharge: {discharge:.1f} m³/s"
                ).add_to(rivers_group)

            # Add the rivers group to the map
            rivers_group.add_to(m)

            # Define vmin and vmax
            vmin = rivers_clipped["DIS_AV_CMS"].min()
            vmax = rivers_clipped["DIS_AV_CMS"].max()

            ticks_raw = np.linspace(vmin, vmax, 5)
            ticks = [round(t, -int(np.floor(np.log10(t))) + 1) for t in ticks_raw]  # e.g., 72 → 70, 1502 → 1500

            norm = mcolors.Normalize(vmin=min(ticks), vmax=max(ticks))
            colors = [mcolors.to_hex(plt.cm.viridis(norm(t))) for t in ticks]

            # Add high-contrast legend
            discharge_colormap = LinearColormap(
                colors=colors,
                index=ticks,
                vmin=vmin,
                vmax=vmax,
                caption="Avg River Discharge (m³/s)"
            )
            discharge_colormap.add_to(m)

            name = f'station_runoff_map_with_rivers.html'

        except Exception as e:
            print(f"Error loading river data: {e}")
            print("Continuing without river data")

    # Create a feature group for active stations
    stations_group = folium.FeatureGroup(name="Gauging Stations", show=True)

    # Include station
    if agg is not None:
        norm_years = LogNorm(vmin=agg["n_years"].min(), vmax=agg["n_years"].max())
        colormap = plt.cm.viridis  # or any other colormap

        for _, row in agg.iterrows():
            # Normalize color from colormap
            rgba = colormap(norm_years(row["n_years"]))
            hex_color = mcolors.to_hex(rgba)

            # Scale size (radius) based on average runoff
            radius = max(4, (row["avg_runoff"] / agg["avg_runoff"].max()) * 20)

            # Add circle marker with enhanced popup
            popup_content = f"""
            <div style="font-family: Arial; min-width: 200px;">
                <h4 style="margin-bottom: 10px;">{row['station_name']}</h4>
                <table style="width: 100%; border-collapse: collapse;">
                    <tr>
                        <td style="padding: 3px;"><b>Avg Runoff:</b></td>
                        <td style="padding: 3px;">{row['avg_runoff']:.0f} mm/year</td>
                    </tr>
                    <tr>
                        <td style="padding: 3px;"><b>Years of Data:</b></td>
                        <td style="padding: 3px;">{row['n_years']}</td>
                    </tr>
                    <tr>
                        <td style="padding: 3px;"><b>Coordinates:</b></td>
                        <td style="padding: 3px;">{row['geo_y']:.4f}, {row['geo_x']:.4f}</td>
                    </tr>
                </table>
            </div>
            """

            folium.CircleMarker(
                location=[row["geo_y"], row["geo_x"]],
                radius=radius,
                color=hex_color,
                fill=True,
                fill_opacity=0.8,
                popup=folium.Popup(popup_content, max_width=300)
            ).add_to(stations_group)

        # Add stations group to map
        stations_group.add_to(m)

        # FeatureGroup for removed stations
        removed_stations_group = folium.FeatureGroup(name="Removed Stations", show=False)

        # Plot removed due to missing area (red X)
        if 'stations_removed_area_df' in globals() and stations_removed_area_df is not None and not stations_removed_area_df.empty:
            for _, row in stations_removed_area_df.iterrows():
                folium.RegularPolygonMarker(
                    location=[row["geo_y"], row["geo_x"]],
                    number_of_sides=4,
                    radius=5,
                    rotation=45,
                    color="red",
                    fill=True,
                    fill_color="red",
                    fill_opacity=1.0,
                    popup=f"<b>{row['station_name']}</b><br>Removed: Area ≤ 0"
                ).add_to(removed_stations_group)

        # Plot removed due to low flow (orange X)
        if 'stations_removed_lowflow_df' in globals() and stations_removed_lowflow_df is not None and not stations_removed_lowflow_df.empty:
            for _, row in stations_removed_lowflow_df.iterrows():
                folium.RegularPolygonMarker(
                    location=[row["geo_y"], row["geo_x"]],
                    number_of_sides=4,
                    radius=5,
                    rotation=45,
                    color="orange",
                    fill=True,
                    fill_color="orange",
                    fill_opacity=1.0,
                    popup=f"<b>{row['station_name']}</b><br>Removed: Flow < threshold"
                ).add_to(removed_stations_group)

        # Add group to map
        removed_stations_group.add_to(m)

        # Create a combined legend for stations color and size
        years_vals = np.percentile(agg["n_years"], [0, 50, 100]).round(0).astype(int)
        runoff_vals = np.percentile(agg["avg_runoff"], [0, 50, 100]).round(0)

        combined_legend_html = f'''
        <div style="
            position: fixed;
            bottom: 50px; left: 50px; width: 260px;
            background-color: white;
            border: 2px solid grey;
            z-index: 9999;
            font-size: 13px;
            padding: 12px;">

            <b style="font-size:14px;">Gauging Stations</b><br><br>

            <b>Years of Available Data (Color)</b><br>
            <i style="background: #440154; width: 18px; height: 18px;
               float: left; margin-right: 5px;"></i> {years_vals[0]}<br>
            <i style="background: #21918c; width: 18px; height: 18px;
               float: left; margin-right: 5px;"></i> {years_vals[1]}<br>
            <i style="background: #fde725; width: 18px; height: 18px;
               float: left; margin-right: 5px;"></i> {years_vals[2]}<br><br>

            <b>Avg Runoff (mm/year)</b><br>
            <svg width="120" height="90">
              <circle cx="30" cy="20" r="6" fill="#777" fill-opacity="0.7" stroke="#333" />
              <text x="50" y="25" font-size="12">{int(runoff_vals[0])}</text>
              <circle cx="30" cy="45" r="10" fill="#777" fill-opacity="0.7" stroke="#333" />
              <text x="50" y="50" font-size="12">{int(runoff_vals[1])}</text>
              <circle cx="30" cy="75" r="14" fill="#777" fill-opacity="0.7" stroke="#333" />
              <text x="50" y="80" font-size="12">{int(runoff_vals[2])}</text>
            </svg>

            <b>Removed Stations</b><br>
            <i class="fa fa-times" style="color:red; margin-right: 6px;"></i> Area missing<br>
            <i class="fa fa-times" style="color:orange; margin-right: 6px;"></i> Low flow
        </div>
        '''

        # Add legend to the map
        m.get_root().html.add_child(folium.Element(combined_legend_html))

    # Include hydropower plants
    if data_hpp is not None and not data_hpp.empty:
        # Create a feature group for hydropower plants
        hpp_group = folium.FeatureGroup(name="Hydropower Plants", show=True)

        # Define fixed status colors
        custom_status_colors = {
            'operating': '#2ca02c',           # green
            'construction': '#ff7f0e',        # orange
            'announced': '#999999',           # grey
            'pre-construction': '#9467bd'     # purple
        }

        # Filter data_hpp for selected statuses only
        valid_statuses = list(custom_status_colors.keys())
        data_hpp_filtered = data_hpp[data_hpp['Status'].isin(valid_statuses)]

        if not data_hpp_filtered.empty:
            # Normalize capacity for marker sizing
            cap_min = data_hpp_filtered['Capacity (MW)'].min()
            cap_max = data_hpp_filtered['Capacity (MW)'].max()

            def normalize_capacity(value):
                return 6 + ((value - cap_min) / (cap_max - cap_min + 1e-6)) * 10  # radius 6-16

            # Convert RGBA to HEX for Folium
            status_colors = {k: mcolors.to_hex(v) for k, v in custom_status_colors.items()}

            # Add each HPP as a square marker with enhanced popup
            for _, row in data_hpp_filtered.iterrows():
                # Create a more structured popup with selected fields
                selected_fields = [
                    'Project Name', 'Status', 'Capacity (MW)', 
                    'River / Watercourse', 'Country/Area 1'
                ]

                popup_content = f"""
                <div style="font-family: Arial; min-width: 250px;">
                    <h4 style="margin-bottom: 10px;">{row['Project Name']}</h4>
                    <table style="width: 100%; border-collapse: collapse;">
                """

                for field in selected_fields:
                    if field in row and field != 'Project Name':
                        value = row[field]
                        popup_content += f"""
                        <tr>
                            <td style="padding: 3px;"><b>{field}:</b></td>
                            <td style="padding: 3px;">{value}</td>
                        </tr>
                        """

                # Add nearest station info if available
                if 'nearest_station' in row and 'distance_to_station_m' in row:
                    distance_km = row['distance_to_station_m'] / 1000
                    same_river = "Yes" if row.get('same_river', False) else "No"
                    popup_content += f"""
                    <tr>
                        <td style="padding: 3px;"><b>Nearest Station:</b></td>
                        <td style="padding: 3px;">{row['nearest_station']}</td>
                    </tr>
                    <tr>
                        <td style="padding: 3px;"><b>Distance:</b></td>
                        <td style="padding: 3px;">{distance_km:.1f} km</td>
                    </tr>
                    <tr>
                        <td style="padding: 3px;"><b>Same River:</b></td>
                        <td style="padding: 3px;">{same_river}</td>
                    </tr>
                    """

                popup_content += """
                    </table>
                </div>
                """

                color = status_colors.get(row['Status'], '#555')
                size = normalize_capacity(row['Capacity (MW)'])

                folium.RegularPolygonMarker(
                    location=[row['Latitude'], row['Longitude']],
                    number_of_sides=4,  # square
                    radius=size,
                    color=None,
                    fill=True,
                    fill_color=color,
                    fill_opacity=0.9,
                    popup=folium.Popup(popup_content, max_width=300)
                ).add_to(hpp_group)

            # Add the HPP group to the map
            hpp_group.add_to(m)

            # Prepare capacity values for size scale legend
            capacity_vals = np.percentile(data_hpp_filtered['Capacity (MW)'], [0, 50, 100]).round(0).astype(int)
            circle_sizes = [6, 10, 14]  # Match values used in marker radius scaling

            # Create status color legend HTML
            status_legend_items = ""
            for status in valid_statuses:
                hex_color = custom_status_colors[status]
                status_legend_items += f'''
                    <i style="background:{hex_color}; width: 18px; height: 18px;
                    float: left; margin-right: 6px;"></i> {status}<br>
                '''

            # Full HTML block
            hpp_legend_html = f'''
            <div style="
                position: fixed;
                bottom: 50px; left: 340px; width: 280px;
                background-color: white;
                border: 2px solid grey;
                z-index: 9999;
                font-size: 13px;
                padding: 12px;">

                <b style="font-size:14px;">Hydropower Plants</b><br><br>

                <b>Status (Color)</b><br>
                {status_legend_items}
                <br>

                <b>Capacity (MW)</b><br>
                <svg width="160" height="100">
                  <rect x="20" y="10" width="{circle_sizes[0]*2}" height="{circle_sizes[0]*2}" fill="#999"/>
                  <text x="60" y="25" font-size="12">{capacity_vals[0]}</text>
                  <rect x="20" y="40" width="{circle_sizes[1]*2}" height="{circle_sizes[1]*2}" fill="#999"/>
                  <text x="60" y="55" font-size="12">{capacity_vals[1]}</text>
                  <rect x="20" y="70" width="{circle_sizes[2]*2}" height="{circle_sizes[2]*2}" fill="#999"/>
                  <text x="60" y="85" font-size="12">{capacity_vals[2]}</text>
                </svg>
            </div>
            '''
            m.get_root().html.add_child(folium.Element(hpp_legend_html))

    # Add layer control to toggle different map elements
    folium.LayerControl().add_to(m)

    # Add fullscreen button for better user experience
    folium.plugins.Fullscreen().add_to(m)

    # Add measure tool for distance measurements
    folium.plugins.MeasureControl(position='topleft', primary_length_unit='kilometers').add_to(m)

    # Save the map
    try:
        m.save(os.path.join(folder_out, name))
        print(f"Map saved to {os.path.join(folder_out, name)}")
    except Exception as e:
        print(f"Error saving map: {e}")

    return m


# Add info

river_path = os.path.join(folder_in, 'river_input', 'HydroRIVERS_v10_af_shp', 'HydroRIVERS_v10_af.shp')
make_maps(data_runoff, data_hpp=data_hpp, rivers_path=river_path)

## 7. Plot Station Runoff for Each Hydropower Plant

### Purpose
This section creates visualizations that show the historical runoff patterns for gauging stations associated with hydropower plants. These plots are essential for:
- Understanding long-term water availability at hydropower locations
- Identifying trends and variability in water resources
- Assessing potential risks to hydropower generation
- Supporting operational planning and investment decisions

### Visualization Types

#### 1. Annual Runoff Evolution
The first plot (`plot_runoff_evolution`) shows how annual runoff has changed over time for each station associated with a hydropower plant:
- Each hydropower plant's nearest station gets its own subplot
- X-axis represents years
- Y-axis shows runoff in mm/year
- Title indicates both station name and associated hydropower project
- Labels indicate whether the station is on the same river as the plant

#### 2. Monthly Discharge Patterns
The second plot (`plot_discharge_month`) displays the average monthly discharge pattern for each station:
- Shows the seasonal flow pattern throughout the year
- Helps identify high and low flow seasons
- Critical for understanding hydropower generation potential throughout the year
- Useful for operational planning and maintenance scheduling

### Interpretation Guidelines

**For Annual Runoff Plots:**
- Look for long-term trends (increasing, decreasing, or stable patterns)
- Note any abrupt changes that might indicate upstream development or climate shifts
- Compare patterns across different stations to identify regional trends
- Pay special attention to recent years for current conditions

**For Monthly Discharge Plots:**
- Identify the high and low flow seasons
- Note the magnitude of seasonal variation (highly seasonal vs. relatively stable)
- Consider how the seasonal pattern aligns with energy demand patterns
- For plants with "Same River = False", interpret with caution as patterns may differ

### Output Files
- **runoff_evolution_hpp.png**: Time series of annual runoff for each hydropower-associated station
- **runoff_monthly_evolution_hpp.png**: Average monthly discharge patterns for each station

### Customization Options
- Adjust the figure size by modifying the `figsize` parameter
- Change the number of columns in the subplot grid by adjusting `ncols`
- Modify the output file names by changing the parameters to `folder_out`

In [None]:
stations = data_hpp['nearest_station'].unique()
# Filter data_runoff to include only stations that have hydropower plants associated
temp = data_runoff[data_runoff['station_name'].isin(stations)].copy()
# Subplots for each station and evolution of

In [None]:
def plot_runoff_evolution(data_runoff, data_hpp, folder_out=None):
    """
    Plot the evolution of runoff for each hydropower plant's nearest station.
    This function creates subplots for each station associated with hydropower plants,
    showing the annual runoff evolution over the years.

    Parameters:
        data_runoff (DataFrame): DataFrame containing runoff data with station coordinates.
        data_hpp (DataFrame): DataFrame containing hydropower plant data with nearest stations.
        folder_out (str, optional): Output folder to save the plot. If None, displays the plot.
    """
    stations = data_hpp['nearest_station'].unique()
    temp = data_runoff[data_runoff['station_name'].isin(stations)].copy()

    # Map station_name to project name and river match info
    station_to_project = data_hpp.set_index('nearest_station')['Project Name'].to_dict()
    station_to_river_match = data_hpp.set_index('nearest_station')['same_river'].to_dict()

    temp['Project Name'] = temp['station_name'].map(station_to_project)
    temp['same_river'] = temp['station_name'].map(station_to_river_match)

    n_stations = len(stations)
    ncols = 3
    nrows = (n_stations + ncols - 1) // ncols

    fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(18, 5 * nrows), sharey=True)
    axs = axs.flatten()

    min_year = temp['year'].min()
    max_year = temp['year'].max()

    for i, station in enumerate(stations):
        ax = axs[i]
        subset = temp[temp['station_name'] == station]
        ax.plot(subset['year'].astype(int), subset['runoff_mm_year'], marker='o', linestyle='-')
        match_label = "Same River" if station_to_river_match[station] else "Closest Station"
        ax.set_title(f"{station} / {station_to_project[station]} ({match_label})", fontsize=10)
        ax.set_ylabel("Runoff (mm/year)")
        ax.set_xlim(min_year, max_year)
        ax.grid(True)

    # Remove any unused subplots
    for j in range(i + 1, len(axs)):
        fig.delaxes(axs[j])

    plt.tight_layout()
    if folder_out:
        plt.savefig(os.path.join(folder_out, 'runoff_evolution_hpp.png'))
        plt.close()
    else:
        plt.show()


plot_runoff_evolution(data_runoff, data_hpp, folder_out=folder_out)


In [None]:
def plot_discharge_month(data_station_filtered, data_hpp, folder_out=None):
    """
    Plot monthly discharge evolution for each hydropower plant's nearest station.
    This function creates subplots for each station associated with hydropower plants,
    showing the average monthly discharge over the years.

    Parameters:
        data_station_filtered (DataFrame): DataFrame containing station discharge data.
        data_hpp (DataFrame): DataFrame containing hydropower plant data with nearest stations.
        folder_out (str, optional): Output folder to save the plot. If None, displays the plot.
    """

    stations = data_hpp['nearest_station'].unique()
    temp = data_station_filtered[data_station_filtered['station_name'].isin(stations)].copy()

    # Map station_name to project name and river match info
    station_to_project = data_hpp.set_index('nearest_station')['Project Name'].to_dict()
    station_to_river_match = data_hpp.set_index('nearest_station')['same_river'].to_dict()

    temp['Project Name'] = temp['station_name'].map(station_to_project)
    temp['same_river'] = temp['station_name'].map(station_to_river_match)

    n_stations = len(stations)
    ncols = 3
    nrows = (n_stations + ncols - 1) // ncols

    fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(18, 5 * nrows), sharey=False)
    axs = axs.flatten()

    for i, station in enumerate(stations):
        ax = axs[i]
        subset = temp[temp['station_name'] == station]
        subset = subset.groupby(['Project Name', 'month', 'same_river'])['Q'].mean().reset_index()
        subset['month_abbr'] = subset['month'].astype(int).apply(lambda x: calendar.month_abbr[x])
        ax.plot(subset['month_abbr'], subset['Q'], marker='o', linestyle='-')
        match_label = "Same River" if station_to_river_match[station] else "Closest Station"
        ax.set_title(f"{station} / {station_to_project[station]} ({match_label})", fontsize=10)
        ax.set_ylabel("Discharge (m3/s)")
        ax.grid(True)

    # Remove any unused subplots
    for j in range(i + 1, len(axs)):
        fig.delaxes(axs[j])

    plt.tight_layout()
    if folder_out:
        plt.savefig(os.path.join(folder_out, 'runoff_monthly_evolution_hpp.png'))
        plt.close()
    else:
        plt.show()

plot_discharge_month(data_station_filtered, data_hpp, folder_out=folder_out)

## Filling Missing Values in Time Series Data

### Purpose
This experimental section demonstrates methods for handling missing values in hydrological time series data. Missing data is a common challenge in hydrology due to equipment failures, maintenance periods, or data transmission issues. This section explores two approaches:

1. **Climatology-Based Filling**: Replacing missing values with long-term monthly averages
2. **Interpolation-Based Filling**: Using linear interpolation to estimate missing values

### Why This Matters
Complete time series are essential for:
- Reliable trend analysis
- Accurate seasonal pattern identification
- Input to hydrological and energy models
- Consistent comparison between stations and time periods

### Methodology

#### Climatology-Based Filling (`fill_missing_climatology`)
- Calculates the long-term average for each month at each station
- Replaces missing values with these monthly averages
- Preserves the seasonal pattern but loses inter-annual variability
- Works well for stations with strong seasonal patterns
- Requires at least 3 months of data per year to be considered valid

#### Interpolation-Based Filling (`drop_sparse_years_and_interpolate`)
- Removes years with too many missing months (less than 9 valid months)
- Uses linear interpolation to fill gaps in the remaining time series
- Better preserves trends and inter-annual patterns
- Works well for short gaps but may create unrealistic values for longer gaps
- More sensitive to outliers than the climatology method

### Visualization and Comparison
The `plot_all_fill_methods` function creates comparison plots showing:
- Original data with gaps
- Climatology-filled data
- Interpolation-filled data

These plots help assess which method is more appropriate for each station based on the pattern and extent of missing data.

### Interpretation Guidelines
- **For climatology-filled data**: Good for preserving seasonal patterns but may mask long-term trends
- **For interpolated data**: Better for trend analysis but may create unrealistic values for long gaps
- **When comparing methods**: Look for divergence between methods, which indicates higher uncertainty

### Status: Experimental
This section is marked as "in progress" because:
- The methods are still being refined
- Parameter values (min_months) may need adjustment for specific applications
- Additional methods (e.g., ARIMA models, machine learning approaches) could be added
- Validation against known values has not been fully implemented

### Next Steps for Development
- Add validation metrics to quantify filling accuracy
- Implement more sophisticated filling methods
- Add uncertainty estimates for filled values
- Create diagnostic plots to help select the best method for each station

In [None]:
def fill_missing_climatology(df, min_months=3):
    """
    Fill missing values using monthly climatology (station-wise).
    Returns a filled DataFrame and a mask of filled values.
    """
    df = df.copy()

    # Step 1: Filter out sparse years
    valid_years = (
        df.groupby(["station_name", "year"])["Q"]
        .apply(lambda x: x.notna().sum() >= min_months)
        .reset_index(name="keep")
    )

    # Merge to filter out sparse rows
    df = df.merge(valid_years[valid_years["keep"]], on=["station_name", "year"])
    df.drop(columns="keep", inplace=True)

    # Step 2: Build climatology
    climatology = (
        df.groupby(["station_name", "month"])["Q"]
        .mean()
        .rename("Q_clim")
        .reset_index()
    )

    df = df.merge(climatology, on=["station_name", "month"], how="left")

    # Step 3: Fill missing with climatology
    fill_mask = df["Q"].isna()
    df.loc[fill_mask, "Q"] = df.loc[fill_mask, "Q_clim"]

    df.drop(columns="Q_clim", inplace=True)

    return df, fill_mask

def drop_sparse_years_and_interpolate(df, min_months=9):
    """
    Drops years with too much missing data and interpolates gaps per station.
    Assumes monthly data. Returns cleaned & interpolated DataFrame.
    """
    df_clean = df.copy()

    # Count non-NaN entries per year-station
    valid_counts = (
        df_clean.groupby(["station_name", "year"])["Q"]
        .apply(lambda x: x.notna().sum())
        .rename("valid_months")
        .reset_index()
    )

    # Keep only rows with enough months
    valid_years = valid_counts[valid_counts["valid_months"] >= min_months]
    df_clean = df_clean.merge(valid_years[["station_name", "year"]], on=["station_name", "year"])

    # Sort for interpolation
    df_clean = df_clean.sort_values(["station_name", "year", "month"])

    # Interpolate per station
    df_clean["Q"] = df_clean.groupby("station_name")["Q"].transform(lambda x: x.interpolate(method='linear', limit_direction='both'))

    return df_clean

def plot_all_fill_methods(df_original, df_clim, df_interp, output_dir):
    """
    Plots original, climatology-filled, and interpolated runoff data per station in one plot.

    Assumes all DataFrames have columns: ['year', 'month', 'station_name', 'Q'].
    """
    os.makedirs(output_dir, exist_ok=True)
    output_dir = os.path.join(output_dir, 'stations')
    os.makedirs(output_dir, exist_ok=True)

    # Combine 'year' and 'month' into datetime
    def add_datetime(df):
        return df.assign(
            date=pd.to_datetime(df['year'].astype(str) + '-' + df['month'].astype(str).str.zfill(2))
        )

    df_original = add_datetime(df_original)
    df_clim = add_datetime(df_clim)
    df_interp = add_datetime(df_interp)

    # Loop over each station
    stations = df_original['station_name'].unique()

    for station in stations:
        fig, ax = plt.subplots(figsize=(12, 4))

        # Subsets for current station
        df_o = df_original[df_original['station_name'] == station]
        df_c = df_clim[df_clim['station_name'] == station]
        df_i = df_interp[df_interp['station_name'] == station]

        # Plot all
        ax.plot(df_o['date'], df_o['Q'], label="Original", alpha=0.5, marker='o', linestyle='-', color='black')
        ax.plot(df_c['date'], df_c['Q'], label="Climatology Fill", linestyle='--', color='orange')
        ax.plot(df_i['date'], df_i['Q'], label="Interpolated", linestyle='-', color='blue')

        ax.set_title(f"Station: {station}")
        ax.set_ylabel("Discharge / Runoff (Q)")
        ax.set_xlabel("Time")
        ax.legend()
        ax.grid(True)

        plt.tight_layout()
        plt.savefig(os.path.join(output_dir, f"{station}.png"))
        plt.close()


In [None]:
df_filled, filled_mask = fill_missing_climatology(df_filtered, min_months=6)
df_interpolated = drop_sparse_years_and_interpolate(df_filtered, min_months=9)
plot_all_fill_methods(df_filtered, df_filled, df_interpolated, output_dir=folder_out)