Skip to content

Commit

Permalink
Refactor: control does not take/know parameters, pass dis and paramet…
Browse files Browse the repository at this point in the history
…ers to individual processes, individual process tests passing
  • Loading branch information
jmccreight committed Jun 15, 2023
1 parent 65626cb commit 6be7331
Show file tree
Hide file tree
Showing 24 changed files with 707 additions and 649 deletions.
239 changes: 130 additions & 109 deletions autotest/test_prms_atmosphere.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,118 +12,139 @@
params = ["params_sep", "params_one"]


# @pytest.fixture(scope="function", params=params)
# def control(domain, request):
# if request.param == "params_one":
# params = PrmsParameters.load(domain["param_file"])
# dis = None

# else:
# # channel needs both hru and seg dis files
# dis_hru_file = domain["dir"] / "parameters_dis_hru.nc"
# dis_data = Parameters.merge(
# Parameters.from_netcdf(dis_hru_file, encoding=False),
# )
# dis = {"dis_hru": dis_data}

# param_file = domain["dir"] / "parameters_PRMSAtmosphere.nc"
# params = {"PRMSAtmosphere": PrmsParameters.from_netcdf(param_file)}

# return Control.load(domain["control_file"], params=params, dis=dis)


@pytest.fixture(scope="function")
def control(domain):
return Control.load(domain["control_file"])


@pytest.fixture(scope="function")
def discretization(domain):
dis_hru_file = domain["dir"] / "parameters_dis_hru.nc"
return Parameters.from_netcdf(dis_hru_file, encoding=False)


@pytest.fixture(scope="function", params=params)
def control(domain, request):
def parameters(domain, request):
if request.param == "params_one":
params = PrmsParameters.load(domain["param_file"])
dis = None

else:
# channel needs both hru and seg dis files
dis_hru_file = domain["dir"] / "parameters_dis_hru.nc"
dis_data = Parameters.merge(
Parameters.from_netcdf(dis_hru_file, encoding=False),
)
dis = {"dis_hru": dis_data}

param_file = domain["dir"] / "parameters_PRMSAtmosphere.nc"
params = {"PRMSAtmosphere": PrmsParameters.from_netcdf(param_file)}

return Control.load(domain["control_file"], params=params, dis=dis)


class TestPRMSAtmosphere:
def test_init(self, domain, control, tmp_path):
output_dir = domain["prms_output_dir"]
cbh_dir = domain["cbh_inputs"]["prcp"].parent.resolve()

# get the answer data
comparison_var_names = [
"tmaxf",
"tminf",
"hru_ppt",
"hru_rain",
"hru_snow",
"swrad",
"potet",
"transp_on",
"tmaxc",
"tavgc",
"tminc",
"prmx",
"pptmix",
"orad_hru",
]
ans = {}
for key in comparison_var_names:
nc_pth = output_dir / f"{key}.nc"
ans[key] = adapter_factory(
nc_pth, variable_name=key, control=control
)

input_variables = {}
for key in PRMSAtmosphere.get_inputs():
dir = ""
if "soltab" in key:
dir = "output/"
nc_pth = cbh_dir / f"{dir}{key}.nc"
input_variables[key] = nc_pth

atm = PRMSAtmosphere(
control=control,
**input_variables,
budget_type=None,
netcdf_output_dir=tmp_path,
)

all_success = True
for istep in range(control.n_times):
control.advance()
atm.advance()
atm.calculate(1.0)
# print(atm.budget)

# compare along the way
for key, val in ans.items():
val.advance()

for key in ans.keys():
a1 = ans[key].current
a2 = atm[key].current

params = PrmsParameters.from_netcdf(param_file)

return params


def test_compare_prms(domain, control, discretization, parameters, tmp_path):
output_dir = domain["prms_output_dir"]
cbh_dir = domain["cbh_inputs"]["prcp"].parent.resolve()

# get the answer data
comparison_var_names = [
"tmaxf",
"tminf",
"hru_ppt",
"hru_rain",
"hru_snow",
"swrad",
"potet",
"transp_on",
"tmaxc",
"tavgc",
"tminc",
"prmx",
"pptmix",
"orad_hru",
]
ans = {}
for key in comparison_var_names:
nc_pth = output_dir / f"{key}.nc"
ans[key] = adapter_factory(nc_pth, variable_name=key, control=control)

input_variables = {}
for key in PRMSAtmosphere.get_inputs():
dir = ""
if "soltab" in key:
dir = "output/"
nc_pth = cbh_dir / f"{dir}{key}.nc"
input_variables[key] = nc_pth

atm = PRMSAtmosphere(
control=control,
discretization=discretization,
parameters=parameters,
**input_variables,
budget_type=None,
netcdf_output_dir=tmp_path,
)

all_success = True
for istep in range(control.n_times):
control.advance()
atm.advance()
atm.calculate(1.0)
# print(atm.budget)

# compare along the way
for key, val in ans.items():
val.advance()

for key in ans.keys():
a1 = ans[key].current
a2 = atm[key].current

tol = 1e-5
if key == "swrad":
tol = 5e-4
warn(f"using tol = {tol} for variable {key}")
if key == "tavgc":
tol = 1e-5
if key == "swrad":
tol = 5e-4
warn(f"using tol = {tol} for variable {key}")
if key == "tavgc":
tol = 1e-5
warn(f"using tol = {tol} for variable {key}")

