Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions docs/source/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,14 @@ WSIReader
Image writer
------------

resolve_writer
~~~~~~~~~~~~~~
.. autofunction:: resolve_writer

register_writer
~~~~~~~~~~~~~~~
.. autofunction:: register_writer

ImageWriter
~~~~~~~~~~~
.. autoclass:: ImageWriter
Expand Down
11 changes: 10 additions & 1 deletion monai/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,16 @@
from .grid_dataset import GridPatchDataset, PatchDataset, PatchIter
from .image_dataset import ImageDataset
from .image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader, WSIReader
from .image_writer import ImageWriter, ITKWriter, NibabelWriter, PILWriter, logger
from .image_writer import (
SUPPORTED_WRITERS,
ImageWriter,
ITKWriter,
NibabelWriter,
PILWriter,
logger,
register_writer,
resolve_writer,
)
from .iterable_dataset import CSVIterableDataset, IterableDataset, ShuffleBuffer
from .nifti_saver import NiftiSaver
from .nifti_writer import write_nifti
Expand Down
5 changes: 3 additions & 2 deletions monai/data/folder_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class FolderLayout:
layout = FolderLayout(
output_dir="/test_run_1/",
postfix="seg",
extension=".nii",
extension="nii",
makedirs=False)
layout.filename(subject="Sub-A", idx="00", modality="T1")
# return value: "/test_run_1/Sub-A_seg_00_modality-T1.nii"
Expand Down Expand Up @@ -95,5 +95,6 @@ def filename(self, subject: PathLike = "subject", idx=None, **kwargs):
for k, v in kwargs.items():
full_name += f"_{k}-{v}"
if self.ext is not None:
full_name += f"{self.ext}"
ext = f"{self.ext}"
full_name += f".{ext}" if ext and not ext.startswith(".") else f"{ext}"
return full_name
4 changes: 1 addition & 3 deletions monai/data/image_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,10 @@
from torch.utils.data._utils.collate import np_str_obj_array_pattern

from monai.config import DtypeLike, KeysCollection, PathLike
from monai.data.utils import correct_nifti_header_if_necessary
from monai.data.utils import correct_nifti_header_if_necessary, is_supported_format
from monai.transforms.utility.array import EnsureChannelFirst
from monai.utils import ensure_tuple, ensure_tuple_rep, optional_import, require_pkg

from .utils import is_supported_format

if TYPE_CHECKING:
import itk
import nibabel as nib
Expand Down
98 changes: 92 additions & 6 deletions monai/data/image_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING, Mapping, Optional, Sequence, Union
from typing import TYPE_CHECKING, Dict, Mapping, Optional, Sequence, Union

import numpy as np

Expand All @@ -22,13 +22,15 @@
GridSampleMode,
GridSamplePadMode,
InterpolateMode,
OptionalImportError,
convert_data_type,
look_up_option,
optional_import,
require_pkg,
)

DEFAULT_FMT = "%(asctime)s %(levelname)s %(filename)s:%(lineno)d - %(message)s"
EXT_WILDCARD = "*"
logger = get_logger(module_name=__name__, fmt=DEFAULT_FMT)

if TYPE_CHECKING:
Expand All @@ -41,7 +43,76 @@
PILImage, _ = optional_import("PIL.Image")


__all__ = ["ImageWriter", "ITKWriter", "NibabelWriter", "PILWriter", "logger"]
__all__ = [
"ImageWriter",
"ITKWriter",
"NibabelWriter",
"PILWriter",
"SUPPORTED_WRITERS",
"register_writer",
"resolve_writer",
"logger",
]

SUPPORTED_WRITERS: Dict = {}


def register_writer(ext_name, *im_writers):
"""
Register ``ImageWriter``, so that writing a file with filename extension ``ext_name``
could be resolved to a tuple of potentially appropriate ``ImageWriter``.
The customised writers could be registered by:

.. code-block:: python

from monai.data import register_writer
# `MyWriter` must implement `ImageWriter` interface
register_writer("nii", MyWriter)

Args:
ext_name: the filename extension of the image.
Comment thread
wyli marked this conversation as resolved.
As an indexing key, it will be converted to a lower case string.
im_writers: one or multiple ImageWriter classes with high priority ones first.
"""
fmt = f"{ext_name}".lower()
if fmt.startswith("."):
fmt = fmt[1:]
existing = look_up_option(fmt, SUPPORTED_WRITERS, default=())
all_writers = im_writers + existing
SUPPORTED_WRITERS[fmt] = all_writers
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @wyli ,

