Skip to content

Commit

Permalink
Tensorboard integration (#30)
Browse files Browse the repository at this point in the history
* init commit tensorboard-integration

* Added tb logger to ppo (with output exclusions)

* fixed truncated stdout

* categorize stdout outputs by tag

* separated exclusions from values, added missing logs

* saving exclusions as dict instead of list

* reformatting, auto run indexing

* included renaming suggestions, fixed tests

* tb support for sac

* linting

* moved logging to base class

* tb support for td3

* removed histograms, non-verbose output working

* modifed changelog

* linting

* fixed type error

* moved logger config to utils

* removed episode_rewards log from ppo

* Enable tensorboard in tests

* Remove unused import

* Update logger sub titles

* Minor edit for PPO

* Update logger and tb log folder

* Pass correct logger to Callbacks

* updated docs

* added tb example image to docs

* add support for continuing training in tensorboard

* added tensorboard to docs index

* added tb test

* moved logger config to _setup_learn, updated tests

* accessing verbose from base class

* Update doc and tests

* Rename session -> time

* Update version

* Update logger truncate

* Update types

* Remove duplicated code

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
  • Loading branch information
rolandgvc and araffin committed Jun 1, 2020
1 parent 42f432c commit bb01253
Show file tree
Hide file tree
Showing 19 changed files with 488 additions and 233 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ These algorithms will make it easier for the research community and industry to
| Custom policies | :heavy_check_mark: |
| Common interface | :heavy_check_mark: |
| Ipython / Notebook friendly | :heavy_check_mark: |
| Tensorboard support | :heavy_check_mark: |
| PEP8 code style | :heavy_check_mark: |
| Custom callback | :heavy_check_mark: |
| High code coverage | :heavy_check_mark: |
Expand All @@ -48,7 +49,6 @@ Planned features:

### Planned features (v1.1+)

- [ ] Full Tensorboard support
- [ ] DQN extensions (prioritized replay, double q-learning, ...)
- [ ] Support for `Tuple` and `Dict` observation spaces
- [ ] Recurrent Policies
Expand Down Expand Up @@ -99,7 +99,7 @@ Install the Stable Baselines3 package:
pip install stable-baselines3[extra]
```

This includes an optional dependencies like OpenCV or `atari-py` to train on atari games. If you do not need those, you can use:
This includes an optional dependencies like Tensorboard, OpenCV or `atari-py` to train on atari games. If you do not need those, you can use:
```
pip install stable-baselines3
```
Expand Down
Binary file added docs/_static/img/Tensorboard_example.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion docs/guide/callbacks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ This will give you access to events (``_on_training_start``, ``_on_step``) and u
# self.locals = None # type: Dict[str, Any]
# self.globals = None # type: Dict[str, Any]
# The logger object, used to report things in the terminal
# self.logger = None # type: logger.Logger
# self.logger = None # stable_baselines3.common.logger
# # Sometimes, for event callback, it is useful
# # to have access to the parent object
# self.parent = None # type: Optional[BaseCallback]
Expand Down
2 changes: 1 addition & 1 deletion docs/guide/install.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ To install Stable Baselines3 with pip, execute:
pip install stable-baselines3[extra]
This includes an optional dependencies like OpenCV or ```atari-py``` to train on atari games. If you do not need those, you can use:
This includes an optional dependencies like Tensorboard, OpenCV or ```atari-py``` to train on atari games. If you do not need those, you can use:

.. code-block:: bash
Expand Down
82 changes: 82 additions & 0 deletions docs/guide/tensorboard.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
.. _tensorboard:

Tensorboard Integration
=======================

Basic Usage
------------

To use Tensorboard with stable baselines3, you simply need to pass the location of the log folder to the RL agent:

.. code-block:: python
from stable_baselines3 import A2C
model = A2C('MlpPolicy', 'CartPole-v1', verbose=1, tensorboard_log="./a2c_cartpole_tensorboard/")
model.learn(total_timesteps=10000)
You can also define custom logging name when training (by default it is the algorithm name)

.. code-block:: python
from stable_baselines3 import A2C
model = A2C('MlpPolicy', 'CartPole-v1', verbose=1, tensorboard_log="./a2c_cartpole_tensorboard/")
model.learn(total_timesteps=10000, tb_log_name="first_run")
# Pass reset_num_timesteps=False to continue the training curve in tensorboard
# By default, it will create a new curve
model.learn(total_timesteps=10000, tb_log_name="second_run", reset_num_timesteps=False)
model.learn(total_timesteps=10000, tb_log_name="third_run", reset_num_timesteps=False)
Once the learn function is called, you can monitor the RL agent during or after the training, with the following bash command:

.. code-block:: bash
tensorboard --logdir ./a2c_cartpole_tensorboard/
you can also add past logging folders:

.. code-block:: bash
tensorboard --logdir ./a2c_cartpole_tensorboard/;./ppo2_cartpole_tensorboard/
It will display information such as the episode reward (when using a ``Monitor`` wrapper), the model losses and other parameter unique to some models.

.. image:: ../_static/img/Tensorboard_example.png
:width: 600
:alt: plotting

Logging More Values
-------------------

Using a callback, you can easily log more values with TensorBoard.
Here is a simple example on how to log both additional tensor or arbitrary scalar value:

.. code-block:: python
import numpy as np
from stable_baselines3 import SAC
from stable_baselines3.common.callbacks import BaseCallback
model = SAC("MlpPolicy", "Pendulum-v0", tensorboard_log="/tmp/sac/", verbose=1)
class TensorboardCallback(BaseCallback):
"""
Custom callback for plotting additional values in tensorboard.
"""
def __init__(self, verbose=0):
super(TensorboardCallback, self).__init__(verbose)
def _on_step(self) -> bool:
# Log scalar value (here a random variable)
value = np.random.random()
self.logger.record('random_value', value)
return True
model.learn(50000, callback=TensorboardCallback())
3 changes: 2 additions & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ Main Features
- Documented functions and classes
- Tests, high code coverage and type hints
- Clean code

- Tensorboard support


.. toctree::
Expand All @@ -42,6 +42,7 @@ Main Features
guide/custom_env
guide/custom_policy
guide/callbacks
guide/tensorboard
guide/rl_zoo
guide/migration
guide/checking_nan
Expand Down
10 changes: 9 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,17 @@
Changelog
==========

Pre-Release 0.6.0a10 (WIP)
Pre-Release 0.6.0a11 (WIP)
------------------------------

Breaking Changes:
^^^^^^^^^^^^^^^^^
- Remove State-Dependent Exploration (SDE) support for ``TD3``
- Methods were renamed in the logger:
- ``logkv`` -> ``record``, ``writekvs`` -> ``write``, ``writeseq`` -> ``write_sequence``,
- ``logkvs`` -> ``record_dict``, ``dumpkvs`` -> ``dump``,
- ``getkvs`` -> ``get_log_dict``, ``logkv_mean`` -> ``record_mean``,


New Features:
^^^^^^^^^^^^^
Expand All @@ -18,14 +23,17 @@ New Features:
- Added ``cmd_util`` and ``atari_wrappers``
- Added support for ``MultiDiscrete`` and ``MultiBinary`` observation spaces (@rolandgvc)
- Added ``MultiCategorical`` and ``Bernoulli`` distributions for PPO/A2C (@rolandgvc)
- Added support for logging to tensorboard (@rolandgvc)
- Added ``VectorizedActionNoise`` for continuous vectorized environments (@PartiallyTyped)
- Log evaluation in the ``EvalCallback`` using the logger

Bug Fixes:
^^^^^^^^^^
- Fixed a bug that prevented model trained on cpu to be loaded on gpu
- Fixed version number that had a new line included
- Fixed weird seg fault in docker image due to FakeImageEnv by reducing screen size
- Fixed ``sde_sample_freq`` that was not taken into account for SAC
- Pass logger module to ``BaseCallback`` otherwise they cannot write in the one used by the algorithms

Deprecations:
^^^^^^^^^^^^^
Expand Down
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,9 @@
# For render
'opencv-python',
# For atari games,
'atari_py~=0.2.0', 'pillow'
'atari_py~=0.2.0', 'pillow',
# Tensorboard support
'tensorboard'
]
},
description='Pytorch version of Stable Baselines, implementations of reinforcement learning algorithms.',
Expand Down
12 changes: 6 additions & 6 deletions stable_baselines3/a2c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,13 +141,13 @@ def train(self, gradient_steps: int, batch_size: Optional[int] = None) -> None:
self.rollout_buffer.values.flatten())

self._n_updates += 1
logger.logkv("n_updates", self._n_updates)
logger.logkv("explained_variance", explained_var)
logger.logkv("entropy_loss", entropy_loss.item())
logger.logkv("policy_loss", policy_loss.item())
logger.logkv("value_loss", value_loss.item())
logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
logger.record("train/explained_variance", explained_var)
logger.record("train/entropy_loss", entropy_loss.item())
logger.record("train/policy_loss", policy_loss.item())
logger.record("train/value_loss", value_loss.item())
if hasattr(self.policy, 'log_std'):
logger.logkv("std", th.exp(self.policy.log_std).mean().item())
logger.record("train/std", th.exp(self.policy.log_std).mean().item())

def learn(self,
total_timesteps: int,
Expand Down
55 changes: 35 additions & 20 deletions stable_baselines3/common/base_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import torch as th
import numpy as np

from stable_baselines3.common import logger
from stable_baselines3.common import logger, utils
from stable_baselines3.common.policies import BasePolicy, get_policy_from_name
from stable_baselines3.common.utils import set_random_seed, get_schedule_fn, update_learning_rate, get_device
from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, unwrap_vec_normalize, VecNormalize, VecTransposeImage
Expand All @@ -35,6 +35,7 @@ class BaseRLModel(ABC):
:param learning_rate: (float or callable) learning rate for the optimizer,
it can be a function of the current progress (from 1 to 0)
:param policy_kwargs: (Dict[str, Any]) Additional arguments to be passed to the policy on creation
:param tensorboard_log: (str) the log location for tensorboard (if None, no logging)
:param verbose: (int) The verbosity level: 0 none, 1 training information, 2 debug
:param device: (Union[th.device, str]) Device on which the code should run.
By default, it will try to use a Cuda compatible device and fallback to cpu
Expand All @@ -58,6 +59,7 @@ def __init__(self,
policy_base: Type[BasePolicy],
learning_rate: Union[float, Callable],
policy_kwargs: Dict[str, Any] = None,
tensorboard_log: Optional[str] = None,
verbose: int = 0,
device: Union[th.device, str] = 'auto',
support_multi_env: bool = False,
Expand Down Expand Up @@ -91,6 +93,7 @@ def __init__(self,
self.start_time = None
self.policy = None
self.learning_rate = learning_rate
self.tensorboard_log = tensorboard_log
self.lr_schedule = None # type: Optional[Callable]
self._last_obs = None # type: Optional[np.ndarray]
# When using VecNormalize:
Expand Down Expand Up @@ -191,7 +194,7 @@ def _update_learning_rate(self, optimizers: Union[List[th.optim.Optimizer], th.o
An optimizer or a list of optimizers.
"""
# Log the current learning rate
logger.logkv("learning_rate", self.lr_schedule(self._current_progress))
logger.record("train/learning_rate", self.lr_schedule(self._current_progress))

if not isinstance(optimizers, list):
optimizers = [optimizers]
Expand Down Expand Up @@ -289,7 +292,7 @@ def learn(self, total_timesteps: int,
"""
Return a trained model.
:param total_timesteps: (int) The total number of samples to train on
:param total_timesteps: (int) The total number of samples (env steps) to train on
:param callback: (function (dict, dict)) -> boolean function called at every steps with state of the algorithm.
It takes the local and global variables. If it returns False, training is aborted.
:param log_interval: (int) The number of timesteps before logging.
Expand Down Expand Up @@ -491,23 +494,27 @@ def _init_callback(self,
return callback

def _setup_learn(self,
total_timesteps: int,
eval_env: Optional[GymEnv],
callback: Union[None, Callable, List[BaseCallback], BaseCallback] = None,
eval_freq: int = 10000,
n_eval_episodes: int = 5,
log_path: Optional[str] = None,
reset_num_timesteps: bool = True,
) -> 'BaseCallback':
tb_log_name: str = 'run',
) -> Tuple[int, 'BaseCallback']:
"""
Initialize different variables needed for training.
:param total_timesteps: (int) The total number of samples (env steps) to train on
:param eval_env: (Optional[GymEnv])
:param callback: (Union[None, BaseCallback, List[BaseCallback, Callable]])
:param eval_freq: (int)
:param n_eval_episodes: (int)
:param log_path (Optional[str]): Path to a log folder
:param reset_num_timesteps: (bool) Whether to reset or not the ``num_timesteps`` attribute
:return: (BaseCallback)
:param tb_log_name: (str) the name of the run for tensorboard log
:return: (int, Tuple[BaseCallback])
"""
self.start_time = time.time()
self.ep_info_buffer = deque(maxlen=100)
Expand All @@ -519,6 +526,9 @@ def _setup_learn(self,
if reset_num_timesteps:
self.num_timesteps = 0
self._episode_num = 0
else:
# Make sure training timesteps are ahead of the internal counter
total_timesteps += self.num_timesteps

# Avoid resetting the environment when calling ``.learn()`` consecutive times
if reset_num_timesteps or self._last_obs is None:
Expand All @@ -532,10 +542,13 @@ def _setup_learn(self,

eval_env = self._get_eval_env(eval_env)

# Configure logger's outputs
utils.configure_logger(self.verbose, self.tensorboard_log, tb_log_name, reset_num_timesteps)

# Create eval callback if needed
callback = self._init_callback(callback, eval_env, eval_freq, n_eval_episodes, log_path)

return callback
return total_timesteps, callback

def _update_info_buffer(self, infos: List[Dict[str, Any]], dones: Optional[np.ndarray] = None) -> None:
"""
Expand Down Expand Up @@ -697,6 +710,7 @@ def __init__(self,
learning_starts: int = 100,
batch_size: int = 256,
policy_kwargs: Dict[str, Any] = None,
tensorboard_log: Optional[str] = None,
verbose: int = 0,
device: Union[th.device, str] = 'auto',
support_multi_env: bool = False,
Expand All @@ -709,13 +723,13 @@ def __init__(self,
sde_support: bool = True):

super(OffPolicyRLModel, self).__init__(policy, env, policy_base, learning_rate,
policy_kwargs, verbose,
policy_kwargs, tensorboard_log, verbose,
device, support_multi_env, create_eval_env, monitor_wrapper,
seed, use_sde, sde_sample_freq)
self.buffer_size = buffer_size
self.batch_size = batch_size
self.learning_starts = learning_starts
self.actor = None
self.actor = None # type: Optional[th.nn.Module]
self.replay_buffer = None # type: Optional[ReplayBuffer]
# Update policy keyword arguments
if sde_support:
Expand Down Expand Up @@ -752,7 +766,7 @@ def load_replay_buffer(self, path: str):
self.replay_buffer = pickle.load(file_handler)
assert isinstance(self.replay_buffer, ReplayBuffer), 'The replay buffer must inherit from ReplayBuffer class'

def collect_rollouts(self,
def collect_rollouts(self, # noqa: C901
env: VecEnv,
# Type hint as string to avoid circular import
callback: 'BaseCallback',
Expand Down Expand Up @@ -873,22 +887,23 @@ def collect_rollouts(self,
if action_noise is not None:
action_noise.reset()

# Display training infos
if self.verbose >= 1 and log_interval is not None and self._episode_num % log_interval == 0:
# Log training infos
if log_interval is not None and self._episode_num % log_interval == 0:
fps = int(self.num_timesteps / (time.time() - self.start_time))
logger.logkv("episodes", self._episode_num)
logger.record("time/episodes", self._episode_num, exclude="tensorboard")
if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:
logger.logkv('ep_rew_mean', self.safe_mean([ep_info['r'] for ep_info in self.ep_info_buffer]))
logger.logkv('ep_len_mean', self.safe_mean([ep_info['l'] for ep_info in self.ep_info_buffer]))
logger.logkv("fps", fps)
logger.logkv('time_elapsed', int(time.time() - self.start_time))
logger.logkv("total timesteps", self.num_timesteps)
logger.record('rollout/ep_rew_mean', self.safe_mean([ep_info['r'] for ep_info in self.ep_info_buffer]))
logger.record('rollout/ep_len_mean', self.safe_mean([ep_info['l'] for ep_info in self.ep_info_buffer]))
logger.record("time/fps", fps)
logger.record('time/time_elapsed', int(time.time() - self.start_time), exclude="tensorboard")
logger.record("time/total timesteps", self.num_timesteps, exclude="tensorboard")
if self.use_sde:
logger.logkv("std", (self.actor.get_std()).mean().item())
logger.record("train/std", (self.actor.get_std()).mean().item())

if len(self.ep_success_buffer) > 0:
logger.logkv('success rate', self.safe_mean(self.ep_success_buffer))
logger.dumpkvs()
logger.record('rollout/success rate', self.safe_mean(self.ep_success_buffer))
# Pass the number of timesteps for tensorboard
logger.dump(step=self.num_timesteps)

mean_reward = np.mean(episode_rewards) if total_episodes > 0 else 0.0

Expand Down

0 comments on commit bb01253

Please sign in to comment.