# GRDC Hydro data processing

Processing streamflow data from the Global Runoff Data Centre (GRDC) to create a comprehensive dataset of river discharge and runoff.

Rely on three kind of input data that all exists in open-source and needs to be downloaded separately:
1. Monthly streamflow data from GRDC stations
2. Hydropower plant data from the Global Hydropower Tracker
3. River data from HydroRIVERS



In [None]:
import os
import geopandas as gpd
import xarray as xr
import folium
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from branca.colormap import LinearColormap
from folium import CircleMarker
from matplotlib.colors import LogNorm

import geopandas as gpd
from shapely.geometry import box
import calendar

from rapidfuzz import process, fuzz
import unicodedata
import re


## User input

In [104]:
countries = ['Angola', 'Burundi', 'Cameroon', 'Central African Republic', 'Chad', 'Republic of the Congo', 'DR Congo', 'Equatorial Guinea', 'Gabon']

# Define the base input directory
base_dir = 'data_grdc_hydro_capp/input'
folder_out = 'data_grdc_hydro_capp/output'
if not os.path.exists(folder_out): os.makedirs(folder_out)

## 1. Load streamflow data

In [105]:
# Initialize a list to hold all xarray Datasets
datasets = []

# Walk through all directories under 'input'
for root, dirs, files in os.walk(base_dir):
    if 'GRDC-Monthly.nc' in files:
        file_path = os.path.join(root, 'GRDC-Monthly.nc')
        try:
            data_discharge = xr.open_dataset(file_path)
            datasets.append(data_discharge)
            print(f"Loaded: {file_path}")
        except Exception as e:
            print(f"Failed to load {file_path}: {e}")

# Merge all datasets
if datasets:
    try:
        data_discharge = xr.concat(datasets, dim='id')  # You can use another dim if needed
        print(f"Total merged dimensions: {data_discharge.dims}")
    except Exception as e:
        print(f"Failed to merge datasets: {e}")
else:
    print("No streamflow data found.")

Loaded: data_grdc_hydro_capp/input/sao_tome/GRDC-Monthly.nc
Loaded: data_grdc_hydro_capp/input/gabon/GRDC-Monthly.nc
Loaded: data_grdc_hydro_capp/input/democroatic_republic_congo/GRDC-Monthly.nc
Loaded: data_grdc_hydro_capp/input/central_africa/GRDC-Monthly.nc
Loaded: data_grdc_hydro_capp/input/angola/GRDC-Monthly.nc
Loaded: data_grdc_hydro_capp/input/chad/GRDC-Monthly.nc
Loaded: data_grdc_hydro_capp/input/republic_congo/GRDC-Monthly.nc
Loaded: data_grdc_hydro_capp/input/cameroon/GRDC-Monthly.nc
Loaded: data_grdc_hydro_capp/input/burundi/GRDC-Monthly.nc


## 2. Filter  data

In [106]:
# -------------------------------
# 1. Filter out stations with no valid area
# -------------------------------

# Identify stations with invalid or missing area
invalid_area_mask = (data_discharge["area"] <= 0) | data_discharge["area"].isnull()
removed_area_ids = data_discharge["id"].values[invalid_area_mask.values]

# Extract metadata of removed stations
removed_area_meta = data_discharge.sel(id=removed_area_ids)[["station_name", "geo_x", "geo_y", "area"]]
stations_removed_area_df = removed_area_meta.to_dataframe().reset_index()

# Remove stations with invalid area
data_discharge = data_discharge.where(data_discharge["area"] > 0, drop=True)

# -------------------------------
# 2. Filter out stations with low runoff values
# -------------------------------

threshold_runoff = 15  # m³/s

# Identify stations where all runoff values are below threshold or NaN
runoff = data_discharge["runoff_mean"]
low_flow_mask = (runoff < threshold_runoff) | runoff.isnull()
stations_to_remove = low_flow_mask.all(dim="time")
removed_ids = data_discharge["id"].values[stations_to_remove.values]

# Extract metadata of removed stations
removed_meta = data_discharge.sel(id=removed_ids)[["station_name", "geo_x", "geo_y"]]
stations_removed_lowflow_df = removed_meta.to_dataframe().reset_index()

# Remove those stations from the dataset
data_discharge = data_discharge.sel(id=~stations_to_remove)

# -------------------------------
# 3. Summary (optional display or export)
# -------------------------------

print(f"Stations removed due to invalid area: {len(stations_removed_area_df)}")
print(f"Stations removed due to low runoff: {len(stations_removed_lowflow_df)}")

Stations removed due to invalid area: 2
Stations removed due to low runoff: 40


## 3. Format data for analysis

In [107]:
def checking_duplicates_grdc(df):
    """
    Check for duplicated (year, station_name, month) entries in the DataFrame.

    :param df:
    :return:
    """
    dupes = (
        df.groupby(["year", "station_name", "month"])
        .size()
        .reset_index(name='count')
        .query("count > 1")
    )

    print(f"Found {len(dupes)} duplicated (year, station_name, month) entries")

    problem_stations = dupes['station_name'].unique()
    print(f'Duplicated tation: {problem_stations}')

    # Merge the duplicate keys back into df to see full rows
    duplicated_rows = df.merge(dupes[["year", "station_name", "month"]], on=["year", "station_name", "month"])
    #display(duplicated_rows.sort_values(["station_name", "year", "month"]))

    return problem_stations

