# FlowGraph with STARFit Reservoir: Big Sandy Reservoir
1. Plot the plot, show where the reservoir should be
2. Run the FGR simulation, look at flows out of where reservoir should be
3. Insert the STARFit reservoir, compare flows.

? Where will the FGR data reside? as release assets?

In [None]:
from copy import deepcopy
import pathlib as pl
import pickle
from pprint import pprint
from shutil import rmtree

import jupyter_black
import numpy as np
from tqdm.auto import tqdm
import xarray as xr

import hvplot.xarray  # noqa, after xr

import pywatershed as pws
from pywatershed.plot import DomainPlot
from pywatershed.constants import __pywatershed_root__ as repo_root
from pywatershed.constants import zero

jupyter_black.load()

In [None]:
nb_output_dir = pl.Path("./06_flow_graph_starfit")
if not nb_output_dir.exists():
    nb_output_dir.mkdir()

## Big Sandy Reservoir

In [None]:
sf_data_dir = pl.Path("/Users/jmccreight/usgs/data/starfit_datasets/")
grand_file = sf_data_dir / "GRanD_Version_1_3/GRanD_reservoirs_v1_3.shp"
istarf_file = sf_data_dir / "ISTARF-CONUS.csv"
sf_params = pws.parameters.StarfitParameters.from_istarf_conus_grand(
    grand_file=grand_file, istarf_file=istarf_file
)

In [None]:
grand_names = sf_params.parameters["GRanD_NAME"].tolist()
big_sandy_index = [
    ii for ii, nn in enumerate(grand_names) if "big sandy" in str(nn).lower()
][0]
big_sandy_grand_id = sf_params.parameters["grand_id"][big_sandy_index]

In [None]:
# get a parameter set with just the big sandy dike
sf_params = pws.parameters.StarfitParameters.from_istarf_conus_grand(
    grand_file=grand_file, istarf_file=istarf_file, grand_ids=[big_sandy_index]
)

In [None]:
start_lat = sf_params.parameters["LAT_DD"]
start_lon = sf_params.parameters["LONG_DD"]
# unfortunately the above are for a different reservoir,
# TODO: is the polygon correct?
# the coords are easy to get on google maps
start_lat = 42.25547378652696
start_lon = -109.43063080023737

In [None]:
domain_dir = pl.Path("/Users/jmccreight/usgs/data/pynhm/fgr")
domain_gis_dir = domain_dir / "GIS"

control_file = domain_dir / "nhm.control"

shp_file_hru = domain_gis_dir / "model_nhru.shp"
shp_file_seg = domain_gis_dir / "model_nsegment.shp"

In [None]:
# add GRanD shp file? or add to the object afterwards? option to get polygons
# for sf_params above? but how to show connectivity?
DomainPlot(
    hru_shp_file=shp_file_hru,
    segment_shp_file=shp_file_seg,
    start_lat=start_lat,
    start_lon=start_lon,
    start_zoom=13,
)

From the above, including mousing over the segments, we can see the reservoir should be inserted above nhm_seg 44426 and below nhm_segs 44434 and 44435. 

## Flaming Gorge Domain run with NHM and NO RESERVOIR
Let's look at the flows on segment 44426 with no reservoir present.

In [None]:
control = pws.Control.load_prms(control_file, warn_unused_options=False)
control.edit_n_time_steps(365 * 2)
parameter_file = domain_dir / control.options["parameter_file"]
params = pws.parameters.PrmsParameters.load(parameter_file)

In [None]:
# # run just once
# cbh_nc_dir = domain_dir
# cbh_files = [
#     domain_dir / "prcp.cbh",
#     domain_dir / "tmax.cbh",
#     domain_dir / "tmin.cbh",
# ]

# params = pws.parameters.PrmsParameters.load(domain_dir / "myparam.param")

# for cbh_file in cbh_files:
#     out_file = cbh_nc_dir / cbh_file.with_suffix(".nc").name
#     pws.utils.cbh_file_to_netcdf(cbh_file, params, out_file)

In [None]:
nhm_processes = [
    pws.PRMSSolarGeometry,
    pws.PRMSAtmosphere,
    pws.PRMSCanopy,
    pws.PRMSSnow,
    pws.PRMSRunoff,
    pws.PRMSSoilzone,
    pws.PRMSGroundwater,
    pws.PRMSChannel,
]

