# All Model Errors

In [1]:
from pathlib import Path
import os
import warnings

%load_ext autoreload
%autoreload 2

# ignore warnings for now ...
warnings.filterwarnings('ignore')

if Path('.').absolute().parents[1].name == 'ml_drought':
    os.chdir(Path('.').absolute().parents[1])

!pwd

/home/tommy/ml_drought


In [2]:
import xarray as xr
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import pickle
import seaborn as sns
import matplotlib as mpl
from tqdm import tqdm
from collections import defaultdict

mpl.rcParams['figure.dpi'] = 150

In [3]:
label_size = 14  # 10
plt.rcParams.update(
    {'axes.labelsize': label_size,
     'legend.fontsize': label_size,
     "font.size": 14,
    }
)

In [4]:
# data_dir = Path('data/')
data_dir = Path('/cats/datastore/data/')

assert data_dir.exists()

In [5]:
from src.utils import drop_nans_and_flatten

from src.analysis import read_train_data, read_test_data, read_pred_data
from src.analysis.evaluation import join_true_pred_da
from src.models import load_model

# Read in the CAMELS data

In [6]:
# read in the training data
ds = xr.open_dataset(data_dir / "RUNOFF/ALL_dynamic_ds.nc")
ds['station_id'] = ds['station_id'].astype(int)

all_static = xr.open_dataset(data_dir / f'RUNOFF/interim/static/data.nc')
all_static['station_id'] = all_static['station_id'].astype(int)
static = all_static

In [7]:
bool_wb = xr.open_dataset(data_dir / "RUNOFF/bool_water_balance_20pct.nc")
bool_wb = bool_wb.to_array().isel(variable=0).drop("variable")
bool_wb

# Read AWS Trained Models

In [8]:
lstm_ensemble_df_peti = pd.read_csv("/cats/datastore/data/runs/ensemble/data_ENS.csv").drop("Unnamed: 0", axis=1)
lstm_ensemble_df_peti["time"] = pd.to_datetime(lstm_ensemble_df_peti["time"])
lstm_ensemble_peti = lstm_ensemble_df_peti.set_index(["station_id", "time"]).to_xarray()
lstm_preds_peti = lstm_ensemble_peti

In [9]:
# ealstm_less_vars = pd.read_csv(data_dir / "runs/ealstm_less_vars_2004_1707_1424/results_ealstm_less_vars_2004_1707_1424_E015.csv")
# ealstm_preds = xr.open_dataset(data_dir / "runs/ensemble/data_ENS.csv")
ealstm_ensemble_df = pd.read_csv(data_dir / "runs/ensemble_EALSTM/data_ENS.csv").drop("Unnamed: 0", axis=1)
ealstm_ensemble_df["time"] = pd.to_datetime(ealstm_ensemble_df["time"])
ealstm_preds = ealstm_ensemble_df.set_index(["station_id", "time"]).to_xarray()

ealstm_preds["station_id"] = [int(sid) for sid in ealstm_preds["station_id"]]

In [10]:
lstm_ensemble_df = pd.read_csv("/cats/datastore/data/runs/ensemble_pet/data_ENS.csv").drop("Unnamed: 0", axis=1)
lstm_ensemble_df = pd.read_csv("/cats/datastore/data/runs/ensemble_pet_trainperiod/data_ENS.csv").drop("Unnamed: 0", axis=1)
lstm_ensemble_df["time"] = pd.to_datetime(lstm_ensemble_df["time"])
lstm_ensemble = lstm_ensemble_df.set_index(["station_id", "time"]).to_xarray()
lstm_preds = lstm_ensemble

In [11]:
metric_df = pd.read_csv(data_dir / "runs/ensemble_pet/metric_df.csv", index_col=0)
metric_df.columns = [c.lower() for c in metric_df.columns]
metric_df.head()

Unnamed: 0,station_id,nse,kge,mse,fhv,fms,flv
0,10002,0.90534,0.867626,0.21353,-8.473469,15.031782,12.874872
1,10003,0.92587,0.880633,0.113025,-5.930443,2.789133,37.174278
2,1001,0.87541,0.91654,0.512555,0.733456,-23.276409,65.910174
3,101002,0.757151,0.649167,0.300624,-24.224581,-2.596738,70.965051
4,101005,0.827427,0.814456,0.205723,-16.600439,2.974179,44.405252


# FUSE Data

In [12]:
all_paths = [d for d in (data_dir / "RUNOFF/FUSE/Timeseries_SimQ_Best/").glob("*_Best_Qsim.txt")]

if not (data_dir / "RUNOFF/ALL_fuse_ds.nc").exists():
    all_dfs = []
    for txt in tqdm(all_paths):
        df = pd.read_csv(txt, skiprows=3, header=0)
        df.columns = [c.rstrip().lstrip() for c in df.columns]
        df = df.rename(columns={"YYYY": "year", "MM": "month", "DD": "day"})
        df["time"] = pd.to_datetime(df[["year", "month", "day"]])
        station_id = int(str(txt).split("/")[-1].split("_")[0])
        df["station_id"] = [station_id for _ in range(len(df))]
        df = df.drop(["year", "month", "day", "HH"], axis=1).set_index(["station_id", "time"])
        all_dfs.append(df)
        
    fuse_ds = pd.concat(all_dfs).to_xarray()
    fuse_ds.to_netcdf(data_dir / "RUNOFF/ALL_fuse_ds.nc")
    
else:
    fuse_ds = xr.open_dataset(data_dir / "RUNOFF/ALL_fuse_ds.nc")

In [13]:
fuse_ds = fuse_ds.sel(time=slice('1998-01-01', '2009-01-01'))

In [14]:
# join with observations for stations that exist
obs = (
    ds.sel(station_id=np.isin(ds["station_id"], fuse_ds["station_id"]), time=fuse_ds["time"])["discharge_spec"]
).rename("obs")
fuse_data = fuse_ds.sel(station_id=obs.station_id).merge(obs)

# Test on only the water balancing stations?

In [15]:
ONLY_WATER_BALANCING = False

