-
Notifications
You must be signed in to change notification settings - Fork 1.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support hparams logging to tensorboard #984
Support hparams logging to tensorboard #984
Conversation
stable_baselines3/common/logger.py
Outdated
@@ -389,6 +414,13 @@ def write(self, key_values: Dict[str, Any], key_excluded: Dict[str, Union[str, T | |||
if isinstance(value, Image): | |||
self.writer.add_image(key, value.image, step, dataformats=value.dataformats) | |||
|
|||
if isinstance(value, HParam): | |||
# we don't use `self.writer.add_hparams` to have control over the log_dir | |||
exp, ssi, sei = hparams(value.hparam_dict, metric_dict=value.metric_dict) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please use explicit names, what is ssi? sei?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's done:
experiment, session_start_info, session_end_info = hparams(value.hparam_dict, metric_dict=value.metric_dict)
The content and meaning of those variables is described in hparam
docstring. (And the initial short names where those used in pytorch SummaryWriter
class)
@@ -296,6 +297,19 @@ def test_report_figure_to_unsupported_format_raises_error(tmp_path, unsupported_ | |||
writer.close() | |||
|
|||
|
|||
@pytest.mark.parametrize("unsupported_format", ["stdout", "log", "json", "csv"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is the new feature also tested somewhere?
currently only the failure case is tested?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tested the new feature in practice, but not programmatically. It is not as easy as for the other types of logs (e.g. images):
I did not create a test_report_hparam_to_tensorboard test, as hparams logs are not seen by EventAccumulator (c.f. Stack Overflow and EventAccumulator implementation).
Two solutions are proposed in this stackoverflow page but one requires an external library, and the other is not very clean.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
more than reading the logged hparams, we should at least run the logger.
…/tim99oth99e/stable-baselines3 into feat/hparams-tensorboard-support
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks =)
Description
I created a
HParam
data class in the same way asFigure
,Image
ones. It can take any number of distinct hyperparameters and metrics as input.In each
.write()
method ofHumanOutputFormat
,JSONOutputFormat
andCSVOutputFormat
, I have raised errors whenhparams
format was given. e.g.TensorBoardOutputFormat
to log the hparams values to tensorboard:As described in the code comment, I have not used
self.writer.add_hparams(hparam_dict, metric_dict)
, provided by pytorch, but reused some of the code from that method. Using this first method, the hparams could not be saved in the same run folder as the other logs, and so metrics from SCALARS tab could not appear in HPARAMS tab.I created one parametrized test :
test_report_hparam_to_unsupported_format_raises_error()
in the same way as other data classes. I did not create atest_report_hparam_to_tensorboard
test, as hparams logs are not seen byEventAccumulator
(c.f. Stack Overflow and EventAccumulator implementation).Finally, in the tensorboard section of the documentation, I added an example of a callback that uses this new code to log hyperparameters to tensorboard.
Additionnal information - choices made
As hyperparameters are key-value pairs, adding the support for csv, json, or human output formats could be a future developpment.
It is not required to pass a
metric_dict
tohparams()
orwriter.add_hparams()
, but if we don't, then nothing is displayed in HPARAMS tab. I have added a warning to alert the user about that.When adding metrics in
metric_dict
, the users have 2 choices:It is not ideal to display the last value - instead of the best value, for example (issue discussed here) - but I decided to use the second one in the documentation example as I found it more relevant and intuitive.
In the example, I put the logic in
_on_training_start()
as we only have to log the hyperparameters & metrics once, but I could also added it inon_step()
as I did it here.Motivation and Context
[x] I have raised an issue to propose this change (required for new features and bug fixes)I have proposed a solution to Also log hyperparameters to the tensorboard #428 : a way of logging hyperparameters to tensorboard.closes #428
Types of changes
Checklist:
make format
(required)make check-codestyle
andmake lint
(required)make pytest
andmake type
both pass. (required)make doc
(required)Note: You can run most of the checks using
make commit-checks
.Note: we are using a maximum length of 127 characters per line