Skip to content

Commit

Permalink
Refactor Tests + Add Helpers (#508)
Browse files Browse the repository at this point in the history
* Add helpers

* Refactor some tests

* Continue refactoring

* Fix for codacy

* Fixes for travis

* Clean up imports

* Fix syntax error

* Fix VecEnv constructor

* Fix perf check in tests

* Seed identity env + minor updates

* Allow more diff after training again

* Try to fix travis non-determinism

* Add tests for the new helpers

* Codacy fixes

* Fix callback logic

* Address comments

* Address review comments

* Make codacy happy

* Fix docstring indentation

* Update README example

* Remove use_subprocess and update doc
  • Loading branch information
araffin committed Nov 24, 2019
1 parent 42c2290 commit e315ebe
Show file tree
Hide file tree
Showing 31 changed files with 324 additions and 224 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
__pycache__/
_build/
*.npz
*.zip

# Setuptools distribution and build folders.
/dist/
Expand Down
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,9 @@ from stable_baselines.common.vec_env import DummyVecEnv
from stable_baselines import PPO2

env = gym.make('CartPole-v1')
env = DummyVecEnv([lambda: env]) # The algorithms require a vectorized environment to run
# Optional: PPO2 requires a vectorized environment to run
# the env is now wrapped automatically when passing it to the constructor
# env = DummyVecEnv([lambda: env])

model = PPO2(MlpPolicy, env, verbose=1)
model.learn(total_timesteps=10000)
Expand Down
7 changes: 7 additions & 0 deletions docs/common/evaluation.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
.. _eval:

Evaluation Helper
=================

.. automodule:: stable_baselines.common.evaluation
:members:
24 changes: 17 additions & 7 deletions docs/guide/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ In the following example, we will train, save and load a DQN model on the Lunar
import gym
from stable_baselines import DQN
from stable_baselines.common.evaluation import evaluate_policy
# Create environment
env = gym.make('LunarLander-v2')
Expand All @@ -71,6 +73,9 @@ In the following example, we will train, save and load a DQN model on the Lunar
# Load the trained agent
model = DQN.load("dqn_lunar")
# Evaluate the agent
mean_reward, n_steps = evaluate_policy(model, model.get_env(), n_eval_episodes=10)
# Enjoy trained agent
obs = env.reset()
for i in range(1000):
Expand Down Expand Up @@ -98,7 +103,7 @@ Multiprocessing: Unleashing the Power of Vectorized Environments
from stable_baselines.common.policies import MlpPolicy
from stable_baselines.common.vec_env import SubprocVecEnv
from stable_baselines.common import set_global_seeds
from stable_baselines.common import set_global_seeds, make_vec_env
from stable_baselines import ACKTR
def make_env(env_id, rank, seed=0):
Expand All @@ -123,6 +128,10 @@ Multiprocessing: Unleashing the Power of Vectorized Environments
# Create the vectorized environment
env = SubprocVecEnv([make_env(env_id, i) for i in range(num_cpu)])
# Stable Baselines provides you with make_vec_env() helper
# which does exactly the previous steps for you:
# env = make_vec_env(env_id, n_envs=num_cpu, seed=0)
model = ACKTR(MlpPolicy, env, verbose=1)
model.learn(total_timesteps=25000)
Expand Down Expand Up @@ -340,8 +349,6 @@ A2C policy gradient updates on the model.
import gym
import numpy as np
from stable_baselines.common.policies import MlpPolicy
from stable_baselines.common.vec_env import DummyVecEnv
from stable_baselines import A2C
def mutate(params):
Expand All @@ -365,9 +372,8 @@ A2C policy gradient updates on the model.
# Create env
env = gym.make('CartPole-v1')
env = DummyVecEnv([lambda: env])
# Create policy with a small network
model = A2C(MlpPolicy, env, ent_coef=0.0, learning_rate=0.1,
model = A2C('MlpPolicy', env, ent_coef=0.0, learning_rate=0.1,
policy_kwargs={'net_arch': [8, ]})
# Use traditional actor-critic policy gradient updates to
Expand Down Expand Up @@ -546,6 +552,9 @@ You can also move from learning on one environment to another for `continual lea
obs, rewards, dones, info = env.step(action)
env.render()
# Close the processes
env.close()
# The number of environments must be identical when changing environments
env = make_atari_env('SpaceInvadersNoFrameskip-v4', num_env=8, seed=0)
Expand All @@ -558,6 +567,7 @@ You can also move from learning on one environment to another for `continual lea
action, _states = model.predict(obs)
obs, rewards, dones, info = env.step(action)
env.render()
env.close()
Record a Video
Expand Down Expand Up @@ -591,6 +601,7 @@ Record a mp4 video (here using a random agent).
for _ in range(video_length + 1):
action = [env.action_space.sample()]
obs, _, _, _ = env.step(action)
# Save the video
env.close()
Expand All @@ -606,10 +617,9 @@ Bonus: Make a GIF of a Trained Agent
import imageio
import numpy as np
from stable_baselines.common.policies import MlpPolicy
from stable_baselines import A2C
model = A2C(MlpPolicy, "LunarLander-v2").learn(100000)
model = A2C("MlpPolicy", "LunarLander-v2").learn(100000)
images = []
obs = model.env.reset()
Expand Down
4 changes: 3 additions & 1 deletion docs/guide/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ Here is a quick example of how to train and run PPO2 on a cartpole environment:
from stable_baselines import PPO2
env = gym.make('CartPole-v1')
env = DummyVecEnv([lambda: env]) # The algorithms require a vectorized environment to run
# Optional: PPO2 requires a vectorized environment to run
# the env is now wrapped automatically when passing it to the constructor
# env = DummyVecEnv([lambda: env])
model = PPO2(MlpPolicy, env, verbose=1)
model.learn(total_timesteps=10000)
Expand Down
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ This toolset is a fork of OpenAI Baselines, with a major structural refactoring,
common/tf_utils
common/cmd_utils
common/schedules
common/evaluation

.. toctree::
:maxdepth: 1
Expand Down
12 changes: 11 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,16 @@ Breaking Changes:
^^^^^^^^^^^^^^^^^
- The `seed` argument has been moved from `learn()` method to model constructor
in order to have reproducible results
- `allow_early_resets` of the `Monitor` wrapper now default to `True`
- `make_atari_env` now returns a `DummyVecEnv` by default (instead of a `SubprocVecEnv`)
this usually improves performance.

New Features:
^^^^^^^^^^^^^
- Add `n_cpu_tf_sess` to model constructor to choose the number of threads used by Tensorflow
- Environments are automatically wrapped in a `DummyVecEnv` if needed when passing them to the model constructor
- Added `stable_baselines.common.make_vec_env` helper to simplify VecEnv creation
- Added `stable_baselines.common.evaluation.evaluate_policy` helper to simplify model evaluation
- `VecNormalize` now supports being pickled and unpickled.
- Add parameter `exploration_initial_eps` to DQN. (@jdossgollin)
- Add type checking and PEP 561 compliance.
Expand All @@ -38,6 +44,7 @@ Deprecations:
Others:
^^^^^^^
- Add upper bound for Tensorflow version (<2.0.0).
- Refactored test to remove duplicated code
- Add pull request template

Documentation:
Expand All @@ -46,8 +53,11 @@ Documentation:
- Add Snake Game AI project (@pedrohbtp)
- Add note on the support Tensorflow versions.
- Remove unnecessary steps required for Windows installation.
- Remove `DummyVecEnv` creation when not needed
- Added `make_vec_env` to the examples to simplify VecEnv creation
- Add QuaRL project (@srivatsankrishnan)
- Add Pwnagotchi project (@evilsocket)
- Fix multiprocessing example (@rusu24edward)
- Fix `result_plotter` example
- Fix typo in algos.rst, "containes" to "contains" (@SyllogismRXS)

Expand Down Expand Up @@ -530,4 +540,4 @@ Thanks to @bjmuld @iambenzo @iandanforth @r7vme @brendenpetersen @huvar @abhiskk
@EliasHasle @mrakgr @Bleyddyn @antoine-galataud @junhyeokahn @AdamGleave @keshaviyengar @tperol
@XMaster96 @kantneel @Pastafarianist @GerardMaggiolino @PatrickWalter214 @yutingsz @sc420 @Aaahh @billtubbs
@Miffyli @dwiel @miguelrass @qxcv @jaberkow @eavelardev @ruifeng96150 @pedrohbtp @srivatsankrishnan @evilsocket
@MarvineGothic @jdossgollin @SyllogismRXS
@MarvineGothic @jdossgollin @SyllogismRXS @rusu24edward
7 changes: 3 additions & 4 deletions docs/modules/a2c.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,11 @@ Train a A2C agent on `CartPole-v1` using 4 processes.
import gym
from stable_baselines.common.policies import MlpPolicy
from stable_baselines.common.vec_env import SubprocVecEnv
from stable_baselines.common import make_vec_env
from stable_baselines import A2C
# multiprocess environment
n_cpu = 4
env = SubprocVecEnv([lambda: gym.make('CartPole-v1') for i in range(n_cpu)])
# Parallel environments
env = make_vec_env('CartPole-v1', n_envs=4)
model = A2C(MlpPolicy, env, verbose=1)
model.learn(total_timesteps=25000)
Expand Down
5 changes: 2 additions & 3 deletions docs/modules/acer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,11 @@ Example
import gym
from stable_baselines.common.policies import MlpPolicy, MlpLstmPolicy, MlpLnLstmPolicy
from stable_baselines.common.vec_env import SubprocVecEnv
from stable_baselines.common import make_vec_env
from stable_baselines import ACER
# multiprocess environment
n_cpu = 4
env = SubprocVecEnv([lambda: gym.make('CartPole-v1') for i in range(n_cpu)])
env = make_vec_env('CartPole-v1', n_envs=4)
model = ACER(MlpPolicy, env, verbose=1)
model.learn(total_timesteps=25000)
Expand Down
5 changes: 2 additions & 3 deletions docs/modules/acktr.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,11 @@ Example
import gym
from stable_baselines.common.policies import MlpPolicy, MlpLstmPolicy, MlpLnLstmPolicy
from stable_baselines.common.vec_env import SubprocVecEnv
from stable_baselines.common import make_vec_env
from stable_baselines import ACKTR
# multiprocess environment
n_cpu = 4
env = SubprocVecEnv([lambda: gym.make('CartPole-v1') for i in range(n_cpu)])
env = make_vec_env('CartPole-v1', n_envs=4)
model = ACKTR(MlpPolicy, env, verbose=1)
model.learn(total_timesteps=25000)
Expand Down
10 changes: 2 additions & 8 deletions docs/modules/ddpg.rst
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,10 @@ Example
import numpy as np
from stable_baselines.ddpg.policies import MlpPolicy
from stable_baselines.common.vec_env import DummyVecEnv
from stable_baselines.ddpg.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise, AdaptiveParamNoiseSpec
from stable_baselines.common.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise, AdaptiveParamNoiseSpec
from stable_baselines import DDPG
env = gym.make('MountainCarContinuous-v0')
env = DummyVecEnv([lambda: env])
# the noise objects for DDPG
n_actions = env.action_space.shape[-1]
Expand Down Expand Up @@ -148,7 +146,6 @@ You can easily define a custom architecture for the policy network:
import gym
from stable_baselines.ddpg.policies import FeedForwardPolicy
from stable_baselines.common.vec_env import DummyVecEnv
from stable_baselines import DDPG
# Custom MLP policy of two layers of size 16 each
Expand All @@ -159,10 +156,7 @@ You can easily define a custom architecture for the policy network:
layer_norm=False,
feature_extraction="mlp")
# Create and wrap the environment
env = gym.make('Pendulum-v0')
env = DummyVecEnv([lambda: env])
model = DDPG(CustomDDPGPolicy, env, verbose=1)
model = DDPG(CustomDDPGPolicy, 'Pendulum-v0', verbose=1)
# Train the agent
model.learn(total_timesteps=100000)
2 changes: 1 addition & 1 deletion docs/modules/gail.rst
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ Example
# Load the expert dataset
dataset = ExpertDataset(expert_path='expert_pendulum.npz', traj_limitation=10, verbose=1)
model = GAIL("MlpPolicy", 'Pendulum-v0', dataset, verbose=1)
model = GAIL('MlpPolicy', 'Pendulum-v0', dataset, verbose=1)
# Note: in practice, you need to train for 1M steps to have a working policy
model.learn(total_timesteps=1000)
model.save("gail_pendulum")
Expand Down
2 changes: 0 additions & 2 deletions docs/modules/ppo1.rst
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,9 @@ Example
import gym
from stable_baselines.common.policies import MlpPolicy
from stable_baselines.common.vec_env import DummyVecEnv
from stable_baselines import PPO1
env = gym.make('CartPole-v1')
env = DummyVecEnv([lambda: env])
model = PPO1(MlpPolicy, env, verbose=1)
model.learn(total_timesteps=25000)
Expand Down
5 changes: 2 additions & 3 deletions docs/modules/ppo2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,11 @@ Train a PPO agent on `CartPole-v1` using 4 processes.
import gym
from stable_baselines.common.policies import MlpPolicy
from stable_baselines.common.vec_env import SubprocVecEnv
from stable_baselines.common import make_vec_env
from stable_baselines import PPO2
# multiprocess environment
n_cpu = 4
env = SubprocVecEnv([lambda: gym.make('CartPole-v1') for i in range(n_cpu)])
env = make_vec_env('CartPole-v1', n_envs=4)
model = PPO2(MlpPolicy, env, verbose=1)
model.learn(total_timesteps=25000)
Expand Down
2 changes: 0 additions & 2 deletions docs/modules/sac.rst
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,9 @@ Example
import numpy as np
from stable_baselines.sac.policies import MlpPolicy
from stable_baselines.common.vec_env import DummyVecEnv
from stable_baselines import SAC
env = gym.make('Pendulum-v0')
env = DummyVecEnv([lambda: env])
model = SAC(MlpPolicy, env, verbose=1)
model.learn(total_timesteps=50000, log_interval=10)
Expand Down
1 change: 0 additions & 1 deletion docs/modules/td3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ Example
from stable_baselines.ddpg.noise import NormalActionNoise, OrnsteinUhlenbeckActionNoise
env = gym.make('Pendulum-v0')
env = DummyVecEnv([lambda: env])
# The noise objects for TD3
n_actions = env.action_space.shape[-1]
Expand Down
2 changes: 0 additions & 2 deletions docs/modules/trpo.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,9 @@ Example
import gym
from stable_baselines.common.policies import MlpPolicy
from stable_baselines.common.vec_env import DummyVecEnv
from stable_baselines import TRPO
env = gym.make('CartPole-v1')
env = DummyVecEnv([lambda: env])
model = TRPO(MlpPolicy, env, verbose=1)
model.learn(total_timesteps=25000)
Expand Down
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@
from stable_baselines import PPO2
env = gym.make('CartPole-v1')
env = DummyVecEnv([lambda: env]) # The algorithms require a vectorized environment to run
# Optional: PPO2 requires a vectorized environment to run
# the env is now wrapped automatically when passing it to the constructor
# env = DummyVecEnv([lambda: env])
model = PPO2(MlpPolicy, env, verbose=1)
model.learn(total_timesteps=10000)
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines/bench/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class Monitor(Wrapper):
EXT = "monitor.csv"
file_handler = None

def __init__(self, env, filename, allow_early_resets=False, reset_keywords=(), info_keywords=()):
def __init__(self, env, filename, allow_early_resets=True, reset_keywords=(), info_keywords=()):
"""
A monitor wrapper for Gym environments, it is used to know the episode reward, length, time and other data.
Expand Down
1 change: 1 addition & 0 deletions stable_baselines/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
from stable_baselines.common.misc_util import zipsame, set_global_seeds, boolean_flag
from stable_baselines.common.base_class import BaseRLModel, ActorCriticRLModel, OffPolicyRLModel, SetVerbosity, \
TensorboardWriter
from stable_baselines.common.cmd_util import make_vec_env
7 changes: 6 additions & 1 deletion stable_baselines/common/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,12 @@ def __init__(self, policy, env, verbose=0, *, requires_vec_env, policy_base,
if isinstance(env, VecEnv):
self.n_envs = env.num_envs
else:
raise ValueError("Error: the model requires a vectorized environment, please use a VecEnv wrapper.")
# The model requires a VecEnv
# wrap it in a DummyVecEnv to avoid error
self.env = DummyVecEnv([lambda: env])
if self.verbose >= 1:
print("Wrapping the env in a DummyVecEnv.")
self.n_envs = 1
else:
if isinstance(env, VecEnv):
if env.num_envs == 1:
Expand Down

0 comments on commit e315ebe

Please sign in to comment.