Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
42a45e0
Merge pull request #19 from Project-MONAI/master
Nic-Ma Feb 1, 2021
cd16a13
Merge pull request #32 from Project-MONAI/master
Nic-Ma Feb 24, 2021
6f87afd
Merge pull request #180 from Project-MONAI/dev
Nic-Ma Jul 22, 2021
f398298
Merge pull request #214 from Project-MONAI/dev
Nic-Ma Sep 8, 2021
b983811
Merge pull request #272 from Project-MONAI/dev
Nic-Ma Nov 3, 2021
493bafb
[DLMED] add support for both TensorBoard and TensorBoardX
Nic-Ma Nov 3, 2021
76892f0
[DLMED] add RGB color GIF
Nic-Ma Nov 3, 2021
b0228fa
[DLMED] add support to handlers
Nic-Ma Nov 3, 2021
dd86c7c
[DLMED] add optional import
Nic-Ma Nov 3, 2021
75abfcc
[DLMED] format code
Nic-Ma Nov 3, 2021
43f180d
[DLMED] fix flake8≈
Nic-Ma Nov 3, 2021
ccb839e
Merge branch 'dev' into 3240-support-color-plot
Nic-Ma Nov 3, 2021
ffc11a4
[DLMED] fix typo
Nic-Ma Nov 3, 2021
ffaf19c
Merge branch 'dev' into 3240-support-color-plot
Nic-Ma Nov 3, 2021
e117005
[DLMED] fix typing
Nic-Ma Nov 3, 2021
87074f6
[DLEMD] test python 3.6
Nic-Ma Nov 3, 2021
5e54c39
Merge branch 'dev' into 3240-support-color-plot
wyli Nov 3, 2021
a14ee61
test remove typing
wyli Nov 4, 2021
1da5f9c
Merge branch 'dev' into 3240-support-color-plot
wyli Nov 4, 2021
4aadb5d
[DLMED] update according to comments
Nic-Ma Nov 5, 2021
70899b0
Merge branch 'dev' into 3240-support-color-plot
Nic-Ma Nov 5, 2021
94ad219
[DLMED] fix tests
Nic-Ma Nov 5, 2021
8f5c5af
Merge branch 'dev' into 3240-support-color-plot
Nic-Ma Nov 6, 2021
0883fb1
Merge branch 'dev' into 3240-support-color-plot
wyli Nov 6, 2021
01d1e6c
Merge branch 'dev' into 3240-support-color-plot
Nic-Ma Nov 6, 2021
af3808c
temp fix
wyli Nov 7, 2021
52a6cd8
[DLMED] remove moviepy
Nic-Ma Nov 8, 2021
85ad82d
Merge branch 'dev' into 3240-support-color-plot
Nic-Ma Nov 8, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@ pandas
einops
transformers
mlflow
tensorboardX
4 changes: 2 additions & 2 deletions docs/source/installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,9 @@ Since MONAI v0.2.0, the extras syntax such as `pip install 'monai[nibabel]'` is

- The options are
```
[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers, mlflow, matplotlib]
[nibabel, skimage, pillow, tensorboard, gdown, ignite, torchvision, itk, tqdm, lmdb, psutil, cucim, openslide, pandas, einops, transformers, mlflow, matplotlib, tensorboardX]
```
which correspond to `nibabel`, `scikit-image`, `pillow`, `tensorboard`,
`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops`, `transformers`, `mlflow`, `matplotlib`, respectively.
`gdown`, `pytorch-ignite`, `torchvision`, `itk`, `tqdm`, `lmdb`, `psutil`, `cucim`, `openslide-python`, `pandas`, `einops`, `transformers`, `mlflow`, `matplotlib`, `tensorboardX` respectively.

- `pip install 'monai[all]'` installs all the optional dependencies.
20 changes: 11 additions & 9 deletions monai/handlers/tensorboard_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from monai.visualize import plot_2d_or_3d_image

