Skip to content

Commit

Permalink
set rng in numba, and make all calls to random in numba jitted functions
Browse files Browse the repository at this point in the history
  • Loading branch information
amoodie committed May 9, 2020
1 parent 01bfe2b commit 9fda2c2
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 10 deletions.
6 changes: 5 additions & 1 deletion pyDeltaRCM/init_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
import logging
import time
import yaml

from . import utils


# tools for initiating deltaRCM model domain


Expand Down Expand Up @@ -98,7 +102,7 @@ def import_files(self):
if self.seed is not None:
if self.verbose >= 2:
print("setting random seed to %s " % str(self.seed))
np.random.seed(self.seed)
utils.set_random_seed(self.seed)

def set_constants(self):

Expand Down
8 changes: 4 additions & 4 deletions pyDeltaRCM/sed_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,8 +240,8 @@ def sand_route(self):
theta_sed = self.theta_sand

num_starts = int(self.Np_sed * self.f_bedload)
start_indices = [self.random_pick_inlet(
self.inlet) for x in range(num_starts)]
typed_inlet = self.inlet_typed
start_indices = [utils.random_pick_inlet(typed_inlet) for x in range(num_starts)]

for np_sed in range(num_starts):

Expand Down Expand Up @@ -282,8 +282,8 @@ def mud_route(self):
theta_sed = self.theta_mud

num_starts = int(self.Np_sed * (1 - self.f_bedload))
start_indices = [self.random_pick_inlet(
self.inlet) for x in range(num_starts)]
typed_inlet = self.inlet_typed
start_indices = [utils.random_pick_inlet(typed_inlet) for x in range(num_starts)]

for np_sed in range(num_starts):

Expand Down
10 changes: 10 additions & 0 deletions pyDeltaRCM/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,16 @@
# utilities used in various places in the model and docs


@numba.njit
def set_random_seed(_seed):
np.random.seed(_seed)


@numba.njit
def get_random_uniform(N):
return np.random.uniform(0, 1, N)


@numba.njit
def random_pick(probs):
"""
Expand Down
11 changes: 6 additions & 5 deletions tests/test_yaml_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from pyDeltaRCM.deltaRCM_driver import pyDeltaRCM

from pyDeltaRCM.utils import set_random_seed, get_random_uniform

# utilities for file writing
def create_temporary_file(tmp_path, file_name):
Expand Down Expand Up @@ -105,13 +106,13 @@ def test_random_seed_settings_value(tmp_path):
p, f = create_temporary_file(tmp_path, file_name)
write_parameter_to_file(f, 'seed', 9999)
f.close()
np.random.seed(9999)
_preval_same = np.random.uniform()
np.random.seed(5)
_preval_diff = np.random.uniform(1000)
set_random_seed(9999)
_preval_same = get_random_uniform(1)
set_random_seed(5)
_preval_diff = get_random_uniform(1000)
delta = pyDeltaRCM(input_file=p)
assert delta.seed == 9999
_postval_same = np.random.uniform()
_postval_same = get_random_uniform(1)
assert _preval_same == _postval_same


Expand Down

0 comments on commit 9fda2c2

Please sign in to comment.