Skip to content

Commit

Permalink
Support hparams logging to tensorboard (#984)
Browse files Browse the repository at this point in the history
* create Hparam class & support in all OutputFormats

* add hparams documentation & example

* add hparam tests

* remove unnecessary test & fix name

* format changes

* support hyperparameters logging to tensorboard

* fix HParams class docstring

* use more explicit variable names

* raise error instead of warning

* Unpin protobuf

* Add test for logging hparams

Co-authored-by: Antonin RAFFIN <antonin.raffin@ensta.org>
  • Loading branch information
timothe-chaumont and araffin committed Aug 22, 2022
1 parent 57e0054 commit 01cc127
Show file tree
Hide file tree
Showing 7 changed files with 139 additions and 10 deletions.
49 changes: 49 additions & 0 deletions docs/guide/tensorboard.rst
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,55 @@ Here is an example of how to render an episode and log the resulting video to Te
video_recorder = VideoRecorderCallback(gym.make("CartPole-v1"), render_freq=5000)
model.learn(total_timesteps=int(5e4), callback=video_recorder)
Logging Hyperparameters
-----------------------

TensorBoard supports logging of hyperparameters in its HPARAMS tab, which helps comparing agents trainings.

.. warning::
To display hyperparameters in the HPARAMS section, a ``metric_dict`` must be given (as well as a ``hparam_dict``).


Here is an example of how to save hyperparameters in TensorBoard:

.. code-block:: python
from stable_baselines3 import A2C
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.logger import HParam
class HParamCallback(BaseCallback):
def __init__(self):
"""
Saves the hyperparameters and metrics at the start of the training, and logs them to TensorBoard.
"""
super().__init__()
def _on_training_start(self) -> None:
hparam_dict = {
"algorithm": self.model.__class__.__name__,
"learning rate": self.model.learning_rate,
"gamma": self.model.gamma,
}
# define the metrics that will appear in the `HPARAMS` Tensorboard tab by referencing their tag
# Tensorbaord will find & display metrics from the `SCALARS` tab
metric_dict = {
"rollout/ep_len_mean": 0,
"train/value_loss": 0,
}
self.logger.record(
"hparams",
HParam(hparam_dict, metric_dict),
exclude=("stdout", "log", "json", "csv"),
)
def _on_step(self) -> bool:
return True
model = A2C("MlpPolicy", "CartPole-v1", tensorboard_log="runs/", verbose=1)
model.learn(total_timesteps=int(5e4), callback=HParamCallback())
Directly Accessing The Summary Writer
-------------------------------------
Expand Down
8 changes: 5 additions & 3 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,16 @@
Changelog
==========

Release 1.6.1a0 (WIP)
Release 1.6.1a1 (WIP)
---------------------------

Breaking Changes:
^^^^^^^^^^^^^^^^^
- Switched minimum tensorboard version to 2.9.1

New Features:
^^^^^^^^^^^^^
- Support logging hyperparameters to tensorboard (@timothe-chaumont)

SB3-Contrib
^^^^^^^^^^^
Expand All @@ -33,12 +35,12 @@ Others:

Documentation:
^^^^^^^^^^^^^^
- Added an example of callback that logs hyperparameters to tensorboard. (@timothe-chaumont)
- Fixed typo in docstring "nature" -> "Nature" (@Melanol)
- Added info on split tensorboard logs into (@Melanol)
- Fixed typo in ppo doc (@francescoluciano)
- Fixed typo in install doc(@jlp-ue)


Release 1.6.0 (2022-07-11)
---------------------------

Expand Down Expand Up @@ -1024,4 +1026,4 @@ And all the contributors:
@eleurent @ac-93 @cove9988 @theDebugger811 @hsuehch @Demetrio92 @thomasgubler @IperGiove @ScheiklP
@simoninithomas @armandpl @manuel-delverme @Gautam-J @gianlucadecola @buoyancy99 @caburu @xy9485
@Gregwar @ycheng517 @quantitative-technologies @bcollazo @git-thor @TibiGG @cool-RR @MWeltevrede
@Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb
@Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb @timothe-chaumont
5 changes: 1 addition & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,7 @@
"autorom[accept-rom-license]~=0.4.2",
"pillow",
# Tensorboard support
"tensorboard>=2.2.0",
# Protobuf >= 4 has breaking changes
# which does play well with tensorboard
"protobuf~=3.19.0",
"tensorboard>=2.9.1",
# Checking memory taken by replay buffer
"psutil",
],
Expand Down
32 changes: 32 additions & 0 deletions stable_baselines3/common/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

try:
from torch.utils.tensorboard import SummaryWriter
from torch.utils.tensorboard.summary import hparams
except ImportError:
SummaryWriter = None

Expand Down Expand Up @@ -66,6 +67,22 @@ def __init__(self, image: Union[th.Tensor, np.ndarray, str], dataformats: str):
self.dataformats = dataformats


class HParam:
"""
Hyperparameter data class storing hyperparameters and metrics in dictionnaries
:param hparam_dict: key-value pairs of hyperparameters to log
:param metric_dict: key-value pairs of metrics to log
A non-empty metrics dict is required to display hyperparameters in the corresponding Tensorboard section.
"""

def __init__(self, hparam_dict: Dict[str, Union[bool, str, float, int, None]], metric_dict: Dict[str, Union[float, int]]):
self.hparam_dict = hparam_dict
if not metric_dict:
raise Exception("`metric_dict` must not be empty to display hyperparameters to the HPARAMS tensorboard tab.")
self.metric_dict = metric_dict


