# Snow mass balance errors


In [None]:
# auto-format the code in this notebook
%load_ext jupyter_black

## Setup

In [None]:
import pathlib as pl
from pprint import pprint
from shutil import rmtree, copy2

import hvplot.xarray  # noqa
from IPython.display import display
import numpy as np
import pywatershed as pws
import xarray as xr

In [None]:
domain_name = "drb_2yr"
pws_root = pws.constants.__pywatershed_root__
domain_dir = pws_root / f"../test_data/{domain_name}"
nb_output_dir = pl.Path("./snow_errors")
nb_output_dir.mkdir(exist_ok=True)

## Run PRMS mixed and double precision runs and convert to netcdf

In [None]:
bin_dir = pws_root / "../prms_src/prms5.2.1/bin/"
bin_mixed = bin_dir / "prms_521_mixed_mac_m1_intel"
bin_double = bin_dir / "prms_521_double_mac_m1_intel"

In [None]:
def run_prms(binary: pl.Path, run_dir: pl.Path):
    import shlex
    import subprocess

    from pywatershed import CsvFile, Soltab

    from pywatershed.parameters import PrmsParameters

    run_dir.mkdir()  # must not exist, on user to delete
    copy2(binary, run_dir / binary.name)
    for ff in [
        "control.test",
        "myparam.param",
        "tmax.cbh",
        "tmin.cbh",
        "prcp.cbh",
        "sf_data",
    ]:
        copy2(domain_dir / ff, run_dir / ff)

    output_dir = run_dir / "output"
    output_dir.mkdir()

    exe_command = f"time ./{binary.name} control.test -MAXDATALNLEN 60000 2>&1 | tee run.log"
    result = subprocess.run(
        exe_command,
        shell=True,
        # stdout = subprocess.PIPE,
        stderr=subprocess.STDOUT,
        universal_newlines=True,
        cwd=run_dir,
    )

    # convert to netcdf
    # could make these arguments
    chunking = {
        "time": 0,
        "doy": 0,
        "nhm_id": 100,
        "nhm_seg": 100,
    }

    output_csvs = output_dir.glob("*.csv")
    for cc in output_csvs:
        if cc.name in ["stats.csv"]:
            continue
        nc_path = cc.with_suffix(".nc")
        CsvFile(cc).to_netcdf(nc_path, chunk_sizes=chunking)

    soltab_file = run_dir / "soltab_debug"
    # the nhm_ids are not available in the solta_debug file currently, so get
    # them from the domain parameters
    params = PrmsParameters.load(run_dir / "myparam.param")
    nhm_ids = params.parameters["nhm_id"]

    soltab = Soltab(
        soltab_file,
        output_dir=output_dir,
        nhm_ids=nhm_ids,
        chunk_sizes=chunking,
    )

    for var in soltab.variables:
        assert (output_dir / f"{var}.nc").exists()

    # previous variables
    for vv in ["pk_ice", "freeh2o", "soil_moist"]:
        data = xr.open_dataset(output_dir / f"{vv}.nc")[vv]
        prev_da = data.copy()
        prev_da[:] = np.roll(prev_da.values, 1, axis=0)
        assert (prev_da[1:, :].values == data[0:-1, :].values).all()
        prev_da[0, :] = np.nan
        prev_da.rename(f"{vv}_prev").to_dataset().to_netcdf(
            output_dir / f"{vv}_prev.nc"
        )
        data[f"{vv}_prev"] = xr.open_dataset(output_dir / f"{vv}_prev.nc")[
            f"{vv}_prev"
        ]

    # through_rain
    dep_vars = [
        "pk_ice_prev",
        "freeh2o_prev",
        "newsnow",
        "pptmix_nopack",
        "net_rain",
    ]
    data = {}
    for vv in dep_vars:
        data[vv] = xr.open_dataset(output_dir / f"{vv}.nc")[vv]

    zero = np.zeros([1])[0]
    epsilon64 = np.finfo(zero).eps
    epsilon32 = np.finfo(zero.astype("float32")).eps

    wh_through = (
        ((data["pk_ice_prev"] + data["freeh2o_prev"]) <= epsilon64)
        & ~(data["newsnow"] == 1)
    ) | (data["pptmix_nopack"] == 1)

    through_rain = data["net_rain"].copy()
    through_rain[:] = np.where(wh_through, data["net_rain"], zero)

    through_rain.to_dataset(name="through_rain").to_netcdf(
        output_dir / "through_rain.nc"
    )
    through_rain.close()

