Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions monai/data/nifti_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,19 @@
import torch

from monai.config import DtypeLike
from monai.config.type_definitions import NdarrayOrTensor
from monai.data.utils import compute_shape_offset, to_affine_nd
from monai.networks.layers import AffineTransform
from monai.utils import GridSampleMode, GridSamplePadMode, optional_import
from monai.utils.type_conversion import convert_data_type

nib, _ = optional_import("nibabel")


def write_nifti(
data: np.ndarray,
data: NdarrayOrTensor,
file_name: str,
affine: Optional[np.ndarray] = None,
affine: Optional[NdarrayOrTensor] = None,
target_affine: Optional[np.ndarray] = None,
resample: bool = True,
output_spatial_shape: Union[Sequence[int], np.ndarray, None] = None,
Expand Down Expand Up @@ -96,13 +98,17 @@ def write_nifti(
If None, use the data type of input data.
output_dtype: data type for saving data. Defaults to ``np.float32``.
"""
if isinstance(data, torch.Tensor):
data, *_ = convert_data_type(data, np.ndarray)
if isinstance(affine, torch.Tensor):
affine, *_ = convert_data_type(affine, np.ndarray)
if not isinstance(data, np.ndarray):
raise AssertionError("input data must be numpy array.")
raise AssertionError("input data must be numpy array or torch tensor.")
dtype = dtype or data.dtype
sr = min(data.ndim, 3)
if affine is None:
affine = np.eye(4, dtype=np.float64)
affine = to_affine_nd(sr, affine)
affine = to_affine_nd(sr, affine) # type: ignore

if target_affine is None:
target_affine = affine
Expand All @@ -122,7 +128,7 @@ def write_nifti(
data = nib.orientations.apply_orientation(data, ornt_transform)
_affine = affine @ nib.orientations.inv_ornt_aff(ornt_transform, data_shape)
if np.allclose(_affine, target_affine, atol=1e-3) or not resample:
results_img = nib.Nifti1Image(data.astype(output_dtype), to_affine_nd(3, _affine))
results_img = nib.Nifti1Image(data.astype(output_dtype), to_affine_nd(3, _affine)) # type: ignore
nib.save(results_img, file_name)
return

Expand All @@ -138,7 +144,7 @@ def write_nifti(
while len(output_spatial_shape_) < 3:
output_spatial_shape_ = output_spatial_shape_ + [1]
spatial_shape, channel_shape = data.shape[:3], data.shape[3:]
data_np = data.reshape(list(spatial_shape) + [-1])
data_np: np.ndarray = data.reshape(list(spatial_shape) + [-1]) # type: ignore
data_np = np.moveaxis(data_np, -1, 0) # channel first for pytorch
data_torch = affine_xform(
torch.as_tensor(np.ascontiguousarray(data_np).astype(dtype)).unsqueeze(0),
Expand Down
55 changes: 30 additions & 25 deletions monai/transforms/spatial/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
https://github.com/Project-MONAI/MONAI/wiki/MONAI_Design
"""
import warnings
from copy import deepcopy
from typing import Any, List, Optional, Sequence, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -85,6 +86,8 @@ class Spacing(Transform):
Resample input image into the specified `pixdim`.
"""

backend = [TransformBackends.TORCH]

def __init__(
self,
pixdim: Union[Sequence[float], float],
Expand Down Expand Up @@ -136,14 +139,14 @@ def __init__(

def __call__(
self,
data_array: np.ndarray,
affine: Optional[np.ndarray] = None,
data_array: NdarrayOrTensor,
affine: Optional[NdarrayOrTensor] = None,
mode: Optional[Union[GridSampleMode, str]] = None,
padding_mode: Optional[Union[GridSamplePadMode, str]] = None,
align_corners: Optional[bool] = None,
dtype: DtypeLike = None,
output_spatial_shape: Optional[np.ndarray] = None,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
) -> Tuple[NdarrayOrTensor, NdarrayOrTensor, NdarrayOrTensor]:
"""
Args:
data_array: in shape (num_channels, H[, W, ...]).
Expand Down Expand Up @@ -171,17 +174,17 @@ def __call__(
data_array (resampled into `self.pixdim`), original affine, current affine.

"""
data_array, *_ = convert_data_type(data_array, np.ndarray) # type: ignore
_dtype = dtype or self.dtype or data_array.dtype
sr = data_array.ndim - 1
sr = int(data_array.ndim - 1)
if sr <= 0:
raise ValueError("data_array must have at least one spatial dimension.")
if affine is None:
# default to identity
affine = np.eye(sr + 1, dtype=np.float64)
affine_ = np.eye(sr + 1, dtype=np.float64)
else:
affine_ = to_affine_nd(sr, affine)
affine, *_ = convert_data_type(affine, np.ndarray)
affine_ = to_affine_nd(sr, affine) # type: ignore

out_d = self.pixdim[:sr]
if out_d.size < sr:
Expand All @@ -197,26 +200,28 @@ def __call__(

# no resampling if it's identity transform
if np.allclose(transform, np.diag(np.ones(len(transform))), atol=1e-3):
output_data = data_array.copy().astype(np.float32)
new_affine = to_affine_nd(affine, new_affine)
return output_data, affine, new_affine
output_data, *_ = convert_data_type(deepcopy(data_array), dtype=_dtype)
new_affine = to_affine_nd(affine, new_affine) # type: ignore

# resample
affine_xform = AffineTransform(
normalized=False,
mode=look_up_option(mode or self.mode, GridSampleMode),
padding_mode=look_up_option(padding_mode or self.padding_mode, GridSamplePadMode),
align_corners=self.align_corners if align_corners is None else align_corners,
reverse_indexing=True,
)
output_data = affine_xform(
# AffineTransform requires a batch dim
torch.as_tensor(np.ascontiguousarray(data_array).astype(_dtype)).unsqueeze(0),
torch.as_tensor(np.ascontiguousarray(transform).astype(_dtype)),
spatial_size=output_shape if output_spatial_shape is None else output_spatial_shape,
)
output_data = np.asarray(output_data.squeeze(0).detach().cpu().numpy(), dtype=np.float32) # type: ignore
new_affine = to_affine_nd(affine, new_affine)
else:
# resample
affine_xform = AffineTransform(
normalized=False,
mode=look_up_option(mode or self.mode, GridSampleMode),
padding_mode=look_up_option(padding_mode or self.padding_mode, GridSamplePadMode),
align_corners=self.align_corners if align_corners is None else align_corners,
reverse_indexing=True,
)
data_array_t: torch.Tensor
data_array_t, *_ = convert_data_type(data_array, torch.Tensor, dtype=_dtype) # type: ignore
output_data = affine_xform(
# AffineTransform requires a batch dim
data_array_t.unsqueeze(0),
convert_data_type(transform, torch.Tensor, data_array_t.device, dtype=_dtype)[0],
spatial_size=output_shape if output_spatial_shape is None else output_spatial_shape,
).squeeze(0)
output_data, *_ = convert_to_dst_type(output_data, data_array, dtype=_dtype)
new_affine = to_affine_nd(affine, new_affine) # type: ignore

return output_data, affine, new_affine

Expand Down
16 changes: 9 additions & 7 deletions monai/transforms/spatial/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,8 @@ class Spacingd(MapTransform, InvertibleTransform):
:py:class:`monai.transforms.Spacing`
"""

backend = Spacing.backend

def __init__(
self,
keys: KeysCollection,
Expand Down Expand Up @@ -211,8 +213,8 @@ def __init__(
self.meta_key_postfix = ensure_tuple_rep(meta_key_postfix, len(self.keys))

def __call__(
self, data: Mapping[Union[Hashable, str], Dict[str, np.ndarray]]
) -> Dict[Union[Hashable, str], Union[np.ndarray, Dict[str, np.ndarray]]]:
self, data: Mapping[Union[Hashable, str], Dict[str, NdarrayOrTensor]]
) -> Dict[Union[Hashable, str], Union[NdarrayOrTensor, Dict[str, NdarrayOrTensor]]]:
d: Dict = dict(data)
for key, mode, padding_mode, align_corners, dtype, meta_key, meta_key_postfix in self.key_iterator(
d, self.mode, self.padding_mode, self.align_corners, self.dtype, self.meta_keys, self.meta_key_postfix
Expand All @@ -226,7 +228,7 @@ def __call__(
# using affine fetched from d[affine_key]
original_spatial_shape = d[key].shape[1:]
d[key], old_affine, new_affine = self.spacing_transform(
data_array=np.asarray(d[key]),
data_array=d[key],
affine=meta_data["affine"],
mode=mode,
padding_mode=padding_mode,
Expand All @@ -249,7 +251,7 @@ def __call__(
meta_data["affine"] = new_affine
return d

def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndarray]:
def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = deepcopy(dict(data))
for key, dtype in self.key_iterator(d, self.dtype):
transform = self.get_most_recent_transform(d, key)
Expand All @@ -269,15 +271,15 @@ def inverse(self, data: Mapping[Hashable, np.ndarray]) -> Dict[Hashable, np.ndar
inverse_transform = Spacing(orig_pixdim, diagonal=self.spacing_transform.diagonal)
# Apply inverse
d[key], _, new_affine = inverse_transform(
data_array=np.asarray(d[key]),
affine=meta_data["affine"],
data_array=d[key],
affine=meta_data["affine"], # type: ignore
mode=mode,
padding_mode=padding_mode,
align_corners=False if align_corners == "none" else align_corners,
dtype=dtype,
output_spatial_shape=orig_size,
)
meta_data["affine"] = new_affine
meta_data["affine"] = new_affine # type: ignore
# Remove the applied transform
self.pop_transform(d, key)

Expand Down
Loading