Skip to content

Commit

Permalink
Refactor: model now takes a dictionary of [control, dis, process, ord…
Browse files Browse the repository at this point in the history
…er] while maintaining backwards compatability, very minor changes to api (passing list instead of unpacking list for arguments).
  • Loading branch information
jmccreight committed Jun 16, 2023
1 parent 6be7331 commit afbe91c
Show file tree
Hide file tree
Showing 16 changed files with 348 additions and 139 deletions.
10 changes: 7 additions & 3 deletions autotest/test_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ def params_simple():


@pytest.fixture(scope="function")
def control_simple(params_simple):
return Control(**time_dict, params=params_simple)
def control_simple():
return Control(**time_dict)


def test_control_simple(control_simple):
Expand Down Expand Up @@ -88,7 +88,7 @@ def test_control_simple(control_simple):
control_simple.advance()


def test_control_advance(control_simple):
def test_control_advance(control_simple, params_simple):
# common inputs for 2 canopies
input_variables = {}
for key in PRMSCanopy.get_inputs():
Expand All @@ -98,13 +98,17 @@ def test_control_advance(control_simple):
# ntimes = control.n_times
cnp1 = PRMSCanopy(
control=control_simple,
discretization=None,
parameters=params_simple,
**input_variables,
verbose=True,
)
cnp1.name = "cnp1"

cnp2 = PRMSCanopy(
control=control_simple,
discretization=None,
parameters=params_simple,
**input_variables,
verbose=True,
)
Expand Down
82 changes: 71 additions & 11 deletions autotest/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from pywatershed.base.adapter import adapter_factory
from pywatershed.base.control import Control
from pywatershed.base.model import Model
from pywatershed.parameters import PrmsParameters
from pywatershed.parameters import Parameters, PrmsParameters

compare_to_prms521 = False # TODO TODO TODO
failfast = True
Expand All @@ -30,25 +30,64 @@
}


params = ("params_sep", "params_one")


@pytest.fixture(scope="function")
def params(domain):
return PrmsParameters.load(domain["param_file"])
def control(domain):
return Control.load(domain["control_file"])


@pytest.fixture(scope="function")
def control(domain, params):
control = Control.load(domain["control_file"], params=params)
control.edit_n_time_steps(n_time_steps)
return control
def discretization(domain):
dis_hru_file = domain["dir"] / "parameters_dis_hru.nc"
dis_seg_file = domain["dir"] / "parameters_dis_seg.nc"
dis_hru = Parameters.from_netcdf(dis_hru_file, encoding=False)
# PRMSChannel needs both dis where as it should only need dis_seg
# and will when we have exchanges
dis_combined = Parameters.merge(
Parameters.from_netcdf(dis_hru_file, encoding=False),
Parameters.from_netcdf(dis_seg_file, encoding=False),
)
dis = {"dis_hru": dis_hru, "dis_combined": dis_combined}

return dis


@pytest.fixture(scope="function", params=params)
def parameters(domain, request):
if request.param == "params_one":
params = PrmsParameters.load(domain["param_file"])
else:
# In this case we are not passing parameters
# but the model dict
params = {}
for process in test_models["nhm"]:
proc_name = process.__name__
params[proc_name] = {}
proc = params[proc_name]
proc["class"] = process
proc_param_file = domain["dir"] / f"parameters_{proc_name}.nc"
proc["parameters"] = PrmsParameters.from_netcdf(proc_param_file)
if proc_name == "PRMSChannel":
proc["dis"] = "dis_combined"
else:
proc["dis"] = "dis_hru"

return params


@pytest.mark.parametrize(
"processes",
test_models.values(),
ids=test_models.keys(),
)
def test_model(domain, control, processes, tmp_path):
def test_model(
domain, control, discretization, parameters, processes, tmp_path
):
"""Run the full NHM model"""
control_copy = control

tmp_path = pl.Path(tmp_path)
output_dir = domain["prms_output_dir"]

Expand All @@ -60,10 +99,31 @@ def test_model(domain, control, processes, tmp_path):
for ff in output_dir.parent.resolve().glob("*.nc"):
shutil.copy(ff, input_dir / ff.name)