In [None]:
run_prms(bin_mixed, nb_output_dir / "prms_mixed_run")

In [None]:
prms_dbl_run_dir = nb_output_dir / "prms_double_run"
run_prms(bin_double, prms_dbl_run_dir)

## Run pywatershed run

In [None]:
process = [pws.PRMSSnow]

In [None]:
pws_run_dir = nb_output_dir / "pws_run"
input_dir = pws_run_dir / "pws_input"
input_dir.mkdir(exist_ok=True, parents=True)
for ff in prms_dbl_run_dir.glob("*.nc"):
    copy2(ff, input_dir / ff.name)
for ff in (prms_dbl_run_dir / "output").glob("*.nc"):
    copy2(ff, input_dir / ff.name)

In [None]:
control = pws.Control.load(domain_dir / "control.test")
output_dir = pws_run_dir / "output"
control.options = control.options | {
    "input_dir": input_dir,
    "budget_type": "warn",
    "calc_method": "numpy",
    "netcdf_output_dir": output_dir,
}
params = pws.parameters.PrmsParameters.load(domain_dir / "myparam.param")

In [None]:
submodel = pws.Model(
    process,
    control=control,
    parameters=params,
)
submodel.run(finalize=True)

In [None]:
for vv in process[0].get_variables():
    print(vv)
    assert (output_dir / f"{vv}.nc").exists()
    try:
        assert (input_dir / f"{vv}.nc").exists()
    except:
        print(f"********** {vv} not in input_dir")

## Start by comparing the budget variables

In [None]:
budget_terms = pws.PRMSSnow.get_mass_budget_terms()

In [None]:
# additional variables
budget_terms["outputs"] += [
    "pk_ice_prev",
    "freeh2o_prev",
    "newsnow",
    "pptmix_nopack",
]

In [None]:
comparisons = {}
for term, vars in budget_terms.items():
    if term == "inputs":
        continue
    print(term)
    for vv in vars:
        print("    ", vv)

        pws_file = output_dir / f"{vv}.nc"
        assert (pws_file).exists()
        pws_ds = xr.open_dataset(pws_file)[vv].rename("pws")

        prms_file = input_dir / f"{vv}.nc"
        assert prms_file.exists()
        prms_ds = xr.open_dataset(prms_file)[vv].rename("prms")

        comparisons[vv] = xr.merge([pws_ds, prms_ds])

In [None]:
# comparisons

In [None]:
def plot_var(var_name, diff=False, nhm_id: list = None):
    ds = comparisons[var_name]
    if diff:
        ds = ds["pws"] - ds["prms"]
    if nhm_id is not None:
        ds = ds.where(ds.nhm_id.isin(nhm_id), drop=True)

    display(ds.hvplot(frame_width=700, title=var_name, groupby="nhm_id"))

In [None]:
def var_close(var_name):
    var_ds = comparisons[var_name]
    abs_diff = abs(var_ds["pws"] - var_ds["prms"])
    rel_abs_diff = abs_diff / var_ds["prms"]
    rtol = atol = 1.0e-2
    close = (abs_diff < atol) | (rel_abs_diff < rtol)
    # assert close.all()
    wh_not_close = np.where(~close)
    if len(wh_not_close) == 0:
        return ()
    nhm_ids = abs_diff.nhm_id[wh_not_close[1]]
    display(plot_var(var_name, diff=True, nhm_id=nhm_ids))
    return

In [None]:
for var_name in comparisons.keys():
    var_close(var_name)