diff --git a/monai/apps/detection/transforms/dictionary.py b/monai/apps/detection/transforms/dictionary.py index 2c29937aa4..8d4f08d282 100644 --- a/monai/apps/detection/transforms/dictionary.py +++ b/monai/apps/detection/transforms/dictionary.py @@ -15,7 +15,6 @@ Class names are ended with 'd' to denote dictionary-based transforms. """ from copy import deepcopy -from enum import Enum from typing import Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Type, Union import numpy as np @@ -37,15 +36,15 @@ from monai.config import KeysCollection, SequenceStr from monai.config.type_definitions import NdarrayOrTensor from monai.data.box_utils import COMPUTE_DTYPE, BoxMode, clip_boxes_to_image +from monai.data.meta_tensor import MetaTensor, get_track_meta from monai.data.utils import orientation_ras_lps -from monai.transforms import Flip, RandFlip, RandRotate90d, RandZoom, Rotate90, SpatialCrop, SpatialPad, Zoom +from monai.transforms import Flip, RandFlip, RandRotate90d, RandZoom, Rotate90, SpatialCrop, Zoom from monai.transforms.inverse import InvertibleTransform from monai.transforms.transform import MapTransform, Randomizable, RandomizableTransform from monai.transforms.utils import generate_pos_neg_label_crop_centers, map_binary_to_indices -from monai.utils import ImageMetaKey as Key from monai.utils import InterpolateMode, NumpyPadMode, ensure_tuple, ensure_tuple_rep, fall_back_tuple from monai.utils.enums import PostFix, TraceKeys -from monai.utils.type_conversion import convert_data_type +from monai.utils.type_conversion import convert_data_type, convert_to_tensor __all__ = [ "ConvertBoxModed", @@ -131,8 +130,8 @@ def __init__( def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): - self.push_transform(d, key, extra_info={"src": self.converter.src_mode, "dst": self.converter.dst_mode}) d[key] = self.converter(d[key]) + self.push_transform(d, key, extra_info={"src": self.converter.src_mode, "dst": self.converter.dst_mode}) return d def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: @@ -186,8 +185,8 @@ def __init__( def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) for key in self.key_iterator(d): - self.push_transform(d, key, extra_info={"mode": self.converter.mode}) d[key] = self.converter(d[key]) + self.push_transform(d, key, extra_info={"mode": self.converter.mode}) return d def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: @@ -242,23 +241,28 @@ def __init__( "Please provide a single key for box_ref_image_keys.\ All boxes of box_keys are attached to box_ref_image_keys." ) + self.box_ref_image_keys = box_ref_image_keys self.image_meta_key = image_meta_key or f"{box_ref_image_keys}_{image_meta_key_postfix}" self.converter_to_image_coordinate = AffineBox() self.affine_lps_to_ras = affine_lps_to_ras - def extract_affine(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Tuple[NdarrayOrTensor, NdarrayOrTensor]: + def extract_affine(self, data: Mapping[Hashable, torch.Tensor]) -> Tuple[NdarrayOrTensor, torch.Tensor]: d = dict(data) meta_key = self.image_meta_key - # extract affine matrix from meta_data - if meta_key not in d: + # extract affine matrix from metadata + if isinstance(d[self.box_ref_image_keys], MetaTensor): + meta_dict = d[self.box_ref_image_keys].meta # type: ignore + elif meta_key in d: + meta_dict = d[meta_key] + else: raise ValueError(f"{meta_key} is not found. Please check whether it is the correct the image meta key.") - if "affine" not in d[meta_key]: + if "affine" not in meta_dict: raise ValueError( f"'affine' is not found in {meta_key}. \ Please check whether it is the correct the image meta key." ) - affine: NdarrayOrTensor = d[meta_key]["affine"] # type: ignore + affine: NdarrayOrTensor = meta_dict["affine"] # type: ignore if self.affine_lps_to_ras: # RAS affine affine = orientation_ras_lps(affine) @@ -272,11 +276,11 @@ def extract_affine(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Tuple[Ndar def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) - affine, inv_affine_t = self.extract_affine(data) + affine, inv_affine_t = self.extract_affine(data) # type: ignore for key in self.key_iterator(d): - self.push_transform(d, key, extra_info={"affine": affine}) d[key] = self.converter_to_image_coordinate(d[key], affine=inv_affine_t) + self.push_transform(d, key, extra_info={"affine": affine}) return d def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: @@ -329,11 +333,11 @@ def __init__( def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) - affine, inv_affine_t = self.extract_affine(data) + affine, inv_affine_t = self.extract_affine(data) # type: ignore for key in self.key_iterator(d): - self.push_transform(d, key, extra_info={"affine": inv_affine_t}) d[key] = self.converter_to_world_coordinate(d[key], affine=affine) + self.push_transform(d, key, extra_info={"affine": inv_affine_t}) return d @@ -401,58 +405,32 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torc src_spatial_size = d[box_ref_image_key].shape[1:] dst_spatial_size = [int(round(z * ss)) for z, ss in zip(self.zoomer.zoom, src_spatial_size)] # type: ignore self.zoomer.zoom = [ds / float(ss) for ss, ds in zip(src_spatial_size, dst_spatial_size)] + d[box_key] = ZoomBox(zoom=self.zoomer.zoom, keep_size=self.keep_size)( + d[box_key], src_spatial_size=src_spatial_size + ) self.push_transform( d, box_key, extra_info={"zoom": self.zoomer.zoom, "src_spatial_size": src_spatial_size, "type": "box_key"}, ) - d[box_key] = ZoomBox(zoom=self.zoomer.zoom, keep_size=self.keep_size)( - d[box_key], src_spatial_size=src_spatial_size - ) - # zoom image, copied from monai.transforms.spatial.dictionary.Zoomd + # zoom image for key, mode, padding_mode, align_corners in zip( self.image_keys, self.mode, self.padding_mode, self.align_corners ): - self.push_transform( - d, - key, - extra_info={ - "mode": mode.value if isinstance(mode, Enum) else mode, - "padding_mode": padding_mode.value if isinstance(padding_mode, Enum) else padding_mode, - "align_corners": align_corners if align_corners is not None else TraceKeys.NONE, - "original_shape": d[key].shape[1:], - "type": "image_key", - }, - ) d[key] = self.zoomer(d[key], mode=mode, padding_mode=padding_mode, align_corners=align_corners) return d def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: - d = deepcopy(dict(data)) + d = dict(data) for key in self.key_iterator(d): - transform = self.get_most_recent_transform(d, key) - key_type = transform[TraceKeys.EXTRA_INFO]["type"] + transform = self.get_most_recent_transform(d, key, check=False) + key_type = transform[TraceKeys.EXTRA_INFO].get("type", "image_key") # zoom image, copied from monai.transforms.spatial.dictionary.Zoomd if key_type == "image_key": - # Create inverse transform - zoom = np.array(self.zoomer.zoom) - inverse_transform = Zoom(zoom=(1 / zoom).tolist(), keep_size=self.zoomer.keep_size) - mode = transform[TraceKeys.EXTRA_INFO]["mode"] - padding_mode = transform[TraceKeys.EXTRA_INFO]["padding_mode"] - align_corners = transform[TraceKeys.EXTRA_INFO]["align_corners"] - # Apply inverse - d[key] = inverse_transform( - d[key], - mode=mode, - padding_mode=padding_mode, - align_corners=None if align_corners == TraceKeys.NONE else align_corners, - ) - # Size might be out by 1 voxel so pad - orig_shape = transform[TraceKeys.EXTRA_INFO]["original_shape"] - d[key] = SpatialPad(orig_shape, mode="edge")(d[key]) + d[key] = self.zoomer.inverse(d[key]) # zoom boxes if key_type == "box_key": @@ -460,9 +438,8 @@ def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch src_spatial_size = transform[TraceKeys.EXTRA_INFO]["src_spatial_size"] box_inverse_transform = ZoomBox(zoom=(1 / zoom).tolist(), keep_size=self.zoomer.keep_size) d[key] = box_inverse_transform(d[key], src_spatial_size=src_spatial_size) - - # Remove the applied transform - self.pop_transform(d, key) + # Remove the applied transform + self.pop_transform(d, key) return d @@ -561,35 +538,28 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torc dst_spatial_size = [int(round(z * ss)) for z, ss in zip(self.rand_zoom._zoom, src_spatial_size)] self.rand_zoom._zoom = [ds / float(ss) for ss, ds in zip(src_spatial_size, dst_spatial_size)] + d[box_key] = ZoomBox(zoom=self.rand_zoom._zoom, keep_size=self.keep_size)( + d[box_key], src_spatial_size=src_spatial_size + ) self.push_transform( d, box_key, extra_info={"zoom": self.rand_zoom._zoom, "src_spatial_size": src_spatial_size, "type": "box_key"}, ) - d[box_key] = ZoomBox(zoom=self.rand_zoom._zoom, keep_size=self.keep_size)( - d[box_key], src_spatial_size=src_spatial_size - ) # zoom image, copied from monai.transforms.spatial.dictionary.RandZoomd for key, mode, padding_mode, align_corners in zip( self.image_keys, self.mode, self.padding_mode, self.align_corners ): if self._do_transform: - self.push_transform( - d, - key, - extra_info={ - "zoom": self.rand_zoom._zoom, - "mode": mode.value if isinstance(mode, Enum) else mode, - "padding_mode": padding_mode.value if isinstance(padding_mode, Enum) else padding_mode, - "align_corners": align_corners if align_corners is not None else TraceKeys.NONE, - "original_shape": d[key].shape[1:], - "type": "image_key", - }, - ) d[key] = self.rand_zoom( d[key], mode=mode, padding_mode=padding_mode, align_corners=align_corners, randomize=False ) + else: + d[key] = convert_to_tensor(d[key], track_meta=get_track_meta()) + if get_track_meta(): + xform = self.pop_transform(d[key], check=False) if self._do_transform else {} + self.push_transform(d[key], extra_info=xform) return d @@ -597,26 +567,15 @@ def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch d = dict(data) for key in self.key_iterator(d): - transform = self.get_most_recent_transform(d, key) - key_type = transform[TraceKeys.EXTRA_INFO]["type"] + transform = self.get_most_recent_transform(d, key, check=False) + key_type = transform[TraceKeys.EXTRA_INFO].get("type", "image_key") # Check if random transform was actually performed (based on `prob`) if transform[TraceKeys.DO_TRANSFORM]: # zoom image, copied from monai.transforms.spatial.dictionary.Zoomd if key_type == "image_key": - zoom = np.array(transform[TraceKeys.EXTRA_INFO]["zoom"]) - mode = transform[TraceKeys.EXTRA_INFO]["mode"] - padding_mode = transform[TraceKeys.EXTRA_INFO]["padding_mode"] - align_corners = transform[TraceKeys.EXTRA_INFO]["align_corners"] - inverse_transform = Zoom(zoom=(1.0 / zoom).tolist(), keep_size=self.rand_zoom.keep_size) - d[key] = inverse_transform( - d[key], - mode=mode, - padding_mode=padding_mode, - align_corners=None if align_corners == TraceKeys.NONE else align_corners, - ) - # Size might be out by 1 voxel so pad - orig_shape = transform[TraceKeys.EXTRA_INFO]["original_shape"] - d[key] = SpatialPad(orig_shape, mode="edge")(d[key]) + xform = self.pop_transform(d[key]) + d[key].applied_operations.append(xform[TraceKeys.EXTRA_INFO]) # type: ignore + d[key] = self.rand_zoom.inverse(d[key]) # zoom boxes if key_type == "box_key": @@ -625,9 +584,8 @@ def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch src_spatial_size = transform[TraceKeys.EXTRA_INFO]["src_spatial_size"] box_inverse_transform = ZoomBox(zoom=(1.0 / zoom).tolist(), keep_size=self.rand_zoom.keep_size) d[key] = box_inverse_transform(d[key], src_spatial_size=src_spatial_size) - - # Remove the applied transform - self.pop_transform(d, key) + # Remove the applied transform + self.pop_transform(d, key) return d @@ -666,7 +624,6 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torc for key in self.image_keys: d[key] = self.flipper(d[key]) - self.push_transform(d, key, extra_info={"type": "image_key"}) for box_key, box_ref_image_key in zip(self.box_keys, self.box_ref_image_keys): spatial_size = d[box_ref_image_key].shape[1:] @@ -678,20 +635,19 @@ def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch d = dict(data) for key in self.key_iterator(d): - transform = self.get_most_recent_transform(d, key) - key_type = transform[TraceKeys.EXTRA_INFO]["type"] + transform = self.get_most_recent_transform(d, key, check=False) + key_type = transform.get(TraceKeys.EXTRA_INFO, {}).get("type", "image_key") # flip image, copied from monai.transforms.spatial.dictionary.Flipd if key_type == "image_key": - d[key] = self.flipper(d[key]) + d[key] = self.flipper.inverse(d[key]) # flip boxes if key_type == "box_key": spatial_size = transform[TraceKeys.EXTRA_INFO]["spatial_size"] d[key] = self.box_flipper(d[key], spatial_size) - - # Remove the applied transform - self.pop_transform(d, key) + # Remove the applied transform + self.pop_transform(d, key) return d @@ -725,14 +681,13 @@ def __init__( RandomizableTransform.__init__(self, prob) self.box_ref_image_keys = ensure_tuple_rep(box_ref_image_keys, len(self.box_keys)) - self.flipper = RandFlip(prob=1.0, spatial_axis=spatial_axis) + self.flipper = Flip(spatial_axis=spatial_axis) self.box_flipper = FlipBox(spatial_axis=spatial_axis) def set_random_state( self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None ) -> "RandFlipBoxd": super().set_random_state(seed, state) - self.flipper.set_random_state(seed, state) return self def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: @@ -741,8 +696,12 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torc for key in self.image_keys: if self._do_transform: - d[key] = self.flipper(d[key], randomize=False) - self.push_transform(d, key, extra_info={"type": "image_key"}) + d[key] = self.flipper(d[key]) + else: + d[key] = convert_to_tensor(d[key], track_meta=get_track_meta()) + if get_track_meta(): + xform_info = self.pop_transform(d[key], check=False) if self._do_transform else {} + self.push_transform(d[key], extra_info=xform_info) for box_key, box_ref_image_key in zip(self.box_keys, self.box_ref_image_keys): spatial_size = d[box_ref_image_key].shape[1:] @@ -755,13 +714,14 @@ def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch d = dict(data) for key in self.key_iterator(d): - transform = self.get_most_recent_transform(d, key) - key_type = transform[TraceKeys.EXTRA_INFO]["type"] + transform = self.get_most_recent_transform(d, key, check=False) + key_type = transform[TraceKeys.EXTRA_INFO].get("type", "image_key") # Check if random transform was actually performed (based on `prob`) if transform[TraceKeys.DO_TRANSFORM]: # flip image, copied from monai.transforms.spatial.dictionary.RandFlipd if key_type == "image_key": - d[key] = self.flipper(d[key], randomize=False) + with self.flipper.trace_transform(False): + d[key] = self.flipper(d[key]) # flip boxes if key_type == "box_key": @@ -769,7 +729,7 @@ def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch d[key] = self.box_flipper(d[key], spatial_size) # Remove the applied transform - self.pop_transform(d, key) + self.pop_transform(d, key, check=False) return d @@ -1114,7 +1074,7 @@ def __init__( self.allow_smaller = allow_smaller def generate_fg_center_boxes_np(self, boxes: NdarrayOrTensor, image_size: Sequence[int]) -> np.ndarray: - # We don't require crop center to be whthin the boxes. + # We don't require crop center to be within the boxes. # As along as the cropped patch contains a box, it is considered as a foreground patch. # Positions within extended_boxes are crop centers for foreground patches spatial_dims = len(image_size) @@ -1211,13 +1171,6 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> List[Dict[Hashable, for label_key, cropped_labels_i in zip(self.label_keys, cropped_labels): results[i][label_key] = cropped_labels_i - # add `patch_index` to the meta data - for key, meta_key, meta_key_postfix in zip(self.image_keys, self.meta_keys, self.meta_key_postfix): - meta_key = meta_key or f"{key}_{meta_key_postfix}" - if meta_key not in results[i]: - results[i][meta_key] = {} # type: ignore - results[i][meta_key][Key.PATCH_INDEX] = i # type: ignore - return results @@ -1269,25 +1222,23 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Mapping[Hashable, t for key in self.image_keys: d[key] = self.img_rotator(d[key]) - self.push_transform(d, key, extra_info={"type": "image_key"}) return d def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: d = dict(data) for key in self.key_iterator(d): - transform = self.get_most_recent_transform(d, key) - key_type = transform[TraceKeys.EXTRA_INFO]["type"] + transform = self.get_most_recent_transform(d, key, check=False) + key_type = transform[TraceKeys.EXTRA_INFO].get("type", "image_key") num_times_to_rotate = 4 - self.img_rotator.k if key_type == "image_key": - inverse_transform = Rotate90(num_times_to_rotate, self.img_rotator.spatial_axes) - d[key] = inverse_transform(d[key]) + d[key] = self.img_rotator.inverse(d[key]) if key_type == "box_key": spatial_size = transform[TraceKeys.EXTRA_INFO]["spatial_size"] inverse_transform = RotateBox90(num_times_to_rotate, self.box_rotator.spatial_axes) d[key] = inverse_transform(d[key], spatial_size) - self.pop_transform(d, key) + self.pop_transform(d, key) return d @@ -1354,8 +1305,14 @@ def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Mapping[Hashable, t for key in self.image_keys: if self._do_transform: - d[key] = img_rotator(d[key]) - self.push_transform(d, key, extra_info={"rand_k": self._rand_k, "type": "image_key"}) + d[key] = ( + img_rotator(d[key]) + if self._do_transform + else convert_to_tensor(d[key], track_meta=get_track_meta()) + ) + if get_track_meta(): + xform = self.pop_transform(d[key], check=False) if self._do_transform else {} + self.push_transform(d[key], extra_info=xform) return d def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]: @@ -1364,21 +1321,21 @@ def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch return d for key in self.key_iterator(d): - transform = self.get_most_recent_transform(d, key) - key_type = transform[TraceKeys.EXTRA_INFO]["type"] + transform = self.get_most_recent_transform(d, key, check=False) + key_type = transform[TraceKeys.EXTRA_INFO].get("type", "image_key") # Check if random transform was actually performed (based on `prob`) if transform[TraceKeys.DO_TRANSFORM]: - num_times_rotated = transform[TraceKeys.EXTRA_INFO]["rand_k"] - num_times_to_rotate = 4 - num_times_rotated # flip image, copied from monai.transforms.spatial.dictionary.RandFlipd if key_type == "image_key": - inverse_transform = Rotate90(num_times_to_rotate, self.spatial_axes) - d[key] = inverse_transform(d[key]) + xform = self.pop_transform(d, key, check=False) + d[key] = Rotate90().inverse_transform(d[key], xform[TraceKeys.EXTRA_INFO]) if key_type == "box_key": + num_times_rotated = transform[TraceKeys.EXTRA_INFO]["rand_k"] + num_times_to_rotate = 4 - num_times_rotated spatial_size = transform[TraceKeys.EXTRA_INFO]["spatial_size"] inverse_transform = RotateBox90(num_times_to_rotate, self.spatial_axes) d[key] = inverse_transform(d[key], spatial_size) - self.pop_transform(d, key) + self.pop_transform(d, key) return d diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py index d604a8d2e8..67e1938454 100644 --- a/monai/transforms/post/dictionary.py +++ b/monai/transforms/post/dictionary.py @@ -668,10 +668,11 @@ def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: inverted_data = self._totensor(inverted[orig_key]) else: inverted_data = inverted[orig_key] - if config.USE_META_DICT and InvertibleTransform.trace_key(orig_key) in d: - d[InvertibleTransform.trace_key(orig_key)] = inverted_data.applied_operations d[key] = post_func(inverted_data.to(device)) - # save the inverted meta dict + # save the invertd applied_operations if it's in the source dict + if InvertibleTransform.trace_key(orig_key) in d: + d[InvertibleTransform.trace_key(orig_key)] = inverted_data.applied_operations + # save the inverted meta dict if it's in the source dict if orig_meta_key in d: meta_key = meta_key or f"{key}_{meta_key_postfix}" d[meta_key] = inverted.get(orig_meta_key) diff --git a/tests/test_box_transform.py b/tests/test_box_transform.py index ad82b080f1..d597ec76bb 100644 --- a/tests/test_box_transform.py +++ b/tests/test_box_transform.py @@ -10,6 +10,7 @@ # limitations under the License. import unittest +from copy import deepcopy import numpy as np import torch @@ -17,12 +18,22 @@ from monai.apps.detection.transforms.box_ops import convert_mask_to_box from monai.apps.detection.transforms.dictionary import ( + AffineBoxToImageCoordinated, + AffineBoxToWorldCoordinated, BoxToMaskd, + ClipBoxToImaged, ConvertBoxModed, + FlipBoxd, MaskToBoxd, RandCropBoxByPosNegLabeld, + RandFlipBoxd, + RandRotateBox90d, + RandZoomBoxd, + RotateBox90d, + ZoomBoxd, ) -from monai.transforms import CastToTyped +from monai.data.meta_tensor import MetaTensor +from monai.transforms import CastToTyped, Invertd from tests.utils import TEST_NDARRAYS, assert_allclose TESTS_3D = [] @@ -140,157 +151,181 @@ def test_value_3d( convert_result["boxes"], expected_convert_result, type_test=True, device_test=True, atol=1e-3 ) - # invert_transform_convert_mode = Invertd( - # keys=["boxes"], transform=transform_convert_mode, orig_keys=["boxes"] - # ) - # data_back = invert_transform_convert_mode(convert_result) - # assert_allclose(data_back["boxes"], data["boxes"], type_test=False, device_test=False, atol=1e-3) - # - # # test ZoomBoxd - # transform_zoom = ZoomBoxd( - # image_keys="image", box_keys="boxes", box_ref_image_keys="image", zoom=[0.5, 3, 1.5], keep_size=False - # ) - # zoom_result = transform_zoom(data) - # assert_allclose(zoom_result["boxes"], expected_zoom_result, type_test=True, device_test=True, atol=1e-3) - # invert_transform_zoom = Invertd( - # keys=["image", "boxes"], transform=transform_zoom, orig_keys=["image", "boxes"] - # ) - # data_back = invert_transform_zoom(zoom_result) - # assert_allclose(data_back["boxes"], data["boxes"], type_test=False, device_test=False, atol=1e-3) - # assert_allclose(data_back["image"], data["image"], type_test=False, device_test=False, atol=1e-3) - # - # transform_zoom = ZoomBoxd( - # image_keys="image", box_keys="boxes", box_ref_image_keys="image", zoom=[0.5, 3, 1.5], keep_size=True - # ) - # zoom_result = transform_zoom(data) - # assert_allclose( - # zoom_result["boxes"], expected_zoom_keepsize_result, type_test=True, device_test=True, atol=1e-3 - # ) - # - # # test RandZoomBoxd - # transform_zoom = RandZoomBoxd( - # image_keys="image", - # box_keys="boxes", - # box_ref_image_keys="image", - # prob=1.0, - # min_zoom=(0.3,) * 3, - # max_zoom=(3.0,) * 3, - # keep_size=False, - # ) - # zoom_result = transform_zoom(data) - # invert_transform_zoom = Invertd( - # keys=["image", "boxes"], transform=transform_zoom, orig_keys=["image", "boxes"] - # ) - # data_back = invert_transform_zoom(zoom_result) - # assert_allclose(data_back["boxes"], data["boxes"], type_test=False, device_test=False, atol=0.01) - # assert_allclose(data_back["image"], data["image"], type_test=False, device_test=False, atol=1e-3) - # - # # test AffineBoxToImageCoordinated, AffineBoxToWorldCoordinated - # transform_affine = AffineBoxToImageCoordinated(box_keys="boxes", box_ref_image_keys="image") - # with self.assertRaises(Exception) as context: - # transform_affine(data) - # self.assertTrue("Please check whether it is the correct the image meta key." in str(context.exception)) - # - # data["image_meta_dict"] = {"affine": torch.diag(1.0 / torch.Tensor([0.5, 3, 1.5, 1]))} - # affine_result = transform_affine(data) - # assert_allclose(affine_result["boxes"], expected_zoom_result, type_test=True, device_test=True, atol=0.01) - # invert_transform_affine = Invertd(keys=["boxes"], transform=transform_affine, orig_keys=["boxes"]) - # data_back = invert_transform_affine(affine_result) - # assert_allclose(data_back["boxes"], data["boxes"], type_test=False, device_test=False, atol=0.01) - # invert_transform_affine = AffineBoxToWorldCoordinated(box_keys="boxes", box_ref_image_keys="image") - # data_back = invert_transform_affine(affine_result) - # assert_allclose(data_back["boxes"], data["boxes"], type_test=False, device_test=False, atol=0.01) - # - # # test FlipBoxd - # transform_flip = FlipBoxd( - # image_keys="image", box_keys="boxes", box_ref_image_keys="image", spatial_axis=[0, 1, 2] - # ) - # flip_result = transform_flip(data) - # assert_allclose(flip_result["boxes"], expected_flip_result, type_test=True, device_test=True, atol=1e-3) - # invert_transform_flip = Invertd( - # keys=["image", "boxes"], transform=transform_flip, orig_keys=["image", "boxes"] - # ) - # data_back = invert_transform_flip(flip_result) - # assert_allclose(data_back["boxes"], data["boxes"], type_test=False, device_test=False, atol=1e-3) - # assert_allclose(data_back["image"], data["image"], type_test=False, device_test=False, atol=1e-3) - # - # # test RandFlipBoxd - # for spatial_axis in [(0,), (1,), (2,), (0, 1), (1, 2)]: - # transform_flip = RandFlipBoxd( - # image_keys="image", - # box_keys="boxes", - # box_ref_image_keys="image", - # prob=1.0, - # spatial_axis=spatial_axis, - # ) - # flip_result = transform_flip(data) - # invert_transform_flip = Invertd( - # keys=["image", "boxes"], transform=transform_flip, orig_keys=["image", "boxes"] - # ) - # data_back = invert_transform_flip(flip_result) - # assert_allclose(data_back["boxes"], data["boxes"], type_test=False, device_test=False, atol=1e-3) - # assert_allclose(data_back["image"], data["image"], type_test=False, device_test=False, atol=1e-3) - # - # # test ClipBoxToImaged - # transform_clip = ClipBoxToImaged( - # box_keys="boxes", box_ref_image_keys="image", label_keys=["labels", "scores"], remove_empty=True - # ) - # clip_result = transform_clip(data) - # assert_allclose(clip_result["boxes"], expected_clip_result, type_test=True, device_test=True, atol=1e-3) - # assert_allclose(clip_result["labels"], data["labels"][1:], type_test=True, device_test=True, atol=1e-3) - # assert_allclose(clip_result["scores"], data["scores"][1:], type_test=True, device_test=True, atol=1e-3) - # - # transform_clip = ClipBoxToImaged( - # box_keys="boxes", box_ref_image_keys="image", label_keys=[], remove_empty=True - # ) # corner case when label_keys is empty - # clip_result = transform_clip(data) - # assert_allclose(clip_result["boxes"], expected_clip_result, type_test=True, device_test=True, atol=1e-3) - # - # # test RandCropBoxByPosNegLabeld - # transform_crop = RandCropBoxByPosNegLabeld( - # image_keys="image", box_keys="boxes", label_keys=["labels", "scores"], spatial_size=2, num_samples=3 - # ) - # crop_result = transform_crop(data) - # assert len(crop_result) == 3 - # for ll in range(3): - # assert_allclose( - # crop_result[ll]["boxes"].shape[0], - # crop_result[ll]["labels"].shape[0], - # type_test=True, - # device_test=True, - # atol=1e-3, - # ) - # assert_allclose( - # crop_result[ll]["boxes"].shape[0], - # crop_result[ll]["scores"].shape[0], - # type_test=True, - # device_test=True, - # atol=1e-3, - # ) - # - # # test RotateBox90d - # transform_rotate = RotateBox90d( - # image_keys="image", box_keys="boxes", box_ref_image_keys="image", k=1, spatial_axes=[0, 1] - # ) - # rotate_result = transform_rotate(data) - # assert_allclose(rotate_result["boxes"], expected_rotate_result, type_test=True, device_test=True, atol=1e-3) - # invert_transform_rotate = Invertd( - # keys=["image", "boxes"], transform=transform_rotate, orig_keys=["image", "boxes"] - # ) - # data_back = invert_transform_rotate(rotate_result) - # assert_allclose(data_back["boxes"], data["boxes"], type_test=False, device_test=False, atol=1e-3) - # assert_allclose(data_back["image"], data["image"], type_test=False, device_test=False, atol=1e-3) - # - # transform_rotate = RandRotateBox90d( - # image_keys="image", box_keys="boxes", box_ref_image_keys="image", prob=1.0, max_k=3, spatial_axes=[0, 1] - # ) - # rotate_result = transform_rotate(data) - # invert_transform_rotate = Invertd( - # keys=["image", "boxes"], transform=transform_rotate, orig_keys=["image", "boxes"] - # ) - # data_back = invert_transform_rotate(rotate_result) - # assert_allclose(data_back["boxes"], data["boxes"], type_test=False, device_test=False, atol=1e-3) - # assert_allclose(data_back["image"], data["image"], type_test=False, device_test=False, atol=1e-3) + invert_transform_convert_mode = Invertd( + keys=["boxes"], transform=transform_convert_mode, orig_keys=["boxes"] + ) + data_back = invert_transform_convert_mode(convert_result) + if "boxes_transforms" in data_back: # if the transform is tracked in dict: + self.assertEqual(data_back["boxes_transforms"], []) # it should be updated + assert_allclose(data_back["boxes"], data["boxes"], type_test=False, device_test=False, atol=1e-3) + + # test ZoomBoxd + transform_zoom = ZoomBoxd( + image_keys="image", box_keys="boxes", box_ref_image_keys="image", zoom=[0.5, 3, 1.5], keep_size=False + ) + zoom_result = transform_zoom(data) + self.assertEqual(len(zoom_result["image"].applied_operations), 1) + assert_allclose(zoom_result["boxes"], expected_zoom_result, type_test=True, device_test=True, atol=1e-3) + invert_transform_zoom = Invertd( + keys=["image", "boxes"], transform=transform_zoom, orig_keys=["image", "boxes"] + ) + data_back = invert_transform_zoom(zoom_result) + self.assertEqual(data_back["image"].applied_operations, []) + assert_allclose(data_back["boxes"], data["boxes"], type_test=False, device_test=False, atol=1e-3) + assert_allclose(data_back["image"], data["image"], type_test=False, device_test=False, atol=1e-3) + + transform_zoom = ZoomBoxd( + image_keys="image", box_keys="boxes", box_ref_image_keys="image", zoom=[0.5, 3, 1.5], keep_size=True + ) + zoom_result = transform_zoom(data) + self.assertEqual(len(zoom_result["image"].applied_operations), 1) + assert_allclose( + zoom_result["boxes"], expected_zoom_keepsize_result, type_test=True, device_test=True, atol=1e-3 + ) + + # test RandZoomBoxd + transform_zoom = RandZoomBoxd( + image_keys="image", + box_keys="boxes", + box_ref_image_keys="image", + prob=1.0, + min_zoom=(0.3,) * 3, + max_zoom=(3.0,) * 3, + keep_size=False, + ) + zoom_result = transform_zoom(data) + self.assertEqual(len(zoom_result["image"].applied_operations), 1) + invert_transform_zoom = Invertd( + keys=["image", "boxes"], transform=transform_zoom, orig_keys=["image", "boxes"] + ) + data_back = invert_transform_zoom(zoom_result) + self.assertEqual(data_back["image"].applied_operations, []) + assert_allclose(data_back["boxes"], data["boxes"], type_test=False, device_test=False, atol=0.01) + assert_allclose(data_back["image"], data["image"], type_test=False, device_test=False, atol=1e-3) + + # test AffineBoxToImageCoordinated, AffineBoxToWorldCoordinated + transform_affine = AffineBoxToImageCoordinated(box_keys="boxes", box_ref_image_keys="image") + if not isinstance(data["image"], MetaTensor): # metadict should be undefined and it's an exception + with self.assertRaises(Exception) as context: + transform_affine(deepcopy(data)) + self.assertTrue("Please check whether it is the correct the image meta key." in str(context.exception)) + + data["image"] = MetaTensor(data["image"], meta={"affine": torch.diag(1.0 / torch.Tensor([0.5, 3, 1.5, 1]))}) + affine_result = transform_affine(data) + if "boxes_transforms" in affine_result: + self.assertEqual(len(affine_result["boxes_transforms"]), 1) + assert_allclose(affine_result["boxes"], expected_zoom_result, type_test=True, device_test=True, atol=0.01) + invert_transform_affine = Invertd(keys=["boxes"], transform=transform_affine, orig_keys=["boxes"]) + data_back = invert_transform_affine(affine_result) + if "boxes_transforms" in data_back: + self.assertEqual(data_back["boxes_transforms"], []) + assert_allclose(data_back["boxes"], data["boxes"], type_test=False, device_test=False, atol=0.01) + invert_transform_affine = AffineBoxToWorldCoordinated(box_keys="boxes", box_ref_image_keys="image") + data_back = invert_transform_affine(affine_result) + assert_allclose(data_back["boxes"], data["boxes"], type_test=False, device_test=False, atol=0.01) + + # test FlipBoxd + transform_flip = FlipBoxd( + image_keys="image", box_keys="boxes", box_ref_image_keys="image", spatial_axis=[0, 1, 2] + ) + flip_result = transform_flip(data) + if "boxes_transforms" in flip_result: + self.assertEqual(len(flip_result["boxes_transforms"]), 1) + assert_allclose(flip_result["boxes"], expected_flip_result, type_test=True, device_test=True, atol=1e-3) + invert_transform_flip = Invertd( + keys=["image", "boxes"], transform=transform_flip, orig_keys=["image", "boxes"] + ) + data_back = invert_transform_flip(flip_result) + if "boxes_transforms" in data_back: + self.assertEqual(data_back["boxes_transforms"], []) + assert_allclose(data_back["boxes"], data["boxes"], type_test=False, device_test=False, atol=1e-3) + assert_allclose(data_back["image"], data["image"], type_test=False, device_test=False, atol=1e-3) + + # test RandFlipBoxd + for spatial_axis in [(0,), (1,), (2,), (0, 1), (1, 2)]: + transform_flip = RandFlipBoxd( + image_keys="image", + box_keys="boxes", + box_ref_image_keys="image", + prob=1.0, + spatial_axis=spatial_axis, + ) + flip_result = transform_flip(data) + if "boxes_transforms" in flip_result: + self.assertEqual(len(flip_result["boxes_transforms"]), 1) + invert_transform_flip = Invertd( + keys=["image", "boxes"], transform=transform_flip, orig_keys=["image", "boxes"] + ) + data_back = invert_transform_flip(flip_result) + if "boxes_transforms" in data_back: + self.assertEqual(data_back["boxes_transforms"], []) + assert_allclose(data_back["boxes"], data["boxes"], type_test=False, device_test=False, atol=1e-3) + assert_allclose(data_back["image"], data["image"], type_test=False, device_test=False, atol=1e-3) + + # test ClipBoxToImaged + transform_clip = ClipBoxToImaged( + box_keys="boxes", box_ref_image_keys="image", label_keys=["labels", "scores"], remove_empty=True + ) + clip_result = transform_clip(data) + assert_allclose(clip_result["boxes"], expected_clip_result, type_test=True, device_test=True, atol=1e-3) + assert_allclose(clip_result["labels"], data["labels"][1:], type_test=True, device_test=True, atol=1e-3) + assert_allclose(clip_result["scores"], data["scores"][1:], type_test=True, device_test=True, atol=1e-3) + + transform_clip = ClipBoxToImaged( + box_keys="boxes", box_ref_image_keys="image", label_keys=[], remove_empty=True + ) # corner case when label_keys is empty + clip_result = transform_clip(data) + assert_allclose(clip_result["boxes"], expected_clip_result, type_test=True, device_test=True, atol=1e-3) + + # test RandCropBoxByPosNegLabeld + transform_crop = RandCropBoxByPosNegLabeld( + image_keys="image", box_keys="boxes", label_keys=["labels", "scores"], spatial_size=2, num_samples=3 + ) + crop_result = transform_crop(data) + assert len(crop_result) == 3 + for ll in range(3): + assert_allclose( + crop_result[ll]["boxes"].shape[0], + crop_result[ll]["labels"].shape[0], + type_test=True, + device_test=True, + atol=1e-3, + ) + assert_allclose( + crop_result[ll]["boxes"].shape[0], + crop_result[ll]["scores"].shape[0], + type_test=True, + device_test=True, + atol=1e-3, + ) + + # test RotateBox90d + transform_rotate = RotateBox90d( + image_keys="image", box_keys="boxes", box_ref_image_keys="image", k=1, spatial_axes=[0, 1] + ) + rotate_result = transform_rotate(data) + self.assertEqual(len(rotate_result["image"].applied_operations), 1) + assert_allclose(rotate_result["boxes"], expected_rotate_result, type_test=True, device_test=True, atol=1e-3) + invert_transform_rotate = Invertd( + keys=["image", "boxes"], transform=transform_rotate, orig_keys=["image", "boxes"] + ) + data_back = invert_transform_rotate(rotate_result) + self.assertEqual(data_back["image"].applied_operations, []) + assert_allclose(data_back["boxes"], data["boxes"], type_test=False, device_test=False, atol=1e-3) + assert_allclose(data_back["image"], data["image"], type_test=False, device_test=False, atol=1e-3) + + transform_rotate = RandRotateBox90d( + image_keys="image", box_keys="boxes", box_ref_image_keys="image", prob=1.0, max_k=3, spatial_axes=[0, 1] + ) + rotate_result = transform_rotate(data) + self.assertEqual(len(rotate_result["image"].applied_operations), 1) + invert_transform_rotate = Invertd( + keys=["image", "boxes"], transform=transform_rotate, orig_keys=["image", "boxes"] + ) + data_back = invert_transform_rotate(rotate_result) + self.assertEqual(data_back["image"].applied_operations, []) + assert_allclose(data_back["boxes"], data["boxes"], type_test=False, device_test=False, atol=1e-3) + assert_allclose(data_back["image"], data["image"], type_test=False, device_test=False, atol=1e-3) def test_crop_shape(self): tt = RandCropBoxByPosNegLabeld(