if isinstance(parameters, PrmsParameters):
process_list_or_model_dict = processes
discretization = None # not used

elif isinstance(parameters, dict):
assert isinstance(discretization, dict)
process_list_or_model_dict = discretization | parameters
process_list_or_model_dict["control"] = control
process_list_or_model_dict["model_order"] = [
pp.__name__ for pp in processes
]
# all arr folded in to the above dict
control = None
discretization = None
parameters = None

else:
raise ValueError("what type is parameters?")

# TODO: Eliminate potet and other variables from being used
model = Model(
*processes,
process_list_or_model_dict,
control=control,
discretization_dict=discretization,
parameters=parameters,
input_dir=input_dir,
budget_type=budget_type,
load_n_time_batches=3,
Expand Down Expand Up @@ -177,7 +237,7 @@ def test_model(domain, control, processes, tmp_path):
else:
nc_pth = input_dir / f"{vv}.nc"
ans[unit_name][vv] = adapter_factory(
nc_pth, variable_name=vv, control=control
nc_pth, variable_name=vv, control=control_copy
)

# ---------------------------------
Expand Down Expand Up @@ -206,7 +266,7 @@ def test_model(domain, control, processes, tmp_path):
all_success = True
fail_prms_compare = False
fail_regression = False
for istep in range(control.n_times):
for istep in range(control_copy.n_times):
model.advance()
model.calculate()

Expand Down
16 changes: 10 additions & 6 deletions autotest/test_netcdf_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ def params(domain):


@pytest.fixture(scope="function")
def control(domain, params):
control = Control.load(domain["control_file"], params=params)
def control(domain):
control = Control.load(domain["control_file"])
control.edit_n_time_steps(n_time_steps)
return control

Expand Down Expand Up @@ -58,7 +58,7 @@ def control(domain, params):
check_budget_sum_vars_params,
ids=[str(ii) for ii in check_budget_sum_vars_params],
)
def test_process_budgets(domain, control, tmp_path, budget_sum_param):
def test_process_budgets(domain, control, params, tmp_path, budget_sum_param):
tmp_dir = pl.Path(tmp_path)
# print(tmp_dir)
model_procs = [pywatershed.PRMSCanopy, pywatershed.PRMSChannel]
Expand All @@ -81,8 +81,10 @@ def test_process_budgets(domain, control, tmp_path, budget_sum_param):

# TODO: Eliminate potet and other variables from being used
model = Model(
*model_procs,
model_procs,
control=control,
discretization_dict=None,
parameters=params,
input_dir=input_dir,
budget_type=budget_type,
)
Expand Down Expand Up @@ -159,7 +161,7 @@ def test_process_budgets(domain, control, tmp_path, budget_sum_param):
[False, True],
ids=["grp_by_process", "separate"],
)
def test_separate_together(domain, control, tmp_path, separate):
def test_separate_together(domain, control, params, tmp_path, separate):
tmp_dir = pl.Path(tmp_path)

