Skip to content

Commit

Permalink
Merge branch 'master' into feat/gymnasium-support
Browse files Browse the repository at this point in the history
  • Loading branch information
araffin committed Apr 12, 2023
2 parents 77b0950 + 15c9daa commit 63307d4
Show file tree
Hide file tree
Showing 8 changed files with 46 additions and 11 deletions.
6 changes: 4 additions & 2 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Changelog
==========

Release 2.0.0a1 (WIP)
Release 2.0.0a3 (WIP)
--------------------------

**Gymnasium support**
Expand All @@ -19,6 +19,7 @@ Breaking Changes:
- Switched to Gymnasium as primary backend, Gym 0.21 and 0.26 are still supported via the ``shimmy`` package
- The deprecated ``online_sampling`` argument of ``HerReplayBuffer`` was removed
- Removed deprecated ``stack_observation_space`` method of ``StackedObservations``
- Renamed environment output observations in ``evaluate_policy`` to prevent shadowing the input observations during callbacks (@npit)

New Features:
^^^^^^^^^^^^^
Expand All @@ -31,6 +32,7 @@ New Features:

Bug Fixes:
^^^^^^^^^^
- Fixed ``VecExtractDictObs`` does not handle terminal observation (@WeberSamuel)

Deprecations:
^^^^^^^^^^^^^
Expand Down Expand Up @@ -1314,4 +1316,4 @@ And all the contributors:
@carlosluis @arjun-kg @tlpss
@Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb @timothe-chaumont @honglu2875
@anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong
@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher
@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel
4 changes: 3 additions & 1 deletion stable_baselines3/common/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def evaluate_policy(
episode_start=episode_starts,
deterministic=deterministic,
)
observations, rewards, dones, infos = env.step(actions)
new_observations, rewards, dones, infos = env.step(actions)
current_rewards += rewards
current_lengths += 1
for i in range(n_envs):
Expand Down Expand Up @@ -125,6 +125,8 @@ def evaluate_policy(
current_rewards[i] = 0
current_lengths[i] = 0

observations = new_observations

if render:
env.render()

Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/common/vec_env/base_vec_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ def getattr_recursive(self, name: str) -> Any:

return attr

def getattr_depth_check(self, name: str, already_found: bool) -> str:
def getattr_depth_check(self, name: str, already_found: bool) -> Optional[str]:
"""See base class.
:return: name of module whose attribute is being shadowed, if any.
Expand Down
6 changes: 4 additions & 2 deletions stable_baselines3/common/vec_env/stacked_observations.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,14 +109,16 @@ def reset(self, observation: TObs) -> TObs:
:return: The stacked reset observation
"""
if isinstance(observation, dict):
return {key: self.sub_stacked_observations[key].reset(obs) for key, obs in observation.items()}
return {
key: self.sub_stacked_observations[key].reset(obs) for key, obs in observation.items()
} # pytype: disable=bad-return-type

self.stacked_obs[...] = 0
if self.channels_first:
self.stacked_obs[:, -observation.shape[self.stack_dimension] :, ...] = observation
else:
self.stacked_obs[..., -observation.shape[self.stack_dimension] :] = observation
return self.stacked_obs
return self.stacked_obs # pytype: disable=bad-return-type

def update(
self,
Expand Down
7 changes: 5 additions & 2 deletions stable_baselines3/common/vec_env/vec_extract_dict_obs.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,8 @@ def reset(self) -> np.ndarray:
return obs[self.key]

def step_wait(self) -> VecEnvStepReturn:
obs, reward, done, info = self.venv.step_wait()
return obs[self.key], reward, done, info
obs, reward, done, infos = self.venv.step_wait()
for info in infos:
if "terminal_observation" in info:
info["terminal_observation"] = info["terminal_observation"][self.key]
return obs[self.key], reward, done, infos
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.0.0a2
2.0.0a3
4 changes: 4 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,10 @@ def test_evaluate_policy(direct_policy: bool):

def dummy_callback(locals_, _globals):
locals_["model"].n_callback_calls += 1
assert "observations" in locals_
assert "new_observations" in locals_
assert locals_["new_observations"] is not locals_["observations"]
assert not np.allclose(locals_["new_observations"], locals_["observations"])

assert model.policy is not None
policy = model.policy if direct_policy else model
Expand Down
26 changes: 24 additions & 2 deletions tests/test_vec_extract_dict_obs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,31 @@ def __init__(self):
self.num_envs = 4
self.action_space = spaces.Discrete(2)
self.observation_space = spaces.Dict({"rgb": spaces.Box(low=0.0, high=255.0, shape=(86, 86), dtype=np.float32)})
self.n_steps = 0
self.max_steps = 5

def step_async(self, actions):
self.actions = actions

def step_wait(self):
self.n_steps += 1
done = self.n_steps >= self.max_steps
if done:
infos = [
{"terminal_observation": {"rgb": np.zeros((86, 86))}, "TimeLimit.truncated": True}
for _ in range(self.num_envs)
]
else:
infos = []
return (
{"rgb": np.zeros((self.num_envs, 86, 86))},
np.zeros((self.num_envs,)),
np.zeros((self.num_envs,), dtype=bool),
[{} for _ in range(self.num_envs)],
np.ones((self.num_envs,), dtype=bool) * done,
infos,
)

def reset(self):
self.n_steps = 0
return {"rgb": np.zeros((self.num_envs, 86, 86))}

def render(self, close=False):
Expand All @@ -40,6 +52,16 @@ def test_extract_dict_obs():
env = VecExtractDictObs(env, "rgb")
assert env.reset().shape == (4, 86, 86)

for _ in range(10):
obs, _, dones, infos = env.step([env.action_space.sample() for _ in range(env.num_envs)])
assert obs.shape == (4, 86, 86)
for idx, info in enumerate(infos):
if "terminal_observation" in info:
assert dones[idx]
assert info["terminal_observation"].shape == (86, 86)
else:
assert not dones[idx]


def test_vec_with_ppo():
"""
Expand Down

0 comments on commit 63307d4

Please sign in to comment.