if ONLY_WATER_BALANCING:
    hold_out_fuse = fuse_data.sel(station_id=np.isin(fuse_data.station_id, bool_wb.where(~bool_wb, drop=True).station_id))
    fuse_data = fuse_data.sel(station_id=np.isin(fuse_data.station_id, bool_wb.where(bool_wb, drop=True).station_id))


# Match Stations / Times

In [16]:
if ONLY_WATER_BALANCING:
    hold_out_stations_lstm = np.isin(lstm_preds.station_id, hold_out_fuse.station_id)
    hold_out_stations_ealstm = np.isin(ealstm_preds.station_id, hold_out_fuse.station_id)
    
all_stations_lstm = np.isin(lstm_preds.station_id, fuse_data.station_id)
all_stations_ealstm = np.isin(ealstm_preds.station_id, fuse_data.station_id)

In [17]:
if ONLY_WATER_BALANCING:
    hold_out_lstm_preds = lstm_preds.sel(station_id=hold_out_stations_lstm, time=fuse_data.time)
    hold_out_lstm_preds = ealstm_preds.sel(station_id=hold_out_stations_ealstm, time=fuse_data.time)

lstm_preds = lstm_preds.sel(station_id=all_stations_lstm, time=fuse_data.time)
ealstm_preds = ealstm_preds.sel(station_id=all_stations_ealstm, time=fuse_data.time)

# Errors

In [18]:
!git pull
from scripts.drafts.calculate_error_scores import calculate_errors, error_func

Already up to date.


In [19]:
ealstm_df = calculate_errors(ealstm_preds).set_index("station_id")
lstm_df = calculate_errors(lstm_preds).set_index("station_id")

# lstm_df.to_csv(data_dir / "RUNOFF/PET_LSTM_RESULTS.csv")

In [20]:
metric_df["rmse"] = np.sqrt(metric_df["mse"])
lstm_df["rmse"] = np.sqrt(lstm_df["mse"])
ealstm_df["rmse"] = np.sqrt(ealstm_df["mse"])

display(metric_df.set_index("station_id").sort_index().head())
display(lstm_df.sort_index().head())

Unnamed: 0_level_0,nse,kge,mse,fhv,fms,flv,rmse
station_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
1001,0.87541,0.91654,0.512555,0.733456,-23.276409,65.910174,0.71593
2001,0.795517,0.734169,1.299555,-22.914962,-9.380016,15.606193,1.13998
2002,0.796522,0.704892,2.278604,-28.140312,-22.949306,61.070805,1.509504
3003,0.879771,0.849186,4.217607,-14.253567,-19.133229,45.2644,2.053681
4001,0.873894,0.879259,1.772178,-1.986958,-18.85153,44.173871,1.331232


Unnamed: 0_level_0,nse,kge,mse,bias,log_nse,inv_kge,abs_pct_bias,mape,mam30_ape,rmse
station_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
1001,0.857709,0.866834,0.585375,-5.134048,0.81215,0.039699,5.134048,65.996419,146.056224,0.765098
2001,0.773687,0.728138,1.438293,-13.445427,0.870523,0.735259,13.445427,22.155156,24.44301,1.199288
2002,0.783678,0.695336,2.422432,-10.413825,0.844212,0.403838,10.413825,43.833926,55.081344,1.556416
3003,0.839633,0.785481,5.62565,-4.024232,0.866475,0.307938,4.024232,46.916948,46.595603,2.371845
4001,0.866291,0.860072,1.879023,-1.970809,0.79738,0.549027,1.970809,30.063662,28.471197,1.370775


In [21]:
lstm_preds

In [22]:
# for station_id in lstm_preds.station_id
station_id = lstm_preds.station_id.values[0]
df = lstm_preds.sel(station_id=station_id).to_dataframe().drop("station_id", axis=1)
# station_id
df.head()

Unnamed: 0_level_0,obs,sim
time,Unnamed: 1_level_1,Unnamed: 2_level_1
1998-01-01,2.14,1.525811
1998-01-02,1.81,1.758741
1998-01-03,4.57,2.450145
1998-01-04,3.3,2.389513
1998-01-05,3.1,2.043354


In [23]:
if False:
    import HydroErr as he
    import hydrostats as hs
    from typing import List
    from tqdm import tqdm

    all_result_dfs: List[pd.DataFrame] = []
    preds: xr.Dataset = lstm_preds
    # preds = fuse_data[["obs", "SimQ_TOPMODEL"]].rename({"SimQ_TOPMODEL": "TOPMODEL"})

    for ix, station_id in tqdm(enumerate(preds.station_id.values), desc="Calculating Metrics"):
    # station_id = lstm_preds.station_id.values[0]
        df = preds.sel(station_id=station_id).to_dataframe().drop("station_id", axis=1)

        epsilon = 1e-10

        result_df = hs.make_table(merged_dataframe=df, 
                      metrics=['NSE', 'KGE (2012)', 'MAPE'], 
                      seasonal_periods=[['12-01', '02-29'], ['03-01', '05-31'], ['06-01', '08-31'], ['09-01', '11-30']], 
                      remove_neg=True, remove_zero=False, 
                      location=station_id
        )

        inv_kge_df = hs.make_table(merged_dataframe=(1 / df + epsilon), 
                      metrics=['KGE (2012)'], 
                      seasonal_periods=[['12-01', '02-29'], ['03-01', '05-31'], ['06-01', '08-31'], ['09-01', '11-30']], 
                      remove_neg=True, remove_zero=False, 
                      location=station_id
        ).rename({"KGE (2012)": "invKGE"}, axis=1)
        log_nse_df = hs.make_table(merged_dataframe=np.log(df + epsilon), 
                      metrics=['NSE'], 
                      seasonal_periods=[['12-01', '02-29'], ['03-01', '05-31'], ['06-01', '08-31'], ['09-01', '11-30']], 
                      remove_neg=True, remove_zero=False, 
                      location=station_id
        ).rename({"NSE": "logNSE"}, axis=1)

        # join all error metrics together
        result_df = pd.concat([result_df, inv_kge_df.drop("Location", axis=1), log_nse_df.drop("Location", axis=1)], axis=1)

        # rename columns/rows
        result_df = result_df.rename({
            "Full Time Series": "All",
            "December-01:February-29": "DJF",
            "March-01:May-31": "MAM",
            "June-01:August-31": "JJA",
            "September-01:November-30": "SON",
        }).rename({"Location": "station_id", "KGE (2012)": "KGE"}, axis=1)
        result_df = result_df.reset_index().rename({"index": "period"}, axis=1)
        all_result_dfs.append(result_df)

        if ix == 5:
            break


    out_df = pd.concat(all_result_dfs)
    display(out_df.loc[out_df["period"] == "All"].set_index("station_id").head())