# we'll use the to-channel fluxes later when running FlowGraph as a post-process
control.options["netcdf_output_var_names"] = [
    "seg_outflow",
    "sroff_vol",
    "ssres_flow_vol",
    "gwres_flow_vol",
]
run_dir = nb_output_dir / "fgr_nhm"

control.options = control.options | {
    "input_dir": domain_dir,
    "budget_type": "error",
    "calc_method": "numba",
    "netcdf_output_dir": run_dir,
}

In [None]:
if not run_dir.exists():
    # must delete the run dir to re-run
    run_dir.mkdir()
    nhm = pws.Model(
        nhm_processes,
        control=control,
        parameters=params,
    )
    nhm.run(finalize=True)
    nhm.finalize()

In [None]:
outflow = xr.open_dataarray(run_dir / "seg_outflow.nc").sel(nhm_seg=44426)

In [None]:
outflow.hvplot()

## FlowGraph in Model

In [None]:
params_file_channel = domain_dir / "parameters_PRMSChannel.nc"
params_channel = pws.parameters.PrmsParameters.from_netcdf(params_file_channel)

dis_file = domain_dir / "parameters_dis_hru.nc"
dis_hru = pws.Parameters.from_netcdf(dis_file, encoding=False)

dis_both_file = domain_dir / "parameters_dis_both.nc"
dis_both = pws.Parameters.from_netcdf(dis_both_file, encoding=False)

In [None]:
control = pws.Control.load_prms(control_file, warn_unused_options=False)
control.edit_n_time_steps(365 * 2)
run_dir = nb_output_dir / "fgr_starfit"
control.options = control.options | {
    "input_dir": domain_dir,
    "budget_type": "error",
    "calc_method": "numba",
    "netcdf_output_dir": run_dir,
    "netcdf_output_var_names": ["node_outflows", "node_upstream_inflows"],
}

In [None]:
nhm_processes = [
    pws.PRMSSolarGeometry,
    pws.PRMSAtmosphere,
    pws.PRMSCanopy,
    pws.PRMSSnow,
    pws.PRMSRunoff,
    pws.PRMSSoilzone,
    pws.PRMSGroundwater,
]

model_dict = {}

for proc in nhm_processes:
    # this is the class name
    proc_name = proc.__name__
    # the processes can have arbitrary names in the model_dict and
    # an instance should not have capitalized name anyway (according to
    # python convention), so rename from the class name
    proc_rename = "prms_" + proc_name[4:].lower()
    # each process has a dictionary of information
    model_dict[proc_rename] = {}
    # alias to shorten lines below
    proc_dict = model_dict[proc_rename]
    # required key "class" specifys the class
    proc_dict["class"] = proc
    # the "parameters" key provides an instance of Parameters
    proc_param_file = domain_dir / f"parameters_{proc_name}.nc"
    proc_dict["parameters"] = pws.Parameters.from_netcdf(proc_param_file)
    # the "dis" key provides the name of the discretizations
    # which we'll supply shortly to the model dictionary
    if proc_rename == "prms_channel":
        proc_dict["dis"] = "dis_both"
    else:
        proc_dict["dis"] = "dis_hru"

In [None]:
pprint(model_dict, sort_dicts=False)

