From 552fd10853a0e44823d13a06d214a166ec104e9b Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Wed, 27 Jul 2022 22:32:57 +0100 Subject: [PATCH 1/2] set no ops id Signed-off-by: Wenqi Li --- monai/data/dataset.py | 12 +++++++++++- monai/transforms/__init__.py | 1 + monai/transforms/utils.py | 16 ++++++++++++++++ tests/test_inverse.py | 5 +++++ tests/test_metatensor_integration.py | 7 ++++++- 5 files changed, 39 insertions(+), 2 deletions(-) diff --git a/monai/data/dataset.py b/monai/data/dataset.py index 15bbf01e3d..eb41e968e4 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -31,7 +31,15 @@ from torch.utils.data import Subset from monai.data.utils import SUPPORTED_PICKLE_MOD, convert_tables_to_dicts, pickle_hashing -from monai.transforms import Compose, Randomizable, ThreadUnsafe, Transform, apply_transform, convert_to_contiguous +from monai.transforms import ( + Compose, + Randomizable, + ThreadUnsafe, + Transform, + apply_transform, + convert_to_contiguous, + reset_ops_id, +) from monai.utils import MAX_SEED, deprecated_arg, get_seed, look_up_option, min_version, optional_import from monai.utils.misc import first @@ -304,6 +312,7 @@ def _pre_transform(self, item_transformed): # this is to be consistent with CacheDataset even though it's not in a multi-thread situation. _xform = deepcopy(_transform) if isinstance(_transform, ThreadUnsafe) else _transform item_transformed = apply_transform(_xform, item_transformed) + reset_ops_id(item_transformed) return item_transformed def _post_transform(self, item_transformed): @@ -466,6 +475,7 @@ def _pre_transform(self, item_transformed): break _xform = deepcopy(_transform) if isinstance(_transform, ThreadUnsafe) else _transform item_transformed = apply_transform(_xform, item_transformed) + reset_ops_id(item_transformed) return item_transformed def _post_transform(self, item_transformed): diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 713f848f86..bcb2f56670 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -608,6 +608,7 @@ rescale_array, rescale_array_int_max, rescale_instance_array, + reset_ops_id, resize_center, sync_meta_info, weighted_patch_samples, diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index d6ba666767..ae550e7ce6 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -110,6 +110,7 @@ "scale_affine", "attach_hook", "sync_meta_info", + "reset_ops_id", ] @@ -1286,6 +1287,21 @@ def convert_applied_interp_mode(trans_info, mode: str = "nearest", align_corners return trans_info +def reset_ops_id(data): + """find MetaTensors in list or dict `data` and (in-place) set ``TraceKeys.ID`` to ``Tracekys.NONE``.""" + if isinstance(data, (list, tuple)): + return [reset_ops_id(d) for d in data] + if isinstance(data, monai.data.MetaTensor): + data.applied_operations = reset_ops_id(data.applied_operations) + return data + if not isinstance(data, Mapping): + return data + data = dict(data) + if TraceKeys.ID in data: + data[TraceKeys.ID] = TraceKeys.NONE + return {k: reset_ops_id(v) for k, v in data.items()} + + def compute_divisible_spatial_size(spatial_shape: Sequence[int], k: Union[Sequence[int], int]): """ Compute the target spatial size which should be divisible by `k`. diff --git a/tests/test_inverse.py b/tests/test_inverse.py index 82902a09eb..192a8d345e 100644 --- a/tests/test_inverse.py +++ b/tests/test_inverse.py @@ -12,6 +12,7 @@ import random import sys import unittest +from copy import deepcopy from functools import partial from typing import TYPE_CHECKING, List, Tuple from unittest.case import skipUnless @@ -63,6 +64,7 @@ Zoomd, allow_missing_keys_mode, convert_applied_interp_mode, + reset_ops_id, ) from monai.utils import first, get_seed, optional_import, set_determinism from tests.utils import make_nifti_image, make_rand_affine @@ -497,10 +499,13 @@ def test_inverse_inferred_seg(self, extra_transform): resizer = ResizeWithPadOrCrop(spatial_size=shape_before_extra_xform) with resizer.trace_transform(False): seg_metatensor = resizer(seg_metatensor) + no_ops_id_tensor = reset_ops_id(deepcopy(seg_metatensor)) with allow_missing_keys_mode(transforms): inv_seg = transforms.inverse({"label": seg_metatensor})["label"] + inv_seg_1 = transforms.inverse({"label": no_ops_id_tensor})["label"] self.assertEqual(inv_seg.shape[1:], test_data[0]["label"].shape) + self.assertEqual(inv_seg_1.shape[1:], test_data[0]["label"].shape) # # Inverse of batch # batch_inverter = BatchInverseTransform(transforms, loader, collate_fn=no_collation, detach=True) diff --git a/tests/test_metatensor_integration.py b/tests/test_metatensor_integration.py index 3af5f56c4a..6e8d5f40a3 100644 --- a/tests/test_metatensor_integration.py +++ b/tests/test_metatensor_integration.py @@ -12,6 +12,7 @@ import os import tempfile import unittest +from copy import deepcopy import numpy as np from parameterized import parameterized @@ -20,7 +21,7 @@ from monai.bundle import ConfigParser from monai.data import CacheDataset, DataLoader, MetaTensor, decollate_batch from monai.data.utils import TraceKeys -from monai.transforms import InvertD, SaveImageD +from monai.transforms import InvertD, SaveImageD, reset_ops_id from monai.utils import optional_import, set_determinism from tests.utils import assert_allclose, download_url_or_skip_test, testing_data_config @@ -80,6 +81,10 @@ def test_transforms(self, case_id): self.assertTrue(len(tracked_cls) <= len(test_cls)) # tracked items should be no more than the compose items. with tempfile.TemporaryDirectory() as tempdir: # test writer SaveImageD(keys, resample=False, output_dir=tempdir, output_postfix=case_id)(loaded) + test_data = reset_ops_id(deepcopy(loaded)) + for val in test_data.values(): + if isinstance(val, MetaTensor) and val.applied_operations: + self.assertEqual(val.applied_operations[-1][TraceKeys.ID], TraceKeys.NONE) # test inverse inv = InvertD(keys, orig_keys=keys, transform=test_case, nearest_interp=True) From 1fa85f1a870da61f202bb8a41a48bdb01d8ca8c7 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 1 Aug 2022 10:11:46 +0100 Subject: [PATCH 2/2] make reset optional in datasets Signed-off-by: Wenqi Li --- monai/data/dataset.py | 42 +++++++++++++++++++++++++++++++++++++----- 1 file changed, 37 insertions(+), 5 deletions(-) diff --git a/monai/data/dataset.py b/monai/data/dataset.py index eb41e968e4..2e28fb1926 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -215,6 +215,7 @@ def __init__( pickle_module: str = "pickle", pickle_protocol: int = DEFAULT_PROTOCOL, hash_transform: Optional[Callable[..., bytes]] = None, + reset_ops_id: bool = True, ) -> None: """ Args: @@ -245,6 +246,10 @@ def __init__( hash_transform: a callable to compute hash from the transform information when caching. This may reduce errors due to transforms changing during experiments. Default to None (no hash). Other options are `pickle_hashing` and `json_hashing` functions from `monai.data.utils`. + reset_ops_id: whether to set `TraceKeys.ID` to ``Tracekys.NONE``, defaults to ``True``. + When this is enabled, the traced transform instance IDs will be removed from the cached MetaTensors. + This is useful for skipping the transform instance checks when inverting applied operations + using the cached content and with re-created transform instances. """ if not isinstance(transform, Compose): @@ -262,6 +267,7 @@ def __init__( self.transform_hash = "" if hash_transform is not None: self.set_transform_hash(hash_transform) + self.reset_ops_id = reset_ops_id def set_transform_hash(self, hash_xform_func): """Get hashable transforms, and then hash them. Hashable transforms @@ -312,7 +318,8 @@ def _pre_transform(self, item_transformed): # this is to be consistent with CacheDataset even though it's not in a multi-thread situation. _xform = deepcopy(_transform) if isinstance(_transform, ThreadUnsafe) else _transform item_transformed = apply_transform(_xform, item_transformed) - reset_ops_id(item_transformed) + if self.reset_ops_id: + reset_ops_id(item_transformed) return item_transformed def _post_transform(self, item_transformed): @@ -354,9 +361,8 @@ def _cachecheck(self, item_transformed): Warning: The current implementation does not encode transform information as part of the - hashing mechanism used for generating cache names. If the transforms applied are - changed in any way, the objects in the cache dir will be invalid. The hash for the - cache is ONLY dependant on the input filename paths. + hashing mechanism used for generating cache names when `hash_transform` is None. + If the transforms applied are changed in any way, the objects in the cache dir will be invalid. """ hashfile = None @@ -418,6 +424,8 @@ def __init__( hash_func: Callable[..., bytes] = pickle_hashing, pickle_module: str = "pickle", pickle_protocol: int = DEFAULT_PROTOCOL, + hash_transform: Optional[Callable[..., bytes]] = None, + reset_ops_id: bool = True, ) -> None: """ Args: @@ -446,6 +454,13 @@ def __init__( pickle_protocol: can be specified to override the default protocol, default to `2`. this arg is used by `torch.save`, for more details, please check: https://pytorch.org/docs/stable/generated/torch.save.html#torch.save. + hash_transform: a callable to compute hash from the transform information when caching. + This may reduce errors due to transforms changing during experiments. Default to None (no hash). + Other options are `pickle_hashing` and `json_hashing` functions from `monai.data.utils`. + reset_ops_id: whether to set `TraceKeys.ID` to ``Tracekys.NONE``, defaults to ``True``. + When this is enabled, the traced transform instance IDs will be removed from the cached MetaTensors. + This is useful for skipping the transform instance checks when inverting applied operations + using the cached content and with re-created transform instances. """ super().__init__( @@ -455,6 +470,8 @@ def __init__( hash_func=hash_func, pickle_module=pickle_module, pickle_protocol=pickle_protocol, + hash_transform=hash_transform, + reset_ops_id=reset_ops_id, ) self.cache_n_trans = cache_n_trans @@ -521,6 +538,8 @@ def __init__( db_name: str = "monai_cache", progress: bool = True, pickle_protocol=pickle.HIGHEST_PROTOCOL, + hash_transform: Optional[Callable[..., bytes]] = None, + reset_ops_id: bool = True, lmdb_kwargs: Optional[dict] = None, ) -> None: """ @@ -540,11 +559,24 @@ def __init__( progress: whether to display a progress bar. pickle_protocol: pickle protocol version. Defaults to pickle.HIGHEST_PROTOCOL. https://docs.python.org/3/library/pickle.html#pickle-protocols + hash_transform: a callable to compute hash from the transform information when caching. + This may reduce errors due to transforms changing during experiments. Default to None (no hash). + Other options are `pickle_hashing` and `json_hashing` functions from `monai.data.utils`. + reset_ops_id: whether to set `TraceKeys.ID` to ``Tracekys.NONE``, defaults to ``True``. + When this is enabled, the traced transform instance IDs will be removed from the cached MetaTensors. + This is useful for skipping the transform instance checks when inverting applied operations + using the cached content and with re-created transform instances. lmdb_kwargs: additional keyword arguments to the lmdb environment. for more details please visit: https://lmdb.readthedocs.io/en/release/#environment-class """ super().__init__( - data=data, transform=transform, cache_dir=cache_dir, hash_func=hash_func, pickle_protocol=pickle_protocol + data=data, + transform=transform, + cache_dir=cache_dir, + hash_func=hash_func, + pickle_protocol=pickle_protocol, + hash_transform=hash_transform, + reset_ops_id=reset_ops_id, ) self.progress = progress if not self.cache_dir: