diff --git a/docs/source/data.rst b/docs/source/data.rst index a5c3509fc9..d1800ba86a 100644 --- a/docs/source/data.rst +++ b/docs/source/data.rst @@ -133,6 +133,14 @@ WSIReader .. autoclass:: WSIReader :members: +Image writer +------------ + +ITKWriter +~~~~~~~~~ +.. autoclass:: ITKWriter + :members: + Nifti format handling --------------------- diff --git a/monai/data/__init__.py b/monai/data/__init__.py index af42627f5f..62437e6a79 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -27,6 +27,7 @@ from .grid_dataset import GridPatchDataset, PatchDataset, PatchIter from .image_dataset import ImageDataset from .image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader, WSIReader +from .image_writer import ImageWriter, ITKWriter from .iterable_dataset import CSVIterableDataset, IterableDataset from .nifti_saver import NiftiSaver from .nifti_writer import write_nifti @@ -37,6 +38,8 @@ from .test_time_augmentation import TestTimeAugmentation from .thread_buffer import ThreadBuffer, ThreadDataLoader from .utils import ( + adjust_orientation_by_affine, + adjust_spatial_shape_by_affine, compute_importance_map, compute_shape_offset, convert_tables_to_dicts, diff --git a/monai/data/image_writer.py b/monai/data/image_writer.py new file mode 100644 index 0000000000..c5c761a45a --- /dev/null +++ b/monai/data/image_writer.py @@ -0,0 +1,161 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from typing import Dict, Optional, Union + +import numpy as np +import torch + +from monai.config import DtypeLike +from monai.data.utils import ( + adjust_orientation_by_affine, + adjust_spatial_shape_by_affine, + create_file_basename, + to_affine_nd, +) +from monai.utils import GridSampleMode, GridSamplePadMode +from monai.utils import ImageMetaKey as Key +from monai.utils import optional_import + +itk, _ = optional_import("itk", allow_namespace_pkg=True) + + +class ImageWriter(ABC): + def __init__( + self, + output_dir: str = "./", + output_postfix: str = "seg", + output_ext: str = ".nii.gz", + squeeze_end_dims: bool = True, + data_root_dir: str = "", + separate_folder: bool = True, + print_log: bool = True, + ) -> None: + self.output_dir = output_dir + self.output_postfix = output_postfix + self.output_ext = output_ext + self._data_index = 0 + self.squeeze_end_dims = squeeze_end_dims + self.data_root_dir = data_root_dir + self.separate_folder = separate_folder + self.print_log = print_log + + def write(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None) -> None: + filename = meta_data[Key.FILENAME_OR_OBJ] if meta_data else str(self._data_index) + self._data_index += 1 + patch_index = meta_data.get(Key.PATCH_INDEX, None) if meta_data else None + + if isinstance(data, torch.Tensor): + data = data.detach().cpu().numpy() + + path = create_file_basename( + postfix=self.output_postfix, + input_file_name=filename, + folder_path=self.output_dir, + data_root_dir=self.data_root_dir, + separate_folder=self.separate_folder, + patch_index=patch_index, + ) + path = f"{path}{self.output_ext}" + + # change data to "channel last" format and write to file + data = np.moveaxis(np.asarray(data), 0, -1) + + # if desired, remove trailing singleton dimensions + if self.squeeze_end_dims: + while data.shape[-1] == 1: + data = np.squeeze(data, -1) + + self._write_file(data=data, filename=path, meta_data=meta_data) + + if self.print_log: + print(f"file written: {path}.") + + @abstractmethod + def _write_file(self, data: np.ndarray, filename: str, meta_data: Optional[Dict] = None): + raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.") + + +class ITKWriter(ImageWriter): + def __init__( + self, + output_dir: str = "./", + output_postfix: str = "seg", + output_ext: str = ".dcm", + squeeze_end_dims: bool = True, + data_root_dir: str = "", + separate_folder: bool = True, + print_log: bool = True, + resample: bool = True, + mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, + padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, + align_corners: bool = False, + dtype: DtypeLike = np.float64, + output_dtype: DtypeLike = np.float32, + ) -> None: + super().__init__( + output_dir=output_dir, + output_postfix=output_postfix, + output_ext=output_ext, + squeeze_end_dims=squeeze_end_dims, + data_root_dir=data_root_dir, + separate_folder=separate_folder, + print_log=print_log, + ) + self.resample = resample + self.mode: GridSampleMode = GridSampleMode(mode) + self.padding_mode: GridSamplePadMode = GridSamplePadMode(padding_mode) + self.align_corners = align_corners + self.dtype = dtype + self.output_dtype = output_dtype + + def _write_file(self, data: np.ndarray, filename: str, meta_data: Optional[Dict] = None): + target_affine = meta_data.get("original_affine", None) if meta_data else None + affine = meta_data.get("affine", None) if meta_data else None + spatial_shape = meta_data.get("spatial_shape", None) if meta_data else None + + if not isinstance(data, np.ndarray): + raise AssertionError("input data must be numpy array.") + dtype = self.dtype or data.dtype + sr = min(data.ndim, 3) + if affine is None: + affine = np.eye(sr + 1, dtype=np.float64) + affine = to_affine_nd(sr, affine) + + if target_affine is None: + target_affine = affine + target_affine = to_affine_nd(sr, target_affine) + + if not np.allclose(affine, target_affine, atol=1e-3): + data, affine = adjust_orientation_by_affine(data=data, affine=affine, target_affine=target_affine) + if self.resample: + data, affine = adjust_spatial_shape_by_affine( + data=data, + affine=affine, + target_affine=target_affine, + output_spatial_shape=spatial_shape, + mode=self.mode, + padding_mode=self.padding_mode, + align_corners=self.align_corners, + dtype=dtype, + ) + + itk_view = itk.image_view_from_array(data.astype(self.output_dtype)) + # nibabel to itk affine + flip_diag = [[-1, 1], [-1, -1, 1], [-1, -1, 1, 1]][sr - 1] + affine = np.diag(flip_diag) @ affine + # set affine matrix into file header + spacing = np.linalg.norm(affine[:-1, :-1] @ np.eye(sr), axis=0) + itk_view.SetSpacing(spacing) + itk_view.SetDirection(affine[:-1, :-1] @ np.diag(1 / spacing)) + itk_view.SetOrigin(affine[:-1, -1]) + itk.imwrite(itk_view, filename) diff --git a/monai/data/utils.py b/monai/data/utils.py index 94c8582e9a..ed413c2334 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -26,10 +26,13 @@ import torch from torch.utils.data._utils.collate import default_collate -from monai.networks.layers.simplelayers import GaussianFilter +from monai.config import DtypeLike +from monai.networks.layers import AffineTransform, GaussianFilter from monai.utils import ( MAX_SEED, BlendMode, + GridSampleMode, + GridSamplePadMode, NumpyPadMode, ensure_tuple, ensure_tuple_rep, @@ -75,6 +78,8 @@ "pad_list_data_collate", "no_collation", "convert_tables_to_dicts", + "adjust_orientation_by_affine", + "adjust_spatial_shape_by_affine", ] @@ -1127,3 +1132,57 @@ def convert_tables_to_dicts( data = [dict(d, **{k: v[i] for k, v in groups.items()}) for i, d in enumerate(data)] return data + + +def adjust_orientation_by_affine(data: np.ndarray, affine: np.ndarray, target_affine: np.ndarray): + start_ornt = nib.orientations.io_orientation(affine) + target_ornt = nib.orientations.io_orientation(target_affine) + ornt_transform = nib.orientations.ornt_transform(start_ornt, target_ornt) + data_shape = data.shape + data = nib.orientations.apply_orientation(data, ornt_transform) + new_affine = affine @ nib.orientations.inv_ornt_aff(ornt_transform, data_shape) + return data, new_affine + + +def adjust_spatial_shape_by_affine( + data: np.ndarray, + affine: np.ndarray, + target_affine: np.ndarray, + output_spatial_shape: Union[Sequence[int], np.ndarray, None] = None, + mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR, + padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER, + align_corners: bool = False, + dtype: DtypeLike = np.float64, +): + affine_xform = AffineTransform( + normalized=False, mode=mode, padding_mode=padding_mode, align_corners=align_corners, reverse_indexing=True + ) + transform = np.linalg.inv(affine) @ target_affine + if output_spatial_shape is None: + output_spatial_shape, _ = compute_shape_offset(data.shape, affine, target_affine) + output_spatial_shape_ = list(output_spatial_shape) if output_spatial_shape is not None else [] + if data.ndim > 3: # multi channel, resampling each channel + while len(output_spatial_shape_) < 3: + output_spatial_shape_ = output_spatial_shape_ + [1] + spatial_shape, channel_shape = data.shape[:3], data.shape[3:] + data_np = data.reshape(list(spatial_shape) + [-1]) + data_np = np.moveaxis(data_np, -1, 0) # channel first for pytorch + data_torch = affine_xform( + torch.as_tensor(np.ascontiguousarray(data_np).astype(dtype)).unsqueeze(0), + torch.as_tensor(np.ascontiguousarray(transform).astype(dtype)), + spatial_size=output_spatial_shape_[:3], + ) + data_np = data_torch.squeeze(0).detach().cpu().numpy() + data_np = np.moveaxis(data_np, 0, -1) # channel last to save file + data_np = data_np.reshape(list(data_np.shape[:3]) + list(channel_shape)) + else: # single channel image, need to expand to have batch and channel + while len(output_spatial_shape_) < len(data.shape): + output_spatial_shape_ = output_spatial_shape_ + [1] + data_torch = affine_xform( + torch.as_tensor(np.ascontiguousarray(data).astype(dtype)[None, None]), + torch.as_tensor(np.ascontiguousarray(transform).astype(dtype)), + spatial_size=output_spatial_shape_[: len(data.shape)], + ) + data_np = data_torch.squeeze(0).squeeze(0).detach().cpu().numpy() + + return data_np, target_affine diff --git a/tests/min_tests.py b/tests/min_tests.py index 1cd54f35d0..08636a05f4 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -135,6 +135,7 @@ def run_testsuit(): "test_unetr_block", "test_vit", "test_handler_decollate_batch", + "test_itk_writer", ] assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}" diff --git a/tests/test_itk_writer.py b/tests/test_itk_writer.py new file mode 100644 index 0000000000..2d33efceec --- /dev/null +++ b/tests/test_itk_writer.py @@ -0,0 +1,70 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile +import unittest +from parameterized import parameterized + +import numpy as np +import torch + +from monai.data import ITKReader, ITKWriter + +TEST_CASE_1 = [".nii.gz", np.float32, None, None, None, False, torch.zeros(8, 1, 2, 3, 4), (4, 3, 2)] + +TEST_CASE_2 = [ + ".dcm", + np.uint8, + [np.diag(np.ones(4)) * 1.0 for _ in range(8)], + [np.diag(np.ones(4)) * 5.0 for _ in range(8)], + [(10, 10, 2) for _ in range(8)], + True, + torch.zeros(8, 3, 2, 3, 4), + (10, 10, 2, 3), +] + + +class TestITKWriter(unittest.TestCase): + @parameterized.expand([TEST_CASE_1, TEST_CASE_2]) + def test_saved_content(self, output_ext, output_dtype, affine, original_affine, output_shape, resample, data, expected_shape): + with tempfile.TemporaryDirectory() as tempdir: + + writer = ITKWriter( + output_dir=tempdir, + output_postfix="seg", + output_ext=output_ext, + output_dtype=output_dtype, + resample=resample, + ) + + meta_data = {"filename_or_obj": ["testfile" + str(i) + ".nii.gz" for i in range(8)]} + if output_shape is not None: + meta_data["spatial_shape"] = output_shape + if affine is not None: + meta_data["affine"] = affine + if original_affine is not None: + meta_data["original_affine"] = original_affine + + for i in range(8): + writer.write(data=data[i], meta_data={k: meta_data[k][i] for k in meta_data} if meta_data is not None else None) + filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_seg" + output_ext) + reader = ITKReader() + img = reader.read(data=os.path.join(tempdir, filepath)) + result, meta = reader.get_data(img) + self.assertTupleEqual(result.shape, expected_shape) + if affine is not None: + # no need to compare the last line of affine matrix + np.testing.assert_allclose(meta["affine"][:-1], original_affine[i][:-1]) + + +if __name__ == "__main__": + unittest.main()