In [None]:
# this graph will have no-inflow to non-prms_channel nodes. could those be added later?
def prms_channel_flow_graph_preprocess(
    prms_channel_params,
    prms_channel_dis,
    prms_channel_dis_name,
    new_nodes_maker_dict,
    new_nodes_maker_names,
    new_nodes_maker_indices,
    new_nodes_flow_to_nhm_seg,
    graph_budget_type="error",
):

    prms_channel_flow_makers = [
        type(vv)
        for vv in new_nodes_maker_dict.values()
        if isinstance(vv, pws.PRMSChannelFlowNodeMaker)
    ]
    assert len(prms_channel_flow_makers) == 0

    assert len(new_nodes_maker_names) == len(new_nodes_maker_indices), "nono"
    assert len(new_nodes_maker_names) == len(new_nodes_flow_to_nhm_seg), "NONO"
    # JLM: I think this is the only condition to check with new_nodes_flow_to_nhm_seg
    assert len(new_nodes_flow_to_nhm_seg) == len(
        np.unique(new_nodes_flow_to_nhm_seg)
    ), "OHNO"

    nseg = prms_channel_params.dims["nsegment"]
    nnodes = nseg + len(new_nodes_maker_names)

    node_maker_name = ["prms_channel"] * nseg + new_nodes_maker_names
    node_maker_index = np.array(
        np.arange(nseg).tolist() + new_nodes_maker_indices
    )

    to_graph_index = np.zeros(nnodes, dtype=np.int64)
    dis_params = prms_channel_dis.parameters
    tosegment = dis_params["tosegment"] - 1  # fortan to python indexing
    to_graph_index[0:nseg] = tosegment

    for ii, nhm_seg in enumerate(new_nodes_flow_to_nhm_seg):
        wh_intervene_above_nhm = np.where(dis_params["nhm_seg"] == nhm_seg)
        wh_intervene_below_nhm = np.where(
            tosegment == wh_intervene_above_nhm[0][0]
        )
        # have to map to the graph from an index found in prms_channel
        wh_intervene_above_graph = np.where(
            (np.array(node_maker_name) == "prms_channel")
            & (node_maker_index == wh_intervene_above_nhm[0][0])
        )
        wh_intervene_below_graph = np.where(
            (np.array(node_maker_name) == "prms_channel")
            & np.isin(node_maker_index, wh_intervene_below_nhm)
        )

        to_graph_index[nseg + ii] = wh_intervene_above_graph[0][0]
        to_graph_index[wh_intervene_below_graph] = nseg + ii

    params_flow_graph = pws.Parameters(
        dims={
            "nnodes": nnodes,
        },
        coords={
            "node_coord": np.arange(nnodes),
        },
        data_vars={
            "node_maker_name": node_maker_name,
            "node_maker_index": node_maker_index,
            "to_graph_index": to_graph_index,
        },
        metadata={
            "node_coord": {"dims": ["nnodes"]},
            "node_maker_name": {"dims": ["nnodes"]},
            "node_maker_index": {"dims": ["nnodes"]},
            "to_graph_index": {"dims": ["nnodes"]},
        },
        validate=True,
    )

    # make available at top level __init__
    node_maker_dict = {
        "prms_channel": pws.PRMSChannelFlowNodeMaker(dis_both, params_channel),
    } | new_nodes_maker_dict

    def exchange_calculation(self) -> None:
        _hru_segment = self.hru_segment - 1
        s_per_time = self.control.time_step_seconds
        self._inputs_sum = (
            sum([vv.current for vv in self._input_variables_dict.values()])
            / s_per_time
        )

        # This zero in the last index means zero inflows to the pass through
        # node
        self.inflows[:] = zero
        # sinks is an HRU variable, its accounting in budget is fine because
        # global collapses it to a scalar before summing over variables
        self.sinks[:] = zero

        for ihru in range(self.nhru):
            iseg = _hru_segment[ihru]
            if iseg < 0:
                self.sinks[ihru] += self._inputs_sum[ihru]
            else:
                self.inflows[iseg] += self._inputs_sum[ihru]

        self.inflows_vol[:] = self.inflows * s_per_time
        self.sinks_vol[:] = self.sinks * s_per_time

    Exchange = pws.base.flow_graph.inflow_exchange_factory(
        dimension_names=("nhru", "nnodes"),
        parameter_names=("hru_segment", "node_coord"),
        input_names=pws.PRMSChannel.get_inputs(),
        init_values={
            "inflows": np.nan,
            "inflows_vol": np.nan,
            "sinks": np.nan,
            "sinks_vol": np.nan,
        },
        mass_budget_terms={
            "inputs": [
                "sroff_vol",
                "ssres_flow_vol",
                "gwres_flow_vol",
            ],
            "outputs": ["inflows_vol", "sinks_vol"],
            "storage_changes": [],
        },
        calculation=exchange_calculation,
    )  # get the budget type into the exchange too: exchange_budget_type

    # Exchange parameters
    # TODO: this is funky, can we make this more elegant?
    params_ds = params_channel.to_xr_ds().copy()
    params_ds["node_coord"] = xr.Variable(
        dims="nnodes",
        data=np.arange(nnodes),
    )
    params_ds = params_ds.set_coords("node_coord")
    params_exchange = pws.Parameters.from_ds(params_ds)

    return {
        "inflow_exchange": {
            "class": Exchange,
            "parameters": params_exchange,
            "dis": prms_channel_dis_name,
        },
        "prms_channel_flow_graph": {
            "class": pws.FlowGraph,
            "node_maker_dict": node_maker_dict,
            "parameters": params_flow_graph,
            "dis": None,
            "budget_type": graph_budget_type,
        },
    }