# FUSE - Calculate from Sim

In [24]:
!git pull
from scripts.drafts.calculate_error_scores import FuseErrors

if False:
    # 1. CALCULATE all error metric (slower code for same results)
    f_class = FuseErrors(fuse_data)
    fuse_errors = f_class.fuse_errors
        
    # 2. extract the error dfs 
    fuse_bias = f_class.get_metric_df("bias")
    fuse_nse_df = f_class.get_metric_df("nse")
    fuse_kge_df = f_class.get_metric_df("kge")
    print(fuse_kge_df.shape)
    print(fuse_kge_df.dropna().shape)
    fuse_kge_df.dropna().head()

Already up to date.


In [25]:
from scripts.drafts.calculate_error_scores import calculate_all_data_errors, get_metric_dataframes_from_output_dict

fuse_output_dict = calculate_all_data_errors(fuse_data)
fuse_metric_dict = get_metric_dataframes_from_output_dict(fuse_output_dict)

Errors: 100%|██████████| 4/4 [00:50<00:00, 12.63s/it]


In [26]:
if False:
    from src.analysis.evaluation import _kge_func


    fuse_data[["SimQ_SACRAMENTO", "obs"]]

    true_vals = fuse_data["SimQ_SACRAMENTO"].transpose("station_id", "time")
    pred_vals = fuse_data["obs"].transpose("station_id", "time")

    out = dict()
    for station_id in fuse_data.station_id.values:
        true_vals = fuse_data["SimQ_SACRAMENTO"].sel(station_id=station_id)
        pred_vals = fuse_data["obs"].sel(station_id=station_id)

        out[station_id] = _kge_func(true_vals.values, pred_vals.values)


    sacramento = pd.DataFrame(out, index=["kge"]).T
    f, ax = plt.subplots()
    sns.distplot(sacramento)
    sns.distplot(lstm_df["kge"])
    ax.axvline(sacramento["kge"].median(), color="C0")
    ax.axvline(lstm_df["kge"].median(), color="C1")
    sns.despine()

In [27]:
fuse_data

In [28]:
# fuse_output_dict
# fuse_metric_dict = get_metric_dataframes_from_output_dict(fuse_output_dict)

In [29]:
fuse_bias = fuse_metric_dict["bias"]
fuse_nse_df = fuse_metric_dict["nse"]
fuse_kge_df = fuse_metric_dict["kge"]
fuse_invkge_df = fuse_metric_dict["inv_kge"]
fuse_lognse_df = fuse_metric_dict["log_nse"]
fuse_mape_df = fuse_metric_dict["mape"]
fuse_abs_pct_bias_df = fuse_metric_dict["abs_pct_bias"]

In [30]:
topmodel_df = fuse_output_dict["TOPMODEL"]
vic_df = fuse_output_dict["VIC"]
sacramento_df = fuse_output_dict["SACRAMENTO"]
prms_df = fuse_output_dict["PRMS"]
display(prms_df.head())
display(lstm_df.head())

Unnamed: 0_level_0,nse,kge,mse,bias,log_nse,inv_kge,abs_pct_bias,mape,mam30_ape,rmse
station_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
1001,0.786218,0.850647,0.879483,8.65518,0.491752,-68.734368,8.65518,86.843233,91.378756,0.879483
2001,0.686722,0.748769,1.99098,0.096087,0.35461,-4.973463,0.096087,45.787281,37.812513,1.99098
2002,0.683937,0.810254,3.539354,-4.450079,-0.010685,-46.244676,4.450079,57.844326,33.976457,3.539354
3003,0.793153,0.889991,7.25615,1.712348,-0.768618,-312.467467,1.712348,59.341185,61.818435,7.25615
4001,0.777022,0.828668,3.133536,-11.385971,0.435414,-1.608734,11.385971,32.896235,42.565224,3.133536


Unnamed: 0_level_0,nse,kge,mse,bias,log_nse,inv_kge,abs_pct_bias,mape,mam30_ape,rmse
station_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
1001,0.857709,0.866834,0.585375,-5.134048,0.81215,0.039699,5.134048,65.996419,146.056224,0.765098
2001,0.773687,0.728138,1.438293,-13.445427,0.870523,0.735259,13.445427,22.155156,24.44301,1.199288
2002,0.783678,0.695336,2.422432,-10.413825,0.844212,0.403838,10.413825,43.833926,55.081344,1.556416
3003,0.839633,0.785481,5.62565,-4.024232,0.866475,0.307938,4.024232,46.916948,46.595603,2.371845
4001,0.866291,0.860072,1.879023,-1.970809,0.79738,0.549027,1.970809,30.063662,28.471197,1.370775


### read published scores (best sim)

In [31]:
# !git pull

In [32]:
from scripts.drafts.calculate_error_scores import FUSEPublishedScores

In [33]:
fuse_dir = data_dir / "FUSE"
fpub = FUSEPublishedScores(fuse_dir)
pub_nse = fpub.read_nse_scores()
pub_best = fpub.read_best_scores()

In [34]:
# pub_nse

# All Errors in one data structure

In [35]:
from scripts.drafts.calculate_error_scores import DeltaError

# calculate all error metrics
processor = DeltaError(ealstm_preds, lstm_preds, fuse_data)
all_preds = processor.all_preds

