Skip to content

Commit

Permalink
Fix check_env, Monitor.close and add Makefile (#673)
Browse files Browse the repository at this point in the history
* Fix `check_env` and add Makefile

* Fixed doc build

* Fixed and typed Monitor
  • Loading branch information
araffin committed Feb 3, 2020
1 parent 34d2bee commit c6acd1e
Show file tree
Hide file tree
Showing 8 changed files with 99 additions and 37 deletions.
2 changes: 1 addition & 1 deletion .github/PULL_REQUEST_TEMPLATE.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,6 @@
- [ ] My change requires a change to the documentation.
- [ ] I have updated the tests accordingly (*required for a bug fix or a new feature*).
- [ ] I have updated the documentation accordingly.
- [ ] I have ensured `pytest` and `pytype` both pass.
- [ ] I have ensured `pytest` and `pytype` both pass (by running `make pytest` and `make type`).

<!--- This Template is an edited version of the one from https://github.com/evilsocket/pwnagotchi/ -->
37 changes: 28 additions & 9 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,17 +57,17 @@ from stable_baselines import PPO2

In general, we recommend using pycharm to format everything in an efficient way.

Please documentation each function/method using the following template:
Please document each function/method and [type](https://google.github.io/pytype/user_guide.html) them using the following template:

```python

def my_function(arg1, arg2):
def my_function(arg1: type1, arg2: type2) -> returntype:
"""
Short description of the function.
:param arg1: (arg1 type) describe what is arg1
:param arg2: (arg2 type) describe what is arg2
:return: (return type) describe what is returned
:param arg1: (type1) describe what is arg1
:param arg2: (type2) describe what is arg2
:return: (returntype) describe what is returned
"""
...
return my_variable
Expand All @@ -77,7 +77,7 @@ def my_function(arg1, arg2):

Before proposing a PR, please open an issue, where the feature will be discussed. This prevent from duplicated PR to be proposed and also ease the code review process.

Each PR need to be reviewed and accepted by at least one of the maintainers (@hill-a , @araffin or @erniejunior ).
Each PR need to be reviewed and accepted by at least one of the maintainers (@hill-a, @araffin, @erniejunior, @AdamGleave or @Miffyli).
A PR must pass the Continuous Integration tests (travis + codacy) to be merged with the master branch.

Note: in rare cases, we can create exception for codacy failure.
Expand All @@ -88,15 +88,34 @@ All new features must add tests in the `tests/` folder ensuring that everything
We use [pytest](https://pytest.org/).
Also, when a bug fix is proposed, tests should be added to avoid regression.

To run tests with `pytest` and type checking with `pytype`:
To run tests with `pytest`:

```
./scripts/run_tests.sh
make pytest
```

Type checking with `pytype`:

```
make type
```

Build the documentation:

```
make doc
```

Check documentation spelling (you need to install `sphinxcontrib.spelling` package for that):

```
make spelling
```


## Changelog and Documentation

Please do not forget to update the changelog and add documentation if needed.
Please do not forget to update the changelog (`docs/misc/changelog.rst`) and add documentation if needed.
A README is present in the `docs/` folder for instructions on how to build the documentation.


Expand Down
19 changes: 19 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Run pytest and coverage report
pytest:
./scripts/run_tests.sh

# Type check
type:
pytype

# Build the doc
doc:
cd docs && make html

# Check the spelling in the doc
spelling:
cd docs && make spelling

# Clean the doc build folder
clean:
cd docs && make clean
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ Some of the baselines examples use [MuJoCo](http://www.mujoco.org) (multi-joint
All unit tests in baselines can be run using pytest runner:
```
pip install pytest pytest-cov
pytest --cov-config .coveragerc --cov-report html --cov-report term --cov=.
make pytest
```

## Projects Using Stable-Baselines
Expand Down
7 changes: 4 additions & 3 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ Breaking Changes:
New Features:
^^^^^^^^^^^^^
- Parallelized updating and sampling from the replay buffer in DQN. (@flodorner)

- Docker build script, `scripts/build_docker.sh`, can push images automatically.

Bug Fixes:
Expand All @@ -30,9 +29,10 @@ Bug Fixes:
- Fixed a bug in PPO2, ACER, A2C, and ACKTR where repeated calls to `learn(total_timesteps)` reset
the environment on every call, potentially biasing samples toward early episode timesteps.
(@shwang)

- Fixed by adding lazy property `ActorCriticRLModel.runner`. Subclasses now use lazily-generated
- Fixed by adding lazy property `ActorCriticRLModel.runner`. Subclasses now use lazily-generated
`self.runner` instead of reinitializing a new Runner every time `learn()` is called.
- Fixed a bug in `check_env` where it would fail on high dimensional action spaces
- Fixed `Monitor.close()` that was not calling the parent method

Deprecations:
^^^^^^^^^^^^^
Expand All @@ -41,6 +41,7 @@ Others:
^^^^^^^
- Removed redundant return value from `a2c.utils::total_episode_reward_logger`. (@shwang)
- Cleanup and refactoring in `common/identity_env.py` (@shwang)
- Added a Makefile to simplify common development tasks (build the doc, type check, run the tests)

Documentation:
^^^^^^^^^^^^^^
Expand Down
46 changes: 27 additions & 19 deletions stable_baselines/bench/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,33 @@
import os
import time
from glob import glob
from typing import Tuple, Dict, Any, List, Optional

import gym
import pandas
from gym.core import Wrapper
import numpy as np


class Monitor(Wrapper):
class Monitor(gym.Wrapper):
EXT = "monitor.csv"
file_handler = None

def __init__(self, env, filename, allow_early_resets=True, reset_keywords=(), info_keywords=()):
def __init__(self,
env: gym.Env,
filename: Optional[str],
allow_early_resets: bool = True,
reset_keywords=(),
info_keywords=()):
"""
A monitor wrapper for Gym environments, it is used to know the episode reward, length, time and other data.
:param env: (Gym environment) The environment
:param filename: (str) the location to save a log file, can be None for no log
:param env: (gym.Env) The environment
:param filename: (Optional[str]) the location to save a log file, can be None for no log
:param allow_early_resets: (bool) allows the reset of the environment before it is done
:param reset_keywords: (tuple) extra keywords for the reset call, if extra parameters are needed at reset
:param info_keywords: (tuple) extra information to log, from the information return of environment.step
"""
Wrapper.__init__(self, env=env)
super(Monitor, self).__init__(env=env)
self.t_start = time.time()
if filename is None:
self.file_handler = None
Expand Down Expand Up @@ -53,12 +60,12 @@ def __init__(self, env, filename, allow_early_resets=True, reset_keywords=(), in
self.total_steps = 0
self.current_reset_info = {} # extra info about the current episode, that was passed in during reset()

def reset(self, **kwargs):
def reset(self, **kwargs) -> np.ndarray:
"""
Calls the Gym environment reset. Can only be called if the environment is over, or if allow_early_resets is True
:param kwargs: Extra keywords saved for the next episode. only if defined by reset_keywords
:return: ([int] or [float]) the first observation of the environment
:return: (np.ndarray) the first observation of the environment
"""
if not self.allow_early_resets and not self.needs_reset:
raise RuntimeError("Tried to reset an environment before done. If you want to allow early resets, "
Expand All @@ -68,16 +75,16 @@ def reset(self, **kwargs):
for key in self.reset_keywords:
value = kwargs.get(key)
if value is None:
raise ValueError('Expected you to pass kwarg %s into reset' % key)
raise ValueError('Expected you to pass kwarg {} into reset'.format(key))
self.current_reset_info[key] = value
return self.env.reset(**kwargs)

def step(self, action):
def step(self, action: np.ndarray) -> Tuple[np.ndarray, float, bool, Dict[Any, Any]]:
"""
Step the environment with the given action
:param action: ([int] or [float]) the action
:return: ([int] or [float], [float], [bool], dict) observation, reward, done, information
:param action: (np.ndarray) the action
:return: (Tuple[np.ndarray, float, bool, Dict[Any, Any]]) observation, reward, done, information
"""
if self.needs_reset:
raise RuntimeError("Tried to step environment that needs reset")
Expand Down Expand Up @@ -105,34 +112,35 @@ def close(self):
"""
Closes the environment
"""
super(Monitor, self).close()
if self.file_handler is not None:
self.file_handler.close()

def get_total_steps(self):
def get_total_steps(self) -> int:
"""
Returns the total number of timesteps
:return: (int)
"""
return self.total_steps

def get_episode_rewards(self):
def get_episode_rewards(self) -> List[float]:
"""
Returns the rewards of all the episodes
:return: ([float])
"""
return self.episode_rewards

def get_episode_lengths(self):
def get_episode_lengths(self) -> List[int]:
"""
Returns the number of timesteps of all the episodes
:return: ([int])
"""
return self.episode_lengths

def get_episode_times(self):
def get_episode_times(self) -> List[float]:
"""
Returns the runtime in seconds of all the episodes
Expand All @@ -148,7 +156,7 @@ class LoadMonitorResultsError(Exception):
pass


def get_monitor_files(path):
def get_monitor_files(path: str) -> List[str]:
"""
get all the monitor files in the given path
Expand All @@ -158,12 +166,12 @@ def get_monitor_files(path):
return glob(os.path.join(path, "*" + Monitor.EXT))


def load_results(path):
def load_results(path: str) -> pandas.DataFrame:
"""
Load all Monitor logs from a given directory path matching ``*monitor.csv`` and ``*monitor.json``
:param path: (str) the directory path containing the log file(s)
:return: (Pandas DataFrame) the logged data
:return: (pandas.DataFrame) the logged data
"""
# get both csv and (old) json files
monitor_files = (glob(os.path.join(path, "*monitor.json")) + get_monitor_files(path))
Expand Down
8 changes: 4 additions & 4 deletions stable_baselines/common/env_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def _check_spaces(env: gym.Env) -> None:
assert isinstance(env.action_space, spaces.Space), "The action space must inherit from gym.spaces" + gym_spaces


def _check_render(env: gym.Env, warn=True, headless=False) -> None:
def _check_render(env: gym.Env, warn: bool = True, headless: bool = False) -> None:
"""
Check the declared render modes and the `render()`/`close()`
method of the environment.
Expand Down Expand Up @@ -163,7 +163,7 @@ def _check_render(env: gym.Env, warn=True, headless=False) -> None:
env.close()


def check_env(env: gym.Env, warn=True, skip_render_check=True) -> None:
def check_env(env: gym.Env, warn: bool = True, skip_render_check: bool = True) -> None:
"""
Check that an environment follows Gym API.
This is particularly useful when using a custom environment.
Expand Down Expand Up @@ -205,8 +205,8 @@ def check_env(env: gym.Env, warn=True, skip_render_check=True) -> None:

# Check for the action space, it may lead to hard-to-debug issues
if (isinstance(action_space, spaces.Box) and
(np.abs(action_space.low) != np.abs(action_space.high)
or np.abs(action_space.low) > 1 or np.abs(action_space.high) > 1)):
(np.any(np.abs(action_space.low) != np.abs(action_space.high))
or np.any(np.abs(action_space.low) > 1) or np.any(np.abs(action_space.high) > 1))):
warnings.warn("We recommend you to use a symmetric and normalized Box action space (range=[-1, 1]) "
"cf https://stable-baselines.readthedocs.io/en/master/guide/rl_tips.html")

Expand Down
15 changes: 15 additions & 0 deletions tests/test_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,21 @@ def test_custom_envs(env_class):
check_env(env)


def test_high_dimension_action_space():
"""
Test for continuous action space
with more than one action.
"""
env = gym.make('Pendulum-v0')
# Patch the action space
env.action_space = spaces.Box(low=-1, high=1, shape=(20,), dtype=np.float32)
# Patch to avoid error
def patched_step(_action):
return env.observation_space.sample(), 0.0, False, {}
env.step = patched_step
check_env(env)


@pytest.mark.parametrize("new_obs_space", [
# Small image
spaces.Box(low=0, high=255, shape=(32, 32, 3), dtype=np.uint8),
Expand Down

0 comments on commit c6acd1e

Please sign in to comment.