In [108]:
var_name = 'runoff_mean'

meta = data_discharge[["station_name", "geo_x", "geo_y"]].to_dataframe().reset_index()
area = data_discharge["area"].to_dataframe().reset_index()

# Convert to DataFrame
data_station = data_discharge[var_name].to_dataframe(name="Q").reset_index()

# Merge metadata into the main DataFrame
data_station = data_station.merge(meta, on="id")
data_station = data_station.merge(area, on="id")

# Add year and month columns
data_station["year"] = data_station["time"].dt.year
data_station["month"] = data_station["time"].dt.month

# Check if station_name has duplicates
if data_station.duplicated(subset=["station_name"]).any():
    problem_stations = checking_duplicates_grdc(data_station)

    # Step 3: Get corresponding station IDs
    problem_ids = data_station[data_station["station_name"].isin(problem_stations)]["id"].unique()

    # Step 4: Assign unique station names for problematic IDs only
    rename_map = {
        id_: f"{data_station[data_station['id'] == id_]['station_name'].iloc[0]}_{i+1}"
        for i, id_ in enumerate(problem_ids)
    }

    # Step 5: Apply renaming based on `id` consistently
    data_station["station_name"] = data_station.apply(
        lambda row: rename_map[row["id"]] if row["id"] in rename_map else row["station_name"],
        axis=1
    )

# Create a unique label for each station using name + coordinates (optional)
data_station["station_label"] = data_station["station_name"].str.strip() + " (" + data_station["geo_y"].round(2).astype(str) + ", " + data_station["geo_x"].round(2).astype(str) + ")"

# Pivot to wide format: year as index, MultiIndex (month, id) as columns
data_station_pivot = data_station.pivot(index="year", columns=["station_name", "month"], values="Q")
data_station_pivot.dropna(axis=1, how='all', inplace=True)
data_station_pivot.sort_index(ascending=True, axis=1, inplace=True)
data_station_pivot.to_csv(os.path.join(folder_out, f'grdc_discharge_monthly-m3-s.csv'), index=True)

Found 2736 duplicated (year, station_name, month) entries
Duplicated tation: ['KRIBI' 'MUYANGE']


## 4. Calculate runoff data

Runoff data (mm/year) is calculated from monthly discharge data (m³/s) using the formula:

    runoff_mm = (Σ(Q_monthly_avg × 86400 × days_in_month)) / area_m2 × 1000

Where:
- Q is discharge in m³/s
- 86400 is seconds per day
- days_in_month accounts for monthly totals
- area is the catchment area in m²
- 1000 converts meters to millimeters

It is more standard way to compare runoff across different catchments.


In [109]:
def filter_full_years(df):
    """
    Keep only (station, year) pairs with 12 valid months.
    Returns filtered DataFrame and number of dropped rows.
    """
    original_len = len(df)

    # Count valid months per station-year
    valid_counts = (
        df.groupby(['station_name', 'year'])['Q']
        .apply(lambda x: x.notna().sum())
        .reset_index(name='valid_months')
    )

    # Only keep those with all 12 months
    full_years = valid_counts[valid_counts['valid_months'] == 12]

    # Merge to filter original DataFrame
    df_filtered = df.merge(full_years[['station_name', 'year']], on=['station_name', 'year'])

    removed_rows = original_len - len(df_filtered)
    print(f"Removed {removed_rows} rows — kept {len(df_filtered)} only full (12-month) years.")

    return df_filtered

def discharge_to_runoff(df):
    """
    Convert annual discharge (m³/s) into annual runoff (mm/year) at the station level.

    The equation used is:

        runoff_mm = (Σ(Q_monthly_avg × 86400 × days_in_month)) / area_m2 × 1000

    Where:
        - Q is discharge in m³/s
        - 86400 is seconds per day
        - days_in_month accounts for monthly totals
        - area is the catchment area in m²
        - 1000 converts meters to millimeters

    Assumes:
        - Input DataFrame has one row per station/month
        - Years are complete (12 months per station)

    Returns:
        - DataFrame with columns: station_label, year, runoff_mm_year
    """
    df = df.copy()

    # Add number of days in each month
    df['days_in_month'] = pd.to_datetime(df['time']).dt.days_in_month

    # Convert area to m²
    df['area_m2'] = df['area'] * 1e6

    # Compute volume in m³ for each month
    df['volume_m3'] = df['Q'] * 86400 * df['days_in_month']

    # Sum monthly volumes per station-year
    runoff_by_year = (
        df.groupby(['station_name', 'year'])
        .apply(lambda x: x['volume_m3'].sum() / x['area_m2'].iloc[0] * 1000, include_groups=False)  # m to mm
        .reset_index(name='runoff_mm_year')
    )

    return runoff_by_year

