diff --git a/monai/data/dataset.py b/monai/data/dataset.py index d63ff32293..066cec41b7 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -230,6 +230,8 @@ def __init__( pickle_protocol: int = DEFAULT_PROTOCOL, hash_transform: Callable[..., bytes] | None = None, reset_ops_id: bool = True, + track_meta: bool = False, + weights_only: bool = True, ) -> None: """ Args: @@ -264,7 +266,17 @@ def __init__( When this is enabled, the traced transform instance IDs will be removed from the cached MetaTensors. This is useful for skipping the transform instance checks when inverting applied operations using the cached content and with re-created transform instances. - + track_meta: whether to track the meta information, if `True`, will convert to `MetaTensor`. + default to `False`. Cannot be used with `weights_only=True`. + weights_only: keyword argument passed to `torch.load` when reading cached files. + default to `True`. When set to `True`, `torch.load` restricts loading to tensors and + other safe objects. Setting this to `False` is required for loading `MetaTensor` + objects saved with `track_meta=True`, however this creates the possibility of remote + code execution through `torch.load` so be aware of the security implications of doing so. + + Raises: + ValueError: When both `track_meta=True` and `weights_only=True`, since this combination + prevents cached MetaTensors from being reloaded and causes perpetual cache regeneration. """ super().__init__(data=data, transform=transform) self.cache_dir = Path(cache_dir) if cache_dir is not None else None @@ -280,6 +292,13 @@ def __init__( if hash_transform is not None: self.set_transform_hash(hash_transform) self.reset_ops_id = reset_ops_id + if track_meta and weights_only: + raise ValueError( + "Invalid argument combination: `track_meta=True` cannot be used with `weights_only=True`. " + "To cache and reload MetaTensors, set `track_meta=True` and `weights_only=False`." + ) + self.track_meta = track_meta + self.weights_only = weights_only def set_transform_hash(self, hash_xform_func: Callable[..., bytes]): """Get hashable transforms, and then hash them. Hashable transforms @@ -377,7 +396,7 @@ def _cachecheck(self, item_transformed): if hashfile is not None and hashfile.is_file(): # cache hit try: - return torch.load(hashfile, weights_only=True) + return torch.load(hashfile, weights_only=self.weights_only) except PermissionError as e: if sys.platform != "win32": raise e @@ -398,7 +417,7 @@ def _cachecheck(self, item_transformed): with tempfile.TemporaryDirectory() as tmpdirname: temp_hash_file = Path(tmpdirname) / hashfile.name torch.save( - obj=convert_to_tensor(_item_transformed, convert_numeric=False), + obj=convert_to_tensor(_item_transformed, convert_numeric=False, track_meta=self.track_meta), f=temp_hash_file, pickle_module=look_up_option(self.pickle_module, SUPPORTED_PICKLE_MOD), pickle_protocol=self.pickle_protocol, diff --git a/tests/data/test_persistentdataset.py b/tests/data/test_persistentdataset.py index 7bf1245592..ca62cdb184 100644 --- a/tests/data/test_persistentdataset.py +++ b/tests/data/test_persistentdataset.py @@ -11,6 +11,7 @@ from __future__ import annotations +import contextlib import os import tempfile import unittest @@ -20,7 +21,7 @@ import torch from parameterized import parameterized -from monai.data import PersistentDataset, json_hashing +from monai.data import MetaTensor, PersistentDataset, json_hashing from monai.transforms import Compose, Flip, Identity, LoadImaged, SimulateDelayd, Transform TEST_CASE_1 = [ @@ -43,9 +44,16 @@ TEST_CASE_3 = [None, (128, 128, 128)] +TEST_CASE_4 = [True, False, False, MetaTensor] + +TEST_CASE_5 = [True, True, True, None] + +TEST_CASE_6 = [False, False, False, torch.Tensor] + +TEST_CASE_7 = [False, True, False, torch.Tensor] -class _InplaceXform(Transform): +class _InplaceXform(Transform): def __call__(self, data): if data: data[0] = data[0] + np.pi @@ -55,7 +63,6 @@ def __call__(self, data): class TestDataset(unittest.TestCase): - def test_cache(self): """testing no inplace change to the hashed item""" items = [[list(range(i))] for i in range(5)] @@ -168,6 +175,31 @@ def test_different_transforms(self): l2 = ((im1 - im2) ** 2).sum() ** 0.5 self.assertGreater(l2, 1) + @parameterized.expand([TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7]) + def test_track_meta_and_weights_only(self, track_meta, weights_only, expected_error, expected_type): + """ + Ensure expected behavior for all combinations of `track_meta` and `weights_only`. + """ + test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]).astype(float), np.eye(4)) + with tempfile.TemporaryDirectory() as tempdir: + nib.save(test_image, os.path.join(tempdir, "test_image.nii.gz")) + test_data = [{"image": os.path.join(tempdir, "test_image.nii.gz")}] + transform = Compose([LoadImaged(keys=["image"])]) + cache_dir = os.path.join(os.path.join(tempdir, "cache"), "data") + + cm = self.assertRaises(ValueError) if expected_error else contextlib.nullcontext() + with cm: + test_dataset = PersistentDataset( + data=test_data, + transform=transform, + cache_dir=cache_dir, + track_meta=track_meta, + weights_only=weights_only, + ) + + im = test_dataset[0]["image"] + self.assertIsInstance(im, expected_type) + if __name__ == "__main__": unittest.main()