In [36]:
# # plot_cdf()
if "all_errors" not in globals().keys():
    all_errors = calculate_all_data_errors(all_preds)
all_metrics = get_metric_dataframes_from_output_dict(all_errors)
all_metrics.keys()

Errors: 100%|██████████| 6/6 [01:08<00:00, 11.48s/it]


dict_keys(['nse', 'kge', 'mse', 'bias', 'log_nse', 'inv_kge', 'abs_pct_bias', 'mape', 'mam30_ape', 'rmse'])

## Kratzert errors

In [41]:
!git pull
from scripts.drafts.calculate_error_scores import DeltaError

processor = DeltaError(ealstm_preds, lstm_preds, fuse_data)

Already up to date.


In [38]:
if "kratzert_results" not in globals().keys():
    kratzert_results = processor.calc_kratzert_error_functions(all_preds, metrics=["flv"])

TypeError: calc_kratzert_error_functions() got an unexpected keyword argument 'metrics'

In [None]:
kratzert_metrics = get_metric_dataframes_from_output_dict(kratzert_results)
kratzert_metrics.keys()

In [None]:
summaries = ["median"]
kratzert_metrics["FLV"].describe().describe(percentiles=[0.05, 0.5, 0.95]).rename(
        {
            "5%": "q5",
            "50%": "median",
            "95%": "q95",
        }
    ).loc[summaries].T

In [None]:
# percent bias in overall runoff ratio


# All Errors Table

In [None]:
metrics = ["nse", "kge", "log_nse", "inv_kge", "bias", "abs_pct_bias", "mam30_ape"]
summaries = ["q5", "median"]

all_summary = []
for metric in metrics:
    summary_df = all_metrics[metric].describe(percentiles=[0.05, 0.5, 0.95]).rename(
        {
            "5%": "q5",
            "50%": "median",
            "95%": "q95",
        }
    ).loc[summaries].T

    summary_df.columns = pd.MultiIndex.from_arrays(
        ([metric.upper() for _ in range(len(summaries))], summaries)
    )

    all_summary.append(summary_df)
    
all_metric_summary = pd.concat(all_summary, axis=1)
all_metric_summary

In [None]:
print(all_metric_summary.to_latex(float_format="%.2f", multirow=True))

# Overall Model Performance Comparison
- Kolmogorov-Smirnov statistic on 2 samples
- This tests whether 2 samples are drawn from the same distribution
- Calculate the Wilcoxon signed-rank test
- tests the null hypothesis that two related paired samples come from the same distribution. In particular, it tests whether the distribution of the differences x - y is symmetric about zero. It is a non-parametric version of the paired T-test.

In [None]:
from scipy.stats import wilcoxon, ks_2samp
from collections import defaultdict

models = ["TOPMODEL", "PRMS", "SACRAMENTO", "ARNOVIC", "EALSTM"]
stations = lstm_df.index

comparison = fuse_nse_df.join(lstm_df["nse"]).rename({"nse": "LSTM"}, axis=1)
comparison = comparison.join(ealstm_df["nse"]).rename({"nse": "EALSTM"}, axis=1)

In [None]:
# f: FuseErrors
from typing import Callable, List


def _result_df(func: Callable, metric_df: pd.DataFrame, models: List[str], ref_model: str) -> pd.DataFrame:
    results = defaultdict(dict)
    for model in models:
        res_ = func(metric_df[model], metric_df[ref_model])
        results[model]["statistic"] = res_.statistic
        results[model]["pvalue"] = res_.pvalue
    
    return pd.DataFrame(results)


def create_joined_metric_df(metric: str):
    # build the dataframe of metrics (FUSE + LSTM + EALSTM)
    metric_df = fuse_metric_dict[metric].join(lstm_df[metric]).rename({metric: "LSTM"}, axis=1)
    metric_df = metric_df.join(ealstm_df[metric]).rename({metric: "EALSTM"}, axis=1)
    metric_df = metric_df.dropna()
    return metric_df 


def run_test(test: str = "ks",  metric: str = "nse", ref_model: str = "LSTM"):
    assert test in ["ks", "wilcoxon"]
    lookup = {"ks": ks_2samp, "wilcoxon": wilcoxon}
    func = lookup[test]
    
    # build the dataframe of metrics (FUSE + LSTM + EALSTM)
    metric_df = create_joined_metric_df(metric=metric)
    
    other_model = "EALSTM" if ref_model == "LSTM" else "LSTM"
    models = ["TOPMODEL", "PRMS", "SACRAMENTO", "ARNOVIC", other_model]
    # run the test
    df = _result_df(func, metric_df, models=models, ref_model=ref_model)
    return df

In [None]:
run_test("ks", "nse", "EALSTM")
run_test("ks", "nse", "LSTM")

In [None]:
# display(run_test("ks", "nse", "LSTM"))
display(run_test("ks", "nse", "LSTM"))

In [None]:
# Wilcoxon Test
data = create_joined_metric_df("nse")

wilcoxon_results = defaultdict(dict)
for model in models:
    res_ = wilcoxon(comparison.dropna()[model], comparison.dropna()["LSTM"])
    wilcoxon_results[model]["statistic"] = res_.statistic
    wilcoxon_results[model]["pvalue"] = res_.pvalue
    
pd.DataFrame(wilcoxon_results)

#### significant testing bias

In [None]:
comparison = fuse_bias.join(lstm_df["bias"]).rename({"bias": "LSTM"}, axis=1)
comparison = comparison.join(ealstm_df["bias"]).rename({"bias": "EALSTM"}, axis=1)

(create_joined_metric_df("bias").dropna() == comparison.dropna()).mean()

In [None]:
# KS Test

ks_2samp_results = defaultdict(dict)
for model in models:
    res_ = ks_2samp(comparison.dropna()[model], comparison.dropna()["LSTM"])
    ks_2samp_results[model]["statistic"] = res_.statistic
    ks_2samp_results[model]["pvalue"] = res_.pvalue
    
display(pd.DataFrame(ks_2samp_results))

run_test("ks", "bias", "LSTM")