In [110]:
# Convert discharge data (m3-s) to runoff in (mm-year)
data_station_filtered = filter_full_years(data_station)
data_runoff = discharge_to_runoff(data_station_filtered)
# Save to CSV
data_runoff.round(0).pivot(index="year", columns="station_name", values="runoff_mm_year").to_csv(os.path.join(folder_out, f'grdc_runoff_mm-year.csv'), index=True)

location = data_station[['station_name', 'geo_x', 'geo_y']].drop_duplicates().reset_index(drop=True)
# Merge with runoff_by_year data
data_runoff = data_runoff.merge(location, on='station_name')

Removed 198480 rows — kept 25872 only full (12-month) years.


## 5. Add hydropowerplant data & associate with river data

In [111]:
path_hpp = 'data_grdc_hydro_capp/Global-Hydropower-Tracker-April-2025.xlsx'
if not os.path.exists(path_hpp):
    raise FileNotFoundError(f"Hydropower data file not found: {path_hpp}")

data_hpp = pd.read_excel(path_hpp, sheet_name='Data', header=[0], index_col=None)

data_hpp = data_hpp.loc[data_hpp['Country/Area 1'].isin(countries), :]

display(data_hpp.head())

Unnamed: 0,Date Last Researched,Country/Area 1,Country/Area 2,Project Name,Project Name (local lang/script),Other name(s),Capacity (MW),Binational,Country/Area 1 Capacity (MW),Country/Area 2 Capacity (MW),...,Region 1,City 2,Local Area 2,Major Area 2,State/Province 2,Subregion 2,Region 2,GEM location ID,GEM unit ID,Wiki URL
21,2024-11-22,Angola,,Caculo Cabaça hydroelectric plant,,,2172,No,2172.0,0.0,...,Africa,,,,,,,L100000600012,G100000600012,https://www.gem.wiki/Caculo_Cabaça_hydroelectr...
22,2022-09-12,Angola,,Cambambe I hydroelectric plant,,,180,No,180.0,0.0,...,Africa,,,,,,,L100000600013,G100000600013,https://www.gem.wiki/Cambambe_I_hydroelectric_...
23,2022-09-12,Angola,,Cambambe II hydroelectric plant,,,700,No,700.0,0.0,...,Africa,,,,,,,L100000600014,G100000600014,https://www.gem.wiki/Cambambe_II_hydroelectric...
24,2022-09-12,Angola,,Capanda hydroelectric plant,,,520,No,520.0,0.0,...,Africa,,,,,,,L100000600015,G100000600015,https://www.gem.wiki/Capanda_hydroelectric_plant
25,2024-11-25,Angola,,Gove Dam hydroelectric plant,,,60,No,60.0,0.0,...,Africa,,,,,,,L100001025775,G100001030660,https://www.gem.wiki/Gove_Dam_hydroelectric_plant


In [112]:
# Add river name to runoff data
meta_df = data_discharge[["station_name", "river_name"]].to_dataframe().reset_index()
meta_df = meta_df.drop_duplicates(subset=["station_name"])

data_station_filtered_river = data_station_filtered.merge(meta_df, on="station_name", how="left")

In [113]:
data_station_filtered_river[['station_name', 'river_name']].drop_duplicates()

Unnamed: 0,station_name,river_name
0,KINSHASA,CONGO RIVER
133,BANGUI,UBANGI
372,LAMBARENE,OGOOUE
470,NDJAMENA(FORT LAMY),CHARI
640,SARH(FORT ARCHAMBAULT),CHARI
...,...,...
24100,KAGA-BANDORO,GRIBINGUI
24106,BANGASSOU,MBOMOU
24125,MUYANGE_3,
24126,MUYANGE_4,


In [114]:
def clean_river_name(name):
    if pd.isna(name):
        return ""

    # Normalize unicode and remove accents (e.g. é → e)
    name = unicodedata.normalize('NFKD', name).encode('ASCII', 'ignore').decode('utf-8')

    # Remove common river words and clean up spacing
    name = re.sub(r"\b(RIVER|RIVIERE|RIVIÈRE|R\.|R)\b", "", name, flags=re.IGNORECASE)
    name = re.sub(r"\s+", " ", name).strip().upper()

    return name

# Prepare cleaned lists
station_rivers_raw = data_station_filtered_river['river_name'].dropna().unique()
station_rivers_cleaned = {clean_river_name(r): r for r in station_rivers_raw}

matches = []

# Iterate over all HPPs (row-wise)
for _, row in data_hpp.dropna(subset=["River / Watercourse"]).iterrows():
    hpp_river_orig = row["River / Watercourse"]
    hpp_river_clean = clean_river_name(hpp_river_orig)

    # Find best match among station rivers
    match, score, _ = process.extractOne(
        hpp_river_clean, station_rivers_cleaned.keys(), scorer=fuzz.token_set_ratio
    )

    if score >= 85:
        matched_station_river = station_rivers_cleaned[match]
        stations = data_station_filtered_river[
            data_station_filtered_river["river_name"] == matched_station_river
        ]["station_name"].tolist()
    else:
        matched_station_river = None
        stations = []

    matches.append({
        "hpp_name": row["Project Name"],
        "hpp_river": hpp_river_orig,
        "matched_river": matched_station_river,
        "score": score,
        "stations": set(stations),
        "hpp_capacity": row["Capacity (MW)"],
        "hpp_status": row["Status"]
    })

