In [None]:
%run "0a_Workspace_setup.ipynb"

In [None]:
import pywatershed as pws
#import os
import dask


#Check to make sure these are really needed
from pywatershed.parameters.prms_parameters import JSONParameterEncoder
import sys
sys.path.append('../scripts/')
#from pest_utils import pars_to_tpl_entries
sys.path.append('../dependencies/')
#import pyemu

import shutil
import time


In [None]:
# model_folder_name = '20240619_v1.1_gm_WallaWalla_byHWobs_custom_cal'# This line will be edited by the user
# root_dir = pl.Path('../').resolve()
# notebook_dir = pl.Path('./').resolve()
# model_dir = pl.Path(root_dir/ f'{model_folder_name}').resolve()

# print(f'The root directory is {root_dir}')
# print(f'The notebook directory is {notebook_dir}')
# print(f'The model directory is {model_dir}.')

In [None]:
# # First make an output directory should one not exist
# if not (model_dir / 'output').exists():
#         (model_dir / 'output').mkdir()

### Writes the parameter file as a json file¶
#### This makes the par file compatible with our current notebooks for pws


In [None]:
# #Setting the model parameter file name-- this set to pyWatershed custom output file, but may choose another NHM file from Bandit extraction, eg. byHRU.

pardat = pws.parameters.PrmsParameters.load(param_filename)#load parameter file from extraction

#pardat.parameters_to_json(model_dir /"parameters.json")

#### Some useful pws checks


In [None]:
#pws.PRMSCanopy.get_variables()
#pws.PRMSSnow.get_variables()
#pws.PRMSRunoff.get_variables()
#pws.PRMSSoilzone.get_variables()
#pws.PRMSGroundwater.get_variables()
#pws.PRMSChannel.get_variables()
#pws.PRMSStarfit.get_variables()

#pws.meta.find_variables([pws.PRMSChannel.get_variables()[2]])

#Helpful table for explaining variables https://water.usgs.gov/water-resources/software/PRMS/PRMS_tables_5.2.1.pdf

In [None]:
pws.meta.find_variables([pws.PRMSChannel.get_variables()[6]])

In [None]:
pws.PRMSAtmosphere.get_variables()

In [None]:
#pws.meta.find_variables([pws.PRMSAtmosphere.get_variables()[5]])

In [None]:
pws.PRMSCanopy.get_variables()

In [None]:
pws.meta.find_variables([pws.PRMSCanopy.get_variables()[6]])

In [None]:
#pws.PRMSSnow.get_variables()

In [None]:
pws.PRMSAtmosphere

In [None]:
pws.meta.find_variables([pws.PRMSSnow.get_variables()[2]])

### Custom Run the Model output loop and default output loop
#### The default loop will output the PyWatershed standard output variables only and outputs each variable as a .nc file.
#### The cusom loop uses the standartds to calculate other output variables (known to PyWatershed) and creates one .nc file will all standard and custom variables and metadata, with special dimension for pois.

In [None]:
sttime = time.time()
print("You will be prompted when the model is finished running in about 5 minutes.")
model_output_netcdf = False

#work_dir = root_dir / model_folder_name
#out_dir = root_dir/ model_folder_name / 'output'
#out_dir.mkdir(exist_ok=True)

#custom_output_file = out_dir / "model_custom_output.nc"
#param_file = work_dir / "myparam.param" #took out because is in the first notebook now 0a

#control = pws.Control.load_prms(model_dir / "control.default.bandit", warn_unused_options= False)

#Load param and control file into pyWatershed
#params = pws.parameters.PrmsParameters.load_from_json(param_file)
params = pws.parameters.PrmsParameters.load(param_filename)


control = pws.Control.load_prms(model_dir / control_file_name, warn_unused_options= False)
#Sets control options for both cases
control.options = control.options | {
    "input_dir": model_dir,
    "budget_type": None,
    "verbosity": 0,
    "calc_method": "numba",
}

