In [1]:
# Compare all data sources
import pandas as pd
import matplotlib.pyplot as plt
import os

In [2]:
path_precip_chirps = 'CHIRPS/output/GN/CHIPRS_GN_precipitation_mm-year.csv'
path_precip_era5  = 'era5/output/GN/era5_GN_runoff_mm-year.csv'

folder = 'comparison'
if not os.path.exists(folder): os.makedirs(folder)

In [3]:
runoff_grun = pd.read_csv(path_runoff_grun, index_col=0, header=[0, 1])
runoff_grun = runoff_grun.droplevel(0, axis=1)
runoff_grdc = pd.read_csv(path_runoff_grdc, index_col=0, header=[0])
runoff_era5 = pd.read_csv(path_runoff_era5, index_col=0, header=[0, 1])
runoff_era5 = runoff_era5.droplevel(0, axis=1)

In [4]:
def get_common_dataframes(data_dict):
    """
    Align multiple DataFrames on shared rows (index) and columns.

    Parameters:
        data_dict (dict): Dictionary with names as keys and pandas DataFrames as values.
                          Each DataFrame should have the same structure: rows = years, columns = stations.

    Returns:
        - aligned_data (dict): Dictionary of filtered DataFrames with only common years and stations.
        - unmatched_columns (dict): Dictionary of unmatched station names per source.
    """
    # Get the set of common years (index) and common stations (columns)
    common_index = set.intersection(*[set(df.index) for df in data_dict.values()])
    common_columns = set.intersection(*[set(df.columns) for df in data_dict.values()])

    print(f"✅ Common years: {len(common_index)}")
    print(f"✅ Common stations: {len(common_columns)}")

    # Filter all dataframes
    aligned_data = {
        name: df.loc[sorted(common_index), sorted(common_columns)]
        for name, df in data_dict.items()
    }

    # Report stations that are missing in each dataset
    unmatched_columns = {
        name: set(df.columns) - common_columns
        for name, df in data_dict.items()
    }

    for name, unmatched in unmatched_columns.items():
        print(f"📌 Stations only in {name}: {unmatched}")

    return aligned_data, unmatched_columns

def calculate_station_error(grun_df, grdc_df, method="rmse"):
    """
    Compute error metrics (e.g. RMSE, MAE, correlation) per station.
    Assumes input DataFrames are aligned on years and stations.
    """
    import numpy as np
    results = {}

    for station in grun_df.columns:
        y_true = grdc_df[station]
        y_pred = grun_df[station]

        mask = y_true.notna() & y_pred.notna()
        if mask.sum() == 0:
            continue

        if method == "rmse":
            error = np.sqrt(np.mean((y_true[mask] - y_pred[mask])**2))
        elif method == "mae":
            error = np.mean(np.abs(y_true[mask] - y_pred[mask]))
        elif method == "corr":
            error = y_true[mask].corr(y_pred[mask])
        else:
            raise ValueError("Unsupported method")

        results[station] = error

    return pd.Series(results, name=method)

def plot_station_comparison_dict(data_dict, station, save_path=None):
    """
    Plot runoff time series from multiple datasets for a single station.

    Parameters:
        data_dict (dict): keys are dataset names (e.g., 'GRUN', 'GRDC'),
                          values are DataFrames with years as index and station names as columns
        station (str): station name to plot
        save_path (str or None): if provided, saves plot to file; else displays it
    """
    plt.figure(figsize=(10, 4))

    # Loop through all datasets and plot the station
    for label, df in data_dict.items():
        if station in df.columns:
            plt.plot(df.index, df[station], label=label, marker='o')
        else:
            print(f"⚠️ Station '{station}' not found in dataset '{label}' — skipping.")

    plt.title(f"Runoff Comparison for Station: {station}")
    plt.xlabel("Year")
    plt.ylabel("Runoff (mm/year)")
    plt.legend()
    plt.grid(True)

    if save_path:
        plt.savefig(save_path, dpi=300)
        plt.close()
    else:
        plt.show()

In [5]:
data_dict = {
    "GRUN": runoff_grun,
    "GRDC": runoff_grdc,
    "ERA5": runoff_era5
}

aligned_data, unmatched_columns = get_common_dataframes(data_dict)

grun_aligned = aligned_data["GRUN"]
grdc_aligned = aligned_data["GRDC"]
era5_aligned = aligned_data["ERA5"]

rmse_per_station = calculate_station_error(grun_aligned, grdc_aligned, method="rmse")
print(rmse_per_station.sort_values())

data_aligned_dict = {
    "GRUN": grun_aligned,
    "GRDC": grdc_aligned,
    "ERA5": era5_aligned
}

# Plot one station comparison
for station in grun_aligned.columns:
    plot_station_comparison_dict(data_aligned_dict, station=station, save_path=os.path.join(folder, f'comparison_{station}.png'))

✅ Common years: 65
✅ Common stations: 18
📌 Stations only in GRUN: set()
📌 Stations only in GRDC: {'GAOUAL'}
📌 Stations only in ERA5: set()
KONSANKORO            76.175538
MANDIANA              78.637768
DIALAKORO             81.604306
TIGUIBERY             99.933699
BARO                 122.448649
KANKAN               128.989196
TINKISSO             139.224900
FARANAH              150.216020
OUARAN               158.831804
DABOLA               165.981345
KOUROUSSA            200.474491
PONT DE TELIMELE     361.183856
BAC                  371.869157
NONGOA               412.496708
BADERA               422.344884
DIAWLA               766.345341
KISSIDOUGOU         1146.609602
KEROUANE            1843.311743
Name: rmse, dtype: float64
