# MMR To MF6
Purpose
* export PRMS files for Muskingum-Mann Routing to run in MF6 as Muskingum
* run mf6 with these files
* compare PRMS MMR and MF6 MR

In [None]:
# This notebook uses the environment found in
# pynhm/ci/requirements/pynhmf6nb.yml
# running python update_flopy.py in modflow6/autotest likely required

import flopy
import pathlib as pl
import shutil

import hvplot.xarray
import numpy as np
import pandas as pd
import pint
import xarray as xr

import pywatershed
from pywatershed.utils.prms_to_mf6 import MMRToMF6

In [None]:
# General setup: use specified
mf6_bin = pl.Path("/Users/jmccreight/usgs/modflow6/bin/mf6")
domain = "drb_2yr"
start_time = np.datetime64("1979-01-01T00:00:00")
end_time = np.datetime64("1979-07-01T00:00:00")
observations_nc_file = pl.Path(
    "/Users/jmccreight/usgs/data/pynhm/nhm_subsets/drb/drb_2yr_gage_poi_obs.nc"
)

In [None]:
# Data paths
pynhm_root_dir = pywatershed.constants.__pywatershed_root__.parent
test_data_dir = pynhm_root_dir / "test_data"
domain_dir = test_data_dir / f"{domain}"
param_file = domain_dir / "myparam.param"
control_file = domain_dir / "control.test"
inflow_dir = control_file.parent / "output"

In [None]:
# where we'll run the experiment
mmr_to_mf6_dir = pynhm_root_dir / f"evaluation/prms/tmp_{domain}_mmr_to_mf6"
if mmr_to_mf6_dir.exists():
    shutil.rmtree(mmr_to_mf6_dir)

mmr_to_mf6_dir.mkdir(parents=True)

In [None]:
# Generate the data for mf6 MR routing from PRMS
mm = MMRToMF6(
    param_file=param_file,
    control_file=control_file,
    output_dir=mmr_to_mf6_dir,
    inflow_dir=inflow_dir,
    sim_name=domain,
    bc_binary_files=True,  # T, T
    bc_flows_combine=True,  # F, T
    write_on_init=False,
    # length_units="meters",
    # time_units="seconds",
    start_time=start_time,
    end_time=end_time,
    save_flows=True,
)

In [None]:
# Save mm object's MMR to save flows
oc = flopy.mf6.ModflowSnfoc(
    mm.sim.snf[0],
    budget_filerecord=f"{domain}.bud",
    saverecord=[
        ("BUDGET", "ALL"),
    ],
    printrecord=[
        ("BUDGET", "ALL"),
    ],
)

mmr = mm.sim.snf[0].get_package("mmr")
mmr.save_flows = True

In [None]:
# Write the data for MF6
mm.write()

In [None]:
# copy the MF6 binary
shutil.copy2(mf6_bin, mmr_to_mf6_dir / mf6_bin.name)

In [None]:
# Run MF6 MR
import subprocess

run_result = subprocess.run("./mf6", cwd=mmr_to_mf6_dir, capture_output=True)

In [None]:
# confirm
assert run_result.returncode == 0

In [None]:
# parse the results from the budget object and the disl grid
budobj = flopy.utils.binaryfile.CellBudgetFile(mmr_to_mf6_dir / f"{domain}.bud")
flowja = budobj.get_data(text="FLOW-JA-FACE")
qstorage = budobj.get_data(text="STORAGE")
qextoutflow = budobj.get_data(text="EXT-OUTFLOW")

disl_grb = mmr_to_mf6_dir / f"{domain}.disl.grb"
grb = flopy.mf6.utils.MfGrdFile(disl_grb)
ia = grb.ia
ja = grb.ja

# reuse this
tosegment = mm.sim.snf[0].disl.tosegment.get_data()

In [None]:
# build the flow from the budget
def get_outflow(itime):
    outflow = np.zeros(ia.shape[0] - 1)
    flowjaflat = flowja[itime].flatten()
    qextflat = qextoutflow[itime].flatten()
    for n in range(grb.nodes):
        itoseg = tosegment[n]
        if itoseg == -1:
            outflow[n] = -qextflat[n]
        else:
            found = False
            for ipos in range(ia[n] + 1, ia[n + 1]):
                j = ja[ipos]
                if j == itoseg:
                    found = True
                    q = flowjaflat[ipos]
                    outflow[n] = -q
                if found:
                    break
            if not found:
                raise Exception(
                    f"could not find entry for tosegment {ito} in flowja for node {n}"
                )

    return outflow

In [None]:
mf6_flow = [get_outflow(tt) for tt in range(len(flowja))]

In [None]:
# bing in PRMS flows and convert units
units = pint.UnitRegistry()
flow_ds = (
    xr.open_dataset(inflow_dir / "seg_outflow.nc")
    .sel(time=slice(start_time, end_time))
    .rename(seg_outflow="prms")
)

flow_ds["prms"][:, :] = (
    (flow_ds["prms"].values * units("feet ** 3 / second")).to("meters ** 3 / second")
).m

In [None]:
# add the MF6 flows the xarray dataset
flow_ds["mf6"] = xr.DataArray(mf6_flow, dims=["time", "nhm_seg"])

In [None]:
# cheap way to identify the outlet
display(flow_ds.mean(dim="time").argmax().values)

In [None]:
flow_ds.nhm_seg[24].values

In [None]:
flow_ds.hvplot(x="time", groupby="nhm_seg", ylabel="streamflow (m^3/s)", xlabel="")

In [None]:
params = pywatershed.PrmsParameters.load(param_file)
poi_id = np.chararray(flow_ds.prms.nhm_seg.shape, unicode=True, itemsize=15)
#            123456789012345
empty_str = "               "
poi_id[:] = empty_str
for ii, jj in enumerate(params.parameters["poi_gage_segment"].tolist()):
    poi_id[jj] = params.parameters["poi_gage_id"][ii]

flow_ds["poi_id"] = xr.DataArray(poi_id, dims=["nhm_seg"])
mod_obs_ds = (
    flow_ds.where(flow_ds.poi_id != empty_str, drop=True)
    .set_coords("poi_id")
    .swap_dims(nhm_seg="poi_id")
)

In [None]:
obs_ds = xr.open_dataset(observations_nc_file)["discharge"].rename("observed")
obs_ds[:] = (
    (obs_ds.values * units("feet ** 3 / second")).to("meters ** 3 / second").magnitude
)
obs_ds = obs_ds.to_dataset()

In [None]:
eval_ds = xr.merge([obs_ds, mod_obs_ds], join="inner")

In [None]:
obs_all_na = np.isnan(eval_ds.observed).sum(dim="time") == len(eval_ds.time)
eval_ds = eval_ds.where(~obs_all_na, drop=True)

In [None]:
eval_ds.hvplot(x="time", groupby="poi_id", ylabel="streamflow (m^3/s)", xlabel="")