From 559c383b5e9c00fac2af521f63292c8262abf609 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Thu, 11 Nov 2021 18:34:31 +0000 Subject: [PATCH 1/6] adds traceable API Signed-off-by: Wenqi Li --- monai/data/test_time_augmentation.py | 4 +- monai/transforms/__init__.py | 2 +- monai/transforms/compose.py | 8 +- monai/transforms/croppad/batch.py | 4 +- monai/transforms/croppad/dictionary.py | 10 +- monai/transforms/inverse.py | 148 +++++++++++++++---------- monai/transforms/post/dictionary.py | 5 +- monai/utils/__init__.py | 1 + monai/utils/enums.py | 9 +- tests/test_inverse.py | 4 +- tests/test_one_of.py | 9 +- tests/test_traceable_transform.py | 53 +++++++++ 12 files changed, 178 insertions(+), 79 deletions(-) create mode 100644 tests/test_traceable_transform.py diff --git a/monai/data/test_time_augmentation.py b/monai/data/test_time_augmentation.py index 037ccc0652..52eab22b73 100644 --- a/monai/data/test_time_augmentation.py +++ b/monai/data/test_time_augmentation.py @@ -20,7 +20,7 @@ from monai.data.dataset import Dataset from monai.data.utils import list_data_collate, pad_list_data_collate from monai.transforms.compose import Compose -from monai.transforms.inverse import InvertibleTransform +from monai.transforms.inverse import InvertibleTransform, TraceableTransform from monai.transforms.inverse_batch_transform import BatchInverseTransform from monai.transforms.transform import Randomizable from monai.transforms.utils import allow_missing_keys_mode, convert_inverse_interp_mode @@ -168,7 +168,7 @@ def __call__( ds = Dataset(data_in, self.transform) dl = DataLoader(ds, num_workers=self.num_workers, batch_size=self.batch_size, collate_fn=pad_list_data_collate) - transform_key = self.orig_key + InverseKeys.KEY_SUFFIX + transform_key = TraceableTransform.transform_key(self.orig_key) # create inverter inverter = BatchInverseTransform(self.transform, dl, collate_fn=list_data_collate) diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 879e8d9a1a..cf52cde264 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -197,7 +197,7 @@ ThresholdIntensityD, ThresholdIntensityDict, ) -from .inverse import InvertibleTransform +from .inverse import InvertibleTransform, TraceableTransform from .inverse_batch_transform import BatchInverseTransform, Decollated from .io.array import SUPPORTED_READERS, LoadImage, SaveImage from .io.dictionary import LoadImaged, LoadImageD, LoadImageDict, SaveImaged, SaveImageD, SaveImageDict diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index 766005ffff..9a75121a6b 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -28,7 +28,7 @@ apply_transform, ) from monai.utils import MAX_SEED, ensure_tuple, get_seed -from monai.utils.enums import InverseKeys +from monai.utils.enums import TraceKeys __all__ = ["Compose", "OneOf"] @@ -237,7 +237,7 @@ def __call__(self, data): # if the data is a mapping (dictionary), append the OneOf transform to the end if isinstance(data, Mapping): for key in data.keys(): - if key + InverseKeys.KEY_SUFFIX in data: + if self.transform_key(key) in data: self.push_transform(data, key, extra_info={"index": index}) return data @@ -250,9 +250,9 @@ def inverse(self, data): # loop until we get an index and then break (since they'll all be the same) index = None for key in data.keys(): - if key + InverseKeys.KEY_SUFFIX in data: + if self.transform_key(key) in data: # get the index of the applied OneOf transform - index = self.get_most_recent_transform(data, key)[InverseKeys.EXTRA_INFO]["index"] + index = self.get_most_recent_transform(data, key)[TraceKeys.EXTRA_INFO]["index"] # and then remove the OneOf transform self.pop_transform(data, key) if index is None: diff --git a/monai/transforms/croppad/batch.py b/monai/transforms/croppad/batch.py index 86e32d1a1b..e2a566c27a 100644 --- a/monai/transforms/croppad/batch.py +++ b/monai/transforms/croppad/batch.py @@ -21,7 +21,7 @@ from monai.data.utils import list_data_collate from monai.transforms.croppad.array import CenterSpatialCrop, SpatialPad -from monai.transforms.inverse import InvertibleTransform +from monai.transforms.inverse import InvertibleTransform, TraceableTransform from monai.utils.enums import InverseKeys, Method, NumpyPadMode __all__ = ["PadListDataCollate"] @@ -115,7 +115,7 @@ def inverse(data: dict) -> Dict[Hashable, np.ndarray]: d = deepcopy(data) for key in d: - transform_key = str(key) + InverseKeys.KEY_SUFFIX + transform_key = TraceableTransform.transform_key(key) if transform_key in d: transform = d[transform_key][-1] if not isinstance(transform, Dict): diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 58e40a3e3b..0e82d1dbea 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -39,7 +39,7 @@ SpatialCrop, SpatialPad, ) -from monai.transforms.inverse import InvertibleTransform +from monai.transforms.inverse import InvertibleTransform, TraceableTransform from monai.transforms.transform import MapTransform, Randomizable from monai.transforms.utils import ( allow_missing_keys_mode, @@ -776,8 +776,8 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> List[Dict[Hashab cropped = self.cropper(d) # self.cropper will have added RandSpatialCropd to the list. Change to RandSpatialCropSamplesd for key in self.key_iterator(cropped): - cropped[str(key) + InverseKeys.KEY_SUFFIX][-1][InverseKeys.CLASS_NAME] = self.__class__.__name__ # type: ignore - cropped[str(key) + InverseKeys.KEY_SUFFIX][-1][InverseKeys.ID] = id(self) # type: ignore + cropped[TraceableTransform.transform_key(key)][-1][InverseKeys.CLASS_NAME] = self.__class__.__name__ # type: ignore + cropped[TraceableTransform.transform_key(key)][-1][InverseKeys.ID] = id(self) # type: ignore # add `patch_index` to the meta data for key, meta_key, meta_key_postfix in self.key_iterator(d, self.meta_keys, self.meta_key_postfix): meta_key = meta_key or f"{key}_{meta_key_postfix}" @@ -792,8 +792,8 @@ def inverse(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: # We changed the transform name from RandSpatialCropd to RandSpatialCropSamplesd # Need to revert that since we're calling RandSpatialCropd's inverse for key in self.key_iterator(d): - d[key + InverseKeys.KEY_SUFFIX][-1][InverseKeys.CLASS_NAME] = self.cropper.__class__.__name__ - d[key + InverseKeys.KEY_SUFFIX][-1][InverseKeys.ID] = id(self.cropper) + d[TraceableTransform.transform_key(key)][-1][InverseKeys.CLASS_NAME] = self.cropper.__class__.__name__ + d[TraceableTransform.transform_key(key)][-1][InverseKeys.ID] = id(self.cropper) context_manager = allow_missing_keys_mode if self.allow_missing_keys else _nullcontext with context_manager(self.cropper): return self.cropper.inverse(d) diff --git a/monai/transforms/inverse.py b/monai/transforms/inverse.py index 57a443241b..604dfd558a 100644 --- a/monai/transforms/inverse.py +++ b/monai/transforms/inverse.py @@ -8,18 +8,80 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from typing import Hashable, Optional, Tuple +import os +from typing import Hashable, Mapping, Optional, Tuple import torch -from monai.transforms.transform import RandomizableTransform, Transform -from monai.utils.enums import InverseKeys +from monai.transforms.transform import Transform +from monai.utils.enums import TraceKeys + +__all__ = ["TraceableTransform", "InvertibleTransform"] + + +class TraceableTransform(Transform): + """ + Maintains a stack of applied transforms. The stack is inserted as pairs of + `transform_key: list of transforms` to each data dictionary. + + The ``__call__`` method of this transform class must be implemented so + that the transformation information for each key is stored when + ``__call__`` is called. If the transforms were applied to keys "image" and + "label", there will be two extra keys in the dictionary: "image_transforms" + and "label_transforms" (based on `TraceKeys.KEY_SUFFIX`). Each list + contains a list of the transforms applied to that key. -__all__ = ["InvertibleTransform"] + The information in ``data[key_transform]`` will be compatible with the + default collate since it only stores strings, numbers and arrays. + + `tracing` could be enabled by `self.set_tracing` or setting + `MONAI_TRACE_TRANSFORM` when initializing the class. + """ + tracing = False if os.environ.get("MONAI_TRACE_TRANSFORM", "1") == "0" else True -class InvertibleTransform(Transform): + def set_tracing(self, tracing: bool) -> None: + """Set whether to trace transforms.""" + self.tracing = tracing + + @staticmethod + def transform_key(key: Hashable = None): + """The key to store the stack of applied transforms.""" + if key is None: + return TraceKeys.KEY_SUFFIX + return str(key) + TraceKeys.KEY_SUFFIX + + def push_transform( + self, data: Mapping, key: Hashable = None, extra_info: Optional[dict] = None, orig_size: Optional[Tuple] = None + ) -> None: + """PUsh to a stack of applied transforms for that key.""" + if not self.tracing: + return + info = {TraceKeys.CLASS_NAME: self.__class__.__name__, TraceKeys.ID: id(self)} + if orig_size is not None: + info[TraceKeys.ORIG_SIZE] = orig_size + elif key in data and hasattr(data[key], "shape"): + info[TraceKeys.ORIG_SIZE] = data[key].shape[1:] + if extra_info is not None: + info[TraceKeys.EXTRA_INFO] = extra_info + # If class is randomizable transform, store whether the transform was actually performed (based on `prob`) + if hasattr(self, "_do_transform"): # RandomizableTransform + info[TraceKeys.DO_TRANSFORM] = self._do_transform # type: ignore + # If this is the first, create list + if self.transform_key(key) not in data: + if not isinstance(data, dict): + data = dict(data) + data[self.transform_key(key)] = [] + data[self.transform_key(key)].append(info) + + def pop_transform(self, data: Mapping, key: Hashable = None): + """Remove the most recent applied transform.""" + if not self.tracing: + return + return data.get(self.transform_key(key), []).pop() + + +class InvertibleTransform(TraceableTransform): """Classes for invertible transforms. This class exists so that an ``invert`` method can be implemented. This allows, for @@ -27,28 +89,21 @@ class InvertibleTransform(Transform): and after be returned to their original size before saving to file for comparison in an external viewer. - When the ``__call__`` method is called, the transformation information for each key is - stored. If the transforms were applied to keys "image" and "label", there will be two - extra keys in the dictionary: "image_transforms" and "label_transforms". Each list - contains a list of the transforms applied to that key. When the ``inverse`` method is - called, the inverse is called on each key individually, which allows for different - parameters being passed to each label (e.g., different interpolation for image and - label). + When the ``inverse`` method is called: - When the ``inverse`` method is called, the inverse transforms are applied in a last- - in-first-out order. As the inverse is applied, its entry is removed from the list - detailing the applied transformations. That is to say that during the forward pass, - the list of applied transforms grows, and then during the inverse it shrinks back - down to an empty list. + - the inverse is called on each key individually, which allows for + different parameters being passed to each label (e.g., different + interpolation for image and label). - The information in ``data[key_transform]`` will be compatible with the default collate - since it only stores strings, numbers and arrays. + - the inverse transforms are applied in a last- in-first-out order. As + the inverse is applied, its entry is removed from the list detailing + the applied transformations. That is to say that during the forward + pass, the list of applied transforms grows, and then during the + inverse it shrinks back down to an empty list. We currently check that the ``id()`` of the transform is the same in the forward and inverse directions. This is a useful check to ensure that the inverses are being - processed in the correct order. However, this may cause issues if the ``id()`` of the - object changes (such as multiprocessing on Windows). If you feel this issue affects - you, please raise a GitHub issue. + processed in the correct order. Note to developers: When converting a transform to an invertible transform, you need to: @@ -63,48 +118,29 @@ class InvertibleTransform(Transform): """ - def push_transform( - self, data: dict, key: Hashable, extra_info: Optional[dict] = None, orig_size: Optional[Tuple] = None - ) -> None: - """Append to list of applied transforms for that key.""" - key_transform = str(key) + InverseKeys.KEY_SUFFIX - info = {InverseKeys.CLASS_NAME: self.__class__.__name__, InverseKeys.ID: id(self)} - if orig_size is not None: - info[InverseKeys.ORIG_SIZE] = orig_size - elif hasattr(data[key], "shape"): - info[InverseKeys.ORIG_SIZE] = data[key].shape[1:] - if extra_info is not None: - info[InverseKeys.EXTRA_INFO] = extra_info - # If class is randomizable transform, store whether the transform was actually performed (based on `prob`) - if isinstance(self, RandomizableTransform): - info[InverseKeys.DO_TRANSFORM] = self._do_transform - # If this is the first, create list - if key_transform not in data: - data[key_transform] = [] - data[key_transform].append(info) - - def check_transforms_match(self, transform: dict) -> None: + def check_transforms_match(self, transform: Mapping) -> None: """Check transforms are of same instance.""" - if transform[InverseKeys.ID] == id(self): + xform_name = transform.get(TraceKeys.CLASS_NAME, "") + xform_id = transform.get(TraceKeys.ID, "") + if xform_id == id(self): return # basic check if multiprocessing uses 'spawn' (objects get recreated so don't have same ID) - if ( - torch.multiprocessing.get_start_method() in ("spawn", None) - and transform[InverseKeys.CLASS_NAME] == self.__class__.__name__ - ): + if torch.multiprocessing.get_start_method() in ("spawn", None) and xform_name == self.__class__.__name__: return - raise RuntimeError("Should inverse most recently applied invertible transform first") + raise RuntimeError(f"Error inverting the most recently applied invertible transform {xform_name} {xform_id}.") + + def get_most_recent_transform(self, data: Mapping, key: Hashable = None): + """alias of self.peek_transform""" + return self.peek_transform(data, key) - def get_most_recent_transform(self, data: dict, key: Hashable) -> dict: + def peek_transform(self, data: Mapping, key: Hashable = None): """Get most recent transform.""" - transform = dict(data[str(key) + InverseKeys.KEY_SUFFIX][-1]) + if not self.tracing: + raise RuntimeError("Transform Tracing must be enabled to get the most recent transform.") + transform = data[self.transform_key(key)][-1] self.check_transforms_match(transform) return transform - def pop_transform(self, data: dict, key: Hashable) -> None: - """Remove most recent transform.""" - data[str(key) + InverseKeys.KEY_SUFFIX].pop() - def inverse(self, data: dict) -> dict: """ Inverse of ``__call__``. diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py index 596b4b3a21..e96f1599c9 100644 --- a/monai/transforms/post/dictionary.py +++ b/monai/transforms/post/dictionary.py @@ -24,7 +24,7 @@ from monai.config import KeysCollection from monai.config.type_definitions import NdarrayOrTensor from monai.data.csv_saver import CSVSaver -from monai.transforms.inverse import InvertibleTransform +from monai.transforms.inverse import InvertibleTransform, TraceableTransform from monai.transforms.post.array import ( Activations, AsDiscrete, @@ -40,7 +40,6 @@ from monai.transforms.utility.array import ToTensor from monai.transforms.utils import allow_missing_keys_mode, convert_inverse_interp_mode from monai.utils import deprecated_arg, ensure_tuple, ensure_tuple_rep -from monai.utils.enums import InverseKeys __all__ = [ "ActivationsD", @@ -600,7 +599,7 @@ def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: self.device, self.post_func, ): - transform_key = f"{orig_key}{InverseKeys.KEY_SUFFIX}" + transform_key = TraceableTransform.transform_key(orig_key) if transform_key not in d: warnings.warn(f"transform info of `{orig_key}` is not available or no InvertibleTransform applied.") continue diff --git a/monai/utils/__init__.py b/monai/utils/__init__.py index 42eba2e67f..3fa3b2d5f6 100644 --- a/monai/utils/__init__.py +++ b/monai/utils/__init__.py @@ -30,6 +30,7 @@ NumpyPadMode, PytorchPadMode, SkipMode, + TraceKeys, TransformBackends, UpsampleMode, Weight, diff --git a/monai/utils/enums.py b/monai/utils/enums.py index 51eb5da5d6..363dc16ac6 100644 --- a/monai/utils/enums.py +++ b/monai/utils/enums.py @@ -11,6 +11,8 @@ from enum import Enum +from monai.utils.aliases import alias + __all__ = [ "NumpyPadMode", "GridSampleMode", @@ -26,6 +28,7 @@ "ChannelMatching", "SkipMode", "Method", + "TraceKeys", "InverseKeys", "CommonKeys", "ForwardMode", @@ -208,7 +211,8 @@ class ForwardMode(Enum): EVAL = "eval" -class InverseKeys: +@alias("InverseKeys") +class TraceKeys: """Extra meta data keys used for inverse transforms.""" CLASS_NAME = "class" @@ -244,3 +248,6 @@ class TransformBackends(Enum): TORCH = "torch" NUMPY = "numpy" + + +InverseKeys = TraceKeys diff --git a/tests/test_inverse.py b/tests/test_inverse.py index d662441494..fc91795dce 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -59,13 +59,13 @@ Spacingd, SpatialCropd, SpatialPadd, + TraceableTransform, Transposed, Zoomd, allow_missing_keys_mode, convert_inverse_interp_mode, ) from monai.utils import first, get_seed, optional_import, set_determinism -from monai.utils.enums import InverseKeys from tests.utils import make_nifti_image, make_rand_affine if TYPE_CHECKING: @@ -466,7 +466,7 @@ def test_inverse_inferred_seg(self, extra_transform): labels = data["label"].to(device) segs = model(labels).detach().cpu() - label_transform_key = "label" + InverseKeys.KEY_SUFFIX + label_transform_key = TraceableTransform.transform_key("label") segs_dict = {"label": segs, label_transform_key: data[label_transform_key]} segs_dict_decollated = decollate_batch(segs_dict) diff --git a/tests/test_one_of.py b/tests/test_one_of.py index 9fe9f193a3..6a68d4f91c 100644 --- a/tests/test_one_of.py +++ b/tests/test_one_of.py @@ -14,7 +14,7 @@ from parameterized import parameterized -from monai.transforms import InvertibleTransform, OneOf, Transform +from monai.transforms import InvertibleTransform, OneOf, TraceableTransform, Transform from monai.transforms.compose import Compose from monai.transforms.transform import MapTransform from monai.utils.enums import InverseKeys @@ -148,7 +148,7 @@ def test_inverse(self, transform, should_be_ok): return for k in KEYS: - t = fwd_data[k + InverseKeys.KEY_SUFFIX][-1] + t = fwd_data[TraceableTransform.transform_key(k)][-1] # make sure the OneOf index was stored self.assertEqual(t[InverseKeys.CLASS_NAME], OneOf.__name__) # make sure index exists and is in bounds @@ -159,7 +159,10 @@ def test_inverse(self, transform, should_be_ok): for k in KEYS: # check transform was removed - self.assertTrue(len(fwd_inv_data[k + InverseKeys.KEY_SUFFIX]) < len(fwd_data[k + InverseKeys.KEY_SUFFIX])) + self.assertTrue( + len(fwd_inv_data[TraceableTransform.transform_key(k)]) + < len(fwd_data[TraceableTransform.transform_key(k)]) + ) # check data is same as original (and different from forward) self.assertEqual(fwd_inv_data[k], data[k]) self.assertNotEqual(fwd_inv_data[k], fwd_data[k]) diff --git a/tests/test_traceable_transform.py b/tests/test_traceable_transform.py new file mode 100644 index 0000000000..0f1640f8be --- /dev/null +++ b/tests/test_traceable_transform.py @@ -0,0 +1,53 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from monai.transforms.inverse import TraceableTransform + + +class _TraceTest(TraceableTransform): + def __call__(self, data): + self.push_transform(data) + return data + + def pop(self, data): + self.pop_transform(data) + return data + + +class TestTraceable(unittest.TestCase): + def test_default(self): + expected_key = "_transforms" + a = _TraceTest() + self.assertEqual(a.transform_key(), expected_key) + + data = {"image": "test"} + data = a(data) # adds to the stack + self.assertTrue(isinstance(data[expected_key], list)) + self.assertEqual(data[expected_key][0]["class"], "_TraceTest") + + data = a(data) # adds to the stack + self.assertEqual(len(data[expected_key]), 2) + self.assertEqual(data[expected_key][-1]["class"], "_TraceTest") + + with self.assertRaises(IndexError): + a.pop({"test": "test"}) # no stack in the data + data = a.pop(data) + data = a.pop(data) + self.assertEqual(data[expected_key], []) + + with self.assertRaises(IndexError): # no more items + a.pop(data) + + +if __name__ == "__main__": + unittest.main() From d2c65f33dd7c24b430aa79de2ae84b8f5770581b Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Thu, 11 Nov 2021 18:35:05 +0000 Subject: [PATCH 2/6] drop peek Signed-off-by: Wenqi Li --- monai/transforms/inverse.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/monai/transforms/inverse.py b/monai/transforms/inverse.py index 604dfd558a..46a485ffd0 100644 --- a/monai/transforms/inverse.py +++ b/monai/transforms/inverse.py @@ -130,10 +130,6 @@ def check_transforms_match(self, transform: Mapping) -> None: raise RuntimeError(f"Error inverting the most recently applied invertible transform {xform_name} {xform_id}.") def get_most_recent_transform(self, data: Mapping, key: Hashable = None): - """alias of self.peek_transform""" - return self.peek_transform(data, key) - - def peek_transform(self, data: Mapping, key: Hashable = None): """Get most recent transform.""" if not self.tracing: raise RuntimeError("Transform Tracing must be enabled to get the most recent transform.") From 2e4a325e45c38479a17af123f22b77aeca23ddb5 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Thu, 11 Nov 2021 18:38:02 +0000 Subject: [PATCH 3/6] deprecate inversekeys Signed-off-by: Wenqi Li --- monai/utils/enums.py | 27 ++++++++++++++++++++++----- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/monai/utils/enums.py b/monai/utils/enums.py index 363dc16ac6..9d643c3ebc 100644 --- a/monai/utils/enums.py +++ b/monai/utils/enums.py @@ -11,7 +11,7 @@ from enum import Enum -from monai.utils.aliases import alias +from monai.utils.deprecate_utils import deprecated __all__ = [ "NumpyPadMode", @@ -211,9 +211,8 @@ class ForwardMode(Enum): EVAL = "eval" -@alias("InverseKeys") class TraceKeys: - """Extra meta data keys used for inverse transforms.""" + """Extra meta data keys used for traceable transforms.""" CLASS_NAME = "class" ID = "id" @@ -224,6 +223,26 @@ class TraceKeys: NONE = "none" +@deprecated(since="0.7.0", msg_suffix="use monai.utils.TraceKeys instead.") +class InverseKeys: + """ + Extra meta data keys used for inverse transforms. + + .. deprecated:: 0.7.0 + Use :class:`monai.utils.TraceKeys` instead. + + """ + + CLASS_NAME = "class" + ID = "id" + ORIG_SIZE = "orig_size" + EXTRA_INFO = "extra_info" + DO_TRANSFORM = "do_transforms" + KEY_SUFFIX = "_transforms" + NONE = "none" + + + class CommonKeys: """ A set of common keys for dictionary based supervised training process. @@ -249,5 +268,3 @@ class TransformBackends(Enum): TORCH = "torch" NUMPY = "numpy" - -InverseKeys = TraceKeys From e098676efd3a8a36ca1133c09b94247bb27b594f Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Thu, 11 Nov 2021 18:41:12 +0000 Subject: [PATCH 4/6] inversekeys -> tracekeys Signed-off-by: Wenqi Li --- monai/data/test_time_augmentation.py | 4 +- monai/transforms/croppad/batch.py | 4 +- monai/transforms/croppad/dictionary.py | 46 +++++----- monai/transforms/spatial/dictionary.py | 118 ++++++++++++------------- monai/transforms/utility/dictionary.py | 6 +- monai/transforms/utils.py | 22 ++--- monai/utils/enums.py | 6 +- tests/test_decollate.py | 4 +- tests/test_one_of.py | 6 +- 9 files changed, 107 insertions(+), 109 deletions(-) diff --git a/monai/data/test_time_augmentation.py b/monai/data/test_time_augmentation.py index 52eab22b73..f4ee2a46bf 100644 --- a/monai/data/test_time_augmentation.py +++ b/monai/data/test_time_augmentation.py @@ -24,7 +24,7 @@ from monai.transforms.inverse_batch_transform import BatchInverseTransform from monai.transforms.transform import Randomizable from monai.transforms.utils import allow_missing_keys_mode, convert_inverse_interp_mode -from monai.utils.enums import CommonKeys, InverseKeys +from monai.utils.enums import CommonKeys, TraceKeys from monai.utils.module import optional_import if TYPE_CHECKING: @@ -188,7 +188,7 @@ def __call__( transform_info = batch_data.get(transform_key, None) if transform_info is None: # no invertible transforms, adding dummy info for identity invertible - transform_info = [[InverseKeys.NONE] for _ in range(self.batch_size)] + transform_info = [[TraceKeys.NONE] for _ in range(self.batch_size)] if self.nearest_interp: transform_info = convert_inverse_interp_mode( trans_info=deepcopy(transform_info), mode="nearest", align_corners=None diff --git a/monai/transforms/croppad/batch.py b/monai/transforms/croppad/batch.py index e2a566c27a..7d15115ab6 100644 --- a/monai/transforms/croppad/batch.py +++ b/monai/transforms/croppad/batch.py @@ -22,7 +22,7 @@ from monai.data.utils import list_data_collate from monai.transforms.croppad.array import CenterSpatialCrop, SpatialPad from monai.transforms.inverse import InvertibleTransform, TraceableTransform -from monai.utils.enums import InverseKeys, Method, NumpyPadMode +from monai.utils.enums import Method, NumpyPadMode, TraceKeys __all__ = ["PadListDataCollate"] @@ -120,7 +120,7 @@ def inverse(data: dict) -> Dict[Hashable, np.ndarray]: transform = d[transform_key][-1] if not isinstance(transform, Dict): continue - if transform.get(InverseKeys.CLASS_NAME) == PadListDataCollate.__name__: + if transform.get(TraceKeys.CLASS_NAME) == PadListDataCollate.__name__: d[key] = CenterSpatialCrop(transform.get("orig_size", -1))(d[key]) # fallback to image size # remove transform d[transform_key].pop() diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 0e82d1dbea..6ee13ff67a 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -52,7 +52,7 @@ ) from monai.utils import ImageMetaKey as Key from monai.utils import Method, NumpyPadMode, PytorchPadMode, ensure_tuple, ensure_tuple_rep, fall_back_tuple, first -from monai.utils.enums import InverseKeys +from monai.utils.enums import TraceKeys __all__ = [ "PadModeSequence", @@ -163,7 +163,7 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform - orig_size = transform[InverseKeys.ORIG_SIZE] + orig_size = transform[TraceKeys.ORIG_SIZE] if self.padder.method == Method.SYMMETRIC: current_size = d[key].shape[1:] roi_center = [floor(i / 2) if r % 2 == 0 else (i - 1) // 2 for r, i in zip(orig_size, current_size)] @@ -239,7 +239,7 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform - orig_size = np.array(transform[InverseKeys.ORIG_SIZE]) + orig_size = np.array(transform[TraceKeys.ORIG_SIZE]) roi_start = np.array(self.padder.spatial_border) # Need to convert single value to [min1,min2,...] if roi_start.size == 1: @@ -247,7 +247,7 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd # need to convert [min1,max1,min2,...] to [min1,min2,...] elif roi_start.size == 2 * orig_size.size: roi_start = roi_start[::2] - roi_end = np.array(transform[InverseKeys.ORIG_SIZE]) + roi_start + roi_end = np.array(transform[TraceKeys.ORIG_SIZE]) + roi_start inverse_transform = SpatialCrop(roi_start=roi_start, roi_end=roi_end) # Apply inverse transform @@ -315,7 +315,7 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform - orig_size = np.array(transform[InverseKeys.ORIG_SIZE]) + orig_size = np.array(transform[TraceKeys.ORIG_SIZE]) current_size = np.array(d[key].shape[1:]) roi_start = np.floor((current_size - orig_size) / 2) roi_end = orig_size + roi_start @@ -384,7 +384,7 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform - orig_size = np.array(transform[InverseKeys.ORIG_SIZE]) + orig_size = np.array(transform[TraceKeys.ORIG_SIZE]) current_size = np.array(d[key].shape[1:]) # get required pad to start and end pad_to_start = np.array([s.indices(o)[0] for s, o in zip(self.cropper.slices, orig_size)]) @@ -440,7 +440,7 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform - orig_size = np.array(transform[InverseKeys.ORIG_SIZE]) + orig_size = np.array(transform[TraceKeys.ORIG_SIZE]) current_size = np.array(d[key].shape[1:]) pad_to_start = np.floor((orig_size - current_size) / 2).astype(int) # in each direction, if original size is even and current size is odd, += 1 @@ -497,7 +497,7 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform - orig_size = np.array(transform[InverseKeys.ORIG_SIZE]) + orig_size = np.array(transform[TraceKeys.ORIG_SIZE]) current_size = np.array(d[key].shape[1:]) pad_to_start = np.floor((orig_size - current_size) / 2).astype(int) # in each direction, if original size is even and current size is odd, += 1 @@ -594,12 +594,12 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform - orig_size = transform[InverseKeys.ORIG_SIZE] + orig_size = transform[TraceKeys.ORIG_SIZE] random_center = self.random_center pad_to_start = np.empty((len(orig_size)), dtype=np.int32) pad_to_end = np.empty((len(orig_size)), dtype=np.int32) if random_center: - for i, _slice in enumerate(transform[InverseKeys.EXTRA_INFO]["slices"]): + for i, _slice in enumerate(transform[TraceKeys.EXTRA_INFO]["slices"]): pad_to_start[i] = _slice[0] pad_to_end[i] = orig_size[i] - _slice[1] else: @@ -776,8 +776,8 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> List[Dict[Hashab cropped = self.cropper(d) # self.cropper will have added RandSpatialCropd to the list. Change to RandSpatialCropSamplesd for key in self.key_iterator(cropped): - cropped[TraceableTransform.transform_key(key)][-1][InverseKeys.CLASS_NAME] = self.__class__.__name__ # type: ignore - cropped[TraceableTransform.transform_key(key)][-1][InverseKeys.ID] = id(self) # type: ignore + cropped[TraceableTransform.transform_key(key)][-1][TraceKeys.CLASS_NAME] = self.__class__.__name__ # type: ignore + cropped[TraceableTransform.transform_key(key)][-1][TraceKeys.ID] = id(self) # type: ignore # add `patch_index` to the meta data for key, meta_key, meta_key_postfix in self.key_iterator(d, self.meta_keys, self.meta_key_postfix): meta_key = meta_key or f"{key}_{meta_key_postfix}" @@ -792,8 +792,8 @@ def inverse(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: # We changed the transform name from RandSpatialCropd to RandSpatialCropSamplesd # Need to revert that since we're calling RandSpatialCropd's inverse for key in self.key_iterator(d): - d[TraceableTransform.transform_key(key)][-1][InverseKeys.CLASS_NAME] = self.cropper.__class__.__name__ - d[TraceableTransform.transform_key(key)][-1][InverseKeys.ID] = id(self.cropper) + d[TraceableTransform.transform_key(key)][-1][TraceKeys.CLASS_NAME] = self.cropper.__class__.__name__ + d[TraceableTransform.transform_key(key)][-1][TraceKeys.ID] = id(self.cropper) context_manager = allow_missing_keys_mode if self.allow_missing_keys else _nullcontext with context_manager(self.cropper): return self.cropper.inverse(d) @@ -877,9 +877,9 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform - orig_size = np.asarray(transform[InverseKeys.ORIG_SIZE]) + orig_size = np.asarray(transform[TraceKeys.ORIG_SIZE]) cur_size = np.asarray(d[key].shape[1:]) - extra_info = transform[InverseKeys.EXTRA_INFO] + extra_info = transform[TraceKeys.EXTRA_INFO] box_start = np.asarray(extra_info["box_start"]) box_end = np.asarray(extra_info["box_end"]) # first crop the padding part @@ -999,9 +999,9 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform - orig_size = np.asarray(transform[InverseKeys.ORIG_SIZE]) + orig_size = np.asarray(transform[TraceKeys.ORIG_SIZE]) current_size = np.asarray(d[key].shape[1:]) - center = transform[InverseKeys.EXTRA_INFO]["center"] + center = transform[TraceKeys.EXTRA_INFO]["center"] cropper = SpatialCrop(roi_center=center, roi_size=self.spatial_size) # get required pad to start and end pad_to_start = np.array([s.indices(o)[0] for s, o in zip(cropper.slices, orig_size)]) @@ -1179,9 +1179,9 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform - orig_size = np.asarray(transform[InverseKeys.ORIG_SIZE]) + orig_size = np.asarray(transform[TraceKeys.ORIG_SIZE]) current_size = np.asarray(d[key].shape[1:]) - center = transform[InverseKeys.EXTRA_INFO]["center"] + center = transform[TraceKeys.EXTRA_INFO]["center"] cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size) # type: ignore # get required pad to start and end pad_to_start = np.array([s.indices(o)[0] for s, o in zip(cropper.slices, orig_size)]) @@ -1364,9 +1364,9 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform - orig_size = np.asarray(transform[InverseKeys.ORIG_SIZE]) + orig_size = np.asarray(transform[TraceKeys.ORIG_SIZE]) current_size = np.asarray(d[key].shape[1:]) - center = transform[InverseKeys.EXTRA_INFO]["center"] + center = transform[TraceKeys.EXTRA_INFO]["center"] cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size) # type: ignore # get required pad to start and end pad_to_start = np.array([s.indices(o)[0] for s, o in zip(cropper.slices, orig_size)]) @@ -1432,7 +1432,7 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform - orig_size = np.array(transform[InverseKeys.ORIG_SIZE]) + orig_size = np.array(transform[TraceKeys.ORIG_SIZE]) current_size = np.array(d[key].shape[1:]) # Unfortunately, we can't just use ResizeWithPadOrCrop with original size because of odd/even rounding. # Instead, we first pad any smaller dimensions, and then we crop any larger dimensions. diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 8bfdd6fd52..6a1e2041f7 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -63,7 +63,7 @@ first, ) from monai.utils.deprecate_utils import deprecated_arg -from monai.utils.enums import InverseKeys +from monai.utils.enums import TraceKeys from monai.utils.module import optional_import from monai.utils.type_conversion import convert_data_type, convert_to_dst_type @@ -257,7 +257,7 @@ def __call__( "old_affine": old_affine, "mode": mode.value if isinstance(mode, Enum) else mode, "padding_mode": padding_mode.value if isinstance(padding_mode, Enum) else padding_mode, - "align_corners": align_corners if align_corners is not None else InverseKeys.NONE, + "align_corners": align_corners if align_corners is not None else TraceKeys.NONE, }, orig_size=original_spatial_shape, ) @@ -275,12 +275,12 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd + "Please raise a github issue if you need this feature" ) # Create inverse transform - meta_data = d[transform[InverseKeys.EXTRA_INFO]["meta_key"]] - old_affine = np.array(transform[InverseKeys.EXTRA_INFO]["old_affine"]) - mode = transform[InverseKeys.EXTRA_INFO]["mode"] - padding_mode = transform[InverseKeys.EXTRA_INFO]["padding_mode"] - align_corners = transform[InverseKeys.EXTRA_INFO]["align_corners"] - orig_size = transform[InverseKeys.ORIG_SIZE] + meta_data = d[transform[TraceKeys.EXTRA_INFO]["meta_key"]] + old_affine = np.array(transform[TraceKeys.EXTRA_INFO]["old_affine"]) + mode = transform[TraceKeys.EXTRA_INFO]["mode"] + padding_mode = transform[TraceKeys.EXTRA_INFO]["padding_mode"] + align_corners = transform[TraceKeys.EXTRA_INFO]["align_corners"] + orig_size = transform[TraceKeys.ORIG_SIZE] orig_pixdim = np.sqrt(np.sum(np.square(old_affine), 0))[:-1] inverse_transform = Spacing(orig_pixdim, diagonal=self.spacing_transform.diagonal) # Apply inverse @@ -289,7 +289,7 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd affine=meta_data["affine"], # type: ignore mode=mode, padding_mode=padding_mode, - align_corners=False if align_corners == InverseKeys.NONE else align_corners, + align_corners=False if align_corners == TraceKeys.NONE else align_corners, dtype=dtype, output_spatial_shape=orig_size, ) @@ -381,8 +381,8 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform - meta_data: Dict = d[transform[InverseKeys.EXTRA_INFO]["meta_key"]] # type: ignore - orig_affine = transform[InverseKeys.EXTRA_INFO]["old_affine"] + meta_data: Dict = d[transform[TraceKeys.EXTRA_INFO]["meta_key"]] # type: ignore + orig_affine = transform[TraceKeys.EXTRA_INFO]["old_affine"] orig_axcodes = nib.orientations.aff2axcodes(orig_affine) inverse_transform = Orientation( axcodes=orig_axcodes, as_closest_canonical=False, labels=self.ornt_transform.labels @@ -499,9 +499,9 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Check if random transform was actually performed (based on `prob`) - if transform[InverseKeys.DO_TRANSFORM]: + if transform[TraceKeys.DO_TRANSFORM]: # Create inverse transform - num_times_rotated = transform[InverseKeys.EXTRA_INFO]["rand_k"] + num_times_rotated = transform[TraceKeys.EXTRA_INFO]["rand_k"] num_times_to_rotate = 4 - num_times_rotated inverse_transform = Rotate90(num_times_to_rotate, self.spatial_axes) # Apply inverse @@ -563,7 +563,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N key, extra_info={ "mode": mode.value if isinstance(mode, Enum) else mode, - "align_corners": align_corners if align_corners is not None else InverseKeys.NONE, + "align_corners": align_corners if align_corners is not None else TraceKeys.NONE, }, ) d[key] = self.resizer(d[key], mode=mode, align_corners=align_corners) @@ -573,14 +573,14 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd d = deepcopy(dict(data)) for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) - orig_size = transform[InverseKeys.ORIG_SIZE] - mode = transform[InverseKeys.EXTRA_INFO]["mode"] - align_corners = transform[InverseKeys.EXTRA_INFO]["align_corners"] + orig_size = transform[TraceKeys.ORIG_SIZE] + mode = transform[TraceKeys.EXTRA_INFO]["mode"] + align_corners = transform[TraceKeys.EXTRA_INFO]["align_corners"] # Create inverse transform inverse_transform = Resize( spatial_size=orig_size, mode=mode, - align_corners=None if align_corners == InverseKeys.NONE else align_corners, + align_corners=None if align_corners == TraceKeys.NONE else align_corners, ) # Apply inverse transform d[key] = inverse_transform(d[key]) @@ -690,11 +690,11 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) - orig_size = transform[InverseKeys.ORIG_SIZE] + orig_size = transform[TraceKeys.ORIG_SIZE] # Create inverse transform - fwd_affine = transform[InverseKeys.EXTRA_INFO]["affine"] - mode = transform[InverseKeys.EXTRA_INFO]["mode"] - padding_mode = transform[InverseKeys.EXTRA_INFO]["padding_mode"] + fwd_affine = transform[TraceKeys.EXTRA_INFO]["affine"] + mode = transform[TraceKeys.EXTRA_INFO]["mode"] + padding_mode = transform[TraceKeys.EXTRA_INFO]["padding_mode"] inv_affine = np.linalg.inv(fwd_affine) affine_grid = AffineGrid(affine=inv_affine) @@ -853,12 +853,12 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # if transform was not performed and spatial size is None, nothing to do. - if transform[InverseKeys.DO_TRANSFORM] or self.rand_affine.spatial_size is not None: - orig_size = transform[InverseKeys.ORIG_SIZE] + if transform[TraceKeys.DO_TRANSFORM] or self.rand_affine.spatial_size is not None: + orig_size = transform[TraceKeys.ORIG_SIZE] # Create inverse transform - fwd_affine = transform[InverseKeys.EXTRA_INFO]["affine"] - mode = transform[InverseKeys.EXTRA_INFO]["mode"] - padding_mode = transform[InverseKeys.EXTRA_INFO]["padding_mode"] + fwd_affine = transform[TraceKeys.EXTRA_INFO]["affine"] + mode = transform[TraceKeys.EXTRA_INFO]["mode"] + padding_mode = transform[TraceKeys.EXTRA_INFO]["padding_mode"] inv_affine = np.linalg.inv(fwd_affine) affine_grid = AffineGrid(affine=inv_affine) @@ -1221,7 +1221,7 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Check if random transform was actually performed (based on `prob`) - if transform[InverseKeys.DO_TRANSFORM]: + if transform[TraceKeys.DO_TRANSFORM]: # Inverse is same as forward d[key] = self.flipper(d[key], randomize=False) # Remove the applied transform @@ -1274,8 +1274,8 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Check if random transform was actually performed (based on `prob`) - if transform[InverseKeys.DO_TRANSFORM]: - flipper = Flip(spatial_axis=transform[InverseKeys.EXTRA_INFO]["axis"]) + if transform[TraceKeys.DO_TRANSFORM]: + flipper = Flip(spatial_axis=transform[TraceKeys.EXTRA_INFO]["axis"]) # Inverse is same as forward d[key] = flipper(d[key]) # Remove the applied transform @@ -1350,7 +1350,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N "rot_mat": rot_mat, "mode": mode.value if isinstance(mode, Enum) else mode, "padding_mode": padding_mode.value if isinstance(padding_mode, Enum) else padding_mode, - "align_corners": align_corners if align_corners is not None else InverseKeys.NONE, + "align_corners": align_corners if align_corners is not None else TraceKeys.NONE, }, ) return d @@ -1360,17 +1360,17 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd for key, dtype in self.key_iterator(d, self.dtype): transform = self.get_most_recent_transform(d, key) # Create inverse transform - fwd_rot_mat = transform[InverseKeys.EXTRA_INFO]["rot_mat"] - mode = transform[InverseKeys.EXTRA_INFO]["mode"] - padding_mode = transform[InverseKeys.EXTRA_INFO]["padding_mode"] - align_corners = transform[InverseKeys.EXTRA_INFO]["align_corners"] + fwd_rot_mat = transform[TraceKeys.EXTRA_INFO]["rot_mat"] + mode = transform[TraceKeys.EXTRA_INFO]["mode"] + padding_mode = transform[TraceKeys.EXTRA_INFO]["padding_mode"] + align_corners = transform[TraceKeys.EXTRA_INFO]["align_corners"] inv_rot_mat = np.linalg.inv(fwd_rot_mat) xform = AffineTransform( normalized=False, mode=mode, padding_mode=padding_mode, - align_corners=False if align_corners == InverseKeys.NONE else align_corners, + align_corners=False if align_corners == TraceKeys.NONE else align_corners, reverse_indexing=True, ) img_t: torch.Tensor @@ -1378,7 +1378,7 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd transform_t: torch.Tensor transform_t, *_ = convert_to_dst_type(inv_rot_mat, img_t) # type: ignore - out = xform(img_t.unsqueeze(0), transform_t, spatial_size=transform[InverseKeys.ORIG_SIZE]).squeeze(0) + out = xform(img_t.unsqueeze(0), transform_t, spatial_size=transform[TraceKeys.ORIG_SIZE]).squeeze(0) out, *_ = convert_to_dst_type(out, dst=d[key], dtype=out.dtype) d[key] = out # Remove the applied transform @@ -1482,7 +1482,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N "rot_mat": rot_mat, "mode": mode.value if isinstance(mode, Enum) else mode, "padding_mode": padding_mode.value if isinstance(padding_mode, Enum) else padding_mode, - "align_corners": align_corners if align_corners is not None else InverseKeys.NONE, + "align_corners": align_corners if align_corners is not None else TraceKeys.NONE, }, ) return d @@ -1492,19 +1492,19 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd for key, dtype in self.key_iterator(d, self.dtype): transform = self.get_most_recent_transform(d, key) # Check if random transform was actually performed (based on `prob`) - if transform[InverseKeys.DO_TRANSFORM]: + if transform[TraceKeys.DO_TRANSFORM]: # Create inverse transform - fwd_rot_mat = transform[InverseKeys.EXTRA_INFO]["rot_mat"] - mode = transform[InverseKeys.EXTRA_INFO]["mode"] - padding_mode = transform[InverseKeys.EXTRA_INFO]["padding_mode"] - align_corners = transform[InverseKeys.EXTRA_INFO]["align_corners"] + fwd_rot_mat = transform[TraceKeys.EXTRA_INFO]["rot_mat"] + mode = transform[TraceKeys.EXTRA_INFO]["mode"] + padding_mode = transform[TraceKeys.EXTRA_INFO]["padding_mode"] + align_corners = transform[TraceKeys.EXTRA_INFO]["align_corners"] inv_rot_mat = np.linalg.inv(fwd_rot_mat) xform = AffineTransform( normalized=False, mode=mode, padding_mode=padding_mode, - align_corners=False if align_corners == InverseKeys.NONE else align_corners, + align_corners=False if align_corners == TraceKeys.NONE else align_corners, reverse_indexing=True, ) img_t: torch.Tensor @@ -1512,7 +1512,7 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd transform_t: torch.Tensor transform_t, *_ = convert_to_dst_type(inv_rot_mat, img_t) # type: ignore output: torch.Tensor - out = xform(img_t.unsqueeze(0), transform_t, spatial_size=transform[InverseKeys.ORIG_SIZE]).squeeze(0) + out = xform(img_t.unsqueeze(0), transform_t, spatial_size=transform[TraceKeys.ORIG_SIZE]).squeeze(0) out, *_ = convert_to_dst_type(out, dst=d[key], dtype=out.dtype) d[key] = out # Remove the applied transform @@ -1582,7 +1582,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N extra_info={ "mode": mode.value if isinstance(mode, Enum) else mode, "padding_mode": padding_mode.value if isinstance(padding_mode, Enum) else padding_mode, - "align_corners": align_corners if align_corners is not None else InverseKeys.NONE, + "align_corners": align_corners if align_corners is not None else TraceKeys.NONE, }, ) d[key] = self.zoomer(d[key], mode=mode, padding_mode=padding_mode, align_corners=align_corners) @@ -1595,18 +1595,18 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd # Create inverse transform zoom = np.array(self.zoomer.zoom) inverse_transform = Zoom(zoom=(1 / zoom).tolist(), keep_size=self.zoomer.keep_size) - mode = transform[InverseKeys.EXTRA_INFO]["mode"] - padding_mode = transform[InverseKeys.EXTRA_INFO]["padding_mode"] - align_corners = transform[InverseKeys.EXTRA_INFO]["align_corners"] + mode = transform[TraceKeys.EXTRA_INFO]["mode"] + padding_mode = transform[TraceKeys.EXTRA_INFO]["padding_mode"] + align_corners = transform[TraceKeys.EXTRA_INFO]["align_corners"] # Apply inverse d[key] = inverse_transform( d[key], mode=mode, padding_mode=padding_mode, - align_corners=None if align_corners == InverseKeys.NONE else align_corners, + align_corners=None if align_corners == TraceKeys.NONE else align_corners, ) # Size might be out by 1 voxel so pad - d[key] = SpatialPad(transform[InverseKeys.ORIG_SIZE], mode="edge")(d[key]) # type: ignore + d[key] = SpatialPad(transform[TraceKeys.ORIG_SIZE], mode="edge")(d[key]) # type: ignore # Remove the applied transform self.pop_transform(d, key) @@ -1701,7 +1701,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N "zoom": self.rand_zoom._zoom, "mode": mode.value if isinstance(mode, Enum) else mode, "padding_mode": padding_mode.value if isinstance(padding_mode, Enum) else padding_mode, - "align_corners": align_corners if align_corners is not None else InverseKeys.NONE, + "align_corners": align_corners if align_corners is not None else TraceKeys.NONE, }, ) return d @@ -1711,22 +1711,22 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Check if random transform was actually performed (based on `prob`) - if transform[InverseKeys.DO_TRANSFORM]: + if transform[TraceKeys.DO_TRANSFORM]: # Create inverse transform - zoom = np.array(transform[InverseKeys.EXTRA_INFO]["zoom"]) - mode = transform[InverseKeys.EXTRA_INFO]["mode"] - padding_mode = transform[InverseKeys.EXTRA_INFO]["padding_mode"] - align_corners = transform[InverseKeys.EXTRA_INFO]["align_corners"] + zoom = np.array(transform[TraceKeys.EXTRA_INFO]["zoom"]) + mode = transform[TraceKeys.EXTRA_INFO]["mode"] + padding_mode = transform[TraceKeys.EXTRA_INFO]["padding_mode"] + align_corners = transform[TraceKeys.EXTRA_INFO]["align_corners"] inverse_transform = Zoom(zoom=(1 / zoom).tolist(), keep_size=self.rand_zoom.keep_size) # Apply inverse d[key] = inverse_transform( d[key], mode=mode, padding_mode=padding_mode, - align_corners=None if align_corners == InverseKeys.NONE else align_corners, + align_corners=None if align_corners == TraceKeys.NONE else align_corners, ) # Size might be out by 1 voxel so pad - d[key] = SpatialPad(transform[InverseKeys.ORIG_SIZE], mode="edge")(d[key]) # type: ignore + d[key] = SpatialPad(transform[TraceKeys.ORIG_SIZE], mode="edge")(d[key]) # type: ignore # Remove the applied transform self.pop_transform(d, key) diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 1412790227..719be5d714 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -62,7 +62,7 @@ from monai.transforms.utils import extreme_points_to_image, get_extreme_points from monai.transforms.utils_pytorch_numpy_unification import concatenate from monai.utils import convert_to_numpy, ensure_tuple, ensure_tuple_rep -from monai.utils.enums import InverseKeys, TransformBackends +from monai.utils.enums import TraceKeys, TransformBackends from monai.utils.type_conversion import convert_to_dst_type __all__ = [ @@ -645,7 +645,7 @@ def inverse(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: for key in self.key_iterator(d): transform = self.get_most_recent_transform(d, key) # Create inverse transform - fwd_indices = np.array(transform[InverseKeys.EXTRA_INFO]["indices"]) + fwd_indices = np.array(transform[TraceKeys.EXTRA_INFO]["indices"]) inv_indices = np.argsort(fwd_indices) inverse_transform = Transpose(inv_indices.tolist()) # Apply inverse @@ -1051,7 +1051,7 @@ def __call__(self, data): return super().__call__(data) def _inverse_transform(self, transform_info: Dict, data: Any, func: Callable): - return self._lambd(data, func=func) if transform_info[InverseKeys.DO_TRANSFORM] else data + return self._lambd(data, func=func) if transform_info[TraceKeys.DO_TRANSFORM] else data class LabelToMaskd(MapTransform): diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index e7a3ea2ab1..0b7479cfb0 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -38,9 +38,9 @@ from monai.utils import ( GridSampleMode, InterpolateMode, - InverseKeys, NumpyPadMode, PytorchPadMode, + TraceKeys, deprecated_arg, ensure_tuple, ensure_tuple_rep, @@ -1194,7 +1194,7 @@ def allow_missing_keys_mode(transform: Union[MapTransform, Compose, Tuple[MapTra def convert_inverse_interp_mode(trans_info: List, mode: str = "nearest", align_corners: Optional[bool] = None): """ Change the interpolation mode when inverting spatial transforms, default to "nearest". - This function modifies trans_info's `InverseKeys.EXTRA_INFO`. + This function modifies trans_info's `TraceKeys.EXTRA_INFO`. See also: :py:class:`monai.transform.inverse.InvertibleTransform` @@ -1207,21 +1207,21 @@ def convert_inverse_interp_mode(trans_info: List, mode: str = "nearest", align_c interp_modes = [i.value for i in InterpolateMode] + [i.value for i in GridSampleMode] # set to string for DataLoader collation - align_corners_ = InverseKeys.NONE if align_corners is None else align_corners + align_corners_ = TraceKeys.NONE if align_corners is None else align_corners for item in ensure_tuple(trans_info): - if InverseKeys.EXTRA_INFO in item: - orig_mode = item[InverseKeys.EXTRA_INFO].get("mode", None) + if TraceKeys.EXTRA_INFO in item: + orig_mode = item[TraceKeys.EXTRA_INFO].get("mode", None) if orig_mode is not None: if orig_mode[0] in interp_modes: - item[InverseKeys.EXTRA_INFO]["mode"] = [mode for _ in range(len(mode))] + item[TraceKeys.EXTRA_INFO]["mode"] = [mode for _ in range(len(mode))] elif orig_mode in interp_modes: - item[InverseKeys.EXTRA_INFO]["mode"] = mode - if "align_corners" in item[InverseKeys.EXTRA_INFO]: - if issequenceiterable(item[InverseKeys.EXTRA_INFO]["align_corners"]): - item[InverseKeys.EXTRA_INFO]["align_corners"] = [align_corners_ for _ in range(len(mode))] + item[TraceKeys.EXTRA_INFO]["mode"] = mode + if "align_corners" in item[TraceKeys.EXTRA_INFO]: + if issequenceiterable(item[TraceKeys.EXTRA_INFO]["align_corners"]): + item[TraceKeys.EXTRA_INFO]["align_corners"] = [align_corners_ for _ in range(len(mode))] else: - item[InverseKeys.EXTRA_INFO]["align_corners"] = align_corners_ + item[TraceKeys.EXTRA_INFO]["align_corners"] = align_corners_ return trans_info diff --git a/monai/utils/enums.py b/monai/utils/enums.py index 9d643c3ebc..1f7518c4a9 100644 --- a/monai/utils/enums.py +++ b/monai/utils/enums.py @@ -223,13 +223,13 @@ class TraceKeys: NONE = "none" -@deprecated(since="0.7.0", msg_suffix="use monai.utils.TraceKeys instead.") +@deprecated(since="0.7.0", msg_suffix="use monai.utils.enums.TraceKeys instead.") class InverseKeys: """ Extra meta data keys used for inverse transforms. .. deprecated:: 0.7.0 - Use :class:`monai.utils.TraceKeys` instead. + Use :class:`monai.utils.enums.TraceKeys` instead. """ @@ -242,7 +242,6 @@ class InverseKeys: NONE = "none" - class CommonKeys: """ A set of common keys for dictionary based supervised training process. @@ -267,4 +266,3 @@ class TransformBackends(Enum): TORCH = "torch" NUMPY = "numpy" - diff --git a/tests/test_decollate.py b/tests/test_decollate.py index f35988e215..763bf808d0 100644 --- a/tests/test_decollate.py +++ b/tests/test_decollate.py @@ -38,7 +38,7 @@ from monai.transforms.inverse_batch_transform import Decollated from monai.transforms.spatial.dictionary import RandAffined, RandRotate90d from monai.utils import optional_import, set_determinism -from monai.utils.enums import InverseKeys +from monai.utils.enums import TraceKeys from tests.utils import make_nifti_image _, has_nib = optional_import("nibabel") @@ -105,7 +105,7 @@ def check_match(self, in1, in2): k1, k2 = k1.value, k2.value self.check_match(k1, k2) # Transform ids won't match for windows with multiprocessing, so don't check values - if k1 == InverseKeys.ID and sys.platform in ["darwin", "win32"]: + if k1 == TraceKeys.ID and sys.platform in ["darwin", "win32"]: continue self.check_match(v1, v2) elif isinstance(in1, (list, tuple)): diff --git a/tests/test_one_of.py b/tests/test_one_of.py index 6a68d4f91c..8dc2abf127 100644 --- a/tests/test_one_of.py +++ b/tests/test_one_of.py @@ -17,7 +17,7 @@ from monai.transforms import InvertibleTransform, OneOf, TraceableTransform, Transform from monai.transforms.compose import Compose from monai.transforms.transform import MapTransform -from monai.utils.enums import InverseKeys +from monai.utils.enums import TraceKeys class X(Transform): @@ -150,9 +150,9 @@ def test_inverse(self, transform, should_be_ok): for k in KEYS: t = fwd_data[TraceableTransform.transform_key(k)][-1] # make sure the OneOf index was stored - self.assertEqual(t[InverseKeys.CLASS_NAME], OneOf.__name__) + self.assertEqual(t[TraceKeys.CLASS_NAME], OneOf.__name__) # make sure index exists and is in bounds - self.assertTrue(0 <= t[InverseKeys.EXTRA_INFO]["index"] < len(transform)) + self.assertTrue(0 <= t[TraceKeys.EXTRA_INFO]["index"] < len(transform)) # call the inverse fwd_inv_data = transform.inverse(fwd_data) From 668e01bcd80d49d5d884fe2e535604370ee452b0 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Thu, 11 Nov 2021 18:55:13 +0000 Subject: [PATCH 5/6] update trace_key Signed-off-by: Wenqi Li --- monai/data/test_time_augmentation.py | 4 ++-- monai/transforms/compose.py | 4 ++-- monai/transforms/croppad/batch.py | 4 ++-- monai/transforms/croppad/dictionary.py | 10 +++++----- monai/transforms/inverse.py | 14 +++++++------- monai/transforms/post/dictionary.py | 4 ++-- tests/test_inverse.py | 2 +- tests/test_one_of.py | 5 ++--- tests/test_traceable_transform.py | 2 +- 9 files changed, 24 insertions(+), 25 deletions(-) diff --git a/monai/data/test_time_augmentation.py b/monai/data/test_time_augmentation.py index f4ee2a46bf..948e85e131 100644 --- a/monai/data/test_time_augmentation.py +++ b/monai/data/test_time_augmentation.py @@ -20,7 +20,7 @@ from monai.data.dataset import Dataset from monai.data.utils import list_data_collate, pad_list_data_collate from monai.transforms.compose import Compose -from monai.transforms.inverse import InvertibleTransform, TraceableTransform +from monai.transforms.inverse import InvertibleTransform from monai.transforms.inverse_batch_transform import BatchInverseTransform from monai.transforms.transform import Randomizable from monai.transforms.utils import allow_missing_keys_mode, convert_inverse_interp_mode @@ -168,7 +168,7 @@ def __call__( ds = Dataset(data_in, self.transform) dl = DataLoader(ds, num_workers=self.num_workers, batch_size=self.batch_size, collate_fn=pad_list_data_collate) - transform_key = TraceableTransform.transform_key(self.orig_key) + transform_key = InvertibleTransform.trace_key(self.orig_key) # create inverter inverter = BatchInverseTransform(self.transform, dl, collate_fn=list_data_collate) diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index 9a75121a6b..d0d14bcc2d 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -237,7 +237,7 @@ def __call__(self, data): # if the data is a mapping (dictionary), append the OneOf transform to the end if isinstance(data, Mapping): for key in data.keys(): - if self.transform_key(key) in data: + if self.trace_key(key) in data: self.push_transform(data, key, extra_info={"index": index}) return data @@ -250,7 +250,7 @@ def inverse(self, data): # loop until we get an index and then break (since they'll all be the same) index = None for key in data.keys(): - if self.transform_key(key) in data: + if self.trace_key(key) in data: # get the index of the applied OneOf transform index = self.get_most_recent_transform(data, key)[TraceKeys.EXTRA_INFO]["index"] # and then remove the OneOf transform diff --git a/monai/transforms/croppad/batch.py b/monai/transforms/croppad/batch.py index 7d15115ab6..0cdaa3a2d8 100644 --- a/monai/transforms/croppad/batch.py +++ b/monai/transforms/croppad/batch.py @@ -21,7 +21,7 @@ from monai.data.utils import list_data_collate from monai.transforms.croppad.array import CenterSpatialCrop, SpatialPad -from monai.transforms.inverse import InvertibleTransform, TraceableTransform +from monai.transforms.inverse import InvertibleTransform from monai.utils.enums import Method, NumpyPadMode, TraceKeys __all__ = ["PadListDataCollate"] @@ -115,7 +115,7 @@ def inverse(data: dict) -> Dict[Hashable, np.ndarray]: d = deepcopy(data) for key in d: - transform_key = TraceableTransform.transform_key(key) + transform_key = InvertibleTransform.trace_key(key) if transform_key in d: transform = d[transform_key][-1] if not isinstance(transform, Dict): diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 6ee13ff67a..9a0f898842 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -39,7 +39,7 @@ SpatialCrop, SpatialPad, ) -from monai.transforms.inverse import InvertibleTransform, TraceableTransform +from monai.transforms.inverse import InvertibleTransform from monai.transforms.transform import MapTransform, Randomizable from monai.transforms.utils import ( allow_missing_keys_mode, @@ -776,8 +776,8 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> List[Dict[Hashab cropped = self.cropper(d) # self.cropper will have added RandSpatialCropd to the list. Change to RandSpatialCropSamplesd for key in self.key_iterator(cropped): - cropped[TraceableTransform.transform_key(key)][-1][TraceKeys.CLASS_NAME] = self.__class__.__name__ # type: ignore - cropped[TraceableTransform.transform_key(key)][-1][TraceKeys.ID] = id(self) # type: ignore + cropped[self.trace_key(key)][-1][TraceKeys.CLASS_NAME] = self.__class__.__name__ # type: ignore + cropped[self.trace_key(key)][-1][TraceKeys.ID] = id(self) # type: ignore # add `patch_index` to the meta data for key, meta_key, meta_key_postfix in self.key_iterator(d, self.meta_keys, self.meta_key_postfix): meta_key = meta_key or f"{key}_{meta_key_postfix}" @@ -792,8 +792,8 @@ def inverse(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: # We changed the transform name from RandSpatialCropd to RandSpatialCropSamplesd # Need to revert that since we're calling RandSpatialCropd's inverse for key in self.key_iterator(d): - d[TraceableTransform.transform_key(key)][-1][TraceKeys.CLASS_NAME] = self.cropper.__class__.__name__ - d[TraceableTransform.transform_key(key)][-1][TraceKeys.ID] = id(self.cropper) + d[self.trace_key(key)][-1][TraceKeys.CLASS_NAME] = self.cropper.__class__.__name__ + d[self.trace_key(key)][-1][TraceKeys.ID] = id(self.cropper) context_manager = allow_missing_keys_mode if self.allow_missing_keys else _nullcontext with context_manager(self.cropper): return self.cropper.inverse(d) diff --git a/monai/transforms/inverse.py b/monai/transforms/inverse.py index 46a485ffd0..90ceba7489 100644 --- a/monai/transforms/inverse.py +++ b/monai/transforms/inverse.py @@ -22,7 +22,7 @@ class TraceableTransform(Transform): """ Maintains a stack of applied transforms. The stack is inserted as pairs of - `transform_key: list of transforms` to each data dictionary. + `trace_key: list of transforms` to each data dictionary. The ``__call__`` method of this transform class must be implemented so that the transformation information for each key is stored when @@ -45,7 +45,7 @@ def set_tracing(self, tracing: bool) -> None: self.tracing = tracing @staticmethod - def transform_key(key: Hashable = None): + def trace_key(key: Hashable = None): """The key to store the stack of applied transforms.""" if key is None: return TraceKeys.KEY_SUFFIX @@ -68,17 +68,17 @@ def push_transform( if hasattr(self, "_do_transform"): # RandomizableTransform info[TraceKeys.DO_TRANSFORM] = self._do_transform # type: ignore # If this is the first, create list - if self.transform_key(key) not in data: + if self.trace_key(key) not in data: if not isinstance(data, dict): data = dict(data) - data[self.transform_key(key)] = [] - data[self.transform_key(key)].append(info) + data[self.trace_key(key)] = [] + data[self.trace_key(key)].append(info) def pop_transform(self, data: Mapping, key: Hashable = None): """Remove the most recent applied transform.""" if not self.tracing: return - return data.get(self.transform_key(key), []).pop() + return data.get(self.trace_key(key), []).pop() class InvertibleTransform(TraceableTransform): @@ -133,7 +133,7 @@ def get_most_recent_transform(self, data: Mapping, key: Hashable = None): """Get most recent transform.""" if not self.tracing: raise RuntimeError("Transform Tracing must be enabled to get the most recent transform.") - transform = data[self.transform_key(key)][-1] + transform = data[self.trace_key(key)][-1] self.check_transforms_match(transform) return transform diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py index e96f1599c9..9644dc4f32 100644 --- a/monai/transforms/post/dictionary.py +++ b/monai/transforms/post/dictionary.py @@ -24,7 +24,7 @@ from monai.config import KeysCollection from monai.config.type_definitions import NdarrayOrTensor from monai.data.csv_saver import CSVSaver -from monai.transforms.inverse import InvertibleTransform, TraceableTransform +from monai.transforms.inverse import InvertibleTransform from monai.transforms.post.array import ( Activations, AsDiscrete, @@ -599,7 +599,7 @@ def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: self.device, self.post_func, ): - transform_key = TraceableTransform.transform_key(orig_key) + transform_key = InvertibleTransform.trace_key(orig_key) if transform_key not in d: warnings.warn(f"transform info of `{orig_key}` is not available or no InvertibleTransform applied.") continue diff --git a/tests/test_inverse.py b/tests/test_inverse.py index fc91795dce..992a919065 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -466,7 +466,7 @@ def test_inverse_inferred_seg(self, extra_transform): labels = data["label"].to(device) segs = model(labels).detach().cpu() - label_transform_key = TraceableTransform.transform_key("label") + label_transform_key = TraceableTransform.trace_key("label") segs_dict = {"label": segs, label_transform_key: data[label_transform_key]} segs_dict_decollated = decollate_batch(segs_dict) diff --git a/tests/test_one_of.py b/tests/test_one_of.py index 8dc2abf127..7537062569 100644 --- a/tests/test_one_of.py +++ b/tests/test_one_of.py @@ -148,7 +148,7 @@ def test_inverse(self, transform, should_be_ok): return for k in KEYS: - t = fwd_data[TraceableTransform.transform_key(k)][-1] + t = fwd_data[TraceableTransform.trace_key(k)][-1] # make sure the OneOf index was stored self.assertEqual(t[TraceKeys.CLASS_NAME], OneOf.__name__) # make sure index exists and is in bounds @@ -160,8 +160,7 @@ def test_inverse(self, transform, should_be_ok): for k in KEYS: # check transform was removed self.assertTrue( - len(fwd_inv_data[TraceableTransform.transform_key(k)]) - < len(fwd_data[TraceableTransform.transform_key(k)]) + len(fwd_inv_data[TraceableTransform.trace_key(k)]) < len(fwd_data[TraceableTransform.trace_key(k)]) ) # check data is same as original (and different from forward) self.assertEqual(fwd_inv_data[k], data[k]) diff --git a/tests/test_traceable_transform.py b/tests/test_traceable_transform.py index 0f1640f8be..b4e9c509f7 100644 --- a/tests/test_traceable_transform.py +++ b/tests/test_traceable_transform.py @@ -28,7 +28,7 @@ class TestTraceable(unittest.TestCase): def test_default(self): expected_key = "_transforms" a = _TraceTest() - self.assertEqual(a.transform_key(), expected_key) + self.assertEqual(a.trace_key(), expected_key) data = {"image": "test"} data = a(data) # adds to the stack From bf85972aaf908c077b3fc5417a7f763c9e5907bd Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Fri, 12 Nov 2021 08:05:03 +0000 Subject: [PATCH 6/6] update based on comments Signed-off-by: Wenqi Li --- monai/utils/enums.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/utils/enums.py b/monai/utils/enums.py index 1f7518c4a9..c059a4a5e2 100644 --- a/monai/utils/enums.py +++ b/monai/utils/enums.py @@ -223,12 +223,12 @@ class TraceKeys: NONE = "none" -@deprecated(since="0.7.0", msg_suffix="use monai.utils.enums.TraceKeys instead.") +@deprecated(since="0.8.0", msg_suffix="use monai.utils.enums.TraceKeys instead.") class InverseKeys: """ Extra meta data keys used for inverse transforms. - .. deprecated:: 0.7.0 + .. deprecated:: 0.8.0 Use :class:`monai.utils.enums.TraceKeys` instead. """