# National Performance (Part 1)

In [1]:
from pathlib import Path
import os
import warnings

%load_ext autoreload
%autoreload 2

# ignore warnings for now ...
warnings.filterwarnings('ignore')

if Path('.').absolute().parents[1].name == 'ml_drought':
    os.chdir(Path('.').absolute().parents[1])

!pwd

/home/tommy/ml_drought


In [2]:
import xarray as xr
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import pickle
import seaborn as sns
import matplotlib as mpl
from tqdm import tqdm
from collections import defaultdict

mpl.rcParams['figure.dpi'] = 150

In [3]:
label_size = 14  # 10
plt.rcParams.update(
    {'axes.labelsize': label_size,
     'legend.fontsize': label_size,
     "font.size": 14,
    }
)

In [4]:
# data_dir = Path('data/')
data_dir = Path('/cats/datastore/data/')

assert data_dir.exists()

In [5]:
from src.utils import drop_nans_and_flatten

from src.analysis import read_train_data, read_test_data, read_pred_data
from src.analysis.evaluation import join_true_pred_da
from src.models import load_model

# Read in the CAMELS data

In [6]:
# read in the training data
ds = xr.open_dataset(data_dir / "RUNOFF/ALL_dynamic_ds.nc")
ds['station_id'] = ds['station_id'].astype(int)

all_static = xr.open_dataset(data_dir / f'RUNOFF/interim/static/data.nc')
all_static['station_id'] = all_static['station_id'].astype(int)
static = all_static

# Read all predictions

In [7]:
from scripts.drafts.io_results import read_ensemble_results, read_fuse_data
from scripts.drafts.calculate_error_scores import DeltaError
SAVE = True
RELOAD = False

if RELOAD:
    pet_ealstm_ensemble_dir = data_dir / "runs/ensemble_pet_ealstm"
    ealstm_preds = read_ensemble_results(pet_ealstm_ensemble_dir)

    lstm_ensemble_dir = data_dir / "runs/ensemble_pet"
    lstm_preds = read_ensemble_results(lstm_ensemble_dir)

    raw_fuse_path = data_dir / "RUNOFF/FUSE"
    fuse_data = read_fuse_data(raw_fuse_path, lstm_preds["obs"])

    # get matching stations
    all_stations_lstm = np.isin(lstm_preds.station_id, fuse_data.station_id)
    all_stations_ealstm = np.isin(ealstm_preds.station_id, fuse_data.station_id)
    lstm_preds = lstm_preds.sel(
        station_id=all_stations_lstm, time=np.isin(lstm_preds.time, fuse_data.time)
    )
    ealstm_preds = ealstm_preds.sel(
        station_id=all_stations_ealstm, time=np.isin(ealstm_preds.time, fuse_data.time)
    )

    processor = DeltaError(
        ealstm_preds,
        lstm_preds,
        fuse_data,
        benchmark_calculation_ds=ds[["discharge_spec"]],
        incl_benchmarks=True,
    )
    all_preds = processor.all_preds
    
    if SAVE:
        all_preds.to_netcdf(data_dir / "RUNOFF/all_preds.nc")

else:
    all_preds = xr.open_dataset(data_dir / "RUNOFF/all_preds.nc")
    
all_preds

In [14]:
import pickle
from scripts.drafts.calculate_error_scores import calculate_all_data_errors, get_metric_dataframes_from_output_dict

RELOAD = False

if RELOAD:
    all_errors = calculate_all_data_errors(all_preds, decompose_kge=True)
    all_metrics = get_metric_dataframes_from_output_dict(all_errors)
    if SAVE:
        pickle.dump(all_errors, (data_dir / "RUNOFF/all_errors.pkl").open("wb"))
        pickle.dump(all_metrics, (data_dir / "RUNOFF/all_metrics.pkl").open("wb"))
    
else:
    all_errors = pickle.load((data_dir / "RUNOFF/all_errors.pkl").open("rb"))
    all_metrics = pickle.load((data_dir / "RUNOFF/all_metrics.pkl").open("rb"))
    