In [None]:
node_maker_dict = {
    "pass_throughs": pws.hydrology.pass_through_node.PassThroughNodeMaker()
}  # to make STARFIT

# could pass the model_dict and return the model_dict?
# do it for 2+ reservoirs/pass throughs
graph_dict = prms_channel_flow_graph_preprocess(
    params_channel,
    dis_both,
    "dis_both",
    {"pass_through": pws.hydrology.pass_through_node.PassThroughNodeMaker()},
    ["pass_through"] * 2,
    [0, 1],
    [44426, 44418],
)

In [None]:
nhm_processes_names = [
    ("prms_" + pp.__name__.lower()[4:]) for pp in nhm_processes
]
model_order = nhm_processes_names + list(graph_dict.keys())

model_dict = (
    model_dict
    | {
        "control": control,
        "dis_both": dis_hru,
        "dis_hru": dis_both,
        "model_order": model_order,
    }
    | graph_dict
)

In [None]:
if not run_dir.exists():
    run_dir.mkdir()
    model = pws.Model(model_dict)
    model.run()
    model.finalize()

In [None]:
wh_44426 = np.where(params.parameters["nhm_seg"] == 44426)[0]
outflow_nodes = xr.open_dataarray(run_dir / "node_outflows.nc")[
    :, wh_44426
].drop_vars("node_coord")

In [None]:
assert (abs(outflow_nodes - outflow) < 1e-9).all()

In [None]:
xr.merge([outflow, outflow_nodes]).hvplot()

## FlowGraph as a post-process

In [None]:
control = pws.Control.load_prms(control_file, warn_unused_options=False)
control.edit_n_time_steps(365 * 2)
run_dir = nb_output_dir / "fgr_starfit_post"
control.options = control.options | {
    "input_dir": domain_dir,
    "budget_type": "error",
    "calc_method": "numba",
    "netcdf_output_dir": run_dir,
    "netcdf_output_var_names": ["node_outflows", "node_upstream_inflows"],
}

params_file_channel = domain_dir / "parameters_PRMSChannel.nc"
params_channel = pws.parameters.PrmsParameters.from_netcdf(params_file_channel)

# dis_file = domain_dir / "parameters_dis_hru.nc"
# dis_hru = pws.Parameters.from_netcdf(dis_file, encoding=False)
if "dis_hru" in locals().keys():
    del dis_hru

dis_both_file = domain_dir / "parameters_dis_both.nc"
dis_both = pws.Parameters.from_netcdf(dis_both_file, encoding=False)