if model_output_netcdf:
    control.options = control.options | {
        "netcdf_output_var_names": [
            "hru_actet",
            #"potet",
            #"tmaxf",
            "sroff_vol",
            "ssres_flow_vol",
            "gwres_flow_vol",
            "seg_outflow",
           #"hru_streamflow_out",
            "recharge",
            #"snowcov_area", 
            #"soil_rechr",
            #"hru_actet",
            "net_rain",
            "net_snow",
            "net_ppt",
            "sroff",
            "ssres_flow",
            "gwres_flow",
            #"seg_outflow",
            #"hru_streamflow_out",
            "gwres_sink",
            "snowmelt",
            "gwres_stor",
            "gwres_stor_change",
            "ssres_stor",
            "unused_potet",
        ],
        "netcdf_output_dir": out_dir,
    }
else:
    control.options = control.options | {
        "netcdf_output_var_names": None,
        "netcdf_output_dir": None,
    }

model = pws.Model(
    [
        pws.PRMSSolarGeometry,
        pws.PRMSAtmosphere,
        pws.PRMSCanopy,
        pws.PRMSSnow,
        pws.PRMSRunoff,
        pws.PRMSSoilzone,
        pws.PRMSGroundwater,
        pws.PRMSChannel,
    ],
    control=control,
    parameters=params,
)

# Custom model output at selected spatial locations for all times.
# Generally, i'd be careful with xarray performance, but just writing at the
# end should be fine.
# Could move to netcdf4 if performance is a concern.

# /////////////////////////////////
# specfications: what we want this to look like to the user
var_list = [
    "hru_actet",
    #"potet",
    #"tmaxf",
    "seg_outflow",
    "recharge",
    #"snowcov_area",
    #"soil_rechr",
    "net_rain",
    "net_snow",
    "net_ppt",
    "sroff",# values in inches for area weighted averaging
    "ssres_flow",# values in inches for area weighted averaging
    "gwres_flow",# values in inches for area weighted averaging
    "gwres_sink",
    "snowmelt",
    "gwres_stor",
    "gwres_stor_change",
    "ssres_stor",
    "unused_potet",
]


# want seg_outflow just on poi_gages
# make it a tuple like the return of np.where
wh_gages = (params.parameters["poi_gage_segment"] - 1,)# - 1 is related to the indexing in fortran; made a a tuple see above
spatial_subsets = {
    "poi_gages": {
        "coord_name": "nhm_seg",
        "indices": wh_gages,
        "new_coord": params.parameters["poi_gage_id"],
        "variables": ["seg_outflow", "seg_gwflow"],#can add any other var with same coord here, eg. seg_gwflow/
    },
}


# A novel, diagnostic variable
def sum_hru_flows(sroff_vol, ssres_flow_vol, gwres_flow_vol): #These vars used to calc, do not need to be in the var list
    return sroff_vol + ssres_flow_vol + gwres_flow_vol


diagnostic_var_dict = {
    "hru_streamflow_out": {
        "inputs": ["sroff_vol", "ssres_flow_vol", "gwres_flow_vol"],
        "function": sum_hru_flows,
        "like_var": "sroff_vol",
        "metadata": {"desc": "Total volume to stream network from each HRU", "units": "cubic feet"},
    },
}

# TODO: specify subsets in time
# TODO: specify different output files

# /////////////////////////////////
# code starts here

out_subset_ds = xr.Dataset()

needed_vars = var_list + [
    var for key, val in diagnostic_var_dict.items() for var in val["inputs"]
]
needed_metadata = pws.meta.get_vars(needed_vars)
dims = set([dim for val in needed_metadata.values() for dim in val["dims"]])

subset_vars = [
    var for key, val in spatial_subsets.items() for var in val["variables"]
]

var_subset_key = {
    var: subkey
    for var in subset_vars
    for subkey in spatial_subsets.keys()
    if var in spatial_subsets[subkey]["variables"]
}

diagnostic_vars = list(diagnostic_var_dict.keys())

# solve the processes for each variable
var_proc = {
    var: proc_key
    for var in needed_vars
    for proc_key, proc_val in model.processes.items()
    if var in proc_val.get_variables()
}

time_coord = np.arange(control.start_time, control.end_time + control.time_step, dtype="datetime64[D]"
                      )
n_time_steps = len(time_coord)
out_subset_ds["time"] = xr.Variable(["time"], time_coord)
out_subset_ds = out_subset_ds.set_coords("time")

# annoying to have to hard-code this
dim_coord = {"nhru": "nhm_id", "nsegment": "nhm_seg"}

