diff --git a/monai/apps/detection/transforms/dictionary.py b/monai/apps/detection/transforms/dictionary.py index 2cf6e19dbb..9d63378a2b 100644 --- a/monai/apps/detection/transforms/dictionary.py +++ b/monai/apps/detection/transforms/dictionary.py @@ -395,7 +395,7 @@ def __init__( self.zoomer = Zoom(zoom=zoom, keep_size=keep_size, **kwargs) self.keep_size = keep_size - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: d = dict(data) # zoom box @@ -408,7 +408,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N box_key, extra_info={"zoom": self.zoomer.zoom, "src_spatial_size": src_spatial_size, "type": "box_key"}, ) - d[box_key] = ZoomBox(zoom=self.zoomer.zoom, keep_size=self.keep_size)( + d[box_key] = ZoomBox(zoom=self.zoomer.zoom, keep_size=self.keep_size)( # type: ignore d[box_key], src_spatial_size=src_spatial_size ) @@ -431,7 +431,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N return d - def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: d = deepcopy(dict(data)) for key in self.key_iterator(d): @@ -461,7 +461,7 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd zoom = np.array(transform[TraceKeys.EXTRA_INFO]["zoom"]) src_spatial_size = transform[TraceKeys.EXTRA_INFO]["src_spatial_size"] box_inverse_transform = ZoomBox(zoom=(1 / zoom).tolist(), keep_size=self.zoomer.keep_size) - d[key] = box_inverse_transform(d[key], src_spatial_size=src_spatial_size) + d[key] = box_inverse_transform(d[key], src_spatial_size=src_spatial_size) # type: ignore # Remove the applied transform self.pop_transform(d, key) @@ -545,7 +545,7 @@ def set_random_state( self.rand_zoom.set_random_state(seed, state) return self - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: d = dict(data) first_key: Union[Hashable, List] = self.first_key(d) if first_key == []: @@ -568,7 +568,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N box_key, extra_info={"zoom": self.rand_zoom._zoom, "src_spatial_size": src_spatial_size, "type": "box_key"}, ) - d[box_key] = ZoomBox(zoom=self.rand_zoom._zoom, keep_size=self.keep_size)( + d[box_key] = ZoomBox(zoom=self.rand_zoom._zoom, keep_size=self.keep_size)( # type: ignore d[box_key], src_spatial_size=src_spatial_size ) @@ -595,7 +595,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N return d - def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: d = deepcopy(dict(data)) for key in self.key_iterator(d): @@ -626,7 +626,7 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd zoom = np.array(transform[TraceKeys.EXTRA_INFO]["zoom"]) src_spatial_size = transform[TraceKeys.EXTRA_INFO]["src_spatial_size"] box_inverse_transform = ZoomBox(zoom=(1.0 / zoom).tolist(), keep_size=self.rand_zoom.keep_size) - d[key] = box_inverse_transform(d[key], src_spatial_size=src_spatial_size) + d[key] = box_inverse_transform(d[key], src_spatial_size=src_spatial_size) # type: ignore # Remove the applied transform self.pop_transform(d, key) @@ -667,7 +667,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N d = dict(data) for key in self.image_keys: - d[key] = self.flipper(d[key]) + d[key] = self.flipper(d[key]) # type: ignore self.push_transform(d, key, extra_info={"type": "image_key"}) for box_key, box_ref_image_key in zip(self.box_keys, self.box_ref_image_keys): @@ -685,7 +685,7 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd # flip image, copied from monai.transforms.spatial.dictionary.Flipd if key_type == "image_key": - d[key] = self.flipper(d[key]) + d[key] = self.flipper(d[key]) # type: ignore # flip boxes if key_type == "box_key": @@ -743,7 +743,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N for key in self.image_keys: if self._do_transform: - d[key] = self.flipper(d[key], randomize=False) + d[key] = self.flipper(d[key], randomize=False) # type: ignore self.push_transform(d, key, extra_info={"type": "image_key"}) for box_key, box_ref_image_key in zip(self.box_keys, self.box_ref_image_keys): @@ -763,7 +763,7 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd if transform[TraceKeys.DO_TRANSFORM]: # flip image, copied from monai.transforms.spatial.dictionary.RandFlipd if key_type == "image_key": - d[key] = self.flipper(d[key], randomize=False) + d[key] = self.flipper(d[key], randomize=False) # type: ignore # flip boxes if key_type == "box_key": @@ -1271,7 +1271,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable self.push_transform(d, key, extra_info={"spatial_size": spatial_size, "type": "box_key"}) for key in self.image_keys: - d[key] = self.img_rotator(d[key]) + d[key] = self.img_rotator(d[key]) # type: ignore self.push_transform(d, key, extra_info={"type": "image_key"}) return d @@ -1285,7 +1285,7 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd if key_type == "image_key": inverse_transform = Rotate90(num_times_to_rotate, self.img_rotator.spatial_axes) - d[key] = inverse_transform(d[key]) + d[key] = inverse_transform(d[key]) # type: ignore if key_type == "box_key": spatial_size = transform[TraceKeys.EXTRA_INFO]["spatial_size"] inverse_transform = RotateBox90(num_times_to_rotate, self.box_rotator.spatial_axes) @@ -1329,7 +1329,7 @@ def __init__( super().__init__(self.image_keys + self.box_keys, prob, max_k, spatial_axes, allow_missing_keys) self.box_ref_image_keys = ensure_tuple_rep(box_ref_image_keys, len(self.box_keys)) - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable, NdarrayOrTensor]: # type: ignore self.randomize() d = dict(data) @@ -1357,11 +1357,11 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable for key in self.image_keys: if self._do_transform: - d[key] = img_rotator(d[key]) + d[key] = img_rotator(d[key]) # type: ignore self.push_transform(d, key, extra_info={"rand_k": self._rand_k, "type": "image_key"}) return d - def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: # type: ignore d = deepcopy(dict(data)) if self._rand_k % 4 == 0: return d @@ -1376,7 +1376,7 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd # flip image, copied from monai.transforms.spatial.dictionary.RandFlipd if key_type == "image_key": inverse_transform = Rotate90(num_times_to_rotate, self.spatial_axes) - d[key] = inverse_transform(d[key]) + d[key] = inverse_transform(d[key]) # type: ignore if key_type == "box_key": spatial_size = transform[TraceKeys.EXTRA_INFO]["spatial_size"] inverse_transform = RotateBox90(num_times_to_rotate, self.spatial_axes) diff --git a/monai/data/image_writer.py b/monai/data/image_writer.py index 2ed22beea6..97b657ad9f 100644 --- a/monai/data/image_writer.py +++ b/monai/data/image_writer.py @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings from typing import TYPE_CHECKING, Dict, Mapping, Optional, Sequence, Union import numpy as np @@ -269,6 +270,9 @@ def resample_if_needed( resampler = SpatialResample(mode=mode, padding_mode=padding_mode, align_corners=align_corners, dtype=dtype) output_array = resampler(data_array[None], dst_affine=target_affine, spatial_size=output_spatial_shape) # convert back at the end + if isinstance(output_array, MetaTensor): + warnings.warn("ignoring the tracking transform info.") + output_array.applied_operations = [] data_array, *_ = convert_data_type(output_array, output_type=orig_type) # type: ignore affine, *_ = convert_data_type(output_array.affine, output_type=orig_type) # type: ignore return data_array[0], affine @@ -764,11 +768,11 @@ def resample_and_clip( _min, _max = np.min(data), np.max(data) if len(data.shape) == 3: data = np.moveaxis(data, -1, 0) # to channel first - data = xform(data) # type: ignore + data = convert_data_type(xform(data), np.ndarray, drop_meta=True)[0] # type: ignore data = np.moveaxis(data, 0, -1) else: # (H, W) data = np.expand_dims(data, 0) # make a channel - data = xform(data)[0] # type: ignore + data = convert_data_type(xform(data), np.ndarray, drop_meta=True)[0][0] # type: ignore if mode != InterpolateMode.NEAREST: data = np.clip(data, _min, _max) return data diff --git a/monai/data/meta_obj.py b/monai/data/meta_obj.py index ea6b377a38..560e5ab2c5 100644 --- a/monai/data/meta_obj.py +++ b/monai/data/meta_obj.py @@ -183,7 +183,7 @@ def __repr__(self) -> str: @property def meta(self) -> dict: """Get the meta.""" - return self._meta + return self._meta if hasattr(self, "_meta") else self.get_default_meta() @meta.setter def meta(self, d) -> None: @@ -195,7 +195,9 @@ def meta(self, d) -> None: @property def applied_operations(self) -> list: """Get the applied operations.""" - return self._applied_operations + if hasattr(self, "_applied_operations"): + return self._applied_operations + return self.get_default_applied_operations() @applied_operations.setter def applied_operations(self, t) -> None: @@ -215,7 +217,7 @@ def pop_applied_operation(self) -> Any: @property def is_batch(self) -> bool: """Return whether object is part of batch or not.""" - return self._is_batch + return self._is_batch if hasattr(self, "_is_batch") else False @is_batch.setter def is_batch(self, val: bool) -> None: diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index 01c964e4ab..3ccda3361f 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -183,6 +183,9 @@ def update_meta(rets: Sequence, func, args, kwargs) -> Sequence: # else, handle the `MetaTensor` metadata. else: meta_args = MetaObj.flatten_meta_objs(list(args) + list(kwargs.values())) + # this is not implemented but the network arch may run into this case: + # if func == torch.cat and any(m.is_batch if hasattr(m, "is_batch") else False for m in meta_args): + # raise NotImplementedError("torch.cat is not implemented for batch of MetaTensors.") ret._copy_meta(meta_args) # If we have a batch of data, then we need to be careful if a slice of @@ -195,17 +198,17 @@ def update_meta(rets: Sequence, func, args, kwargs) -> Sequence: metas = decollate_batch(ret.meta) # if indexing e.g., `batch[0]` if func == torch.Tensor.__getitem__: - idx = args[1] - if isinstance(idx, Sequence): - idx = idx[0] + batch_idx = args[1] + if isinstance(batch_idx, Sequence): + batch_idx = batch_idx[0] # if using e.g., `batch[:, -1]` or `batch[..., -1]`, then the # first element will be `slice(None, None, None)` and `Ellipsis`, # respectively. Don't need to do anything with the metadata. - if idx not in (slice(None, None, None), Ellipsis): - meta = metas[idx] + if batch_idx not in (slice(None, None, None), Ellipsis): + meta = metas[batch_idx] # if using e.g., `batch[0:2]`, then `is_batch` should still be # `True`. Also re-collate the remaining elements. - if isinstance(meta, list) and len(meta) > 1: + if isinstance(meta, list): ret.meta = list_data_collate(meta) # if using e.g., `batch[0]` or `batch[0, 1]`, then return single # element from batch, and set `is_batch` to `False`. @@ -243,6 +246,19 @@ def __torch_function__(cls, func, types, args=(), kwargs=None) -> Any: # we might have 1 or multiple outputs. Might be MetaTensor, might be something # else (e.g., `__repr__` returns a string). # Convert to list (if necessary), process, and at end remove list if one was added. + if ( + hasattr(torch, "return_types") + and hasattr(func, "__name__") + and hasattr(torch.return_types, func.__name__) + and isinstance(getattr(torch.return_types, func.__name__), type) + and isinstance(ret, getattr(torch.return_types, func.__name__)) + ): + # for torch.max(torch.tensor(1.0), dim=0), the return type is named-tuple like + out_items = MetaTensor.update_meta(ret, func, args, kwargs) + for idx in range(ret.n_fields): + ret[idx].meta = out_items[idx].meta + ret[idx].applied_operations = out_items[idx].applied_operations + return ret if isinstance(ret, (str, bytes)) or not isinstance(ret, Sequence): ret = [ret] unpack = True diff --git a/monai/data/png_writer.py b/monai/data/png_writer.py index b1fe7eb327..9fb463e9b9 100644 --- a/monai/data/png_writer.py +++ b/monai/data/png_writer.py @@ -14,7 +14,14 @@ import numpy as np from monai.transforms.spatial.array import Resize -from monai.utils import InterpolateMode, deprecated, ensure_tuple_rep, look_up_option, optional_import +from monai.utils import ( + InterpolateMode, + convert_data_type, + deprecated, + ensure_tuple_rep, + look_up_option, + optional_import, +) Image, _ = optional_import("PIL", name="Image") @@ -74,9 +81,9 @@ def write_png( if scale is not None: data = np.clip(data, 0.0, 1.0) # png writer only can scale data in range [0, 1] if scale == np.iinfo(np.uint8).max: - data = (scale * data).astype(np.uint8, copy=False) + data = convert_data_type((scale * data), np.ndarray, dtype=np.uint8, drop_meta=True)[0] elif scale == np.iinfo(np.uint16).max: - data = (scale * data).astype(np.uint16, copy=False) + data = convert_data_type((scale * data), np.ndarray, dtype=np.uint16, drop_meta=True)[0] else: raise ValueError(f"Unsupported scale: {scale}, available options are [255, 65535]") diff --git a/monai/data/utils.py b/monai/data/utils.py index 8faf2defe3..dd863c4898 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -392,6 +392,24 @@ def dev_collate(batch, level: int = 1, logger_name: str = "dev_collate"): return +def collate_meta_tensor(batch): + """collate a sequence of meta tensor sequences/dictionaries into + a single batched metatensor or a dictionary of batched metatensor""" + if not isinstance(batch, Sequence): + raise NotImplementedError() + elem_0 = first(batch) + if isinstance(elem_0, MetaObj): + collated = default_collate(batch) + collated.meta = default_collate([i.meta or TraceKeys.NONE for i in batch]) + collated.applied_operations = [i.applied_operations or TraceKeys.NONE for i in batch] + collated.is_batch = True + return collated + if isinstance(elem_0, Mapping): + return {k: collate_meta_tensor([d[k] for d in batch]) for k in elem_0} + # no more recursive search for MetaTensor + return default_collate(batch) + + def list_data_collate(batch: Sequence): """ Enhancement for PyTorch DataLoader default collate. @@ -411,19 +429,9 @@ def list_data_collate(batch: Sequence): for k in elem: key = k data_for_batch = [d[key] for d in data] - ret[key] = default_collate(data_for_batch) - if isinstance(ret[key], MetaObj) and all(isinstance(d, MetaObj) for d in data_for_batch): - meta_list = [i.meta or TraceKeys.NONE for i in data_for_batch] - ret[key].meta = default_collate(meta_list) - ops_list = [i.applied_operations or TraceKeys.NONE for i in data_for_batch] - ret[key].applied_operations = default_collate(ops_list) - ret[key].is_batch = True + ret[key] = collate_meta_tensor(data_for_batch) else: - ret = default_collate(data) - if isinstance(ret, MetaObj) and all(isinstance(d, MetaObj) for d in data): - ret.meta = default_collate([i.meta or TraceKeys.NONE for i in data]) - ret.applied_operations = default_collate([i.applied_operations or TraceKeys.NONE for i in data]) - ret.is_batch = True + ret = collate_meta_tensor(data) return ret except RuntimeError as re: re_str = str(re) @@ -550,7 +558,7 @@ def decollate_batch(batch, detach: bool = True, pad=True, fill_value=None): if isinstance(t, MetaObj): t.meta = m t.is_batch = False - for t, m in zip(out_list, decollate_batch(batch.applied_operations)): + for t, m in zip(out_list, batch.applied_operations): if isinstance(t, MetaObj): t.applied_operations = m t.is_batch = False @@ -848,7 +856,7 @@ def to_affine_nd(r: Union[np.ndarray, int], affine: NdarrayTensor, dtype=np.floa an (r+1) x (r+1) matrix (tensor or ndarray depends on the input ``affine`` data type) """ - affine_np = convert_data_type(affine, output_type=np.ndarray, dtype=dtype, wrap_sequence=True)[0] + affine_np = convert_data_type(affine, output_type=np.ndarray, dtype=dtype, wrap_sequence=True, drop_meta=True)[0] affine_np = affine_np.copy() if affine_np.ndim != 2: raise ValueError(f"affine must have 2 dimensions, got {affine_np.ndim}.") diff --git a/monai/inferers/utils.py b/monai/inferers/utils.py index c4c4bd891c..b7e13323ec 100644 --- a/monai/inferers/utils.py +++ b/monai/inferers/utils.py @@ -15,12 +15,14 @@ import torch import torch.nn.functional as F +from monai.data.meta_tensor import MetaTensor from monai.data.utils import compute_importance_map, dense_patch_slices, get_valid_patch_size from monai.transforms import Resize from monai.utils import ( BlendMode, PytorchPadMode, convert_data_type, + convert_to_dst_type, ensure_tuple, fall_back_tuple, look_up_option, @@ -172,7 +174,9 @@ def sliding_window_inference( [slice(int(idx / num_win), int(idx / num_win) + 1), slice(None)] + list(slices[idx % num_win]) for idx in slice_range ] - window_data = torch.cat([inputs[win_slice] for win_slice in unravel_slice]).to(sw_device) + window_data = torch.cat( + [convert_data_type(inputs[win_slice], torch.Tensor, drop_meta=True)[0] for win_slice in unravel_slice] + ).to(sw_device) seg_prob_out = predictor(window_data, *args, **kwargs) # batched patch segmentation # convert seg_prob_out to tuple seg_prob_tuple, this does not allocate new memory. @@ -272,7 +276,10 @@ def sliding_window_inference( final_output = dict(zip(dict_key, output_image_list)) else: final_output = tuple(output_image_list) # type: ignore - return final_output[0] if is_tensor_output else final_output # type: ignore + final_output = final_output[0] if is_tensor_output else final_output # type: ignore + if isinstance(inputs, MetaTensor): + final_output = convert_to_dst_type(final_output, inputs)[0] # type: ignore + return final_output def _get_scan_interval( diff --git a/monai/metrics/utils.py b/monai/metrics/utils.py index c17df7a54a..4722f0f040 100644 --- a/monai/metrics/utils.py +++ b/monai/metrics/utils.py @@ -155,8 +155,8 @@ def get_mask_edges(seg_pred, seg_gt, label_idx: int = 1, crop: bool = True) -> T seg_pred, seg_gt = np.expand_dims(seg_pred, axis=channel_dim), np.expand_dims(seg_gt, axis=channel_dim) box_start, box_end = generate_spatial_bounding_box(np.asarray(seg_pred | seg_gt)) cropper = SpatialCrop(roi_start=box_start, roi_end=box_end) - seg_pred = convert_data_type(np.squeeze(cropper(seg_pred), axis=channel_dim), np.ndarray)[0] - seg_gt = convert_data_type(np.squeeze(cropper(seg_gt), axis=channel_dim), np.ndarray)[0] + seg_pred = convert_data_type(np.squeeze(cropper(seg_pred), axis=channel_dim), np.ndarray, drop_meta=True)[0] + seg_gt = convert_data_type(np.squeeze(cropper(seg_gt), axis=channel_dim), np.ndarray, drop_meta=True)[0] # Do binary erosion and use XOR to get edges edges_pred = binary_erosion(seg_pred) ^ seg_pred diff --git a/monai/networks/layers/spatial_transforms.py b/monai/networks/layers/spatial_transforms.py index 07ddb3ce9d..b349711a9b 100644 --- a/monai/networks/layers/spatial_transforms.py +++ b/monai/networks/layers/spatial_transforms.py @@ -9,11 +9,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from copy import deepcopy from typing import Optional, Sequence, Union import torch import torch.nn as nn +import monai from monai.networks import to_norm_affine from monai.utils import GridSampleMode, GridSamplePadMode, ensure_tuple, look_up_option, optional_import @@ -116,6 +118,10 @@ def grid_pull( ] out: torch.Tensor out = _GridPull.apply(input, grid, interpolation, bound, extrapolate) + if isinstance(input, monai.data.MetaTensor): + out = monai.data.MetaTensor( + out, meta=deepcopy(input.meta), applied_operations=deepcopy(input.applied_operations) + ) return out @@ -217,7 +223,12 @@ def grid_push( if shape is None: shape = tuple(input.shape[2:]) - return _GridPush.apply(input, grid, shape, interpolation, bound, extrapolate) + out: torch.Tensor = _GridPush.apply(input, grid, shape, interpolation, bound, extrapolate) + if isinstance(input, monai.data.MetaTensor): + out = monai.data.MetaTensor( + out, meta=deepcopy(input.meta), applied_operations=deepcopy(input.applied_operations) + ) + return out class _GridCount(torch.autograd.Function): @@ -313,7 +324,12 @@ def grid_count(grid: torch.Tensor, shape=None, interpolation="linear", bound="ze if shape is None: shape = tuple(grid.shape[2:]) - return _GridCount.apply(grid, shape, interpolation, bound, extrapolate) + out: torch.Tensor = _GridCount.apply(grid, shape, interpolation, bound, extrapolate) + if isinstance(input, monai.data.MetaTensor): + out = monai.data.MetaTensor( + out, meta=deepcopy(input.meta), applied_operations=deepcopy(input.applied_operations) + ) + return out class _GridGrad(torch.autograd.Function): @@ -408,7 +424,12 @@ def grid_grad(input: torch.Tensor, grid: torch.Tensor, interpolation="linear", b for i in ensure_tuple(interpolation) ] - return _GridGrad.apply(input, grid, interpolation, bound, extrapolate) + out: torch.Tensor = _GridGrad.apply(input, grid, interpolation, bound, extrapolate) + if isinstance(input, monai.data.MetaTensor): + out = monai.data.MetaTensor( + out, meta=deepcopy(input.meta), applied_operations=deepcopy(input.applied_operations) + ) + return out class AffineTransform(nn.Module): diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index e2291ea7a6..d1bc6d84df 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -607,6 +607,7 @@ rescale_array_int_max, rescale_instance_array, resize_center, + scale_affine, weighted_patch_samples, zero_margins, ) diff --git a/monai/transforms/croppad/array.py b/monai/transforms/croppad/array.py index a8fd4a0243..0225db6abe 100644 --- a/monai/transforms/croppad/array.py +++ b/monai/transforms/croppad/array.py @@ -113,6 +113,7 @@ def _pt_pad(img: torch.Tensor, all_pad_width, mode, **kwargs) -> torch.Tensor: return pad_pt(img.unsqueeze(0), pt_pad_width, mode=mode, **kwargs).squeeze(0) def _forward_meta(self, out, img, to_pad): + if not isinstance(out, MetaTensor) or not isinstance(img, MetaTensor): return out meta_dict = copy.deepcopy(img.meta) @@ -442,8 +443,8 @@ def get_im_center(img: torch.Tensor): @staticmethod def calculate_slices_from_center_and_size(roi_center, roi_size) -> List[slice]: roi_slices = [] - roi_center = [roi_center] if not isinstance(roi_center, Iterable) else roi_center - roi_size = [roi_size] if not isinstance(roi_size, Iterable) else roi_size + roi_center = roi_center if isinstance(roi_center, Iterable) else [roi_center] + roi_size = roi_size if isinstance(roi_size, Iterable) else [roi_size] for c, s in zip(roi_center, roi_size): c = c.item() if isinstance(c, torch.Tensor) else c s = s.item() if isinstance(s, torch.Tensor) else s @@ -460,8 +461,8 @@ def calculate_slices_from_center_and_size(roi_center, roi_size) -> List[slice]: @staticmethod def calculate_slices_from_start_and_end(roi_start, roi_end) -> List[slice]: # start +ve, end <= start - roi_start = [roi_start] if not isinstance(roi_start, Iterable) else roi_start - roi_end = [roi_end] if not isinstance(roi_end, Iterable) else roi_end + roi_start = roi_start if isinstance(roi_start, Iterable) else [roi_start] + roi_end = roi_end if isinstance(roi_end, Iterable) else [roi_end] roi_start = [max(r, 0) for r in roi_start] # type: ignore roi_end = [max(r, s) for r, s in zip(roi_start, roi_end)] # type: ignore roi_slices = [slice(s, e) for s, e in zip(roi_start, roi_end)] @@ -508,7 +509,7 @@ def _forward_meta(self, out, img, slices): out.meta = meta_dict out.meta["affine"] = img.affine @ convert_to_dst_type(mat, img.affine)[0] # out.meta["original_affine"] = img.affine - # out.meta["spatial_shape"] = out.shape[1:]f + # out.meta["spatial_shape"] = out.shape[1:] return out def _forward(self, img: torch.Tensor, slices: Optional[Tuple[slice, ...]]) -> torch.Tensor: @@ -1411,6 +1412,9 @@ def __call__( def inverse(self, img: torch.Tensor) -> torch.Tensor: transform = self.pop_transform(img) + return self.inverse_transform(img, transform) + + def inverse_transform(self, img: torch.Tensor, transform) -> torch.Tensor: if not isinstance(img, MetaTensor): raise RuntimeError() # we joined the cropping and padding, so put them back before calling the inverse diff --git a/monai/transforms/inverse.py b/monai/transforms/inverse.py index c1ae43e977..3e6149e346 100644 --- a/monai/transforms/inverse.py +++ b/monai/transforms/inverse.py @@ -10,6 +10,7 @@ # limitations under the License. import os +import warnings from contextlib import contextmanager from typing import Any, Hashable, Mapping, Optional, Tuple @@ -126,9 +127,6 @@ def push_transform( Returns: None, but data has been updated to store the applied transformation. - - Raises: - - RuntimeError: data is neither `MetaTensor` nor dictionary """ if not self.tracing: return @@ -147,18 +145,24 @@ def push_transform( data[self.trace_key(key)] = [] data[self.trace_key(key)].append(info) else: - raise RuntimeError("`data` should be either `MetaTensor` or dictionary.") + warnings.warn(f"`data` should be either `MetaTensor` or dictionary, got {type(data)}. {info} not tracked.") def check_transforms_match(self, transform: Mapping) -> None: """Check transforms are of same instance.""" - xform_name = transform.get(TraceKeys.CLASS_NAME, "") xform_id = transform.get(TraceKeys.ID, "") if xform_id == id(self): return + # TraceKeys.NONE to skip the id check + if xform_id == TraceKeys.NONE: + return + xform_name = transform.get(TraceKeys.CLASS_NAME, "") # basic check if multiprocessing uses 'spawn' (objects get recreated so don't have same ID) if torch.multiprocessing.get_start_method() in ("spawn", None) and xform_name == self.__class__.__name__: return - raise RuntimeError(f"Error inverting the most recently applied invertible transform {xform_name} {xform_id}.") + raise RuntimeError( + f"Error {self.__class__.__name__} getting the most recently " + f"applied invertible transform {xform_name} {xform_id} != {id(self)}." + ) def get_most_recent_transform(self, data, key: Hashable = None, check: bool = True, pop: bool = False): """ @@ -206,6 +210,8 @@ def get_most_recent_transform(self, data, key: Hashable = None, check: bool = Tr all_transforms = data[key].applied_operations else: all_transforms = data[self.trace_key(key)] + else: + raise ValueError(f"`data` should be either `MetaTensor` or dictionary, got {type(data)}.") if check: self.check_transforms_match(all_transforms[-1]) return all_transforms.pop() if pop else all_transforms[-1] @@ -250,7 +256,7 @@ def pop_transform(self, data, key: Hashable = None, check: bool = True): @contextmanager def trace_transform(self, to_trace: bool): - """Temporarily set the tracing status of a transfrom with a context manager.""" + """Temporarily set the tracing status of a transform with a context manager.""" prev = self.tracing self.tracing = to_trace yield diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py index 6625a9d791..16427122d2 100644 --- a/monai/transforms/post/dictionary.py +++ b/monai/transforms/post/dictionary.py @@ -23,6 +23,7 @@ from monai.config.type_definitions import KeysCollection, NdarrayOrTensor, PathLike from monai.data.csv_saver import CSVSaver +from monai.data.meta_tensor import MetaTensor from monai.transforms.inverse import InvertibleTransform from monai.transforms.post.array import ( Activations, @@ -623,12 +624,22 @@ def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: self.device, self.post_func, ): - transform_key = InvertibleTransform.trace_key(orig_key) - if transform_key not in d: - warnings.warn(f"transform info of `{orig_key}` is not available or no InvertibleTransform applied.") - continue - - transform_info = d[transform_key] + if isinstance(d[key], MetaTensor): + if orig_key not in d: + warnings.warn(f"transform info of `{orig_key}` is not available in MetaTensor {key}.") + continue + else: + transform_key = InvertibleTransform.trace_key(orig_key) + if transform_key not in d: + warnings.warn(f"transform info of `{orig_key}` is not available or no InvertibleTransform applied.") + continue + + if orig_key in d and isinstance(d[orig_key], MetaTensor): + transform_info = d[orig_key].applied_operations + meta_info = d[orig_key].meta + else: + transform_info = d[InvertibleTransform.trace_key(orig_key)] + meta_info = d.get(orig_meta_key or f"{orig_key}_{meta_key_postfix}", {}) if nearest_interp: transform_info = convert_inverse_interp_mode( trans_info=deepcopy(transform_info), mode="nearest", align_corners=None @@ -638,23 +649,28 @@ def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: if isinstance(input, torch.Tensor): input = input.detach() + if not isinstance(input, MetaTensor): + input = MetaTensor(input) + input.applied_operations = transform_info + input.meta = meta_info + # construct the input dict data - input_dict = {orig_key: input, transform_key: transform_info} - orig_meta_key = orig_meta_key or f"{orig_key}_{meta_key_postfix}" - if orig_meta_key in d: - input_dict[orig_meta_key] = d[orig_meta_key] + input_dict = {orig_key: input} with allow_missing_keys_mode(self.transform): # type: ignore inverted = self.transform.inverse(input_dict) # save the inverted data - d[key] = post_func(self._totensor(inverted[orig_key]).to(device) if to_tensor else inverted[orig_key]) + if to_tensor and not isinstance(inverted[orig_key], MetaTensor): + inverted_data = self._totensor(inverted[orig_key]) + else: + inverted_data = inverted[orig_key] + d[key] = post_func(inverted_data.to(device)) # save the inverted meta dict if orig_meta_key in d: meta_key = meta_key or f"{key}_{meta_key_postfix}" d[meta_key] = inverted.get(orig_meta_key) - return d diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index c4ad1eb40e..b0cc6a8e29 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -13,8 +13,9 @@ https://github.com/Project-MONAI/MONAI/wiki/MONAI_Design """ import warnings +from copy import deepcopy from enum import Enum -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import numpy as np import torch @@ -27,10 +28,10 @@ from monai.data.utils import AFFINE_TOL, compute_shape_offset, iter_patch, to_affine_nd, zoom_affine from monai.networks.layers import AffineTransform, GaussianFilter, grid_pull from monai.networks.utils import meshgrid_ij, normalize_transform -from monai.transforms.croppad.array import CenterSpatialCrop, Pad +from monai.transforms.croppad.array import CenterSpatialCrop, ResizeWithPadOrCrop from monai.transforms.intensity.array import GaussianSmooth from monai.transforms.inverse import InvertibleTransform -from monai.transforms.transform import Randomizable, RandomizableTransform, ThreadUnsafe, Transform +from monai.transforms.transform import Randomizable, RandomizableTransform, Transform from monai.transforms.utils import ( convert_pad_mode, create_control_grid, @@ -40,8 +41,9 @@ create_shear, create_translate, map_spatial_axes, + scale_affine, ) -from monai.transforms.utils_pytorch_numpy_unification import allclose, moveaxis +from monai.transforms.utils_pytorch_numpy_unification import allclose, linalg_inv, moveaxis from monai.utils import ( GridSampleMode, GridSamplePadMode, @@ -297,12 +299,13 @@ def __call__( if not USE_COMPILED: _t_l = normalize_transform( in_spatial_size, xform.device, xform.dtype, align_corners=True # type: ignore - ) + )[0] xform = _t_l @ xform # type: ignore affine_xform = Affine( affine=xform, spatial_size=spatial_size, normalized=True, image_only=True, dtype=_dtype ) - img = affine_xform(img, mode=mode, padding_mode=padding_mode) + with affine_xform.trace_transform(False): + img = affine_xform(img, mode=mode, padding_mode=padding_mode) else: affine_xform = AffineTransform( normalized=False, @@ -670,10 +673,10 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor: return data -class Flip(Transform): +class Flip(InvertibleTransform): """ Reverses the order of elements along the given spatial axis. Preserves shape. - Uses ``np.flip`` in practice. See numpy.flip for additional details: + Uses ``np.flip``/``torch.flip``. See numpy.flip for additional details: https://docs.scipy.org/doc/numpy/reference/generated/numpy.flip.html. Args: @@ -690,17 +693,44 @@ class Flip(Transform): def __init__(self, spatial_axis: Optional[Union[Sequence[int], int]] = None) -> None: self.spatial_axis = spatial_axis - def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: + def forward_meta(self, img_meta, shape, axes): + # shape and axes include the channel dim + m = dict(img_meta) + affine = m.get("affine") + if affine is None: + return m + mat = convert_to_dst_type(torch.eye(len(affine)), affine)[0] + for axis in axes: + sp = axis - 1 + mat[sp, sp], mat[sp, -1] = mat[sp, sp] * -1, shape[axis] - 1 + m["affine"] = affine @ mat + return m + + def forward_image(self, img, axes) -> torch.Tensor: + return torch.flip(img, axes) + + def __call__(self, img: torch.Tensor) -> torch.Tensor: """ Args: - img: channel first array, must have shape: (num_channels, H[, W, ..., ]), + img: channel first array, must have shape: (num_channels, H[, W, ..., ]) """ - if isinstance(img, np.ndarray): - return np.ascontiguousarray(np.flip(img, map_spatial_axes(img.ndim, self.spatial_axis))) - return torch.flip(img, map_spatial_axes(img.ndim, self.spatial_axis)) + if not isinstance(img, MetaTensor) and get_track_meta(): + img = MetaTensor(img) + axes = map_spatial_axes(img.ndim, self.spatial_axis) + out = self.forward_image(img, axes) + if isinstance(out, MetaTensor): + out.meta = self.forward_meta(out.meta, out.shape, axes) + self.push_transform(out) + return out + + def inverse(self, data: torch.Tensor) -> torch.Tensor: + self.pop_transform(data) + flipper = Flip(spatial_axis=self.spatial_axis) + with flipper.trace_transform(False): + return flipper(data) -class Resize(Transform): +class Resize(InvertibleTransform): """ Resize the input image to given spatial size (with scaling, not cropping/padding). Implemented using :py:class:`torch.nn.functional.interpolate`. @@ -752,12 +782,12 @@ def __init__( def __call__( self, - img: NdarrayOrTensor, + img: torch.Tensor, mode: Optional[Union[InterpolateMode, str]] = None, align_corners: Optional[bool] = None, anti_aliasing: Optional[bool] = None, anti_aliasing_sigma: Union[Sequence[float], float, None] = None, - ) -> NdarrayOrTensor: + ) -> torch.Tensor: """ Args: img: channel first array, must have shape: (num_channels, H[, W, ..., ]). @@ -785,8 +815,8 @@ def __call__( anti_aliasing = self.anti_aliasing if anti_aliasing is None else anti_aliasing anti_aliasing_sigma = self.anti_aliasing_sigma if anti_aliasing_sigma is None else anti_aliasing_sigma + input_ndim = img.ndim - 1 # spatial ndim if self.size_mode == "all": - input_ndim = img.ndim - 1 # spatial ndim output_ndim = len(ensure_tuple(self.spatial_size)) if output_ndim > input_ndim: input_shape = ensure_tuple_size(img.shape, output_ndim + 1, 1) @@ -806,7 +836,10 @@ def __call__( if tuple(img.shape[1:]) == spatial_size_: # spatial shape is already the desired return img - img_, *_ = convert_data_type(img, torch.Tensor, dtype=torch.float) + + original_sp_size = img.shape[1:] + img_: MetaTensor = convert_data_type(img, MetaTensor, dtype=torch.float)[0] + if anti_aliasing and any(x < y for x, y in zip(spatial_size_, img_.shape[1:])): factors = torch.div(torch.Tensor(list(img_.shape[1:])), torch.Tensor(spatial_size_)) if anti_aliasing_sigma is None: @@ -820,17 +853,56 @@ def __call__( anti_aliasing_filter = GaussianSmooth(sigma=anti_aliasing_sigma) img_ = anti_aliasing_filter(img_) + if not isinstance(img, MetaTensor) and get_track_meta(): + img = MetaTensor(img) + _mode = look_up_option(self.mode if mode is None else mode, InterpolateMode).value + _align_corners = self.align_corners if align_corners is None else align_corners + resized = torch.nn.functional.interpolate( - input=img_.unsqueeze(0), - size=spatial_size_, - mode=look_up_option(self.mode if mode is None else mode, InterpolateMode).value, - align_corners=self.align_corners if align_corners is None else align_corners, + input=img_.unsqueeze(0), size=spatial_size_, mode=_mode, align_corners=_align_corners ) out, *_ = convert_to_dst_type(resized.squeeze(0), img) + if not isinstance(out, MetaTensor): + return out + out.meta = self.forward_meta(img.meta, original_sp_size, spatial_size_) # type: ignore + self.push_transform( + out, + orig_size=original_sp_size, + extra_info={ + "mode": _mode, + "align_corners": _align_corners if _align_corners is not None else TraceKeys.NONE, + "new_dim": len(original_sp_size) - input_ndim, # additional dims appended + }, + ) return out + def forward_meta(self, img_meta, spatial_size, new_spatial_size): + affine = convert_data_type(img_meta["affine"], torch.Tensor)[0] + meta = deepcopy(img_meta) + meta["affine"] = scale_affine(affine, spatial_size, new_spatial_size) + return meta -class Rotate(Transform, ThreadUnsafe): + def inverse(self, data: torch.Tensor) -> torch.Tensor: + if not isinstance(data, MetaTensor): + raise NotImplementedError() + transform = self.pop_transform(data) + return self.inverse_transform(data, transform) + + def inverse_transform(self, data: torch.Tensor, transform) -> torch.Tensor: + orig_size = transform[TraceKeys.ORIG_SIZE] + mode = transform[TraceKeys.EXTRA_INFO]["mode"] + align_corners = transform[TraceKeys.EXTRA_INFO]["align_corners"] + xform = Resize( + spatial_size=orig_size, mode=mode, align_corners=None if align_corners == TraceKeys.NONE else align_corners + ) + with xform.trace_transform(False): + data = xform(data) + for _ in range(transform[TraceKeys.EXTRA_INFO]["new_dim"]): + data = data.squeeze(-1) # remove the additional dims + return data + + +class Rotate(InvertibleTransform): """ Rotates an input image by given angle using :py:class:`monai.networks.layers.AffineTransform`. @@ -869,16 +941,15 @@ def __init__( self.padding_mode: GridSamplePadMode = look_up_option(padding_mode, GridSamplePadMode) self.align_corners = align_corners self.dtype = dtype - self._rotation_matrix: Optional[NdarrayOrTensor] = None def __call__( self, - img: NdarrayOrTensor, + img: torch.Tensor, mode: Optional[Union[GridSampleMode, str]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = None, align_corners: Optional[bool] = None, dtype: Union[DtypeLike, torch.dtype] = None, - ) -> NdarrayOrTensor: + ) -> torch.Tensor: """ Args: img: channel first array, must have shape: [chns, H, W] or [chns, H, W, D]. @@ -900,14 +971,15 @@ def __call__( ValueError: When ``img`` spatially is not one of [2D, 3D]. """ - _dtype = dtype or self.dtype or img.dtype - - img_t, *_ = convert_data_type(img, torch.Tensor, dtype=_dtype) + if not isinstance(img, MetaTensor) and get_track_meta(): + img = MetaTensor(img) + _dtype = get_equivalent_dtype(dtype or self.dtype or img.dtype, torch.Tensor) + img_t: MetaTensor = convert_data_type(img, MetaTensor, dtype=_dtype)[0] im_shape = np.asarray(img_t.shape[1:]) # spatial dimensions input_ndim = len(im_shape) if input_ndim not in (2, 3): - raise ValueError(f"Unsupported img dimension: {input_ndim}, available options are [2, 3].") + raise ValueError(f"Unsupported image dimension: {input_ndim}, available options are [2, 3].") _angle = ensure_tuple_rep(self.angle, 1 if input_ndim == 2 else 3) transform = create_rotate(input_ndim, _angle) shift = create_translate(input_ndim, ((im_shape - 1) / 2).tolist()) @@ -923,29 +995,73 @@ def __call__( transform = shift @ transform @ shift_1 transform_t, *_ = convert_to_dst_type(transform, img_t) - + _mode = look_up_option(mode or self.mode, GridSampleMode).value + _padding_mode = look_up_option(padding_mode or self.padding_mode, GridSamplePadMode).value + _align_corners = self.align_corners if align_corners is None else align_corners xform = AffineTransform( normalized=False, - mode=look_up_option(mode or self.mode, GridSampleMode), - padding_mode=look_up_option(padding_mode or self.padding_mode, GridSamplePadMode), - align_corners=self.align_corners if align_corners is None else align_corners, + mode=_mode, + padding_mode=_padding_mode, + align_corners=_align_corners, reverse_indexing=True, ) output: torch.Tensor = xform(img_t.unsqueeze(0), transform_t, spatial_size=output_shape).float().squeeze(0) - self._rotation_matrix = transform - out: NdarrayOrTensor out, *_ = convert_to_dst_type(output, dst=img, dtype=output.dtype) + if not isinstance(out, MetaTensor): + return out + out.meta = self.forward_meta(img.meta, transform_t) # type: ignore + self.push_transform( + out, + orig_size=img_t.shape[1:], + extra_info={ + "rot_mat": transform, + "mode": _mode, + "padding_mode": _padding_mode, + "align_corners": _align_corners if _align_corners is not None else TraceKeys.NONE, + "dtype": str(_dtype)[6:], # dtype as string; remove "torch": torch.float32 -> float32 + }, + ) return out - def get_rotation_matrix(self) -> Optional[NdarrayOrTensor]: - """ - Get the most recently applied rotation matrix - This is not thread-safe. - """ - return self._rotation_matrix + def forward_meta(self, img_meta, rotate_mat): + meta_dict = deepcopy(img_meta) + affine = convert_data_type(img_meta["affine"], torch.Tensor, drop_meta=True)[0] + mat = to_affine_nd(len(affine) - 1, rotate_mat) + meta_dict["affine"] = affine @ convert_to_dst_type(mat, affine, drop_meta=True)[0] + return meta_dict + def inverse(self, data: torch.Tensor) -> torch.Tensor: + if not isinstance(data, MetaTensor): + raise NotImplementedError() + transform = self.pop_transform(data) + return self.inverse_transform(data, transform) + + def inverse_transform(self, data: torch.Tensor, transform) -> torch.Tensor: + fwd_rot_mat = transform[TraceKeys.EXTRA_INFO]["rot_mat"] + mode = transform[TraceKeys.EXTRA_INFO]["mode"] + padding_mode = transform[TraceKeys.EXTRA_INFO]["padding_mode"] + align_corners = transform[TraceKeys.EXTRA_INFO]["align_corners"] + dtype = transform[TraceKeys.EXTRA_INFO]["dtype"] + inv_rot_mat = linalg_inv(fwd_rot_mat) -class Zoom(Transform): + xform = AffineTransform( + normalized=False, + mode=mode, + padding_mode=padding_mode, + align_corners=False if align_corners == TraceKeys.NONE else align_corners, + reverse_indexing=True, + ) + img_t: torch.Tensor = convert_data_type(data, MetaTensor, dtype=dtype)[0] + transform_t, *_ = convert_to_dst_type(inv_rot_mat, img_t) + sp_size = transform[TraceKeys.ORIG_SIZE] + out: torch.Tensor = xform(img_t.unsqueeze(0), transform_t, spatial_size=sp_size).float().squeeze(0) + out = convert_to_dst_type(out, dst=data, dtype=out.dtype)[0] + if isinstance(data, MetaTensor): + out.meta = self.forward_meta(data.meta, transform_t) # type: ignore + return out + + +class Zoom(InvertibleTransform): """ Zooms an ND image using :py:class:`torch.nn.functional.interpolate`. For details, please see https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html. @@ -996,11 +1112,11 @@ def __init__( def __call__( self, - img: NdarrayOrTensor, + img: torch.Tensor, mode: Optional[Union[InterpolateMode, str]] = None, padding_mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = None, align_corners: Optional[bool] = None, - ) -> NdarrayOrTensor: + ) -> torch.Tensor: """ Args: img: channel first array, must have shape: (num_channels, H[, W, ..., ]). @@ -1020,39 +1136,78 @@ def __call__( See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html """ - img_t, *_ = convert_data_type(img, torch.Tensor, dtype=torch.float32) + if not isinstance(img, MetaTensor) and get_track_meta(): + img = MetaTensor(img) + img_t: torch.Tensor = convert_data_type(img, MetaTensor, dtype=torch.float32)[0] _zoom = ensure_tuple_rep(self.zoom, img.ndim - 1) # match the spatial image dim + _mode = look_up_option(self.mode if mode is None else mode, InterpolateMode).value + _align_corners = self.align_corners if align_corners is None else align_corners + _padding_mode = padding_mode or self.padding_mode + zoomed: NdarrayOrTensor = torch.nn.functional.interpolate( # type: ignore recompute_scale_factor=True, input=img_t.unsqueeze(0), scale_factor=list(_zoom), - mode=look_up_option(self.mode if mode is None else mode, InterpolateMode).value, - align_corners=self.align_corners if align_corners is None else align_corners, + mode=_mode, + align_corners=_align_corners, ) zoomed = zoomed.squeeze(0) + orig_size, z_size = img_t.shape, zoomed.shape - if self.keep_size and not np.allclose(img_t.shape, zoomed.shape): - - pad_vec = [(0, 0)] * len(img_t.shape) - slice_vec = [slice(None)] * len(img_t.shape) - for idx, (od, zd) in enumerate(zip(img_t.shape, zoomed.shape)): - diff = od - zd - half = abs(diff) // 2 - if diff > 0: # need padding - pad_vec[idx] = (half, diff - half) - elif diff < 0: # need slicing - slice_vec[idx] = slice(half, half + od) + out, *_ = convert_to_dst_type(zoomed, dst=img) + out.meta = self.forward_meta(img.meta, orig_size[1:], z_size[1:]) # type: ignore + do_pad_crop = self.keep_size and not np.allclose(orig_size, z_size) + if do_pad_crop: + _pad_crop = ResizeWithPadOrCrop(spatial_size=img_t.shape[1:], mode=_padding_mode) + out = _pad_crop(out) + self.push_transform( + out, + orig_size=orig_size[1:], + extra_info={ + "mode": _mode, + "align_corners": _align_corners if _align_corners is not None else TraceKeys.NONE, + "do_padcrop": do_pad_crop, + }, + ) + return out - padder = Pad(pad_vec, padding_mode or self.padding_mode) - zoomed = padder(zoomed) # type: ignore - zoomed = zoomed[tuple(slice_vec)] + def forward_meta(self, img_meta, spatial_size, new_spatial_size): + affine = convert_data_type(img_meta["affine"], torch.Tensor)[0] + meta = deepcopy(img_meta) + meta["affine"] = scale_affine(affine, spatial_size, new_spatial_size) + return meta - out, *_ = convert_to_dst_type(zoomed, dst=img) + def inverse(self, data: torch.Tensor) -> torch.Tensor: + if not isinstance(data, MetaTensor): + raise NotImplementedError() + transform = self.pop_transform(data) + return self.inverse_transform(data, transform) + + def inverse_transform(self, data: torch.Tensor, transform) -> torch.Tensor: + if not isinstance(data, MetaTensor): + raise NotImplementedError() + if transform[TraceKeys.EXTRA_INFO]["do_padcrop"]: + orig_size = transform[TraceKeys.ORIG_SIZE] + pad_or_crop = ResizeWithPadOrCrop(spatial_size=orig_size, mode="edge") + xform = data.applied_operations[-1] # remove the padding cropping + xform[TraceKeys.ID] = TraceKeys.NONE + xform[TraceKeys.EXTRA_INFO]["pad_info"][TraceKeys.ID] = TraceKeys.NONE + xform[TraceKeys.EXTRA_INFO]["crop_info"][TraceKeys.ID] = TraceKeys.NONE + data = pad_or_crop.inverse(data) + # Create inverse transform + mode = transform[TraceKeys.EXTRA_INFO]["mode"] + align_corners = transform[TraceKeys.EXTRA_INFO]["align_corners"] + inverse_transform = Resize(spatial_size=transform[TraceKeys.ORIG_SIZE]) + # Apply inverse + with inverse_transform.trace_transform(False): + out = inverse_transform( + data, mode=mode, align_corners=None if align_corners == TraceKeys.NONE else align_corners + ) return out -class Rotate90(Transform): +class Rotate90(InvertibleTransform): """ Rotate an array by 90 degrees in the plane specified by `axes`. See np.rot90 for additional details: @@ -1076,18 +1231,58 @@ def __init__(self, k: int = 1, spatial_axes: Tuple[int, int] = (0, 1)) -> None: raise ValueError("spatial_axes must be 2 int numbers to indicate the axes to rotate 90 degrees.") self.spatial_axes = spatial_axes_ - def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: + def __call__(self, img: torch.Tensor) -> torch.Tensor: """ Args: img: channel first array, must have shape: (num_channels, H[, W, ..., ]), """ - rot90: Callable = torch.rot90 if isinstance(img, torch.Tensor) else np.rot90 # type: ignore - out: NdarrayOrTensor = rot90(img, self.k, map_spatial_axes(img.ndim, self.spatial_axes)) + if not isinstance(img, MetaTensor) and get_track_meta(): + img = MetaTensor(img) + axes = map_spatial_axes(img.ndim, self.spatial_axes) + ori_shape = img.shape[1:] + out: NdarrayOrTensor = torch.rot90(img, self.k, axes) out, *_ = convert_data_type(out, dtype=img.dtype) + if not isinstance(out, MetaTensor): + return MetaTensor(out) + out.meta = self.forward_meta(img.meta, ori_shape, out.shape[1:], axes, self.k) # type: ignore + self.push_transform(out, extra_info={"axes": [d - 1 for d in axes], "k": self.k}) # compensate spatial dim return out + def forward_meta(self, img_meta, spatial_size, new_spatial_size, axes, k): + meta_dict = deepcopy(img_meta) + affine = convert_data_type(img_meta["affine"], torch.Tensor)[0] + r, sp_r = len(affine) - 1, len(spatial_size) + mat = to_affine_nd(r, create_translate(sp_r, [-float(d - 1) / 2 for d in new_spatial_size])) + if sp_r == 2: + rot90 = to_affine_nd(r, create_rotate(sp_r, [-np.pi / 2])) + else: + idx = {0, 1, 2} - set(axes) + angle = [0, 0, 0] + angle[2 - idx.pop()] = -np.pi / 2 + rot90 = to_affine_nd(r, create_rotate(sp_r, angle)) + for _ in range(k): + mat = rot90 @ mat + mat = to_affine_nd(r, create_translate(sp_r, [float(d - 1) / 2 for d in spatial_size])) @ mat + meta_dict["affine"] = affine @ convert_to_dst_type(mat, affine)[0] + return meta_dict + + def inverse(self, data: torch.Tensor) -> torch.Tensor: + if not isinstance(data, MetaTensor): + raise NotImplementedError() + transform = self.pop_transform(data) + return self.inverse_transform(data, transform) + + def inverse_transform(self, data: torch.Tensor, transform) -> torch.Tensor: + axes = transform[TraceKeys.EXTRA_INFO]["axes"] + k = transform[TraceKeys.EXTRA_INFO]["k"] + inv_k = 4 - k % 4 + xform = Rotate90(k=inv_k, spatial_axes=axes) + with xform.trace_transform(False): + data = xform(data) + return data + -class RandRotate90(RandomizableTransform): +class RandRotate90(RandomizableTransform, InvertibleTransform): """ With probability `prob`, input arrays are rotated by 90 degrees in the plane specified by `spatial_axes`. @@ -1116,7 +1311,7 @@ def randomize(self, data: Optional[Any] = None) -> None: return None self._rand_k = self.R.randint(self.max_k) + 1 - def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTensor: + def __call__(self, img: torch.Tensor, randomize: bool = True) -> torch.Tensor: """ Args: img: channel first array, must have shape: (num_channels, H[, W, ..., ]), @@ -1125,13 +1320,24 @@ def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTen if randomize: self.randomize() - if not self._do_transform: - return img + if self._do_transform: + out = Rotate90(self._rand_k, self.spatial_axes)(img) + else: + out = MetaTensor(img) if not isinstance(img, MetaTensor) and get_track_meta() else img + self.push_transform(out) + return out - return Rotate90(self._rand_k, self.spatial_axes)(img) + def inverse(self, data: torch.Tensor) -> torch.Tensor: + if not isinstance(data, MetaTensor): + raise NotImplementedError() + if not self.pop_transform(data)[TraceKeys.DO_TRANSFORM]: + return data + data.applied_operations[-1][TraceKeys.ID] = TraceKeys.NONE + rotate_xform = self.pop_transform(data) + return Rotate90().inverse_transform(data, rotate_xform) -class RandRotate(RandomizableTransform): +class RandRotate(RandomizableTransform, InvertibleTransform): """ Randomly rotate the input arrays. @@ -1202,9 +1408,10 @@ def randomize(self, data: Optional[Any] = None) -> None: self.y = self.R.uniform(low=self.range_y[0], high=self.range_y[1]) self.z = self.R.uniform(low=self.range_z[0], high=self.range_z[1]) + @deprecated_arg(name="get_matrix", since="0.9", msg_suffix="please use `img.meta` instead.") def __call__( self, - img: NdarrayOrTensor, + img: torch.Tensor, mode: Optional[Union[GridSampleMode, str]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = None, align_corners: Optional[bool] = None, @@ -1227,27 +1434,36 @@ def __call__( If None, use the data type of input data. To be compatible with other modules, the output data type is always ``np.float32``. randomize: whether to execute `randomize()` function first, default to True. - get_matrix: whether to return the rotated image and rotate matrix together, default to False. """ if randomize: self.randomize() - if not self._do_transform: - return img + if self._do_transform: + rotator = Rotate( + angle=self.x if img.ndim == 3 else (self.x, self.y, self.z), + keep_size=self.keep_size, + mode=look_up_option(mode or self.mode, GridSampleMode), + padding_mode=look_up_option(padding_mode or self.padding_mode, GridSamplePadMode), + align_corners=self.align_corners if align_corners is None else align_corners, + dtype=dtype or self.dtype or img.dtype, + ) + out = rotator(img) + else: + out = MetaTensor(img) if not isinstance(img, MetaTensor) and get_track_meta() else img + self.push_transform(out) + return out - rotator = Rotate( - angle=self.x if img.ndim == 3 else (self.x, self.y, self.z), - keep_size=self.keep_size, - mode=look_up_option(mode or self.mode, GridSampleMode), - padding_mode=look_up_option(padding_mode or self.padding_mode, GridSamplePadMode), - align_corners=self.align_corners if align_corners is None else align_corners, - dtype=dtype or self.dtype or img.dtype, - ) - img = rotator(img) - return (img, rotator.get_rotation_matrix()) if get_matrix else img + def inverse(self, data: torch.Tensor) -> torch.Tensor: + if not isinstance(data, MetaTensor): + raise NotImplementedError() + if not self.pop_transform(data)[TraceKeys.DO_TRANSFORM]: + return data + data.applied_operations[-1][TraceKeys.ID] = TraceKeys.NONE + rotate_xform = self.pop_transform(data) + return Rotate(0).inverse_transform(data, rotate_xform) -class RandFlip(RandomizableTransform): +class RandFlip(RandomizableTransform, InvertibleTransform): """ Randomly flips the image along axes. Preserves shape. See numpy.flip for additional details. @@ -1264,7 +1480,7 @@ def __init__(self, prob: float = 0.1, spatial_axis: Optional[Union[Sequence[int] RandomizableTransform.__init__(self, prob) self.flipper = Flip(spatial_axis=spatial_axis) - def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTensor: + def __call__(self, img: torch.Tensor, randomize: bool = True) -> torch.Tensor: """ Args: img: channel first array, must have shape: (num_channels, H[, W, ..., ]), @@ -1272,14 +1488,19 @@ def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTen """ if randomize: self.randomize(None) + out = self.flipper(img) if self._do_transform else img + out = MetaTensor(out) if not isinstance(out, MetaTensor) and get_track_meta() else out + self.push_transform(out) + return out - if not self._do_transform: - return img - - return self.flipper(img) + def inverse(self, data: torch.Tensor) -> torch.Tensor: + transform = self.pop_transform(data) + if not transform[TraceKeys.DO_TRANSFORM]: + return data + return self.flipper.inverse(data) -class RandAxisFlip(RandomizableTransform): +class RandAxisFlip(RandomizableTransform, InvertibleTransform): """ Randomly select a spatial axis and flip along it. See numpy.flip for additional details. @@ -1295,6 +1516,7 @@ class RandAxisFlip(RandomizableTransform): def __init__(self, prob: float = 0.1) -> None: RandomizableTransform.__init__(self, prob) self._axis: Optional[int] = None + self.flipper = Flip(spatial_axis=self._axis) def randomize(self, data: NdarrayOrTensor) -> None: super().randomize(None) @@ -1302,22 +1524,32 @@ def randomize(self, data: NdarrayOrTensor) -> None: return None self._axis = self.R.randint(data.ndim - 1) - def __call__(self, img: NdarrayOrTensor, randomize: bool = True) -> NdarrayOrTensor: + def __call__(self, img: torch.Tensor, randomize: bool = True) -> torch.Tensor: """ Args: - img: channel first array, must have shape: (num_channels, H[, W, ..., ]), + img: channel first array, must have shape: (num_channels, H[, W, ..., ]) randomize: whether to execute `randomize()` function first, default to True. """ if randomize: self.randomize(data=img) - if not self._do_transform: - return img + if self._do_transform: + self.flipper.spatial_axis = self._axis + out = self.flipper(img) + else: + out = img if not isinstance(img, MetaTensor) and get_track_meta() else img + self.push_transform(out, extra_info={"axes": self._axis}) + return out - return Flip(spatial_axis=self._axis)(img) + def inverse(self, data: torch.Tensor) -> torch.Tensor: + transform = self.pop_transform(data) + if not transform[TraceKeys.DO_TRANSFORM]: + return data + self.flipper.spatial_axis = transform[TraceKeys.EXTRA_INFO]["axes"] + return self.flipper.inverse(data) -class RandZoom(RandomizableTransform): +class RandZoom(RandomizableTransform, InvertibleTransform): """ Randomly zooms input arrays with given probability within given zoom range. @@ -1392,12 +1624,12 @@ def randomize(self, img: NdarrayOrTensor) -> None: def __call__( self, - img: NdarrayOrTensor, + img: torch.Tensor, mode: Optional[Union[InterpolateMode, str]] = None, padding_mode: Optional[Union[NumpyPadMode, PytorchPadMode, str]] = None, align_corners: Optional[bool] = None, randomize: bool = True, - ) -> NdarrayOrTensor: + ) -> torch.Tensor: """ Args: img: channel first array, must have shape 2D: (nchannels, H, W), or 3D: (nchannels, H, W, D). @@ -1422,16 +1654,27 @@ def __call__( self.randomize(img=img) if not self._do_transform: - return img + out = MetaTensor(img) if not isinstance(img, MetaTensor) and get_track_meta() else img + else: + out = Zoom( + self._zoom, + keep_size=self.keep_size, + mode=look_up_option(mode or self.mode, InterpolateMode), + padding_mode=padding_mode or self.padding_mode, + align_corners=self.align_corners if align_corners is None else align_corners, + **self.kwargs, + )(img) + self.push_transform(out) + return out - return Zoom( - self._zoom, - keep_size=self.keep_size, - mode=look_up_option(mode or self.mode, InterpolateMode), - padding_mode=padding_mode or self.padding_mode, - align_corners=self.align_corners if align_corners is None else align_corners, - **self.kwargs, - )(img) + def inverse(self, data: torch.Tensor) -> torch.Tensor: + if not isinstance(data, MetaTensor): + raise NotImplementedError() + if not self.pop_transform(data)[TraceKeys.DO_TRANSFORM]: + return data + data.applied_operations[-1][TraceKeys.ID] = TraceKeys.NONE + xform = self.pop_transform(data) + return Zoom(self._zoom).inverse_transform(data, xform) class AffineGrid(Transform): @@ -1531,8 +1774,8 @@ def __call__( affine = self.affine grid, *_ = convert_data_type(grid, torch.Tensor, device=_device, dtype=self.dtype or grid.dtype) - affine, *_ = convert_to_dst_type(affine, grid) - + affine = to_affine_nd(len(grid) - 1, affine) + affine, *_ = convert_to_dst_type(affine, grid, drop_meta=True) grid = (affine @ grid.reshape((grid.shape[0], -1))).reshape([-1] + list(grid.shape[1:])) return grid, affine @@ -1602,7 +1845,7 @@ def __init__( self.scale_params: Optional[List[float]] = None self.device = device - self.affine: Optional[NdarrayOrTensor] = None + self.affine: Optional[NdarrayOrTensor] = torch.eye(4, dtype=torch.float64) def _get_rand_param(self, param_range, add_scalar: float = 0.0): out_param = [] @@ -1791,6 +2034,8 @@ def __call__( _dtype = dtype or self.dtype or img.dtype img_t = img if isinstance(img, torch.Tensor) else torch.as_tensor(img) img_t, *_ = convert_data_type(img_t, dtype=_dtype, device=_device) + if isinstance(grid, MetaTensor): + grid = grid.as_tensor() # drops any meta/tracking transform info grid_t, *_ = convert_to_dst_type(grid, img_t) if grid_t is grid: # copy if needed (convert_data_type converts to contiguous) grid_t = grid_t.clone(memory_format=torch.contiguous_format) @@ -1832,7 +2077,7 @@ def __call__( return out_val -class Affine(Transform): +class Affine(InvertibleTransform): """ Transform ``img`` given the affine parameters. A tutorial is available: https://github.com/Project-MONAI/tutorials/blob/0.6.0/modules/transforms_demo_2d.ipynb. @@ -1926,18 +2171,19 @@ def __init__( device=device, ) self.image_only = image_only - self.resampler = Resample(norm_coords=not normalized, device=device, dtype=dtype) + self.norm_coord = not normalized + self.resampler = Resample(norm_coords=self.norm_coord, device=device, dtype=dtype) self.spatial_size = spatial_size self.mode: GridSampleMode = look_up_option(mode, GridSampleMode) self.padding_mode: GridSamplePadMode = look_up_option(padding_mode, GridSamplePadMode) def __call__( self, - img: NdarrayOrTensor, + img: torch.Tensor, spatial_size: Optional[Union[Sequence[int], int]] = None, mode: Optional[Union[GridSampleMode, str]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = None, - ) -> Union[NdarrayOrTensor, Tuple[NdarrayOrTensor, NdarrayOrTensor]]: + ) -> Union[torch.Tensor, Tuple[torch.Tensor, NdarrayOrTensor]]: """ Args: img: shape must be (num_channels, H, W[, D]), @@ -1956,14 +2202,62 @@ def __call__( Padding mode for outside grid values. Defaults to ``self.padding_mode``. See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html """ - sp_size = fall_back_tuple(self.spatial_size if spatial_size is None else spatial_size, img.shape[1:]) + if not isinstance(img, MetaTensor) and get_track_meta(): + img = MetaTensor(img) + img_size = img.shape[1:] + sp_size = fall_back_tuple(self.spatial_size if spatial_size is None else spatial_size, img_size) + _mode = mode or self.mode + _padding_mode = padding_mode or self.padding_mode grid, affine = self.affine_grid(spatial_size=sp_size) - ret = self.resampler(img, grid=grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode) - - return ret if self.image_only else (ret, affine) - + out = self.resampler(img, grid=grid, mode=_mode, padding_mode=_padding_mode) + if not isinstance(out, MetaTensor): + return out if self.image_only else (out, affine) + if not self.norm_coord: + warnings.warn("customized transform may not work with the metadata operation.") + out.meta = self.forward_meta(img.meta, affine, img_size, sp_size) # type: ignore + self.push_transform( + out, orig_size=img_size, extra_info={"affine": affine, "mode": _mode, "padding_mode": _padding_mode} + ) + return out if self.image_only else (out, affine) + + @classmethod + def compute_w_affine(cls, affine, mat, img_size, sp_size): + r = len(affine) - 1 + mat = to_affine_nd(r, mat) + shift_1 = create_translate(r, [float(d - 1) / 2 for d in img_size[:r]]) + shift_2 = create_translate(r, [-float(d - 1) / 2 for d in sp_size[:r]]) + mat = shift_1 @ convert_data_type(mat, np.ndarray, drop_meta=True)[0] @ shift_2 + return affine @ convert_to_dst_type(mat, affine)[0] + + def forward_meta(self, img_meta, mat, img_size, sp_size): + meta_dict = deepcopy(img_meta) + affine = convert_data_type(img_meta["affine"], torch.Tensor)[0] + meta_dict["affine"] = Affine.compute_w_affine(affine, mat, img_size, sp_size) + return meta_dict -class RandAffine(RandomizableTransform): + def inverse(self, data: torch.Tensor) -> torch.Tensor: + if not isinstance(data, MetaTensor): + raise NotImplementedError() + transform = self.pop_transform(data) + orig_size = transform[TraceKeys.ORIG_SIZE] + # Create inverse transform + fwd_affine = transform[TraceKeys.EXTRA_INFO]["affine"] + mode = transform[TraceKeys.EXTRA_INFO]["mode"] + padding_mode = transform[TraceKeys.EXTRA_INFO]["padding_mode"] + inv_affine = linalg_inv(fwd_affine) + inv_affine = convert_to_dst_type(inv_affine, data, dtype=inv_affine.dtype)[0] + + affine_grid = AffineGrid(affine=inv_affine) + grid, _ = affine_grid(orig_size) + # Apply inverse transform + out = self.resampler(data, grid, mode, padding_mode) + if not isinstance(out, MetaTensor): + out = MetaTensor(out) + out.meta = self.forward_meta(data.meta, inv_affine, data.shape[1:], orig_size) + return out # type: ignore + + +class RandAffine(RandomizableTransform, InvertibleTransform): """ Random affine transform. A tutorial is available: https://github.com/Project-MONAI/tutorials/blob/0.6.0/modules/transforms_demo_2d.ipynb. @@ -2110,12 +2404,13 @@ def randomize(self, data: Optional[Any] = None) -> None: def __call__( self, - img: NdarrayOrTensor, + img: torch.Tensor, spatial_size: Optional[Union[Sequence[int], int]] = None, mode: Optional[Union[GridSampleMode, str]] = None, padding_mode: Optional[Union[GridSamplePadMode, str]] = None, randomize: bool = True, - ) -> NdarrayOrTensor: + grid=None, + ) -> torch.Tensor: """ Args: img: shape must be (num_channels, H, W[, D]), @@ -2131,25 +2426,70 @@ def __call__( Padding mode for outside grid values. Defaults to ``self.padding_mode``. See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html randomize: whether to execute `randomize()` function first, default to True. + grid: precomputed grid to be used (mainly to accelerate `RandAffined`). """ if randomize: self.randomize() - # if not doing transform and spatial size doesn't change, nothing to do # except convert to float and device sp_size = fall_back_tuple(self.spatial_size if spatial_size is None else spatial_size, img.shape[1:]) do_resampling = self._do_transform or (sp_size != ensure_tuple(img.shape[1:])) + _mode = mode or self.mode + _padding_mode = padding_mode or self.padding_mode + if not isinstance(img, MetaTensor) and get_track_meta(): + img = MetaTensor(img) if not do_resampling: - img, *_ = convert_data_type(img, dtype=torch.float32, device=self.resampler.device) - return img - grid = self.get_identity_grid(sp_size) - if self._do_transform: - grid = self.rand_affine_grid(grid=grid, randomize=randomize) - out: NdarrayOrTensor = self.resampler( - img=img, grid=grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode + out: torch.Tensor = convert_data_type(img, dtype=torch.float32, device=self.resampler.device)[0] + else: + if grid is None: + grid = self.get_identity_grid(sp_size) + if self._do_transform: + grid = self.rand_affine_grid(grid=grid, randomize=randomize) + out = self.resampler(img=img, grid=grid, mode=_mode, padding_mode=_padding_mode) + mat = self.rand_affine_grid.get_transformation_matrix() + if not isinstance(out, MetaTensor): + out = MetaTensor(out) + self.push_transform( + out, + orig_size=img.shape[1:], + extra_info={"affine": mat, "mode": _mode, "padding_mode": _padding_mode, "do_resampling": do_resampling}, ) - return out + if isinstance(img, MetaTensor): + out.meta = self.forward_meta(img.meta, mat, img.shape[1:], sp_size) + return out # type: ignore + + def forward_meta(self, img_meta, mat, img_size, sp_size): + meta_dict = deepcopy(img_meta) + affine = convert_data_type(img_meta["affine"], torch.Tensor)[0] + meta_dict["affine"] = Affine.compute_w_affine(affine, mat, img_size, sp_size) + return meta_dict + + def inverse(self, data: torch.Tensor) -> torch.Tensor: + if not isinstance(data, MetaTensor): + raise NotImplementedError("data must be a MetaTensor") + + transform = self.pop_transform(data) + # if transform was not performed nothing to do. + if not transform[TraceKeys.EXTRA_INFO]["do_resampling"]: + return data + orig_size = transform[TraceKeys.ORIG_SIZE] + orig_size = fall_back_tuple(orig_size, data.shape[1:]) + # Create inverse transform + fwd_affine = transform[TraceKeys.EXTRA_INFO]["affine"] + mode = transform[TraceKeys.EXTRA_INFO]["mode"] + padding_mode = transform[TraceKeys.EXTRA_INFO]["padding_mode"] + inv_affine = linalg_inv(fwd_affine) + inv_affine = convert_to_dst_type(inv_affine, data, dtype=inv_affine.dtype)[0] + affine_grid = AffineGrid(affine=inv_affine) + grid, _ = affine_grid(orig_size) + + # Apply inverse transform + out = self.resampler(data, grid, mode, padding_mode) + if not isinstance(out, MetaTensor): + out = MetaTensor(out) + out.meta = self.forward_meta(data.meta, inv_affine, data.shape[1:], orig_size) + return out # type: ignore class Rand2DElastic(RandomizableTransform): diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 718f019c26..9006ec0305 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -16,7 +16,6 @@ """ from copy import deepcopy -from enum import Enum from typing import Any, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union import numpy as np @@ -24,13 +23,12 @@ from monai.config import DtypeLike, KeysCollection from monai.config.type_definitions import NdarrayOrTensor -from monai.networks.layers import AffineTransform +from monai.data.meta_tensor import MetaTensor from monai.networks.layers.simplelayers import GaussianFilter -from monai.transforms.croppad.array import CenterSpatialCrop, SpatialPad +from monai.transforms.croppad.array import CenterSpatialCrop from monai.transforms.inverse import InvertibleTransform from monai.transforms.spatial.array import ( Affine, - AffineGrid, Flip, GridDistortion, GridPatch, @@ -40,7 +38,6 @@ Rand3DElastic, RandAffine, RandAxisFlip, - RandFlip, RandGridDistortion, RandGridPatch, RandRotate, @@ -70,7 +67,6 @@ from monai.utils.deprecate_utils import deprecated_arg from monai.utils.enums import PostFix, TraceKeys from monai.utils.module import optional_import -from monai.utils.type_conversion import convert_data_type, convert_to_dst_type nib, _ = optional_import("nibabel") @@ -475,27 +471,16 @@ def __init__( super().__init__(keys, allow_missing_keys) self.rotator = Rotate90(k, spatial_axes) - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: d = dict(data) for key in self.key_iterator(d): - self.push_transform(d, key) d[key] = self.rotator(d[key]) return d - def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: d = deepcopy(dict(data)) for key in self.key_iterator(d): - _ = self.get_most_recent_transform(d, key) - # Create inverse transform - spatial_axes = self.rotator.spatial_axes - num_times_rotated = self.rotator.k - num_times_to_rotate = 4 - num_times_rotated - inverse_transform = Rotate90(num_times_to_rotate, spatial_axes) - # Apply inverse - d[key] = inverse_transform(d[key]) - # Remove the applied transform - self.pop_transform(d, key) - + d[key] = self.rotator.inverse(d[key]) return d @@ -540,7 +525,7 @@ def randomize(self, data: Optional[Any] = None) -> None: self._rand_k = self.R.randint(self.max_k) + 1 super().randomize(None) - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Mapping[Hashable, torch.Tensor]: self.randomize() d = dict(data) @@ -550,24 +535,19 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable for key in self.key_iterator(d): if self._do_transform: d[key] = rotator(d[key]) - self.push_transform(d, key, extra_info={"rand_k": self._rand_k}) + self.push_transform(d[key]) return d - def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: - d = deepcopy(dict(data)) + def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: + d = dict(data) for key in self.key_iterator(d): - transform = self.get_most_recent_transform(d, key) - # Check if random transform was actually performed (based on `prob`) - if transform[TraceKeys.DO_TRANSFORM]: - # Create inverse transform - num_times_rotated = transform[TraceKeys.EXTRA_INFO]["rand_k"] - num_times_to_rotate = 4 - num_times_rotated - inverse_transform = Rotate90(num_times_to_rotate, self.spatial_axes) - # Apply inverse - d[key] = inverse_transform(d[key]) - # Remove the applied transform - self.pop_transform(d, key) - + if not isinstance(d[key], MetaTensor): + continue + d[key].applied_operations[-1][TraceKeys.ID] = TraceKeys.NONE # type: ignore + if self.pop_transform(d[key])[TraceKeys.DO_TRANSFORM]: + d[key].applied_operations[-1][TraceKeys.ID] = TraceKeys.NONE # type: ignore + xform = self.pop_transform(d[key]) + d[key] = Rotate90().inverse_transform(d[key], xform) return d @@ -614,38 +594,16 @@ def __init__( self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) self.resizer = Resize(spatial_size=spatial_size, size_mode=size_mode) - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: d = dict(data) for key, mode, align_corners in self.key_iterator(d, self.mode, self.align_corners): - self.push_transform( - d, - key, - extra_info={ - "mode": mode.value if isinstance(mode, Enum) else mode, - "align_corners": align_corners if align_corners is not None else TraceKeys.NONE, - }, - ) d[key] = self.resizer(d[key], mode=mode, align_corners=align_corners) return d - def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: d = deepcopy(dict(data)) for key in self.key_iterator(d): - transform = self.get_most_recent_transform(d, key) - orig_size = transform[TraceKeys.ORIG_SIZE] - mode = transform[TraceKeys.EXTRA_INFO]["mode"] - align_corners = transform[TraceKeys.EXTRA_INFO]["align_corners"] - # Create inverse transform - inverse_transform = Resize( - spatial_size=orig_size, - mode=mode, - align_corners=None if align_corners == TraceKeys.NONE else align_corners, - ) - # Apply inverse transform - d[key] = inverse_transform(d[key]) - # Remove the applied transform - self.pop_transform(d, key) - + d[key] = self.resizer.inverse(d[key]) return d @@ -737,44 +695,16 @@ def __init__( self.mode = ensure_tuple_rep(mode, len(self.keys)) self.padding_mode = ensure_tuple_rep(padding_mode, len(self.keys)) - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: d = dict(data) for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): - orig_size = d[key].shape[1:] - d[key], affine = self.affine(d[key], mode=mode, padding_mode=padding_mode) - self.push_transform( - d, - key, - orig_size=orig_size, - extra_info={ - "affine": affine, - "mode": mode.value if isinstance(mode, Enum) else mode, - "padding_mode": padding_mode.value if isinstance(padding_mode, Enum) else padding_mode, - }, - ) + d[key], _ = self.affine(d[key], mode=mode, padding_mode=padding_mode) return d - def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: d = deepcopy(dict(data)) - for key in self.key_iterator(d): - transform = self.get_most_recent_transform(d, key) - orig_size = transform[TraceKeys.ORIG_SIZE] - # Create inverse transform - fwd_affine = transform[TraceKeys.EXTRA_INFO]["affine"] - mode = transform[TraceKeys.EXTRA_INFO]["mode"] - padding_mode = transform[TraceKeys.EXTRA_INFO]["padding_mode"] - inv_affine = np.linalg.inv(fwd_affine) - - affine_grid = AffineGrid(affine=inv_affine) - grid, _ = affine_grid(orig_size) - - # Apply inverse transform - d[key] = self.affine.resampler(d[key], grid, mode, padding_mode) - - # Remove the applied transform - self.pop_transform(d, key) - + d[key] = self.affine.inverse(d[key]) return d @@ -889,58 +819,30 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N # all the keys share the same random Affine factor self.rand_affine.randomize() - device = self.rand_affine.resampler.device spatial_size = d[first_key].shape[1:] # type: ignore sp_size = fall_back_tuple(self.rand_affine.spatial_size, spatial_size) # change image size or do random transform do_resampling = self._do_transform or (sp_size != ensure_tuple(spatial_size)) - affine: torch.Tensor = torch.eye(len(sp_size) + 1, dtype=torch.float64, device=device) # converting affine to tensor because the resampler currently only support torch backend grid = None if do_resampling: # need to prepare grid grid = self.rand_affine.get_identity_grid(sp_size) if self._do_transform: # add some random factors grid = self.rand_affine.rand_affine_grid(grid=grid) - affine = self.rand_affine.rand_affine_grid.get_transformation_matrix() for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): - self.push_transform( - d, - key, - extra_info={ - "affine": affine, - "mode": mode.value if isinstance(mode, Enum) else mode, - "padding_mode": padding_mode.value if isinstance(padding_mode, Enum) else padding_mode, - }, - ) # do the transform if do_resampling: - d[key] = self.rand_affine.resampler(d[key], grid, mode=mode, padding_mode=padding_mode) - + d[key] = self.rand_affine(d[key], mode=mode, padding_mode=padding_mode, grid=grid) + self.push_transform(d[key], extra_info={"do_resampling": do_resampling}) return d def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = deepcopy(dict(data)) for key in self.key_iterator(d): - transform = self.get_most_recent_transform(d, key) - # if transform was not performed and spatial size is None, nothing to do. - if transform[TraceKeys.DO_TRANSFORM] or self.rand_affine.spatial_size is not None: - orig_size = transform[TraceKeys.ORIG_SIZE] - # Create inverse transform - fwd_affine = transform[TraceKeys.EXTRA_INFO]["affine"] - mode = transform[TraceKeys.EXTRA_INFO]["mode"] - padding_mode = transform[TraceKeys.EXTRA_INFO]["padding_mode"] - inv_affine = np.linalg.inv(fwd_affine) - - affine_grid = AffineGrid(affine=inv_affine) - grid, _ = affine_grid(orig_size) - - # Apply inverse transform - d[key] = self.rand_affine.resampler(d[key], grid, mode, padding_mode) - - # Remove the applied transform - self.pop_transform(d, key) + if self.pop_transform(d[key])[TraceKeys.EXTRA_INFO]["do_resampling"]: + d[key] = self.rand_affine.inverse(d[key]) return d @@ -1233,22 +1135,16 @@ def __init__( super().__init__(keys, allow_missing_keys) self.flipper = Flip(spatial_axis=spatial_axis) - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: d = dict(data) for key in self.key_iterator(d): - self.push_transform(d, key) d[key] = self.flipper(d[key]) return d - def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: d = deepcopy(dict(data)) for key in self.key_iterator(d): - _ = self.get_most_recent_transform(d, key) - # Inverse is same as forward - d[key] = self.flipper(d[key]) - # Remove the applied transform - self.pop_transform(d, key) - + d[key] = self.flipper.inverse(d[key]) return d @@ -1266,7 +1162,7 @@ class RandFlipd(RandomizableTransform, MapTransform, InvertibleTransform): allow_missing_keys: don't raise exception if key is missing. """ - backend = RandFlip.backend + backend = Flip.backend def __init__( self, @@ -1277,35 +1173,34 @@ def __init__( ) -> None: MapTransform.__init__(self, keys, allow_missing_keys) RandomizableTransform.__init__(self, prob) - self.flipper = RandFlip(prob=1.0, spatial_axis=spatial_axis) + self.flipper = Flip(spatial_axis=spatial_axis) def set_random_state( self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None ) -> "RandFlipd": super().set_random_state(seed, state) - self.flipper.set_random_state(seed, state) return self - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: d = dict(data) self.randomize(None) for key in self.key_iterator(d): if self._do_transform: - d[key] = self.flipper(d[key], randomize=False) - self.push_transform(d, key) + d[key] = self.flipper(d[key]) + self.push_transform(d[key]) return d - def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: d = deepcopy(dict(data)) for key in self.key_iterator(d): - transform = self.get_most_recent_transform(d, key) - # Check if random transform was actually performed (based on `prob`) - if transform[TraceKeys.DO_TRANSFORM]: - # Inverse is same as forward - d[key] = self.flipper(d[key], randomize=False) - # Remove the applied transform - self.pop_transform(d, key) + xform = self.pop_transform(d[key]) + if not xform[TraceKeys.DO_TRANSFORM]: + continue + d[key].applied_operations[-1][TraceKeys.ID] = TraceKeys.NONE # type: ignore + self.pop_transform(d[key]) # drop the Flip + with self.flipper.trace_transform(False): + d[key] = self.flipper(d[key]) return d @@ -1337,7 +1232,7 @@ def set_random_state( self.flipper.set_random_state(seed, state) return self - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: d = dict(data) first_key: Union[Hashable, List] = self.first_key(d) if first_key == []: @@ -1350,20 +1245,14 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N for key in self.key_iterator(d): if self._do_transform: d[key] = self.flipper(d[key], randomize=False) - self.push_transform(d, key, extra_info={"axis": self.flipper._axis}) + self.push_transform(d[key]) return d - def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: d = deepcopy(dict(data)) for key in self.key_iterator(d): - transform = self.get_most_recent_transform(d, key) - # Check if random transform was actually performed (based on `prob`) - if transform[TraceKeys.DO_TRANSFORM]: - flipper = Flip(spatial_axis=transform[TraceKeys.EXTRA_INFO]["axis"]) - # Inverse is same as forward - d[key] = flipper(d[key]) - # Remove the applied transform - self.pop_transform(d, key) + if self.pop_transform(d[key])[TraceKeys.DO_TRANSFORM]: + d[key] = self.flipper.inverse(d[key]) return d @@ -1416,56 +1305,20 @@ def __init__( self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) self.dtype = ensure_tuple_rep(dtype, len(self.keys)) - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: d = dict(data) for key, mode, padding_mode, align_corners, dtype in self.key_iterator( d, self.mode, self.padding_mode, self.align_corners, self.dtype ): - orig_size = d[key].shape[1:] d[key] = self.rotator( d[key], mode=mode, padding_mode=padding_mode, align_corners=align_corners, dtype=dtype ) - rot_mat = self.rotator.get_rotation_matrix() - self.push_transform( - d, - key, - orig_size=orig_size, - extra_info={ - "rot_mat": rot_mat, - "mode": mode.value if isinstance(mode, Enum) else mode, - "padding_mode": padding_mode.value if isinstance(padding_mode, Enum) else padding_mode, - "align_corners": align_corners if align_corners is not None else TraceKeys.NONE, - }, - ) return d - def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: d = deepcopy(dict(data)) - for key, dtype in self.key_iterator(d, self.dtype): - transform = self.get_most_recent_transform(d, key) - # Create inverse transform - fwd_rot_mat = transform[TraceKeys.EXTRA_INFO]["rot_mat"] - mode = transform[TraceKeys.EXTRA_INFO]["mode"] - padding_mode = transform[TraceKeys.EXTRA_INFO]["padding_mode"] - align_corners = transform[TraceKeys.EXTRA_INFO]["align_corners"] - inv_rot_mat = np.linalg.inv(fwd_rot_mat) - - xform = AffineTransform( - normalized=False, - mode=mode, - padding_mode=padding_mode, - align_corners=False if align_corners == TraceKeys.NONE else align_corners, - reverse_indexing=True, - ) - img_t, *_ = convert_data_type(d[key], torch.Tensor, dtype=dtype) - transform_t, *_ = convert_to_dst_type(inv_rot_mat, img_t) - - out = xform(img_t.unsqueeze(0), transform_t, spatial_size=transform[TraceKeys.ORIG_SIZE]).squeeze(0) - out, *_ = convert_to_dst_type(out, dst=d[key], dtype=out.dtype) - d[key] = out - # Remove the applied transform - self.pop_transform(d, key) - + for key in self.key_iterator(d): + d[key] = self.rotator.inverse(d[key]) return d @@ -1535,7 +1388,7 @@ def set_random_state( self.rand_rotate.set_random_state(seed, state) return self - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: d = dict(data) self.randomize(None) @@ -1545,59 +1398,22 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N d, self.mode, self.padding_mode, self.align_corners, self.dtype ): if self._do_transform: - d[key], rot_mat = self.rand_rotate( + d[key] = self.rand_rotate( d[key], mode=mode, padding_mode=padding_mode, align_corners=align_corners, dtype=dtype, randomize=False, - get_matrix=True, ) - else: - rot_mat = np.eye(d[key].ndim) - self.push_transform( - d, - key, - orig_size=d[key].shape[1:], - extra_info={ - "rot_mat": rot_mat, - "mode": mode.value if isinstance(mode, Enum) else mode, - "padding_mode": padding_mode.value if isinstance(padding_mode, Enum) else padding_mode, - "align_corners": align_corners if align_corners is not None else TraceKeys.NONE, - }, - ) + self.push_transform(d[key]) return d - def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: d = deepcopy(dict(data)) - for key, dtype in self.key_iterator(d, self.dtype): - transform = self.get_most_recent_transform(d, key) - # Check if random transform was actually performed (based on `prob`) - if transform[TraceKeys.DO_TRANSFORM]: - # Create inverse transform - fwd_rot_mat = transform[TraceKeys.EXTRA_INFO]["rot_mat"] - mode = transform[TraceKeys.EXTRA_INFO]["mode"] - padding_mode = transform[TraceKeys.EXTRA_INFO]["padding_mode"] - align_corners = transform[TraceKeys.EXTRA_INFO]["align_corners"] - inv_rot_mat = np.linalg.inv(fwd_rot_mat) - - xform = AffineTransform( - normalized=False, - mode=mode, - padding_mode=padding_mode, - align_corners=False if align_corners == TraceKeys.NONE else align_corners, - reverse_indexing=True, - ) - img_t, *_ = convert_data_type(d[key], torch.Tensor, dtype=dtype) - transform_t, *_ = convert_to_dst_type(inv_rot_mat, img_t) - output: torch.Tensor - out = xform(img_t.unsqueeze(0), transform_t, spatial_size=transform[TraceKeys.ORIG_SIZE]).squeeze(0) - out, *_ = convert_to_dst_type(out, dst=d[key], dtype=out.dtype) - d[key] = out - # Remove the applied transform - self.pop_transform(d, key) - + for key in self.key_iterator(d): + if self.pop_transform(d[key])[TraceKeys.DO_TRANSFORM]: + d[key] = self.rand_rotate.inverse(d[key]) return d @@ -1651,45 +1467,18 @@ def __init__( self.align_corners = ensure_tuple_rep(align_corners, len(self.keys)) self.zoomer = Zoom(zoom=zoom, keep_size=keep_size, **kwargs) - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: d = dict(data) for key, mode, padding_mode, align_corners in self.key_iterator( d, self.mode, self.padding_mode, self.align_corners ): - self.push_transform( - d, - key, - extra_info={ - "mode": mode.value if isinstance(mode, Enum) else mode, - "padding_mode": padding_mode.value if isinstance(padding_mode, Enum) else padding_mode, - "align_corners": align_corners if align_corners is not None else TraceKeys.NONE, - }, - ) d[key] = self.zoomer(d[key], mode=mode, padding_mode=padding_mode, align_corners=align_corners) return d - def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: d = deepcopy(dict(data)) for key in self.key_iterator(d): - transform = self.get_most_recent_transform(d, key) - # Create inverse transform - zoom = np.array(self.zoomer.zoom) - inverse_transform = Zoom(zoom=(1 / zoom).tolist(), keep_size=self.zoomer.keep_size) - mode = transform[TraceKeys.EXTRA_INFO]["mode"] - padding_mode = transform[TraceKeys.EXTRA_INFO]["padding_mode"] - align_corners = transform[TraceKeys.EXTRA_INFO]["align_corners"] - # Apply inverse - d[key] = inverse_transform( - d[key], - mode=mode, - padding_mode=padding_mode, - align_corners=None if align_corners == TraceKeys.NONE else align_corners, - ) - # Size might be out by 1 voxel so pad - d[key] = SpatialPad(transform[TraceKeys.ORIG_SIZE], mode="edge")(d[key]) # type: ignore - # Remove the applied transform - self.pop_transform(d, key) - + d[key] = self.zoomer.inverse(d[key]) return d @@ -1761,7 +1550,7 @@ def set_random_state( self.rand_zoom.set_random_state(seed, state) return self - def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: d = dict(data) first_key: Union[Hashable, List] = self.first_key(d) if first_key == []: @@ -1778,42 +1567,14 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N d[key] = self.rand_zoom( d[key], mode=mode, padding_mode=padding_mode, align_corners=align_corners, randomize=False ) - self.push_transform( - d, - key, - extra_info={ - "zoom": self.rand_zoom._zoom, - "mode": mode.value if isinstance(mode, Enum) else mode, - "padding_mode": padding_mode.value if isinstance(padding_mode, Enum) else padding_mode, - "align_corners": align_corners if align_corners is not None else TraceKeys.NONE, - }, - ) + self.push_transform(d[key]) return d - def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: d = deepcopy(dict(data)) for key in self.key_iterator(d): - transform = self.get_most_recent_transform(d, key) - # Check if random transform was actually performed (based on `prob`) - if transform[TraceKeys.DO_TRANSFORM]: - # Create inverse transform - zoom = np.array(transform[TraceKeys.EXTRA_INFO]["zoom"]) - mode = transform[TraceKeys.EXTRA_INFO]["mode"] - padding_mode = transform[TraceKeys.EXTRA_INFO]["padding_mode"] - align_corners = transform[TraceKeys.EXTRA_INFO]["align_corners"] - inverse_transform = Zoom(zoom=(1 / zoom).tolist(), keep_size=self.rand_zoom.keep_size) - # Apply inverse - d[key] = inverse_transform( - d[key], - mode=mode, - padding_mode=padding_mode, - align_corners=None if align_corners == TraceKeys.NONE else align_corners, - ) - # Size might be out by 1 voxel so pad - d[key] = SpatialPad(transform[TraceKeys.ORIG_SIZE], mode="edge")(d[key]) # type: ignore - # Remove the applied transform - self.pop_transform(d, key) - + if self.pop_transform(d[key])[TraceKeys.DO_TRANSFORM]: + d[key] = self.rand_zoom.inverse(d[key]) return d diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index bfd22a341b..d7aa0c04a6 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -395,6 +395,8 @@ def __call__(self, img: NdarrayOrTensor): """ Apply the transform to `img` and make it contiguous. """ + if isinstance(img, MetaTensor): + img.applied_operations = [] # drops tracking info return convert_to_tensor(img, dtype=self.dtype, device=self.device, wrap_sequence=self.wrap_sequence) @@ -411,6 +413,9 @@ class EnsureType(Transform): device: for Tensor data type, specify the target device. wrap_sequence: if `False`, then lists will recursively call this function, default to `True`. E.g., if `False`, `[1, 2]` -> `[tensor(1), tensor(2)]`, if `True`, then `[1, 2]` -> `tensor([1, 2])`. + drop_meta: whether to drop the meta information of the input data, default to `True`. + If `True`, then the meta information will be dropped quietly, unless the output type is MetaTensor. + If `False`, converting a MetaTensor into a non-metatensor instance will raise an error. """ @@ -422,11 +427,13 @@ def __init__( dtype: Optional[Union[DtypeLike, torch.dtype]] = None, device: Optional[torch.device] = None, wrap_sequence: bool = True, + drop_meta: bool = True, ) -> None: self.data_type = look_up_option(data_type.lower(), {"tensor", "numpy"}) self.dtype = dtype self.device = device self.wrap_sequence = wrap_sequence + self.drop_meta = drop_meta def __call__(self, data: NdarrayOrTensor): """ @@ -440,7 +447,12 @@ def __call__(self, data: NdarrayOrTensor): output_type = torch.Tensor if self.data_type == "tensor" else np.ndarray out: NdarrayOrTensor out, *_ = convert_data_type( - data=data, output_type=output_type, dtype=self.dtype, device=self.device, wrap_sequence=self.wrap_sequence + data=data, + output_type=output_type, + dtype=self.dtype, + device=self.device, + wrap_sequence=self.wrap_sequence, + drop_meta=self.drop_meta, ) return out @@ -1051,7 +1063,8 @@ def __call__(self, img: NdarrayOrTensor): img: PyTorch Tensor data for the TorchVision transform. """ - img_t, *_ = convert_data_type(img, torch.Tensor) + img_t, *_ = convert_data_type(img, torch.Tensor, drop_meta=True) + out = self.trans(img_t) out, *_ = convert_to_dst_type(src=out, dst=img) return out diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 74c215eb63..db061e9ab2 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -552,6 +552,7 @@ def __init__( dtype: Union[DtypeLike, torch.dtype] = None, device: Optional[torch.device] = None, wrap_sequence: bool = True, + drop_meta: bool = True, allow_missing_keys: bool = False, ) -> None: """ @@ -563,10 +564,15 @@ def __init__( device: for Tensor data type, specify the target device. wrap_sequence: if `False`, then lists will recursively call this function, default to `True`. E.g., if `False`, `[1, 2]` -> `[tensor(1), tensor(2)]`, if `True`, then `[1, 2]` -> `tensor([1, 2])`. + drop_meta: whether to drop the meta information of the input data, default to `True`. + If `True`, then the meta information will be dropped quietly, unless the output type is MetaTensor. + If `False`, converting a MetaTensor into a non-metatensor instance will raise an error. allow_missing_keys: don't raise exception if key is missing. """ super().__init__(keys, allow_missing_keys) - self.converter = EnsureType(data_type=data_type, dtype=dtype, device=device, wrap_sequence=wrap_sequence) + self.converter = EnsureType( + data_type=data_type, dtype=dtype, device=device, wrap_sequence=wrap_sequence, drop_meta=drop_meta + ) def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 9b148d7587..d6b1fcffef 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -105,6 +105,7 @@ "convert_pad_mode", "convert_to_contiguous", "get_unique_labels", + "scale_affine", ] @@ -1185,16 +1186,13 @@ def map_spatial_axes( """ if spatial_axes is None: - spatial_axes_ = list(range(1, img_ndim) if channel_first else range(img_ndim - 1)) - - else: - spatial_axes_ = [] - for a in ensure_tuple(spatial_axes): - if channel_first: - spatial_axes_.append(a if a < 0 else a + 1) - else: - spatial_axes_.append(a - 1 if a < 0 else a) - + return list(range(1, img_ndim) if channel_first else range(img_ndim - 1)) + spatial_axes_ = [] + for a in ensure_tuple(spatial_axes): + if channel_first: + spatial_axes_.append(a % img_ndim if a < 0 else a + 1) + else: + spatial_axes_.append((a - 1) % (img_ndim - 1) if a < 0 else a) return spatial_axes_ @@ -1573,5 +1571,30 @@ def convert_to_contiguous(data, **kwargs): return data +def scale_affine(affine, spatial_size, new_spatial_size, centered: bool = True): + """ + Scale the affine matrix according to the new spatial size. + + Args: + affine: affine matrix to scale. + spatial_size: original spatial size. + new_spatial_size: new spatial size. + centered: whether the scaling is with respect to + the image center (True, default) or corner (False). + + Returns: + Scaled affine matrix. + + """ + if spatial_size == new_spatial_size: + return affine + r = len(affine) - 1 + s = np.array([float(o) / max(n, 0) for o, n in zip(spatial_size, new_spatial_size)]) + scale = create_scale(r, s.tolist()) + if centered: + scale[:r, -1] = (np.diag(scale)[:r] - 1) / 2 # type: ignore + return affine @ convert_to_dst_type(scale, affine)[0] + + if __name__ == "__main__": print_transform_backends() diff --git a/monai/transforms/utils_pytorch_numpy_unification.py b/monai/transforms/utils_pytorch_numpy_unification.py index 2aedc77dd7..5e84efafe7 100644 --- a/monai/transforms/utils_pytorch_numpy_unification.py +++ b/monai/transforms/utils_pytorch_numpy_unification.py @@ -376,7 +376,7 @@ def mode(x: NdarrayTensor, dim: int = -1, to_long: bool = True) -> NdarrayTensor to_long: convert input to long before performing mode. """ dtype = torch.int64 if to_long else None - x_t, *_ = convert_data_type(x, torch.Tensor, dtype=dtype) + x_t, *_ = convert_data_type(x, torch.Tensor, dtype=dtype, drop_meta=True) o_t = torch.mode(x_t, dim).values o, *_ = convert_to_dst_type(o_t, x) return o @@ -389,3 +389,14 @@ def unique(x: NdarrayTensor) -> NdarrayTensor: x: array/tensor """ return torch.unique(x) if isinstance(x, torch.Tensor) else np.unique(x) # type: ignore + + +def linalg_inv(x: NdarrayTensor) -> NdarrayTensor: + """`torch.linalg.inv` with equivalent implementation for numpy. + + Args: + x: array/tensor + """ + if isinstance(x, torch.Tensor) and hasattr(torch, "inverse"): # pytorch 1.7.0 + return torch.inverse(x) # type: ignore + return torch.linalg.inv(x) if isinstance(x, torch.Tensor) else np.linalg.inv(x) # type: ignore diff --git a/monai/utils/type_conversion.py b/monai/utils/type_conversion.py index 30eb045a57..6199be212e 100644 --- a/monai/utils/type_conversion.py +++ b/monai/utils/type_conversion.py @@ -10,11 +10,13 @@ # limitations under the License. import re +from copy import deepcopy from typing import Any, Optional, Sequence, Tuple, Type, Union import numpy as np import torch +import monai from monai.config.type_definitions import DtypeLike, NdarrayTensor from monai.utils import optional_import @@ -112,12 +114,10 @@ def convert_to_tensor( wrap_sequence: if `False`, then lists will recursively call this function. E.g., `[1, 2]` -> `[tensor(1), tensor(2)]`. If `True`, then `[1, 2]` -> `tensor([1, 2])`. - """ - # avoids circular import - from monai.data.meta_tensor import MetaTensor + """ if isinstance(data, torch.Tensor): - if isinstance(data, MetaTensor): + if isinstance(data, monai.data.MetaTensor): data = data.as_tensor() return data.to(dtype=dtype, device=device, memory_format=torch.contiguous_format) # type: ignore if isinstance(data, np.ndarray): @@ -160,13 +160,10 @@ def convert_to_meta_tensor( E.g., `[1, 2]` -> `[tensor(1), tensor(2)]`. If `True`, then `[1, 2]` -> `tensor([1, 2])`. """ - # avoids circular import - from monai.data.meta_tensor import MetaTensor - if isinstance(data, torch.Tensor): out = data.to(dtype=dtype, device=device, memory_format=torch.contiguous_format) # type: ignore - if not isinstance(out, MetaTensor): - out = MetaTensor(out) + if not isinstance(out, monai.data.MetaTensor): + out = monai.data.MetaTensor(out) return out if isinstance(data, np.ndarray): # skip array of string classes and object, refer to: @@ -176,20 +173,20 @@ def convert_to_meta_tensor( # `ascontiguousarray` will add 1 dim if img has no dim, so we only apply on data with dims if data.ndim > 0: data = np.ascontiguousarray(data) - return MetaTensor(torch.as_tensor(data, dtype=dtype, device=device)) # type: ignore + return monai.data.MetaTensor(torch.as_tensor(data, dtype=dtype, device=device)) # type: ignore elif (has_cp and isinstance(data, cp_ndarray)) or isinstance(data, (float, int, bool)): - return MetaTensor(torch.as_tensor(data, dtype=dtype, device=device)) # type: ignore + return monai.data.MetaTensor(torch.as_tensor(data, dtype=dtype, device=device)) # type: ignore elif isinstance(data, list): list_ret = [convert_to_meta_tensor(i, dtype=dtype, device=device) for i in data] return ( - MetaTensor(torch.as_tensor(list_ret, dtype=dtype, device=device)) # type: ignore + monai.data.MetaTensor(torch.as_tensor(list_ret, dtype=dtype, device=device)) # type: ignore if wrap_sequence else list_ret ) elif isinstance(data, tuple): tuple_ret = tuple(convert_to_meta_tensor(i, dtype=dtype, device=device) for i in data) return ( - MetaTensor(torch.as_tensor(tuple_ret, dtype=dtype, device=device)) # type: ignore + monai.data.MetaTensor(torch.as_tensor(tuple_ret, dtype=dtype, device=device)) # type: ignore if wrap_sequence else tuple_ret ) @@ -274,6 +271,7 @@ def convert_data_type( device: Optional[torch.device] = None, dtype: Union[DtypeLike, torch.dtype] = None, wrap_sequence: bool = False, + drop_meta: bool = True, ) -> Tuple[NdarrayTensor, type, Optional[torch.device]]: """ Convert to `torch.Tensor`/`np.ndarray` from `torch.Tensor`/`np.ndarray`/`float`/`int` etc. @@ -287,6 +285,10 @@ def convert_data_type( If left blank, it remains unchanged. wrap_sequence: if `False`, then lists will recursively call this function. E.g., `[1, 2]` -> `[array(1), array(2)]`. If `True`, then `[1, 2]` -> `array([1, 2])`. + drop_meta: whether to drop the meta information of the input data, default to `True`. + If `True`, then the meta information will be dropped quietly, unless the output type is MetaTensor. + If `False`, converting a MetaTensor into a non-tensor instance will raise an error. + Returns: modified data, orig_type, orig_device @@ -299,12 +301,9 @@ def convert_data_type( (1.0, , None) """ - # avoids circular import - from monai.data.meta_tensor import MetaTensor - orig_type: type - if isinstance(data, MetaTensor): - orig_type = MetaTensor + if isinstance(data, monai.data.MetaTensor): + orig_type = monai.data.MetaTensor elif isinstance(data, torch.Tensor): orig_type = torch.Tensor elif isinstance(data, np.ndarray): @@ -320,9 +319,23 @@ def convert_data_type( dtype_ = get_equivalent_dtype(dtype, output_type) + # input MetaTensor, out type torch tensor, this will potentially drop the meta info + is_meta_to_tensor = ( + issubclass(output_type, torch.Tensor) + and not issubclass(output_type, monai.data.MetaObj) + and isinstance(data, monai.data.MetaObj) + ) + if not drop_meta: + if is_meta_to_tensor: + output_type = type(data) + else: + raise RuntimeError( + f"the specified output_type {output_type} is not compatible with option drop_meta=False." + ) + data_: NdarrayTensor - if issubclass(output_type, MetaTensor): + if issubclass(output_type, monai.data.MetaTensor): data_ = convert_to_meta_tensor(data, dtype=dtype_, device=device, wrap_sequence=wrap_sequence) return data_, orig_type, orig_device if issubclass(output_type, torch.Tensor): @@ -338,7 +351,11 @@ def convert_data_type( def convert_to_dst_type( - src: Any, dst: NdarrayTensor, dtype: Union[DtypeLike, torch.dtype, None] = None, wrap_sequence: bool = False + src: Any, + dst: NdarrayTensor, + dtype: Union[DtypeLike, torch.dtype, None] = None, + wrap_sequence: bool = False, + drop_meta: bool = True, ) -> Tuple[NdarrayTensor, type, Optional[torch.device]]: """ Convert source data to the same data type and device as the destination data. @@ -352,27 +369,37 @@ def convert_to_dst_type( dtype: an optional argument if the target `dtype` is different from the original `dst`'s data type. wrap_sequence: if `False`, then lists will recursively call this function. E.g., `[1, 2]` -> `[array(1), array(2)]`. If `True`, then `[1, 2]` -> `array([1, 2])`. + drop_meta: whether to drop the meta information of the input data, default to `True`. + If `True`, then the meta information will be dropped quietly, unless the output type is MetaTensor. + If `False`, converting a MetaTensor into a non-tensor instance will raise an error. See Also: :func:`convert_data_type` """ - # avoids circular import - from monai.data.meta_tensor import MetaTensor device = dst.device if isinstance(dst, torch.Tensor) else None if dtype is None: dtype = dst.dtype + copy_meta = False output_type: Any - if isinstance(dst, MetaTensor): - output_type = MetaTensor + if isinstance(dst, monai.data.MetaTensor): + output_type = monai.data.MetaTensor + if not isinstance(src, monai.data.MetaTensor): + copy_meta = True # converting a non-meta tensor to a meta tensor, probably take the metadata as well. elif isinstance(dst, torch.Tensor): output_type = torch.Tensor elif isinstance(dst, np.ndarray): output_type = np.ndarray else: output_type = type(dst) - return convert_data_type(data=src, output_type=output_type, device=device, dtype=dtype, wrap_sequence=wrap_sequence) + output: NdarrayTensor + output, _type, _device = convert_data_type( + data=src, output_type=output_type, device=device, dtype=dtype, wrap_sequence=wrap_sequence, drop_meta=drop_meta + ) + if copy_meta and isinstance(output, monai.data.MetaTensor): # type: ignore + output.meta, output.applied_operations = deepcopy(dst.meta), deepcopy(dst.applied_operations) # type: ignore + return output, _type, _device def convert_to_list(data: Union[Sequence, torch.Tensor, np.ndarray]) -> list: diff --git a/runtests.sh b/runtests.sh index a69c408e3c..cf8f1435d8 100755 --- a/runtests.sh +++ b/runtests.sh @@ -268,6 +268,7 @@ do --autofix) doIsortFix=true doBlackFix=true + doPrecommit=true doIsortFormat=true doBlackFormat=true doCopyRight=true @@ -403,6 +404,28 @@ then fi fi +if [ $doPrecommit = true ] +then + set +e # disable exit on failure so that diagnostics can be given on failure + echo "${separator}${blue}pre-commit${noColor}" + + # ensure that the necessary packages for code format testing are installed + if ! is_pip_installed pre_commit + then + install_deps + fi + ${cmdPrefix}${PY_EXE} -m pre_commit run --all-files + + pre_commit_status=$? + if [ ${pre_commit_status} -ne 0 ] + then + print_style_fail_msg + exit ${pre_commit_status} + else + echo "${green}passed!${noColor}" + fi + set -e # enable exit on failure +fi if [ $doIsortFormat = true ] then @@ -501,29 +524,6 @@ then set -e # enable exit on failure fi -if [ $doPrecommit = true ] -then - set +e # disable exit on failure so that diagnostics can be given on failure - echo "${separator}${blue}pre-commit${noColor}" - - # ensure that the necessary packages for code format testing are installed - if ! is_pip_installed pre_commit - then - install_deps - fi - ${cmdPrefix}${PY_EXE} -m pre_commit run --all-files - - pre_commit_status=$? - if [ ${pre_commit_status} -ne 0 ] - then - print_style_fail_msg - exit ${pre_commit_status} - else - echo "${green}passed!${noColor}" - fi - set -e # enable exit on failure -fi - if [ $doPylintFormat = true ] then set +e # disable exit on failure so that diagnostics can be given on failure diff --git a/tests/test_affine.py b/tests/test_affine.py index d681d2941b..9803baef6c 100644 --- a/tests/test_affine.py +++ b/tests/test_affine.py @@ -16,7 +16,7 @@ from parameterized import parameterized from monai.transforms import Affine -from tests.utils import TEST_NDARRAYS, assert_allclose +from tests.utils import TEST_NDARRAYS, assert_allclose, test_local_inversion TESTS = [] for p in TEST_NDARRAYS: @@ -159,7 +159,8 @@ def test_affine(self, input_param, input_data, expected_val): result = g(**input_data) if isinstance(result, tuple): result = result[0] - assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) + test_local_inversion(g, result, input_data["img"]) + assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4, type_test=False) if __name__ == "__main__": diff --git a/tests/test_affined.py b/tests/test_affined.py index 665c93d23f..4b3addf1dc 100644 --- a/tests/test_affined.py +++ b/tests/test_affined.py @@ -16,7 +16,7 @@ from parameterized import parameterized from monai.transforms import Affined -from tests.utils import TEST_NDARRAYS, assert_allclose +from tests.utils import TEST_NDARRAYS, assert_allclose, test_local_inversion TESTS = [] for p in TEST_NDARRAYS: @@ -160,8 +160,9 @@ class TestAffined(unittest.TestCase): @parameterized.expand(TESTS) def test_affine(self, input_param, input_data, expected_val): g = Affined(**input_param) - result = g(input_data)["img"] - assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4) + result = g(input_data) + test_local_inversion(g, result, input_data, dict_key="img") + assert_allclose(result["img"], expected_val, rtol=1e-4, atol=1e-4, type_test=False) if __name__ == "__main__": diff --git a/tests/test_convert_data_type.py b/tests/test_convert_data_type.py index bda392f8ea..0e457fa134 100644 --- a/tests/test_convert_data_type.py +++ b/tests/test_convert_data_type.py @@ -47,7 +47,7 @@ class TestTensor(torch.Tensor): class TestConvertDataType(unittest.TestCase): @parameterized.expand(TESTS) def test_convert_data_type(self, in_image, im_out): - converted_im, orig_type, orig_device = convert_data_type(in_image, type(im_out)) + converted_im, orig_type, orig_device = convert_data_type(in_image, type(im_out), drop_meta=True) # check input is unchanged self.assertEqual(type(in_image), orig_type) if isinstance(in_image, torch.Tensor): @@ -79,7 +79,7 @@ class TestConvertDataSame(unittest.TestCase): # add test for subclass of Tensor @parameterized.expand(TESTS + [(np.array(1.0), TestTensor(np.array(1.0)))]) def test_convert_data_type(self, in_image, im_out): - converted_im, orig_type, orig_device = convert_to_dst_type(in_image, im_out) + converted_im, orig_type, orig_device = convert_to_dst_type(in_image, im_out, drop_meta=True) # check input is unchanged self.assertEqual(type(in_image), orig_type) if isinstance(in_image, torch.Tensor): diff --git a/tests/test_ensure_typed.py b/tests/test_ensure_typed.py index cadab9bd56..4066df723f 100644 --- a/tests/test_ensure_typed.py +++ b/tests/test_ensure_typed.py @@ -14,6 +14,7 @@ import numpy as np import torch +from monai.data import MetaTensor from monai.transforms import EnsureTyped from tests.utils import assert_allclose @@ -90,6 +91,19 @@ def test_dict(self): self.assertEqual(result["meta"]["path"], "temp/test") self.assertEqual(result["extra"], None) + def test_error(self): + # simulate metatenor input data but drop_meta=False to not removing the meta + test_data = { + "img": MetaTensor( + np.array([1.0, 2.0], dtype=np.float32), + meta={"dims": 3, "size": np.array([1, 2, 3]), "path": "temp/test"}, + ) + } + with self.assertRaises(RuntimeError): + EnsureTyped(keys="img", data_type="numpy", device="cpu", drop_meta=False)(test_data) + result = EnsureTyped(keys="img", data_type="tensor", device="cpu", drop_meta=False)(test_data) + self.assertTrue(isinstance(result["img"], MetaTensor)) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_flip.py b/tests/test_flip.py index 17cf0d2c39..0894f1993b 100644 --- a/tests/test_flip.py +++ b/tests/test_flip.py @@ -15,7 +15,7 @@ from parameterized import parameterized from monai.transforms import Flip -from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose, test_local_inversion INVALID_CASES = [("wrong_axis", ["s", 1], TypeError), ("not_numbers", "s", TypeError)] @@ -37,7 +37,8 @@ def test_correct_results(self, _, spatial_axis): expected = [np.flip(channel, spatial_axis) for channel in self.imt[0]] expected = np.stack(expected) result = flip(im) - assert_allclose(result, p(expected)) + assert_allclose(result, p(expected), type_test=False) + test_local_inversion(flip, result, im) if __name__ == "__main__": diff --git a/tests/test_flipd.py b/tests/test_flipd.py index 900779f4e0..6dda13ae2d 100644 --- a/tests/test_flipd.py +++ b/tests/test_flipd.py @@ -15,7 +15,7 @@ from parameterized import parameterized from monai.transforms import Flipd -from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose, test_local_inversion INVALID_CASES = [("wrong_axis", ["s", 1], TypeError), ("not_numbers", "s", TypeError)] @@ -35,8 +35,10 @@ def test_correct_results(self, _, spatial_axis): flip = Flipd(keys="img", spatial_axis=spatial_axis) expected = [np.flip(channel, spatial_axis) for channel in self.imt[0]] expected = np.stack(expected) - result = flip({"img": p(self.imt[0])})["img"] - assert_allclose(result, p(expected)) + im = p(self.imt[0]) + result = flip({"img": im})["img"] + assert_allclose(result, p(expected), type_test=False) + test_local_inversion(flip, {"img": result}, {"img": im}, "img") if __name__ == "__main__": diff --git a/tests/test_integration_bundle_run.py b/tests/test_integration_bundle_run.py index 556e0bbaea..c81836099d 100644 --- a/tests/test_integration_bundle_run.py +++ b/tests/test_integration_bundle_run.py @@ -92,7 +92,7 @@ def test_shape(self, config_file, expected_shape): else: override = f"--network %{overridefile1}#move_net --dataset#_target_ %{overridefile2}" # test with `monai.bundle` as CLI entry directly - cmd = f"-m monai.bundle run evaluating --postprocessing#transforms#3#output_postfix seg {override}" + cmd = f"-m monai.bundle run evaluating --postprocessing#transforms#2#output_postfix seg {override}" la = ["coverage", "run"] + cmd.split(" ") + ["--meta_file", meta_file] + ["--config_file", config_file] test_env = os.environ.copy() print(f"CUDA_VISIBLE_DEVICES in {__file__}", test_env.get("CUDA_VISIBLE_DEVICES")) diff --git a/tests/test_integration_determinism.py b/tests/test_integration_determinism.py index 64c018b4f5..94d2325514 100644 --- a/tests/test_integration_determinism.py +++ b/tests/test_integration_determinism.py @@ -18,7 +18,7 @@ from monai.data import create_test_image_2d from monai.losses import DiceLoss from monai.networks.nets import UNet -from monai.transforms import AddChannel, Compose, RandRotate90, RandSpatialCrop, ScaleIntensity, ToTensor +from monai.transforms import AddChannel, Compose, RandRotate90, RandSpatialCrop, ScaleIntensity from monai.utils import set_determinism from tests.utils import DistTestCase, TimedCall @@ -47,7 +47,7 @@ def __len__(self): loss = DiceLoss(sigmoid=True) opt = torch.optim.Adam(net.parameters(), 1e-2) train_transforms = Compose( - [AddChannel(), ScaleIntensity(), RandSpatialCrop((96, 96), random_size=False), RandRotate90(), ToTensor()] + [AddChannel(), ScaleIntensity(), RandSpatialCrop((96, 96), random_size=False), RandRotate90()] ) src = DataLoader(_TestBatch(train_transforms), batch_size=batch_size, shuffle=True) diff --git a/tests/test_integration_fast_train.py b/tests/test_integration_fast_train.py index 5f528e072b..08143ff690 100644 --- a/tests/test_integration_fast_train.py +++ b/tests/test_integration_fast_train.py @@ -34,10 +34,8 @@ Compose, CropForegroundd, EnsureChannelFirstd, - EnsureType, EnsureTyped, FgBgToIndicesd, - FromMetaTensord, LoadImaged, RandAffined, RandAxisFlipd, @@ -90,15 +88,13 @@ def test_train_timing(self): LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys=["image", "label"]), Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest")), - FromMetaTensord(["image", "label"]), ScaleIntensityd(keys="image"), CropForegroundd(keys=["image", "label"], source_key="image"), # pre-compute foreground and background indexes # and cache them to accelerate training FgBgToIndicesd(keys="label", fg_postfix="_fg", bg_postfix="_bg"), - # change to execute transforms with Tensor data - EnsureTyped(keys=["image", "label"]), # move the data to GPU and cache to avoid CPU -> GPU sync in every epoch + EnsureTyped(keys=["image", "label"], drop_meta=True), ToDeviced(keys=["image", "label"], device=device), # randomly crop out patch samples from big # image based on pos / neg ratio @@ -137,10 +133,8 @@ def test_train_timing(self): LoadImaged(keys=["image", "label"]), EnsureChannelFirstd(keys=["image", "label"]), Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "nearest")), - FromMetaTensord(["image", "label"]), ScaleIntensityd(keys="image"), CropForegroundd(keys=["image", "label"], source_key="image"), - EnsureTyped(keys=["image", "label"]), # move the data to GPU and cache to avoid CPU -> GPU sync in every epoch ToDeviced(keys=["image", "label"], device=device), ] @@ -173,8 +167,8 @@ def test_train_timing(self): optimizer = Novograd(model.parameters(), learning_rate * 10) scaler = torch.cuda.amp.GradScaler() - post_pred = Compose([EnsureType(), AsDiscrete(argmax=True, to_onehot=2)]) - post_label = Compose([EnsureType(), AsDiscrete(to_onehot=2)]) + post_pred = AsDiscrete(argmax=True, to_onehot=2) + post_label = AsDiscrete(to_onehot=2) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) diff --git a/tests/test_integration_segmentation_3d.py b/tests/test_integration_segmentation_3d.py index 47e4a53fc1..e98c7a3d6e 100644 --- a/tests/test_integration_segmentation_3d.py +++ b/tests/test_integration_segmentation_3d.py @@ -38,8 +38,6 @@ SaveImage, ScaleIntensityd, Spacingd, - ToTensor, - ToTensord, ) from monai.utils import set_determinism from monai.utils.enums import PostFix @@ -65,13 +63,11 @@ def run_training_test(root_dir, device="cuda:0", cachedataset=0, readers=(None, # resampling with align_corners=True or dtype=float64 will generate # slight different results between PyTorch 1.5 an 1.6 Spacingd(keys=["img", "seg"], pixdim=[1.2, 0.8, 0.7], mode=["bilinear", "nearest"], dtype=np.float32), - FromMetaTensord(["img", "seg"]), ScaleIntensityd(keys="img"), RandCropByPosNegLabeld( keys=["img", "seg"], label_key="seg", spatial_size=[96, 96, 96], pos=1, neg=1, num_samples=4 ), RandRotate90d(keys=["img", "seg"], prob=0.8, spatial_axes=[0, 2]), - ToTensord(keys=["img", "seg"]), ] ) train_transforms.set_random_state(1234) @@ -82,9 +78,7 @@ def run_training_test(root_dir, device="cuda:0", cachedataset=0, readers=(None, # resampling with align_corners=True or dtype=float64 will generate # slight different results between PyTorch 1.5 an 1.6 Spacingd(keys=["img", "seg"], pixdim=[1.2, 0.8, 0.7], mode=["bilinear", "nearest"], dtype=np.float32), - FromMetaTensord(["img", "seg"]), ScaleIntensityd(keys="img"), - ToTensord(keys=["img", "seg"]), ] ) @@ -100,7 +94,7 @@ def run_training_test(root_dir, device="cuda:0", cachedataset=0, readers=(None, # create a validation data loader val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4) - val_post_tran = Compose([ToTensor(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)]) + val_post_tran = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)]) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) # create UNet, DiceLoss and Adam optimizer @@ -195,13 +189,12 @@ def run_inference_test(root_dir, device="cuda:0"): Spacingd(keys=["img", "seg"], pixdim=[1.2, 0.8, 0.7], mode=["bilinear", "nearest"], dtype=np.float32), FromMetaTensord(["img", "seg"]), ScaleIntensityd(keys="img"), - ToTensord(keys=["img", "seg"]), ] ) val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) # sliding window inference need to input 1 image in every iteration val_loader = monai.data.DataLoader(val_ds, batch_size=1, num_workers=4) - val_post_tran = Compose([ToTensor(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)]) + val_post_tran = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)]) dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False) model = UNet( diff --git a/tests/test_integration_workflows.py b/tests/test_integration_workflows.py index 06fdef995b..2ccf1c97f5 100644 --- a/tests/test_integration_workflows.py +++ b/tests/test_integration_workflows.py @@ -52,7 +52,6 @@ SaveImaged, ScaleIntensityd, ToMetaTensord, - ToTensord, ) from monai.utils import set_determinism from monai.utils.enums import PostFix @@ -73,22 +72,18 @@ def run_training_test(root_dir, device="cuda:0", amp=False, num_workers=4): [ LoadImaged(keys=["image", "label"]), AsChannelFirstd(keys=["image", "label"], channel_dim=-1), - FromMetaTensord(["image", "label"]), ScaleIntensityd(keys=["image", "label"]), RandCropByPosNegLabeld( keys=["image", "label"], label_key="label", spatial_size=[96, 96, 96], pos=1, neg=1, num_samples=4 ), RandRotate90d(keys=["image", "label"], prob=0.5, spatial_axes=[0, 2]), - ToTensord(keys=["image", "label"]), ] ) val_transforms = Compose( [ LoadImaged(keys=["image", "label"]), AsChannelFirstd(keys=["image", "label"], channel_dim=-1), - FromMetaTensord(["image", "label"]), ScaleIntensityd(keys=["image", "label"]), - ToTensord(keys=["image", "label"]), ] ) @@ -116,7 +111,6 @@ def run_training_test(root_dir, device="cuda:0", amp=False, num_workers=4): val_postprocessing = Compose( [ - ToTensord(keys=["pred", "label"]), Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold=0.5), KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]), @@ -159,7 +153,6 @@ def _forward_completed(self, engine): train_postprocessing = Compose( [ - ToTensord(keys=["pred", "label"]), Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold=0.5), KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]), @@ -226,10 +219,9 @@ def run_inference_test(root_dir, model_file, device="cuda:0", amp=False, num_wor val_transforms = Compose( [ LoadImaged(keys=["image", "label"]), - FromMetaTensord(["image", "label"]), AsChannelFirstd(keys=["image", "label"], channel_dim=-1), ScaleIntensityd(keys=["image", "label"]), - ToTensord(keys=["image", "label"]), + FromMetaTensord(["image", "label"]), ] ) @@ -249,7 +241,6 @@ def run_inference_test(root_dir, model_file, device="cuda:0", amp=False, num_wor val_postprocessing = Compose( [ - ToTensord(keys=["pred", "label"]), Activationsd(keys="pred", sigmoid=True), AsDiscreted(keys="pred", threshold=0.5), KeepLargestConnectedComponentd(keys="pred", applied_labels=[1]), diff --git a/tests/test_integration_workflows_gan.py b/tests/test_integration_workflows_gan.py index f65a30450a..ff53851ce0 100644 --- a/tests/test_integration_workflows_gan.py +++ b/tests/test_integration_workflows_gan.py @@ -26,7 +26,7 @@ from monai.handlers import CheckpointSaver, StatsHandler, TensorBoardStatsHandler from monai.networks import normal_init from monai.networks.nets import Discriminator, Generator -from monai.transforms import AsChannelFirstd, Compose, LoadImaged, RandFlipd, ScaleIntensityd, ToTensord +from monai.transforms import AsChannelFirstd, Compose, LoadImaged, RandFlipd, ScaleIntensityd from monai.utils import set_determinism from tests.utils import DistTestCase, TimedCall, skip_if_quick @@ -42,7 +42,6 @@ def run_training_test(root_dir, device="cuda:0"): AsChannelFirstd(keys=["reals"]), ScaleIntensityd(keys=["reals"]), RandFlipd(keys=["reals"], prob=0.5), - ToTensord(keys=["reals"]), ] ) train_ds = monai.data.CacheDataset(data=train_files, transform=train_transforms, cache_rate=0.5) diff --git a/tests/test_inverse.py b/tests/test_inverse.py index ae3514be18..345b768755 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -10,6 +10,7 @@ # limitations under the License. import random +import sys import unittest from functools import partial from typing import TYPE_CHECKING, List, Tuple @@ -19,10 +20,12 @@ import torch from parameterized import parameterized -from monai.data import create_test_image_2d, create_test_image_3d +from monai.data import CacheDataset, DataLoader, MetaTensor, create_test_image_2d, create_test_image_3d, decollate_batch +from monai.networks.nets import UNet from monai.transforms import ( AddChanneld, Affined, + BatchInverseTransform, BorderPadd, CenterScaleCropd, CenterSpatialCropd, @@ -47,6 +50,7 @@ RandSpatialCropd, RandSpatialCropSamplesd, RandWeightedCropd, + RandZoomd, Resized, ResizeWithPadOrCrop, ResizeWithPadOrCropd, @@ -55,10 +59,13 @@ Spacingd, SpatialCropd, SpatialPadd, + ToMetaTensord, Transposed, + Zoomd, + allow_missing_keys_mode, + convert_inverse_interp_mode, ) -from monai.transforms.meta_utility.dictionary import ToMetaTensord -from monai.utils import get_seed, optional_import, set_determinism +from monai.utils import first, get_seed, optional_import, set_determinism from tests.utils import make_nifti_image, make_rand_affine if TYPE_CHECKING: @@ -158,30 +165,32 @@ TESTS.append(("ResizeWithPadOrCropd 3d", "3D", 0, True, ResizeWithPadOrCropd(KEYS, [201, 150, 105]))) -TESTS.append(("Flipd 3d", "3D", 0, False, Flipd(KEYS, [1, 2]))) +TESTS.append(("Flipd 3d", "3D", 0, True, Flipd(KEYS, [1, 2]))) +TESTS.append(("Flipd 3d", "3D", 0, True, Flipd(KEYS, [1, 2]))) -TESTS.append(("RandFlipd 3d", "3D", 0, False, RandFlipd(KEYS, 1, [1, 2]))) +TESTS.append(("RandFlipd 3d", "3D", 0, True, RandFlipd(KEYS, 1, [1, 2]))) -TESTS.append(("RandAxisFlipd 3d", "3D", 0, False, RandAxisFlipd(KEYS, 1))) +TESTS.append(("RandAxisFlipd 3d", "3D", 0, True, RandAxisFlipd(KEYS, 1))) +TESTS.append(("RandAxisFlipd 3d", "3D", 0, True, RandAxisFlipd(KEYS, 1))) for acc in [True, False]: TESTS.append(("Orientationd 3d", "3D", 0, True, Orientationd(KEYS, "RAS", as_closest_canonical=acc))) -TESTS.append(("Rotate90d 2d", "2D", 0, False, Rotate90d(KEYS))) +TESTS.append(("Rotate90d 2d", "2D", 0, True, Rotate90d(KEYS))) -TESTS.append(("Rotate90d 3d", "3D", 0, False, Rotate90d(KEYS, k=2, spatial_axes=(1, 2)))) +TESTS.append(("Rotate90d 3d", "3D", 0, True, Rotate90d(KEYS, k=2, spatial_axes=(1, 2)))) -TESTS.append(("RandRotate90d 3d", "3D", 0, False, RandRotate90d(KEYS, prob=1, spatial_axes=(1, 2)))) +TESTS.append(("RandRotate90d 3d", "3D", 0, True, RandRotate90d(KEYS, prob=1, spatial_axes=(1, 2)))) TESTS.append(("Spacingd 3d", "3D", 3e-2, True, Spacingd(KEYS, [0.5, 0.7, 0.9], diagonal=False))) -TESTS.append(("Resized 2d", "2D", 2e-1, False, Resized(KEYS, [50, 47]))) +TESTS.append(("Resized 2d", "2D", 2e-1, True, Resized(KEYS, [50, 47]))) -TESTS.append(("Resized 3d", "3D", 5e-2, False, Resized(KEYS, [201, 150, 78]))) +TESTS.append(("Resized 3d", "3D", 5e-2, True, Resized(KEYS, [201, 150, 78]))) -TESTS.append(("Resized longest 2d", "2D", 2e-1, False, Resized(KEYS, 47, "longest", "area"))) +TESTS.append(("Resized longest 2d", "2D", 2e-1, True, Resized(KEYS, 47, "longest", "area"))) -TESTS.append(("Resized longest 3d", "3D", 5e-2, False, Resized(KEYS, 201, "longest", "trilinear", True))) +TESTS.append(("Resized longest 3d", "3D", 5e-2, True, Resized(KEYS, 201, "longest", "trilinear", True))) TESTS.append( ("Lambdad 2d", "2D", 5e-2, False, Lambdad(KEYS, func=lambda x: x + 5, inv_func=lambda x: x - 5, overwrite=True)) @@ -197,22 +206,22 @@ ) ) -# TESTS.append(("Zoomd 1d", "1D odd", 0, False, Zoomd(KEYS, zoom=2, keep_size=False))) +TESTS.append(("Zoomd 1d", "1D odd", 0, True, Zoomd(KEYS, zoom=2, keep_size=False))) -# TESTS.append(("Zoomd 2d", "2D", 2e-1, False, Zoomd(KEYS, zoom=0.9))) +TESTS.append(("Zoomd 2d", "2D", 2e-1, True, Zoomd(KEYS, zoom=0.9))) -# TESTS.append(("Zoomd 3d", "3D", 3e-2, False, Zoomd(KEYS, zoom=[2.5, 1, 3], keep_size=False))) +TESTS.append(("Zoomd 3d", "3D", 3e-2, True, Zoomd(KEYS, zoom=[2.5, 1, 3], keep_size=False))) -# TESTS.append(("RandZoom 3d", "3D", 9e-2, False, RandZoomd(KEYS, 1, [0.5, 0.6, 0.9], [1.1, 1, 1.05], keep_size=True))) +TESTS.append(("RandZoom 3d", "3D", 9e-2, True, RandZoomd(KEYS, 1, [0.5, 0.6, 0.9], [1.1, 1, 1.05], keep_size=True))) -TESTS.append(("RandRotated, prob 0", "2D", 0, False, RandRotated(KEYS, prob=0, dtype=np.float64))) +TESTS.append(("RandRotated, prob 0", "2D", 0, True, RandRotated(KEYS, prob=0, dtype=np.float64))) TESTS.append( ( "Rotated 2d", "2D", 8e-2, - False, + True, Rotated(KEYS, random.uniform(np.pi / 6, np.pi), keep_size=True, align_corners=False, dtype=np.float64), ) ) @@ -222,7 +231,7 @@ "Rotated 3d", "3D", 1e-1, - False, + True, Rotated(KEYS, [random.uniform(np.pi / 6, np.pi) for _ in range(3)], True, dtype=np.float64), ) ) @@ -232,7 +241,7 @@ "RandRotated 3d", "3D", 1e-1, - False, + True, RandRotated(KEYS, *[random.uniform(np.pi / 6, np.pi) for _ in range(3)], 1, dtype=np.float64), # type: ignore ) ) @@ -246,7 +255,7 @@ "Affine 3d", "3D", 1e-1, - False, + True, Affined( KEYS, spatial_size=[155, 179, 192], @@ -263,7 +272,7 @@ "RandAffine 3d", "3D", 1e-1, - False, + True, RandAffined( KEYS, [155, 179, 192], @@ -277,7 +286,7 @@ ) ) -TESTS.append(("RandAffine 3d", "3D", 0, False, RandAffined(KEYS, spatial_size=None, prob=0))) +TESTS.append(("RandAffine 3d", "3D", 0, True, RandAffined(KEYS, spatial_size=None, prob=0))) TESTS.append( ( @@ -451,52 +460,47 @@ def test_fail(self): with self.assertRaises(RuntimeError): t2.inverse(data) - # @parameterized.expand(N_SAMPLES_TESTS) - # def test_inverse_inferred_seg(self, extra_transform): - - # test_data = [] - # for _ in range(20): - # image, label = create_test_image_2d(100, 101) - # test_data.append({"image": image, "label": label.astype(np.float32)}) - - # batch_size = 10 - # # num workers = 0 for mac - # num_workers = 2 if sys.platform == "linux" else 0 - # transforms = Compose([AddChanneld(KEYS), SpatialPadd(KEYS, (150, 153)), extra_transform]) - # num_invertible_transforms = sum(1 for i in transforms.transforms if isinstance(i, InvertibleTransform)) - - # dataset = CacheDataset(test_data, transform=transforms, progress=False) - # loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) - - # device = "cuda" if torch.cuda.is_available() else "cpu" - # model = UNet(spatial_dims=2, in_channels=1, out_channels=1, channels=(2, 4), strides=(2,)).to(device) - - # data = first(loader) - # self.assertEqual(len(data["label"].applied_operations), num_invertible_transforms) - # self.assertEqual(data["image"].shape[0], batch_size * NUM_SAMPLES) - - # labels = data["label"].to(device) - # segs = model(labels).detach().cpu() - # label_transform_key = TraceableTransform.trace_key("label") - # segs_dict = {"label": segs, label_transform_key: data[label_transform_key]} - - # segs_dict_decollated = decollate_batch(segs_dict) - # # inverse of individual segmentation - # seg_dict = first(segs_dict_decollated) - # # test to convert interpolation mode for 1 data of model output batch - # convert_inverse_interp_mode(seg_dict, mode="nearest", align_corners=None) - - # with allow_missing_keys_mode(transforms): - # inv_seg = transforms.inverse(seg_dict)["label"] - # self.assertEqual(len(data["label_transforms"]), num_invertible_transforms) - # self.assertEqual(len(seg_dict["label_transforms"]), num_invertible_transforms) - # self.assertEqual(inv_seg.shape[1:], test_data[0]["label"].shape) - - # # Inverse of batch - # batch_inverter = BatchInverseTransform(transforms, loader, collate_fn=no_collation, detach=True) - # with allow_missing_keys_mode(transforms): - # inv_batch = batch_inverter(segs_dict) - # self.assertEqual(inv_batch[0]["label"].shape[1:], test_data[0]["label"].shape) + @parameterized.expand(N_SAMPLES_TESTS) + def test_inverse_inferred_seg(self, extra_transform): + + test_data = [] + for _ in range(20): + image, label = create_test_image_2d(100, 101) + test_data.append({"image": image, "label": label.astype(np.float32)}) + + batch_size = 10 + # num workers = 0 for mac + num_workers = 2 if sys.platform == "linux" else 0 + transforms = Compose([AddChanneld(KEYS), SpatialPadd(KEYS, (150, 153)), extra_transform]) + + dataset = CacheDataset(test_data, transform=transforms, progress=False) + loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers) + + device = "cuda" if torch.cuda.is_available() else "cpu" + model = UNet(spatial_dims=2, in_channels=1, out_channels=1, channels=(2, 4), strides=(1,)).to(device) + + data = first(loader) + self.assertEqual(data["image"].shape[0], batch_size * NUM_SAMPLES) + + labels = data["label"].to(device) + self.assertTrue(isinstance(labels, MetaTensor)) + segs = model(labels).detach().cpu() + segs_decollated = decollate_batch(segs) + self.assertTrue(isinstance(segs_decollated[0], MetaTensor)) + # inverse of individual segmentation + seg_metatensor = first(segs_decollated) + # test to convert interpolation mode for 1 data of model output batch + convert_inverse_interp_mode(seg_metatensor.applied_operations, mode="nearest", align_corners=None) + + with allow_missing_keys_mode(transforms): + inv_seg = transforms.inverse({"label": seg_metatensor})["label"] + self.assertEqual(inv_seg.shape[1:], test_data[0]["label"].shape) + + # Inverse of batch + batch_inverter = BatchInverseTransform(transforms, loader, collate_fn=no_collation, detach=True) + with allow_missing_keys_mode(transforms): + inv_batch = batch_inverter(first(loader)) + self.assertEqual(inv_batch[0]["label"].shape[1:], test_data[0]["label"].shape) if __name__ == "__main__": diff --git a/tests/test_inverse_collation.py b/tests/test_inverse_collation.py index 7cc5f77941..4614432808 100644 --- a/tests/test_inverse_collation.py +++ b/tests/test_inverse_collation.py @@ -17,11 +17,19 @@ import torch from parameterized import parameterized -from monai.data import CacheDataset, DataLoader, create_test_image_2d, create_test_image_3d, pad_list_data_collate +from monai.data import ( + CacheDataset, + DataLoader, + MetaTensor, + create_test_image_2d, + create_test_image_3d, + decollate_batch, + pad_list_data_collate, +) from monai.transforms import ( AddChanneld, Compose, - FromMetaTensord, + Flipd, LoadImaged, RandAffined, RandAxisFlipd, @@ -30,7 +38,7 @@ RandRotated, RandZoomd, ResizeWithPadOrCropd, - ToTensord, + Rotated, ) from monai.utils import optional_import, set_determinism from tests.utils import make_nifti_image @@ -47,10 +55,12 @@ (t.__class__.__name__ + (" pad_list_data_collate" if collate_fn else " default_collate"), t, collate_fn, 3) for collate_fn in [None, pad_list_data_collate] for t in [ + Flipd(KEYS, spatial_axis=1), RandFlipd(keys=KEYS, prob=0.5, spatial_axis=[1, 2]), RandAxisFlipd(keys=KEYS, prob=0.5), - Compose([RandRotate90d(keys=KEYS, spatial_axes=(1, 2)), ToTensord(keys=KEYS)]), + Compose([RandRotate90d(keys=KEYS, spatial_axes=(1, 2))]), RandZoomd(keys=KEYS, prob=0.5, min_zoom=0.5, max_zoom=1.1, keep_size=True), + Rotated(keys=KEYS, angle=np.pi, dtype=np.float64), RandRotated(keys=KEYS, prob=0.5, range_x=np.pi, dtype=np.float64), RandAffined( keys=KEYS, prob=0.5, rotate_range=np.pi, device=torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -62,10 +72,12 @@ (t.__class__.__name__ + (" pad_list_data_collate" if collate_fn else " default_collate"), t, collate_fn, 2) for collate_fn in [None, pad_list_data_collate] for t in [ + Flipd(KEYS, spatial_axis=1), RandFlipd(keys=KEYS, prob=0.5, spatial_axis=[1]), RandAxisFlipd(keys=KEYS, prob=0.5), - Compose([RandRotate90d(keys=KEYS, prob=0.5, spatial_axes=(0, 1)), ToTensord(keys=KEYS)]), + Compose([RandRotate90d(keys=KEYS, prob=0.5, spatial_axes=(0, 1))]), RandZoomd(keys=KEYS, prob=0.5, min_zoom=0.5, max_zoom=1.1, keep_size=True), + Rotated(keys=KEYS, angle=np.pi / 2, dtype=np.float64), RandRotated(keys=KEYS, prob=0.5, range_x=np.pi, dtype=np.float64), RandAffined( keys=KEYS, prob=0.5, rotate_range=np.pi, device=torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -85,12 +97,12 @@ def setUp(self): b_size = 11 im_fname, seg_fname = (make_nifti_image(i) for i in create_test_image_3d(101, 100, 107)) - load_ims = Compose([LoadImaged(KEYS), AddChanneld(KEYS), FromMetaTensord(KEYS)]) + load_ims = Compose([LoadImaged(KEYS), AddChanneld(KEYS)]) self.data_3d = [load_ims({"image": im_fname, "label": seg_fname}) for _ in range(b_size)] b_size = 8 im_fname, seg_fname = (make_nifti_image(i) for i in create_test_image_2d(62, 37, rad_max=10)) - load_ims = Compose([LoadImaged(KEYS), AddChanneld(KEYS), FromMetaTensord(KEYS)]) + load_ims = Compose([LoadImaged(KEYS), AddChanneld(KEYS)]) self.data_2d = [load_ims({"image": im_fname, "label": seg_fname}) for _ in range(b_size)] self.batch_size = 7 @@ -100,11 +112,12 @@ def tearDown(self): @parameterized.expand(TESTS_2D + TESTS_3D) def test_collation(self, _, transform, collate_fn, ndim): + """transform, collate_fn, ndim""" data = self.data_3d if ndim == 3 else self.data_2d if collate_fn: modified_transform = transform else: - modified_transform = Compose([transform, ResizeWithPadOrCropd(KEYS, 100), ToTensord(KEYS)]) + modified_transform = Compose([transform, ResizeWithPadOrCropd(KEYS, 100)]) # num workers = 0 for mac or gpu transforms num_workers = 0 if sys.platform != "linux" or torch.cuda.is_available() else 2 @@ -113,9 +126,20 @@ def test_collation(self, _, transform, collate_fn, ndim): loader = DataLoader(dataset, num_workers, batch_size=self.batch_size, collate_fn=collate_fn) for item in loader: - np.testing.assert_array_equal( - item["image_transforms"][1]["do_transforms"], item["label_transforms"][1]["do_transforms"] - ) + if isinstance(item, dict): + np.testing.assert_array_equal(item["image"].shape, item["label"].shape) + continue + d = decollate_batch(item) + self.assertTrue(len(d) <= self.batch_size) + for b in d: + self.assertTrue(isinstance(b["image"], MetaTensor)) + np.testing.assert_array_equal( + b["image"].applied_operations[-1]["orig_size"], b["label"].applied_operations[-1]["orig_size"] + ) + np.testing.assert_array_equal( + b["image"].applied_operations[-1].get("_do_transform"), + b["label"].applied_operations[-1].get("_do_transform"), + ) if __name__ == "__main__": diff --git a/tests/test_invertd.py b/tests/test_invertd.py index 92ec30acc5..7ca562a2d9 100644 --- a/tests/test_invertd.py +++ b/tests/test_invertd.py @@ -9,9 +9,33 @@ # See the License for the specific language governing permissions and # limitations under the License. +import sys import unittest +import numpy as np +import torch + +from monai.data import DataLoader, Dataset, create_test_image_3d, decollate_batch +from monai.transforms import ( + CastToTyped, + Compose, + CopyItemsd, + EnsureChannelFirstd, + Invertd, + LoadImaged, + Orientationd, + RandAffined, + RandAxisFlipd, + RandFlipd, + RandRotate90d, + RandRotated, + RandZoomd, + ResizeWithPadOrCropd, + ScaleIntensityd, + Spacingd, +) from monai.utils import set_determinism +from tests.utils import make_nifti_image KEYS = ["image", "label"] @@ -19,125 +43,99 @@ class TestInvertd(unittest.TestCase): def test_invert(self): set_determinism(seed=0) - # im_fname, seg_fname = (make_nifti_image(i) for i in create_test_image_3d(101, 100, 107, noise_max=100)) - # transform = Compose( - # [ - # LoadImaged(KEYS), - # AddChanneld(KEYS), - # Orientationd(KEYS, "RPS"), - # Spacingd(KEYS, pixdim=(1.2, 1.01, 0.9), mode=["bilinear", "nearest"], dtype=np.float32), - # FromMetaTensord(KEYS), - # ScaleIntensityd("image", minv=1, maxv=10), - # RandFlipd(KEYS, prob=0.5, spatial_axis=[1, 2]), - # RandAxisFlipd(KEYS, prob=0.5), - # RandRotate90d(KEYS, spatial_axes=(1, 2)), - # RandZoomd(KEYS, prob=0.5, min_zoom=0.5, max_zoom=1.1, keep_size=True), - # RandRotated(KEYS, prob=0.5, range_x=np.pi, mode="bilinear", align_corners=True, dtype=np.float64), - # RandAffined(KEYS, prob=0.5, rotate_range=np.pi, mode="nearest"), - # ResizeWithPadOrCropd(KEYS, 100), - # # test EnsureTensor for complicated dict data and invert it - # CopyItemsd(PostFix.meta("image"), times=1, names="test_dict"), - # # test to support Tensor, Numpy array and dictionary when inverting - # EnsureTyped(keys=["image", "test_dict"]), - # ToTensord("image"), - # CastToTyped(KEYS, dtype=[torch.uint8, np.uint8]), - # CopyItemsd("label", times=2, names=["label_inverted", "label_inverted1"]), - # CopyItemsd("image", times=2, names=["image_inverted", "image_inverted1"]), - # ] - # ) - # data = [{"image": im_fname, "label": seg_fname} for _ in range(12)] - - # # num workers = 0 for mac or gpu transforms - # num_workers = 0 if sys.platform != "linux" or torch.cuda.is_available() else 2 - - # dataset = CacheDataset(data, transform=transform, progress=False) - # loader = DataLoader(dataset, num_workers=num_workers, batch_size=5) - # inverter = Invertd( - # # `image` was not copied, invert the original value directly - # keys=["image_inverted", "label_inverted", "test_dict"], - # transform=transform, - # orig_keys=["label", "label", "test_dict"], - # meta_keys=[PostFix.meta("image_inverted"), PostFix.meta("label_inverted"), None], - # orig_meta_keys=[PostFix.meta("label"), PostFix.meta("label"), None], - # nearest_interp=True, - # to_tensor=[True, False, False], - # device="cpu", - # ) - - # inverter_1 = Invertd( - # # `image` was not copied, invert the original value directly - # keys=["image_inverted1", "label_inverted1"], - # transform=transform, - # orig_keys=["image", "image"], - # meta_keys=[PostFix.meta("image_inverted1"), PostFix.meta("label_inverted1")], - # orig_meta_keys=[PostFix.meta("image"), PostFix.meta("image")], - # nearest_interp=[True, False], - # to_tensor=[True, True], - # device="cpu", - # ) - - # expected_keys = [ - # "image", - # "image_inverted", - # "image_inverted1", - # PostFix.meta("image_inverted1"), - # PostFix.meta("image_inverted"), - # PostFix.meta("image"), - # "image_transforms", - # "label", - # "label_inverted", - # "label_inverted1", - # PostFix.meta("label_inverted1"), - # PostFix.meta("label_inverted"), - # PostFix.meta("label"), - # "label_transforms", - # "test_dict", - # "test_dict_transforms", - # ] - # # execute 1 epoch - # for d in loader: - # d = decollate_batch(d) - # for item in d: - # item = inverter(item) - # item = inverter_1(item) - # - # self.assertListEqual(sorted(item), expected_keys) - # self.assertTupleEqual(item["image"].shape[1:], (100, 100, 100)) - # self.assertTupleEqual(item["label"].shape[1:], (100, 100, 100)) - # # check the nearest interpolation mode - # i = item["image_inverted"] - # torch.testing.assert_allclose(i.to(torch.uint8).to(torch.float), i.to(torch.float)) - # self.assertTupleEqual(i.shape[1:], (100, 101, 107)) - # i = item["label_inverted"] - # torch.testing.assert_allclose(i.to(torch.uint8).to(torch.float), i.to(torch.float)) - # self.assertTupleEqual(i.shape[1:], (100, 101, 107)) - # # test inverted test_dict - # self.assertTrue(isinstance(item["test_dict"]["affine"], np.ndarray)) - # self.assertTrue(isinstance(item["test_dict"]["filename_or_obj"], str)) - # - # # check the case that different items use different interpolation mode to invert transforms - # d = item["image_inverted1"] - # # if the interpolation mode is nearest, accumulated diff should be smaller than 1 - # self.assertLess(torch.sum(d.to(torch.float) - d.to(torch.uint8).to(torch.float)).item(), 1.0) - # self.assertTupleEqual(d.shape, (1, 100, 101, 107)) - # - # d = item["label_inverted1"] - # # if the interpolation mode is not nearest, accumulated diff should be greater than 10000 - # self.assertGreater(torch.sum(d.to(torch.float) - d.to(torch.uint8).to(torch.float)).item(), 10000.0) - # self.assertTupleEqual(d.shape, (1, 100, 101, 107)) - # - # # check labels match - # reverted = item["label_inverted"].detach().cpu().numpy().astype(np.int32) - # original = LoadImaged(KEYS)(data[-1])["label"] - # n_good = np.sum(np.isclose(reverted, original, atol=1e-3)) - # reverted_name = item["label_inverted"].meta["filename_or_obj"] - # original_name = data[-1]["label"] - # self.assertEqual(reverted_name, original_name) - # print("invert diff", reverted.size - n_good) - # # 25300: 2 workers (cpu, non-macos) - # # 1812: 0 workers (gpu or macos) - # # 1821: windows torch 1.10.0 - # self.assertTrue((reverted.size - n_good) in (34007, 1812, 1821), f"diff. {reverted.size - n_good}") + im_fname, seg_fname = (make_nifti_image(i) for i in create_test_image_3d(101, 100, 107, noise_max=100)) + transform = Compose( + [ + LoadImaged(KEYS), + EnsureChannelFirstd(KEYS), + Orientationd(KEYS, "RPS"), + Spacingd(KEYS, pixdim=(1.2, 1.01, 0.9), mode=["bilinear", "nearest"], dtype=np.float32), + ScaleIntensityd("image", minv=1, maxv=10), + RandFlipd(KEYS, prob=0.5, spatial_axis=[1, 2]), + RandAxisFlipd(KEYS, prob=0.5), + RandRotate90d(KEYS, prob=0, spatial_axes=(1, 2)), + RandZoomd(KEYS, prob=0.5, min_zoom=0.5, max_zoom=1.1, keep_size=True), + RandRotated(KEYS, prob=0.5, range_x=np.pi, mode="bilinear", align_corners=True, dtype=np.float64), + RandAffined(KEYS, prob=0.5, rotate_range=np.pi, mode="nearest"), + ResizeWithPadOrCropd(KEYS, 100), + # test EnsureTensor for complicated dict data and invert it + # CopyItemsd(PostFix.meta("image"), times=1, names="test_dict"), + # test to support Tensor, Numpy array and dictionary when inverting + # EnsureTyped(keys=["image", "test_dict"]), + # ToTensord("image"), + CastToTyped(KEYS, dtype=[torch.uint8, np.uint8]), + CopyItemsd("label", times=2, names=["label_inverted", "label_inverted1"]), + CopyItemsd("image", times=2, names=["image_inverted", "image_inverted1"]), + ] + ) + data = [{"image": im_fname, "label": seg_fname} for _ in range(12)] + + # num workers = 0 for mac or gpu transforms + num_workers = 0 if sys.platform != "linux" or torch.cuda.is_available() else 2 + + dataset = Dataset(data, transform=transform) + transform.inverse(dataset[0]) + loader = DataLoader(dataset, num_workers=num_workers, batch_size=1) + inverter = Invertd( + # `image` was not copied, invert the original value directly + keys=["image_inverted", "label_inverted"], + transform=transform, + orig_keys=["label", "label"], + nearest_interp=True, + device="cpu", + ) + + inverter_1 = Invertd( + # `image` was not copied, invert the original value directly + keys=["image_inverted1", "label_inverted1"], + transform=transform, + orig_keys=["image", "image"], + nearest_interp=[True, False], + device="cpu", + ) + + expected_keys = ["image", "image_inverted", "image_inverted1", "label", "label_inverted", "label_inverted1"] + # execute 1 epoch + for d in loader: + d = decollate_batch(d) + for item in d: + item = inverter(item) + item = inverter_1(item) + + self.assertListEqual(sorted(item), expected_keys) + self.assertTupleEqual(item["image"].shape[1:], (100, 100, 100)) + self.assertTupleEqual(item["label"].shape[1:], (100, 100, 100)) + # check the nearest interpolation mode + i = item["image_inverted"] + torch.testing.assert_allclose(i.to(torch.uint8).to(torch.float), i.to(torch.float)) + self.assertTupleEqual(i.shape[1:], (100, 101, 107)) + i = item["label_inverted"] + torch.testing.assert_allclose(i.to(torch.uint8).to(torch.float), i.to(torch.float)) + self.assertTupleEqual(i.shape[1:], (100, 101, 107)) + + # check the case that different items use different interpolation mode to invert transforms + d = item["image_inverted1"] + # if the interpolation mode is nearest, accumulated diff should be smaller than 1 + self.assertLess(torch.sum(d.to(torch.float) - d.to(torch.uint8).to(torch.float)).item(), 1.0) + self.assertTupleEqual(d.shape, (1, 100, 101, 107)) + + d = item["label_inverted1"] + # if the interpolation mode is not nearest, accumulated diff should be greater than 10000 + self.assertGreater(torch.sum(d.to(torch.float) - d.to(torch.uint8).to(torch.float)).item(), 10000.0) + self.assertTupleEqual(d.shape, (1, 100, 101, 107)) + + # check labels match + reverted = item["label_inverted"].detach().cpu().numpy().astype(np.int32) + original = LoadImaged(KEYS)(data[-1])["label"] + n_good = np.sum(np.isclose(reverted, original, atol=1e-3)) + reverted_name = item["label_inverted"].meta["filename_or_obj"] + original_name = data[-1]["label"] + self.assertEqual(reverted_name, original_name) + print("invert diff", reverted.size - n_good) + # 25300: 2 workers (cpu, non-macos) + # 1812: 0 workers (gpu or macos) + # 1821: windows torch 1.10.0 + self.assertTrue((reverted.size - n_good) < 40000, f"diff. {reverted.size - n_good}") set_determinism(seed=None) diff --git a/tests/test_meta_tensor.py b/tests/test_meta_tensor.py index 976d8e1e0f..d8ab0823d9 100644 --- a/tests/test_meta_tensor.py +++ b/tests/test_meta_tensor.py @@ -296,7 +296,7 @@ def test_collate(self, device, dtype): self.assertIsInstance(collated.affine, torch.Tensor) expected_shape = (numel,) + tuple(ims[0].affine.shape) self.assertTupleEqual(tuple(collated.affine.shape), expected_shape) - self.assertEqual(len(collated.applied_operations), 1) + self.assertEqual(len(collated.applied_operations), numel) @parameterized.expand(TESTS) def test_dataset(self, device, dtype): @@ -321,7 +321,7 @@ def test_dataloader(self, dtype): self.assertIsInstance(batch, MetaTensor) self.assertTupleEqual(tuple(batch.shape), expected_im_shape) self.assertTupleEqual(tuple(batch.affine.shape), expected_affine_shape) - self.assertEqual(len(batch.applied_operations), 1) + self.assertEqual(len(batch.applied_operations), batch_size) @SkipIfBeforePyTorchVersion((1, 9)) def test_indexing(self): diff --git a/tests/test_pad_collation.py b/tests/test_pad_collation.py index 530e5f86a3..9ea3a7bc73 100644 --- a/tests/test_pad_collation.py +++ b/tests/test_pad_collation.py @@ -31,7 +31,6 @@ RandZoom, RandZoomd, ToTensor, - ToTensord, ) from monai.utils import set_determinism @@ -44,7 +43,9 @@ TESTS.append((dict, pad_collate, RandSpatialCropd("image", roi_size=[8, 7], random_size=True))) TESTS.append((dict, pad_collate, RandRotated("image", prob=1, range_x=np.pi, keep_size=False, dtype=np.float64))) TESTS.append((dict, pad_collate, RandZoomd("image", prob=1, min_zoom=1.1, max_zoom=2.0, keep_size=False))) - TESTS.append((dict, pad_collate, Compose([RandRotate90d("image", prob=1, max_k=2), ToTensord("image")]))) + TESTS.append( + (dict, pad_collate, Compose([RandRotate90d("image", prob=1, max_k=3), RandRotate90d("image", prob=1, max_k=4)])) + ) TESTS.append((list, pad_collate, RandSpatialCrop(roi_size=[8, 7], random_size=True))) TESTS.append((list, pad_collate, RandRotate(prob=1, range_x=np.pi, keep_size=False, dtype=np.float64))) diff --git a/tests/test_rand_affine.py b/tests/test_rand_affine.py index dcfe193213..363ea93650 100644 --- a/tests/test_rand_affine.py +++ b/tests/test_rand_affine.py @@ -17,7 +17,7 @@ from monai.transforms import RandAffine from monai.utils.type_conversion import convert_data_type -from tests.utils import TEST_NDARRAYS, assert_allclose, is_tf32_env +from tests.utils import TEST_NDARRAYS, assert_allclose, is_tf32_env, test_local_inversion _rtol = 1e-3 if is_tf32_env() else 1e-4 @@ -144,7 +144,9 @@ def test_rand_affine(self, input_param, input_data, expected_val): result = g(**input_data) if input_param.get("cache_grid", False): self.assertTrue(g._cached_grid is not None) - assert_allclose(result, expected_val, rtol=_rtol, atol=1e-4) + test_local_inversion(g, result, input_data) + + assert_allclose(result, expected_val, rtol=_rtol, atol=1e-4, type_test=False) def test_ill_cache(self): with self.assertWarns(UserWarning): @@ -189,8 +191,8 @@ def test_no_randomize(self, initial_randomize, cache_grid): arr2 = rand_affine(arr, randomize=False) m2 = rand_affine.rand_affine_grid.get_transformation_matrix() - assert_allclose(m1, m2) - assert_allclose(arr1, arr2) + assert_allclose(m1, m2, type_test=False) + assert_allclose(arr1, arr2, type_test=False) if __name__ == "__main__": diff --git a/tests/test_rand_affined.py b/tests/test_rand_affined.py index 4f549cb7ab..e5f2582dd1 100644 --- a/tests/test_rand_affined.py +++ b/tests/test_rand_affined.py @@ -15,14 +15,15 @@ import torch from parameterized import parameterized +from monai.data import MetaTensor from monai.transforms import RandAffined from monai.utils import GridSampleMode -from tests.utils import TEST_NDARRAYS_NO_META_TENSOR, assert_allclose, is_tf32_env +from tests.utils import TEST_NDARRAYS, assert_allclose, is_tf32_env _rtol = 1e-3 if is_tf32_env() else 1e-4 TESTS = [] -for p in TEST_NDARRAYS_NO_META_TENSOR: +for p in TEST_NDARRAYS: for device in [None, "cpu", "cuda"] if torch.cuda.is_available() else [None, "cpu"]: TESTS.append( [ @@ -211,12 +212,17 @@ def test_rand_affined(self, input_param, input_data, expected_val): if "_transforms" in key: continue expected = expected_val[key] if isinstance(expected_val, dict) else expected_val - assert_allclose(result, expected, rtol=_rtol, atol=1e-3) + assert_allclose(result, expected, rtol=_rtol, atol=1e-3, type_test=False) g.set_random_state(4) res = g(input_data) # affine should be tensor because the resampler only supports pytorch backend - self.assertTrue(isinstance(res["img_transforms"][0]["extra_info"]["affine"], torch.Tensor)) + if isinstance(res["img"], MetaTensor) and "extra_info" in res["img"].applied_operations[0]: + if not res["img"].applied_operations[-1]["extra_info"]["do_resampling"]: + return + affine_img = res["img"].applied_operations[0]["extra_info"]["affine"] + affine_seg = res["seg"].applied_operations[0]["extra_info"]["affine"] + assert_allclose(affine_img, affine_seg, rtol=_rtol, atol=1e-3) def test_ill_cache(self): with self.assertWarns(UserWarning): diff --git a/tests/test_rand_axis_flip.py b/tests/test_rand_axis_flip.py index b7c504557f..760f6c23ea 100644 --- a/tests/test_rand_axis_flip.py +++ b/tests/test_rand_axis_flip.py @@ -14,16 +14,18 @@ import numpy as np from monai.transforms import RandAxisFlip -from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose, test_local_inversion class TestRandAxisFlip(NumpyImageTestCase2D): def test_correct_results(self): for p in TEST_NDARRAYS: flip = RandAxisFlip(prob=1.0) - result = flip(p(self.imt[0])) + im = p(self.imt[0]) + result = flip(im) expected = [np.flip(channel, flip._axis) for channel in self.imt[0]] - assert_allclose(result, p(np.stack(expected))) + assert_allclose(result, p(np.stack(expected)), type_test=False) + test_local_inversion(flip, result, im) if __name__ == "__main__": diff --git a/tests/test_rand_axis_flipd.py b/tests/test_rand_axis_flipd.py index ff97d5dc1e..3f2bc80194 100644 --- a/tests/test_rand_axis_flipd.py +++ b/tests/test_rand_axis_flipd.py @@ -14,17 +14,18 @@ import numpy as np from monai.transforms import RandAxisFlipd -from tests.utils import TEST_NDARRAYS, NumpyImageTestCase3D, assert_allclose +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase3D, assert_allclose, test_local_inversion class TestRandAxisFlip(NumpyImageTestCase3D): def test_correct_results(self): for p in TEST_NDARRAYS: flip = RandAxisFlipd(keys="img", prob=1.0) - result = flip({"img": p(self.imt[0])})["img"] - + im = p(self.imt[0]) + result = flip({"img": im}) + test_local_inversion(flip, result, {"img": im}, "img") expected = [np.flip(channel, flip.flipper._axis) for channel in self.imt[0]] - assert_allclose(result, p(np.stack(expected))) + assert_allclose(result["img"], p(np.stack(expected)), type_test=False) if __name__ == "__main__": diff --git a/tests/test_rand_flip.py b/tests/test_rand_flip.py index b9e9a8c4d6..5d1723499f 100644 --- a/tests/test_rand_flip.py +++ b/tests/test_rand_flip.py @@ -15,7 +15,7 @@ from parameterized import parameterized from monai.transforms import RandFlip -from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose, test_local_inversion INVALID_CASES = [("wrong_axis", ["s", 1], TypeError), ("not_numbers", "s", TypeError)] @@ -37,7 +37,8 @@ def test_correct_results(self, _, spatial_axis): expected = [np.flip(channel, spatial_axis) for channel in self.imt[0]] expected = np.stack(expected) result = flip(im) - assert_allclose(result, p(expected)) + assert_allclose(result, p(expected), type_test=False) + test_local_inversion(flip, result, im) if __name__ == "__main__": diff --git a/tests/test_rand_flipd.py b/tests/test_rand_flipd.py index 9a92661c59..edefeaf5bf 100644 --- a/tests/test_rand_flipd.py +++ b/tests/test_rand_flipd.py @@ -15,7 +15,7 @@ from parameterized import parameterized from monai.transforms import RandFlipd -from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose, test_local_inversion VALID_CASES = [("no_axis", None), ("one_axis", 1), ("many_axis", [0, 1])] @@ -25,10 +25,12 @@ class TestRandFlipd(NumpyImageTestCase2D): def test_correct_results(self, _, spatial_axis): for p in TEST_NDARRAYS: flip = RandFlipd(keys="img", prob=1.0, spatial_axis=spatial_axis) - result = flip({"img": p(self.imt[0])})["img"] + im = p(self.imt[0]) + result = flip({"img": im})["img"] expected = [np.flip(channel, spatial_axis) for channel in self.imt[0]] expected = np.stack(expected) - assert_allclose(result, p(expected)) + assert_allclose(result, p(expected), type_test=False) + test_local_inversion(flip, {"img": result}, {"img": im}, "img") if __name__ == "__main__": diff --git a/tests/test_rand_rotate.py b/tests/test_rand_rotate.py index 7a85fce23b..172bc6c59b 100644 --- a/tests/test_rand_rotate.py +++ b/tests/test_rand_rotate.py @@ -18,7 +18,7 @@ from parameterized import parameterized from monai.transforms import RandRotate -from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, NumpyImageTestCase3D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, NumpyImageTestCase3D, test_local_inversion TEST_CASES_2D: List[Tuple] = [] for p in TEST_NDARRAYS: @@ -108,8 +108,10 @@ def test_correct_results(self, im_type, x, y, z, keep_size, mode, padding_mode, dtype=np.float64, ) rotate_fn.set_random_state(243) - rotated = rotate_fn(im_type(self.imt[0])) + im = im_type(self.imt[0]) + rotated = rotate_fn(im) torch.testing.assert_allclose(rotated.shape, expected, rtol=1e-7, atol=0) + test_local_inversion(rotate_fn, rotated, im) if __name__ == "__main__": diff --git a/tests/test_rand_rotate90.py b/tests/test_rand_rotate90.py index b845944062..12c14508e2 100644 --- a/tests/test_rand_rotate90.py +++ b/tests/test_rand_rotate90.py @@ -14,7 +14,7 @@ import numpy as np from monai.transforms import RandRotate90 -from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose, test_local_inversion class TestRandRotate90(NumpyImageTestCase2D): @@ -22,37 +22,45 @@ def test_default(self): rotate = RandRotate90() for p in TEST_NDARRAYS: rotate.set_random_state(123) - rotated = rotate(p(self.imt[0])) + im = p(self.imt[0]) + rotated = rotate(im) + test_local_inversion(rotate, rotated, im) expected = [np.rot90(channel, 0, (0, 1)) for channel in self.imt[0]] expected = np.stack(expected) - assert_allclose(rotated, p(expected), rtol=1.0e-5, atol=1.0e-8) + assert_allclose(rotated, p(expected), rtol=1.0e-5, atol=1.0e-8, type_test=False) def test_k(self): rotate = RandRotate90(max_k=2) for p in TEST_NDARRAYS: rotate.set_random_state(123) - rotated = rotate(p(self.imt[0])) + im = p(self.imt[0]) + rotated = rotate(im) + test_local_inversion(rotate, rotated, im) expected = [np.rot90(channel, 0, (0, 1)) for channel in self.imt[0]] expected = np.stack(expected) - assert_allclose(rotated, p(expected), rtol=1.0e-5, atol=1.0e-8) + assert_allclose(rotated, p(expected), rtol=1.0e-5, atol=1.0e-8, type_test=False) def test_spatial_axes(self): - rotate = RandRotate90(spatial_axes=(0, 1)) + rotate = RandRotate90(spatial_axes=(0, 1), prob=1.0) for p in TEST_NDARRAYS: - rotate.set_random_state(123) - rotated = rotate(p(self.imt[0])) - expected = [np.rot90(channel, 0, (0, 1)) for channel in self.imt[0]] + rotate.set_random_state(1234) + im = p(self.imt[0]) + rotated = rotate(im) + expected = [np.rot90(channel, rotate._rand_k, (0, 1)) for channel in self.imt[0]] expected = np.stack(expected) - assert_allclose(rotated, p(expected), rtol=1.0e-5, atol=1.0e-8) + assert_allclose(rotated, p(expected), rtol=1.0e-5, atol=1.0e-8, type_test=False) + test_local_inversion(rotate, rotated, im) def test_prob_k_spatial_axes(self): rotate = RandRotate90(prob=1.0, max_k=2, spatial_axes=(0, 1)) for p in TEST_NDARRAYS: rotate.set_random_state(234) - rotated = rotate(p(self.imt[0])) + im = p(self.imt[0]) + rotated = rotate(im) + test_local_inversion(rotate, rotated, im) expected = [np.rot90(channel, 1, (0, 1)) for channel in self.imt[0]] expected = np.stack(expected) - assert_allclose(rotated, p(expected), rtol=1.0e-5, atol=1.0e-8) + assert_allclose(rotated, p(expected), rtol=1.0e-5, atol=1.0e-8, type_test=False) if __name__ == "__main__": diff --git a/tests/test_rand_rotate90d.py b/tests/test_rand_rotate90d.py index ded18e430a..690bfcd66d 100644 --- a/tests/test_rand_rotate90d.py +++ b/tests/test_rand_rotate90d.py @@ -14,7 +14,7 @@ import numpy as np from monai.transforms import RandRotate90d -from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose, test_local_inversion class TestRandRotate90d(NumpyImageTestCase2D): @@ -22,41 +22,49 @@ def test_default(self): key = None rotate = RandRotate90d(keys=key) for p in TEST_NDARRAYS: - rotate.set_random_state(123) - rotated = rotate({key: p(self.imt[0])}) + rotate.set_random_state(1323) + im = {key: p(self.imt[0])} + rotated = rotate(im) + test_local_inversion(rotate, rotated, im, key) expected = [np.rot90(channel, 0, (0, 1)) for channel in self.imt[0]] expected = np.stack(expected) - assert_allclose(rotated[key], p(expected)) + assert_allclose(rotated[key], p(expected), type_test=False) def test_k(self): key = "test" rotate = RandRotate90d(keys=key, max_k=2) for p in TEST_NDARRAYS: rotate.set_random_state(234) - rotated = rotate({key: p(self.imt[0])}) + im = {key: p(self.imt[0])} + rotated = rotate(im) + test_local_inversion(rotate, rotated, im, key) expected = [np.rot90(channel, 0, (0, 1)) for channel in self.imt[0]] expected = np.stack(expected) - assert_allclose(rotated[key], p(expected)) + assert_allclose(rotated[key], p(expected), type_test=False) def test_spatial_axes(self): key = "test" rotate = RandRotate90d(keys=key, spatial_axes=(0, 1)) for p in TEST_NDARRAYS: rotate.set_random_state(234) - rotated = rotate({key: p(self.imt[0])}) + im = {key: p(self.imt[0])} + rotated = rotate(im) + test_local_inversion(rotate, rotated, im, key) expected = [np.rot90(channel, 0, (0, 1)) for channel in self.imt[0]] expected = np.stack(expected) - assert_allclose(rotated[key], p(expected)) + assert_allclose(rotated[key], p(expected), type_test=False) def test_prob_k_spatial_axes(self): key = "test" rotate = RandRotate90d(keys=key, prob=1.0, max_k=2, spatial_axes=(0, 1)) for p in TEST_NDARRAYS: rotate.set_random_state(234) - rotated = rotate({key: p(self.imt[0])}) + im = {key: p(self.imt[0])} + rotated = rotate(im) expected = [np.rot90(channel, 1, (0, 1)) for channel in self.imt[0]] expected = np.stack(expected) - assert_allclose(rotated[key], p(expected)) + assert_allclose(rotated[key], p(expected), type_test=False) + test_local_inversion(rotate, rotated, im, key) def test_no_key(self): key = "unknown" diff --git a/tests/test_rand_rotated.py b/tests/test_rand_rotated.py index 464b37d925..b6b2798b1b 100644 --- a/tests/test_rand_rotated.py +++ b/tests/test_rand_rotated.py @@ -19,7 +19,7 @@ from monai.transforms import RandRotated from monai.utils import GridSampleMode, GridSamplePadMode -from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, NumpyImageTestCase3D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, NumpyImageTestCase3D, test_local_inversion TEST_CASES_2D: List[Tuple] = [] for p in TEST_NDARRAYS: @@ -118,8 +118,9 @@ def test_correct_results(self, im_type, degrees, keep_size, mode, padding_mode, align_corners=align_corners, dtype=np.float64, ) + im = im_type(self.imt[0]) rotate_fn.set_random_state(243) - rotated = rotate_fn({"img": im_type(self.imt[0]), "seg": im_type(self.segn[0])}) + rotated = rotate_fn({"img": im, "seg": im_type(self.segn[0])}) _order = 0 if mode == "nearest" else 1 if padding_mode == "border": @@ -132,6 +133,7 @@ def test_correct_results(self, im_type, degrees, keep_size, mode, padding_mode, expected = scipy.ndimage.rotate( self.imt[0, 0], -np.rad2deg(angle), (0, 1), not keep_size, order=_order, mode=_mode, prefilter=False ) + test_local_inversion(rotate_fn, rotated, {"img": im}, "img") for k, v in rotated.items(): rotated[k] = v.cpu() if isinstance(v, torch.Tensor) else v expected = np.stack(expected).astype(np.float32) diff --git a/tests/test_rand_zoom.py b/tests/test_rand_zoom.py index 35472024ef..71f20b0de7 100644 --- a/tests/test_rand_zoom.py +++ b/tests/test_rand_zoom.py @@ -17,7 +17,7 @@ from monai.transforms import RandZoom from monai.utils import GridSampleMode, InterpolateMode -from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose, test_local_inversion VALID_CASES = [(0.8, 1.2, "nearest", False), (0.8, 1.2, InterpolateMode.NEAREST, False)] @@ -28,20 +28,24 @@ def test_correct_results(self, min_zoom, max_zoom, mode, keep_size): for p in TEST_NDARRAYS: random_zoom = RandZoom(prob=1.0, min_zoom=min_zoom, max_zoom=max_zoom, mode=mode, keep_size=keep_size) random_zoom.set_random_state(1234) - zoomed = random_zoom(p(self.imt[0])) + im = p(self.imt[0]) + zoomed = random_zoom(im) + test_local_inversion(random_zoom, zoomed, im) expected = [ zoom_scipy(channel, zoom=random_zoom._zoom, mode="nearest", order=0, prefilter=False) for channel in self.imt[0] ] expected = np.stack(expected).astype(np.float32) - assert_allclose(zoomed, p(expected), atol=1.0) + assert_allclose(zoomed, p(expected), atol=1.0, type_test=False) def test_keep_size(self): for p in TEST_NDARRAYS: im = p(self.imt[0]) random_zoom = RandZoom(prob=1.0, min_zoom=0.6, max_zoom=0.7, keep_size=True) + random_zoom.set_random_state(12) zoomed = random_zoom(im) + test_local_inversion(random_zoom, zoomed, im) self.assertTrue(np.array_equal(zoomed.shape, self.imt.shape[1:])) zoomed = random_zoom(im) self.assertTrue(np.array_equal(zoomed.shape, self.imt.shape[1:])) @@ -67,8 +71,8 @@ def test_auto_expand_3d(self): random_zoom.set_random_state(1234) test_data = p(np.random.randint(0, 2, size=[2, 2, 3, 4])) zoomed = random_zoom(test_data) - assert_allclose(random_zoom._zoom, (1.048844, 1.048844, 0.962637), atol=1e-2) - assert_allclose(zoomed.shape, (2, 2, 3, 3)) + assert_allclose(random_zoom._zoom, (1.048844, 1.048844, 0.962637), atol=1e-2, type_test=False) + assert_allclose(zoomed.shape, (2, 2, 3, 3), type_test=False) if __name__ == "__main__": diff --git a/tests/test_rand_zoomd.py b/tests/test_rand_zoomd.py index a22f2f36f1..ee82fb7917 100644 --- a/tests/test_rand_zoomd.py +++ b/tests/test_rand_zoomd.py @@ -16,7 +16,7 @@ from scipy.ndimage import zoom as zoom_scipy from monai.transforms import RandZoomd -from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose, test_local_inversion VALID_CASES = [(0.8, 1.2, "nearest", None, False)] @@ -37,14 +37,16 @@ def test_correct_results(self, min_zoom, max_zoom, mode, align_corners, keep_siz for p in TEST_NDARRAYS: random_zoom.set_random_state(1234) - zoomed = random_zoom({key: p(self.imt[0])}) + im = p(self.imt[0]) + zoomed = random_zoom({key: im}) + test_local_inversion(random_zoom, zoomed, {key: im}, key) expected = [ zoom_scipy(channel, zoom=random_zoom.rand_zoom._zoom, mode="nearest", order=0, prefilter=False) for channel in self.imt[0] ] expected = np.stack(expected).astype(np.float32) - assert_allclose(zoomed[key], p(expected), atol=1.0) + assert_allclose(zoomed[key], p(expected), atol=1.0, type_test=False) def test_keep_size(self): key = "img" @@ -52,7 +54,9 @@ def test_keep_size(self): keys=key, prob=1.0, min_zoom=0.6, max_zoom=0.7, keep_size=True, padding_mode="constant", constant_values=2 ) for p in TEST_NDARRAYS: - zoomed = random_zoom({key: p(self.imt[0])}) + im = p(self.imt[0]) + zoomed = random_zoom({key: im}) + test_local_inversion(random_zoom, zoomed, {key: im}, key) np.testing.assert_array_equal(zoomed[key].shape, self.imt.shape[1:]) @parameterized.expand( diff --git a/tests/test_resample_to_matchd.py b/tests/test_resample_to_matchd.py index baf4df7b19..a63f12f426 100644 --- a/tests/test_resample_to_matchd.py +++ b/tests/test_resample_to_matchd.py @@ -18,7 +18,6 @@ Compose, CopyItemsd, EnsureChannelFirstd, - FromMetaTensord, Invertd, Lambda, LoadImaged, @@ -29,7 +28,7 @@ def update_fname(d): - d["im3_meta_dict"]["filename_or_obj"] = "file3.nii.gz" + d["im3"].meta["filename_or_obj"] = "file3.nii.gz" return d @@ -59,22 +58,18 @@ def test_correct(self): EnsureChannelFirstd(("im1", "im2")), CopyItemsd(("im2"), names=("im3")), ResampleToMatchd("im3", "im1"), - FromMetaTensord("im3"), Lambda(update_fname), - SaveImaged("im3", output_dir=self.tmpdir, output_postfix="", separate_folder=False), + SaveImaged("im3", output_dir=self.tmpdir, output_postfix="", separate_folder=False, resample=False), ] ) data = transforms({"im1": self.fnames[0], "im2": self.fnames[1]}) # check that output sizes match assert_allclose(data["im1"].shape, data["im3"].shape) # and that the meta data has been updated accordingly - # assert_allclose(data["im3"].shape[1:], data["im3_meta_dict"]["spatial_shape"], type_test=False) - assert_allclose(data["im3_meta_dict"]["affine"], data["im1"].affine) + assert_allclose(data["im3"].affine, data["im1"].affine) # check we're different from the original self.assertTrue(any(i != j for i, j in zip(data["im3"].shape, data["im2"].shape))) - self.assertTrue( - any(i != j for i, j in zip(data["im3_meta_dict"]["affine"].flatten(), data["im2"].affine.flatten())) - ) + self.assertTrue(any(i != j for i, j in zip(data["im3"].affine.flatten(), data["im2"].affine.flatten()))) # test the inverse data = Invertd("im3", transforms, "im3")(data) assert_allclose(data["im2"].shape, data["im3"].shape) diff --git a/tests/test_resize.py b/tests/test_resize.py index 5f946a13e3..43fff6c08f 100644 --- a/tests/test_resize.py +++ b/tests/test_resize.py @@ -16,6 +16,7 @@ import torch from parameterized import parameterized +from monai.data import MetaTensor from monai.transforms import Resize from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose, is_tf32_env, pytorch_after @@ -60,6 +61,7 @@ def test_correct_results(self, spatial_size, mode, anti_aliasing): _order = 1 if spatial_size == (32, -1): spatial_size = (32, 64) + expected = [ skimage.transform.resize( channel, spatial_size, order=_order, clip=False, preserve_range=False, anti_aliasing=anti_aliasing @@ -69,19 +71,27 @@ def test_correct_results(self, spatial_size, mode, anti_aliasing): expected = np.stack(expected).astype(np.float32) for p in TEST_NDARRAYS: - out = resize(p(self.imt[0])) + im = p(self.imt[0]) + out = resize(im) + if isinstance(im, MetaTensor): + if not out.applied_operations: + return # skipped because good shape + im_inv = resize.inverse(out) + self.assertTrue(not im_inv.applied_operations) + assert_allclose(im_inv.shape, im.shape) + assert_allclose(im_inv.affine, im.affine, atol=1e-3, rtol=1e-3) if not anti_aliasing: assert_allclose(out, expected, type_test=False, atol=0.9) - else: - # skimage uses reflect padding for anti-aliasing filter. - # Our implementation reuses GaussianSmooth() as anti-aliasing filter, which uses zero padding instead. - # Thus their results near the image boundary will be different. - if isinstance(out, torch.Tensor): - out = out.cpu().detach().numpy() - good = np.sum(np.isclose(expected, out, atol=0.9)) - self.assertLessEqual( - np.abs(good - expected.size) / float(expected.size), diff_t, f"at most {diff_t} percent mismatch " - ) + return + # skimage uses reflect padding for anti-aliasing filter. + # Our implementation reuses GaussianSmooth() as anti-aliasing filter, which uses zero padding instead. + # Thus their results near the image boundary will be different. + if isinstance(out, torch.Tensor): + out = out.cpu().detach().numpy() + good = np.sum(np.isclose(expected, out, atol=0.9)) + self.assertLessEqual( + np.abs(good - expected.size) / float(expected.size), diff_t, f"at most {diff_t} percent mismatch " + ) @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4]) def test_longest_shape(self, input_param, expected_shape): diff --git a/tests/test_resized.py b/tests/test_resized.py index d7374ea930..732c141123 100644 --- a/tests/test_resized.py +++ b/tests/test_resized.py @@ -16,7 +16,7 @@ from parameterized import parameterized from monai.transforms import Resized -from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose, test_local_inversion TEST_CASE_0 = [{"keys": "img", "spatial_size": 15}, (6, 10, 15)] @@ -57,8 +57,10 @@ def test_correct_results(self, spatial_size, mode): expected = np.stack(expected).astype(np.float32) for p in TEST_NDARRAYS: - out = resize({"img": p(self.imt[0])})["img"] - assert_allclose(out, expected, type_test=False, atol=0.9) + im = p(self.imt[0]) + out = resize({"img": im}) + test_local_inversion(resize, out, {"img": im}, "img") + assert_allclose(out["img"], expected, type_test=False, atol=0.9) @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_2, TEST_CASE_3]) def test_longest_shape(self, input_param, expected_shape): diff --git a/tests/test_rotate.py b/tests/test_rotate.py index 01842f6d73..d174973f26 100644 --- a/tests/test_rotate.py +++ b/tests/test_rotate.py @@ -18,7 +18,7 @@ from parameterized import parameterized from monai.transforms import Rotate -from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, NumpyImageTestCase3D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, NumpyImageTestCase3D, test_local_inversion TEST_CASES_2D: List[Tuple] = [] for p in TEST_NDARRAYS: @@ -101,8 +101,10 @@ def test_correct_results(self, im_type, angle, keep_size, mode, padding_mode, al @parameterized.expand(TEST_CASES_SHAPE_3D) def test_correct_shape(self, im_type, angle, mode, padding_mode, align_corners): rotate_fn = Rotate(angle, True, align_corners=align_corners, dtype=np.float64) - rotated = rotate_fn(im_type(self.imt[0]), mode=mode, padding_mode=padding_mode) + im = im_type(self.imt[0]) + rotated = rotate_fn(im, mode=mode, padding_mode=padding_mode) np.testing.assert_allclose(self.imt[0].shape, rotated.shape) + test_local_inversion(rotate_fn, rotated, im) def test_ill_case(self): for p in TEST_NDARRAYS: diff --git a/tests/test_rotate90.py b/tests/test_rotate90.py index 9865120688..fe50cc5aac 100644 --- a/tests/test_rotate90.py +++ b/tests/test_rotate90.py @@ -14,41 +14,91 @@ import numpy as np from monai.transforms import Rotate90 -from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, NumpyImageTestCase3D, assert_allclose, test_local_inversion class TestRotate90(NumpyImageTestCase2D): def test_rotate90_default(self): rotate = Rotate90() for p in TEST_NDARRAYS: - rotated = rotate(p(self.imt[0])) + im = p(self.imt[0]) + rotated = rotate(im) + test_local_inversion(rotate, rotated, im) expected = [np.rot90(channel, 1, (0, 1)) for channel in self.imt[0]] expected = np.stack(expected) - assert_allclose(rotated, p(expected), rtol=1.0e-5, atol=1.0e-8) + assert_allclose(rotated, p(expected), rtol=1.0e-5, atol=1.0e-8, type_test=False) def test_k(self): rotate = Rotate90(k=2) for p in TEST_NDARRAYS: - rotated = rotate(p(self.imt[0])) + im = p(self.imt[0]) + rotated = rotate(im) + test_local_inversion(rotate, rotated, im) expected = [np.rot90(channel, 2, (0, 1)) for channel in self.imt[0]] expected = np.stack(expected) - assert_allclose(rotated, p(expected), rtol=1.0e-5, atol=1.0e-8) + assert_allclose(rotated, p(expected), rtol=1.0e-5, atol=1.0e-8, type_test=False) def test_spatial_axes(self): rotate = Rotate90(spatial_axes=(0, -1)) for p in TEST_NDARRAYS: - rotated = rotate(p(self.imt[0])) + im = p(self.imt[0]) + rotated = rotate(im) + test_local_inversion(rotate, rotated, im) expected = [np.rot90(channel, 1, (0, -1)) for channel in self.imt[0]] expected = np.stack(expected) - assert_allclose(rotated, p(expected), rtol=1.0e-5, atol=1.0e-8) + assert_allclose(rotated, p(expected), rtol=1.0e-5, atol=1.0e-8, type_test=False) def test_prob_k_spatial_axes(self): rotate = Rotate90(k=2, spatial_axes=(0, 1)) for p in TEST_NDARRAYS: - rotated = rotate(p(self.imt[0])) + im = p(self.imt[0]) + rotated = rotate(im) + test_local_inversion(rotate, rotated, im) expected = [np.rot90(channel, 2, (0, 1)) for channel in self.imt[0]] expected = np.stack(expected) - assert_allclose(rotated, p(expected), rtol=1.0e-5, atol=1.0e-8) + assert_allclose(rotated, p(expected), rtol=1.0e-5, atol=1.0e-8, type_test=False) + + +class TestRotate903d(NumpyImageTestCase3D): + def test_rotate90_default(self): + rotate = Rotate90() + for p in TEST_NDARRAYS: + im = p(self.imt[0]) + rotated = rotate(im) + test_local_inversion(rotate, rotated, im) + expected = [np.rot90(channel, 1, (0, 1)) for channel in self.imt[0]] + expected = np.stack(expected) + assert_allclose(rotated, p(expected), rtol=1.0e-5, atol=1.0e-8, type_test=False) + + def test_k(self): + rotate = Rotate90(k=2) + for p in TEST_NDARRAYS: + im = p(self.imt[0]) + rotated = rotate(im) + test_local_inversion(rotate, rotated, im) + expected = [np.rot90(channel, 2, (0, 1)) for channel in self.imt[0]] + expected = np.stack(expected) + assert_allclose(rotated, p(expected), rtol=1.0e-5, atol=1.0e-8, type_test=False) + + def test_spatial_axes(self): + rotate = Rotate90(spatial_axes=(0, -1)) + for p in TEST_NDARRAYS: + im = p(self.imt[0]) + rotated = rotate(im) + test_local_inversion(rotate, rotated, im) + expected = [np.rot90(channel, 1, (0, -1)) for channel in self.imt[0]] + expected = np.stack(expected) + assert_allclose(rotated, p(expected), rtol=1.0e-5, atol=1.0e-8, type_test=False) + + def test_prob_k_spatial_axes(self): + rotate = Rotate90(k=2, spatial_axes=(0, 1)) + for p in TEST_NDARRAYS: + im = p(self.imt[0]) + rotated = rotate(im) + test_local_inversion(rotate, rotated, im) + expected = [np.rot90(channel, 2, (0, 1)) for channel in self.imt[0]] + expected = np.stack(expected) + assert_allclose(rotated, p(expected), rtol=1.0e-5, atol=1.0e-8, type_test=False) if __name__ == "__main__": diff --git a/tests/test_rotate90d.py b/tests/test_rotate90d.py index ef4bad9419..bf50bd88fb 100644 --- a/tests/test_rotate90d.py +++ b/tests/test_rotate90d.py @@ -14,7 +14,7 @@ import numpy as np from monai.transforms import Rotate90d -from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose, test_local_inversion class TestRotate90d(NumpyImageTestCase2D): @@ -22,37 +22,45 @@ def test_rotate90_default(self): key = "test" rotate = Rotate90d(keys=key) for p in TEST_NDARRAYS: - rotated = rotate({key: p(self.imt[0])}) + im = p(self.imt[0]) + rotated = rotate({key: im}) + test_local_inversion(rotate, rotated, {key: im}, key) expected = [np.rot90(channel, 1, (0, 1)) for channel in self.imt[0]] expected = np.stack(expected) - assert_allclose(rotated[key], p(expected)) + assert_allclose(rotated[key], p(expected), type_test=False) def test_k(self): key = None rotate = Rotate90d(keys=key, k=2) for p in TEST_NDARRAYS: - rotated = rotate({key: p(self.imt[0])}) + im = p(self.imt[0]) + rotated = rotate({key: im}) + test_local_inversion(rotate, rotated, {key: im}, key) expected = [np.rot90(channel, 2, (0, 1)) for channel in self.imt[0]] expected = np.stack(expected) - assert_allclose(rotated[key], p(expected)) + assert_allclose(rotated[key], p(expected), type_test=False) def test_spatial_axes(self): key = "test" rotate = Rotate90d(keys=key, spatial_axes=(0, 1)) for p in TEST_NDARRAYS: - rotated = rotate({key: p(self.imt[0])}) + im = p(self.imt[0]) + rotated = rotate({key: im}) + test_local_inversion(rotate, rotated, {key: im}, key) expected = [np.rot90(channel, 1, (0, 1)) for channel in self.imt[0]] expected = np.stack(expected) - assert_allclose(rotated[key], p(expected)) + assert_allclose(rotated[key], p(expected), type_test=False) def test_prob_k_spatial_axes(self): key = "test" rotate = Rotate90d(keys=key, k=2, spatial_axes=(0, 1)) for p in TEST_NDARRAYS: - rotated = rotate({key: p(self.imt[0])}) + im = p(self.imt[0]) + rotated = rotate({key: im}) + test_local_inversion(rotate, rotated, {key: im}, key) expected = [np.rot90(channel, 2, (0, 1)) for channel in self.imt[0]] expected = np.stack(expected) - assert_allclose(rotated[key], p(expected)) + assert_allclose(rotated[key], p(expected), type_test=False) def test_no_key(self): key = "unknown" diff --git a/tests/test_rotated.py b/tests/test_rotated.py index 24ed82b84d..c7f46bc662 100644 --- a/tests/test_rotated.py +++ b/tests/test_rotated.py @@ -19,7 +19,7 @@ from monai.data import MetaTensor from monai.transforms import Rotated -from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, NumpyImageTestCase3D +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, NumpyImageTestCase3D, test_local_inversion TEST_CASES_2D: List[Tuple] = [] for p in TEST_NDARRAYS: @@ -44,7 +44,8 @@ def test_correct_results(self, im_type, angle, keep_size, mode, padding_mode, al rotate_fn = Rotated( ("img", "seg"), angle, keep_size, (mode, "nearest"), padding_mode, align_corners, dtype=np.float64 ) - rotated = rotate_fn({"img": im_type(self.imt[0]), "seg": im_type(self.segn[0])}) + im = im_type(self.imt[0]) + rotated = rotate_fn({"img": im, "seg": im_type(self.segn[0])}) if keep_size: np.testing.assert_allclose(self.imt[0].shape, rotated["img"].shape) _order = 0 if mode == "nearest" else 1 @@ -61,6 +62,7 @@ def test_correct_results(self, im_type, angle, keep_size, mode, padding_mode, al rotated[k] = v.cpu() if isinstance(v, torch.Tensor) else v good = np.sum(np.isclose(expected, rotated["img"][0], atol=1e-3)) self.assertLessEqual(np.abs(good - expected.size), 5, "diff at most 5 pixels") + test_local_inversion(rotate_fn, rotated, {"img": im}, "img") expected = scipy.ndimage.rotate( self.segn[0, 0], -np.rad2deg(angle), (0, 1), not keep_size, order=0, mode=_mode, prefilter=False diff --git a/tests/test_sliding_window_inference.py b/tests/test_sliding_window_inference.py index e1fa7d600e..8b8ec47d32 100644 --- a/tests/test_sliding_window_inference.py +++ b/tests/test_sliding_window_inference.py @@ -15,9 +15,10 @@ import torch from parameterized import parameterized +from monai.data.utils import list_data_collate from monai.inferers import SlidingWindowInferer, sliding_window_inference from monai.utils import optional_import -from tests.utils import skip_if_no_cuda +from tests.utils import TEST_TORCH_AND_META_TENSORS, skip_if_no_cuda _, has_tqdm = optional_import("tqdm") @@ -68,9 +69,11 @@ def compute(data): np.testing.assert_string_equal(device.type, result.device.type) np.testing.assert_allclose(result.cpu().numpy(), expected_val) - def test_default_device(self): + @parameterized.expand([[x] for x in TEST_TORCH_AND_META_TENSORS]) + def test_default_device(self, data_type): device = "cuda" if torch.cuda.is_available() else "cpu:0" - inputs = torch.ones((1, 3, 16, 15, 7)).to(device=device) + inputs = data_type(torch.ones((3, 16, 15, 7))).to(device=device) + inputs = list_data_collate([inputs]) # make a proper batch roi_shape = (4, 10, 7) sw_batch_size = 10 @@ -82,9 +85,11 @@ def compute(data): expected_val = np.ones((1, 3, 16, 15, 7), dtype=np.float32) + 1 np.testing.assert_allclose(result.cpu().numpy(), expected_val) + @parameterized.expand([[x] for x in TEST_TORCH_AND_META_TENSORS]) @skip_if_no_cuda - def test_sw_device(self): - inputs = torch.ones((1, 3, 16, 15, 7)).to(device="cpu") + def test_sw_device(self, data_type): + inputs = data_type(torch.ones((3, 16, 15, 7))).to(device="cpu") + inputs = list_data_collate([inputs]) # make a proper batch roi_shape = (4, 10, 7) sw_batch_size = 10 diff --git a/tests/test_testtimeaugmentation.py b/tests/test_testtimeaugmentation.py index 4b5ded3de1..75f3fdc181 100644 --- a/tests/test_testtimeaugmentation.py +++ b/tests/test_testtimeaugmentation.py @@ -14,10 +14,23 @@ from typing import TYPE_CHECKING import numpy as np +import torch -from monai.data import create_test_image_2d +from monai.data import CacheDataset, DataLoader, create_test_image_2d from monai.data.test_time_augmentation import TestTimeAugmentation -from monai.transforms import AddChanneld, Compose, RandScaleIntensityd +from monai.data.utils import pad_list_data_collate +from monai.losses import DiceLoss +from monai.networks.nets import UNet +from monai.transforms import ( + Activations, + AddChanneld, + AsDiscrete, + Compose, + CropForegroundd, + DivisiblePadd, + RandAffined, + RandScaleIntensityd, +) from monai.transforms.croppad.dictionary import SpatialPadd from monai.transforms.spatial.dictionary import RandFlipd from monai.utils import optional_import, set_determinism @@ -61,77 +74,76 @@ def tearDown(self) -> None: set_determinism(None) def test_test_time_augmentation(self): - pass - # input_size = (20, 40) # test different input data shape to pad list collate - # keys = ["image", "label"] - # num_training_ims = 10 - - # train_data = self.get_data(num_training_ims, input_size) - # test_data = self.get_data(1, input_size) - # device = "cuda" if torch.cuda.is_available() else "cpu" - - # transforms = Compose( - # [ - # AddChanneld(keys), - # RandAffined( - # keys, - # prob=1.0, - # spatial_size=(30, 30), - # rotate_range=(np.pi / 3, np.pi / 3), - # translate_range=(3, 3), - # scale_range=((0.8, 1), (0.8, 1)), - # padding_mode="zeros", - # mode=("bilinear", "nearest"), - # as_tensor_output=False, - # ), - # CropForegroundd(keys, source_key="image"), - # DivisiblePadd(keys, 4), - # ] - # ) - - # train_ds = CacheDataset(train_data, transforms) - # # output might be different size, so pad so that they match - # train_loader = DataLoader(train_ds, batch_size=2, collate_fn=pad_list_data_collate) - - # model = UNet(2, 1, 1, channels=(6, 6), strides=(2, 2)).to(device) - # loss_function = DiceLoss(sigmoid=True) - # optimizer = torch.optim.Adam(model.parameters(), 1e-3) - - # num_epochs = 10 - # for _ in trange(num_epochs): - # epoch_loss = 0 - - # for batch_data in train_loader: - # inputs, labels = batch_data["image"].to(device), batch_data["label"].to(device) - # optimizer.zero_grad() - # outputs = model(inputs) - # loss = loss_function(outputs, labels) - # loss.backward() - # optimizer.step() - # epoch_loss += loss.item() - - # epoch_loss /= len(train_loader) - - # post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)]) - - # tt_aug = TestTimeAugmentation( - # transform=transforms, - # batch_size=5, - # num_workers=0, - # inferrer_fn=model, - # device=device, - # to_tensor=True, - # output_device="cpu", - # post_func=post_trans, - # ) - # mode, mean, std, vvc = tt_aug(test_data) - # self.assertEqual(mode.shape, (1,) + input_size) - # self.assertEqual(mean.shape, (1,) + input_size) - # self.assertTrue(all(np.unique(mode) == (0, 1))) - # self.assertGreaterEqual(mean.min(), 0.0) - # self.assertLessEqual(mean.max(), 1.0) - # self.assertEqual(std.shape, (1,) + input_size) - # self.assertIsInstance(vvc, float) + input_size = (20, 40) # test different input data shape to pad list collate + keys = ["image", "label"] + num_training_ims = 10 + + train_data = self.get_data(num_training_ims, input_size) + test_data = self.get_data(1, input_size) + device = "cuda" if torch.cuda.is_available() else "cpu" + + transforms = Compose( + [ + AddChanneld(keys), + RandAffined( + keys, + prob=1.0, + spatial_size=(30, 30), + rotate_range=(np.pi / 3, np.pi / 3), + translate_range=(3, 3), + scale_range=((0.8, 1), (0.8, 1)), + padding_mode="zeros", + mode=("bilinear", "nearest"), + as_tensor_output=False, + ), + CropForegroundd(keys, source_key="image"), + DivisiblePadd(keys, 4), + ] + ) + + train_ds = CacheDataset(train_data, transforms) + # output might be different size, so pad so that they match + train_loader = DataLoader(train_ds, batch_size=2, collate_fn=pad_list_data_collate) + + model = UNet(2, 1, 1, channels=(6, 6), strides=(2, 2)).to(device) + loss_function = DiceLoss(sigmoid=True) + optimizer = torch.optim.Adam(model.parameters(), 1e-3) + + num_epochs = 10 + for _ in trange(num_epochs): + epoch_loss = 0 + + for batch_data in train_loader: + inputs, labels = batch_data["image"].to(device), batch_data["label"].to(device) + optimizer.zero_grad() + outputs = model(inputs) + loss = loss_function(outputs, labels) + loss.backward() + optimizer.step() + epoch_loss += loss.item() + + epoch_loss /= len(train_loader) + + post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)]) + + tt_aug = TestTimeAugmentation( + transform=transforms, + batch_size=5, + num_workers=0, + inferrer_fn=model, + device=device, + to_tensor=True, + output_device="cpu", + post_func=post_trans, + ) + mode, mean, std, vvc = tt_aug(test_data) + self.assertEqual(mode.shape, (1,) + input_size) + self.assertEqual(mean.shape, (1,) + input_size) + self.assertTrue(all(np.unique(mode) == (0, 1))) + self.assertGreaterEqual(mean.min(), 0.0) + self.assertLessEqual(mean.max(), 1.0) + self.assertEqual(std.shape, (1,) + input_size) + self.assertIsInstance(vvc, float) def test_warn_non_random(self): transforms = Compose([AddChanneld("im"), SpatialPadd("im", 1)]) @@ -164,20 +176,6 @@ def test_image_no_label(self): tta = TestTimeAugmentation(transforms, batch_size=5, num_workers=0, inferrer_fn=lambda x: x, orig_key="image") tta(self.get_data(1, (20, 20), include_label=False)) - # # @unittest.skipUnless(has_nib, "Requires nibabel") - # def test_requires_meta_dict(self): - # transforms = Compose( - # [ - # AddChanneld("image"), - # RandFlipd("image"), - # ToMetaTensord("image"), - # Spacingd("image", pixdim=1.1), - # FromMetaTensord("image"), - # ] - # ) - # tta = TestTimeAugmentation(transforms, batch_size=5, num_workers=0, inferrer_fn=lambda x: x, orig_key="image") - # tta(self.get_data(1, (20, 20), include_label=False)) - if __name__ == "__main__": unittest.main() diff --git a/tests/test_zoom.py b/tests/test_zoom.py index 1a7694072e..dee3565ba4 100644 --- a/tests/test_zoom.py +++ b/tests/test_zoom.py @@ -17,7 +17,7 @@ from scipy.ndimage import zoom as zoom_scipy from monai.transforms import Zoom -from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose, test_local_inversion VALID_CASES = [(1.5, "nearest"), (1.5, "nearest"), (0.8, "bilinear"), (0.8, "area")] @@ -29,7 +29,9 @@ class TestZoom(NumpyImageTestCase2D): def test_correct_results(self, zoom, mode): for p in TEST_NDARRAYS: zoom_fn = Zoom(zoom=zoom, mode=mode, keep_size=False) - zoomed = zoom_fn(p(self.imt[0])) + im = p(self.imt[0]) + zoomed = zoom_fn(im) + test_local_inversion(zoom_fn, zoomed, im) _order = 0 if mode.endswith("linear"): _order = 1 @@ -37,17 +39,21 @@ def test_correct_results(self, zoom, mode): for channel in self.imt[0]: expected.append(zoom_scipy(channel, zoom=zoom, mode="nearest", order=_order, prefilter=False)) expected = np.stack(expected).astype(np.float32) - assert_allclose(zoomed, p(expected), atol=1.0) + assert_allclose(zoomed, p(expected), atol=1.0, type_test=False) def test_keep_size(self): for p in TEST_NDARRAYS: zoom_fn = Zoom(zoom=[0.6, 0.6], keep_size=True, align_corners=True) - zoomed = zoom_fn(p(self.imt[0]), mode="bilinear") - assert_allclose(zoomed.shape, self.imt.shape[1:]) + im = p(self.imt[0]) + zoomed = zoom_fn(im, mode="bilinear") + assert_allclose(zoomed.shape, self.imt.shape[1:], type_test=False) + test_local_inversion(zoom_fn, zoomed, im) zoom_fn = Zoom(zoom=[1.3, 1.3], keep_size=True) - zoomed = zoom_fn(p(self.imt[0])) - assert_allclose(zoomed.shape, self.imt.shape[1:]) + im = p(self.imt[0]) + zoomed = zoom_fn(im) + assert_allclose(zoomed.shape, self.imt.shape[1:], type_test=False) + test_local_inversion(zoom_fn, zoomed, im) @parameterized.expand(INVALID_CASES) def test_invalid_inputs(self, zoom, mode, raises): diff --git a/tests/test_zoomd.py b/tests/test_zoomd.py index 87a5cec22b..231ed4c6e0 100644 --- a/tests/test_zoomd.py +++ b/tests/test_zoomd.py @@ -16,7 +16,7 @@ from scipy.ndimage import zoom as zoom_scipy from monai.transforms import Zoomd -from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose +from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose, test_local_inversion VALID_CASES = [(1.5, "nearest", False), (0.3, "bilinear", False), (0.8, "bilinear", False)] @@ -29,7 +29,9 @@ def test_correct_results(self, zoom, mode, keep_size): key = "img" zoom_fn = Zoomd(key, zoom=zoom, mode=mode, keep_size=keep_size) for p in TEST_NDARRAYS: - zoomed = zoom_fn({key: p(self.imt[0])}) + im = p(self.imt[0]) + zoomed = zoom_fn({key: im}) + test_local_inversion(zoom_fn, zoomed, {key: im}, key) _order = 0 if mode.endswith("linear"): _order = 1 @@ -38,7 +40,7 @@ def test_correct_results(self, zoom, mode, keep_size): ] expected = np.stack(expected).astype(np.float32) - assert_allclose(zoomed[key], p(expected), atol=1.0) + assert_allclose(zoomed[key], p(expected), atol=1.0, type_test=False) def test_keep_size(self): key = "img" diff --git a/tests/testing_data/inference.json b/tests/testing_data/inference.json index a6f85e43dd..46aa206d03 100644 --- a/tests/testing_data/inference.json +++ b/tests/testing_data/inference.json @@ -38,14 +38,6 @@ "_target_": "EnsureChannelFirstd", "keys": "image" }, - { - "_target_": "FromMetaTensord", - "keys": "image" - }, - { - "_target_": "ToNumpyd", - "keys": "image" - }, { "_target_": "ScaleIntensityd", "keys": "image" @@ -54,10 +46,6 @@ "_target_": "RandRotated", "_disabled_": true, "keys": "image" - }, - { - "_target_": "EnsureTyped", - "keys": "image" } ] }, @@ -96,11 +84,6 @@ "keys": "pred", "argmax": true }, - { - "_target_": "ToMetaTensord", - "keys": "pred", - "meta_keys": "image" - }, { "_target_": "SaveImaged", "keys": "pred", diff --git a/tests/testing_data/inference.yaml b/tests/testing_data/inference.yaml index c1b860a4fa..90f0bb35b9 100644 --- a/tests/testing_data/inference.yaml +++ b/tests/testing_data/inference.yaml @@ -29,17 +29,11 @@ preprocessing: keys: image - _target_: EnsureChannelFirstd keys: image - - _target_: FromMetaTensord - keys: image - - _target_: ToNumpyd - keys: image - _target_: ScaleIntensityd keys: image - _target_: RandRotated _disabled_: true keys: image - - _target_: EnsureTyped - keys: image dataset: _target_: need override data: "@_meta_#datalist" @@ -67,9 +61,6 @@ postprocessing: - _target_: AsDiscreted keys: pred argmax: true - - _target_: ToMetaTensord - keys: pred - meta_keys: image - _target_: SaveImaged keys: pred output_dir: "@_meta_#output_dir" diff --git a/tests/utils.py b/tests/utils.py index a0114fb2df..6a690c56f1 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -708,12 +708,29 @@ def query_memory(n=2): return ",".join(f"{int(x)}" for x in ids) +def test_local_inversion(invertible_xform, to_invert, im, dict_key=None): + """test that invertible_xform can bring to_invert back to im""" + im_item = im if dict_key is None else im[dict_key] + if not isinstance(im_item, MetaTensor): + return + im_inv = invertible_xform.inverse(to_invert) + if dict_key: + im_inv = im_inv[dict_key] + im = im[dict_key] + np.testing.assert_array_equal(im_inv.applied_operations, []) + assert_allclose(im_inv.shape, im.shape) + assert_allclose(im_inv.affine, im.affine, atol=1e-3, rtol=1e-3) + + TEST_TORCH_TENSORS: Tuple[Callable] = (torch.as_tensor,) # type: ignore if torch.cuda.is_available(): gpu_tensor: Callable = partial(torch.as_tensor, device="cuda") TEST_NDARRAYS = TEST_TORCH_TENSORS + (gpu_tensor,) # type: ignore -_metatensor_creator = partial(MetaTensor, meta={"a": "b", "affine": torch.eye(4) * 2}) +DEFAULT_TEST_AFFINE = torch.tensor( + [[2.0, 0.0, 0.0, 0.0], [0.0, 2.0, 0.0, 0.0], [0.0, 0.0, 2.0, 0.0], [0.0, 0.0, 0.0, 1.0]] +) +_metatensor_creator = partial(MetaTensor, meta={"a": "b", "affine": DEFAULT_TEST_AFFINE}) TEST_NDARRAYS_NO_META_TENSOR: Tuple[Callable] = (np.array,) + TEST_TORCH_TENSORS # type: ignore TEST_NDARRAYS: Tuple[Callable] = TEST_NDARRAYS_NO_META_TENSOR + (_metatensor_creator,) # type: ignore TEST_TORCH_AND_META_TENSORS: Tuple[Callable] = TEST_TORCH_TENSORS + (_metatensor_creator,) # type: ignore