Skip to content

Commit

Permalink
Pr revert 1658 (#1659)
Browse files Browse the repository at this point in the history
* Revert "fix random seed (#1658)"

This reverts commit 101d7a1.

* revert PR 1658 and add more docstring for random seed

* update
  • Loading branch information
hnyu authored Jun 5, 2024
1 parent 4639a81 commit bbc9664
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 29 deletions.
13 changes: 12 additions & 1 deletion alf/algorithms/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,18 @@ def __init__(self,
normalized whereas hindsight data directly pulled from the replay buffer
will not be normalized. Data will be in mismatch, causing training to
suffer and potentially fail.
random_seed (None|int): random seed, a random seed is used if None
random_seed (None|int): random seed, a random seed is used if None.
For a None random seed, all DDP ranks (if multi-gpu training used)
will have a None random seed set to their ``TrainerConfig.random_seed``.
This means that the actual random seed used by each rank is purely
random. A None random seed won't set a deterministic torch behavior.
If a specific random seed is set, DDP rank>0 (if multi-gpu training
used) will have a random seed set to a value that is deterministically
"randomized" from this random seed. In this case, all ranks will
have a deterministic torch behavior. NOTE: By the current design,
you won't be able to reproduce a training job if its random seed
was set as None. For reproducible training jobs, always set the
random seed in the first place.
num_iterations (int): For RL trainer, indicates number of update
iterations (ignored if 0). Note that for off-policy algorithms, if
``initial_collect_steps>0``, then the first
Expand Down
21 changes: 6 additions & 15 deletions alf/config_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,22 +269,13 @@ def get_env():
#
# The random.seed(random_seed) is temporary and will be overridden by
# set_random_seed() below.
random.seed(random_seed)

# ``random_seed`` here is the *meta* random seed. We will use it to generate
# an actual random seed for each process.
if random_seed is None:
# If the user does not specify the meta random seed, we will generate a
# "random" meta random seed here.
random.seed(None)
random_seed = random.randint(0, 2**32)

# Adjust the random seed based on the DDP rank. Note that for single process,
# the ddp rank is -1. So single process or rank=0 won't change the random seed.
# This means that if a user provides a random seed shown in the tensor
# board, it will be used as is for single process and rank=0. This ensures
# that the training can be reproduced.
for _ in range(PerProcessContext().ddp_rank):
random_seed = random.randint(0, 2**32)
if random_seed is not None:
# If random seed is None, we will have None for other ranks, too.
# A 'None' random seed won't set a deterministic torch behavior.
for _ in range(PerProcessContext().ddp_rank):
random_seed = random.randint(0, 2**32)
config1("TrainerConfig.random_seed", random_seed, raise_if_used=False)

# We have to call set_random_seed() here because we need the actual
Expand Down
17 changes: 4 additions & 13 deletions alf/environments/process_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,21 +300,12 @@ def start(self, wait_to_start=True):
ddp_num_procs = PerProcessContext().num_processes
ddp_rank = PerProcessContext().ddp_rank

pre_configs = alf.get_handled_pre_configs()
if self._start_method == 'spawn':
# If we spawn an env, TrainerConfig.random_seed of the main processs
# won't be inherited. In case we need this info later in the env, we
# need to pass it explicitly, so that later we can do
# ``alf.get_config_value("TrainerConfig.random_seed")``.
pre_configs.append(
('TrainerConfig.random_seed',
alf.get_config_value('TrainerConfig.random_seed')))
self._process = mp_ctx.Process(
target=_worker,
args=(conn, self._env_constructor, self._start_method, pre_configs,
self._env_id, self._flatten, self._fast, self._num_envs,
self._torch_num_threads, ddp_num_procs, ddp_rank,
self._name),
args=(conn, self._env_constructor, self._start_method,
alf.get_handled_pre_configs(), self._env_id, self._flatten,
self._fast, self._num_envs, self._torch_num_threads,
ddp_num_procs, ddp_rank, self._name),
name=f"ProcessEnvironment-{self._env_id}")
atexit.register(self.close)
self._process.start()
Expand Down

0 comments on commit bbc9664

Please sign in to comment.