# Runoff mass balance errors


## Setup

In [None]:
import pathlib as pl
from pprint import pprint
from shutil import rmtree, copy2

import hvplot.xarray  # noqa
from IPython.display import display
import jupyter_black
import numpy as np
import pywatershed as pws
import xarray as xr

jupyter_black.load()

In [None]:
# Configuration
domain_name = "drb_2yr"

nb_output_dir = pl.Path("./runoff_errors")

In [None]:
pws_root = pws.constants.__pywatershed_root__
domain_dir = pws_root / f"../test_data/{domain_name}"
nb_output_dir.mkdir(exist_ok=True)

zero = np.zeros([1])[0]
epsilon64 = np.finfo(zero).eps
epsilon32 = np.finfo(zero.astype("float32")).eps

## Run PRMS double precision runs and convert to netcdf

In [None]:
bin_dir = pws_root / "../bin/"
# bin_mixed = bin_dir / "prms_mac_m1_ifort_mixed_prec"
bin_double = bin_dir / "prms_mac_m1_ifort_dbl_prec"

In [None]:
def run_prms(binary: pl.Path, run_dir: pl.Path, skip_if_exists=False):
    import shlex
    import subprocess

    from pywatershed import CsvFile, Soltab
    from pywatershed.parameters import PrmsParameters

    if skip_if_exists and run_dir.exists():
        print(
            f"Run ({run_dir}) already exists and skip_if_exists=True. Using existing run."
        )
        return None

    run_dir.mkdir(parents=True)  # must not exist, on user to delete
    copy2(binary, run_dir / binary.name)
    for ff in [
        "nhm.control",
        "myparam.param",
        "tmax.cbh",
        "tmin.cbh",
        "prcp.cbh",
        "sf_data",
    ]:
        copy2(domain_dir / ff, run_dir / ff)

    output_dir = run_dir / "output"
    output_dir.mkdir()

    exe_command = f"time ./{binary.name} nhm.control -MAXDATALNLEN 60000 2>&1 | tee run.log"
    result = subprocess.run(
        exe_command,
        shell=True,
        # stdout = subprocess.PIPE,
        stderr=subprocess.STDOUT,
        universal_newlines=True,
        cwd=run_dir,
    )

    # these will be useful in what follows
    params = pws.parameters.PrmsParameters.load(
        domain_dir / "myparam.param"
    ).parameters

    # convert to netcdf
    # could make these arguments
    chunking = {
        "time": 0,
        "doy": 0,
        "nhm_id": 0,
        "nhm_seg": 0,
    }

    output_csvs = output_dir.glob("*.csv")
    for cc in output_csvs:
        if cc.name in ["stats.csv"]:
            continue
        nc_path = cc.with_suffix(".nc")
        CsvFile(cc).to_netcdf(nc_path, chunk_sizes=chunking)

    # previous and change variables
    for vv in [
        "pk_ice",
        "freeh2o",
        "soil_moist",
        "hru_impervstor",
        "dprst_stor_hru",
        "soil_lower",
        "soil_rechr",
    ]:
        data = xr.open_dataset(output_dir / f"{vv}.nc")[vv]
        prev_da = data.copy()
        prev_da[:] = np.roll(prev_da.values, 1, axis=0)
        assert (prev_da[1:, :].values == data[0:-1, :].values).all()
        prev_da[0, :] = np.zeros(1)[
            0
        ]  # np.nan better but causes plotting to fail
        change_da = data - prev_da
        prev_da.rename(f"{vv}_prev").to_dataset().to_netcdf(
            output_dir / f"{vv}_prev.nc"
        )
        data[f"{vv}_prev"] = xr.open_dataset(output_dir / f"{vv}_prev.nc")[
            f"{vv}_prev"
        ]

        change_da.rename(f"{vv}_change").to_dataset().to_netcdf(
            output_dir / f"{vv}_change.nc"
        )
        data[f"{vv}_change"] = xr.open_dataset(output_dir / f"{vv}_change.nc")[
            f"{vv}_change"
        ]

    # through_rain
    dep_vars = [
        "net_ppt",
        "pptmix_nopack",
        "snowmelt",
        "pkwater_equiv",
        "snow_evap",
        "net_snow",
        "net_rain",
    ]
    data = {}
    for vv in dep_vars:
        data[vv] = xr.open_dataset(output_dir / f"{vv}.nc")[vv]

    nearzero = 1.0e-6

    cond1 = data["net_ppt"] > zero
    cond2 = data["pptmix_nopack"] != 0
    cond3 = data["snowmelt"] < nearzero
    cond4 = data["pkwater_equiv"] < epsilon32
    cond5 = data["snow_evap"] < nearzero
    cond6 = data["net_snow"] < nearzero

    through_rain = data["net_rain"] * zero
    # these are in reverse order
    through_rain[:] = np.where(
        cond1 & cond3 & cond4 & cond6, data["net_rain"], zero
    )
    through_rain[:] = np.where(
        cond1 & cond3 & cond4 & cond5, data["net_ppt"], through_rain
    )
    through_rain[:] = np.where(cond1 & cond2, data["net_rain"], through_rain)

    through_rain.to_dataset(name="through_rain").to_netcdf(
        output_dir / "through_rain.nc"
    )
    through_rain.close()

    # infil_hru
    imperv_frac = params["hru_percent_imperv"]
    dprst_frac = params["dprst_frac"]
    perv_frac = 1.0 - imperv_frac - dprst_frac
    da = xr.open_dataset(output_dir / "infil.nc")["infil"].rename("infil_hru")
    da *= perv_frac
    da.to_dataset().to_netcdf(output_dir / "infil_hru.nc")
    da.close()