Events, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Events")

if TYPE_CHECKING:
from ignite.engine import Engine
from torch.utils.tensorboard import SummaryWriter
Expand All @@ -35,8 +36,8 @@ class TensorBoardHandler:
Base class for the handlers to write data into TensorBoard.

Args:
summary_writer: user can specify TensorBoard SummaryWriter,
default to create a new writer.
summary_writer: user can specify TensorBoard or TensorBoardX SummaryWriter,
default to create a new TensorBoard writer.
log_dir: if using default SummaryWriter, write logs to this directory, default is `./runs`.

"""
Expand Down Expand Up @@ -94,8 +95,8 @@ def __init__(
) -> None:
"""
Args:
summary_writer: user can specify TensorBoard SummaryWriter,
default to create a new writer.
summary_writer: user can specify TensorBoard or TensorBoardX SummaryWriter,
default to create a new TensorBoard writer.
log_dir: if using default SummaryWriter, write logs to this directory, default is `./runs`.
epoch_event_writer: customized callable TensorBoard writer for epoch level.
Must accept parameter "engine" and "summary_writer", use default event writer if None.
Expand Down Expand Up @@ -180,7 +181,7 @@ def _default_epoch_writer(self, engine: Engine, writer: SummaryWriter) -> None:

Args:
engine: Ignite Engine, it can be a trainer, validator or evaluator.
writer: TensorBoard writer, created in TensorBoardHandler.
writer: TensorBoard or TensorBoardX writer, passed or created in TensorBoardHandler.

"""
current_epoch = self.global_epoch_transform(engine.state.epoch)
Expand All @@ -203,7 +204,7 @@ def _default_iteration_writer(self, engine: Engine, writer: SummaryWriter) -> No

Args:
engine: Ignite Engine, it can be a trainer, validator or evaluator.
writer: TensorBoard writer, created in TensorBoardHandler.
writer: TensorBoard or TensorBoardX writer, passed or created in TensorBoardHandler.

"""
loss = self.output_transform(engine.state.output)
Expand Down Expand Up @@ -243,6 +244,7 @@ class TensorBoardImageHandler(TensorBoardHandler):
2D output (shape in Batch, channel, H, W) will be shown as simple image using the first element in the batch,
for 3D to ND output (shape in Batch, channel, H, W, D) input, each of ``self.max_channels`` number of images'
last three dimensions will be shown as animated GIF along the last axis (typically Depth).
And if writer is from TensorBoardX, data has 3 channels and `max_channels=3`, will plot as RGB video.

It can be used for any Ignite Engine (trainer, validator and evaluator).
User can easily add it to engine for any expected Event, for example: ``EPOCH_COMPLETED``,
Expand Down Expand Up @@ -277,8 +279,8 @@ def __init__(
) -> None:
"""
Args:
summary_writer: user can specify TensorBoard SummaryWriter,
default to create a new writer.
summary_writer: user can specify TensorBoard or TensorBoardX SummaryWriter,
default to create a new TensorBoard writer.
log_dir: if using default SummaryWriter, write logs to this directory, default is `./runs`.
interval: plot content from engine.state every N epochs or every N iterations, default is 1.
epoch_level: plot content from engine.state every N epochs or N iterations. `True` is epoch level,
Expand All @@ -299,7 +301,7 @@ def __init__(
For example, in evaluation, the evaluator engine needs to know current epoch from trainer.
index: plot which element in a data batch, default is the first element.
max_channels: number of channels to plot.
max_frames: number of frames for 2D-t plot.
max_frames: if plot 3D RGB image as video in TensorBoardX, set the FPS to `max_frames`.
"""
super().__init__(summary_writer=summary_writer, log_dir=log_dir)
self.interval = interval
Expand Down
77 changes: 51 additions & 26 deletions monai/visualize/img2tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