# Convert to DataFrame for easy display or export
matches_df = pd.DataFrame(matches)
matches_df.to_csv(os.path.join(folder_out, 'hpp_grdc_hydro_matches.csv'), index=False)


In [115]:
# Associate hydropower plants with gauging stations based on proximity
def associate_hpp_station(data_hpp, data_station, matches_df):
    # Step 1: Build GeoDataFrames
    gdf_hpp = gpd.GeoDataFrame(
        data_hpp.copy(),
        geometry=gpd.points_from_xy(data_hpp["Longitude"], data_hpp["Latitude"]),
        crs="EPSG:4326"
    )

    gdf_station = gpd.GeoDataFrame(
        data_station.copy(),
        geometry=gpd.points_from_xy(data_station["geo_x"], data_station["geo_y"]),
        crs="EPSG:4326"
    )

    # Step 2: Project both to a metric CRS (meters)
    gdf_hpp = gdf_hpp.to_crs(epsg=3857)
    gdf_station = gdf_station.to_crs(epsg=3857)

    # Step 3: Build dictionary from matches_df
    station_lookup = {
        row["hpp_name"]: row["stations"]
        for _, row in matches_df.iterrows()
        if row["stations"]
    }

    # Step 4: For each HPP, associate station
    results = []

    for idx, hpp in gdf_hpp.iterrows():
        hpp_name = hpp["Project Name"]
        if hpp_name in station_lookup and station_lookup[hpp_name]:
            # Use the first station from the match (same river)
            station_name = list(station_lookup[hpp_name])[0]
            station = gdf_station[gdf_station["station_name"] == station_name].iloc[0]
            distance = hpp.geometry.distance(station.geometry)
            results.append({
                "hpp_index": idx,
                "nearest_station": station_name,
                "distance_to_station_m": distance,
                "same_river": True
            })
        else:
            # Find nearest station by distance
            distances = gdf_station.geometry.distance(hpp.geometry)
            min_idx = distances.idxmin()
            nearest_station = gdf_station.loc[min_idx]
            results.append({
                "hpp_index": idx,
                "nearest_station": nearest_station["station_name"],
                "distance_to_station_m": distances[min_idx],
                "same_river": False
            })

    # Step 5: Merge back into HPP GeoDataFrame
    nearest_df = pd.DataFrame(results)
    gdf_hpp["nearest_station"] = nearest_df["nearest_station"].values
    gdf_hpp["distance_to_station_m"] = nearest_df["distance_to_station_m"].values
    gdf_hpp["same_river"] = nearest_df["same_river"].values

    return gdf_hpp

gdf_hpp = associate_hpp_station(data_hpp, data_station_filtered, matches_df)
data_hpp = data_hpp.merge(
    gdf_hpp[["Project Name", "nearest_station", "distance_to_station_m", 'same_river']],
)

## 6. Visualize all data on the maps

In [116]:
# Function to clip rivers to the bounding box of runoff stations
def clip_rivers_to_stations(rivers_path, df_runoff, ord_stra=6):
    # Step 1: Load the HydroRIVERS shapefile

    # To be downloaded from: https://www.hydrosheds.org/products/hydrorivers using shapefile format
    # Load available layers (should return e.g. 'HydroRIVERS_v10_af')

    if not os.path.exists(rivers_path):
        raise FileNotFoundError(f"Geodatabase not found: {rivers_path}")

    # Load rivers
    rivers_gdf = gpd.read_file(rivers_path)

    # Filter based on average discharge and stream order
    rivers_gdf = rivers_gdf[(rivers_gdf["DIS_AV_CMS"] > 50) & (rivers_gdf["ORD_STRA"] >= 6) & (rivers_gdf["ORD_FLOW"] <= 5)]
    #rivers_gdf = rivers_gdf[(rivers_gdf["ORD_STRA"] >= ord_stra)]

    # Step 2: Prepare the bounding box for clipping
    # Create GeoDataFrame from stations
    stations_gdf = gpd.GeoDataFrame(
        df_runoff,
        geometry=gpd.points_from_xy(df_runoff["geo_x"], df_runoff["geo_y"]),
        crs="EPSG:4326"
    )

    # Get bounding box coordinates from the stations
    roi_bounds = stations_gdf.total_bounds  # [minx, miny, maxx, maxy]

    # Create a polygon box from the bounds
    roi_polygon = box(*roi_bounds).buffer(1.0)  # optional padding of 1 degree

    # Create a GeoDataFrame for clipping
    roi_gdf = gpd.GeoDataFrame(geometry=[roi_polygon], crs="EPSG:4326")

    rivers_clipped = gpd.clip(rivers_gdf, roi_gdf)

    return rivers_clipped