In [None]:
# run_prms(
#     bin_mixed,
#     nb_output_dir / f"{domain_name}_prms_mixed_run",
#     skip_if_exists=skip_if_exists_prms_mixed,
# )

In [None]:
# %debug

In [None]:
prms_dbl_run_dir = nb_output_dir / f"{domain_name}_prms_double_run"
skip_if_exists_prms_double = True
run_prms(
    bin_double, prms_dbl_run_dir, skip_if_exists=skip_if_exists_prms_double
)

## Run pywatershed run forced with output from PRMS double precision run

In [None]:
process = [pws.PRMSRunoff]
pws_run_dir = nb_output_dir / f"{domain_name}_pws_run"
input_dir_cp = prms_dbl_run_dir / "inputs"

In [None]:
skip_if_exists_pws = True
control = pws.Control.load_prms(domain_dir / "nhm.control")
output_dir = pws_run_dir / "output"
control.options = control.options | {
    "input_dir": input_dir_cp,
    "budget_type": "error",
    "calc_method": "numpy",
    "netcdf_output_dir": output_dir,
}
del control.options["netcdf_output_var_names"]
params = pws.parameters.PrmsParameters.load(domain_dir / "myparam.param")

In [None]:
if output_dir.exists() and skip_if_exists_pws:
    print(
        f"Output ({output_dir}) already exists and skip_if_exists=True. Using existing run."
    )

else:
    input_dir_cp.mkdir(exist_ok=True, parents=True)
    for ff in prms_dbl_run_dir.glob("*.nc"):
        copy2(ff, input_dir_cp / ff.name)
    for ff in (prms_dbl_run_dir / "output").glob("*.nc"):
        copy2(ff, input_dir_cp / ff.name)

    submodel = pws.Model(
        process,
        control=control,
        parameters=params,
    )

    submodel.run(finalize=True)

In [None]:
for vv in process[0].get_variables():
    print(vv)
    assert (output_dir / f"{vv}.nc").exists()
    try:
        assert (input_dir_cp / f"{vv}.nc").exists()
    except:
        print(f"********** {vv} not in input_dir_cp")

## Start by comparing the budget variables

In [None]:
budget_terms = process[0].get_mass_budget_terms()

