Skip to content

Commit

Permalink
Review of code (A2C, PPO and refactoring) (#35)
Browse files Browse the repository at this point in the history
* Split torch module code into torch_layers file

* Updated reference to CNN

* Change 'CxWxH' to 'CxHxW', as per common notion

* Fix missing import in policies.py

* Move PPOPolicy to OnlineActorCriticPolicy

* Create OnPolicyRLModel from PPO, and make A2C and PPO inherit

* Update A2C optimizer comment

* Clean weight init scales for clarity

* Fix A2C log_interval default parameter

* Rename 'progress' to 'progress_remaining

* Rename 'Models' to 'Algorithms'

* Rename 'OnlineActorCriticPolicy' to 'ActorCriticPolicy'

* Move static functions out from BaseAlgorithm

* Move on/off_policy base algorithms to their own files

* Add  files for A2C/PPO

* Fix docs

* Fix pytype

* Update documentation on OnPolicyAlgorithm

* Add proper doctstring for on_policy rollout gathering

* Add bit clarification on the mlppolicy/cnnpolicy naming

* Move static function is_vectorized_policies to utils.py

* Checking docstrings, pep8 fixes

* Update changelog

* Clean changelog

* Remove policy warnings for sac/td3

* Add monitor_wrapper for OnPolicyAlgorithm. Clean tb logging variables. Add parameter keywords to OffPolicyAlgorithm super init

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
  • Loading branch information
Miffyli and araffin committed Jun 9, 2020
1 parent 11d33eb commit 44f8218
Show file tree
Hide file tree
Showing 26 changed files with 1,511 additions and 1,279 deletions.
2 changes: 1 addition & 1 deletion docs/guide/callbacks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ This will give you access to events (``_on_training_start``, ``_on_step``) and u
# Those variables will be accessible in the callback
# (they are defined in the base class)
# The RL model
# self.model = None # type: BaseRLModel
# self.model = None # type: BaseAlgorithm
# An alias for self.model.get_env(), the environment used for training
# self.training_env = None # type: Union[gym.Env, VecEnv, None]
# Number of time the callback was called
Expand Down
16 changes: 11 additions & 5 deletions docs/guide/developer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ The library is not meant to be modular, although inheritance is used to reduce c
Algorithms Structure
====================


Each algorithm (on-policy and off-policy ones) follows a common structure.
Policy contains code for acting in the environment, and algorithm updates this policy.
There is one folder per algorithm, and in that folder there is the algorithm and the policy definition (``policies.py``).

Each algorithm has two main methods:
Expand All @@ -34,13 +36,14 @@ Where to start?

The first thing you need to read and understand are the base classes in the ``common/`` folder:

- ``BaseRLModel`` in ``base_class.py`` which defines how an RL class should look like.
- ``BaseAlgorithm`` in ``base_class.py`` which defines how an RL class should look like.
It contains also all the "glue code" for saving/loading and the common operations (wrapping environments)

- ``BasePolicy`` in ``policies.py`` which defines how a policy class should look like.
It contains also all the magic for the ``.predict()`` method, to handle as many cases as possible
It contains also all the magic for the ``.predict()`` method, to handle as many spaces/cases as possible

- ``OffPolicyRLModel`` in ``base_class.py`` that contains the implementation of ``collect_rollouts()`` for the off-policy algorithms
- ``OffPolicyAlgorithm`` in ``off_policy_algorithm.py`` that contains the implementation of ``collect_rollouts()`` for the off-policy algorithms,
and similarly ``OnPolicyAlgorithm`` in ``on_policy_algorithm.py``.


All the environments handled internally are assumed to be ``VecEnv`` (``gym.Env`` are automatically wrapped).
Expand All @@ -50,7 +53,7 @@ Pre-Processing
==============

To handle different observation spaces, some pre-processing needs to be done (e.g. one-hot encoding for discrete observation).
Most of the code for pre-processing is in ``common/preprocessing.py``.
Most of the code for pre-processing is in ``common/preprocessing.py`` and ``common/policies.py``.

For images, we make use of an additional wrapper ``VecTransposeImage`` because PyTorch uses the "channel-first" convention.

Expand All @@ -61,9 +64,12 @@ Policy Structure
When we refer to "policy" in Stable-Baselines3, this is usually an abuse of language compared to RL terminology.
In SB3, "policy" refers to the class that handles all the networks useful for training,
so not only the network used to predict actions (the "learned controller").

For instance, the ``TD3`` policy contains the actor, the critic and the target networks.

To avoid the hassle of importing specific policy classes for specific algorithm (e.g. both A2C and PPO use ``ActorCriticPolicy``),
SB3 uses names like "MlpPolicy" and "CnnPolicy" to refer policies using small feed-forward networks or convolutional networks,
respectively. Importing ``[algorithm]/policies.py`` registers an appropriate policy for that algorithm under those names.

Probability distributions
=========================

Expand Down
33 changes: 29 additions & 4 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,27 +10,51 @@ Pre-Release 0.7.0a1 (WIP)
Breaking Changes:
^^^^^^^^^^^^^^^^^
- ``render()`` method of ``VecEnvs`` now only accept one argument: ``mode``

- Created new file common/torch_layers.py, similar to SB refactoring

- Contains all PyTorch network layer definitions and feature extractors: ``MlpExtractor``, ``create_mlp``, ``NatureCNN``

- Renamed ``BaseRLModel`` to ``BaseAlgorithm`` (along with offpolicy and onpolicy variants)
- Moved on-policy and off-policy base algorithms to ``common/on_policy_algorithm.py`` and ``common/off_policy_algorithm.py``, respectively.
- Moved ``PPOPolicy`` to ``ActorCriticPolicy`` in common/policies.py
- Moved ``PPO`` (algorithm class) into ``OnPolicyAlgorithm`` (``common/on_policy_algorithm.py``), to be shared with A2C
- Moved following functions from ``BaseAlgorithm``:

- ``_load_from_file`` to ``load_from_zip_file`` (save_util.py)
- ``_save_to_file_zip`` to ``save_to_zip_file`` (save_util.py)
- ``safe_mean`` to ``safe_mean`` (utils.py)
- ``check_env`` to ``check_for_correct_spaces`` (utils.py. Renamed to avoid confusion with environment checker tools)

- Moved static function ``_is_vectorized_observation`` from common/policies.py to common/utils.py under name ``is_vectorized_observation``.


New Features:
^^^^^^^^^^^^^

Bug Fixes:
^^^^^^^^^^
- Fixed ``render()`` method for ``VecEnvs``
- Fixed ``seed()``` method for ``SubprocVecEnv``
- Fixed ``seed()`` method for ``SubprocVecEnv``
- Fixed loading on GPU for testing when using gSDE and ``deterministic=False``
- Fixed ``register_policy`` to allow re-registering same policy for same sub-class (i.e. assign same value to same key).

Deprecations:
^^^^^^^^^^^^^

Others:
^^^^^^^
- Re-enable unsafe ``fork`` start method in the tests (was causing a deadlock with tensorflow)
- Added a test for seeding ``SubprocVecEnv``` and rendering
- Added a test for seeding ``SubprocVecEnv`` and rendering
- Fixed reference in NatureCNN (pointed to older version with different network architecture)
- Fixed comments saying "CxWxH" instead of "CxHxW" (same style as in torch docs / commonly used)
- Added bit further comments on register/getting policies ("MlpPolicy", "CnnPolicy").
- Renamed ``progress`` (value from 1 in start of training to 0 in end) to ``progress_remaining``.
- Added ``policies.py`` files for A2C/PPO, which define MlpPolicy/CnnPolicy (renamed ActorCriticPolicies).

Documentation:
^^^^^^^^^^^^^^

- Added a paragraph on "MlpPolicy"/"CnnPolicy" and policy naming scheme under "Developer Guide"
- Fixed second-level listing in changelog

Pre-Release 0.6.0 (2020-06-01)
------------------------------
Expand All @@ -41,6 +65,7 @@ Breaking Changes:
^^^^^^^^^^^^^^^^^
- Remove State-Dependent Exploration (SDE) support for ``TD3``
- Methods were renamed in the logger:

- ``logkv`` -> ``record``, ``writekvs`` -> ``write``, ``writeseq`` -> ``write_sequence``,
- ``logkvs`` -> ``record_dict``, ``dumpkvs`` -> ``dump``,
- ``getkvs`` -> ``get_log_dict``, ``logkv_mean`` -> ``record_mean``,
Expand Down
21 changes: 18 additions & 3 deletions docs/modules/base.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,29 @@ Base RL Class

Common interface for all the RL algorithms

.. autoclass:: BaseRLModel
.. autoclass:: BaseAlgorithm
:members:


.. automodule:: stable_baselines3.common.off_policy_algorithm


Base Off-Policy Class
^^^^^^^^^^^^^^^^^^^^^

The base RL model for Off-Policy algorithm (ex: SAC/TD3)
The base RL algorithm for Off-Policy algorithm (ex: SAC/TD3)

.. autoclass:: OffPolicyAlgorithm
:members:


.. automodule:: stable_baselines3.common.on_policy_algorithm


Base On-Policy Class
^^^^^^^^^^^^^^^^^^^^^

The base RL algorithm for On-Policy algorithm (ex: A2C/PPO)

.. autoclass:: OffPolicyRLModel
.. autoclass:: OnPolicyAlgorithm
:members:
6 changes: 0 additions & 6 deletions docs/modules/sac.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,6 @@ SAC is the successor of `Soft Q-Learning SQL <https://arxiv.org/abs/1702.08165>`
A key feature of SAC, and a major difference with common RL algorithms, is that it is trained to maximize a trade-off between expected return and entropy, a measure of randomness in the policy.


.. warning::

The SAC model does not support ``stable_baselines3.ppo.policies`` because it uses double q-values
and value estimation, as a result it must use its own policy models (see :ref:`sac_policies`).


.. rubric:: Available Policies

.. autosummary::
Expand Down
6 changes: 0 additions & 6 deletions docs/modules/td3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,6 @@ TD3 is a direct successor of DDPG and improves it using three major tricks: clip
We recommend reading `OpenAI Spinning guide on TD3 <https://spinningup.openai.com/en/latest/algorithms/td3.html>`_ to learn more about those.


.. warning::

The TD3 model does not support ``stable_baselines3.ppo.policies`` because it uses double q-values
estimation, as a result it must use its own policy models (see :ref:`td3_policies`).


.. rubric:: Available Policies

.. autosummary::
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/a2c/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from stable_baselines3.a2c.a2c import A2C
from stable_baselines3.ppo.policies import MlpPolicy, CnnPolicy
from stable_baselines3.a2c.policies import MlpPolicy, CnnPolicy
38 changes: 20 additions & 18 deletions stable_baselines3/a2c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
from typing import Type, Union, Callable, Optional, Dict, Any

from stable_baselines3.common import logger
from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback
from stable_baselines3.common.utils import explained_variance
from stable_baselines3.ppo.policies import PPOPolicy
from stable_baselines3.ppo.ppo import PPO
from stable_baselines3.common.policies import ActorCriticPolicy


class A2C(PPO):
class A2C(OnPolicyAlgorithm):
"""
Advantage Actor Critic (A2C)
Expand All @@ -20,7 +20,7 @@ class A2C(PPO):
Introduction to A2C: https://hackernoon.com/intuitive-rl-intro-to-advantage-actor-critic-a2c-4ff545978752
:param policy: (PPOPolicy or str) The policy model to use (MlpPolicy, CnnPolicy, ...)
:param policy: (ActorCriticPolicy or str) The policy model to use (MlpPolicy, CnnPolicy, ...)
:param env: (Gym environment or str) The environment to learn from (if registered in Gym, can be str)
:param learning_rate: (float or callable) The learning rate, it can be a function
:param n_steps: (int) The number of steps to run for each environment per update
Expand Down Expand Up @@ -49,7 +49,8 @@ class A2C(PPO):
Setting it to auto, the code will be run on the GPU if possible.
:param _init_setup_model: (bool) Whether or not to build the network at the creation of the instance
"""
def __init__(self, policy: Union[str, Type[PPOPolicy]],

def __init__(self, policy: Union[str, Type[ActorCriticPolicy]],
env: Union[GymEnv, str],
learning_rate: Union[float, Callable] = 7e-4,
n_steps: int = 5,
Expand All @@ -72,16 +73,17 @@ def __init__(self, policy: Union[str, Type[PPOPolicy]],
_init_setup_model: bool = True):

super(A2C, self).__init__(policy, env, learning_rate=learning_rate,
n_steps=n_steps, batch_size=None, n_epochs=1,
gamma=gamma, gae_lambda=gae_lambda, ent_coef=ent_coef,
vf_coef=vf_coef, max_grad_norm=max_grad_norm,
n_steps=n_steps, gamma=gamma, gae_lambda=gae_lambda,
ent_coef=ent_coef, vf_coef=vf_coef, max_grad_norm=max_grad_norm,
use_sde=use_sde, sde_sample_freq=sde_sample_freq,
tensorboard_log=tensorboard_log, policy_kwargs=policy_kwargs,
verbose=verbose, device=device, create_eval_env=create_eval_env,
seed=seed, _init_setup_model=False)

self.normalize_advantage = normalize_advantage
# Override PPO optimizer to match original implementation

# Update optimizer inside the policy if we want to use RMSProp
# (original implementation) rather than Adam
if use_rms_prop and 'optimizer_class' not in self.policy_kwargs:
self.policy_kwargs['optimizer_class'] = th.optim.RMSprop
self.policy_kwargs['optimizer_kwargs'] = dict(alpha=0.99, eps=rms_prop_eps,
Expand All @@ -90,13 +92,13 @@ def __init__(self, policy: Union[str, Type[PPOPolicy]],
if _init_setup_model:
self._setup_model()

def train(self, gradient_steps: int, batch_size: Optional[int] = None) -> None:
def train(self) -> None:
"""
Update policy using the currently gathered
rollout buffer (one gradient step over whole data).
"""
# Update optimizer learning rate
self._update_learning_rate(self.policy.optimizer)
# A2C with gradient_steps > 1 does not make sense
assert gradient_steps == 1, "A2C does not support multiple gradient steps"
# We do not use minibatches for A2C
assert batch_size is None, "A2C does not support minibatch"

for rollout_data in self.rollout_buffer.get(batch_size=None):

Expand Down Expand Up @@ -160,7 +162,7 @@ def learn(self,
eval_log_path: Optional[str] = None,
reset_num_timesteps: bool = True) -> 'A2C':

return super(A2C, self).learn(total_timesteps=total_timesteps, callback=callback, log_interval=log_interval,
eval_env=eval_env, eval_freq=eval_freq, n_eval_episodes=n_eval_episodes,
tb_log_name=tb_log_name, eval_log_path=eval_log_path,
reset_num_timesteps=reset_num_timesteps)
return super(A2C, self).learn(total_timesteps=total_timesteps, callback=callback,
log_interval=log_interval, eval_env=eval_env, eval_freq=eval_freq,
n_eval_episodes=n_eval_episodes, tb_log_name=tb_log_name,
eval_log_path=eval_log_path, reset_num_timesteps=reset_num_timesteps)
9 changes: 9 additions & 0 deletions stable_baselines3/a2c/policies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# This file is here just to define MlpPolicy/CnnPolicy
# that work for A2C
from stable_baselines3.common.policies import ActorCriticPolicy, ActorCriticCnnPolicy, register_policy

MlpPolicy = ActorCriticPolicy
CnnPolicy = ActorCriticCnnPolicy

register_policy("MlpPolicy", ActorCriticPolicy)
register_policy("CnnPolicy", ActorCriticCnnPolicy)

0 comments on commit 44f8218

Please sign in to comment.