Here you use a global dictionary to store the supported format -> writers mapping, while in the image readers we store readers in LoadImage transform with its own supported formats, for example:
https://github.com/Project-MONAI/MONAI/blob/dev/monai/data/image_reader.py#L386

I am not very sure which way is better, the readers way is similar to the Chain of Responsibility design pattern, This global SUPPORTED_WRITERS may be not easy to maintain, especially in multi-threads cases, etc.
Maybe because I was Java / C++ developer before, I usually never use global variables even in python. I prefer to use class to implement complicated logic, because it can maintain self-contained local states. For example you implemented FolderLayout as a class instead of a function.
Glad to see your opinions.
@ericspod @rijobro What do you guys think?

Thanks in advance.

Copy link
Copy Markdown
Contributor Author

@wyli wyli Feb 10, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SUPPORTED_WRITERS is a module-level variable, it's not a global one. it's a global dict accessed without global keyword.. I believe it's a
very common design, for example, scikit-image and PIL used it in a similar way

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I am not sure whether it can work fine if calling SaveImage in multi-thread? I think all our non-random transforms are thread-safe now.

Thanks.



def resolve_writer(ext_name, error_if_not_found=True) -> Sequence:
Comment thread
wyli marked this conversation as resolved.
"""
Resolves to a tuple of available ``ImageWriter`` in ``SUPPORTED_WRITERS``
according to the filename extension key ``ext_name``.

Args:
ext_name: the filename extension of the image.
As an indexing key it will be converted to a lower case string.
error_if_not_found: whether to raise an error if no suitable image writer is found.
if True , raise an ``OptionalImportError``, otherwise return an empty tuple. Default is ``True``.
"""
if not SUPPORTED_WRITERS:
init()
fmt = f"{ext_name}".lower()
if fmt.startswith("."):
fmt = fmt[1:]
avail_writers = []
default_writers = SUPPORTED_WRITERS.get(EXT_WILDCARD, ())
for _writer in look_up_option(fmt, SUPPORTED_WRITERS, default=default_writers):
try:
_writer() # this triggers `monai.utils.module.require_pkg` to check the system availability
avail_writers.append(_writer)
except OptionalImportError:
continue
except Exception: # other writer init errors indicating it exists
avail_writers.append(_writer)
if not avail_writers and error_if_not_found:
raise OptionalImportError(f"No ImageWriter backend found for {fmt}.")
writer_tuple = ensure_tuple(avail_writers)
SUPPORTED_WRITERS[fmt] = writer_tuple
return writer_tuple


class ImageWriter:
Expand Down Expand Up @@ -297,7 +368,9 @@ def __init__(self, output_dtype: DtypeLike = np.float32, **kwargs):
"""
super().__init__(output_dtype=output_dtype, affine=None, channel_dim=0, **kwargs)

def set_data_array(self, data_array, channel_dim: Optional[int] = 0, squeeze_end_dims: bool = True, **kwargs):
def set_data_array(
self, data_array: NdarrayOrTensor, channel_dim: Optional[int] = 0, squeeze_end_dims: bool = True, **kwargs
):
"""
Convert ``data_array`` into 'channel-last' numpy ndarray.

