diff --git a/docs/requirements.txt b/docs/requirements.txt index cefb47e7e0..55bb8f0cb0 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -22,3 +22,4 @@ pandas einops transformers mlflow +tensorboardX diff --git a/docs/source/installation.md b/docs/source/installation.md index 2649756815..008ff00e79 100644 --- a/docs/source/installation.md +++ b/docs/source/installation.md @@ -174,9 +174,9 @@ Since MONAI v0.2.0, the extras syntax such as `pip install 'monai[nibabel]'` is - The options are ``` -[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers, mlflow, matplotlib] +[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers, mlflow, matplotlib, tensorboardX] ``` which correspond to `nibabel`, `scikit-image`, `pillow`, `tensorboard`, -`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops`, `transformers`, `mlflow`, `matplotlib`, respectively. +`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops`, `transformers`, `mlflow`, `matplotlib`, `tensorboardX` respectively. - `pip install 'monai[all]'` installs all the optional dependencies. diff --git a/monai/handlers/tensorboard_handlers.py b/monai/handlers/tensorboard_handlers.py index d294d0adb5..490ba5d2d1 100644 --- a/monai/handlers/tensorboard_handlers.py +++ b/monai/handlers/tensorboard_handlers.py @@ -20,6 +20,7 @@ from monai.visualize import plot_2d_or_3d_image Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events") + if TYPE_CHECKING: from ignite.engine import Engine from torch.utils.tensorboard import SummaryWriter @@ -35,8 +36,8 @@ class TensorBoardHandler: Base class for the handlers to write data into TensorBoard. Args: - summary_writer: user can specify TensorBoard SummaryWriter, - default to create a new writer. + summary_writer: user can specify TensorBoard or TensorBoardX SummaryWriter, + default to create a new TensorBoard writer. log_dir: if using default SummaryWriter, write logs to this directory, default is `./runs`. """ @@ -94,8 +95,8 @@ def __init__( ) -> None: """ Args: - summary_writer: user can specify TensorBoard SummaryWriter, - default to create a new writer. + summary_writer: user can specify TensorBoard or TensorBoardX SummaryWriter, + default to create a new TensorBoard writer. log_dir: if using default SummaryWriter, write logs to this directory, default is `./runs`. epoch_event_writer: customized callable TensorBoard writer for epoch level. Must accept parameter "engine" and "summary_writer", use default event writer if None. @@ -180,7 +181,7 @@ def _default_epoch_writer(self, engine: Engine, writer: SummaryWriter) -> None: Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. - writer: TensorBoard writer, created in TensorBoardHandler. + writer: TensorBoard or TensorBoardX writer, passed or created in TensorBoardHandler. """ current_epoch = self.global_epoch_transform(engine.state.epoch) @@ -203,7 +204,7 @@ def _default_iteration_writer(self, engine: Engine, writer: SummaryWriter) -> No Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. - writer: TensorBoard writer, created in TensorBoardHandler. + writer: TensorBoard or TensorBoardX writer, passed or created in TensorBoardHandler. """ loss = self.output_transform(engine.state.output) @@ -243,6 +244,7 @@ class TensorBoardImageHandler(TensorBoardHandler): 2D output (shape in Batch, channel, H, W) will be shown as simple image using the first element in the batch, for 3D to ND output (shape in Batch, channel, H, W, D) input, each of ``self.max_channels`` number of images' last three dimensions will be shown as animated GIF along the last axis (typically Depth). + And if writer is from TensorBoardX, data has 3 channels and `max_channels=3`, will plot as RGB video. It can be used for any Ignite Engine (trainer, validator and evaluator). User can easily add it to engine for any expected Event, for example: ``EPOCH_COMPLETED``, @@ -277,8 +279,8 @@ def __init__( ) -> None: """ Args: - summary_writer: user can specify TensorBoard SummaryWriter, - default to create a new writer. + summary_writer: user can specify TensorBoard or TensorBoardX SummaryWriter, + default to create a new TensorBoard writer. log_dir: if using default SummaryWriter, write logs to this directory, default is `./runs`. interval: plot content from engine.state every N epochs or every N iterations, default is 1. epoch_level: plot content from engine.state every N epochs or N iterations. `True` is epoch level, @@ -299,7 +301,7 @@ def __init__( For example, in evaluation, the evaluator engine needs to know current epoch from trainer. index: plot which element in a data batch, default is the first element. max_channels: number of channels to plot. - max_frames: number of frames for 2D-t plot. + max_frames: if plot 3D RGB image as video in TensorBoardX, set the FPS to `max_frames`. """ super().__init__(summary_writer=summary_writer, log_dir=log_dir) self.interval = interval diff --git a/monai/visualize/img2tensorboard.py b/monai/visualize/img2tensorboard.py index fd6dc9483b..eef1b2e764 100644 --- a/monai/visualize/img2tensorboard.py +++ b/monai/visualize/img2tensorboard.py @@ -20,6 +20,8 @@ PIL, _ = optional_import("PIL") GifImage, _ = optional_import("PIL.GifImagePlugin", name="Image") +SummaryX, _ = optional_import("tensorboardX.proto.summary_pb2", name="Summary") +SummaryWriterX, has_tensorboardx = optional_import("tensorboardX", name="SummaryWriter") if TYPE_CHECKING: from tensorboard.compat.proto.summary_pb2 import Summary @@ -28,11 +30,10 @@ Summary, _ = optional_import("tensorboard.compat.proto.summary_pb2", name="Summary") SummaryWriter, _ = optional_import("torch.utils.tensorboard", name="SummaryWriter") - __all__ = ["make_animated_gif_summary", "add_animated_gif", "add_animated_gif_no_channels", "plot_2d_or_3d_image"] -def _image3_animated_gif(tag: str, image: Union[np.ndarray, torch.Tensor], scale_factor: float = 1.0) -> Summary: +def _image3_animated_gif(tag: str, image: Union[np.ndarray, torch.Tensor], writer, scale_factor: float = 1.0): """Function to actually create the animated gif. Args: @@ -54,14 +55,17 @@ def _image3_animated_gif(tag: str, image: Union[np.ndarray, torch.Tensor], scale for b_data in PIL.GifImagePlugin.getdata(i): img_str += b_data img_str += b"\x3B" - summary_image_str = Summary.Image(height=10, width=10, colorspace=1, encoded_image_string=img_str) - image_summary = Summary.Value(tag=tag, image=summary_image_str) - return Summary(value=[image_summary]) + + summary = SummaryX if has_tensorboardx and isinstance(writer, SummaryWriterX) else Summary + summary_image_str = summary.Image(height=10, width=10, colorspace=1, encoded_image_string=img_str) + image_summary = summary.Value(tag=tag, image=summary_image_str) + return summary(value=[image_summary]) def make_animated_gif_summary( tag: str, image: Union[np.ndarray, torch.Tensor], + writer=None, max_out: int = 3, animation_axes: Sequence[int] = (3,), image_axes: Sequence[int] = (1, 2), @@ -73,7 +77,8 @@ def make_animated_gif_summary( Args: tag: Data identifier image: The image, expected to be in CHWD format - max_out: maximum number of slices to animate through + writer: the tensorboard writer to plot image + max_out: maximum number of image channels to animate through animation_axes: axis to animate on (not currently used) image_axes: axes of image (not currently used) other_indices: (not currently used) @@ -95,11 +100,12 @@ def make_animated_gif_summary( slicing.append(slice(other_ind, other_ind + 1)) image = image[tuple(slicing)] + summary_op = [] for it_i in range(min(max_out, list(image.shape)[0])): one_channel_img: Union[torch.Tensor, np.ndarray] = ( image[it_i, :, :, :].squeeze(dim=0) if isinstance(image, torch.Tensor) else image[it_i, :, :, :] ) - summary_op = _image3_animated_gif(tag + suffix.format(it_i), one_channel_img, scale_factor) + summary_op.append(_image3_animated_gif(tag + suffix.format(it_i), one_channel_img, writer, scale_factor)) return summary_op @@ -107,8 +113,8 @@ def add_animated_gif( writer: SummaryWriter, tag: str, image_tensor: Union[np.ndarray, torch.Tensor], - max_out: int, - scale_factor: float, + max_out: int = 3, + scale_factor: float = 1.0, global_step: Optional[int] = None, ) -> None: """Creates an animated gif out of an image tensor in 'CHWD' format and writes it with SummaryWriter. @@ -117,25 +123,31 @@ def add_animated_gif( writer: Tensorboard SummaryWriter to write to tag: Data identifier image_tensor: tensor for the image to add, expected to be in CHWD format - max_out: maximum number of slices to animate through + max_out: maximum number of image channels to animate through scale_factor: amount to multiply values by. If the image data is between 0 and 1, using 255 for this value will scale it to displayable range global_step: Global step value to record """ - writer._get_file_writer().add_summary( - make_animated_gif_summary( - tag, image_tensor, max_out=max_out, animation_axes=[1], image_axes=[2, 3], scale_factor=scale_factor - ), - global_step, + summary = make_animated_gif_summary( + tag=tag, + image=image_tensor, + writer=writer, + max_out=max_out, + animation_axes=[1], + image_axes=[2, 3], + scale_factor=scale_factor, ) + for s in summary: + # add GIF for every channel separately + writer._get_file_writer().add_summary(s, global_step) def add_animated_gif_no_channels( writer: SummaryWriter, tag: str, image_tensor: Union[np.ndarray, torch.Tensor], - max_out: int, - scale_factor: float, + max_out: int = 3, + scale_factor: float = 1.0, global_step: Optional[int] = None, ) -> None: """Creates an animated gif out of an image tensor in 'HWD' format that does not have @@ -146,15 +158,21 @@ def add_animated_gif_no_channels( writer: Tensorboard SummaryWriter to write to tag: Data identifier image_tensor: tensor for the image to add, expected to be in HWD format - max_out: maximum number of slices to animate through + max_out: maximum number of image channels to animate through scale_factor: amount to multiply values by. If the image data is between 0 and 1, using 255 for this value will scale it to displayable range global_step: Global step value to record """ writer._get_file_writer().add_summary( make_animated_gif_summary( - tag, image_tensor, max_out=max_out, animation_axes=[1], image_axes=[1, 2], scale_factor=scale_factor - ), + tag=tag, + image=image_tensor, + writer=writer, + max_out=max_out, + animation_axes=[1], + image_axes=[1, 2], + scale_factor=scale_factor, + )[0], global_step, ) @@ -165,23 +183,24 @@ def plot_2d_or_3d_image( writer: SummaryWriter, index: int = 0, max_channels: int = 1, - max_frames: int = 64, + max_frames: int = 24, tag: str = "output", ) -> None: """Plot 2D or 3D image on the TensorBoard, 3D image will be converted to GIF image. Note: Plot 3D or 2D image(with more than 3 channels) as separate images. + And if writer is from TensorBoardX, data has 3 channels and `max_channels=3`, will plot as RGB video. Args: data: target data to be plotted as image on the TensorBoard. The data is expected to have 'NCHW[D]' dimensions or a list of data with `CHW[D]` dimensions, and only plot the first in the batch. step: current step to plot in a chart. - writer: specify TensorBoard SummaryWriter to plot the image. + writer: specify TensorBoard or TensorBoardX SummaryWriter to plot the image. index: plot which element in the input data batch, default is the first element. max_channels: number of channels to plot. - max_frames: number of frames for 2D-t plot. + max_frames: if plot 3D RGB image as video in TensorBoardX, set the FPS to `max_frames`. tag: tag of the plotted image on TensorBoard. """ data_index = data[index] @@ -206,7 +225,13 @@ def plot_2d_or_3d_image( if d.ndim >= 4: spatial = d.shape[-3:] - for j, d3 in enumerate(d.reshape([-1] + list(spatial))[:max_channels]): - d3 = rescale_array(d3, 0, 255) - add_animated_gif(writer, f"{tag}_HWD_{j}", d3[None], max_frames, 1.0, step) + d = d.reshape([-1] + list(spatial)) + if d.shape[0] == 3 and max_channels == 3 and has_tensorboardx and isinstance(writer, SummaryWriterX): # RGB + writer.add_video(tag, d[None], step, fps=max_frames, dataformats="NCHWT") + return + # scale data to 0 - 255 for visualization + max_channels = min(max_channels, d.shape[0]) + d = np.stack([rescale_array(i, 0, 255) for i in d[:max_channels]], axis=0) + # will plot every channel as a separate GIF image + add_animated_gif(writer, f"{tag}_HWD", d, max_out=max_channels, global_step=step) return diff --git a/requirements-dev.txt b/requirements-dev.txt index 56d5709cb3..1d9d52bca5 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -41,3 +41,4 @@ einops transformers mlflow matplotlib +tensorboardX diff --git a/setup.cfg b/setup.cfg index 6f94bee7c0..1a87d0d91a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -47,6 +47,7 @@ all = transformers mlflow matplotlib + tensorboardX nibabel = nibabel skimage = @@ -83,6 +84,8 @@ mlflow = mlflow matplotlib = matplotlib +tensorboardX = + tensorboardX [flake8] select = B,C,E,F,N,P,T4,W,B9 diff --git a/tests/test_img2tensorboard.py b/tests/test_img2tensorboard.py index bd0369868e..bf6890bcad 100644 --- a/tests/test_img2tensorboard.py +++ b/tests/test_img2tensorboard.py @@ -29,9 +29,10 @@ def test_write_gray(self): image_axes=(1, 2), scale_factor=253.0, ) - assert isinstance( - summary_object_np, tensorboard.compat.proto.summary_pb2.Summary - ), "make_animated_gif_summary must return a tensorboard.summary object from numpy array" + for s in summary_object_np: + assert isinstance( + s, tensorboard.compat.proto.summary_pb2.Summary + ), "make_animated_gif_summary must return a tensorboard.summary object from numpy array" tensorarr = torch.tensor(nparr) summary_object_tensor = make_animated_gif_summary( @@ -42,9 +43,10 @@ def test_write_gray(self): image_axes=(1, 2), scale_factor=253.0, ) - assert isinstance( - summary_object_tensor, tensorboard.compat.proto.summary_pb2.Summary - ), "make_animated_gif_summary must return a tensorboard.summary object from tensor input" + for s in summary_object_tensor: + assert isinstance( + s, tensorboard.compat.proto.summary_pb2.Summary + ), "make_animated_gif_summary must return a tensorboard.summary object from tensor input" if __name__ == "__main__": diff --git a/tests/test_plot_2d_or_3d_image.py b/tests/test_plot_2d_or_3d_image.py index 645658e311..c645c8ff86 100644 --- a/tests/test_plot_2d_or_3d_image.py +++ b/tests/test_plot_2d_or_3d_image.py @@ -17,7 +17,11 @@ from parameterized import parameterized from torch.utils.tensorboard import SummaryWriter +from monai.utils import optional_import from monai.visualize import plot_2d_or_3d_image +from tests.utils import SkipIfNoModule + +SummaryWriterX, _ = optional_import("tensorboardX", name="SummaryWriter") TEST_CASE_1 = [(1, 1, 10, 10)] @@ -32,10 +36,30 @@ class TestPlot2dOr3dImage(unittest.TestCase): @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) - def test_tb_image_shape(self, shape): + def test_tb_image(self, shape): with tempfile.TemporaryDirectory() as tempdir: writer = SummaryWriter(log_dir=tempdir) - plot_2d_or_3d_image(torch.zeros(shape), 0, writer) + plot_2d_or_3d_image(torch.zeros(shape), 0, writer, max_channels=20) + writer.flush() + writer.close() + self.assertTrue(len(glob.glob(tempdir)) > 0) + + @SkipIfNoModule("tensorboardX") + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5]) + def test_tbx_image(self, shape): + with tempfile.TemporaryDirectory() as tempdir: + writer = SummaryWriterX(log_dir=tempdir) + plot_2d_or_3d_image(torch.zeros(shape), 0, writer, max_channels=2) + writer.flush() + writer.close() + self.assertTrue(len(glob.glob(tempdir)) > 0) + + @SkipIfNoModule("tensorboardX") + @parameterized.expand([TEST_CASE_5]) + def test_tbx_video(self, shape): + with tempfile.TemporaryDirectory() as tempdir: + writer = SummaryWriterX(log_dir=tempdir) + plot_2d_or_3d_image(torch.rand(shape), 0, writer, max_channels=3) writer.flush() writer.close() self.assertTrue(len(glob.glob(tempdir)) > 0)