In [None]:
# Wilcoxon Test
wilcoxon_results = defaultdict(dict)
for model in models:
    res_ = wilcoxon(comparison.dropna()[model], comparison.dropna()["LSTM"])
    wilcoxon_results[model]["statistic"] = res_.statistic
    wilcoxon_results[model]["pvalue"] = res_.pvalue
    
display(pd.DataFrame(wilcoxon_results))
run_test("wilcoxon", "bias", "LSTM")

# NSE

In [None]:
ml_sids = np.isin(lstm_df.index, [sid for sid in fuse_nse_df.dropna().index])
concept_sids = np.isin(fuse_nse_df.index, [sid for sid in fuse_nse_df.dropna().index])

test_sids = [id_ for id_ in lstm_df[ml_sids].index]
# try:
#     lstm_nse = lstm_df["nse"].set_index("station_id")
#     ealstm_nse = ealstm_df["nse"].set_index("station_id")
# except KeyError:
#     print("Already set Index to SID!")

In [None]:
try:
    lstm_kwargs = {
        "clip": [-0.5, 1],
        "label": f"LSTM: {(lstm_data).median()}",
        "ls": "-",
        "linewidth": 3,
    }
    ealstm_kwargs = {
        "clip": [-0.5, 1],
        "label": f"EALSTM: {(ealstm_data).median()}",
        "ls": "-",
        "linewidth": 3,
    }

    def cdf_plot(data: pd.Series, ax, clip: List = [None, None], kwargs: Dict = {}):
        sns.kdeplot(
            lstm_data,
            cumulative=True,
            legend=False, ax=ax,
            **kwargs
        )


    # def plot_all_cdfs(metric: str, test_sids: List[int]):

    # lstm_data = lstm_df.loc[test_sids, 'nse']
    # ealstm_data = ealstm_df.loc[test_sids, 'nse']
    # concept_data = fuse_nse_df.loc[test_sids]

    # fig, ax = plt.subplots(figsize=(12, 8))

except NameError:
    print("LSTM Data not yet defined")

In [None]:
from typing import Optional, List, Tuple

ml_sids = all_metrics["nse"].index.values

def plot_cdf(
    error_data, metric: str = "", 
    sids: List[int] = ml_sids, 
    clip: Optional[Tuple] = None, 
    ax = None, 
    title=None,
    models: Optional[List[str]] = None,
    median: bool = True
):
    colors = sns.color_palette()
    kwargs_dict = {
        "TOPMODEL": {"linewidth": 1, "alpha":0.8, "color": colors[2], "clip": clip},
        "PRMS": {"linewidth": 1, "alpha":0.8, "color": colors[3], "clip": clip},
        "ARNOVIC": {"linewidth": 1, "alpha":0.8, "color": colors[4], "clip": clip},
        "VIC": {"linewidth": 1, "alpha":0.8, "color": colors[4], "clip": clip},
        "SACRAMENTO": {"linewidth": 1, "alpha":0.8, "color": colors[5], "clip": clip},
        "gr4j": {"linewidth": 1, "alpha":0.8, "color": colors[9], "clip": clip},
        "climatology": {"linewidth": 1, "alpha":0.8, "color": colors[6], "clip": clip, "ls": "-."},
        "climatology_doy": {"linewidth": 1, "alpha":0.8, "color": colors[6], "clip": clip, "ls": "-."},
        "climatology_mon": {"linewidth": 1, "alpha":0.8, "color": colors[8], "clip": clip, "ls": "-."},
        "persistence": {"linewidth": 1, "alpha":0.8, "color": colors[7], "clip": clip, "ls": "-."},
        "EALSTM": {"linewidth": 3, "alpha": 1, "color": colors[1], "clip": clip},
        "LSTM": {"linewidth": 3, "alpha": 1, "color": colors[0], "clip": clip},
    }
    if ax is None:
        fig, ax = plt.subplots(figsize=(12, 3))
    
    if models is None:
        models = [c for c in error_data.columns if c in kwargs_dict]
    for ix, model in enumerate(models):
        summary_stat = error_data[model].dropna().median() if median else error_data[model].dropna().mean()
        sns.kdeplot(
            error_data[model].dropna(),
            cumulative=True,
            legend=False, ax=ax,
            label=f"{model}: {summary_stat:.2f}",
            **kwargs_dict[model]
        )

        ax.axvline(summary_stat, ls="--", color=kwargs_dict[model]["color"])

    ax.set_xlim(clip)
    ax.set_xlabel(metric)
    ax.set_ylabel("Cumulative density")
    title = title if title is not None else f"Cumuluative Density Function of Station {metric} Scores"
    ax.set_title(title)
    sns.despine()
    plt.legend()
    
    return ax

f, ax = plt.subplots(figsize=(12, 8))
plot_cdf(all_metrics["nse"], metric="NSE", title="", ax=ax, clip=(0, 1), median=True);

In [None]:
for column in all_metrics["nse"]:
    print(f"{column} NSE - {all_metrics['nse'][column].median():.2f}")

# Bias

In [None]:
lstm_df.head()
fuse_metric_dict.keys()

fuse_metric_dict["abs_pct_bias"].min()

In [None]:
f, ax = plt.subplots(figsize=(12, 8))
plot_cdf(all_metrics["abs_pct_bias"], metric="Absolute Percentage Bias [%]", title="", ax=ax, clip=(0, 50));

In [None]:
f, ax = plt.subplots(figsize=(12, 8))
plot_cdf(all_metrics["bias"], metric="Mean Bias [%]", title="", ax=ax, clip=(-30, 30));

In [None]:
ml_sids; concept_sids;

In [None]:
for column in all_metrics["bias"]:
    print(f"{column} Bias - {all_metrics['bias'][column].median():.2f}")

In [None]:
for column in all_metrics["abs_pct_bias"]:
    print(f"{column} Abs Bias (%) - {all_metrics['abs_pct_bias'][column].median():.2f}")

# KGE

In [None]:
f, ax = plt.subplots(figsize=(12, 8))
plot_cdf(all_metrics["kge"], metric="KGE", title="", ax=ax, clip=(0, 1), median=False);