success_a = np.allclose(a2, a1, atol=tol, rtol=0.00)
success_r = np.allclose(a2, a1, atol=0.00, rtol=tol)
success = False
if (not success_a) and (not success_r):
diff = a2 - a1
diffratio = abs(diff / a2)
if (diffratio < 1e-6).all():
success = True
continue
all_success = False
diffmin = diff.min()
diffmax = diff.max()
abs_diff = abs(diff)
absdiffmax = abs_diff.max()
wh_absdiffmax = np.where(abs_diff)[0]
print(f"time step {istep}")
print(f"output variable {key}")
print(f"prms {a1.min()} {a1.max()}")
print(f"pywatershed {a2.min()} {a2.max()}")
print(f"diff {diffmin} {diffmax}")
print(f"absdiffmax {absdiffmax}")
print(f"wh_absdiffmax {wh_absdiffmax}")
assert success

atm.finalize()

if not all_success:
raise Exception("pywatershed results do not match prms results")
warn(f"using tol = {tol} for variable {key}")

success_a = np.allclose(a2, a1, atol=tol, rtol=0.00)
success_r = np.allclose(a2, a1, atol=0.00, rtol=tol)
success = False
if (not success_a) and (not success_r):
diff = a2 - a1
diffratio = abs(diff / a2)
if (diffratio < 1e-6).all():
success = True
continue
all_success = False
diffmin = diff.min()
diffmax = diff.max()
abs_diff = abs(diff)
absdiffmax = abs_diff.max()
wh_absdiffmax = np.where(abs_diff)[0]
print(f"time step {istep}")
print(f"output variable {key}")
print(f"prms {a1.min()} {a1.max()}")
print(f"pywatershed {a2.min()} {a2.max()}")
print(f"diff {diffmin} {diffmax}")
print(f"absdiffmax {absdiffmax}")
print(f"wh_absdiffmax {wh_absdiffmax}")
assert success

atm.finalize()

if not all_success:
raise Exception("pywatershed results do not match prms results")
73 changes: 18 additions & 55 deletions autotest/test_prms_canopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,77 +6,38 @@
from pywatershed.base.adapter import adapter_factory
from pywatershed.base.control import Control
from pywatershed.hydrology.PRMSCanopy import PRMSCanopy, has_prmscanopy_f
from pywatershed.parameters import PrmsParameters
from pywatershed.parameters import Parameters, PrmsParameters

calc_methods = ("numpy", "numba", "fortran")
params = ("params_sep", "params_one")


def test_simple():
time_dict = {
"start_time": np.datetime64("1979-01-03T00:00:00.00"),
"end_time": np.datetime64("1979-01-04T00:00:00.00"),
"time_step": np.timedelta64(1, "D"),
}

nhru = 2
prms_params = {
"dims": {"nhru": nhru},
"data_vars": {
"hru_area": np.array(nhru * [1.0]),
"covden_sum": np.array(nhru * [0.5]),
"covden_win": np.array(nhru * [0.5]),
"srain_intcp": np.array(nhru * [1.0]),
"wrain_intcp": np.array(nhru * [1.0]),
"snow_intcp": np.array(nhru * [1.0]),
"epan_coef": np.array(nhru * [1.0]),
"potet_sublim": np.array(nhru * [1.0]),
"cov_type": np.array(nhru * [1]),
},
}
prms_params["metadata"] = {}
for kk in prms_params["data_vars"].keys():
prms_params["metadata"][kk] = {"dims": ("nhru",)}

prms_params = PrmsParameters(**prms_params)

control = Control(**time_dict, params=prms_params)
@pytest.fixture(scope="function")
def control(domain):
return Control.load(domain["control_file"])

input_variables = {}
for key in PRMSCanopy.get_inputs():
input_variables[key] = np.ones([nhru])

# todo: this is testing instantiation, but not physics
cnp = PRMSCanopy(
control=control,
**input_variables,
budget_type="error",
)
control.advance()
cnp.advance()
cnp.calculate(time_length=1.0)

return
@pytest.fixture(scope="function")
def discretization(domain):
dis_hru_file = domain["dir"] / "parameters_dis_hru.nc"
return Parameters.from_netcdf(dis_hru_file, encoding=False)


@pytest.fixture(scope="function", params=["params_sep", "params_one"])
def params(domain, request):
@pytest.fixture(scope="function", params=params)
def parameters(domain, request):
if request.param == "params_one":
params = PrmsParameters.load(domain["param_file"])
else:
params = PrmsParameters.from_netcdf(
domain["dir"] / "parameters_PRMSCanopy.nc"
)
param_file = domain["dir"] / "parameters_PRMSCanopy.nc"
params = PrmsParameters.from_netcdf(param_file)

return params


@pytest.fixture(scope="function")
def control(domain, params):
return Control.load(domain["control_file"], params=params)


@pytest.mark.parametrize("calc_method", calc_methods)
def test_compare_prms(domain, control, tmp_path, calc_method):
def test_compare_prms(
domain, control, discretization, parameters, tmp_path, calc_method
):
if not has_prmscanopy_f and calc_method == "fortran":
pytest.skip(
"PRMSCanopy fortran code not available, skipping its test."
Expand Down Expand Up @@ -109,6 +70,8 @@ def test_compare_prms(domain, control, tmp_path, calc_method):

cnp = PRMSCanopy(
control=control,
discretization=discretization,
parameters=parameters,
**input_variables,
budget_type="error",
calc_method=calc_method,
Expand Down
Loading

0 comments on commit 6be7331

Please sign in to comment.