Skip to content

Commit

Permalink
Fix seeding for vectorized environments and BaseRLModel (#676)
Browse files Browse the repository at this point in the history
* Add a seed() method to vectorized environments (fixes #675).

* Updated Changelog

* Updated Changelog (again) and added type hints.

* Update docstring

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
  • Loading branch information
NeoExtended and araffin committed Feb 3, 2020
1 parent c6acd1e commit 4476ef2
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 8 deletions.
4 changes: 3 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ New Features:
^^^^^^^^^^^^^
- Parallelized updating and sampling from the replay buffer in DQN. (@flodorner)
- Docker build script, `scripts/build_docker.sh`, can push images automatically.
- Added a seeding method for vectorized environments. (@NeoExtended)

Bug Fixes:
^^^^^^^^^^
Expand All @@ -33,6 +34,7 @@ Bug Fixes:
`self.runner` instead of reinitializing a new Runner every time `learn()` is called.
- Fixed a bug in `check_env` where it would fail on high dimensional action spaces
- Fixed `Monitor.close()` that was not calling the parent method
- Fixed a bug in `BaseRLModel` when seeding vectorized environments. (@NeoExtended)

Deprecations:
^^^^^^^^^^^^^
Expand Down Expand Up @@ -612,4 +614,4 @@ Thanks to @bjmuld @iambenzo @iandanforth @r7vme @brendenpetersen @huvar @abhiskk
@XMaster96 @kantneel @Pastafarianist @GerardMaggiolino @PatrickWalter214 @yutingsz @sc420 @Aaahh @billtubbs
@Miffyli @dwiel @miguelrass @qxcv @jaberkow @eavelardev @ruifeng96150 @pedrohbtp @srivatsankrishnan @evilsocket
@MarvineGothic @jdossgollin @SyllogismRXS @rusu24edward @jbulow @Antymon @seheevic @justinkterry @edbeeching
@flodorner @KuKuXia
@flodorner @KuKuXia @NeoExtended
7 changes: 1 addition & 6 deletions stable_baselines/common/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,12 +179,7 @@ def set_random_seed(self, seed):
# Seed python, numpy and tf random generator
set_global_seeds(seed)
if self.env is not None:
if isinstance(self.env, VecEnv):
# Use a different seed for each env
for idx in range(self.env.num_envs):
self.env.env_method("seed", seed + idx)
else:
self.env.seed(seed)
self.env.seed(seed)
# Seed the action space
# useful when selecting random actions
self.env.action_space.seed(seed)
Expand Down
17 changes: 16 additions & 1 deletion stable_baselines/common/vec_env/base_vec_env.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from abc import ABC, abstractmethod
import inspect
import pickle
from typing import Sequence
from typing import Sequence, Optional, List, Union

import cloudpickle
import numpy as np
Expand Down Expand Up @@ -127,6 +127,18 @@ def env_method(self, method_name, *method_args, indices=None, **method_kwargs):
"""
pass

@abstractmethod
def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]:
"""
Sets the random seeds for all environments, based on a given seed.
Each individual environment will still get its own seed, by incrementing the given seed.
:param seed: (Optional[int]) The random seed. May be None for completely random seeding.
:return: (List[Union[None, int]]) Returns a list containing the seeds for each individual env.
Note that all list elements may be None, if the env does not return anything when being seeded.
"""
pass

def step(self, actions):
"""
Step the environments with the given action
Expand Down Expand Up @@ -225,6 +237,9 @@ def reset(self):
def step_wait(self):
pass

def seed(self, seed=None):
return self.venv.seed(seed)

def close(self):
return self.venv.close()

Expand Down
6 changes: 6 additions & 0 deletions stable_baselines/common/vec_env/dummy_vec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@ def step_wait(self):
return (self._obs_from_buf(), np.copy(self.buf_rews), np.copy(self.buf_dones),
self.buf_infos.copy())

def seed(self, seed=None):
seeds = list()
for idx, env in enumerate(self.envs):
seeds.append(env.seed(seed + idx))
return seeds

def reset(self):
for env_idx in range(self.num_envs):
obs = self.envs[env_idx].reset()
Expand Down
7 changes: 7 additions & 0 deletions stable_baselines/common/vec_env/subproc_vec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ def _worker(remote, parent_remote, env_fn_wrapper):
info['terminal_observation'] = observation
observation = env.reset()
remote.send((observation, reward, done, info))
elif cmd == 'seed':
remote.send(env.seed(data))
elif cmd == 'reset':
observation = env.reset()
remote.send(observation)
Expand Down Expand Up @@ -107,6 +109,11 @@ def step_wait(self):
obs, rews, dones, infos = zip(*results)
return _flatten_obs(obs, self.observation_space), np.stack(rews), np.stack(dones), infos

def seed(self, seed=None):
for idx, remote in enumerate(self.remotes):
remote.send(('seed', seed + idx))
return [remote.recv() for remote in self.remotes]

def reset(self):
for remote in self.remotes:
remote.send(('reset', None))
Expand Down

0 comments on commit 4476ef2

Please sign in to comment.