In [117]:
# This cell can take a while to run (some minutes), depending on the number of stations and rivers
def make_maps(df_runoff, data_hpp=None, rivers_path=None, folder_out=folder_out):

    # Read station data
    agg = (
        df_runoff.groupby(["station_name", "geo_x", "geo_y"])
          .agg(
              avg_runoff=("runoff_mm_year", "mean"),
              n_years=("year", "count")
          )
          .reset_index()
    )

    # Create base map centered on the region
    center_lat = agg["geo_y"].mean()
    center_lon = agg["geo_x"].mean()

    m = folium.Map(location=[center_lat, center_lon], zoom_start=5, tiles="CartoDB positron")

    # Add rivers from HydroRivers
    name = f'station_runoff_map.html'
    if rivers_path is not None:
        rivers_clipped = clip_rivers_to_stations(rivers_path, df_runoff)

        colormap = plt.cm.viridis
        norm = mcolors.Normalize(
            vmin=rivers_clipped["DIS_AV_CMS"].min(),
            vmax=rivers_clipped["DIS_AV_CMS"].max())


        # Add rivers to the map with color based on discharge
        for _, row in rivers_clipped.iterrows():
            discharge = row["DIS_AV_CMS"]
            color = mcolors.to_hex(colormap(norm(discharge)))

            folium.GeoJson(
                row["geometry"],
                style_function=lambda feature, color=color: {
                    "color": color,
                    "weight": 2,
                    "opacity": 0.8
                }
            ).add_to(m)


        # Define vmin and vmax
        vmin = rivers_clipped["DIS_AV_CMS"].min()
        vmax = rivers_clipped["DIS_AV_CMS"].max()

        ticks_raw = np.linspace(vmin, vmax, 5)
        ticks = [round(t, -int(np.floor(np.log10(t))) + 1) for t in ticks_raw]  # e.g., 72 → 70, 1502 → 1500

        norm = mcolors.Normalize(vmin=min(ticks), vmax=max(ticks))
        colors = [mcolors.to_hex(plt.cm.viridis(norm(t))) for t in ticks]

        # Add high-contrast legend
        discharge_colormap = LinearColormap(
            colors=colors,
            index=ticks,
            vmin=vmin,
            vmax=vmax,
            caption="Avg River Discharge (m³/s)"
        )
        discharge_colormap.add_to(m)

        name = f'station_runoff_map_with_rivers.html'

    # Include station
    if agg is not None:
        norm_years = LogNorm(vmin=agg["n_years"].min(), vmax=agg["n_years"].max())
        colormap = plt.cm.viridis  # or any other colormap

        for _, row in agg.iterrows():

            # Normalize color from colormap
            rgba = colormap(norm_years(row["n_years"]))
            hex_color = mcolors.to_hex(rgba)

            # Scale size (radius) based on average runoff
            radius = max(4, (row["avg_runoff"] / agg["avg_runoff"].max()) * 20)

            # Add circle marker
            folium.CircleMarker(
                location=[row["geo_y"], row["geo_x"]],
                radius=radius,
                color=hex_color,
                fill=True,
                fill_opacity=0.8,
                popup=(f"<b>{row['station_name']}</b><br>"
                       f"Avg Runoff: {row['avg_runoff']:.0f} mm/year<br>"
                       f"Years: {row['n_years']}")
            ).add_to(m)


        # FeatureGroup for removed stations
        removed_stations_group = folium.FeatureGroup(name="Removed Stations", show=True)

        # Plot removed due to missing area (red X)
        if stations_removed_area_df is not None and not stations_removed_area_df.empty:
            for _, row in stations_removed_area_df.iterrows():
                folium.RegularPolygonMarker(
                    location=[row["geo_y"], row["geo_x"]],
                    number_of_sides=4,
                    radius=5,
                    rotation=45,
                    color="red",
                    fill=True,
                    fill_color="red",
                    fill_opacity=1.0,
                    popup=f"<b>{row['station_name']}</b><br>Removed: Area ≤ 0"
                ).add_to(removed_stations_group)

        # Plot removed due to low flow (orange X)
        if stations_removed_lowflow_df is not None and not stations_removed_lowflow_df.empty:
            for _, row in stations_removed_lowflow_df.iterrows():
                folium.RegularPolygonMarker(
                    location=[row["geo_y"], row["geo_x"]],
                    number_of_sides=4,
                    radius=5,
                    rotation=45,
                    color="orange",
                    fill=True,
                    fill_color="orange",
                    fill_opacity=1.0,
                    popup=f"<b>{row['station_name']}</b><br>Removed: Flow < threshold"
                ).add_to(removed_stations_group)

        # Add group to map
        removed_stations_group.add_to(m)

        # Create a combined legend for stations color and size
        years_vals = np.percentile(agg["n_years"], [0, 50, 100]).round(0).astype(int)
        runoff_vals = np.percentile(agg["avg_runoff"], [0, 50, 100]).round(0)

        combined_legend_html = f'''
        <div style="
            position: fixed;
            bottom: 50px; left: 50px; width: 260px;
            background-color: white;
            border: 2px solid grey;
            z-index: 9999;
            font-size: 13px;
            padding: 12px;">

            <b style="font-size:14px;">Gauging Stations</b><br><br>

            <b>Years of Available Data (Color)</b><br>
            <i style="background: #440154; width: 18px; height: 18px;
               float: left; margin-right: 5px;"></i> {years_vals[0]}<br>
            <i style="background: #21918c; width: 18px; height: 18px;
               float: left; margin-right: 5px;"></i> {years_vals[1]}<br>
            <i style="background: #fde725; width: 18px; height: 18px;
               float: left; margin-right: 5px;"></i> {years_vals[2]}<br><br>

            <b>Avg Runoff (mm/year)</b><br>
            <svg width="120" height="90">
              <circle cx="30" cy="20" r="6" fill="#777" fill-opacity="0.7" stroke="#333" />
              <text x="50" y="25" font-size="12">{int(runoff_vals[0])}</text>
              <circle cx="30" cy="45" r="10" fill="#777" fill-opacity="0.7" stroke="#333" />
              <text x="50" y="50" font-size="12">{int(runoff_vals[1])}</text>
              <circle cx="30" cy="75" r="14" fill="#777" fill-opacity="0.7" stroke="#333" />
              <text x="50" y="80" font-size="12">{int(runoff_vals[2])}</text>
            </svg>

            <b>Removed Stations</b><br>
            <i class="fa fa-times" style="color:red; margin-right: 6px;"></i> Area missing<br>
            <i class="fa fa-times" style="color:orange; margin-right: 6px;"></i> Low flow
        </div>
        '''

        # Add both legend to the map
        m.get_root().html.add_child(folium.Element(combined_legend_html))

    # Include hydropower plants
    if data_hpp is not None:

        # Define fixed status colors
        custom_status_colors = {
            'operating': '#2ca02c',           # green
            'construction': '#ff7f0e',        # orange
            'announced': '#999999',           # grey
            'pre-construction': '#9467bd'     # purple
        }

        # Filter data_hpp for selected statuses only
        valid_statuses = list(custom_status_colors.keys())
        data_hpp = data_hpp[data_hpp['Status'].isin(valid_statuses)]


        # Normalize capacity for marker sizing
        cap_min = data_hpp['Capacity (MW)'].min()
        cap_max = data_hpp['Capacity (MW)'].max()

        def normalize_capacity(value):
            return 6 + ((value - cap_min) / (cap_max - cap_min + 1e-6)) * 10  # radius 4–14


        # Convert RGBA to HEX for Folium
        status_colors = {k: mcolors.to_hex(v) for k, v in custom_status_colors.items()}

        # Add each HPP as a square marker
        for _, row in data_hpp.iterrows():
            popup_html = "<br>".join([f"<b>{k}</b>: {v}" for k, v in row.items()])
            color = status_colors.get(row['Status'], '#555')
            size = normalize_capacity(row['Capacity (MW)'])

            folium.RegularPolygonMarker(
                location=[row['Latitude'], row['Longitude']],
                number_of_sides=4,  # square
                radius=size,
                color=None,
                fill=True,
                fill_color=color,
                fill_opacity=0.9,
                popup=folium.Popup(popup_html, max_width=300)
            ).add_to(m)

        # Prepare capacity values for size scale legend
        capacity_vals = np.percentile(data_hpp['Capacity (MW)'], [0, 50, 100]).round(0).astype(int)
        circle_sizes = [6, 10, 14]  # Match values used in marker radius scaling


        # Create status color legend HTML
        status_legend_items = ""
        for status in valid_statuses:
            hex_color = custom_status_colors[status]
            status_legend_items += f'''
                <i style="background:{hex_color}; width: 18px; height: 18px;
                float: left; margin-right: 6px;"></i> {status}<br>
            '''

        # Full HTML block
        hpp_legend_html = f'''
        <div style="
            position: fixed;
            bottom: 50px; left: 340px; width: 280px;
            background-color: white;
            border: 2px solid grey;
            z-index: 9999;
            font-size: 13px;
            padding: 12px;">

            <b style="font-size:14px;">Hydropower Plants</b><br><br>

            <b>Status (Color)</b><br>
            {status_legend_items}
            <br>

            <b>Capacity (MW)</b><br>
            <svg width="160" height="100">
              <rect x="20" y="10" width="{circle_sizes[0]*2}" height="{circle_sizes[0]*2}" fill="#999"/>
              <text x="60" y="25" font-size="12">{capacity_vals[0]}</text>
              <rect x="20" y="40" width="{circle_sizes[1]*2}" height="{circle_sizes[1]*2}" fill="#999"/>
              <text x="60" y="55" font-size="12">{capacity_vals[1]}</text>
              <rect x="20" y="70" width="{circle_sizes[2]*2}" height="{circle_sizes[2]*2}" fill="#999"/>
              <text x="60" y="85" font-size="12">{capacity_vals[2]}</text>
            </svg>
        </div>
        '''
        m.get_root().html.add_child(folium.Element(hpp_legend_html))


    # Step 4: Save or display
    m.save(os.path.join(folder_out, name))
    m