In [None]:
print(f"LSTM KGE: {lstm_df['kge'].median():.3f}")
print(f"EALSTM KGE: {ealstm_df['kge'].median():.3f}")

for model in [c for c in fuse_nse_df.columns if (not "Name" in c) and (not "station" in c)]:
    print(f"{model} KGE: {fuse_kge_df[model].dropna().median():.3f}")

# Tables

In [None]:
catchment_ids = [int(c) for c in ["12002", "15006", "27009", "27034", "27041", "39001", "39081", "43021", "47001", "54001", "54057", "71001", "84013",]]
catchment_names = ["Dee@Park", "Tay@Ballathie", "Ouse@Skelton", "Ure@Kilgram", "Derwent@Buttercrambe", "Thames@Kingston", "Ock@Abingdon", "Avon@Knapp", "Tamar@Gunnislake", "Severn@Bewdley", "Severn@Haw", "Ribble@Samlesbury", "Clyde@Daldowie"]
station_map = dict(zip(catchment_ids, catchment_names))

In [None]:
process_errors = pd.read_csv(data_dir / "RUNOFF/jules_classic.csv")

classic = process_errors.loc[process_errors["Model"] == "Classic", :].drop('Model', axis=1)
classic = classic.rename(columns={"ID": "Station ID"}).set_index("Station ID")
classic.columns = [["CLASSIC" for _ in range(len(classic.columns))], classic.columns]

jules = process_errors.loc[process_errors["Model"] == "Jules", :].drop('Model', axis=1)
jules = jules.rename(columns={"ID": "Station ID"}).set_index("Station ID").drop("Name", axis=1)
jules.columns = [["JULES" for _ in range(len(jules.columns))], jules.columns]

process_errors = pd.concat([classic, jules], axis=1)
process_errors.to_pickle(data_dir / "RUNOFF/process_models.pkl")
process_errors

In [None]:
# NSE
nse = fuse_nse_df.join(
    ealstm_df["nse"].rename("EALSTM")
)
nse = nse.join(
    lstm_df["nse"].rename("LSTM")
)

# BIAS
bias = fuse_bias.join(
    ealstm_df["bias"].rename("EALSTM")
)
bias = bias.join(
    lstm_df["bias"].rename("LSTM")
)

# bias['Name'] = nse["Name"]
# bias = bias[["Name"] + [c for c in bias.columns if c != "Name"]]

In [None]:
nse_13 = nse.loc[catchment_ids]
nse_13.index.name = "Station ID"
bias_13 = bias.loc[catchment_ids]
bias_13.index.name = "Station ID"

In [None]:
columns = [
 "LSTM",
 "EALSTM",
 "CLASSIC",
 "JULES",
"PRMS",
"SACRAMENTO",
"TOPMODEL",
"ARNOVIC"]

all_bias = bias_13.join(process_errors.drop(["Name", "NSE"], axis=1, level=1).droplevel(axis=1, level=1))
all_bias = all_bias[columns]

all_bias = static["gauge_name"].to_dataframe("Name").join(all_bias).dropna()
all_bias.head()

In [None]:
all_nse = nse_13.join(process_errors.drop(["Bias", "Name"], axis=1, level=1).droplevel(axis=1, level=1))
all_nse = all_nse[columns]
all_nse = static["gauge_name"].to_dataframe("Name").join(all_nse).dropna()

all_nse.head()
all_nse

In [None]:
print(all_nse.to_latex(float_format="%.2f"))

In [None]:
print(all_bias.to_latex(float_format="%.2f"))

In [None]:
all_nse.to_csv(data_dir / "RUNOFF/all_nse.csv")
all_bias.to_csv(data_dir / "RUNOFF/all_bias.csv")

# Spatial Plots

In [None]:
vic_errors = fuse_output_dict["ARNOVIC"]  # f_class.get_model_df("VIC")
prms_errors = fuse_output_dict["PRMS"]  # f_class.get_model_df("PRMS")
top_errors = fuse_output_dict["TOPMODEL"] # f_class.get_model_df("TOPMODEL")
sac_errors = fuse_output_dict["SACRAMENTO"]  # f_class.get_model_df("Sacramento")

vic_errors
# fuse_errors.drop("Name", axis=1, level=1)

# Create Geospatial Map
- http://darribas.org/gds15/content/labs/lab_03.html

In [None]:
import geopandas as gpd

shp_path = data_dir / "CAMELS_GB_DATASET/Catchment_Boundaries/CAMELS_GB_catchment_boundaries.shp"
assert shp_path.exists()

# load in the shapefile
geo_df = gpd.read_file(shp_path)
geo_df['ID_STRING'] = geo_df['ID_STRING'].astype('int')
geo_df.crs = {'init' :'epsg:27700'}  # 4277  27700

# points_gdf = 
d = static[["gauge_lat", "gauge_lon"]].to_dataframe()
points = gpd.GeoSeries(gpd.points_from_xy(d["gauge_lon"], d["gauge_lat"]), index=d.index)
points.name = "geometry"

In [None]:
def create_spatial_dataframe(error_df: pd.DataFrame, geo_df: gpd.GeoDataFrame, polygon: bool = False) -> gpd.GeoDataFrame:
    assert error_df.index.dtype == geo_df['ID_STRING'].dtype, "Need to be the same type (integer)"
    error_gdf = gpd.GeoDataFrame(
        geo_df.set_index('ID_STRING').join(error_df)
    )
    if not polygon:
        static_df = static.to_dataframe()
        d = static_df[["gauge_lat", "gauge_lon"]]

        points = gpd.GeoSeries(gpd.points_from_xy(d["gauge_lon"], d["gauge_lat"]), index=d.index)
        points.name = "geometry"
        error_gdf = error_gdf.drop("geometry", axis=1).join(points)
        error_gdf.crs = {'init' :'epsg:4326'}
    else:
        error_gdf.crs = {'init' :'epsg:27700'}
        
    return error_gdf
    
    