class FormatUnsupportedError(NotImplementedError):
"""
Custom error to display informative message when
Expand Down Expand Up @@ -165,6 +182,9 @@ def write(self, key_values: Dict, key_excluded: Dict, step: int = 0) -> None:
elif isinstance(value, Image):
raise FormatUnsupportedError(["stdout", "log"], "image")

elif isinstance(value, HParam):
raise FormatUnsupportedError(["stdout", "log"], "hparam")

elif isinstance(value, float):
# Align left
value_str = f"{value:<8.3g}"
Expand Down Expand Up @@ -264,6 +284,8 @@ def cast_to_json_serializable(value: Any):
raise FormatUnsupportedError(["json"], "figure")
if isinstance(value, Image):
raise FormatUnsupportedError(["json"], "image")
if isinstance(value, HParam):
raise FormatUnsupportedError(["json"], "hparam")
if hasattr(value, "dtype"):
if value.shape == () or len(value) == 1:
# if value is a dimensionless numpy array or of length 1, serialize as a float
Expand Down Expand Up @@ -333,6 +355,9 @@ def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, T
elif isinstance(value, Image):
raise FormatUnsupportedError(["csv"], "image")

elif isinstance(value, HParam):
raise FormatUnsupportedError(["csv"], "hparam")

elif isinstance(value, str):
# escape quotechars by prepending them with another quotechar
value = value.replace(self.quotechar, self.quotechar + self.quotechar)
Expand Down Expand Up @@ -389,6 +414,13 @@ def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, T
if isinstance(value, Image):
self.writer.add_image(key, value.image, step, dataformats=value.dataformats)

if isinstance(value, HParam):
# we don't use `self.writer.add_hparams` to have control over the log_dir
experiment, session_start_info, session_end_info = hparams(value.hparam_dict, metric_dict=value.metric_dict)
self.writer.file_writer.add_summary(experiment)
self.writer.file_writer.add_summary(session_start_info)
self.writer.file_writer.add_summary(session_end_info)

# Flush the output to the file
self.writer.flush()

Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.6.1a0
1.6.1a1
14 changes: 14 additions & 0 deletions tests/test_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
CSVOutputFormat,
Figure,
FormatUnsupportedError,
HParam,
HumanOutputFormat,
Image,
Logger,
Expand Down Expand Up @@ -296,6 +297,19 @@ def test_report_figure_to_unsupported_format_raises_error(tmp_path, unsupported_
writer.close()


@pytest.mark.parametrize("unsupported_format", ["stdout", "log", "json", "csv"])
def test_report_hparam_to_unsupported_format_raises_error(tmp_path, unsupported_format):
writer = make_output_format(unsupported_format, tmp_path)

with pytest.raises(FormatUnsupportedError) as exec_info:
hparam_dict = {"learning rate": np.random.random()}
metric_dict = {"train/value_loss": 0}
hparam = HParam(hparam_dict=hparam_dict, metric_dict=metric_dict)
writer.write({"hparam": hparam}, key_excluded={"hparam": ()})
assert unsupported_format in str(exec_info.value)
writer.close()


def test_key_length(tmp_path):
writer = make_output_format("stdout", tmp_path)
assert writer.max_length == 36
Expand Down
39 changes: 37 additions & 2 deletions tests/test_tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import pytest

from stable_baselines3 import A2C, PPO, SAC, TD3
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.logger import HParam
from stable_baselines3.common.utils import get_latest_run_id

MODEL_DICT = {
Expand All @@ -15,15 +17,48 @@
N_STEPS = 100


class HParamCallback(BaseCallback):
def __init__(self):
"""
Saves the hyperparameters and metrics at the start of the training, and logs them to TensorBoard.
"""
super().__init__()

def _on_training_start(self) -> None:
hparam_dict = {
"algorithm": self.model.__class__.__name__,
"learning rate": self.model.learning_rate,
"gamma": self.model.gamma,
}
# define the metrics that will appear in the `HPARAMS` Tensorboard tab by referencing their tag
# Tensorbaord will find & display metrics from the `SCALARS` tab
metric_dict = {
"rollout/ep_len_mean": 0,
}
self.logger.record(
"hparams",
HParam(hparam_dict, metric_dict),
exclude=("stdout", "log", "json", "csv"),
)

def _on_step(self) -> bool:
return True


@pytest.mark.parametrize("model_name", MODEL_DICT.keys())
def test_tensorboard(tmp_path, model_name):
# Skip if no tensorboard installed
pytest.importorskip("tensorboard")

logname = model_name.upper()
algo, env_id = MODEL_DICT[model_name]
model = algo("MlpPolicy", env_id, verbose=1, tensorboard_log=tmp_path)
model.learn(N_STEPS)
kwargs = {}
if model_name == "ppo":
kwargs["n_steps"] = 64
elif model_name in {"sac", "td3"}:
kwargs["train_freq"] = 2
model = algo("MlpPolicy", env_id, verbose=1, tensorboard_log=tmp_path, **kwargs)
model.learn(N_STEPS, callback=HParamCallback())
model.learn(N_STEPS, reset_num_timesteps=False)

assert os.path.isdir(tmp_path / str(logname + "_1"))
Expand Down

0 comments on commit 01cc127

Please sign in to comment.