all_metrics.keys()

all_metrics["bias_error_pct"] = all_metrics["bias_error"] * 100
all_metrics["std_error_pct"] = all_metrics["std_error"] * 100

# All Errors Table

In [15]:
metrics = ["nse", "kge", "log_nse", "bias_error_pct", "std_error_pct", "correlation"]
models = ['TOPMODEL', 'ARNOVIC', 'PRMS', 'SACRAMENTO', 'EALSTM', 'LSTM']
summaries = ["q5", "median"]

all_summary = []
for metric in metrics:
    summary_df = all_metrics[metric][models].describe(percentiles=[0.05, 0.5, 0.95]).rename(
        {
            "5%": "q5",
            "50%": "median",
            "95%": "q95",
        }
    ).loc[summaries].T

    summary_df.columns = pd.MultiIndex.from_arrays(
        ([metric for _ in range(len(summaries))], summaries)
    )

    all_summary.append(summary_df)

all_metric_summary = pd.concat(all_summary, axis=1)
all_metric_summary

Unnamed: 0_level_0,nse,nse,kge,kge,log_nse,log_nse,bias_error_pct,bias_error_pct,std_error_pct,std_error_pct,correlation,correlation
Unnamed: 0_level_1,q5,median,q5,median,q5,median,q5,median,q5,median,q5,median
TOPMODEL,0.37126,0.761844,0.50094,0.809525,-3.083626,0.547909,-18.500998,-3.895669,-31.001417,-9.50946,0.704736,0.882474
ARNOVIC,0.450332,0.780652,0.448192,0.799972,-1.260614,0.582283,-10.818944,6.311045,-26.821041,-10.016732,0.778176,0.896385
PRMS,-0.454427,0.766325,0.101257,0.827186,-4.664449,-0.313191,-12.176392,3.482145,-20.066977,-3.090016,0.730357,0.890359
SACRAMENTO,0.199348,0.799887,0.456775,0.839339,-2.481576,0.412521,-16.135613,-1.167484,-21.161546,-6.985332,0.737769,0.903701
EALSTM,0.686097,0.863691,0.652947,0.845071,0.623193,0.894506,-13.620205,-2.664854,-27.592921,-11.079071,0.867559,0.93748
LSTM,0.720685,0.884243,0.687628,0.873273,0.679941,0.915161,-12.040954,-1.984744,-23.594855,-8.560276,0.88096,0.945526


In [16]:
print(all_metric_summary.to_latex(float_format="%.2f", multirow=True))

\begin{tabular}{lrrrrrrrrrrrr}
\toprule
{} & \multicolumn{2}{l}{nse} & \multicolumn{2}{l}{kge} & \multicolumn{2}{l}{log\_nse} & \multicolumn{2}{l}{bias\_error\_pct} & \multicolumn{2}{l}{std\_error\_pct} & \multicolumn{2}{l}{correlation} \\
{} &    q5 & median &   q5 & median &      q5 & median &             q5 & median &            q5 & median &          q5 & median \\
\midrule
TOPMODEL   &  0.37 &   0.76 & 0.50 &   0.81 &   -3.08 &   0.55 &         -18.50 &  -3.90 &        -31.00 &  -9.51 &        0.70 &   0.88 \\
ARNOVIC    &  0.45 &   0.78 & 0.45 &   0.80 &   -1.26 &   0.58 &         -10.82 &   6.31 &        -26.82 & -10.02 &        0.78 &   0.90 \\
PRMS       & -0.45 &   0.77 & 0.10 &   0.83 &   -4.66 &  -0.31 &         -12.18 &   3.48 &        -20.07 &  -3.09 &        0.73 &   0.89 \\
SACRAMENTO &  0.20 &   0.80 & 0.46 &   0.84 &   -2.48 &   0.41 &         -16.14 &  -1.17 &        -21.16 &  -6.99 &        0.74 &   0.90 \\
EALSTM     &  0.69 &   0.86 & 0.65 &   0.85 &    0.62 &   0

# CDF Plots

# Spatial Plots