From 33f0348f2511c8ebe888fafe61081b39338525c0 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Tue, 8 Feb 2022 14:26:04 +0000 Subject: [PATCH 01/11] update saveimage and writer selector Signed-off-by: Wenqi Li --- docs/source/data.rst | 8 + monai/data/__init__.py | 11 +- monai/data/image_reader.py | 4 +- monai/data/image_writer.py | 79 ++++++++- monai/data/nifti_saver.py | 5 + monai/data/nifti_writer.py | 7 +- monai/data/png_saver.py | 6 +- monai/data/png_writer.py | 6 +- monai/transforms/io/array.py | 199 ++++++++++++---------- monai/transforms/io/dictionary.py | 97 +++++------ tests/min_tests.py | 2 - tests/test_handler_segmentation_saver.py | 8 +- tests/test_integration_segmentation_3d.py | 17 +- tests/test_nifti_saver.py | 111 ------------ tests/test_png_saver.py | 76 --------- tests/test_save_image.py | 12 +- tests/test_save_imaged.py | 12 +- 17 files changed, 319 insertions(+), 341 deletions(-) delete mode 100644 tests/test_nifti_saver.py delete mode 100644 tests/test_png_saver.py diff --git a/docs/source/data.rst b/docs/source/data.rst index f2377e2972..2bdf401c7f 100644 --- a/docs/source/data.rst +++ b/docs/source/data.rst @@ -153,6 +153,14 @@ WSIReader Image writer ------------ +resolve_writer +~~~~~~~~~~~~~~ +.. autofunction:: resolve_writer + +register_writer +~~~~~~~~~~~~~~~ +.. autofunction:: register_writer + ImageWriter ~~~~~~~~~~~ .. autoclass:: ImageWriter diff --git a/monai/data/__init__.py b/monai/data/__init__.py index 86630ae495..bed194d2f4 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -35,7 +35,16 @@ from .grid_dataset import GridPatchDataset, PatchDataset, PatchIter from .image_dataset import ImageDataset from .image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader, WSIReader -from .image_writer import ImageWriter, ITKWriter, NibabelWriter, PILWriter, logger +from .image_writer import ( + SUPPORTED_WRITERS, + ImageWriter, + ITKWriter, + NibabelWriter, + PILWriter, + logger, + register_writer, + resolve_writer, +) from .iterable_dataset import CSVIterableDataset, IterableDataset, ShuffleBuffer from .nifti_saver import NiftiSaver from .nifti_writer import write_nifti diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index 9f0e3f32cf..0be7feb1e5 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -18,12 +18,10 @@ from torch.utils.data._utils.collate import np_str_obj_array_pattern from monai.config import DtypeLike, KeysCollection, PathLike -from monai.data.utils import correct_nifti_header_if_necessary +from monai.data.utils import correct_nifti_header_if_necessary, is_supported_format from monai.transforms.utility.array import EnsureChannelFirst from monai.utils import ensure_tuple, ensure_tuple_rep, optional_import, require_pkg -from .utils import is_supported_format - if TYPE_CHECKING: import itk import nibabel as nib diff --git a/monai/data/image_writer.py b/monai/data/image_writer.py index 375063e397..62ffc6c072 100644 --- a/monai/data/image_writer.py +++ b/monai/data/image_writer.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TYPE_CHECKING, Mapping, Optional, Sequence, Union +from typing import TYPE_CHECKING, Dict, Mapping, Optional, Sequence, Union import numpy as np @@ -22,6 +22,7 @@ GridSampleMode, GridSamplePadMode, InterpolateMode, + OptionalImportError, convert_data_type, look_up_option, optional_import, @@ -41,7 +42,69 @@ PILImage, _ = optional_import("PIL.Image") -__all__ = ["ImageWriter", "ITKWriter", "NibabelWriter", "PILWriter", "logger"] +__all__ = [ + "ImageWriter", + "ITKWriter", + "NibabelWriter", + "PILWriter", + "SUPPORTED_WRITERS", + "register_writer", + "resolve_writer", + "logger", +] + +SUPPORTED_WRITERS: Dict = {} + + +def register_writer(ext_name, *im_writer): + """ + Register ``ImageWriter``, so that writing a file with filename extension ``ext_name`` + could be resolved to a tuple of potentially appropriate ``ImageWriter``. + The customised writers could be registered by: + + .. code-block:: python + + from monai.data import image_writer + # `MyWriter` must implement `ImageWriter` interface + image_writer.register_writer(".nii", MyWriter) + + Args: + ext_name: the filename extension of the image. + As an indexing key, it will be converted to a lower case string. + im_writer: one or multiple ImageWriter classes with high priority ones first. + """ + fmt = f"{ext_name}".lower() + existing = look_up_option(fmt, SUPPORTED_WRITERS, default=()) + all_writers = im_writer + existing + SUPPORTED_WRITERS[fmt] = all_writers + + +def resolve_writer(ext_name, error_if_not_found=True) -> Sequence: + """ + Resolves to a tuple of available ``ImageWriter`` in ``SUPPORTED_WRITERS`` + according to the filename extension key ``ext_name``. + + Args: + ext_name: the filename extension of the image. + As an indexing key it will be converted to a lower case string. + error_if_not_found: whether to raise an error if no suitable image writer is found. + if True , raise an ``OptionalImportError``, otherwise return an empty tuple. Default is ``True``. + """ + if not SUPPORTED_WRITERS: + init() + fmt = f"{ext_name}".lower() + avail_writers = [] + for _writer in look_up_option(fmt, SUPPORTED_WRITERS, default=SUPPORTED_WRITERS["*"]): + try: + _writer() # this triggers `monai.utils.module.require_pkg` to check the system availability + avail_writers.append(_writer) + except OptionalImportError: + pass + if not avail_writers and error_if_not_found: + raise OptionalImportError(f"No ImageWriter backend found for {fmt}.") + writer_tuple = ensure_tuple(avail_writers) + SUPPORTED_WRITERS[fmt] = writer_tuple + return writer_tuple class ImageWriter: @@ -716,3 +779,15 @@ def create_backend_obj( data = np.moveaxis(data, 0, 1) return PILImage.fromarray(data, mode=kwargs.pop("image_mode", None)) + + +def init(): + """ + Initialize the image writer modules according to the filename extension. + """ + for ext in (".png", ".jpg", ".jpeg", ".bmp", ".tiff", ".tif"): + register_writer(ext, PILWriter) # TODO: test 16-bit + for ext in (".nii.gz", ".nii"): + register_writer(ext, NibabelWriter, ITKWriter) + register_writer(".nrrd", ITKWriter, NibabelWriter) + register_writer("*", ITKWriter, NibabelWriter) diff --git a/monai/data/nifti_saver.py b/monai/data/nifti_saver.py index a5acdd032e..3fdc0aa3e8 100644 --- a/monai/data/nifti_saver.py +++ b/monai/data/nifti_saver.py @@ -19,8 +19,10 @@ from monai.data.utils import create_file_basename from monai.utils import GridSampleMode, GridSamplePadMode from monai.utils import ImageMetaKey as Key +from monai.utils import deprecated +@deprecated(since="0.8", msg_suffix="use monai.transforms.SaveImage instead.") class NiftiSaver: """ Save the data as NIfTI file, it can support single data content or a batch of data. @@ -32,6 +34,9 @@ class NiftiSaver: Note: image should include channel dimension: [B],C,H,W,[D]. + .. deprecated:: 0.8 + Use :py:class:`monai.transforms.SaveImage` instead. + """ def __init__( diff --git a/monai/data/nifti_writer.py b/monai/data/nifti_writer.py index b658121e49..8a6172955f 100644 --- a/monai/data/nifti_writer.py +++ b/monai/data/nifti_writer.py @@ -19,12 +19,13 @@ from monai.data.utils import compute_shape_offset, to_affine_nd from monai.networks.layers import AffineTransform from monai.transforms.utils_pytorch_numpy_unification import allclose -from monai.utils import GridSampleMode, GridSamplePadMode, optional_import +from monai.utils import GridSampleMode, GridSamplePadMode, deprecated, optional_import from monai.utils.type_conversion import convert_data_type nib, _ = optional_import("nibabel") +@deprecated(since="0.8", msg_suffix="use monai.data.NibabelWriter instead.") def write_nifti( data: NdarrayOrTensor, file_name: str, @@ -98,6 +99,10 @@ def write_nifti( dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision. If None, use the data type of input data. output_dtype: data type for saving data. Defaults to ``np.float32``. + + .. deprecated:: 0.8 + Use :py:meth:`monai.data.NibabelWriter` instead. + """ data, *_ = convert_data_type(data, np.ndarray) affine, *_ = convert_data_type(affine, np.ndarray) diff --git a/monai/data/png_saver.py b/monai/data/png_saver.py index a83a560e9f..9a1ade0efa 100644 --- a/monai/data/png_saver.py +++ b/monai/data/png_saver.py @@ -18,9 +18,10 @@ from monai.data.png_writer import write_png from monai.data.utils import create_file_basename from monai.utils import ImageMetaKey as Key -from monai.utils import InterpolateMode, look_up_option +from monai.utils import InterpolateMode, deprecated, look_up_option +@deprecated(since="0.8", msg_suffix="use monai.transforms.SaveImage instead.") class PNGSaver: """ Save the data as png file, it can support single data content or a batch of data. @@ -30,6 +31,9 @@ class PNGSaver: where the input image name is extracted from the provided meta data dictionary. If no meta data provided, use index from 0 as the filename prefix. + .. deprecated:: 0.8 + Use :py:class:`monai.transforms.SaveImage` instead. + """ def __init__( diff --git a/monai/data/png_writer.py b/monai/data/png_writer.py index 7fcdb7fdb0..5d05536923 100644 --- a/monai/data/png_writer.py +++ b/monai/data/png_writer.py @@ -14,11 +14,12 @@ import numpy as np from monai.transforms.spatial.array import Resize -from monai.utils import InterpolateMode, ensure_tuple_rep, look_up_option, optional_import +from monai.utils import InterpolateMode, deprecated, ensure_tuple_rep, look_up_option, optional_import Image, _ = optional_import("PIL", name="Image") +@deprecated(since="0.8", msg_suffix="use monai.data.PILWriter instead.") def write_png( data: np.ndarray, file_name: str, @@ -46,6 +47,9 @@ def write_png( Raises: ValueError: When ``scale`` is not one of [255, 65535]. + .. deprecated:: 0.8 + Use :py:meth:`monai.data.PILWriter` instead. + """ if not isinstance(data, np.ndarray): raise ValueError("input data must be numpy array.") diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index 0b8b7ba156..5b9fdbc3c3 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -24,9 +24,9 @@ import torch from monai.config import DtypeLike, PathLike +from monai.data import image_writer +from monai.data.folder_layout import FolderLayout from monai.data.image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader -from monai.data.nifti_saver import NiftiSaver -from monai.data.png_saver import PNGSaver from monai.transforms.transform import Transform from monai.utils import GridSampleMode, GridSamplePadMode from monai.utils import ImageMetaKey as Key @@ -82,7 +82,7 @@ class LoadImage(Transform): - User-specified reader in the constructor of `LoadImage`. - Readers from the last to the first in the registered list. - Current default readers: (nii, nii.gz -> NibabelReader), (png, jpg, bmp -> PILReader), - (npz, npy -> NumpyReader), (dcm, DICOM series and others -> ITKReader). + (npz, npy -> NumpyReader), (DICOM file -> ITKReader). See also: @@ -112,7 +112,7 @@ def __init__(self, reader=None, image_only: bool = False, dtype: DtypeLike = np. or a tuple of two elements containing the data array, and the meta data in a dictionary format otherwise. - If `reader` is specified, the loader will attempt to use the specified readers and the default supported readers. This might introduce overheads when handling the exceptions of trying the incompatible loaders. - In this case, it is therefore recommended to set the most appropriate reader as + In this case, it is therefore recommended setting the most appropriate reader as the last item of the `reader` parameter. """ @@ -227,69 +227,59 @@ def __call__(self, filename: Union[Sequence[PathLike], PathLike], reader: Option class SaveImage(Transform): """ - Save transformed data into files, support NIfTI and PNG formats. - It can work for both numpy array and PyTorch Tensor in both preprocessing transform - chain and postprocessing transform chain. - The name of saved file will be `{input_image_name}_{output_postfix}{output_ext}`, - where the input image name is extracted from the provided meta data dictionary. - If no meta data provided, use index from 0 as the filename prefix. - It can also save a list of PyTorch Tensor or numpy array without `batch dim`. + Save the image (in the form of torch tensor or numpy ndarray) and metadata dictionary into files. - Note: image should be channel-first shape: [C,H,W,[D]]. + The name of saved file will be `{input_image_name}_{output_postfix}{output_ext}`, + where the `input_image_name` is extracted from the provided metadata dictionary. + If no metadata provided, a running index starting from 0 will be used as the filename prefix. Args: output_dir: output image directory. output_postfix: a string appended to all output file names, default to `trans`. - output_ext: output file extension name, available extensions: `.nii.gz`, `.nii`, `.png`. - resample: whether to resample before saving the data array. - if saving PNG format image, based on the `spatial_shape` from metadata. - if saving NIfTI format image, based on the `original_affine` from metadata. - mode: This option is used when ``resample = True``. Defaults to ``"nearest"``. - - - NIfTI files {``"bilinear"``, ``"nearest"``} - Interpolation mode to calculate output values. - See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html - - PNG files {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} - The interpolation mode. - See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html - - padding_mode: This option is used when ``resample = True``. Defaults to ``"border"``. + output_ext: output file extension name. + output_dtype: data type for saving data. Defaults to ``np.float32``. + resample: whether to resample image (if needed) before saving the data array, + based on the `spatial_shape` (and `original_affine`) from metadata. + mode: This option is used when ``resample=True``. Defaults to ``"nearest"``. + Depending on the writers, the possible options are - - NIfTI files {``"zeros"``, ``"border"``, ``"reflection"``} - Padding mode for outside grid values. - See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html - - PNG files - This option is ignored. + - {``"bilinear"``, ``"nearest"``, ``"bicubic"``}. + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + - {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``}. + See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate + padding_mode: This option is used when ``resample = True``. Defaults to ``"border"``. + Possible options are {``"zeros"``, ``"border"``, ``"reflection"``} + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample scale: {``255``, ``65535``} postprocess data by clipping to [0, 1] and scaling - [0, 255] (uint8) or [0, 65535] (uint16). Default is None to disable scaling. - it's used for PNG format only. + [0, 255] (uint8) or [0, 65535] (uint16). Default is `None` (no scaling). dtype: data type during resampling computation. Defaults to ``np.float64`` for best precision. if None, use the data type of input data. To be compatible with other modules, - the output data type is always ``np.float32``. - it's used for NIfTI format only. - output_dtype: data type for saving data. Defaults to ``np.float32``. - it's used for NIfTI format only. squeeze_end_dims: if True, any trailing singleton dimensions will be removed (after the channel has been moved to the end). So if input is (C,H,W,D), this will be altered to (H,W,D,C), and - then if C==1, it will be saved as (H,W,D). If D also ==1, it will be saved as (H,W). If false, + then if C==1, it will be saved as (H,W,D). If D is also 1, it will be saved as (H,W). If `false`, image will always be saved as (H,W,D,C). - it's used for NIfTI format only. data_root_dir: if not empty, it specifies the beginning parts of the input file's - absolute path. it's used to compute `input_file_rel_path`, the relative path to the file from + absolute path. It's used to compute `input_file_rel_path`, the relative path to the file from `data_root_dir` to preserve folder structure when saving in case there are files in different - folders with the same file names. for example: - input_file_name: /foo/bar/test1/image.nii, - output_postfix: seg - output_ext: nii.gz - output_dir: /output, - data_root_dir: /foo/bar, - output will be: /output/test1/image/image_seg.nii.gz - separate_folder: whether to save every file in a separate folder, for example: if input filename is - `image.nii`, postfix is `seg` and folder_path is `output`, if `True`, save as: - `output/image/image_seg.nii`, if `False`, save as `output/image_seg.nii`. default to `True`. - print_log: whether to print log about the saved file path, etc. default to `True`. - + folders with the same file names. For example, with the following inputs: + + - input_file_name: `/foo/bar/test1/image.nii` + - output_postfix: `seg` + - output_ext: `.nii.gz` + - output_dir: `/output` + - data_root_dir: `/foo/bar` + + The output will be: /output/test1/image/image_seg.nii.gz + + separate_folder: whether to save every file in a separate folder. For example: for the input filename + `image.nii`, postfix `seg` and folder_path `output`, if `separate_folder=True`, it will be saved as: + `output/image/image_seg.nii`, if `False`, saving as `output/image_seg.nii`. Default to `True`. + print_log: whether to print logs when saving. Default to `True`. + output_format: an optional string to specify the output image writer. + see also: `monai.data.image_writer.SUPPORTED_WRITERS`. + writer: a customised image writer to save data arrays. + if `None`, use the default writer from `monai.data.image_writer` according to `output_ext`. """ def __init__( @@ -297,55 +287,90 @@ def __init__( output_dir: PathLike = "./", output_postfix: str = "trans", output_ext: str = ".nii.gz", + output_dtype: DtypeLike = np.float32, resample: bool = True, mode: Union[GridSampleMode, InterpolateMode, str] = "nearest", padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, scale: Optional[int] = None, dtype: DtypeLike = np.float64, - output_dtype: DtypeLike = np.float32, squeeze_end_dims: bool = True, data_root_dir: PathLike = "", separate_folder: bool = True, print_log: bool = True, + output_format: Optional[str] = None, + writer: Optional[image_writer.ImageWriter] = None, ) -> None: - self.saver: Union[NiftiSaver, PNGSaver] - if output_ext in {".nii.gz", ".nii"}: - self.saver = NiftiSaver( - output_dir=output_dir, - output_postfix=output_postfix, - output_ext=output_ext, - resample=resample, - mode=GridSampleMode(mode), - padding_mode=padding_mode, - dtype=dtype, - output_dtype=output_dtype, - squeeze_end_dims=squeeze_end_dims, - data_root_dir=data_root_dir, - separate_folder=separate_folder, - print_log=print_log, - ) - elif output_ext == ".png": - self.saver = PNGSaver( - output_dir=output_dir, - output_postfix=output_postfix, - output_ext=output_ext, - resample=resample, - mode=InterpolateMode(mode), - scale=scale, - data_root_dir=data_root_dir, - separate_folder=separate_folder, - print_log=print_log, - ) - else: - raise ValueError(f"unsupported output extension: {output_ext}.") + self.folder_layout = FolderLayout( + output_dir=output_dir, + postfix=output_postfix, + extension=output_ext, + parent=separate_folder, + makedirs=True, + data_root_dir=data_root_dir, + ) + + self.output_ext = output_ext.lower() + self.writers = image_writer.resolve_writer(output_format or self.output_ext) if writer is None else (writer,) + + _output_dtype = output_dtype + if self.output_ext == ".png" and _output_dtype not in (np.uint8, np.uint16): + _output_dtype = np.uint8 + if self.output_ext == ".dcm" and _output_dtype not in (np.uint8, np.uint16): + _output_dtype = np.uint8 + self.init_kwargs = {"output_dtype": _output_dtype, "scale": scale} + self.data_kwargs = {"squeeze_end_dims": squeeze_end_dims, "channel_dim": 0} + self.meta_kwargs = {"resample": resample, "mode": mode, "padding_mode": padding_mode, "dtype": dtype} + self.write_kwargs = {"verbose": print_log} + self._data_index = 0 + + def set_options(self, init_kwargs=None, data_kwargs=None, meta_kwargs=None, write_kwargs=None): + """ + Set the options for the underlying writer by updating `kwargs` dictionaries. + + The arguments correspond to the following usage: + + - `writer = ImageWriter(**init_kwargs)` + - `writer.set_data_array(array, **data_kwargs)` + - `writer.set_metadata(meta_data, **meta_kwargs)` + - `writer.write(filename, **write_kwargs)` + + """ + if init_kwargs is not None: + self.init_kwargs.update(init_kwargs) + if data_kwargs is not None: + self.data_kwargs.update(data_kwargs) + if meta_kwargs is not None: + self.meta_kwargs.update(meta_kwargs) + if write_kwargs is not None: + self.write_kwargs.update(write_kwargs) def __call__(self, img: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None): """ Args: - img: target data content that save into file. + img: target data content that save into file. The image should be channel-first, shape: `[C,H,W,[D]]`. meta_data: key-value pairs of meta_data corresponding to the data. - """ - self.saver.save(img, meta_data) + subject = meta_data[Key.FILENAME_OR_OBJ] if meta_data else str(self._data_index) + patch_index = meta_data.get(Key.PATCH_INDEX, None) if meta_data else None + filename = self.folder_layout.filename(subject=f"{subject}", idx=patch_index) - return img + for writer_cls in self.writers: + try: + writer_obj = writer_cls(**self.init_kwargs) + writer_obj.set_data_array(data_array=img, **self.data_kwargs) + writer_obj.set_metadata(meta_dict=meta_data, **self.meta_kwargs) + writer_obj.write(filename, **self.write_kwargs) + except Exception as e: + logging.getLogger(self.__class__.__name__).exception(e, exc_info=True) + logging.getLogger(self.__class__.__name__).info( + f"{writer_cls.__class__.__name__}: unable to write {filename}." + ) + else: + self._data_index += 1 + return img + raise RuntimeError( + f"cannot find a suitable writer for {filename}.\n" + " Please install the reader libraries, see also the installation instructions:\n" + " https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies.\n" + f" The current registered: {self.writers}.\n" + ) diff --git a/monai/transforms/io/dictionary.py b/monai/transforms/io/dictionary.py index 071db4b5b2..67c7bd0588 100644 --- a/monai/transforms/io/dictionary.py +++ b/monai/transforms/io/dictionary.py @@ -21,6 +21,7 @@ import numpy as np from monai.config import DtypeLike, KeysCollection +from monai.data import image_writer from monai.data.image_reader import ImageReader from monai.transforms.io.array import LoadImage, SaveImage from monai.transforms.transform import MapTransform @@ -150,68 +151,61 @@ class SaveImaged(MapTransform): Args: keys: keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` - meta_keys: explicitly indicate the key of the corresponding meta data dictionary. - for example, for data with key `image`, the metadata by default is in `image_meta_dict`. - the meta data is a dictionary object which contains: filename, original_shape, etc. - it can be a sequence of string, map to the `keys`. - if None, will try to construct meta_keys by `key_{meta_key_postfix}`. - meta_key_postfix: if meta_keys is None and `key_{postfix}` was used to store the metadata in `LoadImaged`. - need the key to extract metadata to save images, default is `meta_dict`. - for example, for data with key `image`, the metadata by default is in `image_meta_dict`. - the meta data is a dictionary object which contains: filename, affine, original_shape, etc. - if no corresponding metadata, set to `None`. + meta_keys: explicitly indicate the key of the corresponding metadata dictionary. + For example, for data with key `image`, the metadata by default is in `image_meta_dict`. + The metadata is a dictionary contains values such as filename, original_shape. + This argument can be a sequence of string, map to the `keys`. + If `None`, will try to construct meta_keys by `key_{meta_key_postfix}`. + meta_key_postfix: if `meta_keys` is `None`, use `key_{meta_key_postfix}` to retrieve the metadict. output_dir: output image directory. output_postfix: a string appended to all output file names, default to `trans`. output_ext: output file extension name, available extensions: `.nii.gz`, `.nii`, `.png`. - resample: whether to resample before saving the data array. - if saving PNG format image, based on the `spatial_shape` from metadata. - if saving NIfTI format image, based on the `original_affine` from metadata. - mode: This option is used when ``resample = True``. Defaults to ``"nearest"``. - - - NIfTI files {``"bilinear"``, ``"nearest"``} - Interpolation mode to calculate output values. - See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html - - PNG files {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``} - The interpolation mode. - See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html - - padding_mode: This option is used when ``resample = True``. Defaults to ``"border"``. + output_dtype: data type for saving data. Defaults to ``np.float32``. + resample: whether to resample image (if needed) before saving the data array, + based on the `spatial_shape` (and `original_affine`) from metadata. + mode: This option is used when ``resample=True``. Defaults to ``"nearest"``. + Depending on the writers, the possible options are: - - NIfTI files {``"zeros"``, ``"border"``, ``"reflection"``} - Padding mode for outside grid values. - See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html - - PNG files - This option is ignored. + - {``"bilinear"``, ``"nearest"``, ``"bicubic"``}. + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample + - {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``}. + See also: https://pytorch.org/docs/stable/nn.functional.html#interpolate + padding_mode: This option is used when ``resample = True``. Defaults to ``"border"``. + Possible options are {``"zeros"``, ``"border"``, ``"reflection"``} + See also: https://pytorch.org/docs/stable/nn.functional.html#grid-sample scale: {``255``, ``65535``} postprocess data by clipping to [0, 1] and scaling - [0, 255] (uint8) or [0, 65535] (uint16). Default is None to disable scaling. - it's used for PNG format only. + [0, 255] (uint8) or [0, 65535] (uint16). Default is `None` (no scaling). dtype: data type during resampling computation. Defaults to ``np.float64`` for best precision. if None, use the data type of input data. To be compatible with other modules, - the output data type is always ``np.float32``. - it's used for NIfTI format only. output_dtype: data type for saving data. Defaults to ``np.float32``. it's used for NIfTI format only. allow_missing_keys: don't raise exception if key is missing. squeeze_end_dims: if True, any trailing singleton dimensions will be removed (after the channel has been moved to the end). So if input is (C,H,W,D), this will be altered to (H,W,D,C), and - then if C==1, it will be saved as (H,W,D). If D also ==1, it will be saved as (H,W). If false, + then if C==1, it will be saved as (H,W,D). If D is also 1, it will be saved as (H,W). If `false`, image will always be saved as (H,W,D,C). - it's used for NIfTI format only. data_root_dir: if not empty, it specifies the beginning parts of the input file's - absolute path. it's used to compute `input_file_rel_path`, the relative path to the file from + absolute path. It's used to compute `input_file_rel_path`, the relative path to the file from `data_root_dir` to preserve folder structure when saving in case there are files in different - folders with the same file names. for example: - input_file_name: /foo/bar/test1/image.nii, - output_postfix: seg - output_ext: nii.gz - output_dir: /output, - data_root_dir: /foo/bar, - output will be: /output/test1/image/image_seg.nii.gz - separate_folder: whether to save every file in a separate folder, for example: if input filename is - `image.nii`, postfix is `seg` and folder_path is `output`, if `True`, save as: - `output/image/image_seg.nii`, if `False`, save as `output/image_seg.nii`. default to `True`. - print_log: whether to print log about the saved file path, etc. default to `True`. + folders with the same file names. For example, with the following inputs: + + - input_file_name: `/foo/bar/test1/image.nii` + - output_postfix: `seg` + - output_ext: `.nii.gz` + - output_dir: `/output` + - data_root_dir: `/foo/bar` + + The output will be: /output/test1/image/image_seg.nii.gz + + separate_folder: whether to save every file in a separate folder. For example: for the input filename + `image.nii`, postfix `seg` and folder_path `output`, if `separate_folder=True`, it will be saved as: + `output/image/image_seg.nii`, if `False`, saving as `output/image_seg.nii`. Default to `True`. + print_log: whether to print logs when saving. Default to `True`. + output_format: an optional string to specify the output image writer. + see also: `monai.data.image_writer.SUPPORTED_WRITERS`. + writer: a customised image writer to save data arrays. + if `None`, use the default writer from `monai.data.image_writer` according to `output_ext`. """ @@ -234,11 +228,13 @@ def __init__( data_root_dir: str = "", separate_folder: bool = True, print_log: bool = True, + output_format: Optional[str] = None, + writer: Optional[image_writer.ImageWriter] = None, ) -> None: super().__init__(keys, allow_missing_keys) self.meta_keys = ensure_tuple_rep(meta_keys, len(self.keys)) self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys)) - self._saver = SaveImage( + self.saver = SaveImage( output_dir=output_dir, output_postfix=output_postfix, output_ext=output_ext, @@ -252,15 +248,20 @@ def __init__( data_root_dir=data_root_dir, separate_folder=separate_folder, print_log=print_log, + output_format=output_format, + writer=writer, ) + def set_options(self, init_kwargs=None, data_kwargs=None, meta_kwargs=None, write_kwargs=None): + self.saver.set_options(init_kwargs, data_kwargs, meta_kwargs, write_kwargs) + def __call__(self, data): d = dict(data) for key, meta_key, meta_key_postfix in self.key_iterator(d, self.meta_keys, self.meta_key_postfix): if meta_key is None and meta_key_postfix is not None: meta_key = f"{key}_{meta_key_postfix}" meta_data = d[meta_key] if meta_key is not None else None - self._saver(img=d[key], meta_data=meta_data) + self.saver(img=d[key], meta_data=meta_data) return d diff --git a/tests/min_tests.py b/tests/min_tests.py index 090167c4b1..cfb7038703 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -110,7 +110,6 @@ def run_testsuit(): "test_mlp", "test_nifti_header_revise", "test_nifti_rw", - "test_nifti_saver", "test_occlusion_sensitivity", "test_orientation", "test_orientationd", @@ -120,7 +119,6 @@ def run_testsuit(): "test_pil_reader", "test_plot_2d_or_3d_image", "test_png_rw", - "test_png_saver", "test_prepare_batch_default", "test_prepare_batch_extra_input", "test_rand_rotate", diff --git a/tests/test_handler_segmentation_saver.py b/tests/test_handler_segmentation_saver.py index 3632a98cfc..ee6566f6cb 100644 --- a/tests/test_handler_segmentation_saver.py +++ b/tests/test_handler_segmentation_saver.py @@ -39,7 +39,9 @@ def _train_func(engine, batch): engine = Engine(_train_func) # set up testing handler - saver = SegmentationSaver(output_dir=tempdir, output_postfix="seg", output_ext=output_ext, scale=255) + saver = SegmentationSaver( + output_dir=tempdir, output_postfix="seg", output_ext=output_ext, scale=255, output_dtype=np.uint8 + ) saver.attach(engine) data = [ @@ -65,7 +67,9 @@ def _train_func(engine, batch): engine = Engine(_train_func) # set up testing handler - saver = SegmentationSaver(output_dir=tempdir, output_postfix="seg", output_ext=output_ext, scale=255) + saver = SegmentationSaver( + output_dir=tempdir, output_postfix="seg", output_ext=output_ext, scale=255, output_dtype=np.uint8 + ) saver.attach(engine) data = [ diff --git a/tests/test_integration_segmentation_3d.py b/tests/test_integration_segmentation_3d.py index 718c9291fb..5c273d0a46 100644 --- a/tests/test_integration_segmentation_3d.py +++ b/tests/test_integration_segmentation_3d.py @@ -21,7 +21,7 @@ from torch.utils.tensorboard import SummaryWriter import monai -from monai.data import NiftiSaver, create_test_image_3d, decollate_batch +from monai.data import create_test_image_3d, decollate_batch from monai.inferers import sliding_window_inference from monai.metrics import DiceMetric from monai.networks import eval_mode @@ -34,6 +34,7 @@ LoadImaged, RandCropByPosNegLabeld, RandRotate90d, + SaveImage, ScaleIntensityd, Spacingd, ToTensor, @@ -213,17 +214,25 @@ def run_inference_test(root_dir, device="cuda:0"): with eval_mode(model): # resampling with align_corners=True or dtype=float64 will generate # slight different results between PyTorch 1.5 an 1.6 - saver = NiftiSaver(output_dir=os.path.join(root_dir, "output"), dtype=np.float32) + saver = SaveImage( + output_dir=os.path.join(root_dir, "output"), + dtype=np.float32, + output_ext=".nii.gz", + output_postfix="seg", + mode="bilinear", + ) for val_data in val_loader: val_images, val_labels = val_data["img"].to(device), val_data["seg"].to(device) # define sliding window size and batch size for windows inference sw_batch_size, roi_size = 4, (96, 96, 96) val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model) - # decollate prediction into a list and execute post processing for every item + # decollate prediction into a list val_outputs = [val_post_tran(i) for i in decollate_batch(val_outputs)] + val_meta = decollate_batch(val_data[PostFix.meta("img")]) # compute metrics dice_metric(y_pred=val_outputs, y=val_labels) - saver.save_batch(val_outputs, val_data[PostFix.meta("img")]) + for img, meta in zip(val_outputs, val_meta): # save a decollated batch of files + saver(img, meta) return dice_metric.aggregate().item() diff --git a/tests/test_nifti_saver.py b/tests/test_nifti_saver.py deleted file mode 100644 index 6855a59041..0000000000 --- a/tests/test_nifti_saver.py +++ /dev/null @@ -1,111 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import tempfile -import unittest -from pathlib import Path - -import numpy as np -import torch - -from monai.data import NiftiSaver -from monai.transforms import LoadImage - - -class TestNiftiSaver(unittest.TestCase): - def test_saved_content(self): - with tempfile.TemporaryDirectory() as tempdir: - - saver = NiftiSaver(output_dir=Path(tempdir), output_postfix="seg", output_ext=".nii.gz") - - meta_data = {"filename_or_obj": ["testfile" + str(i) + ".nii" for i in range(8)]} - saver.save_batch(torch.zeros(8, 1, 2, 2), meta_data) - for i in range(8): - filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_seg.nii.gz") - self.assertTrue(os.path.exists(os.path.join(tempdir, filepath))) - - def test_saved_resize_content(self): - with tempfile.TemporaryDirectory() as tempdir: - - saver = NiftiSaver(output_dir=tempdir, output_postfix="seg", output_ext=".nii.gz", dtype=np.float32) - - meta_data = { - "filename_or_obj": ["testfile" + str(i) + ".nii" for i in range(8)], - "affine": [np.diag(np.ones(4)) * 5] * 8, - "original_affine": [np.diag(np.ones(4)) * 1.0] * 8, - } - saver.save_batch(torch.randint(0, 255, (8, 8, 2, 2)), meta_data) - for i in range(8): - filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_seg.nii.gz") - self.assertTrue(os.path.exists(os.path.join(tempdir, filepath))) - - def test_saved_3d_resize_content(self): - with tempfile.TemporaryDirectory() as tempdir: - - saver = NiftiSaver(output_dir=tempdir, output_postfix="seg", output_ext=".nii.gz", dtype=np.float32) - - meta_data = { - "filename_or_obj": ["testfile" + str(i) + ".nii.gz" for i in range(8)], - "spatial_shape": [(10, 10, 2)] * 8, - "affine": [np.diag(np.ones(4)) * 5] * 8, - "original_affine": [np.diag(np.ones(4)) * 1.0] * 8, - } - saver.save_batch(torch.randint(0, 255, (8, 8, 1, 2, 2)), meta_data) - for i in range(8): - filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_seg.nii.gz") - self.assertTrue(os.path.exists(os.path.join(tempdir, filepath))) - - def test_saved_3d_no_resize_content(self): - with tempfile.TemporaryDirectory() as tempdir: - - saver = NiftiSaver( - output_dir=tempdir, output_postfix="seg", output_ext=".nii.gz", dtype=np.float32, resample=False - ) - - meta_data = { - "filename_or_obj": ["testfile" + str(i) + ".nii.gz" for i in range(8)], - "spatial_shape": [(10, 10, 2)] * 8, - "affine": [np.diag(np.ones(4)) * 5] * 8, - "original_affine": [np.diag(np.ones(4)) * 1.0] * 8, - } - saver.save_batch(torch.randint(0, 255, (8, 8, 1, 2, 2)), meta_data) - for i in range(8): - filepath = os.path.join(tempdir, "testfile" + str(i), "testfile" + str(i) + "_seg.nii.gz") - img, _ = LoadImage("nibabelreader")(filepath) - self.assertEqual(img.shape, (1, 2, 2, 8)) - - def test_squeeze_end_dims(self): - with tempfile.TemporaryDirectory() as tempdir: - - for squeeze_end_dims in [False, True]: - - saver = NiftiSaver( - output_dir=tempdir, - output_postfix="", - output_ext=".nii.gz", - dtype=np.float32, - squeeze_end_dims=squeeze_end_dims, - ) - - fname = "testfile_squeeze" - meta_data = {"filename_or_obj": fname} - - # 2d image w channel - saver.save(torch.randint(0, 255, (1, 2, 2)), meta_data) - - im, meta = LoadImage()(os.path.join(tempdir, fname, fname + ".nii.gz")) - self.assertTrue(im.ndim == 2 if squeeze_end_dims else 4) - self.assertTrue(meta["dim"][0] == im.ndim) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_png_saver.py b/tests/test_png_saver.py deleted file mode 100644 index d832718643..0000000000 --- a/tests/test_png_saver.py +++ /dev/null @@ -1,76 +0,0 @@ -# Copyright (c) MONAI Consortium -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import tempfile -import unittest -from pathlib import Path - -import torch - -from monai.data import PNGSaver - - -class TestPNGSaver(unittest.TestCase): - def test_saved_content(self): - with tempfile.TemporaryDirectory() as tempdir: - - saver = PNGSaver(output_dir=tempdir, output_postfix="seg", output_ext=".png", scale=255) - - meta_data = {"filename_or_obj": ["testfile" + str(i) + ".jpg" for i in range(8)]} - saver.save_batch(torch.randint(1, 200, (8, 1, 2, 2)), meta_data) - for i in range(8): - filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_seg.png") - self.assertTrue(os.path.exists(os.path.join(tempdir, filepath))) - - def test_saved_content_three_channel(self): - with tempfile.TemporaryDirectory() as tempdir: - - saver = PNGSaver(output_dir=Path(tempdir), output_postfix="seg", output_ext=".png", scale=255) - - meta_data = {"filename_or_obj": ["testfile" + str(i) + ".jpg" for i in range(8)]} - saver.save_batch(torch.randint(1, 200, (8, 3, 2, 2)), meta_data) - for i in range(8): - filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_seg.png") - self.assertTrue(os.path.exists(os.path.join(tempdir, filepath))) - - def test_saved_content_spatial_size(self): - with tempfile.TemporaryDirectory() as tempdir: - - saver = PNGSaver(output_dir=tempdir, output_postfix="seg", output_ext=".png", scale=255) - - meta_data = { - "filename_or_obj": ["testfile" + str(i) + ".jpg" for i in range(8)], - "spatial_shape": [(4, 4) for i in range(8)], - } - saver.save_batch(torch.randint(1, 200, (8, 1, 2, 2)), meta_data) - for i in range(8): - filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_seg.png") - self.assertTrue(os.path.exists(os.path.join(tempdir, filepath))) - - def test_saved_specified_root(self): - with tempfile.TemporaryDirectory() as tempdir: - - saver = PNGSaver( - output_dir=tempdir, output_postfix="seg", output_ext=".png", scale=255, data_root_dir="test" - ) - - meta_data = { - "filename_or_obj": [os.path.join("test", "testfile" + str(i), "image" + ".jpg") for i in range(8)] - } - saver.save_batch(torch.randint(1, 200, (8, 1, 2, 2)), meta_data) - for i in range(8): - filepath = os.path.join("testfile" + str(i), "image", "image" + "_seg.png") - self.assertTrue(os.path.exists(os.path.join(tempdir, filepath))) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_save_image.py b/tests/test_save_image.py index d3671cf830..7c703b5220 100644 --- a/tests/test_save_image.py +++ b/tests/test_save_image.py @@ -13,6 +13,7 @@ import tempfile import unittest +import numpy as np import torch from parameterized import parameterized @@ -22,9 +23,18 @@ TEST_CASE_2 = [torch.randint(0, 255, (1, 2, 3, 4)), None, ".nii.gz", False] +TEST_CASE_3 = [torch.randint(0, 255, (1, 2, 3, 4)), {"filename_or_obj": "testfile0.nrrd"}, ".nrrd", False] + +TEST_CASE_4 = [ + np.random.randint(0, 255, (3, 2, 4, 5), dtype=np.uint8), + {"filename_or_obj": "testfile0.dcm"}, + ".dcm", + False, +] + class TestSaveImage(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) def test_saved_content(self, test_data, meta_data, output_ext, resample): with tempfile.TemporaryDirectory() as tempdir: trans = SaveImage( diff --git a/tests/test_save_imaged.py b/tests/test_save_imaged.py index 6f0bb4c2ba..a6988683e5 100644 --- a/tests/test_save_imaged.py +++ b/tests/test_save_imaged.py @@ -35,9 +35,19 @@ False, ] +TEST_CASE_3 = [ + { + "img": torch.randint(0, 255, (1, 2, 3, 4)), + PostFix.meta("img"): {"filename_or_obj": "testfile0.nrrd"}, + "patch_index": 6, + }, + ".nrrd", + False, +] + class TestSaveImaged(unittest.TestCase): - @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + @parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) def test_saved_content(self, test_data, output_ext, resample): with tempfile.TemporaryDirectory() as tempdir: trans = SaveImaged( From 0df2424d00799970ea5d36efa3f0b97b369fd678 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Tue, 8 Feb 2022 20:37:06 +0000 Subject: [PATCH 02/11] more tests Signed-off-by: Wenqi Li --- monai/transforms/io/array.py | 12 ++++++------ monai/transforms/io/dictionary.py | 6 +++--- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index 5b9fdbc3c3..04a0a41310 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -276,7 +276,7 @@ class SaveImage(Transform): `image.nii`, postfix `seg` and folder_path `output`, if `separate_folder=True`, it will be saved as: `output/image/image_seg.nii`, if `False`, saving as `output/image_seg.nii`. Default to `True`. print_log: whether to print logs when saving. Default to `True`. - output_format: an optional string to specify the output image writer. + output_format: an optional string of filename extension to specify the output image writer. see also: `monai.data.image_writer.SUPPORTED_WRITERS`. writer: a customised image writer to save data arrays. if `None`, use the default writer from `monai.data.image_writer` according to `output_ext`. @@ -309,8 +309,8 @@ def __init__( data_root_dir=data_root_dir, ) - self.output_ext = output_ext.lower() - self.writers = image_writer.resolve_writer(output_format or self.output_ext) if writer is None else (writer,) + self.output_ext = output_ext.lower() or output_format.lower() + self.writers = image_writer.resolve_writer(self.output_ext) if writer is None else (writer,) _output_dtype = output_dtype if self.output_ext == ".png" and _output_dtype not in (np.uint8, np.uint16): @@ -348,7 +348,7 @@ def __call__(self, img: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dic """ Args: img: target data content that save into file. The image should be channel-first, shape: `[C,H,W,[D]]`. - meta_data: key-value pairs of meta_data corresponding to the data. + meta_data: key-value pairs of metadata corresponding to the data. """ subject = meta_data[Key.FILENAME_OR_OBJ] if meta_data else str(self._data_index) patch_index = meta_data.get(Key.PATCH_INDEX, None) if meta_data else None @@ -370,7 +370,7 @@ def __call__(self, img: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dic return img raise RuntimeError( f"cannot find a suitable writer for {filename}.\n" - " Please install the reader libraries, see also the installation instructions:\n" + " Please install the writer libraries, see also the installation instructions:\n" " https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies.\n" - f" The current registered: {self.writers}.\n" + f" The current registered writers for {self.output_ext}: {self.writers}.\n" ) diff --git a/monai/transforms/io/dictionary.py b/monai/transforms/io/dictionary.py index 67c7bd0588..3ae6db20dc 100644 --- a/monai/transforms/io/dictionary.py +++ b/monai/transforms/io/dictionary.py @@ -55,7 +55,7 @@ class LoadImaged(MapTransform): - If `reader` is specified, the loader will attempt to use the specified readers and the default supported readers. This might introduce overheads when handling the exceptions of trying the incompatible loaders. - In this case, it is therefore recommended to set the most appropriate reader as + In this case, it is therefore recommended setting the most appropriate reader as the last item of the `reader` parameter. See also: @@ -85,7 +85,7 @@ def __init__( at runtime or use the default readers. If a string of reader name provided, will construct a reader object with the `*args` and `**kwargs` parameters, supported reader name: "NibabelReader", "PILReader", "ITKReader", "NumpyReader". - dtype: if not None convert the loaded image data to this data type. + dtype: if not None, convert the loaded image data to this data type. meta_keys: explicitly indicate the key to store the corresponding meta data dictionary. the meta data is a dictionary object which contains: filename, original_shape, etc. it can be a sequence of string, map to the `keys`. @@ -93,7 +93,7 @@ def __init__( meta_key_postfix: if meta_keys is None, use `key_{postfix}` to store the metadata of the nifti image, default is `meta_dict`. The meta data is a dictionary object. For example, load nifti file for `image`, store the metadata into `image_meta_dict`. - overwriting: whether allow to overwrite existing meta data of same key. + overwriting: whether allow overwriting existing meta data of same key. default is False, which will raise exception if encountering existing key. image_only: if True return dictionary containing just only the image volumes, otherwise return dictionary containing image data array and header dict per input key. From 677cb10a3bb4d8b832aac739eba00c65a29a0f4d Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Tue, 8 Feb 2022 20:38:33 +0000 Subject: [PATCH 03/11] more tests Signed-off-by: Wenqi Li --- tests/test_save_image.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_save_image.py b/tests/test_save_image.py index 7c703b5220..a1297c1e61 100644 --- a/tests/test_save_image.py +++ b/tests/test_save_image.py @@ -41,8 +41,7 @@ def test_saved_content(self, test_data, meta_data, output_ext, resample): output_dir=tempdir, output_ext=output_ext, resample=resample, - # test saving into the same folder - separate_folder=False, + separate_folder=False, # test saving into the same folder ) trans(test_data, meta_data) From 05b94ad56fbd6265636bb894082eb692c288ae5c Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Wed, 9 Feb 2022 09:17:35 +0000 Subject: [PATCH 04/11] adds saving loading tests Signed-off-by: Wenqi Li --- monai/data/image_writer.py | 7 ++- monai/transforms/io/array.py | 4 +- monai/transforms/io/dictionary.py | 2 +- tests/min_tests.py | 1 + tests/test_image_rw.py | 82 +++++++++++++++++++++++++++++++ 5 files changed, 92 insertions(+), 4 deletions(-) create mode 100644 tests/test_image_rw.py diff --git a/monai/data/image_writer.py b/monai/data/image_writer.py index 62ffc6c072..6394e1c8a0 100644 --- a/monai/data/image_writer.py +++ b/monai/data/image_writer.py @@ -360,7 +360,9 @@ def __init__(self, output_dtype: DtypeLike = np.float32, **kwargs): """ super().__init__(output_dtype=output_dtype, affine=None, channel_dim=0, **kwargs) - def set_data_array(self, data_array, channel_dim: Optional[int] = 0, squeeze_end_dims: bool = True, **kwargs): + def set_data_array( + self, data_array: NdarrayOrTensor, channel_dim: Optional[int] = 0, squeeze_end_dims: bool = True, **kwargs + ): """ Convert ``data_array`` into 'channel-last' numpy ndarray. @@ -372,6 +374,7 @@ def set_data_array(self, data_array, channel_dim: Optional[int] = 0, squeeze_end kwargs: keyword arguments passed to ``self.convert_to_channel_last``, currently support ``spatial_ndim`` and ``contiguous``, defauting to ``3`` and ``False`` respectively. """ + _r = len(data_array.shape) self.data_obj = self.convert_to_channel_last( data=data_array, channel_dim=channel_dim, @@ -379,7 +382,7 @@ def set_data_array(self, data_array, channel_dim: Optional[int] = 0, squeeze_end spatial_ndim=kwargs.pop("spatial_ndim", 3), contiguous=kwargs.pop("contiguous", True), ) - self.channel_dim = channel_dim + self.channel_dim = channel_dim if len(self.data_obj.shape) >= _r else None # channel dim is at the end def set_metadata(self, meta_dict: Optional[Mapping] = None, resample: bool = True, **options): """ diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index 04a0a41310..4f4fbaf464 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -297,7 +297,7 @@ def __init__( data_root_dir: PathLike = "", separate_folder: bool = True, print_log: bool = True, - output_format: Optional[str] = None, + output_format: str = "", writer: Optional[image_writer.ImageWriter] = None, ) -> None: self.folder_layout = FolderLayout( @@ -311,6 +311,7 @@ def __init__( self.output_ext = output_ext.lower() or output_format.lower() self.writers = image_writer.resolve_writer(self.output_ext) if writer is None else (writer,) + self.writer_obj = None _output_dtype = output_dtype if self.output_ext == ".png" and _output_dtype not in (np.uint8, np.uint16): @@ -360,6 +361,7 @@ def __call__(self, img: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dic writer_obj.set_data_array(data_array=img, **self.data_kwargs) writer_obj.set_metadata(meta_dict=meta_data, **self.meta_kwargs) writer_obj.write(filename, **self.write_kwargs) + self.writer_obj = writer_obj except Exception as e: logging.getLogger(self.__class__.__name__).exception(e, exc_info=True) logging.getLogger(self.__class__.__name__).info( diff --git a/monai/transforms/io/dictionary.py b/monai/transforms/io/dictionary.py index 3ae6db20dc..d9a3f44e4b 100644 --- a/monai/transforms/io/dictionary.py +++ b/monai/transforms/io/dictionary.py @@ -228,7 +228,7 @@ def __init__( data_root_dir: str = "", separate_folder: bool = True, print_log: bool = True, - output_format: Optional[str] = None, + output_format: str = "", writer: Optional[image_writer.ImageWriter] = None, ) -> None: super().__init__(keys, allow_missing_keys) diff --git a/tests/min_tests.py b/tests/min_tests.py index cfb7038703..192d390ded 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -87,6 +87,7 @@ def run_testsuit(): "test_header_correct", "test_hilbert_transform", "test_image_dataset", + "test_image_rw", "test_img2tensorboard", "test_integration_fast_train", "test_integration_segmentation_3d", diff --git a/tests/test_image_rw.py b/tests/test_image_rw.py new file mode 100644 index 0000000000..bd779b8ee7 --- /dev/null +++ b/tests/test_image_rw.py @@ -0,0 +1,82 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +import os +import shutil +import tempfile +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.data.image_reader import ITKReader, NibabelReader +from monai.data.image_writer import ITKWriter, NibabelWriter +from monai.transforms import LoadImage, SaveImage, moveaxis +from tests.utils import TEST_NDARRAYS, assert_allclose + + +class TestLoadSaveNifti(unittest.TestCase): + def setUp(self): + self.test_dir = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.test_dir, ignore_errors=True) + + def nifti_rw(self, test_data, reader, writer, dtype, resample=True): + test_data = test_data.astype(dtype) + ndim = len(test_data.shape) - 1 + for p in TEST_NDARRAYS: + output_ext = ".nii.gz" + filepath = f"testfile_{ndim}d" + saver = SaveImage( + output_dir=self.test_dir, output_ext=output_ext, resample=resample, separate_folder=False, writer=writer + ) + saver( + p(test_data), + { + "filename_or_obj": f"{filepath}.png", + "affine": np.eye(4), + "original_affine": np.array([[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]]), + }, + ) + saved_path = os.path.join(self.test_dir, filepath + "_trans" + output_ext) + self.assertTrue(os.path.exists(saved_path)) + loader = LoadImage(reader=reader, squeeze_non_spatial_dims=True) + data, meta = loader(saved_path) + if meta["original_channel_dim"] == -1: + _test_data = moveaxis(test_data, 0, -1) + else: + _test_data = test_data[0] + if resample: + _test_data = moveaxis(_test_data, 0, 1) + assert_allclose(data, _test_data) + + @parameterized.expand(itertools.product([NibabelReader, ITKReader], [NibabelWriter, ITKWriter])) + def test_2d(self, reader, writer): + test_data = np.arange(48, dtype=np.uint8).reshape(1, 6, 8) + self.nifti_rw(test_data, reader, writer, np.uint8) + self.nifti_rw(test_data, reader, writer, np.float32) + + @parameterized.expand(itertools.product([NibabelReader, ITKReader], [NibabelWriter, ITKWriter])) + def test_3d(self, reader, writer): + test_data = np.arange(48, dtype=np.uint8).reshape(1, 2, 3, 8) + self.nifti_rw(test_data, reader, writer, int) + self.nifti_rw(test_data, reader, writer, int, False) + + @parameterized.expand(itertools.product([NibabelReader, ITKReader], [NibabelWriter, ITKWriter])) + def test_4d(self, reader, writer): + test_data = np.arange(48, dtype=np.uint8).reshape(2, 1, 3, 8) + self.nifti_rw(test_data, reader, writer, np.float16) + + +if __name__ == "__main__": + unittest.main() From aa7af6312295d6a9e70a8f5f3282f59191413a15 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Wed, 9 Feb 2022 23:26:31 +0000 Subject: [PATCH 05/11] fixes #3783 Signed-off-by: Wenqi Li --- monai/transforms/spatial/array.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 38cfeb00c9..096ce9d79a 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -173,14 +173,14 @@ def __call__( if src_affine is None: src_affine = np.eye(4, dtype=np.float64) spatial_rank = min(len(img.shape) - 1, src_affine.shape[0] - 1, 3) - if spatial_size is not -1 and spatial_size is not None: + if (isinstance(spatial_size, int) and spatial_size != -1) and spatial_size is not None: spatial_rank = min(len(ensure_tuple(spatial_size)), 3) # infer spatial rank based on spatial_size src_affine = to_affine_nd(spatial_rank, src_affine) dst_affine = to_affine_nd(spatial_rank, dst_affine) if dst_affine is not None else src_affine dst_affine, *_ = convert_to_dst_type(dst_affine, dst_affine, dtype=torch.float32) in_spatial_size = np.asarray(img.shape[1 : spatial_rank + 1]) - if spatial_size is -1: # using the input spatial size + if isinstance(spatial_size, int) and (spatial_size == -1): # using the input spatial size spatial_size = in_spatial_size elif spatial_size is None: # auto spatial size spatial_size, _ = compute_shape_offset(in_spatial_size, src_affine, dst_affine) # type: ignore From 743901980e773c06a3ebfd5c5ea3e6f7291b778c Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Wed, 9 Feb 2022 23:32:44 +0000 Subject: [PATCH 06/11] enhance import checks Signed-off-by: Wenqi Li --- runtests.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/runtests.sh b/runtests.sh index 6bd1b3d51f..23ab4f4bf1 100755 --- a/runtests.sh +++ b/runtests.sh @@ -111,7 +111,7 @@ function print_usage { function check_import { echo "Python: ${PY_EXE}" - ${cmdPrefix}${PY_EXE} -c "import monai" + ${cmdPrefix}${PY_EXE} -W error -c "import monai" } function print_version { From f86ce2b681dc4a45302ed40a5110a71ada628fbf Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Thu, 10 Feb 2022 10:13:47 +0000 Subject: [PATCH 07/11] warn to exception; int check Signed-off-by: Wenqi Li --- monai/transforms/spatial/array.py | 2 +- monai/utils/module.py | 7 +++++-- runtests.sh | 2 +- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 096ce9d79a..34cdc37148 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -173,7 +173,7 @@ def __call__( if src_affine is None: src_affine = np.eye(4, dtype=np.float64) spatial_rank = min(len(img.shape) - 1, src_affine.shape[0] - 1, 3) - if (isinstance(spatial_size, int) and spatial_size != -1) and spatial_size is not None: + if (not isinstance(spatial_size, int) or spatial_size != -1) and spatial_size is not None: spatial_rank = min(len(ensure_tuple(spatial_size)), 3) # infer spatial rank based on spatial_size src_affine = to_affine_nd(spatial_rank, src_affine) dst_affine = to_affine_nd(spatial_rank, dst_affine) if dst_affine is not None else src_affine diff --git a/monai/utils/module.py b/monai/utils/module.py index 8813301d8e..613eaf64dc 100644 --- a/monai/utils/module.py +++ b/monai/utils/module.py @@ -421,9 +421,12 @@ def version_leq(lhs: str, rhs: str): """ lhs, rhs = str(lhs), str(rhs) - ver, has_ver = optional_import("pkg_resources", name="parse_version") + pkging, has_ver = optional_import("pkg_resources", name="packaging") if has_ver: - return ver(lhs) <= ver(rhs) + try: + return pkging.version.Version(lhs) <= pkging.version.Version(rhs) + except pkging.version.InvalidVersion: + return False def _try_cast(val: str): val = val.strip() diff --git a/runtests.sh b/runtests.sh index 23ab4f4bf1..5464f3d020 100755 --- a/runtests.sh +++ b/runtests.sh @@ -111,7 +111,7 @@ function print_usage { function check_import { echo "Python: ${PY_EXE}" - ${cmdPrefix}${PY_EXE} -W error -c "import monai" + ${cmdPrefix}${PY_EXE} -W error -W ignore::DeprecationWarning -c "import monai" } function print_version { From 9292690646a8ae551c6062bd9157d83c2260175c Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Thu, 10 Feb 2022 11:08:42 +0000 Subject: [PATCH 08/11] fixes tests Signed-off-by: Wenqi Li --- monai/utils/module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/utils/module.py b/monai/utils/module.py index 613eaf64dc..b21828fbbb 100644 --- a/monai/utils/module.py +++ b/monai/utils/module.py @@ -426,7 +426,7 @@ def version_leq(lhs: str, rhs: str): try: return pkging.version.Version(lhs) <= pkging.version.Version(rhs) except pkging.version.InvalidVersion: - return False + return True def _try_cast(val: str): val = val.strip() From 01853a6d182639a81d06d72def22f00861e89c77 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Thu, 10 Feb 2022 18:01:22 +0000 Subject: [PATCH 09/11] update based on comments Signed-off-by: Wenqi Li --- monai/data/folder_layout.py | 5 +- monai/data/image_writer.py | 25 ++++---- monai/transforms/io/array.py | 2 +- tests/min_tests.py | 2 + tests/test_nifti_saver.py | 111 +++++++++++++++++++++++++++++++++++ tests/test_png_saver.py | 76 ++++++++++++++++++++++++ 6 files changed, 208 insertions(+), 13 deletions(-) create mode 100644 tests/test_nifti_saver.py create mode 100644 tests/test_png_saver.py diff --git a/monai/data/folder_layout.py b/monai/data/folder_layout.py index d8ce162c27..b2f41b0651 100644 --- a/monai/data/folder_layout.py +++ b/monai/data/folder_layout.py @@ -29,7 +29,7 @@ class FolderLayout: layout = FolderLayout( output_dir="/test_run_1/", postfix="seg", - extension=".nii", + extension="nii", makedirs=False) layout.filename(subject="Sub-A", idx="00", modality="T1") # return value: "/test_run_1/Sub-A_seg_00_modality-T1.nii" @@ -95,5 +95,6 @@ def filename(self, subject: PathLike = "subject", idx=None, **kwargs): for k, v in kwargs.items(): full_name += f"_{k}-{v}" if self.ext is not None: - full_name += f"{self.ext}" + ext = f"{self.ext}" + full_name += f".{ext}" if ext and not ext.startswith(".") else f"{ext}" return full_name diff --git a/monai/data/image_writer.py b/monai/data/image_writer.py index 6394e1c8a0..074f4e22cf 100644 --- a/monai/data/image_writer.py +++ b/monai/data/image_writer.py @@ -56,7 +56,7 @@ SUPPORTED_WRITERS: Dict = {} -def register_writer(ext_name, *im_writer): +def register_writer(ext_name, *im_writers): """ Register ``ImageWriter``, so that writing a file with filename extension ``ext_name`` could be resolved to a tuple of potentially appropriate ``ImageWriter``. @@ -64,18 +64,20 @@ def register_writer(ext_name, *im_writer): .. code-block:: python - from monai.data import image_writer + from monai.data import register_writer # `MyWriter` must implement `ImageWriter` interface - image_writer.register_writer(".nii", MyWriter) + register_writer("nii", MyWriter) Args: ext_name: the filename extension of the image. As an indexing key, it will be converted to a lower case string. - im_writer: one or multiple ImageWriter classes with high priority ones first. + im_writers: one or multiple ImageWriter classes with high priority ones first. """ fmt = f"{ext_name}".lower() + if fmt.startswith("."): + fmt = fmt[1:] existing = look_up_option(fmt, SUPPORTED_WRITERS, default=()) - all_writers = im_writer + existing + all_writers = im_writers + existing SUPPORTED_WRITERS[fmt] = all_writers @@ -93,8 +95,11 @@ def resolve_writer(ext_name, error_if_not_found=True) -> Sequence: if not SUPPORTED_WRITERS: init() fmt = f"{ext_name}".lower() + if fmt.startswith("."): + fmt = fmt[1:] avail_writers = [] - for _writer in look_up_option(fmt, SUPPORTED_WRITERS, default=SUPPORTED_WRITERS["*"]): + default_writers = SUPPORTED_WRITERS.get("*", ()) + for _writer in look_up_option(fmt, SUPPORTED_WRITERS, default=default_writers): try: _writer() # this triggers `monai.utils.module.require_pkg` to check the system availability avail_writers.append(_writer) @@ -788,9 +793,9 @@ def init(): """ Initialize the image writer modules according to the filename extension. """ - for ext in (".png", ".jpg", ".jpeg", ".bmp", ".tiff", ".tif"): + for ext in ("png", "jpg", "jpeg", "bmp", "tiff", "tif"): register_writer(ext, PILWriter) # TODO: test 16-bit - for ext in (".nii.gz", ".nii"): + for ext in ("nii.gz", "nii"): register_writer(ext, NibabelWriter, ITKWriter) - register_writer(".nrrd", ITKWriter, NibabelWriter) - register_writer("*", ITKWriter, NibabelWriter) + register_writer("nrrd", ITKWriter, NibabelWriter) + register_writer("*", ITKWriter, NibabelWriter, ITKWriter) diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index 4f4fbaf464..46460292b0 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -326,7 +326,7 @@ def __init__( def set_options(self, init_kwargs=None, data_kwargs=None, meta_kwargs=None, write_kwargs=None): """ - Set the options for the underlying writer by updating `kwargs` dictionaries. + Set the options for the underlying writer by updating the `self.*_kwargs` dictionaries. The arguments correspond to the following usage: diff --git a/tests/min_tests.py b/tests/min_tests.py index 192d390ded..426650eb04 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -111,6 +111,7 @@ def run_testsuit(): "test_mlp", "test_nifti_header_revise", "test_nifti_rw", + "test_nifti_saver", "test_occlusion_sensitivity", "test_orientation", "test_orientationd", @@ -120,6 +121,7 @@ def run_testsuit(): "test_pil_reader", "test_plot_2d_or_3d_image", "test_png_rw", + "test_png_saver", "test_prepare_batch_default", "test_prepare_batch_extra_input", "test_rand_rotate", diff --git a/tests/test_nifti_saver.py b/tests/test_nifti_saver.py new file mode 100644 index 0000000000..6855a59041 --- /dev/null +++ b/tests/test_nifti_saver.py @@ -0,0 +1,111 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +import unittest +from pathlib import Path + +import numpy as np +import torch + +from monai.data import NiftiSaver +from monai.transforms import LoadImage + + +class TestNiftiSaver(unittest.TestCase): + def test_saved_content(self): + with tempfile.TemporaryDirectory() as tempdir: + + saver = NiftiSaver(output_dir=Path(tempdir), output_postfix="seg", output_ext=".nii.gz") + + meta_data = {"filename_or_obj": ["testfile" + str(i) + ".nii" for i in range(8)]} + saver.save_batch(torch.zeros(8, 1, 2, 2), meta_data) + for i in range(8): + filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_seg.nii.gz") + self.assertTrue(os.path.exists(os.path.join(tempdir, filepath))) + + def test_saved_resize_content(self): + with tempfile.TemporaryDirectory() as tempdir: + + saver = NiftiSaver(output_dir=tempdir, output_postfix="seg", output_ext=".nii.gz", dtype=np.float32) + + meta_data = { + "filename_or_obj": ["testfile" + str(i) + ".nii" for i in range(8)], + "affine": [np.diag(np.ones(4)) * 5] * 8, + "original_affine": [np.diag(np.ones(4)) * 1.0] * 8, + } + saver.save_batch(torch.randint(0, 255, (8, 8, 2, 2)), meta_data) + for i in range(8): + filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_seg.nii.gz") + self.assertTrue(os.path.exists(os.path.join(tempdir, filepath))) + + def test_saved_3d_resize_content(self): + with tempfile.TemporaryDirectory() as tempdir: + + saver = NiftiSaver(output_dir=tempdir, output_postfix="seg", output_ext=".nii.gz", dtype=np.float32) + + meta_data = { + "filename_or_obj": ["testfile" + str(i) + ".nii.gz" for i in range(8)], + "spatial_shape": [(10, 10, 2)] * 8, + "affine": [np.diag(np.ones(4)) * 5] * 8, + "original_affine": [np.diag(np.ones(4)) * 1.0] * 8, + } + saver.save_batch(torch.randint(0, 255, (8, 8, 1, 2, 2)), meta_data) + for i in range(8): + filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_seg.nii.gz") + self.assertTrue(os.path.exists(os.path.join(tempdir, filepath))) + + def test_saved_3d_no_resize_content(self): + with tempfile.TemporaryDirectory() as tempdir: + + saver = NiftiSaver( + output_dir=tempdir, output_postfix="seg", output_ext=".nii.gz", dtype=np.float32, resample=False + ) + + meta_data = { + "filename_or_obj": ["testfile" + str(i) + ".nii.gz" for i in range(8)], + "spatial_shape": [(10, 10, 2)] * 8, + "affine": [np.diag(np.ones(4)) * 5] * 8, + "original_affine": [np.diag(np.ones(4)) * 1.0] * 8, + } + saver.save_batch(torch.randint(0, 255, (8, 8, 1, 2, 2)), meta_data) + for i in range(8): + filepath = os.path.join(tempdir, "testfile" + str(i), "testfile" + str(i) + "_seg.nii.gz") + img, _ = LoadImage("nibabelreader")(filepath) + self.assertEqual(img.shape, (1, 2, 2, 8)) + + def test_squeeze_end_dims(self): + with tempfile.TemporaryDirectory() as tempdir: + + for squeeze_end_dims in [False, True]: + + saver = NiftiSaver( + output_dir=tempdir, + output_postfix="", + output_ext=".nii.gz", + dtype=np.float32, + squeeze_end_dims=squeeze_end_dims, + ) + + fname = "testfile_squeeze" + meta_data = {"filename_or_obj": fname} + + # 2d image w channel + saver.save(torch.randint(0, 255, (1, 2, 2)), meta_data) + + im, meta = LoadImage()(os.path.join(tempdir, fname, fname + ".nii.gz")) + self.assertTrue(im.ndim == 2 if squeeze_end_dims else 4) + self.assertTrue(meta["dim"][0] == im.ndim) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_png_saver.py b/tests/test_png_saver.py new file mode 100644 index 0000000000..d832718643 --- /dev/null +++ b/tests/test_png_saver.py @@ -0,0 +1,76 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +import unittest +from pathlib import Path + +import torch + +from monai.data import PNGSaver + + +class TestPNGSaver(unittest.TestCase): + def test_saved_content(self): + with tempfile.TemporaryDirectory() as tempdir: + + saver = PNGSaver(output_dir=tempdir, output_postfix="seg", output_ext=".png", scale=255) + + meta_data = {"filename_or_obj": ["testfile" + str(i) + ".jpg" for i in range(8)]} + saver.save_batch(torch.randint(1, 200, (8, 1, 2, 2)), meta_data) + for i in range(8): + filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_seg.png") + self.assertTrue(os.path.exists(os.path.join(tempdir, filepath))) + + def test_saved_content_three_channel(self): + with tempfile.TemporaryDirectory() as tempdir: + + saver = PNGSaver(output_dir=Path(tempdir), output_postfix="seg", output_ext=".png", scale=255) + + meta_data = {"filename_or_obj": ["testfile" + str(i) + ".jpg" for i in range(8)]} + saver.save_batch(torch.randint(1, 200, (8, 3, 2, 2)), meta_data) + for i in range(8): + filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_seg.png") + self.assertTrue(os.path.exists(os.path.join(tempdir, filepath))) + + def test_saved_content_spatial_size(self): + with tempfile.TemporaryDirectory() as tempdir: + + saver = PNGSaver(output_dir=tempdir, output_postfix="seg", output_ext=".png", scale=255) + + meta_data = { + "filename_or_obj": ["testfile" + str(i) + ".jpg" for i in range(8)], + "spatial_shape": [(4, 4) for i in range(8)], + } + saver.save_batch(torch.randint(1, 200, (8, 1, 2, 2)), meta_data) + for i in range(8): + filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_seg.png") + self.assertTrue(os.path.exists(os.path.join(tempdir, filepath))) + + def test_saved_specified_root(self): + with tempfile.TemporaryDirectory() as tempdir: + + saver = PNGSaver( + output_dir=tempdir, output_postfix="seg", output_ext=".png", scale=255, data_root_dir="test" + ) + + meta_data = { + "filename_or_obj": [os.path.join("test", "testfile" + str(i), "image" + ".jpg") for i in range(8)] + } + saver.save_batch(torch.randint(1, 200, (8, 1, 2, 2)), meta_data) + for i in range(8): + filepath = os.path.join("testfile" + str(i), "image", "image" + "_seg.png") + self.assertTrue(os.path.exists(os.path.join(tempdir, filepath))) + + +if __name__ == "__main__": + unittest.main() From 2c09f8efc43ed4a34644c25d1340ea516aaa412e Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Thu, 10 Feb 2022 21:31:29 +0000 Subject: [PATCH 10/11] fixes #3787 Signed-off-by: Wenqi Li --- monai/transforms/io/array.py | 33 +++++++++++++++++++++++---------- 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index 46460292b0..a852188f2d 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -16,6 +16,7 @@ import inspect import logging import sys +import traceback import warnings from pathlib import Path from typing import Dict, List, Optional, Sequence, Union @@ -184,7 +185,7 @@ def __call__(self, filename: Union[Sequence[PathLike], PathLike], reader: Option """ filename = tuple(f"{Path(s).expanduser()}" for s in ensure_tuple(filename)) # allow Path objects - img = None + img, err = None, [] if reader is not None: img = reader.read(filename) # runtime specified reader else: @@ -197,20 +198,24 @@ def __call__(self, filename: Union[Sequence[PathLike], PathLike], reader: Option try: img = reader.read(filename) except Exception as e: - logging.getLogger(self.__class__.__name__).debug( - f"{reader.__class__.__name__}: unable to load {filename}.\n" f"Error: {e}" + err.append(traceback.format_exc()) + logging.getLogger(self.__class__.__name__).debug(e, exc_info=True) + logging.getLogger(self.__class__.__name__).info( + f"{reader.__class__.__name__}: unable to load {filename}.\n" ) else: + err = [] break if img is None or reader is None: if isinstance(filename, tuple) and len(filename) == 1: filename = filename[0] + msg = "\n".join([f"{e}" for e in err]) raise RuntimeError( - f"cannot find a suitable reader for file: {filename}.\n" + f"{self.__class__.__name__} cannot find a suitable reader for file: {filename}.\n" " Please install the reader libraries, see also the installation instructions:\n" " https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies.\n" - f" The current registered: {self.readers}.\n" + f" The current registered: {self.readers}.\n{msg}" ) img_array, meta_data = reader.get_data(img) @@ -280,6 +285,8 @@ class SaveImage(Transform): see also: `monai.data.image_writer.SUPPORTED_WRITERS`. writer: a customised image writer to save data arrays. if `None`, use the default writer from `monai.data.image_writer` according to `output_ext`. + channel_dim: the index of the channel dimension. Default to `0`. + `None` to indicate no channel dimension. """ def __init__( @@ -299,6 +306,7 @@ def __init__( print_log: bool = True, output_format: str = "", writer: Optional[image_writer.ImageWriter] = None, + channel_dim: Optional[int] = 0, ) -> None: self.folder_layout = FolderLayout( output_dir=output_dir, @@ -319,7 +327,7 @@ def __init__( if self.output_ext == ".dcm" and _output_dtype not in (np.uint8, np.uint16): _output_dtype = np.uint8 self.init_kwargs = {"output_dtype": _output_dtype, "scale": scale} - self.data_kwargs = {"squeeze_end_dims": squeeze_end_dims, "channel_dim": 0} + self.data_kwargs = {"squeeze_end_dims": squeeze_end_dims, "channel_dim": channel_dim} self.meta_kwargs = {"resample": resample, "mode": mode, "padding_mode": padding_mode, "dtype": dtype} self.write_kwargs = {"verbose": print_log} self._data_index = 0 @@ -354,7 +362,10 @@ def __call__(self, img: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dic subject = meta_data[Key.FILENAME_OR_OBJ] if meta_data else str(self._data_index) patch_index = meta_data.get(Key.PATCH_INDEX, None) if meta_data else None filename = self.folder_layout.filename(subject=f"{subject}", idx=patch_index) + if meta_data and len(ensure_tuple(meta_data.get("spatial_shape", ()))) == len(img.shape): + self.data_kwargs["channel_dim"] = None + err = [] for writer_cls in self.writers: try: writer_obj = writer_cls(**self.init_kwargs) @@ -363,16 +374,18 @@ def __call__(self, img: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dic writer_obj.write(filename, **self.write_kwargs) self.writer_obj = writer_obj except Exception as e: - logging.getLogger(self.__class__.__name__).exception(e, exc_info=True) + err.append(traceback.format_exc()) + logging.getLogger(self.__class__.__name__).debug(e, exc_info=True) logging.getLogger(self.__class__.__name__).info( - f"{writer_cls.__class__.__name__}: unable to write {filename}." + f"{writer_cls.__class__.__name__}: unable to write {filename}.\n" ) else: self._data_index += 1 return img + msg = "\n".join([f"{e}" for e in err]) raise RuntimeError( - f"cannot find a suitable writer for {filename}.\n" + f"{self.__class__.__name__} cannot find a suitable writer for {filename}.\n" " Please install the writer libraries, see also the installation instructions:\n" " https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies.\n" - f" The current registered writers for {self.output_ext}: {self.writers}.\n" + f" The current registered writers for {self.output_ext}: {self.writers}.\n{msg}" ) From 01525439c238ec24ca93939d9cf9d0732a7bbc45 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Thu, 10 Feb 2022 22:28:52 +0000 Subject: [PATCH 11/11] unit testing Signed-off-by: Wenqi Li --- monai/data/image_writer.py | 13 ++++--- monai/transforms/spatial/array.py | 13 ++++--- tests/test_image_rw.py | 58 +++++++++++++++++++++++++++++-- 3 files changed, 73 insertions(+), 11 deletions(-) diff --git a/monai/data/image_writer.py b/monai/data/image_writer.py index 074f4e22cf..e9f753fb34 100644 --- a/monai/data/image_writer.py +++ b/monai/data/image_writer.py @@ -30,6 +30,7 @@ ) DEFAULT_FMT = "%(asctime)s %(levelname)s %(filename)s:%(lineno)d - %(message)s" +EXT_WILDCARD = "*" logger = get_logger(module_name=__name__, fmt=DEFAULT_FMT) if TYPE_CHECKING: @@ -98,13 +99,15 @@ def resolve_writer(ext_name, error_if_not_found=True) -> Sequence: if fmt.startswith("."): fmt = fmt[1:] avail_writers = [] - default_writers = SUPPORTED_WRITERS.get("*", ()) + default_writers = SUPPORTED_WRITERS.get(EXT_WILDCARD, ()) for _writer in look_up_option(fmt, SUPPORTED_WRITERS, default=default_writers): try: _writer() # this triggers `monai.utils.module.require_pkg` to check the system availability avail_writers.append(_writer) except OptionalImportError: - pass + continue + except Exception: # other writer init errors indicating it exists + avail_writers.append(_writer) if not avail_writers and error_if_not_found: raise OptionalImportError(f"No ImageWriter backend found for {fmt}.") writer_tuple = ensure_tuple(avail_writers) @@ -406,7 +409,7 @@ def set_metadata(self, meta_dict: Optional[Mapping] = None, resample: bool = Tru data_array=self.data_obj, affine=affine, target_affine=original_affine if resample else None, - output_spatial_shape=spatial_shape, + output_spatial_shape=spatial_shape if resample else None, mode=options.pop("mode", GridSampleMode.BILINEAR), padding_mode=options.pop("padding_mode", GridSamplePadMode.BORDER), align_corners=options.pop("align_corners", False), @@ -547,7 +550,7 @@ def set_metadata(self, meta_dict: Optional[Mapping], resample: bool = True, **op data_array=self.data_obj, affine=affine, target_affine=original_affine if resample else None, - output_spatial_shape=spatial_shape, + output_spatial_shape=spatial_shape if resample else None, mode=options.pop("mode", GridSampleMode.BILINEAR), padding_mode=options.pop("padding_mode", GridSamplePadMode.BORDER), align_corners=options.pop("align_corners", False), @@ -798,4 +801,4 @@ def init(): for ext in ("nii.gz", "nii"): register_writer(ext, NibabelWriter, ITKWriter) register_writer("nrrd", ITKWriter, NibabelWriter) - register_writer("*", ITKWriter, NibabelWriter, ITKWriter) + register_writer(EXT_WILDCARD, ITKWriter, NibabelWriter, ITKWriter) diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 34cdc37148..0f11dc4390 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -144,8 +144,8 @@ def __call__( the shape should be `(r+1, r+1)` where `r` is the spatial rank of ``img``. dst_affine: destination affine matrix. Defaults to ``None``, which means the same as `src_affine`. the shape should be `(r+1, r+1)` where `r` is the spatial rank of ``img``. - when `dst` is None, the input will be returned without resampling, but the data type - will be `float32`. + when `dst_affine` and `spatial_size` are None, the input will be returned without resampling, + but the data type will be `float32`. spatial_size: output image spatial size. if `spatial_size` and `self.spatial_size` are not defined, the transform will compute a spatial size automatically containing the previous field of view. @@ -169,6 +169,7 @@ def __call__( When both ``monai.config.USE_COMPILED`` and ``align_corners`` are set to ``True``, MONAI's resampling implementation will be used. + Set `dst_affine` and `spatial_size` to `None` to turn off the resampling step. """ if src_affine is None: src_affine = np.eye(4, dtype=np.float64) @@ -182,11 +183,15 @@ def __call__( in_spatial_size = np.asarray(img.shape[1 : spatial_rank + 1]) if isinstance(spatial_size, int) and (spatial_size == -1): # using the input spatial size spatial_size = in_spatial_size - elif spatial_size is None: # auto spatial size + elif spatial_size is None and spatial_rank > 1: # auto spatial size spatial_size, _ = compute_shape_offset(in_spatial_size, src_affine, dst_affine) # type: ignore spatial_size = np.asarray(fall_back_tuple(ensure_tuple(spatial_size)[:spatial_rank], in_spatial_size)) - if allclose(src_affine, dst_affine, atol=AFFINE_TOL) and allclose(spatial_size, in_spatial_size): + if ( + allclose(src_affine, dst_affine, atol=AFFINE_TOL) + and allclose(spatial_size, in_spatial_size) + or spatial_rank == 1 + ): # no significant change, return original image output_data, *_ = convert_to_dst_type(img, img, dtype=torch.float32) return output_data, dst_affine diff --git a/tests/test_image_rw.py b/tests/test_image_rw.py index bd779b8ee7..e1079e63f7 100644 --- a/tests/test_image_rw.py +++ b/tests/test_image_rw.py @@ -18,9 +18,10 @@ import numpy as np from parameterized import parameterized -from monai.data.image_reader import ITKReader, NibabelReader -from monai.data.image_writer import ITKWriter, NibabelWriter +from monai.data.image_reader import ITKReader, NibabelReader, PILReader +from monai.data.image_writer import ITKWriter, NibabelWriter, PILWriter, register_writer, resolve_writer from monai.transforms import LoadImage, SaveImage, moveaxis +from monai.utils import OptionalImportError from tests.utils import TEST_NDARRAYS, assert_allclose @@ -78,5 +79,58 @@ def test_4d(self, reader, writer): self.nifti_rw(test_data, reader, writer, np.float16) +class TestLoadSavePNG(unittest.TestCase): + def setUp(self): + self.test_dir = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self.test_dir, ignore_errors=True) + + def png_rw(self, test_data, reader, writer, dtype, resample=True): + test_data = test_data.astype(dtype) + ndim = len(test_data.shape) - 1 + for p in TEST_NDARRAYS: + output_ext = ".png" + filepath = f"testfile_{ndim}d" + saver = SaveImage( + output_dir=self.test_dir, output_ext=output_ext, resample=resample, separate_folder=False, writer=writer + ) + saver(p(test_data), {"filename_or_obj": f"{filepath}.png", "spatial_shape": (6, 8)}) + saved_path = os.path.join(self.test_dir, filepath + "_trans" + output_ext) + self.assertTrue(os.path.exists(saved_path)) + loader = LoadImage(reader=reader) + data, meta = loader(saved_path) + if meta["original_channel_dim"] == -1: + _test_data = moveaxis(test_data, 0, -1) + else: + _test_data = test_data[0] + assert_allclose(data, _test_data) + + @parameterized.expand(itertools.product([PILReader, ITKReader], [PILWriter, ITKWriter])) + def test_2d(self, reader, writer): + test_data = np.arange(48, dtype=np.uint8).reshape(1, 6, 8) + self.png_rw(test_data, reader, writer, np.uint8) + + @parameterized.expand(itertools.product([PILReader, ITKReader], [PILWriter, ITKWriter])) + def test_rgb(self, reader, writer): + test_data = np.arange(48, dtype=np.uint8).reshape(3, 2, 8) + self.png_rw(test_data, reader, writer, np.uint8, False) + + +class TestRegRes(unittest.TestCase): + def test_0_default(self): + self.assertTrue(len(resolve_writer(".png")) > 0, "has png writer") + self.assertTrue(len(resolve_writer(".nrrd")) > 0, "has nrrd writer") + self.assertTrue(len(resolve_writer("unknown")) > 0, "has writer") + register_writer("unknown1", lambda: (_ for _ in ()).throw(OptionalImportError)) + with self.assertRaises(OptionalImportError): + resolve_writer("unknown1") + + def test_1_new(self): + register_writer("new", lambda x: x + 1) + register_writer("new2", lambda x: x + 1) + self.assertEqual(resolve_writer("new")[0](0), 1) + + if __name__ == "__main__": unittest.main()