From ea998b1fe5ad3b892cdb762b8a66f8b40d0a1fb9 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Sun, 13 Feb 2022 00:31:25 +0800 Subject: [PATCH 1/2] [DLMED] add ensure_channel_first Signed-off-by: Nic Ma --- monai/transforms/io/array.py | 25 +++++++++++++++++++++---- monai/transforms/io/dictionary.py | 5 ++++- tests/test_load_image.py | 16 +++++++++++++--- tests/test_load_imaged.py | 2 +- 4 files changed, 39 insertions(+), 9 deletions(-) diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index a852188f2d..25712cdde4 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -24,11 +24,12 @@ import numpy as np import torch -from monai.config import DtypeLike, PathLike +from monai.config import DtypeLike, PathLike, NdarrayOrTensor from monai.data import image_writer from monai.data.folder_layout import FolderLayout from monai.data.image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader from monai.transforms.transform import Transform +from monai.transforms.utility.array import EnsureChannelFirst from monai.utils import GridSampleMode, GridSamplePadMode from monai.utils import ImageMetaKey as Key from monai.utils import InterpolateMode, OptionalImportError, ensure_tuple, look_up_option, optional_import @@ -91,7 +92,15 @@ class LoadImage(Transform): """ - def __init__(self, reader=None, image_only: bool = False, dtype: DtypeLike = np.float32, *args, **kwargs) -> None: + def __init__( + self, + reader=None, + image_only: bool = False, + dtype: DtypeLike = np.float32, + ensure_channel_first: bool = False, + *args, + **kwargs, + ) -> None: """ Args: reader: reader to load image file and meta data @@ -104,6 +113,8 @@ def __init__(self, reader=None, image_only: bool = False, dtype: DtypeLike = np. image_only: if True return only the image volume, otherwise return image data array and header dict. dtype: if not None convert the loaded image to this data type. + ensure_channel_first: if `True` and loaded both image array and meta data, automatically convert + the image array shape to `channel first`. default to `False`. args: additional parameters for reader if providing a reader name. kwargs: additional parameters for reader if providing a reader name. @@ -121,6 +132,7 @@ def __init__(self, reader=None, image_only: bool = False, dtype: DtypeLike = np. self.auto_select = reader is None self.image_only = image_only self.dtype = dtype + self.ensure_channel_first = ensure_channel_first self.readers: List[ImageReader] = [] for r in SUPPORTED_READERS: # set predefined readers as default @@ -218,14 +230,19 @@ def __call__(self, filename: Union[Sequence[PathLike], PathLike], reader: Option f" The current registered: {self.readers}.\n{msg}" ) + img_array: NdarrayOrTensor img_array, meta_data = reader.get_data(img) img_array = img_array.astype(self.dtype, copy=False) + if not isinstance(meta_data, dict): + raise ValueError("`meta_data` must be a dict.") + # make sure all elements in metadata are little endian + meta_data = switch_endianness(meta_data, "<") + if self.ensure_channel_first: + img_array = EnsureChannelFirst()(img_array, meta_data) if self.image_only: return img_array meta_data[Key.FILENAME_OR_OBJ] = f"{ensure_tuple(filename)[0]}" # Path obj should be strings for data loader - # make sure all elements in metadata are little endian - meta_data = switch_endianness(meta_data, "<") return img_array, meta_data diff --git a/monai/transforms/io/dictionary.py b/monai/transforms/io/dictionary.py index d9a3f44e4b..ebe06898a6 100644 --- a/monai/transforms/io/dictionary.py +++ b/monai/transforms/io/dictionary.py @@ -73,6 +73,7 @@ def __init__( meta_key_postfix: str = DEFAULT_POST_FIX, overwriting: bool = False, image_only: bool = False, + ensure_channel_first: bool = False, allow_missing_keys: bool = False, *args, **kwargs, @@ -97,12 +98,14 @@ def __init__( default is False, which will raise exception if encountering existing key. image_only: if True return dictionary containing just only the image volumes, otherwise return dictionary containing image data array and header dict per input key. + ensure_channel_first: if `True` and loaded both image array and meta data, automatically convert + the image array shape to `channel first`. default to `False`. allow_missing_keys: don't raise exception if key is missing. args: additional parameters for reader if providing a reader name. kwargs: additional parameters for reader if providing a reader name. """ super().__init__(keys, allow_missing_keys) - self._loader = LoadImage(reader, image_only, dtype, *args, **kwargs) + self._loader = LoadImage(reader, image_only, dtype, ensure_channel_first, *args, **kwargs) if not isinstance(meta_key_postfix, str): raise TypeError(f"meta_key_postfix must be a str but is {type(meta_key_postfix).__name__}.") self.meta_keys = ensure_tuple_rep(None, len(self.keys)) if meta_keys is None else ensure_tuple(meta_keys) diff --git a/tests/test_load_image.py b/tests/test_load_image.py index f215c925d8..2c8638ebbe 100644 --- a/tests/test_load_image.py +++ b/tests/test_load_image.py @@ -116,7 +116,11 @@ def get_data(self, _obj): TEST_CASE_13 = [{"reader": "nibabelreader", "channel_dim": 0}, "test_image.nii.gz", (3, 128, 128, 128)] -TEST_CASE_14 = [{"reader": "nibabelreader", "channel_dim": -1}, "test_image.nii.gz", (128, 128, 128, 3)] +TEST_CASE_14 = [ + {"reader": "nibabelreader", "channel_dim": -1, "ensure_channel_first": True}, + "test_image.nii.gz", + (128, 128, 128, 3), +] TEST_CASE_15 = [{"reader": "nibabelreader", "channel_dim": 2}, "test_image.nii.gz", (128, 128, 3, 128)] @@ -124,7 +128,11 @@ def get_data(self, _obj): TEST_CASE_17 = [{"reader": "ITKReader", "channel_dim": -1}, "test_image.nii.gz", (128, 128, 128, 3)] -TEST_CASE_18 = [{"reader": "ITKReader", "channel_dim": 2}, "test_image.nii.gz", (128, 128, 3, 128)] +TEST_CASE_18 = [ + {"reader": "ITKReader", "channel_dim": 2, "ensure_channel_first": True}, + "test_image.nii.gz", + (128, 128, 3, 128), +] class TestLoadImage(unittest.TestCase): @@ -290,7 +298,9 @@ def test_channel_dim(self, input_param, filename, expected_shape): nib.save(nib.Nifti1Image(test_image, np.eye(4)), filename) result = LoadImage(**input_param)(filename) - self.assertTupleEqual(result[0].shape, expected_shape) + self.assertTupleEqual( + result[0].shape, (3, 128, 128, 128) if input_param.get("ensure_channel_first", False) else expected_shape + ) self.assertTupleEqual(tuple(result[1]["spatial_shape"]), (128, 128, 128)) self.assertEqual(result[1]["original_channel_dim"], input_param["channel_dim"]) diff --git a/tests/test_load_imaged.py b/tests/test_load_imaged.py index f4499311d3..bc001cf2fd 100644 --- a/tests/test_load_imaged.py +++ b/tests/test_load_imaged.py @@ -82,7 +82,7 @@ class TestConsistency(unittest.TestCase): def _cmp(self, filename, shape, ch_shape, reader_1, reader_2, outname, ext): data_dict = {"img": filename} keys = data_dict.keys() - xforms = Compose([LoadImaged(keys, reader=reader_1), EnsureChannelFirstD(keys)]) + xforms = Compose([LoadImaged(keys, reader=reader_1, ensure_channel_first=True)]) img_dict = xforms(data_dict) # load dicom with itk self.assertTupleEqual(img_dict["img"].shape, ch_shape) self.assertTupleEqual(tuple(img_dict[PostFix.meta("img")]["spatial_shape"]), shape) From be6607152ce318d342d05c14bcab92c7650cbf0f Mon Sep 17 00:00:00 2001 From: monai-bot Date: Sat, 12 Feb 2022 16:37:31 +0000 Subject: [PATCH 2/2] [MONAI] python code formatting Signed-off-by: monai-bot --- monai/transforms/io/array.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms/io/array.py b/monai/transforms/io/array.py index 25712cdde4..f3715b2712 100644 --- a/monai/transforms/io/array.py +++ b/monai/transforms/io/array.py @@ -24,7 +24,7 @@ import numpy as np import torch -from monai.config import DtypeLike, PathLike, NdarrayOrTensor +from monai.config import DtypeLike, NdarrayOrTensor, PathLike from monai.data import image_writer from monai.data.folder_layout import FolderLayout from monai.data.image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader