Skip to content

Commit

Permalink
Update tensorboardlogger's hparam function.
Browse files Browse the repository at this point in the history
  • Loading branch information
mzenk committed May 26, 2020
1 parent 6c3d5ba commit d0f9907
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 6 deletions.
2 changes: 1 addition & 1 deletion Readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ The Logging API supports

And offers different Backends, e.g. :
* Visdom ([visdom-loggers](https://trixi.readthedocs.io/en/latest/_api/trixi.logger.html#module-trixi.logger.visdom))
* TensorboardX ([tensorboard-loggers](https://trixi.readthedocs.io/en/latest/_api/trixi.logger.html#module-trixi.logger.tensorboard))
* Tensorboard ([tensorboard-loggers](https://trixi.readthedocs.io/en/latest/_api/trixi.logger.html#module-trixi.logger.tensorboard))
* Matplotlib / Seaborn ([plt-loggers](https://trixi.readthedocs.io/en/latest/_api/trixi.logger.html#module-trixi.logger.plt))
* Local Disk ([file-loggers](https://trixi.readthedocs.io/en/latest/_api/trixi.logger.html#module-trixi.logger.file))
* Telegram & Slack ([message-loggers](https://trixi.readthedocs.io/en/latest/_api/trixi.logger.html#module-trixi.logger.message))
Expand Down
26 changes: 21 additions & 5 deletions trixi/logger/tensorboard/tensorboardlogger.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def show_piechart(self, array, name="piechart", counter=None, *args, **kwargs):

def show_embedding(self, tensor, labels=None, name="default", label_img=None, counter=None, *args, **kwargs):
"""
Displays an embedding of a tensor (for more details see tensorboardX)
Displays an embedding of a tensor (for more details see tensorboard)
Args:
tensor (torch.tensor/np.array): Tensor to be embedded and then displayed
Expand Down Expand Up @@ -295,21 +295,37 @@ def show_pr_curve(self, tensor, labels, name="pr-curve", counter=None, *args, **
tag=name, labels=labels, predictions=tensor, global_step=self.val_dict["{}-pr-curve".format(name)]
)

# adapted from torch.utils.tensorboard.add_hparams
def show_hparams(self, hparam_dict=None, metric_dict=None, counter=None, *args, **kwargs):
"""
When using tensorboard's hparam-view, it can be filled with this function.
It is usually called once per experiment. The values can later be updated by calling show_value with the same
name (and WITHOUT tag) as used in metric_dict. It is recommended to choose that name such that it does not
coincide with other quantities your logging, because it will also appear in the scalars view
(e.g. `hp/loss' instead of `loss`)
Args:
hparam_dict: Each key-value pair in the dictionary is the name of the hyper parameter and it’s corresponding value.
metric_dict: Each key-value pair in the dictionary is the name of the metric and it’s corresponding value.
Note that the key used here should be unique in the tensorboard record. Otherwise the value you added by
show_value will be displayed in the hparam plugin. In most cases, this is unwanted.
counter (int): Global step value
"""
if counter is not None:
self.val_dict["{}-hparams"] = counter
self.val_dict["hparams"] = counter
else:
self.val_dict["{}-hparams"] += 1
self.val_dict["hparams"] += 1

if type(hparam_dict) is not dict or type(metric_dict) is not dict:
raise TypeError('hparam_dict and metric_dict should be dictionary.')

from torch.utils.tensorboard.summary import hparams
exp, ssi, sei = hparams(hparam_dict, metric_dict)

self.writer.add_hparams(hparam_dict=hparam_dict, metric_dict=metric_dict)
self.writer.file_writer.add_summary(exp)
self.writer.file_writer.add_summary(ssi)
self.writer.file_writer.add_summary(sei)
for k, v in metric_dict.items():
self.writer.add_scalar(k, v, counter)

def close(self):
self.writer.close()

0 comments on commit d0f9907

Please sign in to comment.