In [None]:
# additional variables
budget_terms["outputs"] += [
    "dprst_insroff_hru",
    "dprst_stor_hru",
    "contrib_fraction",
    "infil",
    "infil_hru",
    "sroff",
    "hru_sroffp",
    "hru_sroffi",
    # "imperv_stor",
    # "imperv_evap",
    "hru_impervevap",
    "hru_impervstor",
    # "hru_impervstor_old",
    "hru_impervstor_change",
    # "dprst_vol_frac",
    # "dprst_vol_clos",
    # "dprst_vol_open",
    # "dprst_vol_clos_frac",
    # "dprst_vol_open_frac",
    # "dprst_area_clos",
    # "dprst_area_open",
    # "dprst_area_clos_max",
    # "dprst_area_open_max",
    "dprst_sroff_hru",
    "dprst_seep_hru",
    "dprst_evap_hru",
    "dprst_insroff_hru",
    "dprst_stor_hru",
    # "dprst_stor_hru_old",
    "dprst_stor_hru_change",
]
budget_terms["outputs"] = list(set(budget_terms["outputs"]))

In [None]:
comparisons = {}
for term, vars in budget_terms.items():
    if term == "inputs":
        continue
    print(term)
    for vv in vars:
        print("    ", vv)

        pws_file = output_dir / f"{vv}.nc"
        assert (pws_file).exists()
        pws_ds = xr.open_dataset(pws_file)[vv].rename("pws")

        prms_file = input_dir_cp / f"{vv}.nc"
        assert (prms_file).exists()
        prms_ds = xr.open_dataset(prms_file)[vv].rename("prms")

        comparisons[vv] = xr.merge([pws_ds, prms_ds])

In [None]:
# comparisons

In [None]:
def plot_var(var_name, diff=False, nhm_id: list = None):
    from textwrap import fill

    # lines = textwrap.wrap(text, width, break_long_words=False)
    meta = pws.meta.find_variables(var_name)[var_name]
    ylabel = f"{fill(meta['desc'], 40)}\n({meta['units']})"
    title = var_name
    ds = comparisons[var_name]

    if diff:
        ds = ds.copy()
        ds["error"] = ds["pws"] - ds["prms"]
        ds["relative_error"] = ds["error"] / ds["prms"]
        del ds["pws"], ds["prms"]
        ylabel = "Difference PWS - PRMS\n" + ylabel
        title = "ERRORS: Difference in " + title

    if (nhm_id is not None) and (len(nhm_id) > 0):
        ds = ds.where(ds.nhm_id.isin(nhm_id), drop=True)

    display(
        ds.hvplot(
            frame_width=700,
            groupby="nhm_id",
            title=title,
            ylabel=ylabel,
            # fontsize={"ylabel": "9px"},
        )
    )

In [None]:
def var_close(var_name):
    print(var_name)
    var_ds = comparisons[var_name]
    abs_diff = abs(var_ds["pws"] - var_ds["prms"])
    rel_abs_diff = abs_diff / var_ds["prms"]
    rtol = atol = 1.0e-7
    close = (abs_diff < atol) | (rel_abs_diff < rtol)
    if close.all():
        plot_var(var_name, diff=False)

    else:
        wh_not_close = np.where(~close)
        nhm_ids = abs_diff.nhm_id[wh_not_close[1]]
        plot_var(var_name, diff=True, nhm_id=nhm_ids)

    return

In [None]:
var_close("hru_impervstor_change")

In [None]:
for var_name in comparisons.keys():
    var_close(var_name)

## Look at specific time and budget errors

Runoff mass balance errors have been solved.

In [None]:
budget_terms = process[0].get_mass_budget_terms()
budget_terms["inputs"] += [
    "net_ppt",
    "net_rain",
    "net_snow",
    "pptmix_nopack",
    "pk_ice_prev",
    "freeh2o_prev",
    "newsnow",
    "snow_evap",
]

In [None]:
budget_cases = [
    (
        "1979-01-11T00:00:00",
        [0, 1, 2, 3, 4, 5, 6, 9, 10, 11, 12, 13],
    ),
]

case_ind = 0
budget_time = np.datetime64(budget_cases[case_ind][0])
budget_location_inds = budget_cases[case_ind][1]

budget_comps = {}
for term, vars in budget_terms.items():
    print(term)

    for vv in vars:
        print("    ", vv)

        if term == "inputs":
            pws_file = input_dir_cp / f"{vv}.nc"
        else:
            pws_file = output_dir / f"{vv}.nc"

        assert (pws_file).exists()
        pws_ds = xr.open_dataset(pws_file)[vv].rename("pws")

        prms_file = input_dir_cp / f"{vv}.nc"
        assert (prms_file).exists()
        prms_ds = xr.open_dataset(prms_file)[vv].rename("prms")

        budget_comps[vv] = (
            xr.merge([pws_ds, prms_ds])
            .sel(time=budget_time)
            .isel(nhm_id=budget_location_inds)
        )

In [None]:
bc = budget_comps
inputs = bc["through_rain"] + bc["snowmelt"] + bc["intcp_changeover"]
outputs = (
    bc["hru_sroffi"]
    + bc["hru_sroffp"]
    + bc["dprst_sroff_hru"]
    + bc["infil_hru"]
    + bc["hru_impervevap"]
    + bc["dprst_seep_hru"]
    + bc["dprst_evap_hru"]
)
storage_changes = bc["hru_impervstor_change"] + bc["dprst_stor_hru_change"]
balance = inputs - outputs - storage_changes

In [None]:
print(f"{budget_location_inds=}")
print(f"{inputs.prms.values=}")
print(f"{outputs.prms.values=}")
print(f"{storage_changes.prms.values=}")

print("-----------")

print(f'{bc["through_rain"].pws.values=}')

print(f'{bc["snow_evap"].prms.values=}')
print(f'{bc["hru_impervstor_change"].prms.values=}')
print(f'{bc["hru_impervstor_change"].pws.values=}')
print(f'{bc["dprst_stor_hru_change"].prms.values=}')
print(f'{bc["dprst_stor_hru_change"].pws.values=}')
print(f"{balance.prms.values=}")

# print(f"{bc["hru_sroffi"].prms.sum().values=}")
# print(f"{bc["hru_sroffp"].prms.sum().values=}")
# print(f"{bc["dprst_sroff_hru"].prms.sum().values=}")
# print(f"{bc["infil_hru"].prms.sum().values=}")
# print(f"{bc["hru_impervevap"].prms.sum().values=}")
# print(f"{bc["dprst_seep_hru"].prms.sum().values=}")
# print(f"{bc["dprst_evap_hru"].prms.sum().values=}")

# print(f"{storage_changes.prms.values=}")

In [None]:
print(f'{(balance - bc["through_rain"]).pws.values=}')
print(f'{(balance - bc["through_rain"]).prms.values=}')

In [None]:
((balance.pws - balance.prms) < 1.0e-8).all().values

In [None]:
balance.pws.sum()

In [None]:
print(f'{bc["through_rain"].pws.values=}')
print(f'{bc["net_rain"].pws.values=}')
print(f'{bc["net_snow"].pws.values=}')
print(f'{bc["net_ppt"].pws.values=}')
print(f'{bc["pptmix_nopack"].pws.values=}')
print(f'{bc["newsnow"].pws.values=}')
print(
    f'{(bc["pk_ice_prev"].pws.values + bc["freeh2o_prev"].pws.values) < epsilon32=}'
)

In [None]:
input_max = max(
    abs(bc["through_rain"]), abs(bc["snowmelt"]), +abs(bc["intcp_changeover"])
)
output_max = max(
    abs(bc["hru_sroffi"]),
    abs(bc["hru_sroffp"]),
    abs(bc["dprst_sroff_hru"]),
    abs(bc["infil_hru"]),
    abs(bc["hru_impervevap"]),
    abs(bc["dprst_seep_hru"]),
    abs(bc["dprst_evap_hru"]),
)

In [None]:
input_max.pws.values.tolist()

In [None]:
output_max.prms.values.tolist()

In [None]:
print((balance.pws / output_max.pws.max()).values)
print((balance.prms / output_max.prms.max()).values)