Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions monai/transforms/io/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,12 @@
import numpy as np
import torch

from monai.config import DtypeLike, PathLike
from monai.config import DtypeLike, NdarrayOrTensor, PathLike
from monai.data import image_writer
from monai.data.folder_layout import FolderLayout
from monai.data.image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader
from monai.transforms.transform import Transform
from monai.transforms.utility.array import EnsureChannelFirst
from monai.utils import GridSampleMode, GridSamplePadMode
from monai.utils import ImageMetaKey as Key
from monai.utils import InterpolateMode, OptionalImportError, ensure_tuple, look_up_option, optional_import
Expand Down Expand Up @@ -91,7 +92,15 @@ class LoadImage(Transform):

"""

def __init__(self, reader=None, image_only: bool = False, dtype: DtypeLike = np.float32, *args, **kwargs) -> None:
def __init__(
self,
reader=None,
image_only: bool = False,
dtype: DtypeLike = np.float32,
ensure_channel_first: bool = False,
*args,
**kwargs,
) -> None:
"""
Args:
reader: reader to load image file and meta data
Expand All @@ -104,6 +113,8 @@ def __init__(self, reader=None, image_only: bool = False, dtype: DtypeLike = np.

image_only: if True return only the image volume, otherwise return image data array and header dict.
dtype: if not None convert the loaded image to this data type.
ensure_channel_first: if `True` and loaded both image array and meta data, automatically convert
the image array shape to `channel first`. default to `False`.
args: additional parameters for reader if providing a reader name.
kwargs: additional parameters for reader if providing a reader name.

Expand All @@ -121,6 +132,7 @@ def __init__(self, reader=None, image_only: bool = False, dtype: DtypeLike = np.
self.auto_select = reader is None
self.image_only = image_only
self.dtype = dtype
self.ensure_channel_first = ensure_channel_first

self.readers: List[ImageReader] = []
for r in SUPPORTED_READERS: # set predefined readers as default
Expand Down Expand Up @@ -218,14 +230,19 @@ def __call__(self, filename: Union[Sequence[PathLike], PathLike], reader: Option
f" The current registered: {self.readers}.\n{msg}"
)

img_array: NdarrayOrTensor
img_array, meta_data = reader.get_data(img)
img_array = img_array.astype(self.dtype, copy=False)
if not isinstance(meta_data, dict):
raise ValueError("`meta_data` must be a dict.")
# make sure all elements in metadata are little endian
meta_data = switch_endianness(meta_data, "<")
if self.ensure_channel_first:
img_array = EnsureChannelFirst()(img_array, meta_data)

if self.image_only:
return img_array
meta_data[Key.FILENAME_OR_OBJ] = f"{ensure_tuple(filename)[0]}" # Path obj should be strings for data loader
# make sure all elements in metadata are little endian
meta_data = switch_endianness(meta_data, "<")

return img_array, meta_data

Expand Down
5 changes: 4 additions & 1 deletion monai/transforms/io/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def __init__(
meta_key_postfix: str = DEFAULT_POST_FIX,
overwriting: bool = False,
image_only: bool = False,
ensure_channel_first: bool = False,
allow_missing_keys: bool = False,
*args,
**kwargs,
Expand All @@ -97,12 +98,14 @@ def __init__(
default is False, which will raise exception if encountering existing key.
image_only: if True return dictionary containing just only the image volumes, otherwise return
dictionary containing image data array and header dict per input key.
ensure_channel_first: if `True` and loaded both image array and meta data, automatically convert
the image array shape to `channel first`. default to `False`.
allow_missing_keys: don't raise exception if key is missing.
args: additional parameters for reader if providing a reader name.
kwargs: additional parameters for reader if providing a reader name.
"""
super().__init__(keys, allow_missing_keys)
self._loader = LoadImage(reader, image_only, dtype, *args, **kwargs)
self._loader = LoadImage(reader, image_only, dtype, ensure_channel_first, *args, **kwargs)
if not isinstance(meta_key_postfix, str):
raise TypeError(f"meta_key_postfix must be a str but is {type(meta_key_postfix).__name__}.")
self.meta_keys = ensure_tuple_rep(None, len(self.keys)) if meta_keys is None else ensure_tuple(meta_keys)
Expand Down
16 changes: 13 additions & 3 deletions tests/test_load_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,15 +116,23 @@ def get_data(self, _obj):

TEST_CASE_13 = [{"reader": "nibabelreader", "channel_dim": 0}, "test_image.nii.gz", (3, 128, 128, 128)]

TEST_CASE_14 = [{"reader": "nibabelreader", "channel_dim": -1}, "test_image.nii.gz", (128, 128, 128, 3)]
TEST_CASE_14 = [
{"reader": "nibabelreader", "channel_dim": -1, "ensure_channel_first": True},
"test_image.nii.gz",
(128, 128, 128, 3),
]

TEST_CASE_15 = [{"reader": "nibabelreader", "channel_dim": 2}, "test_image.nii.gz", (128, 128, 3, 128)]

TEST_CASE_16 = [{"reader": "itkreader", "channel_dim": 0}, "test_image.nii.gz", (3, 128, 128, 128)]

TEST_CASE_17 = [{"reader": "ITKReader", "channel_dim": -1}, "test_image.nii.gz", (128, 128, 128, 3)]

TEST_CASE_18 = [{"reader": "ITKReader", "channel_dim": 2}, "test_image.nii.gz", (128, 128, 3, 128)]
TEST_CASE_18 = [
{"reader": "ITKReader", "channel_dim": 2, "ensure_channel_first": True},
"test_image.nii.gz",
(128, 128, 3, 128),
]


class TestLoadImage(unittest.TestCase):
Expand Down Expand Up @@ -290,7 +298,9 @@ def test_channel_dim(self, input_param, filename, expected_shape):
nib.save(nib.Nifti1Image(test_image, np.eye(4)), filename)
result = LoadImage(**input_param)(filename)

self.assertTupleEqual(result[0].shape, expected_shape)
self.assertTupleEqual(
result[0].shape, (3, 128, 128, 128) if input_param.get("ensure_channel_first", False) else expected_shape
)
self.assertTupleEqual(tuple(result[1]["spatial_shape"]), (128, 128, 128))
self.assertEqual(result[1]["original_channel_dim"], input_param["channel_dim"])

Expand Down
2 changes: 1 addition & 1 deletion tests/test_load_imaged.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ class TestConsistency(unittest.TestCase):
def _cmp(self, filename, shape, ch_shape, reader_1, reader_2, outname, ext):
data_dict = {"img": filename}
keys = data_dict.keys()
xforms = Compose([LoadImaged(keys, reader=reader_1), EnsureChannelFirstD(keys)])
xforms = Compose([LoadImaged(keys, reader=reader_1, ensure_channel_first=True)])
img_dict = xforms(data_dict) # load dicom with itk
self.assertTupleEqual(img_dict["img"].shape, ch_shape)
self.assertTupleEqual(tuple(img_dict[PostFix.meta("img")]["spatial_shape"]), shape)
Expand Down