# Add info


make_maps(data_runoff, data_hpp=data_hpp, rivers_path = "data_grdc_hydro_capp/hydro_input/HydroRIVERS_v10_af_shp/HydroRIVERS_v10_af.shp")

## 7. Plot station runoff for each hydropower plant

In [118]:
stations = data_hpp['nearest_station'].unique()
# Filter data_runoff to include only stations that have hydropower plants associated
temp = data_runoff[data_runoff['station_name'].isin(stations)].copy()
# Subplots for each station and evolution of

In [120]:
def plot_runoff_evolution(data_runoff, data_hpp, folder_out=None):
    stations = data_hpp['nearest_station'].unique()
    temp = data_runoff[data_runoff['station_name'].isin(stations)].copy()

    # Map station_name to project name and river match info
    station_to_project = data_hpp.set_index('nearest_station')['Project Name'].to_dict()
    station_to_river_match = data_hpp.set_index('nearest_station')['same_river'].to_dict()

    temp['Project Name'] = temp['station_name'].map(station_to_project)
    temp['same_river'] = temp['station_name'].map(station_to_river_match)

    n_stations = len(stations)
    ncols = 3
    nrows = (n_stations + ncols - 1) // ncols

    fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(18, 5 * nrows), sharey=True)
    axs = axs.flatten()

    min_year = temp['year'].min()
    max_year = temp['year'].max()

    for i, station in enumerate(stations):
        ax = axs[i]
        subset = temp[temp['station_name'] == station]
        ax.plot(subset['year'].astype(int), subset['runoff_mm_year'], marker='o', linestyle='-')
        match_label = "Same River" if station_to_river_match[station] else "Closest Station"
        ax.set_title(f"{station} / {station_to_project[station]} ({match_label})", fontsize=10)
        ax.set_ylabel("Runoff (mm/year)")
        ax.set_xlim(min_year, max_year)
        ax.grid(True)

    # Remove any unused subplots
    for j in range(i + 1, len(axs)):
        fig.delaxes(axs[j])

    plt.tight_layout()
    if folder_out:
        plt.savefig(os.path.join(folder_out, 'runoff_evolution_with_river_info.png'))
        plt.close()
    else:
        plt.show()