####################################################################################
# declare memory for the outputs
for var in var_list + diagnostic_vars:
    # impostor approach
    orig_diag_var = None
    if var in diagnostic_vars:
        orig_diag_var = var
        var = diagnostic_var_dict[var]["like_var"]

    proc = model.processes[var_proc[var]]
    dim_name = needed_metadata[var]["dims"][0]
    dim_len = proc.params.dims[dim_name]
    coord_name = dim_coord[dim_name]
    coord_data = proc.params.coords[dim_coord[dim_name]]
    type = needed_metadata[var]["type"]

    var_meta = {
        kk: vv
        for kk, vv in needed_metadata[var].items()
        if kk in ["desc", "units"]
    }

    if orig_diag_var is not None:
        var = orig_diag_var
        del var_meta["desc"]
        if "metadata" in diagnostic_var_dict[var]:
            var_meta = diagnostic_var_dict[var]["metadata"]
        if "desc" not in var_meta.keys():
            var_meta["desc"] = "Custom output diagnostic variable"

    if var in subset_vars:
        subset_key = var_subset_key[var]
        subset_info = spatial_subsets[subset_key]
        dim_name = f"n{subset_key}"
        coord_name = subset_key
        dim_len = len(subset_info["indices"][0])
        coord_data = subset_info["new_coord"]

    if coord_name not in list(out_subset_ds.variables):
        out_subset_ds[coord_name] = xr.DataArray(coord_data, dims=[dim_name])
        out_subset_ds = out_subset_ds.set_coords(coord_name)

    out_subset_ds[var] = xr.Variable(
        ["time", dim_name],
        np.full(
            [n_time_steps, dim_len],
            pws.constants.fill_values_dict[np.dtype(type)],
            type,
        ),
    )

    out_subset_ds[var].attrs = var_meta

#########################################################################################
#Is this the model running???? YES
for istep in range(n_time_steps):
    model.advance()
    model.calculate()

    if model_output_netcdf:
        model.output()

    for var in var_list:
        proc = model.processes[var_proc[var]]
        data = proc[var]
        if isinstance(proc[var], pws.base.timeseries.TimeseriesArray):
            data = data.current
        if var not in subset_vars:
            out_subset_ds[var][istep, :] = data
        else:
            indices = spatial_subsets[var_subset_key[var]]["indices"]
            out_subset_ds[var][istep, :] = data[indices]

    for diag_key, diag_val in diagnostic_var_dict.items():
        input_dict = {}
        for ii in diag_val["inputs"]:
            proc = model.processes[var_proc[ii]]
            input_dict[ii] = proc[ii]

        out_subset_ds[diag_key][istep, :] = diag_val["function"](**input_dict)#this is where the diag_var is actually being calc'd/time step


out_subset_ds.to_netcdf(custom_output_file)
out_subset_ds.close()
print(f"Model run finished! That took {time.time()-sttime:.3f} looong seconds")

del proc
del input_dict
del model
del out_subset_ds


### Diagnostic Check
#### Checks the custom output against the standard outputs

In [None]:
if model_output_netcdf:
    out_subset_ds = xr.open_dataset(custom_output_file)

    for vv in var_list:
        default_output_file = out_dir / f"{vv}.nc"
        print("checking variable: ", vv)
        answer = xr.load_dataarray(default_output_file)
        
        result = out_subset_ds[vv]

        if vv in subset_vars:
            indices = spatial_subsets[var_subset_key[vv]]["indices"]
            answer = answer[:, indices[0]]

        np.testing.assert_allclose(answer, result)
        answer.close()

    for diag_key, diag_val in diagnostic_var_dict.items():
        print("checking diagnostic variable: ", diag_key)
        input_dict = {}
        for ii in diag_val["inputs"]:
            default_output_file = out_dir / f"{ii}.nc"
            input_dict[ii] = xr.load_dataarray(default_output_file)

        answer = diag_val["function"](**input_dict)
        result = out_subset_ds[diag_key]

        np.testing.assert_allclose(answer, result)
        
    out_subset_ds.close()

#### Reading the custom output.nc file

In [None]:
model_output  = xr.load_dataset(out_dir / 'model_custom_output.nc')

In [None]:
model_output

In [None]:
#model_output.snowmelt.values[100,400]