# Compare pywatershed and PRMS

This notebook compares pywatershed (PWS) and PRMS outputs. As part of release of v1.0 this notebook is meant to give users insight on how wimilar results from pywatershed and PRMS are. They are identical or very close in the vast majority of cases. This notebook makes it easy to find when they are not by providing statistics at individual HRUs and timeseries for all HRUs.

Note that this notebook requires an editable install of pywatershed (`pip install -e` in the pywatershed repository root) for the requisite data. PRMS/NHM domains which may be used are in the `test_data/` directory of pywatershed (`hru_1`, `drb_2yr`, and `ucb_2yr`) but any other domain may be used. 

## Notes on setting up other domains
You may want to supply your own domain and see how pywatershed works on it. Here are notes on doing so. Domains must supply the correct, required files in `test_data/your_domain` which are given in this listing:

```
control.test  prcp.cbh      sf_data       tmax.nc       tmin.nc
myparam.param prcp.nc       tmax.cbh      tmin.cbh
```

The `*.cbh` files must be pre-converted to netcdf for `prcp`, `tmin`, and `tmax` and how to do this can be found near the top of notebook 02. The `control.test` and `myparam.param` files are used by both PRMS and PWS. The `control.test` files in the repo are specific for being able to run sub-models and include a nearly maximal amount of model output (time-inefficient for both PRMS and PWS). The stock control files can be found in `test_data/common` there is a file for single-hru domains and multi-hru domains and these are identical (as appropriate) for the domains included in the repository. For running a large domain, for example, it is desirable to reduce the total amount of output (but this may not allow for PWS sub-models to be run as PRMS dosent necessarily supply all the required fields). So you may modify the `control.test` file but take careful note of what options are available in pywatershed as currently only NHM configuration is available.

The runs of PRMS use double precision binaries produced by the `prms_src/prms5.2.1` source code in the pywatershed repository. The procedure used below is exactly as done in CI for running regression tests against PRMS.

All of the code required for plotting below is included so that it can be further tailored to your tastes.

## Imports, etc

In [None]:
# auto-format the code in this notebook
%load_ext jupyter_black

In [None]:
import pathlib as pl
from platform import processor
from pprint import pprint
from shutil import rmtree
import subprocess
from sys import platform
import warnings

import hvplot.pandas  # noqa
import hvplot.xarray  # noqa
import numpy as np
import pandas as pd
import pywatershed as pws
import xarray as xr

repo_root = pws.constants.__pywatershed_root__.parent
nb_output_dir = pl.Path("./03_compare_pws_prms")

## Configuration

Specify what you want!

In [None]:
domain_name: str = "drb_2yr"  # must be present in test_data/domain_name
calc_method: str = "numba"
budget_type: str = None

run_prms: bool = True  ## always forced/overwrite

run_pws: bool = True  # run if the output does not exist on disk
force_pws_run: bool = True  # if it exists on disk, re-run it and overwrite?

## Run PRMS

In [None]:
domain_dir = repo_root / f"test_data/{domain_name}"

In [None]:
# use pytest to run the domains as in CI
if run_prms:
    print(f"PRMS running domain in {repo_root / f'test_data' / domain_name}")
    subprocess.run(
        f"pytest -s -n=2 run_prms_domains.py --domain={domain_name} -vv --force",
        shell=True,
        cwd=repo_root / "test_data/generate",
    )

In [None]:
# Convert PRMS output to netcdf as in CI
if run_prms:
    if "conus" in domain_name:
        nproc = 2  # memory bound for CONUS
        conv_only = "::make_netcdf_files"
    else:
        nproc = 8  # processor bound otherwise
        conv_only = ""

    subprocess.run(
        f"pytest -n={nproc} convert_prms_output_to_nc.py{conv_only} --domain={domain_name} --force",
        shell=True,
        cwd=repo_root / "test_data/generate",
    )

## Run pywatershed

