Skip to content

Commit

Permalink
Refactor: PRMSChannel init self._muskingum_mann from numpy, numba, fo…
Browse files Browse the repository at this point in the history
…rtran instead if during calculate
  • Loading branch information
jmccreight committed May 26, 2023
1 parent dd99eb7 commit e010992
Showing 1 changed file with 78 additions and 115 deletions.
193 changes: 78 additions & 115 deletions pywatershed/hydrology/PRMSChannel.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ def __init__(
# process channel data
self._initialize_channel_data()

self._init_calc_method()

return

@staticmethod
Expand Down Expand Up @@ -333,6 +335,62 @@ def _initialize_channel_data(self) -> None:

return

def _init_calc_method(self):
if self._calc_method.lower() == "numba":
import numba as nb

if not hasattr(self, "_muskingum_mann_numba"):
numba_msg = f"{self.name} jit compiling with numba "
# this method can not be parallelized
print(numba_msg, flush=True)

# This is annoying that long integers on windows are 32bit
if platform.system() == "Windows":
self._muskingum_mann = nb.njit(
nb.types.UniTuple(nb.float64[:], 7)(
nb.int32[:], # _segment_order
nb.int32[:], # tosegment
nb.float64[:], # seg_lateral_inflow
nb.float64[:], # _seg_inflow0
nb.float64[:], # _outflow_ts
nb.int32[:], # _tsi
nb.float64[:], # _ts
nb.float64[:], # _c0
nb.float64[:], # _c1
nb.float64[:], # _c2
),
fastmath=True,
parallel=False,
)(self._muskingum_mann_numpy)

else:
self._muskingum_mann = nb.njit(
nb.types.UniTuple(nb.float64[:], 7)(
nb.int64[:], # _segment_order
nb.int64[:], # tosegment
nb.float64[:], # seg_lateral_inflow
nb.float64[:], # _seg_inflow0
nb.float64[:], # _outflow_ts
nb.int64[:], # _tsi
nb.float64[:], # _ts
nb.float64[:], # _c0
nb.float64[:], # _c1
nb.float64[:], # _c2
),
fastmath=True,
parallel=False,
)(self._muskingum_mann_numpy)

elif self._calc_method.lower() in ["none", "numpy"]:
self._muskingum_mann = self._muskingum_mann_numpy

elif self._calc_method.lower() == "fortran":
self._muskingum_mann = _calculate_fortran

else:
msg = f"Invalid calc_method={self._calc_method} for {self.name}"
raise ValueError(msg)

def _advance_variables(self) -> None:
"""Advance the channel segment variables
Returns:
Expand Down Expand Up @@ -387,122 +445,27 @@ def _calculate(self, simulation_time: float) -> None:
self.seg_lateral_inflow[iseg] += lateral_inflow

# solve muskingum_mann routing
if self._calc_method.lower() == "numba":
import numba as nb

if not hasattr(self, "_muskingum_mann_numba"):
numba_msg = f"{self.name} jit compiling with numba "
nb_parallel = (numba_num_threads is not None) and (
numba_num_threads > 1
)
if nb_parallel:
numba_msg += f"and using {numba_num_threads} threads"
print(numba_msg, flush=True)

# This is annoying that long integers on windows are 32bit
if platform.system() == "Windows":
self._muskingum_mann_numba = nb.njit(
nb.types.UniTuple(nb.float64[:], 7)(
nb.int32[:], # _segment_order
nb.int32[:], # tosegment
nb.float64[:], # seg_lateral_inflow
nb.float64[:], # _seg_inflow0
nb.float64[:], # _outflow_ts
nb.int32[:], # _tsi
nb.float64[:], # _ts
nb.float64[:], # _c0
nb.float64[:], # _c1
nb.float64[:], # _c2
),
fastmath=True,
parallel=False,
)(self._muskingum_mann_numpy)

else:
self._muskingum_mann_numba = nb.njit(
nb.types.UniTuple(nb.float64[:], 7)(
nb.int64[:], # _segment_order
nb.int64[:], # tosegment
nb.float64[:], # seg_lateral_inflow
nb.float64[:], # _seg_inflow0
nb.float64[:], # _outflow_ts
nb.int64[:], # _tsi
nb.float64[:], # _ts
nb.float64[:], # _c0
nb.float64[:], # _c1
nb.float64[:], # _c2
),
fastmath=True,
)(self._muskingum_mann_numpy)

(
self.seg_upstream_inflow[:],
self._seg_inflow0[:],
self._seg_inflow[:],
self.seg_outflow[:],
self._inflow_ts[:],
self._outflow_ts[:],
self._seg_current_sum[:],
) = self._muskingum_mann_numba(
self._segment_order,
self.tosegment,
self.seg_lateral_inflow,
self._seg_inflow0,
self._outflow_ts,
self._tsi,
self._ts,
self._c0,
self._c1,
self._c2,
)

elif self._calc_method.lower() in ["none", "numpy"]:
(
self.seg_upstream_inflow[:],
self._seg_inflow0[:],
self._seg_inflow[:],
self.seg_outflow[:],
self._inflow_ts[:],
self._outflow_ts[:],
self._seg_current_sum[:],
) = self._muskingum_mann_numpy(
self._segment_order,
self.tosegment,
self.seg_lateral_inflow,
self._seg_inflow0,
self._outflow_ts,
self._tsi,
self._ts,
self._c0,
self._c1,
self._c2,
)

elif self._calc_method.lower() == "fortran":
(
self.seg_upstream_inflow[:],
self._seg_inflow0[:],
self._seg_inflow[:],
self.seg_outflow[:],
self._inflow_ts[:],
self._outflow_ts[:],
self._seg_current_sum[:],
) = _calculate_fortran(
self._segment_order,
self.tosegment,
self.seg_lateral_inflow,
self._seg_inflow0,
self._outflow_ts,
self._tsi,
self._ts,
self._c0,
self._c1,
self._c2,
)

else:
msg = f"Invalid calc_method={self._calc_method} for {self.name}"
raise ValueError(msg)
(
self.seg_upstream_inflow[:],
self._seg_inflow0[:],
self._seg_inflow[:],
self.seg_outflow[:],
self._inflow_ts[:],
self._outflow_ts[:],
self._seg_current_sum[:],
) = self._muskingum_mann(
self._segment_order,
self.tosegment,
self.seg_lateral_inflow,
self._seg_inflow0,
self._outflow_ts,
self._tsi,
self._ts,
self._c0,
self._c1,
self._c2,
)

self.seg_stor_change[:] = (
self._seg_inflow - self.seg_outflow
Expand Down

0 comments on commit e010992

Please sign in to comment.