Skip to content

Commit

Permalink
refactor: calc_method clean up on channel and gw
Browse files Browse the repository at this point in the history
  • Loading branch information
jmccreight committed Jun 22, 2023
1 parent 2c2526c commit cc44610
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 117 deletions.
41 changes: 20 additions & 21 deletions pywatershed/hydrology/PRMSChannel.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,27 +344,26 @@ def _init_calc_method(self):
if self._calc_method.lower() == "numba":
import numba as nb

if not hasattr(self, "_muskingum_mann_numba"):
numba_msg = f"{self.name} jit compiling with numba "
# this method can not be parallelized (? true?)
print(numba_msg, flush=True)

self._muskingum_mann = nb.njit(
nb.types.UniTuple(nb.float64[:], 7)(
nb.int64[:], # _segment_order
nb.int64[:], # _tosegment
nb.float64[:], # seg_lateral_inflow
nb.float64[:], # _seg_inflow0
nb.float64[:], # _outflow_ts
nb.int64[:], # _tsi
nb.float64[:], # _ts
nb.float64[:], # _c0
nb.float64[:], # _c1
nb.float64[:], # _c2
),
fastmath=True,
parallel=False,
)(self._muskingum_mann_numpy)
numba_msg = f"{self.name} jit compiling with numba "
# this method can not be parallelized (? true?)
print(numba_msg, flush=True)

self._muskingum_mann = nb.njit(
nb.types.UniTuple(nb.float64[:], 7)(
nb.int64[:], # _segment_order
nb.int64[:], # _tosegment
nb.float64[:], # seg_lateral_inflow
nb.float64[:], # _seg_inflow0
nb.float64[:], # _outflow_ts
nb.int64[:], # _tsi
nb.float64[:], # _ts
nb.float64[:], # _c0
nb.float64[:], # _c1
nb.float64[:], # _c2
),
fastmath=True,
parallel=False,
)(self._muskingum_mann_numpy)

elif self._calc_method.lower() in ["none", "numpy"]:
self._muskingum_mann = self._muskingum_mann_numpy
Expand Down
154 changes: 58 additions & 96 deletions pywatershed/hydrology/PRMSGroundwater.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,7 @@ def __init__(

self._set_inputs(locals())
self._set_budget(budget_type)

if calc_method == "numba":
# read-only arrays dont have numba signatures
self._hru_area = self.hru_area.copy()
self._gwflow_coef = self.gwflow_coef.copy()
self._gwsink_coef = self.gwsink_coef.copy()
self._init_calc_method()

return

Expand Down Expand Up @@ -132,6 +127,46 @@ def _set_initial_conditions(self):
self.gwres_stor_old[:] = self.gwstor_init.copy()
return

def _init_calc_method(self):
if self._calc_method.lower() == "numba":
import numba as nb

numba_msg = f"{self.name} jit compiling with numba "
nb_parallel = (numba_num_threads is not None) and (
numba_num_threads > 1
)
if nb_parallel:
numba_msg += f"and using {numba_num_threads} threads"
print(numba_msg, flush=True)

self._calculate_gw = nb.njit(
nb.types.UniTuple(nb.float64[:], 5)(
nb.types.Array(nb.types.float64, 1, "C", readonly=True),
nb.float64[:],
nb.float64[:],
nb.float64[:],
nb.float64[:],
nb.types.Array(nb.types.float64, 1, "C", readonly=True),
nb.types.Array(nb.types.float64, 1, "C", readonly=True),
nb.float64[:],
nb.types.Array(nb.types.float64, 1, "C", readonly=True),
),
fastmath=True,
parallel=False,
)(self._calculate_numpy)

elif self._calc_method.lower() in ["none", "numpy"]:
self._calculate_gw = self._calculate_numpy

elif self._calc_method.lower() == "fortran":
self._calculate_gw = _calculate_fortran

else:
msg = f"Invalid calc_method={self._calc_method} for {self.name}"
raise ValueError(msg)

return

def _advance_variables(self) -> None:
"""Advance the groundwater reservoir variables
Returns:
Expand All @@ -150,97 +185,24 @@ def _calculate(self, simulation_time):
None
"""

self._simulation_time = simulation_time

if self._calc_method.lower() == "numba":
import numba as nb

if not hasattr(self, "_calculate_numba"):
numba_msg = f"{self.name} jit compiling with numba "
nb_parallel = (numba_num_threads is not None) and (
numba_num_threads > 1
)
if nb_parallel:
numba_msg += f"and using {numba_num_threads} threads"
print(numba_msg, flush=True)

self._calculate_numba = nb.njit(
nb.types.UniTuple(nb.float64[:], 5)(
nb.float64[:],
nb.float64[:],
nb.float64[:],
nb.float64[:],
nb.float64[:],
nb.float64[:],
nb.float64[:],
nb.float64[:],
nb.types.Array(nb.types.float64, 1, "C", readonly=True)
# nb.float64[:],
),
parallel=False,
)(self._calculate_numpy)

(
self.gwres_stor[:],
self.gwres_flow[:],
self.gwres_sink[:],
self.gwres_stor_change[:],
self.gwres_flow_vol[:],
) = self._calculate_numba(
self._hru_area,
self.soil_to_gw,
self.ssr_to_gw,
self.dprst_seep_hru,
self.gwres_stor,
self._gwflow_coef,
self._gwsink_coef,
self.gwres_stor_old,
self.hru_in_to_cf,
)

elif self._calc_method.lower() == "fortran":
(
self.gwres_stor[:],
self.gwres_flow[:],
self.gwres_sink[:],
self.gwres_stor_change[:],
self.gwres_flow_vol[:],
) = _calculate_fortran(
self.hru_area,
self.soil_to_gw,
self.ssr_to_gw,
self.dprst_seep_hru,
self.gwres_stor,
self.gwflow_coef,
self.gwsink_coef,
self.gwres_stor_old,
self.hru_in_to_cf,
)

elif self._calc_method.lower() in ["none", "numpy"]:
(
self.gwres_stor[:],
self.gwres_flow[:],
self.gwres_sink[:],
self.gwres_stor_change[:],
self.gwres_flow_vol[:],
) = self._calculate_numpy(
self.hru_area,
self.soil_to_gw,
self.ssr_to_gw,
self.dprst_seep_hru,
self.gwres_stor,
self.gwflow_coef,
self.gwsink_coef,
self.gwres_stor_old,
self.hru_in_to_cf,
)

else:
msg = f"Invalid calc_method={self._calc_method} for {self.name}"
raise ValueError(msg)

(
self.gwres_stor[:],
self.gwres_flow[:],
self.gwres_sink[:],
self.gwres_stor_change[:],
self.gwres_flow_vol[:],
) = self._calculate_gw(
self.hru_area,
self.soil_to_gw,
self.ssr_to_gw,
self.dprst_seep_hru,
self.gwres_stor,
self.gwflow_coef,
self.gwsink_coef,
self.gwres_stor_old,
self.hru_in_to_cf,
)
return

@staticmethod
Expand Down

0 comments on commit cc44610

Please sign in to comment.