In [None]:
def prms_channel_flow_graph_postprocess(
    input_dir,
    prms_channel_params,
    prms_channel_dis,
    new_nodes_maker_dict,
    new_nodes_maker_names,
    new_nodes_maker_indices,
    new_nodes_flow_to_nhm_seg,
    graph_budget_type="error",
):

    prms_channel_flow_makers = [
        type(vv)
        for vv in new_nodes_maker_dict.values()
        if isinstance(vv, pws.PRMSChannelFlowNodeMaker)
    ]
    assert len(prms_channel_flow_makers) == 0

    assert len(new_nodes_maker_names) == len(new_nodes_maker_indices), "nono"
    assert len(new_nodes_maker_names) == len(new_nodes_flow_to_nhm_seg), "NONO"
    # JLM: I think this is the only condition to check with new_nodes_flow_to_nhm_seg
    assert len(new_nodes_flow_to_nhm_seg) == len(
        np.unique(new_nodes_flow_to_nhm_seg)
    ), "OHNO"

    nseg = prms_channel_params.dims["nsegment"]
    nnodes = nseg + len(new_nodes_maker_names)

    node_maker_name = ["prms_channel"] * nseg + new_nodes_maker_names
    node_maker_index = np.array(
        np.arange(nseg).tolist() + new_nodes_maker_indices
    )

    to_graph_index = np.zeros(nnodes, dtype=np.int64)
    dis_params = prms_channel_dis.parameters
    tosegment = dis_params["tosegment"] - 1  # fortan to python indexing
    to_graph_index[0:nseg] = tosegment

    for ii, nhm_seg in enumerate(new_nodes_flow_to_nhm_seg):
        wh_intervene_above_nhm = np.where(dis_params["nhm_seg"] == nhm_seg)
        wh_intervene_below_nhm = np.where(
            tosegment == wh_intervene_above_nhm[0][0]
        )
        # have to map to the graph from an index found in prms_channel
        wh_intervene_above_graph = np.where(
            (np.array(node_maker_name) == "prms_channel")
            & (node_maker_index == wh_intervene_above_nhm[0][0])
        )
        wh_intervene_below_graph = np.where(
            (np.array(node_maker_name) == "prms_channel")
            & np.isin(node_maker_index, wh_intervene_below_nhm)
        )

        to_graph_index[nseg + ii] = wh_intervene_above_graph[0][0]
        to_graph_index[wh_intervene_below_graph] = nseg + ii

    params_flow_graph = pws.Parameters(
        dims={
            "nnodes": nnodes,
        },
        coords={
            "node_coord": np.arange(nnodes),
        },
        data_vars={
            "node_maker_name": node_maker_name,
            "node_maker_index": node_maker_index,
            "to_graph_index": to_graph_index,
        },
        metadata={
            "node_coord": {"dims": ["nnodes"]},
            "node_maker_name": {"dims": ["nnodes"]},
            "node_maker_index": {"dims": ["nnodes"]},
            "to_graph_index": {"dims": ["nnodes"]},
        },
        validate=True,
    )

    # make available at top level __init__
    node_maker_dict = {
        "prms_channel": pws.PRMSChannelFlowNodeMaker(dis_both, params_channel),
    } | new_nodes_maker_dict

    # ---------XXXXXXXXXX----------
    # combine PRMS lateral inflows to a single non-volumetric inflow
    input_variables = {}
    for key in pws.PRMSChannel.get_inputs():
        nc_path = input_dir / f"{key}.nc"
        input_variables[key] = pws.AdapterNetcdf(nc_path, key, control)

    inflows_prms = pws.hydrology.prms_channel_flow_graph.HruSegmentFlowAdapter(
        params_channel, **input_variables
    )

    class GraphInflowAdapter(pws.Adapter):
        def __init__(
            self,
            prms_inflows: pws.Adapter,
            variable: str = "inflows",
        ):
            self._variable = variable
            self._prms_inflows = prms_inflows

            self._nnodes = nnodes
            self._nseg = nseg
            self._current_value = np.zeros(self._nnodes) * pws.constants.nan
            return

        def advance(self) -> None:
            self._prms_inflows.advance()
            self._current_value[0:nseg] = self._prms_inflows.current
            self._current_value[nseg:] = (
                zero  # no inflow non-prms-channel nodes
            )
            return

    inflows_graph = GraphInflowAdapter(inflows_prms)

    flow_graph = pws.FlowGraph(
        control=control,
        discretization=dis_both,
        parameters=params_flow_graph,
        inflows=inflows_graph,
        node_maker_dict=node_maker_dict,
        budget_type="error",
    )
    return flow_graph

In [None]:
input_dir = nb_output_dir / "fgr_nhm"  # use the output of the NHM run
flow_graph = prms_channel_flow_graph_postprocess(
    input_dir,
    params_channel,
    dis_both,
    {"pass_through": pws.hydrology.pass_through_node.PassThroughNodeMaker()},
    ["pass_through"] * 2,
    [0, 1],
    [44426, 44418],
)

In [None]:
if not run_dir.exists():
    run_dir.mkdir()
    flow_graph.initialize_netcdf()
    for istep in tqdm(range(control.n_times)):
        control.advance()
        flow_graph.advance()
        flow_graph.calculate(1.0)
        flow_graph.output()

    flow_graph.finalize()

In [None]:
wh_44426 = np.where(params.parameters["nhm_seg"] == 44426)[0]
outflow_nodes_post = (
    xr.open_dataarray(run_dir / "node_outflows.nc")[:, wh_44426]
    .drop_vars("node_coord")
    .rename("node_outflows_post")
)

In [None]:
assert (abs(outflow_nodes_post - outflow) < 1e-9).all()

In [None]:
xr.merge([outflow, outflow_nodes, outflow_nodes_post]).hvplot()