# Snow mass balance errors


In [None]:
# auto-format the code in this notebook
%load_ext jupyter_black

## Setup

In [None]:
import pathlib as pl
from pprint import pprint
from shutil import rmtree, copy2

import hvplot.xarray  # noqa
from icecream import ic
from IPython.display import display
import numpy as np
import pywatershed as pws
import xarray as xr

In [None]:
domain_name = "ucb_2yr"

nb_output_dir = pl.Path("./snow_errors")

skip_if_exists_prms_mixed = True
skip_if_exists_prms_double = True
skip_if_exists_pws = True

In [None]:
pws_root = pws.constants.__pywatershed_root__
domain_dir = pws_root / f"../test_data/{domain_name}"
nb_output_dir.mkdir(exist_ok=True)

zero = pws.constants.zero
epsilon64 = pws.constants.epsilon64
epsilon32 = pws.constants.epsilon32

## Run PRMS mixed and double precision runs and convert to netcdf

In [None]:
bin_dir = pws_root / "../bin/"
# bin_mixed = bin_dir / "prms_521_mixed_mac_m1_intel"
bin_double = bin_dir / "prms_mac_m1_ifort_dbl_prec"

In [None]:
def run_prms(binary: pl.Path, run_dir: pl.Path, skip_if_exists=False):
    import shlex
    import subprocess

    from pywatershed import CsvFile, Soltab

    from pywatershed.parameters import PrmsParameters

    if skip_if_exists and run_dir.exists():
        print(
            f"Run ({run_dir}) already exists and skip_if_exists=True. Using existing run."
        )
        return None

    run_dir.mkdir()  # must not exist, on user to delete
    copy2(binary, run_dir / binary.name)
    for ff in [
        "control.test",
        "myparam.param",
        "tmax.cbh",
        "tmin.cbh",
        "prcp.cbh",
        "sf_data",
    ]:
        copy2(domain_dir / ff, run_dir / ff)

    output_dir = run_dir / "output"
    output_dir.mkdir()

    exe_command = f"time ./{binary.name} control.test -MAXDATALNLEN 60000 2>&1 | tee run.log"
    result = subprocess.run(
        exe_command,
        shell=True,
        # stdout = subprocess.PIPE,
        stderr=subprocess.STDOUT,
        universal_newlines=True,
        cwd=run_dir,
    )

    # these will be useful in what follows
    params = pws.parameters.PrmsParameters.load(
        domain_dir / "myparam.param"
    ).parameters

    # convert to netcdf
    # could make these arguments
    chunking = {
        "time": 0,
        "doy": 0,
        "nhm_id": 100,
        "nhm_seg": 100,
    }

    output_csvs = output_dir.glob("*.csv")
    for cc in output_csvs:
        if cc.name in ["stats.csv"]:
            continue
        nc_path = cc.with_suffix(".nc")
        CsvFile(cc).to_netcdf(nc_path, chunk_sizes=chunking)

    # solar tables
    soltab_file = run_dir / "soltab_debug"
    # the nhm_ids are not available in the solta_debug file currently, so get
    # them from the domain parameters
    params = PrmsParameters.load(run_dir / "myparam.param")
    nhm_ids = params.parameters["nhm_id"]

    soltab = Soltab(
        soltab_file,
        output_dir=output_dir,
        nhm_ids=nhm_ids,
        chunk_sizes=chunking,
    )

    for var in soltab.variables:
        assert (output_dir / f"{var}.nc").exists()

    # previous and change variables
    for vv in ["pk_ice", "freeh2o", "soil_moist"]:
        data = xr.open_dataset(output_dir / f"{vv}.nc")[vv]
        prev_da = data.copy()
        prev_da[:] = np.roll(prev_da.values, 1, axis=0)
        assert (prev_da[1:, :].values == data[0:-1, :].values).all()
        prev_da[0, :] = np.zeros(1)[
            0
        ]  # np.nan better but causes plotting to fail
        change_da = data - prev_da
        prev_da.rename(f"{vv}_prev").to_dataset().to_netcdf(
            output_dir / f"{vv}_prev.nc"
        )
        data[f"{vv}_prev"] = xr.open_dataset(output_dir / f"{vv}_prev.nc")[
            f"{vv}_prev"
        ]

        change_da.rename(f"{vv}_change").to_dataset().to_netcdf(
            output_dir / f"{vv}_change.nc"
        )
        data[f"{vv}_change"] = xr.open_dataset(output_dir / f"{vv}_change.nc")[
            f"{vv}_change"
        ]
    # through_rain
    dep_vars = [
        "net_ppt",
        "pptmix_nopack",
        "snowmelt",
        "pkwater_equiv",
        "snow_evap",
        "net_snow",
        "net_rain",
    ]
    data = {}
    for vv in dep_vars:
        data[vv] = xr.open_dataset(output_dir / f"{vv}.nc")[vv]

    nearzero = 1.0e-6

    cond1 = data["net_ppt"] > zero
    cond2 = data["pptmix_nopack"] != 0
    cond3 = data["snowmelt"] < nearzero
    cond4 = data["pkwater_equiv"] < epsilon32
    cond5 = data["snow_evap"] < nearzero
    cond6 = data["net_snow"] < nearzero

    through_rain = data["net_rain"] * zero
    # these are in reverse order
    through_rain[:] = np.where(
        cond1 & cond3 & cond4 & cond6, data["net_rain"], zero
    )
    through_rain[:] = np.where(
        cond1 & cond3 & cond4 & cond5, data["net_ppt"], through_rain
    )
    through_rain[:] = np.where(cond1 & cond2, data["net_rain"], through_rain)

    through_rain.to_dataset(name="through_rain").to_netcdf(
        output_dir / "through_rain.nc"
    )
    through_rain.close()

