From 43e4472a9a8c1cdf0136b2dfb50e4d68a3cc5c3a Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 23 Jul 2021 11:01:19 +0800 Subject: [PATCH 1/8] [DLMED] add ImageWriter Signed-off-by: Nic Ma --- monai/data/image_writer.py | 104 +++++++++++++++++++++++++++++++++++++ 1 file changed, 104 insertions(+) create mode 100644 monai/data/image_writer.py diff --git a/monai/data/image_writer.py b/monai/data/image_writer.py new file mode 100644 index 0000000000..b6ec06edee --- /dev/null +++ b/monai/data/image_writer.py @@ -0,0 +1,104 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from typing import Dict, Optional, Union + +import numpy as np +import torch + +from monai.config import DtypeLike +from monai.data.utils import create_file_basename +from monai.utils import GridSampleMode, GridSamplePadMode, optional_import +from monai.utils import ImageMetaKey as Key + +itk, _ = optional_import("itk", allow_namespace_pkg=True) + + +class ImageWriter(ABC): + def __init__( + self, + output_dir: str = "./", + output_postfix: str = "seg", + output_ext: str = ".nii.gz", + squeeze_end_dims: bool = True, + data_root_dir: str = "", + separate_folder: bool = True, + print_log: bool = True, + ) -> None: + self.output_dir = output_dir + self.output_postfix = output_postfix + self.output_ext = output_ext + self._data_index = 0 + self.squeeze_end_dims = squeeze_end_dims + self.data_root_dir = data_root_dir + self.separate_folder = separate_folder + self.print_log = print_log + + def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None) -> None: + filename = meta_data[Key.FILENAME_OR_OBJ] if meta_data else str(self._data_index) + self._data_index += 1 + patch_index = meta_data.get(Key.PATCH_INDEX, None) if meta_data else None + + if isinstance(data, torch.Tensor): + data = data.detach().cpu().numpy() + + path = create_file_basename( + postfix=self.output_postfix, + input_file_name=filename, + folder_path=self.output_dir, + data_root_dir=self.data_root_dir, + separate_folder=self.separate_folder, + patch_index=patch_index, + ) + path = f"{path}{self.output_ext}" + + # change data to "channel last" format and write to file + data = np.moveaxis(np.asarray(data), 0, -1) + + # if desired, remove trailing singleton dimensions + if self.squeeze_end_dims: + while data.shape[-1] == 1: + data = np.squeeze(data, -1) + + self.write(data=data, meta_data=meta_data, filename=path) + + if self.print_log: + print(f"file written: {path}.") + + def save_batch(self, batch_data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None) -> None: + for i, data in enumerate(batch_data): # save a batch of files + self.save(data=data, meta_data={k: meta_data[k][i] for k in meta_data} if meta_data is not None else None) + + @abstractmethod + def write(self, data, meta_data, filename): + raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") + + +class ITKWriter(ImageWriter): + def __init__( + self, + resample: bool = True, + mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, + padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, + align_corners: bool = False, + dtype: DtypeLike = np.float64, + output_dtype: DtypeLike = np.float32, + ) -> None: + self.resample = resample + self.mode: GridSampleMode = GridSampleMode(mode) + self.padding_mode: GridSamplePadMode = GridSamplePadMode(padding_mode) + self.align_corners = align_corners + self.dtype = dtype + self.output_dtype = output_dtype + + def write(self, data, meta_data, filename): + pass From 7751d16ab76618d074f5bbac3d0f233ce7fe4d41 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 23 Jul 2021 11:06:16 +0800 Subject: [PATCH 2/8] [DLMED] add docs Signed-off-by: Nic Ma --- docs/source/data.rst | 8 ++++++++ monai/data/__init__.py | 1 + 2 files changed, 9 insertions(+) diff --git a/docs/source/data.rst b/docs/source/data.rst index a5c3509fc9..d1800ba86a 100644 --- a/docs/source/data.rst +++ b/docs/source/data.rst @@ -133,6 +133,14 @@ WSIReader .. autoclass:: WSIReader :members: +Image writer +------------ + +ITKWriter +~~~~~~~~~ +.. autoclass:: ITKWriter + :members: + Nifti format handling --------------------- diff --git a/monai/data/__init__.py b/monai/data/__init__.py index af42627f5f..a8cf719499 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -27,6 +27,7 @@ from .grid_dataset import GridPatchDataset, PatchDataset, PatchIter from .image_dataset import ImageDataset from .image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader, WSIReader +from .image_writer import ImageWriter, ITKWriter from .iterable_dataset import CSVIterableDataset, IterableDataset from .nifti_saver import NiftiSaver from .nifti_writer import write_nifti From b32a123f012f9de9c95ea6870b6356844a2c6374 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 23 Jul 2021 11:13:58 +0800 Subject: [PATCH 3/8] [DLMED] adjust APIs Signed-off-by: Nic Ma --- monai/data/image_writer.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/monai/data/image_writer.py b/monai/data/image_writer.py index b6ec06edee..be1e988b53 100644 --- a/monai/data/image_writer.py +++ b/monai/data/image_writer.py @@ -43,7 +43,7 @@ def __init__( self.separate_folder = separate_folder self.print_log = print_log - def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None) -> None: + def write(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None) -> None: filename = meta_data[Key.FILENAME_OR_OBJ] if meta_data else str(self._data_index) self._data_index += 1 patch_index = meta_data.get(Key.PATCH_INDEX, None) if meta_data else None @@ -69,17 +69,17 @@ def save(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] while data.shape[-1] == 1: data = np.squeeze(data, -1) - self.write(data=data, meta_data=meta_data, filename=path) + self._write_file(data=data, meta_data=meta_data, filename=path) if self.print_log: print(f"file written: {path}.") - def save_batch(self, batch_data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None) -> None: - for i, data in enumerate(batch_data): # save a batch of files - self.save(data=data, meta_data={k: meta_data[k][i] for k in meta_data} if meta_data is not None else None) + def write_batch(self, batch_data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None) -> None: + for i, data in enumerate(batch_data): # write a batch of data to files + self.write(data=data, meta_data={k: meta_data[k][i] for k in meta_data} if meta_data is not None else None) @abstractmethod - def write(self, data, meta_data, filename): + def _write_file(self, data, meta_data, filename): raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") From 0bf3feda2e16d0c0c7d30268f33a0d20d502b258 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 23 Jul 2021 17:48:40 +0800 Subject: [PATCH 4/8] [DLMED] add orientation and resample Signed-off-by: Nic Ma --- monai/data/__init__.py | 2 ++ monai/data/image_writer.py | 67 ++++++++++++++++++++++++++++++++++---- monai/data/utils.py | 61 +++++++++++++++++++++++++++++++++- tests/test_itk_writer.py | 35 ++++++++++++++++++++ 4 files changed, 158 insertions(+), 7 deletions(-) create mode 100644 tests/test_itk_writer.py diff --git a/monai/data/__init__.py b/monai/data/__init__.py index a8cf719499..62437e6a79 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -38,6 +38,8 @@ from .test_time_augmentation import TestTimeAugmentation from .thread_buffer import ThreadBuffer, ThreadDataLoader from .utils import ( + adjust_orientation_by_affine, + adjust_spatial_shape_by_affine, compute_importance_map, compute_shape_offset, convert_tables_to_dicts, diff --git a/monai/data/image_writer.py b/monai/data/image_writer.py index be1e988b53..321a1824ca 100644 --- a/monai/data/image_writer.py +++ b/monai/data/image_writer.py @@ -16,9 +16,15 @@ import torch from monai.config import DtypeLike -from monai.data.utils import create_file_basename -from monai.utils import GridSampleMode, GridSamplePadMode, optional_import +from monai.data.utils import ( + adjust_orientation_by_affine, + adjust_spatial_shape_by_affine, + create_file_basename, + to_affine_nd, +) +from monai.utils import GridSampleMode, GridSamplePadMode from monai.utils import ImageMetaKey as Key +from monai.utils import optional_import itk, _ = optional_import("itk", allow_namespace_pkg=True) @@ -69,7 +75,7 @@ def write(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] while data.shape[-1] == 1: data = np.squeeze(data, -1) - self._write_file(data=data, meta_data=meta_data, filename=path) + self._write_file(data=data, filename=path, meta_data=meta_data) if self.print_log: print(f"file written: {path}.") @@ -79,13 +85,20 @@ def write_batch(self, batch_data: Union[torch.Tensor, np.ndarray], meta_data: Op self.write(data=data, meta_data={k: meta_data[k][i] for k in meta_data} if meta_data is not None else None) @abstractmethod - def _write_file(self, data, meta_data, filename): + def _write_file(self, data: np.ndarray, filename: str, meta_data: Optional[np.ndarray] = None): raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") class ITKWriter(ImageWriter): def __init__( self, + output_dir: str = "./", + output_postfix: str = "seg", + output_ext: str = ".dcm", + squeeze_end_dims: bool = True, + data_root_dir: str = "", + separate_folder: bool = True, + print_log: bool = True, resample: bool = True, mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, @@ -93,6 +106,15 @@ def __init__( dtype: DtypeLike = np.float64, output_dtype: DtypeLike = np.float32, ) -> None: + super().__init__( + output_dir=output_dir, + output_postfix=output_postfix, + output_ext=output_ext, + squeeze_end_dims=squeeze_end_dims, + data_root_dir=data_root_dir, + separate_folder=separate_folder, + print_log=print_log, + ) self.resample = resample self.mode: GridSampleMode = GridSampleMode(mode) self.padding_mode: GridSamplePadMode = GridSamplePadMode(padding_mode) @@ -100,5 +122,38 @@ def __init__( self.dtype = dtype self.output_dtype = output_dtype - def write(self, data, meta_data, filename): - pass + def _write_file(self, data: np.ndarray, filename: str, meta_data: Optional[np.ndarray] = None): + target_affine = meta_data.get("original_affine", None) if meta_data else None + affine = meta_data.get("affine", None) if meta_data else None + spatial_shape = meta_data.get("spatial_shape", None) if meta_data else None + + if not isinstance(data, np.ndarray): + raise AssertionError("input data must be numpy array.") + dtype = self.dtype or data.dtype + sr = min(data.ndim, 3) + if affine is None: + affine = np.eye(4, dtype=np.float64) + affine = to_affine_nd(sr, affine) + + if target_affine is None: + target_affine = affine + target_affine = to_affine_nd(sr, target_affine) + + if not np.allclose(affine, target_affine, atol=1e-3): + data, affine = adjust_orientation_by_affine(data=data, affine=affine, target_affine=target_affine) + if self.resample: + data, affine = adjust_spatial_shape_by_affine( + data=data, + affine=affine, + target_affine=target_affine, + output_spatial_shape=spatial_shape, + mode=self.mode, + padding_mode=self.padding_mode, + align_corners=self.align_corners, + dtype=dtype, + ) + + itk_np_view = itk.image_view_from_array(data.astype(self.output_dtype)) + # TODO: need to set affine matrix into file header + # itk_np_view.SetMatrix(to_affine_nd(3, affine)) + itk.imwrite(itk_np_view, filename) diff --git a/monai/data/utils.py b/monai/data/utils.py index 94c8582e9a..ed413c2334 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -26,10 +26,13 @@ import torch from torch.utils.data._utils.collate import default_collate -from monai.networks.layers.simplelayers import GaussianFilter +from monai.config import DtypeLike +from monai.networks.layers import AffineTransform, GaussianFilter from monai.utils import ( MAX_SEED, BlendMode, + GridSampleMode, + GridSamplePadMode, NumpyPadMode, ensure_tuple, ensure_tuple_rep, @@ -75,6 +78,8 @@ "pad_list_data_collate", "no_collation", "convert_tables_to_dicts", + "adjust_orientation_by_affine", + "adjust_spatial_shape_by_affine", ] @@ -1127,3 +1132,57 @@ def convert_tables_to_dicts( data = [dict(d, **{k: v[i] for k, v in groups.items()}) for i, d in enumerate(data)] return data + + +def adjust_orientation_by_affine(data: np.ndarray, affine: np.ndarray, target_affine: np.ndarray): + start_ornt = nib.orientations.io_orientation(affine) + target_ornt = nib.orientations.io_orientation(target_affine) + ornt_transform = nib.orientations.ornt_transform(start_ornt, target_ornt) + data_shape = data.shape + data = nib.orientations.apply_orientation(data, ornt_transform) + new_affine = affine @ nib.orientations.inv_ornt_aff(ornt_transform, data_shape) + return data, new_affine + + +def adjust_spatial_shape_by_affine( + data: np.ndarray, + affine: np.ndarray, + target_affine: np.ndarray, + output_spatial_shape: Union[Sequence[int], np.ndarray, None] = None, + mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, + padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, + align_corners: bool = False, + dtype: DtypeLike = np.float64, +): + affine_xform = AffineTransform( + normalized=False, mode=mode, padding_mode=padding_mode, align_corners=align_corners, reverse_indexing=True + ) + transform = np.linalg.inv(affine) @ target_affine + if output_spatial_shape is None: + output_spatial_shape, _ = compute_shape_offset(data.shape, affine, target_affine) + output_spatial_shape_ = list(output_spatial_shape) if output_spatial_shape is not None else [] + if data.ndim > 3: # multi channel, resampling each channel + while len(output_spatial_shape_) < 3: + output_spatial_shape_ = output_spatial_shape_ + [1] + spatial_shape, channel_shape = data.shape[:3], data.shape[3:] + data_np = data.reshape(list(spatial_shape) + [-1]) + data_np = np.moveaxis(data_np, -1, 0) # channel first for pytorch + data_torch = affine_xform( + torch.as_tensor(np.ascontiguousarray(data_np).astype(dtype)).unsqueeze(0), + torch.as_tensor(np.ascontiguousarray(transform).astype(dtype)), + spatial_size=output_spatial_shape_[:3], + ) + data_np = data_torch.squeeze(0).detach().cpu().numpy() + data_np = np.moveaxis(data_np, 0, -1) # channel last to save file + data_np = data_np.reshape(list(data_np.shape[:3]) + list(channel_shape)) + else: # single channel image, need to expand to have batch and channel + while len(output_spatial_shape_) < len(data.shape): + output_spatial_shape_ = output_spatial_shape_ + [1] + data_torch = affine_xform( + torch.as_tensor(np.ascontiguousarray(data).astype(dtype)[None, None]), + torch.as_tensor(np.ascontiguousarray(transform).astype(dtype)), + spatial_size=output_spatial_shape_[: len(data.shape)], + ) + data_np = data_torch.squeeze(0).squeeze(0).detach().cpu().numpy() + + return data_np, target_affine diff --git a/tests/test_itk_writer.py b/tests/test_itk_writer.py new file mode 100644 index 0000000000..f2566f4dcb --- /dev/null +++ b/tests/test_itk_writer.py @@ -0,0 +1,35 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +import unittest + +import torch + +from monai.data import ITKWriter + + +class TestITKWriter(unittest.TestCase): + def test_saved_content(self): + with tempfile.TemporaryDirectory() as tempdir: + + writer = ITKWriter(output_dir=tempdir, output_postfix="seg", output_ext=".nii.gz") + + meta_data = {"filename_or_obj": ["testfile" + str(i) + ".nii" for i in range(8)]} + writer.write_batch(torch.zeros(8, 1, 2, 2), meta_data) + for i in range(8): + filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_seg.nii.gz") + self.assertTrue(os.path.exists(os.path.join(tempdir, filepath))) + + +if __name__ == "__main__": + unittest.main() From 31d93e4e4e3ffe524530a634a91e9be7328356d7 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 23 Jul 2021 17:52:18 +0800 Subject: [PATCH 5/8] [DLMED] fix flake8 Signed-off-by: Nic Ma --- monai/data/image_writer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/data/image_writer.py b/monai/data/image_writer.py index 321a1824ca..3da1d32a0f 100644 --- a/monai/data/image_writer.py +++ b/monai/data/image_writer.py @@ -85,7 +85,7 @@ def write_batch(self, batch_data: Union[torch.Tensor, np.ndarray], meta_data: Op self.write(data=data, meta_data={k: meta_data[k][i] for k in meta_data} if meta_data is not None else None) @abstractmethod - def _write_file(self, data: np.ndarray, filename: str, meta_data: Optional[np.ndarray] = None): + def _write_file(self, data: np.ndarray, filename: str, meta_data: Optional[Dict] = None): raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") @@ -122,7 +122,7 @@ def __init__( self.dtype = dtype self.output_dtype = output_dtype - def _write_file(self, data: np.ndarray, filename: str, meta_data: Optional[np.ndarray] = None): + def _write_file(self, data: np.ndarray, filename: str, meta_data: Optional[Dict] = None): target_affine = meta_data.get("original_affine", None) if meta_data else None affine = meta_data.get("affine", None) if meta_data else None spatial_shape = meta_data.get("spatial_shape", None) if meta_data else None From 61ba6eeafd6deea2d8498a719ecf7bffb5cd0e4f Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 23 Jul 2021 17:59:45 +0800 Subject: [PATCH 6/8] [DLMED] skip min test Signed-off-by: Nic Ma --- tests/min_tests.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/min_tests.py b/tests/min_tests.py index 1cd54f35d0..08636a05f4 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -135,6 +135,7 @@ def run_testsuit(): "test_unetr_block", "test_vit", "test_handler_decollate_batch", + "test_itk_writer", ] assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" From cff6d866c59768aa4299f8e5e305c8357097d883 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Mon, 26 Jul 2021 12:54:46 +0800 Subject: [PATCH 7/8] [DLMED] save affine into image header Signed-off-by: Nic Ma --- monai/data/image_writer.py | 9 ++++++--- tests/test_itk_writer.py | 7 ++++--- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/monai/data/image_writer.py b/monai/data/image_writer.py index 3da1d32a0f..7eb53e0896 100644 --- a/monai/data/image_writer.py +++ b/monai/data/image_writer.py @@ -132,7 +132,7 @@ def _write_file(self, data: np.ndarray, filename: str, meta_data: Optional[Dict] dtype = self.dtype or data.dtype sr = min(data.ndim, 3) if affine is None: - affine = np.eye(4, dtype=np.float64) + affine = np.eye(sr + 1, dtype=np.float64) affine = to_affine_nd(sr, affine) if target_affine is None: @@ -154,6 +154,9 @@ def _write_file(self, data: np.ndarray, filename: str, meta_data: Optional[Dict] ) itk_np_view = itk.image_view_from_array(data.astype(self.output_dtype)) - # TODO: need to set affine matrix into file header - # itk_np_view.SetMatrix(to_affine_nd(3, affine)) + # set affine matrix into file header + spacing = np.linalg.norm(affine[:-1, :-1] @ np.eye(sr), axis=0) + itk_np_view.SetSpacing(spacing) + itk_np_view.SetDirection(affine[:-1, :-1] @ np.diag(1 / spacing)) + itk_np_view.SetOrigin(affine[:-1, -1]) itk.imwrite(itk_np_view, filename) diff --git a/tests/test_itk_writer.py b/tests/test_itk_writer.py index f2566f4dcb..e92d4acdea 100644 --- a/tests/test_itk_writer.py +++ b/tests/test_itk_writer.py @@ -13,6 +13,7 @@ import tempfile import unittest +import numpy as np import torch from monai.data import ITKWriter @@ -22,12 +23,12 @@ class TestITKWriter(unittest.TestCase): def test_saved_content(self): with tempfile.TemporaryDirectory() as tempdir: - writer = ITKWriter(output_dir=tempdir, output_postfix="seg", output_ext=".nii.gz") + writer = ITKWriter(output_dir=tempdir, output_postfix="seg", output_ext=".dcm", output_dtype=np.uint8) meta_data = {"filename_or_obj": ["testfile" + str(i) + ".nii" for i in range(8)]} - writer.write_batch(torch.zeros(8, 1, 2, 2), meta_data) + writer.write_batch(torch.zeros(8, 1, 2, 3, 4), meta_data) for i in range(8): - filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_seg.nii.gz") + filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_seg.dcm") self.assertTrue(os.path.exists(os.path.join(tempdir, filepath))) From 5cd3d353938a078bc63b8b957822891340bc5d4f Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Mon, 26 Jul 2021 18:28:21 +0800 Subject: [PATCH 8/8] [DLMED] add more tests Signed-off-by: Nic Ma --- monai/data/image_writer.py | 17 +++++++------- tests/test_itk_writer.py | 48 ++++++++++++++++++++++++++++++++------ 2 files changed, 49 insertions(+), 16 deletions(-) diff --git a/monai/data/image_writer.py b/monai/data/image_writer.py index 7eb53e0896..c5c761a45a 100644 --- a/monai/data/image_writer.py +++ b/monai/data/image_writer.py @@ -80,10 +80,6 @@ def write(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] if self.print_log: print(f"file written: {path}.") - def write_batch(self, batch_data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None) -> None: - for i, data in enumerate(batch_data): # write a batch of data to files - self.write(data=data, meta_data={k: meta_data[k][i] for k in meta_data} if meta_data is not None else None) - @abstractmethod def _write_file(self, data: np.ndarray, filename: str, meta_data: Optional[Dict] = None): raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") @@ -153,10 +149,13 @@ def _write_file(self, data: np.ndarray, filename: str, meta_data: Optional[Dict] dtype=dtype, ) - itk_np_view = itk.image_view_from_array(data.astype(self.output_dtype)) + itk_view = itk.image_view_from_array(data.astype(self.output_dtype)) + # nibabel to itk affine + flip_diag = [[-1, 1], [-1, -1, 1], [-1, -1, 1, 1]][sr - 1] + affine = np.diag(flip_diag) @ affine # set affine matrix into file header spacing = np.linalg.norm(affine[:-1, :-1] @ np.eye(sr), axis=0) - itk_np_view.SetSpacing(spacing) - itk_np_view.SetDirection(affine[:-1, :-1] @ np.diag(1 / spacing)) - itk_np_view.SetOrigin(affine[:-1, -1]) - itk.imwrite(itk_np_view, filename) + itk_view.SetSpacing(spacing) + itk_view.SetDirection(affine[:-1, :-1] @ np.diag(1 / spacing)) + itk_view.SetOrigin(affine[:-1, -1]) + itk.imwrite(itk_view, filename) diff --git a/tests/test_itk_writer.py b/tests/test_itk_writer.py index e92d4acdea..2d33efceec 100644 --- a/tests/test_itk_writer.py +++ b/tests/test_itk_writer.py @@ -12,24 +12,58 @@ import os import tempfile import unittest +from parameterized import parameterized import numpy as np import torch -from monai.data import ITKWriter +from monai.data import ITKReader, ITKWriter + +TEST_CASE_1 = [".nii.gz", np.float32, None, None, None, False, torch.zeros(8, 1, 2, 3, 4), (4, 3, 2)] + +TEST_CASE_2 = [ + ".dcm", + np.uint8, + [np.diag(np.ones(4)) * 1.0 for _ in range(8)], + [np.diag(np.ones(4)) * 5.0 for _ in range(8)], + [(10, 10, 2) for _ in range(8)], + True, + torch.zeros(8, 3, 2, 3, 4), + (10, 10, 2, 3), +] class TestITKWriter(unittest.TestCase): - def test_saved_content(self): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + def test_saved_content(self, output_ext, output_dtype, affine, original_affine, output_shape, resample, data, expected_shape): with tempfile.TemporaryDirectory() as tempdir: - writer = ITKWriter(output_dir=tempdir, output_postfix="seg", output_ext=".dcm", output_dtype=np.uint8) + writer = ITKWriter( + output_dir=tempdir, + output_postfix="seg", + output_ext=output_ext, + output_dtype=output_dtype, + resample=resample, + ) + + meta_data = {"filename_or_obj": ["testfile" + str(i) + ".nii.gz" for i in range(8)]} + if output_shape is not None: + meta_data["spatial_shape"] = output_shape + if affine is not None: + meta_data["affine"] = affine + if original_affine is not None: + meta_data["original_affine"] = original_affine - meta_data = {"filename_or_obj": ["testfile" + str(i) + ".nii" for i in range(8)]} - writer.write_batch(torch.zeros(8, 1, 2, 3, 4), meta_data) for i in range(8): - filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_seg.dcm") - self.assertTrue(os.path.exists(os.path.join(tempdir, filepath))) + writer.write(data=data[i], meta_data={k: meta_data[k][i] for k in meta_data} if meta_data is not None else None) + filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_seg" + output_ext) + reader = ITKReader() + img = reader.read(data=os.path.join(tempdir, filepath)) + result, meta = reader.get_data(img) + self.assertTupleEqual(result.shape, expected_shape) + if affine is not None: + # no need to compare the last line of affine matrix + np.testing.assert_allclose(meta["affine"][:-1], original_affine[i][:-1]) if __name__ == "__main__":