lstm_gdf = create_spatial_dataframe(lstm_df, geo_df)
ealstm_gdf = create_spatial_dataframe(ealstm_df, geo_df)
vic_gdf  = create_spatial_dataframe(vic_errors, geo_df)
prms_gdf  = create_spatial_dataframe(prms_errors, geo_df)
top_gdf  = create_spatial_dataframe(top_errors, geo_df)
sac_gdf  = create_spatial_dataframe(sac_errors, geo_df)

In [None]:
# lstm_gdf.to_file(data_dir / "RUNOFF/shp_files/lstm.shp")
# ealstm_gdf.to_file(data_dir / "RUNOFF/shp_files/ealstm.shp")
# vic_gdf.to_file(data_dir / "RUNOFF/shp_files/vic.shp")
# prms_gdf.to_file(data_dir / "RUNOFF/shp_files/prms.shp")
# top_gdf.to_file(data_dir / "RUNOFF/shp_files/top.shp")
# sac_gdf.to_file(data_dir / "RUNOFF/shp_files/sac.shp")

## Get UK Boundaries

Get the COUNTY SHAPE data [here:](https://opendata.arcgis.com/datasets/1919db8ffcc5445ea4ba5b8a10acfccd_0.zip?outSR=%7B%22latestWkid%22%3A27700%2C%22wkid%22%3A27700%7D)
```
!wget https://opendata.arcgis.com/datasets/1919db8ffcc5445ea4ba5b8a10acfccd_0.zip
!unzip 1919db8ffcc5445ea4ba5b8a10acfccd_0.zip
!mkdir Counties_and_Unitary_Authorities_April_2019_Boundaries_EW_BFC
!mv Counties* Counties_and_Unitary_Authorities_April_2019_Boundaries_EW_BFC
```

Get all of these shapefiles and merge into one big polygon
```python
uk = gpd.read_file(data_dir / "RUNOFF/Counties_and_Unitary_Authorities_April_2019_Boundaries_EW_BFC/Counties_and_Unitary_Authorities_April_2019_Boundaries_EW_BFC.shp")
uk.plot()

from shapely.ops import unary_union  # cascaded_union, 
uk_bound = unary_union([p for p in uk.geometry])
uk_bound = gpd.GeoSeries(uk_bound)
```

[Link to CRS Discussion](https://communityhub.esriuk.com/geoxchange/2012/3/26/coordinate-systems-and-projections-for-beginners.html#:~:text=If%20you%20work%20with%20UK,that%20you%20should%20know%20about.&text=Web%20Mercator%20is%20a%20PCS,36%20used%20for%20British%20maps)


![title](https://static1.squarespace.com/static/55bb8935e4b046642e9d3fa7/55bb8e8ee4b03fcc125a74c0/55bb8e91e4b03fcc125a7a67/1331725592717/1000w/coordsys_diagram.png)


In [None]:
assert (data_dir / "RUNOFF/natural_earth_hires/ne_10m_admin_0_countries.shp").exists(), "Download the natural earth hires from https://www.naturalearthdata.com/http//www.naturalearthdata.com/download/10m/cultural/ne_10m_admin_0_countries.zip"

world = gpd.read_file(data_dir / "RUNOFF/natural_earth_hires/ne_10m_admin_0_countries.shp")
uk = world.query("ADM0_A3 == 'GBR'")
# uk.plot(facecolor='none', edgecolor='k')

In [None]:
from mpl_toolkits.axes_grid1 import make_axes_locatable
markersize = 10
# "scheme": "quantiles" if quantiles else None
opts = {
    "rmse": {"vmin": 0, "vmax": 1, 'cmap': 'viridis', "markersize": markersize},   # rmse   
    "rmse_norm": {"vmin": 0, "vmax": 0.5, 'cmap': 'viridis', "markersize": markersize},   # rmse   
    "nse": {"vmin": 0.7, "vmax": 1, 'cmap': 'viridis_r', "markersize": markersize},    # nse  8
    "log_nse": {"vmin": 0.7, "vmax": 1, 'cmap': 'viridis_r', "markersize": markersize},    # log nse  8
    "kge": {"vmin": 0.7, "vmax": 1, 'cmap': 'plasma_r', "markersize": markersize},    # kge   
    "inv_kge": {"vmin": 0.7, "vmax": 1, 'cmap': 'plasma_r', "markersize": markersize},    # inv kge   
    "mape": {"vmin": 0, "vmax": None, 'cmap': 'plasma_r', "markersize": markersize},    # inv kge   
    "bias": {"vmin": -20, "vmax": 20, 'cmap': 'RdBu', "markersize": markersize},    # bias  
    "abs_pct_bias": {"vmin": 0, "vmax": 50, 'cmap': 'RdBu', "markersize": markersize},    # abs_pct_bias
    "mam30_ape": {"vmin": 0, "vmax": 50, 'cmap': 'RdBu', "markersize": markersize},    # mam30_ape  
}

from typing import List 

def plot_geospatial_data(model_data, model: str, metrics: List[str] = ["nse", "bias", "kge"]):
    assert all(np.isin(metrics, model_data.columns))
    fig, axs = plt.subplots(1, 3, figsize=(5*3, 8))
    
    for ix, metric in enumerate(metrics):
        ax = axs[ix]
        divider = make_axes_locatable(ax)
        cax = divider.append_axes("right", size="5%", pad=0.001) # depends on the user needs
        # plot the chloropleth
        model_data.to_crs(epsg=4326).plot(metric, ax=ax, legend=True, cax=cax, **opts[metric]);
        
        # plot the surrounding lines
        uk.plot(facecolor='none', edgecolor='k', ax=ax, linewidth=0.3)
        
        ax.set_xlim([-8.2, 2.1])
        ax.set_ylim([50, 59.5])
        ax.axis('off');

        ax.set_title(metric.upper())

#     fig.tight_layout(rect=[0, 0.03, 1, 0.965]);
#     fig.suptitle(f"{model} Model Error", size=14);

In [None]:
ealstm_gdf.columns

In [None]:
metrics = ["kge", "inv_kge", "log_nse"]

plot_geospatial_data(ealstm_gdf, model="EALSTM", metrics=metrics)
plot_geospatial_data(lstm_gdf, model="LSTM", metrics=metrics)

In [None]:
plot_geospatial_data(vic_gdf, model="VIC", metrics=metrics)
plot_geospatial_data(prms_gdf, model="PRMS", metrics=metrics)
plot_geospatial_data(top_gdf, model="TOPMODEL", metrics=metrics)
plot_geospatial_data(sac_gdf, model="Sacramento", metrics=metrics)

# Explore Other Metrics

In [None]:
# # plot_cdf()
if "all_errors" not in globals().keys():
    all_errors = calculate_all_data_errors(all_preds)
all_metrics = get_metric_dataframes_from_output_dict(all_errors)
all_metrics.keys()

In [None]:
for metric in []:
f, ax = plt.subplots(figsize=(15, 8))
plot_cdf(all_metrics["mam30_ape"], ax=ax, metric="mam30_ape", clip=(0, 50), title="")

In [None]:
if False:
    ax = plot_cdf(q95_rmse_df, metric="RMSE", sids=ml_sids, clip=(0, 10))
    ax.set_title("CDF of Q95 RMSE")

    # fig, ax = plt.subplots(figsize=(12, 8))
    ax = plot_cdf(q5_rmse_df, metric="RMSE", sids=ml_sids, clip=(0, 2), ax=None)
    ax.set_title("CDF of Q5 RMSE")

In [None]:
pubs = (
    pub_nse.set_index("station_id")
    .rename({"NSE_TOPMODEL": "TOPMODEL", "NSE_PRMS": "PRMS", "NSE_VIC": "VIC", "NSE_SACRAMENTO": "SACRAMENTO"}, axis=1)
    .loc[:, ["TOPMODEL", "PRMS", "VIC", "SACRAMENTO"]]
)
all_pubs = pubs.join(lstm_df["nse"].rename("LSTM")).join(ealstm_df["nse"].rename("EALSTM")).dropna()

In [None]:
f, ax = plt.subplots(figsize=(12, 8))
ax = plot_cdf(all_pubs, metric="NSE", sids=ml_sids, clip=(0, 1), ax=ax)
ax.set_title("CDF of Published NSE Scores")

# Look at performance at either end of the extremes

## Calculate the Q5 / Q95 metrics

In [None]:
# all_preds 
from scripts.drafts.calculate_error_scores import DeltaError

# calculate all error metrics
processor = DeltaError(ealstm_preds, lstm_preds, fuse_data)
all_preds = processor.all_preds

In [None]:
# less than or equal Q5
q5_flows = all_preds.where(all_preds["obs"] <= all_preds["obs"].quantile(q=0.05, dim=["time"]))
# more than or equal Q95
q95_flows = all_preds.where(all_preds["obs"] >= all_preds["obs"].quantile(q=0.95, dim=["time"]))

In [None]:
f, axs = plt.subplots(2, 1, figsize=(12, 3))
ax = axs[0]
sns.distplot(q95_flows["obs"].values[~np.isnan(q95_flows["obs"].values)], ax=ax)
ax = axs[1]
sns.distplot(q5_flows["obs"].values[~np.isnan(q5_flows["obs"].values)], ax=ax)
plt.tight_layout()
sns.despine()

In [None]:
def calculate_all_errors_xr(all_simulations: xr.Dataset) -> xr.Dataset:
    assert all(np.isin(["obs"], all_simulations.data_vars))

    all_errors: List[xr.Dataset] = []
    for ix, model in enumerate(tqdm([v for v in all_simulations.data_vars if v != "obs"], desc="Calculating Errors")):
        preds = all_simulations[[model, "obs"]].rename({model: "sim"})
        errors_df = calculate_errors(preds)
        errors_df["model"] = [model for _ in range(len(errors_df))]
        error_xr = errors_df.set_index(["station_id", "model"]).to_xarray()
        all_errors.append(error_xr)
    all_errors = xr.combine_by_coords(all_errors)
    return all_errors

In [None]:
if False:
    q5_errors = calculate_all_errors_xr(q5_flows)
    q95_errors = calculate_all_errors_xr(q95_flows)
    
    q95_errors["rmse"] = np.sqrt(q95_errors["mse"])
    q5_errors["rmse"] = np.sqrt(q5_errors["mse"])

In [None]:
if False:
    all_q95_errors = q95_errors.to_dataframe().reset_index().set_index("station_id")
    # all_q95_errors["rmse"] = np.sqrt(all_q95_errors["mse"])
    all_q5_errors = q5_errors.to_dataframe().reset_index().set_index("station_id")
    # all_q5_errors["rmse"] = np.sqrt(all_q5_errors["mse"])

In [None]:
def get_metric_from_xarray_error(xr_errors: xr.Dataset, metric: str = "nse") -> pd.DataFrame:
    # convert to dataframe
    _errors = xr_errors.to_dataframe().reset_index().set_index("station_id")
    # extract the model
    df_errors = (
        _errors
        .loc[:, ["model", metric]]
        .reset_index()
        .set_index(["station_id", "model"])
        .unstack().droplevel(axis=1, level=0)
    )
    return df_errors

if False:
    # q95_nse_df = all_q95_errors.loc[:, ["model", "nse"]].reset_index().set_index(["station_id", "model"]).unstack().droplevel(axis=1, level=0)
    q95_nse_df = get_metric_from_xarray_error(q95_errors, metric="nse")
    q5_rmse_df = get_metric_from_xarray_error(q5_errors, metric="rmse")
    q95_rmse_df = get_metric_from_xarray_error(q95_errors, metric="rmse")
    q5_bias_df = get_metric_from_xarray_error(q5_errors, metric="bias")
    q95_bias_df = get_metric_from_xarray_error(q95_errors, metric="bias")
    q95_nse_df.head()

In [None]:

# ax = plot_cdf(q95_rmse_df, metric="RMSE", sids=ml_sids, clip=None)

In [None]:
lstm_data