In [None]:
# run_prms(bin_mixed, nb_output_dir / "prms_mixed_run")

In [None]:
prms_dbl_run_dir = nb_output_dir / f"{domain_name}_prms_double_run"
run_prms(
    bin_double, prms_dbl_run_dir, skip_if_exists=skip_if_exists_prms_double
)

## pywatershed run

In [None]:
process = [pws.PRMSSnow]
pws_run_dir = nb_output_dir / f"{domain_name}_pws_run"
input_dir = pws_run_dir / "pws_input"

In [None]:
control = pws.Control.load(domain_dir / "control.test")
output_dir = pws_run_dir / "output"
control.options = control.options | {
    "input_dir": input_dir,
    "budget_type": "warn",
    "calc_method": "numpy",
    "netcdf_output_dir": output_dir,
}
params = pws.parameters.PrmsParameters.load(domain_dir / "myparam.param")

In [None]:
if output_dir.exists() and skip_if_exists_pws:
    print(
        f"Output ({output_dir}) already exists and skip_if_exists=True. Using existing run."
    )

else:
    input_dir.mkdir(exist_ok=True, parents=True)
    for ff in prms_dbl_run_dir.glob("*.nc"):
        copy2(ff, input_dir / ff.name)
    for ff in (prms_dbl_run_dir / "output").glob("*.nc"):
        copy2(ff, input_dir / ff.name)

    submodel = pws.Model(
        process,
        control=control,
        parameters=params,
    )

    submodel.run(finalize=True)

In [None]:
for vv in process[0].get_variables():
    print(vv)
    assert (output_dir / f"{vv}.nc").exists()
    try:
        assert (input_dir / f"{vv}.nc").exists()
    except:
        print(f"********** {vv} not in input_dir")

## Start by comparing the budget variables

In [None]:
budget_terms = pws.PRMSSnow.get_mass_budget_terms()

In [None]:
# additional variables
budget_terms["outputs"] += [
    "pk_ice_prev",
    "freeh2o_prev",
    "newsnow",
    "pptmix_nopack",
    # "ai",
    "albedo",
    # 'frac_swe',
    "freeh2o",
    "freeh2o_change",
    "freeh2o_prev",
    #' iasw',
    # 'int_alb',
    "iso",
    # 'lso',
    # 'lst',
    # "mso",
    "newsnow",
    "pk_def",
    "pk_den",
    "pk_depth",
    "pk_ice",
    "pk_ice_change",
    "pk_ice_prev",
    #  'pk_precip',
    "pk_temp",
    # 'pksv',
    "pkwater_ante",
    "pkwater_equiv",
    # 'pkwater_equiv_change',
    "pptmix_nopack",
    # 'pss',
    "pst",
    # "salb",
    #' scrv',
    #' slst',
    "snow_evap",
    "snowcov_area",
    # 'snowcov_areasv',
    "snowmelt",
    # 'snsv',
    "tcal",
    "through_rain",
]

In [None]:
comparisons = {}
for term, vars in budget_terms.items():
    if term == "inputs":
        continue
    print(term)
    for vv in sorted(vars):
        print("    ", vv)

        pws_file = output_dir / f"{vv}.nc"
        assert (pws_file).exists()
        pws_ds = xr.open_dataset(pws_file)[vv].rename("pws")

        prms_file = input_dir / f"{vv}.nc"
        assert prms_file.exists()
        prms_ds = xr.open_dataset(prms_file)[vv].rename("prms")

        comparisons[vv] = xr.merge([pws_ds, prms_ds])

In [None]:
# comparisons

In [None]:
def plot_var(var_name, diff=False, nhm_id: list = None):
    from textwrap import fill

    # lines = textwrap.wrap(text, width, break_long_words=False)
    meta = pws.meta.find_variables(var_name)[var_name]
    ylabel = f"{fill(meta['desc'], 40)}\n({meta['units']})"
    title = var_name
    ds = comparisons[var_name]

    if diff:
        ds = ds.copy()
        ds["error"] = ds["pws"] - ds["prms"]
        # ds["relative_error"] = ds["error"] / ds["prms"]
        # ds["relative_error"] = xr.where(
        #     abs(ds["prms"]) < 1.0e-7, np.nan, ds["relative_error"]
        # )
        del ds["pws"], ds["prms"]
        ylabel = "Difference PWS - PRMS\n" + ylabel
        title = "ERRORS: Difference in " + title

    if (nhm_id is not None) and (len(nhm_id) > 0):
        ds = ds.where(ds.nhm_id.isin(nhm_id), drop=True)

    display(
        ds.hvplot(
            frame_width=700,
            groupby="nhm_id",
            title=title,
            ylabel=ylabel,
            # fontsize={"ylabel": "9px"},
        )
    )

In [None]:
def var_close(var_name):
    print(var_name)
    var_ds = comparisons[var_name]
    abs_diff = abs(var_ds["pws"] - var_ds["prms"])
    rel_abs_diff = abs_diff / var_ds["prms"]
    rtol = 0.02
    atol = 1.0e-2
    close = (abs_diff < atol) | (rel_abs_diff < rtol)
    if close.all():
        plot_var(var_name, diff=False)

    else:
        wh_not_close = np.where(~close)
        nhm_ids = abs_diff.nhm_id[wh_not_close[1]]
        plot_var(var_name, diff=True, nhm_id=nhm_ids)

    return

In [None]:
for var_name in comparisons.keys():
    var_close(var_name)