Skip to content

Commit

Permalink
Changed handling of tensorboard logger when receiving float as count
Browse files Browse the repository at this point in the history
  • Loading branch information
dzimmerer committed May 8, 2020
1 parent ff329a8 commit 6c3d5ba
Showing 1 changed file with 33 additions and 18 deletions.
51 changes: 33 additions & 18 deletions trixi/logger/tensorboard/tensorboardlogger.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,16 +74,16 @@ def show_value(self, value, name="Value", counter=None, tag=None, **kwargs):
else:
key = tag + "-" + name

if counter is not None:
self.val_dict["{}-image".format(key)] = counter
if counter is not None and isinstance(counter, int):
self.val_dict[f"{key}-image"] = counter
else:
self.val_dict["{}-image".format(key)] += 1
self.val_dict[f"{key}-image"] += 1

if tag is not None:
self.writer.add_scalars(tag, {name: value}, global_step=self.val_dict["{}-image".format(key)])
self.writer.add_scalars(tag, {name: value}, global_step=self.val_dict[f"{key}-image"])
self.writer.scalar_dict = {}
else:
self.writer.add_scalar(name, value, global_step=self.val_dict["{}-image".format(key)])
self.writer.add_scalar(name, value, global_step=self.val_dict[f"{key}-image"])

def show_text(self, text, name="Text", counter=None, **kwargs):
"""
Expand All @@ -103,9 +103,20 @@ def show_text(self, text, name="Text", counter=None, **kwargs):
self.writer.add_text(name, text, global_step=self.val_dict["{}-text".format(name)])

@convert_params
def show_image_grid(self, image_array, name="Image-Grid", counter=None, nrow=8, padding=2,
normalize=False, range=None, scale_each=False, pad_value=0,
*args, **kwargs):
def show_image_grid(
self,
image_array,
name="Image-Grid",
counter=None,
nrow=8,
padding=2,
normalize=False,
range=None,
scale_each=False,
pad_value=0,
*args,
**kwargs,
):
"""
Sends an array of images to tensorboard as a grid. Like :meth:`.show_image`, but generates
image grid before.
Expand All @@ -123,12 +134,9 @@ def show_image_grid(self, image_array, name="Image-Grid", counter=None, nrow=8,
pad_value (float): Fill padding with this value
"""

image_args = dict(nrow=nrow,
padding=padding,
normalize=normalize,
range=range,
scale_each=scale_each,
pad_value=pad_value)
image_args = dict(
nrow=nrow, padding=padding, normalize=normalize, range=range, scale_each=scale_each, pad_value=pad_value
)

if counter is not None:
self.val_dict["{}-image".format(name)] = counter
Expand Down Expand Up @@ -223,8 +231,7 @@ def show_piechart(self, array, name="piechart", counter=None, *args, **kwargs):
figure = super().show_piechart(array, name, *args, **kwargs)
self.writer.add_figure(tag=name, figure=figure, global_step=self.val_dict["{}-figure".format(name)])

def show_embedding(self, tensor, labels=None, name='default', label_img=None, counter=None,
*args, **kwargs):
def show_embedding(self, tensor, labels=None, name="default", label_img=None, counter=None, *args, **kwargs):
"""
Displays an embedding of a tensor (for more details see tensorboardX)
Expand All @@ -242,7 +249,13 @@ def show_embedding(self, tensor, labels=None, name='default', label_img=None, co
else:
self.val_dict["{}-embedding".format(name)] += 1

self.writer.add_embedding(mat=tensor, metadata=labels, label_img=label_img, tag=name, global_step=self.val_dict["{}-embedding".format(name)])
self.writer.add_embedding(
mat=tensor,
metadata=labels,
label_img=label_img,
tag=name,
global_step=self.val_dict["{}-embedding".format(name)],
)

def show_histogram(self, array, name="Histogram", counter=None, *args, **kwargs):
"""
Expand Down Expand Up @@ -278,7 +291,9 @@ def show_pr_curve(self, tensor, labels, name="pr-curve", counter=None, *args, **
else:
self.val_dict["{}-pr-curve".format(name)] += 1

self.writer.add_pr_curve(tag=name, labels=labels, predictions=tensor, global_step=self.val_dict["{}-pr-curve".format(name)])
self.writer.add_pr_curve(
tag=name, labels=labels, predictions=tensor, global_step=self.val_dict["{}-pr-curve".format(name)]
)

def show_hparams(self, hparam_dict=None, metric_dict=None, counter=None, *args, **kwargs):
"""
Expand Down

0 comments on commit 6c3d5ba

Please sign in to comment.