Skip to content
8 changes: 8 additions & 0 deletions docs/source/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,14 @@ WSIReader
.. autoclass:: WSIReader
:members:

Image writer
------------

ITKWriter
~~~~~~~~~
.. autoclass:: ITKWriter
:members:

Nifti format handling
---------------------

Expand Down
3 changes: 3 additions & 0 deletions monai/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from .grid_dataset import GridPatchDataset, PatchDataset, PatchIter
from .image_dataset import ImageDataset
from .image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader, WSIReader
from .image_writer import ImageWriter, ITKWriter
from .iterable_dataset import CSVIterableDataset, IterableDataset
from .nifti_saver import NiftiSaver
from .nifti_writer import write_nifti
Expand All @@ -37,6 +38,8 @@
from .test_time_augmentation import TestTimeAugmentation
from .thread_buffer import ThreadBuffer, ThreadDataLoader
from .utils import (
adjust_orientation_by_affine,
adjust_spatial_shape_by_affine,
compute_importance_map,
compute_shape_offset,
convert_tables_to_dicts,
Expand Down
161 changes: 161 additions & 0 deletions monai/data/image_writer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC, abstractmethod
from typing import Dict, Optional, Union

import numpy as np
import torch

from monai.config import DtypeLike
from monai.data.utils import (
adjust_orientation_by_affine,
adjust_spatial_shape_by_affine,
create_file_basename,
to_affine_nd,
)
from monai.utils import GridSampleMode, GridSamplePadMode
from monai.utils import ImageMetaKey as Key
from monai.utils import optional_import

itk, _ = optional_import("itk", allow_namespace_pkg=True)


class ImageWriter(ABC):
def __init__(
self,
output_dir: str = "./",
output_postfix: str = "seg",
output_ext: str = ".nii.gz",
squeeze_end_dims: bool = True,
data_root_dir: str = "",
separate_folder: bool = True,
print_log: bool = True,
) -> None:
self.output_dir = output_dir
self.output_postfix = output_postfix
self.output_ext = output_ext
self._data_index = 0
self.squeeze_end_dims = squeeze_end_dims
self.data_root_dir = data_root_dir
self.separate_folder = separate_folder
self.print_log = print_log

def write(self, data: Union[torch.Tensor, np.ndarray], meta_data: Optional[Dict] = None) -> None:
filename = meta_data[Key.FILENAME_OR_OBJ] if meta_data else str(self._data_index)
self._data_index += 1
patch_index = meta_data.get(Key.PATCH_INDEX, None) if meta_data else None

if isinstance(data, torch.Tensor):
data = data.detach().cpu().numpy()

path = create_file_basename(
postfix=self.output_postfix,
input_file_name=filename,
folder_path=self.output_dir,
data_root_dir=self.data_root_dir,
separate_folder=self.separate_folder,
patch_index=patch_index,
)
path = f"{path}{self.output_ext}"

# change data to "channel last" format and write to file
data = np.moveaxis(np.asarray(data), 0, -1)

# if desired, remove trailing singleton dimensions
if self.squeeze_end_dims:
while data.shape[-1] == 1:
data = np.squeeze(data, -1)

self._write_file(data=data, filename=path, meta_data=meta_data)

if self.print_log:
print(f"file written: {path}.")

@abstractmethod
def _write_file(self, data: np.ndarray, filename: str, meta_data: Optional[Dict] = None):
raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")


class ITKWriter(ImageWriter):
def __init__(
self,
output_dir: str = "./",
output_postfix: str = "seg",
output_ext: str = ".dcm",
squeeze_end_dims: bool = True,
data_root_dir: str = "",
separate_folder: bool = True,
print_log: bool = True,
resample: bool = True,
mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR,
padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER,
align_corners: bool = False,
dtype: DtypeLike = np.float64,
output_dtype: DtypeLike = np.float32,
) -> None:
super().__init__(
output_dir=output_dir,
output_postfix=output_postfix,
output_ext=output_ext,
squeeze_end_dims=squeeze_end_dims,
data_root_dir=data_root_dir,
separate_folder=separate_folder,
print_log=print_log,
)
self.resample = resample
self.mode: GridSampleMode = GridSampleMode(mode)
self.padding_mode: GridSamplePadMode = GridSamplePadMode(padding_mode)
self.align_corners = align_corners
self.dtype = dtype
self.output_dtype = output_dtype

def _write_file(self, data: np.ndarray, filename: str, meta_data: Optional[Dict] = None):
target_affine = meta_data.get("original_affine", None) if meta_data else None
affine = meta_data.get("affine", None) if meta_data else None
spatial_shape = meta_data.get("spatial_shape", None) if meta_data else None

if not isinstance(data, np.ndarray):
raise AssertionError("input data must be numpy array.")
dtype = self.dtype or data.dtype
sr = min(data.ndim, 3)
if affine is None:
affine = np.eye(sr + 1, dtype=np.float64)
affine = to_affine_nd(sr, affine)

if target_affine is None:
target_affine = affine
target_affine = to_affine_nd(sr, target_affine)

if not np.allclose(affine, target_affine, atol=1e-3):
data, affine = adjust_orientation_by_affine(data=data, affine=affine, target_affine=target_affine)
if self.resample:
data, affine = adjust_spatial_shape_by_affine(
data=data,
affine=affine,
target_affine=target_affine,
output_spatial_shape=spatial_shape,
mode=self.mode,
padding_mode=self.padding_mode,
align_corners=self.align_corners,
dtype=dtype,
)

itk_view = itk.image_view_from_array(data.astype(self.output_dtype))
# nibabel to itk affine
flip_diag = [[-1, 1], [-1, -1, 1], [-1, -1, 1, 1]][sr - 1]
affine = np.diag(flip_diag) @ affine
# set affine matrix into file header
spacing = np.linalg.norm(affine[:-1, :-1] @ np.eye(sr), axis=0)
itk_view.SetSpacing(spacing)
itk_view.SetDirection(affine[:-1, :-1] @ np.diag(1 / spacing))
itk_view.SetOrigin(affine[:-1, -1])
itk.imwrite(itk_view, filename)
61 changes: 60 additions & 1 deletion monai/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,13 @@
import torch
from torch.utils.data._utils.collate import default_collate

from monai.networks.layers.simplelayers import GaussianFilter
from monai.config import DtypeLike
from monai.networks.layers import AffineTransform, GaussianFilter
from monai.utils import (
MAX_SEED,
BlendMode,
GridSampleMode,
GridSamplePadMode,
NumpyPadMode,
ensure_tuple,
ensure_tuple_rep,
Expand Down Expand Up @@ -75,6 +78,8 @@
"pad_list_data_collate",
"no_collation",
"convert_tables_to_dicts",
"adjust_orientation_by_affine",
"adjust_spatial_shape_by_affine",
]


Expand Down Expand Up @@ -1127,3 +1132,57 @@ def convert_tables_to_dicts(
data = [dict(d, **{k: v[i] for k, v in groups.items()}) for i, d in enumerate(data)]

return data


def adjust_orientation_by_affine(data: np.ndarray, affine: np.ndarray, target_affine: np.ndarray):
start_ornt = nib.orientations.io_orientation(affine)
target_ornt = nib.orientations.io_orientation(target_affine)
ornt_transform = nib.orientations.ornt_transform(start_ornt, target_ornt)
data_shape = data.shape
data = nib.orientations.apply_orientation(data, ornt_transform)
new_affine = affine @ nib.orientations.inv_ornt_aff(ornt_transform, data_shape)
return data, new_affine


def adjust_spatial_shape_by_affine(
data: np.ndarray,
affine: np.ndarray,
target_affine: np.ndarray,
output_spatial_shape: Union[Sequence[int], np.ndarray, None] = None,
mode: Union[GridSampleMode, str] = GridSampleMode.BILINEAR,
padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER,
align_corners: bool = False,
dtype: DtypeLike = np.float64,
):
affine_xform = AffineTransform(
normalized=False, mode=mode, padding_mode=padding_mode, align_corners=align_corners, reverse_indexing=True
)
transform = np.linalg.inv(affine) @ target_affine
if output_spatial_shape is None:
output_spatial_shape, _ = compute_shape_offset(data.shape, affine, target_affine)
output_spatial_shape_ = list(output_spatial_shape) if output_spatial_shape is not None else []
if data.ndim > 3: # multi channel, resampling each channel
while len(output_spatial_shape_) < 3:
output_spatial_shape_ = output_spatial_shape_ + [1]
spatial_shape, channel_shape = data.shape[:3], data.shape[3:]
data_np = data.reshape(list(spatial_shape) + [-1])
data_np = np.moveaxis(data_np, -1, 0) # channel first for pytorch
data_torch = affine_xform(
torch.as_tensor(np.ascontiguousarray(data_np).astype(dtype)).unsqueeze(0),
torch.as_tensor(np.ascontiguousarray(transform).astype(dtype)),
spatial_size=output_spatial_shape_[:3],
)
data_np = data_torch.squeeze(0).detach().cpu().numpy()
data_np = np.moveaxis(data_np, 0, -1) # channel last to save file
data_np = data_np.reshape(list(data_np.shape[:3]) + list(channel_shape))
else: # single channel image, need to expand to have batch and channel
while len(output_spatial_shape_) < len(data.shape):
output_spatial_shape_ = output_spatial_shape_ + [1]
data_torch = affine_xform(
torch.as_tensor(np.ascontiguousarray(data).astype(dtype)[None, None]),
torch.as_tensor(np.ascontiguousarray(transform).astype(dtype)),
spatial_size=output_spatial_shape_[: len(data.shape)],
)
data_np = data_torch.squeeze(0).squeeze(0).detach().cpu().numpy()

return data_np, target_affine
1 change: 1 addition & 0 deletions tests/min_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def run_testsuit():
"test_unetr_block",
"test_vit",
"test_handler_decollate_batch",
"test_itk_writer",
]
assert sorted(exclude_cases) == sorted(set(exclude_cases)), f"Duplicated items in {exclude_cases}"

Expand Down
70 changes: 70 additions & 0 deletions tests/test_itk_writer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright 2020 - 2021 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import tempfile
import unittest
from parameterized import parameterized

import numpy as np
import torch

from monai.data import ITKReader, ITKWriter

TEST_CASE_1 = [".nii.gz", np.float32, None, None, None, False, torch.zeros(8, 1, 2, 3, 4), (4, 3, 2)]

TEST_CASE_2 = [
".dcm",
np.uint8,
[np.diag(np.ones(4)) * 1.0 for _ in range(8)],
[np.diag(np.ones(4)) * 5.0 for _ in range(8)],
[(10, 10, 2) for _ in range(8)],
True,
torch.zeros(8, 3, 2, 3, 4),
(10, 10, 2, 3),
]


class TestITKWriter(unittest.TestCase):
@parameterized.expand([TEST_CASE_1, TEST_CASE_2])
def test_saved_content(self, output_ext, output_dtype, affine, original_affine, output_shape, resample, data, expected_shape):
with tempfile.TemporaryDirectory() as tempdir:

writer = ITKWriter(
output_dir=tempdir,
output_postfix="seg",
output_ext=output_ext,
output_dtype=output_dtype,
resample=resample,
)

meta_data = {"filename_or_obj": ["testfile" + str(i) + ".nii.gz" for i in range(8)]}
if output_shape is not None:
meta_data["spatial_shape"] = output_shape
if affine is not None:
meta_data["affine"] = affine
if original_affine is not None:
meta_data["original_affine"] = original_affine

for i in range(8):
writer.write(data=data[i], meta_data={k: meta_data[k][i] for k in meta_data} if meta_data is not None else None)
filepath = os.path.join("testfile" + str(i), "testfile" + str(i) + "_seg" + output_ext)
reader = ITKReader()
img = reader.read(data=os.path.join(tempdir, filepath))
result, meta = reader.get_data(img)
self.assertTupleEqual(result.shape, expected_shape)
if affine is not None:
# no need to compare the last line of affine matrix
np.testing.assert_allclose(meta["affine"][:-1], original_affine[i][:-1])


if __name__ == "__main__":
unittest.main()