plot_runoff_evolution(data_runoff, data_hpp, folder_out=folder_out)


In [138]:
def plot_discharge_month(data_station_filtered, data_hpp, folder_out=None):

    stations = data_hpp['nearest_station'].unique()
    temp = data_station_filtered[data_station_filtered['station_name'].isin(stations)].copy()

    # Map station_name to project name and river match info
    station_to_project = data_hpp.set_index('nearest_station')['Project Name'].to_dict()
    station_to_river_match = data_hpp.set_index('nearest_station')['same_river'].to_dict()

    temp['Project Name'] = temp['station_name'].map(station_to_project)
    temp['same_river'] = temp['station_name'].map(station_to_river_match)

    n_stations = len(stations)
    ncols = 3
    nrows = (n_stations + ncols - 1) // ncols

    fig, axs = plt.subplots(nrows=nrows, ncols=ncols, figsize=(18, 5 * nrows), sharey=False)
    axs = axs.flatten()

    for i, station in enumerate(stations):
        ax = axs[i]
        subset = temp[temp['station_name'] == station]
        subset = subset.groupby(['Project Name', 'month', 'same_river'])['Q'].mean().reset_index()
        subset['month_abbr'] = subset['month'].astype(int).apply(lambda x: calendar.month_abbr[x])
        ax.plot(subset['month_abbr'], subset['Q'], marker='o', linestyle='-')
        match_label = "Same River" if station_to_river_match[station] else "Closest Station"
        ax.set_title(f"{station} / {station_to_project[station]} ({match_label})", fontsize=10)
        ax.set_ylabel("Discharge (m3/s)")
        ax.grid(True)

    # Remove any unused subplots
    for j in range(i + 1, len(axs)):
        fig.delaxes(axs[j])

    plt.tight_layout()
    if folder_out:
        plt.savefig(os.path.join(folder_out, 'runoff_evolution_with_river_info.png'))
        plt.close()
    else:
        plt.show()

plot_discharge_month(data_station_filtered, data_hpp, folder_out=folder_out)

## Filling missing values (in progress)

In [15]:
def fill_missing_climatology(df, min_months=3):
    """
    Fill missing values using monthly climatology (station-wise).
    Returns a filled DataFrame and a mask of filled values.
    """
    df = df.copy()

    # Step 1: Filter out sparse years
    valid_years = (
        df.groupby(["station_name", "year"])["Q"]
        .apply(lambda x: x.notna().sum() >= min_months)
        .reset_index(name="keep")
    )

    # Merge to filter out sparse rows
    df = df.merge(valid_years[valid_years["keep"]], on=["station_name", "year"])
    df.drop(columns="keep", inplace=True)

    # Step 2: Build climatology
    climatology = (
        df.groupby(["station_name", "month"])["Q"]
        .mean()
        .rename("Q_clim")
        .reset_index()
    )

    df = df.merge(climatology, on=["station_name", "month"], how="left")

    # Step 3: Fill missing with climatology
    fill_mask = df["Q"].isna()
    df.loc[fill_mask, "Q"] = df.loc[fill_mask, "Q_clim"]

    df.drop(columns="Q_clim", inplace=True)

    return df, fill_mask