model_procs = [
Expand All @@ -181,8 +183,10 @@ def test_separate_together(domain, control, tmp_path, separate):
shutil.copy(ff, input_dir / ff.name)

model = Model(
*model_procs,
model_procs,
control=control,
discretization_dict=None,
parameters=params,
input_dir=input_dir,
budget_type=budget_type,
)
Expand Down
10 changes: 6 additions & 4 deletions autotest/test_nhm_self_drive.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,13 @@ def test_drive_indiv_process(domain, tmp_path):
nhm_output_dir = pl.Path(tmp_path) / "nhm_output"

params = pws.parameters.PrmsParameters.load(domain["param_file"])
control = pws.Control.load(domain["control_file"], params=params)
control = pws.Control.load(domain["control_file"])
control.edit_n_time_steps(n_time_steps)

nhm = pws.Model(
*nhm_processes,
nhm_processes,
control=control,
parameters=params,
input_dir=domain["prms_run_dir"],
budget_type="warn",
calc_method="numba",
Expand All @@ -56,12 +57,13 @@ def test_drive_indiv_process(domain, tmp_path):
proc_model_output_dir.mkdir()

params = pws.parameters.PrmsParameters.load(domain["param_file"])
control = pws.Control.load(domain["control_file"], params=params)
control = pws.Control.load(domain["control_file"])
control.edit_n_time_steps(n_time_steps)

proc_model = pws.Model(
proc,
[proc],
control=control,
parameters=params,
input_dir=nhm_output_dir,
budget_type="warn",
calc_method="numba",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ def params(domain):


@pytest.fixture(scope="function")
def control(domain, params):
return Control.load(domain["control_file"], params=params)
def control(domain):
return Control.load(domain["control_file"])


class TestPRMSCanopyRunoffDomain:
def test_init(self, domain, control, tmp_path):
def test_init(self, domain, control, params, tmp_path):
tmp_path = pl.Path(tmp_path)

# get the answer data
Expand All @@ -48,7 +48,12 @@ def test_init(self, domain, control, tmp_path):
nc_pth = output_dir / f"{key}.nc"
input_variables[key] = nc_pth

canopy = PRMSCanopy(control=control, **input_variables)
canopy = PRMSCanopy(
control=control,
discretization=None,
parameters=params,
**input_variables,
)

# instantiate runoff
input_variables = {}
Expand All @@ -60,7 +65,12 @@ def test_init(self, domain, control, tmp_path):
input_variables["net_ppt"] = None
input_variables["net_rain"] = None
input_variables["net_snow"] = None
runoff = PRMSRunoff(control=control, **input_variables)
runoff = PRMSRunoff(
control=control,
discretization=None,
parameters=params,
**input_variables,
)

# wire up output from canopy as input to runoff
runoff.set_input_to_adapter(
Expand Down
2 changes: 1 addition & 1 deletion autotest/test_prms_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
fail_fast = False

calc_methods = ("numpy", "numba", "fortran")
params = ["params_sep", "params_one"]
params = ("params_sep", "params_one")


@pytest.fixture(scope="function")
Expand Down
8 changes: 5 additions & 3 deletions autotest/test_et.py → autotest/test_prms_et.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ def params(domain):


@pytest.fixture(scope="function")
def control(domain, params):
return Control.load(domain["control_file"], params=params)
def control(domain):
return Control.load(domain["control_file"])


class TestPRMSEt:
def test_init(self, domain, control, tmp_path):
def test_init(self, domain, control, params, tmp_path):
tmp_path = pl.Path(tmp_path)
output_dir = domain["prms_output_dir"]

Expand All @@ -31,6 +31,8 @@ def test_init(self, domain, control, tmp_path):

et = PRMSEt(
control=control,
discretization=None,
parameters=params,
budget_type="strict",
**et_inputs,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@ def params(domain):


@pytest.fixture(scope="function")
def control(domain, params):
return Control.load(domain["control_file"], params=params)
def control(domain):
return Control.load(domain["control_file"])


class TestPRMSCanopyRunoffDomain:
def test_init(self, domain, control, tmp_path):
def test_init(self, domain, control, params, tmp_path):
tmp_path = pl.Path(tmp_path)
output_dir = domain["prms_output_dir"]

Expand Down Expand Up @@ -78,6 +78,8 @@ def test_init(self, domain, control, tmp_path):

et = PRMSEt(
control=control,
discretization=None,
parameters=params,
budget_type="error",
**et_inputs,
)
Expand All @@ -94,6 +96,8 @@ def test_init(self, domain, control, tmp_path):

canopy = PRMSCanopy(
control=control,
discretization=None,
parameters=params,
budget_type="error",
**canopy_inputs,
)
Expand Down Expand Up @@ -121,6 +125,8 @@ def test_init(self, domain, control, tmp_path):

runoff = PRMSRunoff(
control=control,
discretization=None,
parameters=params,
**runoff_inputs,
budget_type=None,
)
Expand Down
Loading

0 comments on commit afbe91c

Please sign in to comment.