PIL, _ = optional_import("PIL")
GifImage, _ = optional_import("PIL.GifImagePlugin", name="Image")
SummaryX, _ = optional_import("tensorboardX.proto.summary_pb2", name="Summary")
SummaryWriterX, has_tensorboardx = optional_import("tensorboardX", name="SummaryWriter")

if TYPE_CHECKING:
from tensorboard.compat.proto.summary_pb2 import Summary
Expand All @@ -28,11 +30,10 @@
Summary, _ = optional_import("tensorboard.compat.proto.summary_pb2", name="Summary")
SummaryWriter, _ = optional_import("torch.utils.tensorboard", name="SummaryWriter")


__all__ = ["make_animated_gif_summary", "add_animated_gif", "add_animated_gif_no_channels", "plot_2d_or_3d_image"]


def _image3_animated_gif(tag: str, image: Union[np.ndarray, torch.Tensor], scale_factor: float = 1.0) -> Summary:
def _image3_animated_gif(tag: str, image: Union[np.ndarray, torch.Tensor], writer, scale_factor: float = 1.0):
"""Function to actually create the animated gif.

Args:
Expand All @@ -54,14 +55,17 @@ def _image3_animated_gif(tag: str, image: Union[np.ndarray, torch.Tensor], scale
for b_data in PIL.GifImagePlugin.getdata(i):
img_str += b_data
img_str += b"\x3B"
summary_image_str = Summary.Image(height=10, width=10, colorspace=1, encoded_image_string=img_str)
image_summary = Summary.Value(tag=tag, image=summary_image_str)
return Summary(value=[image_summary])

summary = SummaryX if has_tensorboardx and isinstance(writer, SummaryWriterX) else Summary
summary_image_str = summary.Image(height=10, width=10, colorspace=1, encoded_image_string=img_str)
image_summary = summary.Value(tag=tag, image=summary_image_str)
return summary(value=[image_summary])


def make_animated_gif_summary(
tag: str,
image: Union[np.ndarray, torch.Tensor],
writer=None,
max_out: int = 3,
animation_axes: Sequence[int] = (3,),
image_axes: Sequence[int] = (1, 2),
Expand All @@ -73,7 +77,8 @@ def make_animated_gif_summary(
Args:
tag: Data identifier
image: The image, expected to be in CHWD format
max_out: maximum number of slices to animate through
writer: the tensorboard writer to plot image
max_out: maximum number of image channels to animate through
animation_axes: axis to animate on (not currently used)
image_axes: axes of image (not currently used)
other_indices: (not currently used)
Expand All @@ -95,20 +100,21 @@ def make_animated_gif_summary(
slicing.append(slice(other_ind, other_ind + 1))
image = image[tuple(slicing)]

summary_op = []
for it_i in range(min(max_out, list(image.shape)[0])):
one_channel_img: Union[torch.Tensor, np.ndarray] = (
image[it_i, :, :, :].squeeze(dim=0) if isinstance(image, torch.Tensor) else image[it_i, :, :, :]
)
summary_op = _image3_animated_gif(tag + suffix.format(it_i), one_channel_img, scale_factor)
summary_op.append(_image3_animated_gif(tag + suffix.format(it_i), one_channel_img, writer, scale_factor))
return summary_op


def add_animated_gif(
writer: SummaryWriter,
tag: str,
image_tensor: Union[np.ndarray, torch.Tensor],
max_out: int,
scale_factor: float,
max_out: int = 3,
scale_factor: float = 1.0,
global_step: Optional[int] = None,
) -> None:
"""Creates an animated gif out of an image tensor in 'CHWD' format and writes it with SummaryWriter.
Expand All @@ -117,25 +123,31 @@ def add_animated_gif(
writer: Tensorboard SummaryWriter to write to
tag: Data identifier
image_tensor: tensor for the image to add, expected to be in CHWD format
max_out: maximum number of slices to animate through
max_out: maximum number of image channels to animate through
scale_factor: amount to multiply values by. If the image data is between 0 and 1, using 255 for this value will
scale it to displayable range
global_step: Global step value to record
"""
writer._get_file_writer().add_summary(
make_animated_gif_summary(
tag, image_tensor, max_out=max_out, animation_axes=[1], image_axes=[2, 3], scale_factor=scale_factor
),
global_step,
summary = make_animated_gif_summary(
tag=tag,
image=image_tensor,
writer=writer,
max_out=max_out,
animation_axes=[1],
image_axes=[2, 3],
scale_factor=scale_factor,
)
for s in summary:
# add GIF for every channel separately
writer._get_file_writer().add_summary(s, global_step)


def add_animated_gif_no_channels(
writer: SummaryWriter,
tag: str,
image_tensor: Union[np.ndarray, torch.Tensor],
max_out: int,
scale_factor: float,
max_out: int = 3,
scale_factor: float = 1.0,
global_step: Optional[int] = None,
) -> None:
"""Creates an animated gif out of an image tensor in 'HWD' format that does not have
Expand All @@ -146,15 +158,21 @@ def add_animated_gif_no_channels(
writer: Tensorboard SummaryWriter to write to
tag: Data identifier
image_tensor: tensor for the image to add, expected to be in HWD format
max_out: maximum number of slices to animate through
max_out: maximum number of image channels to animate through
scale_factor: amount to multiply values by. If the image data is between 0 and 1,
using 255 for this value will scale it to displayable range
global_step: Global step value to record
"""
writer._get_file_writer().add_summary(
make_animated_gif_summary(
tag, image_tensor, max_out=max_out, animation_axes=[1], image_axes=[1, 2], scale_factor=scale_factor
),
tag=tag,
image=image_tensor,
writer=writer,
max_out=max_out,
animation_axes=[1],
image_axes=[1, 2],
scale_factor=scale_factor,
)[0],
global_step,
)

Expand All @@ -165,23 +183,24 @@ def plot_2d_or_3d_image(
writer: SummaryWriter,
index: int = 0,
max_channels: int = 1,
max_frames: int = 64,
max_frames: int = 24,
tag: str = "output",
) -> None:
"""Plot 2D or 3D image on the TensorBoard, 3D image will be converted to GIF image.

Note:
Plot 3D or 2D image(with more than 3 channels) as separate images.
And if writer is from TensorBoardX, data has 3 channels and `max_channels=3`, will plot as RGB video.

Args:
data: target data to be plotted as image on the TensorBoard.
The data is expected to have 'NCHW[D]' dimensions or a list of data with `CHW[D]` dimensions,
and only plot the first in the batch.
step: current step to plot in a chart.
writer: specify TensorBoard SummaryWriter to plot the image.
writer: specify TensorBoard or TensorBoardX SummaryWriter to plot the image.
index: plot which element in the input data batch, default is the first element.
max_channels: number of channels to plot.
max_frames: number of frames for 2D-t plot.
max_frames: if plot 3D RGB image as video in TensorBoardX, set the FPS to `max_frames`.
tag: tag of the plotted image on TensorBoard.
"""
data_index = data[index]
Expand All @@ -206,7 +225,13 @@ def plot_2d_or_3d_image(

if d.ndim >= 4:
spatial = d.shape[-3:]
for j, d3 in enumerate(d.reshape([-1] + list(spatial))[:max_channels]):
d3 = rescale_array(d3, 0, 255)
add_animated_gif(writer, f"{tag}_HWD_{j}", d3[None], max_frames, 1.0, step)
d = d.reshape([-1] + list(spatial))
if d.shape[0] == 3 and max_channels == 3 and has_tensorboardx and isinstance(writer, SummaryWriterX): # RGB
writer.add_video(tag, d[None], step, fps=max_frames, dataformats="NCHWT")
return
# scale data to 0 - 255 for visualization
max_channels = min(max_channels, d.shape[0])
d = np.stack([rescale_array(i, 0, 255) for i in d[:max_channels]], axis=0)
# will plot every channel as a separate GIF image
add_animated_gif(writer, f"{tag}_HWD", d, max_out=max_channels, global_step=step)
return
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,4 @@ einops
transformers
mlflow
matplotlib
tensorboardX
3 changes: 3 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ all =
transformers
mlflow
matplotlib
tensorboardX
nibabel =
nibabel
skimage =
Expand Down Expand Up @@ -83,6 +84,8 @@ mlflow =
mlflow
matplotlib =
matplotlib
tensorboardX =
tensorboardX

[flake8]
select = B,C,E,F,N,P,T4,W,B9
Expand Down
14 changes: 8 additions & 6 deletions tests/test_img2tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@ def test_write_gray(self):
image_axes=(1, 2),
scale_factor=253.0,
)
assert isinstance(
summary_object_np, tensorboard.compat.proto.summary_pb2.Summary
), "make_animated_gif_summary must return a tensorboard.summary object from numpy array"
for s in summary_object_np:
assert isinstance(
s, tensorboard.compat.proto.summary_pb2.Summary
), "make_animated_gif_summary must return a tensorboard.summary object from numpy array"

tensorarr = torch.tensor(nparr)
summary_object_tensor = make_animated_gif_summary(
Expand All @@ -42,9 +43,10 @@ def test_write_gray(self):
image_axes=(1, 2),
scale_factor=253.0,
)
assert isinstance(
summary_object_tensor, tensorboard.compat.proto.summary_pb2.Summary
), "make_animated_gif_summary must return a tensorboard.summary object from tensor input"
for s in summary_object_tensor:
assert isinstance(
s, tensorboard.compat.proto.summary_pb2.Summary
), "make_animated_gif_summary must return a tensorboard.summary object from tensor input"


if __name__ == "__main__":
Expand Down
28 changes: 26 additions & 2 deletions tests/test_plot_2d_or_3d_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@
from parameterized import parameterized
from torch.utils.tensorboard import SummaryWriter

from monai.utils import optional_import
from monai.visualize import plot_2d_or_3d_image
from tests.utils import SkipIfNoModule

SummaryWriterX, _ = optional_import("tensorboardX", name="SummaryWriter")

TEST_CASE_1 = [(1, 1, 10, 10)]

Expand All @@ -32,10 +36,30 @@

class TestPlot2dOr3dImage(unittest.TestCase):
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5])
def test_tb_image_shape(self, shape):
def test_tb_image(self, shape):
with tempfile.TemporaryDirectory() as tempdir:
writer = SummaryWriter(log_dir=tempdir)
plot_2d_or_3d_image(torch.zeros(shape), 0, writer)
plot_2d_or_3d_image(torch.zeros(shape), 0, writer, max_channels=20)
writer.flush()
writer.close()
self.assertTrue(len(glob.glob(tempdir)) > 0)

@SkipIfNoModule("tensorboardX")
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4, TEST_CASE_5])
def test_tbx_image(self, shape):
with tempfile.TemporaryDirectory() as tempdir:
writer = SummaryWriterX(log_dir=tempdir)
plot_2d_or_3d_image(torch.zeros(shape), 0, writer, max_channels=2)
writer.flush()
writer.close()
self.assertTrue(len(glob.glob(tempdir)) > 0)

@SkipIfNoModule("tensorboardX")
@parameterized.expand([TEST_CASE_5])
def test_tbx_video(self, shape):
with tempfile.TemporaryDirectory() as tempdir:
writer = SummaryWriterX(log_dir=tempdir)
plot_2d_or_3d_image(torch.rand(shape), 0, writer, max_channels=3)
writer.flush()
writer.close()
self.assertTrue(len(glob.glob(tempdir)) > 0)
Expand Down