def drop_sparse_years_and_interpolate(df, min_months=9):
    """
    Drops years with too much missing data and interpolates gaps per station.
    Assumes monthly data. Returns cleaned & interpolated DataFrame.
    """
    df_clean = df.copy()

    # Count non-NaN entries per year-station
    valid_counts = (
        df_clean.groupby(["station_name", "year"])["Q"]
        .apply(lambda x: x.notna().sum())
        .rename("valid_months")
        .reset_index()
    )

    # Keep only rows with enough months
    valid_years = valid_counts[valid_counts["valid_months"] >= min_months]
    df_clean = df_clean.merge(valid_years[["station_name", "year"]], on=["station_name", "year"])

    # Sort for interpolation
    df_clean = df_clean.sort_values(["station_name", "year", "month"])

    # Interpolate per station
    df_clean["Q"] = df_clean.groupby("station_name")["Q"].transform(lambda x: x.interpolate(method='linear', limit_direction='both'))

    return df_clean

def plot_all_fill_methods(df_original, df_clim, df_interp, output_dir):
    """
    Plots original, climatology-filled, and interpolated runoff data per station in one plot.

    Assumes all DataFrames have columns: ['year', 'month', 'station_name', 'Q'].
    """
    os.makedirs(output_dir, exist_ok=True)
    output_dir = os.path.join(output_dir, 'stations')
    os.makedirs(output_dir, exist_ok=True)

    # Combine 'year' and 'month' into datetime
    def add_datetime(df):
        return df.assign(
            date=pd.to_datetime(df['year'].astype(str) + '-' + df['month'].astype(str).str.zfill(2))
        )

    df_original = add_datetime(df_original)
    df_clim = add_datetime(df_clim)
    df_interp = add_datetime(df_interp)

    # Loop over each station
    stations = df_original['station_name'].unique()

    for station in stations:
        fig, ax = plt.subplots(figsize=(12, 4))

        # Subsets for current station
        df_o = df_original[df_original['station_name'] == station]
        df_c = df_clim[df_clim['station_name'] == station]
        df_i = df_interp[df_interp['station_name'] == station]

        # Plot all
        ax.plot(df_o['date'], df_o['Q'], label="Original", alpha=0.5, marker='o', linestyle='-', color='black')
        ax.plot(df_c['date'], df_c['Q'], label="Climatology Fill", linestyle='--', color='orange')
        ax.plot(df_i['date'], df_i['Q'], label="Interpolated", linestyle='-', color='blue')

        ax.set_title(f"Station: {station}")
        ax.set_ylabel("Discharge / Runoff (Q)")
        ax.set_xlabel("Time")
        ax.legend()
        ax.grid(True)

        plt.tight_layout()
        plt.savefig(os.path.join(output_dir, f"{station}.png"))
        plt.close()


In [16]:
df_filled, filled_mask = fill_missing_climatology(df_filtered, min_months=6)
df_interpolated = drop_sparse_years_and_interpolate(df_filtered, min_months=9)
plot_all_fill_methods(df_filtered, df_filled, df_interpolated, output_dir=folder_out)

In [17]:
df

Unnamed: 0,time,id,Q,station_name,geo_x,geo_y,area,year,month,station_label
0,1903-01-01,1926500,,MANUEL CAROCA,6.650000,0.12000,103.199997,1903,1,"MANUEL CAROCA (0.12, 6.65)"
1,1903-01-01,1626100,,ANDOK-FOULA,10.230000,0.37000,1700.000000,1903,1,"ANDOK-FOULA (0.37, 10.23)"
2,1903-01-01,1643100,,LAMBARENE,10.230000,-0.68000,205000.000000,1903,1,"LAMBARENE (-0.68, 10.23)"
3,1903-01-01,1643150,,FOUGAMOU,10.593750,-1.21041,22000.000000,1903,1,"FOUGAMOU (-1.21, 10.59)"
4,1903-01-01,1643200,,NDJOLE,10.770000,-0.18000,158100.000000,1903,1,"NDJOLE (-0.18, 10.77)"
...,...,...,...,...,...,...,...,...,...,...
224347,2016-12-01,1670141,,MUYANGE_3,29.860600,-3.39190,1505.599976,2016,12,"MUYANGE_3 (-3.39, 29.86)"
224348,2016-12-01,1670145,,MUYANGE_4,29.809700,-3.52360,662.000000,2016,12,"MUYANGE_4 (-3.52, 29.81)"
224349,2016-12-01,1670150,,MUBUGA,30.041401,-3.37220,933.400024,2016,12,"MUBUGA (-3.37, 30.04)"
224350,2016-12-01,1670160,,NYANKANDA,30.318001,-3.29030,682.000000,2016,12,"NYANKANDA (-3.29, 30.32)"
