Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 47 additions & 5 deletions monai/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,15 @@
from torch.utils.data import Subset

from monai.data.utils import SUPPORTED_PICKLE_MOD, convert_tables_to_dicts, pickle_hashing
from monai.transforms import Compose, Randomizable, ThreadUnsafe, Transform, apply_transform, convert_to_contiguous
from monai.transforms import (
Compose,
Randomizable,
ThreadUnsafe,
Transform,
apply_transform,
convert_to_contiguous,
reset_ops_id,
)
from monai.utils import MAX_SEED, deprecated_arg, get_seed, look_up_option, min_version, optional_import
from monai.utils.misc import first

Expand Down Expand Up @@ -207,6 +215,7 @@ def __init__(
pickle_module: str = "pickle",
pickle_protocol: int = DEFAULT_PROTOCOL,
hash_transform: Optional[Callable[..., bytes]] = None,
reset_ops_id: bool = True,
) -> None:
"""
Args:
Expand Down Expand Up @@ -237,6 +246,10 @@ def __init__(
hash_transform: a callable to compute hash from the transform information when caching.
This may reduce errors due to transforms changing during experiments. Default to None (no hash).
Other options are `pickle_hashing` and `json_hashing` functions from `monai.data.utils`.
reset_ops_id: whether to set `TraceKeys.ID` to ``Tracekys.NONE``, defaults to ``True``.
When this is enabled, the traced transform instance IDs will be removed from the cached MetaTensors.
This is useful for skipping the transform instance checks when inverting applied operations
using the cached content and with re-created transform instances.

"""
if not isinstance(transform, Compose):
Expand All @@ -254,6 +267,7 @@ def __init__(
self.transform_hash = ""
if hash_transform is not None:
self.set_transform_hash(hash_transform)
self.reset_ops_id = reset_ops_id

def set_transform_hash(self, hash_xform_func):
"""Get hashable transforms, and then hash them. Hashable transforms
Expand Down Expand Up @@ -304,6 +318,8 @@ def _pre_transform(self, item_transformed):
# this is to be consistent with CacheDataset even though it's not in a multi-thread situation.
_xform = deepcopy(_transform) if isinstance(_transform, ThreadUnsafe) else _transform
item_transformed = apply_transform(_xform, item_transformed)
if self.reset_ops_id:
reset_ops_id(item_transformed)
return item_transformed

def _post_transform(self, item_transformed):
Expand Down Expand Up @@ -345,9 +361,8 @@ def _cachecheck(self, item_transformed):

Warning:
The current implementation does not encode transform information as part of the
hashing mechanism used for generating cache names. If the transforms applied are
changed in any way, the objects in the cache dir will be invalid. The hash for the
cache is ONLY dependant on the input filename paths.
hashing mechanism used for generating cache names when `hash_transform` is None.
If the transforms applied are changed in any way, the objects in the cache dir will be invalid.

"""
hashfile = None
Expand Down Expand Up @@ -409,6 +424,8 @@ def __init__(
hash_func: Callable[..., bytes] = pickle_hashing,
pickle_module: str = "pickle",
pickle_protocol: int = DEFAULT_PROTOCOL,
hash_transform: Optional[Callable[..., bytes]] = None,
reset_ops_id: bool = True,
) -> None:
"""
Args:
Expand Down Expand Up @@ -437,6 +454,13 @@ def __init__(
pickle_protocol: can be specified to override the default protocol, default to `2`.
this arg is used by `torch.save`, for more details, please check:
https://pytorch.org/docs/stable/generated/torch.save.html#torch.save.
hash_transform: a callable to compute hash from the transform information when caching.
This may reduce errors due to transforms changing during experiments. Default to None (no hash).
Other options are `pickle_hashing` and `json_hashing` functions from `monai.data.utils`.
reset_ops_id: whether to set `TraceKeys.ID` to ``Tracekys.NONE``, defaults to ``True``.
When this is enabled, the traced transform instance IDs will be removed from the cached MetaTensors.
This is useful for skipping the transform instance checks when inverting applied operations
using the cached content and with re-created transform instances.

"""
super().__init__(
Expand All @@ -446,6 +470,8 @@ def __init__(
hash_func=hash_func,
pickle_module=pickle_module,
pickle_protocol=pickle_protocol,
hash_transform=hash_transform,
reset_ops_id=reset_ops_id,
)
self.cache_n_trans = cache_n_trans

Expand All @@ -466,6 +492,7 @@ def _pre_transform(self, item_transformed):
break
_xform = deepcopy(_transform) if isinstance(_transform, ThreadUnsafe) else _transform
item_transformed = apply_transform(_xform, item_transformed)
reset_ops_id(item_transformed)
return item_transformed

def _post_transform(self, item_transformed):
Expand Down Expand Up @@ -511,6 +538,8 @@ def __init__(
db_name: str = "monai_cache",
progress: bool = True,
pickle_protocol=pickle.HIGHEST_PROTOCOL,
hash_transform: Optional[Callable[..., bytes]] = None,
reset_ops_id: bool = True,
lmdb_kwargs: Optional[dict] = None,
) -> None:
"""
Expand All @@ -530,11 +559,24 @@ def __init__(
progress: whether to display a progress bar.
pickle_protocol: pickle protocol version. Defaults to pickle.HIGHEST_PROTOCOL.
https://docs.python.org/3/library/pickle.html#pickle-protocols
hash_transform: a callable to compute hash from the transform information when caching.
This may reduce errors due to transforms changing during experiments. Default to None (no hash).
Other options are `pickle_hashing` and `json_hashing` functions from `monai.data.utils`.
reset_ops_id: whether to set `TraceKeys.ID` to ``Tracekys.NONE``, defaults to ``True``.
When this is enabled, the traced transform instance IDs will be removed from the cached MetaTensors.
This is useful for skipping the transform instance checks when inverting applied operations
using the cached content and with re-created transform instances.
lmdb_kwargs: additional keyword arguments to the lmdb environment.
for more details please visit: https://lmdb.readthedocs.io/en/release/#environment-class
"""
super().__init__(
data=data, transform=transform, cache_dir=cache_dir, hash_func=hash_func, pickle_protocol=pickle_protocol
data=data,
transform=transform,
cache_dir=cache_dir,
hash_func=hash_func,
pickle_protocol=pickle_protocol,
hash_transform=hash_transform,
reset_ops_id=reset_ops_id,
)
self.progress = progress
if not self.cache_dir:
Expand Down
1 change: 1 addition & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,7 @@
rescale_array,
rescale_array_int_max,
rescale_instance_array,
reset_ops_id,
resize_center,
sync_meta_info,
weighted_patch_samples,
Expand Down
16 changes: 16 additions & 0 deletions monai/transforms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@
"scale_affine",
"attach_hook",
"sync_meta_info",
"reset_ops_id",
]


Expand Down Expand Up @@ -1286,6 +1287,21 @@ def convert_applied_interp_mode(trans_info, mode: str = "nearest", align_corners
return trans_info


def reset_ops_id(data):
"""find MetaTensors in list or dict `data` and (in-place) set ``TraceKeys.ID`` to ``Tracekys.NONE``."""
if isinstance(data, (list, tuple)):
return [reset_ops_id(d) for d in data]
if isinstance(data, monai.data.MetaTensor):
data.applied_operations = reset_ops_id(data.applied_operations)
return data
if not isinstance(data, Mapping):
return data
data = dict(data)
if TraceKeys.ID in data:
data[TraceKeys.ID] = TraceKeys.NONE
return {k: reset_ops_id(v) for k, v in data.items()}


def compute_divisible_spatial_size(spatial_shape: Sequence[int], k: Union[Sequence[int], int]):
"""
Compute the target spatial size which should be divisible by `k`.
Expand Down
5 changes: 5 additions & 0 deletions tests/test_inverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import random
import sys
import unittest
from copy import deepcopy
from functools import partial
from typing import TYPE_CHECKING, List, Tuple
from unittest.case import skipUnless
Expand Down Expand Up @@ -63,6 +64,7 @@
Zoomd,
allow_missing_keys_mode,
convert_applied_interp_mode,
reset_ops_id,
)
from monai.utils import first, get_seed, optional_import, set_determinism
from tests.utils import make_nifti_image, make_rand_affine
Expand Down Expand Up @@ -497,10 +499,13 @@ def test_inverse_inferred_seg(self, extra_transform):
resizer = ResizeWithPadOrCrop(spatial_size=shape_before_extra_xform)
with resizer.trace_transform(False):
seg_metatensor = resizer(seg_metatensor)
no_ops_id_tensor = reset_ops_id(deepcopy(seg_metatensor))

with allow_missing_keys_mode(transforms):
inv_seg = transforms.inverse({"label": seg_metatensor})["label"]
inv_seg_1 = transforms.inverse({"label": no_ops_id_tensor})["label"]
self.assertEqual(inv_seg.shape[1:], test_data[0]["label"].shape)
self.assertEqual(inv_seg_1.shape[1:], test_data[0]["label"].shape)

# # Inverse of batch
# batch_inverter = BatchInverseTransform(transforms, loader, collate_fn=no_collation, detach=True)
Expand Down
7 changes: 6 additions & 1 deletion tests/test_metatensor_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import os
import tempfile
import unittest
from copy import deepcopy

import numpy as np
from parameterized import parameterized
Expand All @@ -20,7 +21,7 @@
from monai.bundle import ConfigParser
from monai.data import CacheDataset, DataLoader, MetaTensor, decollate_batch
from monai.data.utils import TraceKeys
from monai.transforms import InvertD, SaveImageD
from monai.transforms import InvertD, SaveImageD, reset_ops_id
from monai.utils import optional_import, set_determinism
from tests.utils import assert_allclose, download_url_or_skip_test, testing_data_config

Expand Down Expand Up @@ -80,6 +81,10 @@ def test_transforms(self, case_id):
self.assertTrue(len(tracked_cls) <= len(test_cls)) # tracked items should be no more than the compose items.
with tempfile.TemporaryDirectory() as tempdir: # test writer
SaveImageD(keys, resample=False, output_dir=tempdir, output_postfix=case_id)(loaded)
test_data = reset_ops_id(deepcopy(loaded))
for val in test_data.values():
if isinstance(val, MetaTensor) and val.applied_operations:
self.assertEqual(val.applied_operations[-1][TraceKeys.ID], TraceKeys.NONE)

# test inverse
inv = InvertD(keys, orig_keys=keys, transform=test_case, nearest_interp=True)
Expand Down