Skip to content

Commit

Permalink
Fix double reset and improve typing coverage (#136)
Browse files Browse the repository at this point in the history
* Fix double reset and improve typing coverage

* Revert minor edit

* Add doc about types
  • Loading branch information
araffin committed Aug 5, 2020
1 parent cceffd5 commit 21e9994
Show file tree
Hide file tree
Showing 9 changed files with 143 additions and 109 deletions.
25 changes: 25 additions & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,31 @@
Changelog
==========

Pre-Release 0.9.0a0 (WIP)
------------------------------

Breaking Changes:
^^^^^^^^^^^^^^^^^

New Features:
^^^^^^^^^^^^^
- Added ``unwrap_vec_wrapper()`` to ``common.vec_env`` to extract ``VecEnvWrapper`` if needed

Bug Fixes:
^^^^^^^^^^
- Fixed a bug where the environment was reset twice when using ``evaluate_policy``

Deprecations:
^^^^^^^^^^^^^

Others:
^^^^^^^
- Improve typing coverage of the ``VecEnv``
- Removed ``AlreadySteppingError`` and ``NotSteppingError`` that were not used

Documentation:
^^^^^^^^^^^^^^

Pre-Release 0.8.0 (2020-08-03)
------------------------------

Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/common/env_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def _check_nan(env: gym.Env) -> None:
"""Check for Inf and NaN using the VecWrapper."""
vec_env = VecCheckNan(DummyVecEnv([lambda: env]))
for _ in range(10):
action = [env.action_space.sample()]
action = np.array([env.action_space.sample()])
_, _, _, _ = vec_env.step(action)


Expand Down
34 changes: 21 additions & 13 deletions stable_baselines3/common/evaluation.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,25 @@
# Copied from stable_baselines
import typing
from typing import Callable, List, Optional, Tuple, Union

import gym
import numpy as np

from stable_baselines3.common.vec_env import VecEnv

if typing.TYPE_CHECKING:
from stable_baselines3.common.base_class import BaseAlgorithm


def evaluate_policy(
model,
env,
n_eval_episodes=10,
deterministic=True,
render=False,
callback=None,
reward_threshold=None,
return_episode_rewards=False,
):
model: "BaseAlgorithm",
env: Union[gym.Env, VecEnv],
n_eval_episodes: int = 10,
deterministic: bool = True,
render: bool = False,
callback: Optional[Callable] = None,
reward_threshold: Optional[float] = None,
return_episode_rewards: bool = False,
) -> Union[Tuple[float, float], Tuple[List[float], List[int]]]:
"""
Runs policy for ``n_eval_episodes`` episodes and returns average reward.
This is made to work only with one env.
Expand All @@ -28,7 +34,7 @@ def evaluate_policy(
called after each step.
:param reward_threshold: (float) Minimum expected reward per episode,
this will raise an error if the performance is not met
:param return_episode_rewards: (bool) If True, a list of reward per episode
:param return_episode_rewards: (Optional[float]) If True, a list of reward per episode
will be returned instead of the mean.
:return: (float, float) Mean reward per episode, std of reward per episode
returns ([float], [int]) when ``return_episode_rewards`` is True
Expand All @@ -37,8 +43,10 @@ def evaluate_policy(
assert env.num_envs == 1, "You must pass only one environment when using this function"

episode_rewards, episode_lengths = [], []
for _ in range(n_eval_episodes):
obs = env.reset()
for i in range(n_eval_episodes):
# Avoid double reset, as VecEnv are reset automatically
if not isinstance(env, VecEnv) or i == 0:
obs = env.reset()
done, state = False, None
episode_reward = 0.0
episode_length = 0
Expand Down
29 changes: 17 additions & 12 deletions stable_baselines3/common/vec_env/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,9 @@
# flake8: noqa F401
import typing
from copy import deepcopy
from typing import Optional, Union

from stable_baselines3.common.vec_env.base_vec_env import (
AlreadySteppingError,
CloudpickleWrapper,
NotSteppingError,
VecEnv,
VecEnvWrapper,
)
from typing import Optional, Type, Union

from stable_baselines3.common.vec_env.base_vec_env import CloudpickleWrapper, VecEnv, VecEnvWrapper
from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv
from stable_baselines3.common.vec_env.vec_check_nan import VecCheckNan
Expand All @@ -23,19 +17,30 @@
from stable_baselines3.common.type_aliases import GymEnv


def unwrap_vec_normalize(env: Union["GymEnv", VecEnv]) -> Optional[VecNormalize]:
def unwrap_vec_wrapper(env: Union["GymEnv", VecEnv], vec_wrapper_class: Type[VecEnvWrapper]) -> Optional[VecEnvWrapper]:
"""
Retrieve a ``VecEnvWrapper`` object by recursively searching.
:param env: (gym.Env)
:return: (VecNormalize)
:param vec_wrapper_class: (VecEnvWrapper)
:return: (VecEnvWrapper)
"""
env_tmp = env
while isinstance(env_tmp, VecEnvWrapper):
if isinstance(env_tmp, VecNormalize):
if isinstance(env_tmp, vec_wrapper_class):
return env_tmp
env_tmp = env_tmp.venv
return None


def unwrap_vec_normalize(env: Union["GymEnv", VecEnv]) -> Optional[VecNormalize]:
"""
:param env: (gym.Env)
:return: (VecNormalize)
"""
return unwrap_vec_wrapper(env, VecNormalize) # pytype:disable=bad-return-type


# Define here to avoid circular import
def sync_envs_normalization(env: "GymEnv", eval_env: "GymEnv") -> None:
"""
Expand Down

0 comments on commit 21e9994

Please sign in to comment.