Expand All @@ -309,14 +382,15 @@ def set_data_array(self, data_array, channel_dim: Optional[int] = 0, squeeze_end
kwargs: keyword arguments passed to ``self.convert_to_channel_last``,
currently support ``spatial_ndim`` and ``contiguous``, defauting to ``3`` and ``False`` respectively.
"""
_r = len(data_array.shape)
self.data_obj = self.convert_to_channel_last(
data=data_array,
channel_dim=channel_dim,
squeeze_end_dims=squeeze_end_dims,
spatial_ndim=kwargs.pop("spatial_ndim", 3),
contiguous=kwargs.pop("contiguous", True),
)
self.channel_dim = channel_dim
self.channel_dim = channel_dim if len(self.data_obj.shape) >= _r else None # channel dim is at the end

def set_metadata(self, meta_dict: Optional[Mapping] = None, resample: bool = True, **options):
"""
Expand All @@ -335,7 +409,7 @@ def set_metadata(self, meta_dict: Optional[Mapping] = None, resample: bool = Tru
data_array=self.data_obj,
affine=affine,
target_affine=original_affine if resample else None,
output_spatial_shape=spatial_shape,
output_spatial_shape=spatial_shape if resample else None,
mode=options.pop("mode", GridSampleMode.BILINEAR),
padding_mode=options.pop("padding_mode", GridSamplePadMode.BORDER),
align_corners=options.pop("align_corners", False),
Expand Down Expand Up @@ -476,7 +550,7 @@ def set_metadata(self, meta_dict: Optional[Mapping], resample: bool = True, **op
data_array=self.data_obj,
affine=affine,
target_affine=original_affine if resample else None,
output_spatial_shape=spatial_shape,
output_spatial_shape=spatial_shape if resample else None,
mode=options.pop("mode", GridSampleMode.BILINEAR),
padding_mode=options.pop("padding_mode", GridSamplePadMode.BORDER),
align_corners=options.pop("align_corners", False),
Expand Down Expand Up @@ -716,3 +790,15 @@ def create_backend_obj(
data = np.moveaxis(data, 0, 1)

return PILImage.fromarray(data, mode=kwargs.pop("image_mode", None))


def init():
"""
Initialize the image writer modules according to the filename extension.
"""
for ext in ("png", "jpg", "jpeg", "bmp", "tiff", "tif"):
register_writer(ext, PILWriter) # TODO: test 16-bit
for ext in ("nii.gz", "nii"):
register_writer(ext, NibabelWriter, ITKWriter)
register_writer("nrrd", ITKWriter, NibabelWriter)
register_writer(EXT_WILDCARD, ITKWriter, NibabelWriter, ITKWriter)
5 changes: 5 additions & 0 deletions monai/data/nifti_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@
from monai.data.utils import create_file_basename
from monai.utils import GridSampleMode, GridSamplePadMode
from monai.utils import ImageMetaKey as Key
from monai.utils import deprecated


@deprecated(since="0.8", msg_suffix="use monai.transforms.SaveImage instead.")
class NiftiSaver:
"""
Save the data as NIfTI file, it can support single data content or a batch of data.
Expand All @@ -32,6 +34,9 @@ class NiftiSaver:

Note: image should include channel dimension: [B],C,H,W,[D].

.. deprecated:: 0.8
Use :py:class:`monai.transforms.SaveImage` instead.

"""

def __init__(
Expand Down
7 changes: 6 additions & 1 deletion monai/data/nifti_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@
from monai.data.utils import compute_shape_offset, to_affine_nd
from monai.networks.layers import AffineTransform
from monai.transforms.utils_pytorch_numpy_unification import allclose
from monai.utils import GridSampleMode, GridSamplePadMode, optional_import
from monai.utils import GridSampleMode, GridSamplePadMode, deprecated, optional_import
from monai.utils.type_conversion import convert_data_type

nib, _ = optional_import("nibabel")


@deprecated(since="0.8", msg_suffix="use monai.data.NibabelWriter instead.")
def write_nifti(
data: NdarrayOrTensor,
file_name: str,
Expand Down Expand Up @@ -98,6 +99,10 @@ def write_nifti(
dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision.
If None, use the data type of input data.
output_dtype: data type for saving data. Defaults to ``np.float32``.

.. deprecated:: 0.8
Use :py:meth:`monai.data.NibabelWriter` instead.

"""
data, *_ = convert_data_type(data, np.ndarray)
affine, *_ = convert_data_type(affine, np.ndarray)
Expand Down
6 changes: 5 additions & 1 deletion monai/data/png_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@
from monai.data.png_writer import write_png
from monai.data.utils import create_file_basename
from monai.utils import ImageMetaKey as Key
from monai.utils import InterpolateMode, look_up_option
from monai.utils import InterpolateMode, deprecated, look_up_option


@deprecated(since="0.8", msg_suffix="use monai.transforms.SaveImage instead.")
class PNGSaver:
"""
Save the data as png file, it can support single data content or a batch of data.
Expand All @@ -30,6 +31,9 @@ class PNGSaver:
where the input image name is extracted from the provided meta data dictionary.
If no meta data provided, use index from 0 as the filename prefix.

.. deprecated:: 0.8
Use :py:class:`monai.transforms.SaveImage` instead.

"""

def __init__(
Expand Down
6 changes: 5 additions & 1 deletion monai/data/png_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@
import numpy as np

from monai.transforms.spatial.array import Resize
from monai.utils import InterpolateMode, ensure_tuple_rep, look_up_option, optional_import
from monai.utils import InterpolateMode, deprecated, ensure_tuple_rep, look_up_option, optional_import

Image, _ = optional_import("PIL", name="Image")


@deprecated(since="0.8", msg_suffix="use monai.data.PILWriter instead.")
def write_png(
data: np.ndarray,
file_name: str,
Expand Down Expand Up @@ -46,6 +47,9 @@ def write_png(
Raises:
ValueError: When ``scale`` is not one of [255, 65535].

.. deprecated:: 0.8
Use :py:meth:`monai.data.PILWriter` instead.

"""
if not isinstance(data, np.ndarray):
raise ValueError("input data must be numpy array.")
Expand Down
Loading