diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index 303091f0d8..9c212784ca 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -323,14 +323,22 @@ class NibabelReader(ImageReader): Args: as_closest_canonical: if True, load the image as closest to canonical axis format. + squeeze_non_spatial_dims: if True, non-spatial singletons will be squeezed, e.g. (256,256,1,3) -> (256,256,3) kwargs: additional args for `nibabel.load` API. more details about available args: https://github.com/nipy/nibabel/blob/master/nibabel/loadsave.py """ - def __init__(self, as_closest_canonical: bool = False, dtype: DtypeLike = np.float32, **kwargs): + def __init__( + self, + as_closest_canonical: bool = False, + squeeze_non_spatial_dims: bool = False, + dtype: DtypeLike = np.float32, + **kwargs, + ): super().__init__() self.as_closest_canonical = as_closest_canonical + self.squeeze_non_spatial_dims = squeeze_non_spatial_dims self.dtype = dtype self.kwargs = kwargs @@ -395,6 +403,10 @@ def get_data(self, img): header["affine"] = self._get_affine(i) header["spatial_shape"] = self._get_spatial_shape(i) data = self._get_array_data(i) + if self.squeeze_non_spatial_dims: + for d in range(len(data.shape), len(header["spatial_shape"]), -1): + if data.shape[d - 1] == 1: + data = data.squeeze(axis=d - 1) img_array.append(data) header["original_channel_dim"] = "no_channel" if len(data.shape) == len(header["spatial_shape"]) else -1 _copy_compatible_dict(header, compatible_meta) diff --git a/tests/test_load_image.py b/tests/test_load_image.py index 3f78d3892d..687829dfa8 100644 --- a/tests/test_load_image.py +++ b/tests/test_load_image.py @@ -170,6 +170,21 @@ def test_itk_reader_multichannel(self): np.testing.assert_allclose(result[:, :, 1], test_image[:, :, 1].T) np.testing.assert_allclose(result[:, :, 2], test_image[:, :, 2].T) + def test_load_nifti_multichannel(self): + test_image = np.random.randint(0, 256, size=(31, 64, 16, 2)).astype(np.float32) + with tempfile.TemporaryDirectory() as tempdir: + filename = os.path.join(tempdir, "test_image.nii.gz") + itk_np_view = itk.image_view_from_array(test_image, is_vector=True) + itk.imwrite(itk_np_view, filename) + + itk_img, itk_header = LoadImage(reader=ITKReader())(Path(filename)) + self.assertTupleEqual(tuple(itk_header["spatial_shape"]), (16, 64, 31)) + self.assertTupleEqual(tuple(itk_img.shape), (16, 64, 31, 2)) + + nib_image, nib_header = LoadImage(reader=NibabelReader(squeeze_non_spatial_dims=True))(Path(filename)) + self.assertTupleEqual(tuple(nib_header["spatial_shape"]), (16, 64, 31)) + self.assertTupleEqual(tuple(nib_image.shape), (16, 64, 31, 2)) + def test_load_png(self): spatial_size = (256, 224) test_image = np.random.randint(0, 256, size=spatial_size)