In [None]:
if run_pws:
    nhm_processes = [
        pws.PRMSSolarGeometry,  # submodles are possible
        pws.PRMSAtmosphere,
        pws.PRMSCanopy,
        pws.PRMSSnow,
        pws.PRMSRunoff,
        pws.PRMSSoilzone,
        pws.PRMSGroundwater,
        pws.PRMSChannel,
    ]

    if len(nhm_processes) == 8:
        input_dir = domain_dir
        run_dir = nb_output_dir / f"{domain_name}_full_nhm"
    else:
        input_dir = domain_dir / "output"
        run_dir = nb_output_dir / f"{domain_name}_subset_nhm"

    control = pws.Control.load(domain_dir / "control.test")
    params = pws.parameters.PrmsParameters.load(domain_dir / "myparam.param")

    if run_dir.exists():
        if force_pws_run:
            rmtree(run_dir)
        else:
            raise RuntimeError("run directory exists")

    print(f"PWS writing output to {run_dir}")

    control.options = control.options | {
        "input_dir": input_dir,
        "budget_type": budget_type,
        "calc_method": calc_method,
        "netcdf_output_dir": run_dir,
    }

    nhm = pws.Model(
        nhm_processes,
        control=control,
        parameters=params,
    )
    nhm.run(finalize=True)

## Compare outputs

In [None]:
def compare_var_timeseries(var_name, rmse_min=None):
    """Plots compare timeseries a PWS and PRMS variable for all locations in domain (scrollable).

    Args:
        var_name: string name of variable
        rmse_min: only plot locations which exceed this minimum rmse between PRMS and PWS.

    """
    from textwrap import fill

    var_meta = pws.meta.find_variables(var_name)[var_name]
    ylabel = f"{fill(var_meta['desc'], 40)}\n({var_meta['units']})"

    prms_file = domain_dir / f"output/{var_name}.nc"
    if not prms_file.exists():
        return None
    prms_var = xr.open_dataarray(prms_file)
    pws_var = xr.open_dataarray(run_dir / f"{var_name}.nc")

    if rmse_min is not None:
        if "time" in prms_var.dims:
            time_dim = "time"
        else:
            time_dim = "doy"

        rmse = np.sqrt((pws_var - prms_var).mean(dim=time_dim) ** 2)
        mask_ge_min = rmse >= rmse_min
        n_mask = len(np.where(mask_ge_min)[0])
        print(f"There are {n_mask} locations with RMSE > {rmse_min}")
        if n_mask == 0:
            return None
        prms_var = prms_var.where(mask_ge_min, drop=True)
        pws_var = pws_var.where(mask_ge_min, drop=True)

    comp_ds = xr.merge(
        [
            prms_var.rename("prms"),
            pws_var.rename("pws"),
        ]
    )
    var_meta = pws.meta.find_variables(var_name)[var_name]
    space_coord = list(comp_ds.coords)
    for t_coord in ["doy", "time"]:
        if t_coord in space_coord:
            space_coord.remove(t_coord)

    display(
        comp_ds.hvplot(
            frame_width=800,
            frame_height=500,
            groupby=space_coord,
            # title=title,
            ylabel=ylabel,
            group_label="Model",
        )
    )

In [None]:
def calc_stat_location(var_name, stat_name):
    """Calculate a statistic location-wise (over time).

    Args:
        var_name: str for the variable of interest
        stat_name: one of ["rmse", "rrmse"]
    """
    prms_file = domain_dir / f"output/{var_name}.nc"
    if not prms_file.exists():
        print(f"PRMS file '{prms_file}' DNE, skipping.")
        return None
    prms = xr.open_dataarray(prms_file, decode_timedelta=False)
    pws_file = run_dir / f"{var_name}.nc"
    assert pws_file.exists()
    nhm_after = xr.open_dataarray(pws_file, decode_timedelta=False)
    if "time" in prms.dims:
        time_dim = "time"
    else:
        time_dim = "doy"
    if stat_name.lower() == "rmse":
        stat = np.sqrt((nhm_after - prms).mean(dim=time_dim) ** 2)
    elif stat_name.lower() == "rrmse":
        stat = np.sqrt(((nhm_after - prms) / prms).mean(dim=time_dim) ** 2)
    return stat.to_dataframe().melt(ignore_index=False)


def box_jitter_plot(
    df, subplot_width: int = 400, stat_name: str = "Statistic"
):
    """Box/violin-plot of a dataframe.

    Args:
        df: a pd.Dataframe
        subplot_width: int for how wide the subplots should be
        stat_name: str of the statisitc name
    """
    from textwrap import fill

    var_name = df.variable.iloc[0]
    var_meta = pws.meta.find_variables(var_name)[var_name]
    ylabel = (
        f"{stat_name} of\n{fill(var_meta['desc'], 40)}\n({var_meta['units']})"
    )
    coord = df.index.name

    box = df.hvplot.violin(y="value", by="variable", legend=False)
    jitter = df.hvplot.scatter(
        y="value",
        x="variable",
        hover_cols=[coord],
    )
    return (box * jitter).opts(
        width=subplot_width,
        # xlabel=f"over {coord}s",
        xlabel="",
        ylabel=ylabel,
    )


def plot_proc_stats(
    proc, stat_name: str = "RMSE", ncols: int = 5, subplot_width: int = 300
):
    """Plot pywatershed process stats.

    For a process (e.g. pws.PRMSRunoff), make box/violin plots of its stats for each of its (available) variables

    Args:
        proc: a pws.Process subclass
        stat_name: string of the statistic desired to be passed to box_jitter_plot
        ncols: int number of columns in the plot
        subplot_widt: int width of the subplots

    """
    var_stats = []
    for var_name in proc.get_variables():
        var_stats += [calc_stat_location(var_name, stat_name)]

    var_plots = [
        box_jitter_plot(vv, subplot_width=subplot_width, stat_name=stat_name)
        for vv in var_stats
        if vv is not None
    ]
    if len(var_plots) == 0:
        return None

    plot = var_plots[0]
    for vv in var_plots[1:]:
        plot += vv

    plot = plot.opts(shared_axes=False)
    if len(var_plots) > 1:
        plot = plot.cols(ncols)

    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        display(plot)

In [None]:
if pws.PRMSSolarGeometry in nhm_processes:
    plot_proc_stats(pws.PRMSSolarGeometry, "RMSE", 5)

In [None]:
if pws.PRMSSolarGeometry in nhm_processes:
    compare_var_timeseries("soltab_potsw")

In [None]:
if pws.PRMSAtmosphere in nhm_processes:
    plot_proc_stats(pws.PRMSAtmosphere, "RMSE", 4)

In [None]:
if pws.PRMSAtmosphere in nhm_processes:
    compare_var_timeseries("tmaxf")

In [None]:
if pws.PRMSCanopy in nhm_processes:
    plot_proc_stats(pws.PRMSCanopy, "RMSE", 4)

In [None]:
if pws.PRMSCanopy in nhm_processes:
    compare_var_timeseries("intcp_stor")

In [None]:
if pws.PRMSSnow in nhm_processes:
    plot_proc_stats(pws.PRMSSnow, "RMSE", 4)

In [None]:
if pws.PRMSSnow in nhm_processes:
    compare_var_timeseries("pkwater_equiv")

In [None]:
if pws.PRMSRunoff in nhm_processes:
    plot_proc_stats(pws.PRMSRunoff, "RMSE", 4)

In [None]:
if pws.PRMSRunoff in nhm_processes:
    compare_var_timeseries("contrib_fraction")

In [None]:
if pws.PRMSSoilzone in nhm_processes:
    plot_proc_stats(pws.PRMSSoilzone, "RMSE", 4)

In [None]:
if pws.PRMSSoilzone in nhm_processes:
    compare_var_timeseries("soil_rechr")

In [None]:
if pws.PRMSGroundwater in nhm_processes:
    plot_proc_stats(pws.PRMSGroundwater, "RMSE", 4)

In [None]:
if pws.PRMSGroundwater in nhm_processes:
    compare_var_timeseries("gwres_flow_vol")

In [None]:
if pws.PRMSChannel in nhm_processes:
    plot_proc_stats(pws.PRMSChannel, "RMSE", 4)

In [None]:
if pws.PRMSChannel in nhm_processes:
    compare_var_timeseries("seg_outflow")  # , rmse_min=0.01)