From 9c0e3752bf63b43d867d0c4f20f88822d8bad506 Mon Sep 17 00:00:00 2001 From: Mason Cleveland Date: Tue, 11 Nov 2025 11:15:13 -0500 Subject: [PATCH 1/7] Update PersistentDataset for MetaTensors Signed-off-by: Mason Cleveland --- monai/data/dataset.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/monai/data/dataset.py b/monai/data/dataset.py index d63ff32293..b3f53d7b4e 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -230,6 +230,8 @@ def __init__( pickle_protocol: int = DEFAULT_PROTOCOL, hash_transform: Callable[..., bytes] | None = None, reset_ops_id: bool = True, + track_meta: bool = False, + weights_only: bool = True ) -> None: """ Args: @@ -264,7 +266,10 @@ def __init__( When this is enabled, the traced transform instance IDs will be removed from the cached MetaTensors. This is useful for skipping the transform instance checks when inverting applied operations using the cached content and with re-created transform instances. - + track_meta: whether to track the meta information, if `True`, will convert to `MetaTensor`. + default to `False`. + weights_only: keyword argument passed to `torch.load` when reading cached files. + default to `True`. """ super().__init__(data=data, transform=transform) self.cache_dir = Path(cache_dir) if cache_dir is not None else None @@ -280,6 +285,8 @@ def __init__( if hash_transform is not None: self.set_transform_hash(hash_transform) self.reset_ops_id = reset_ops_id + self.track_meta = track_meta + self.weights_only = weights_only def set_transform_hash(self, hash_xform_func: Callable[..., bytes]): """Get hashable transforms, and then hash them. Hashable transforms @@ -377,7 +384,7 @@ def _cachecheck(self, item_transformed): if hashfile is not None and hashfile.is_file(): # cache hit try: - return torch.load(hashfile, weights_only=True) + return torch.load(hashfile, weights_only=self.weights_only) except PermissionError as e: if sys.platform != "win32": raise e @@ -398,7 +405,7 @@ def _cachecheck(self, item_transformed): with tempfile.TemporaryDirectory() as tmpdirname: temp_hash_file = Path(tmpdirname) / hashfile.name torch.save( - obj=convert_to_tensor(_item_transformed, convert_numeric=False), + obj=convert_to_tensor(_item_transformed, convert_numeric=False, track_meta=self.track_meta), f=temp_hash_file, pickle_module=look_up_option(self.pickle_module, SUPPORTED_PICKLE_MOD), pickle_protocol=self.pickle_protocol, From 42afa85e72208afbf2a93899c6f83b8d38e4796c Mon Sep 17 00:00:00 2001 From: Mason Cleveland Date: Tue, 11 Nov 2025 21:22:30 -0500 Subject: [PATCH 2/7] codeformat autofix Signed-off-by: Mason Cleveland --- monai/data/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/data/dataset.py b/monai/data/dataset.py index b3f53d7b4e..f5f2661f4f 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -231,7 +231,7 @@ def __init__( hash_transform: Callable[..., bytes] | None = None, reset_ops_id: bool = True, track_meta: bool = False, - weights_only: bool = True + weights_only: bool = True, ) -> None: """ Args: From d16616cc8d7fab3036ba665048a6d4482608f719 Mon Sep 17 00:00:00 2001 From: Mason Cleveland Date: Tue, 11 Nov 2025 21:56:59 -0500 Subject: [PATCH 3/7] Add ValueError for track_meta and weights_only; Update docstring Signed-off-by: Mason Cleveland --- monai/data/dataset.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/monai/data/dataset.py b/monai/data/dataset.py index f5f2661f4f..51deff3bf8 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -267,9 +267,15 @@ def __init__( This is useful for skipping the transform instance checks when inverting applied operations using the cached content and with re-created transform instances. track_meta: whether to track the meta information, if `True`, will convert to `MetaTensor`. - default to `False`. + default to `False`. Cannot be used with `weights_only=True`. weights_only: keyword argument passed to `torch.load` when reading cached files. - default to `True`. + default to `True`. When set to `True`, `torch.load` restricts loading to tensors and + other safe objects. Setting this to `False` is required for loading `MetaTensor` + objects saved with `track_meta=True`. + + Raises: + ValueError: When both `track_meta=True` and `weights_only=True`, since this combination + prevents cached MetaTensors from being reloaded and causes perpetual cache regeneration. """ super().__init__(data=data, transform=transform) self.cache_dir = Path(cache_dir) if cache_dir is not None else None @@ -285,6 +291,11 @@ def __init__( if hash_transform is not None: self.set_transform_hash(hash_transform) self.reset_ops_id = reset_ops_id + if track_meta and weights_only: + raise ValueError( + "Invalid argument combination: `track_meta=True` cannot be used with `weights_only=True`. " + "To cache and reload MetaTensors, set `track_meta=True` and `weights_only=False`." + ) self.track_meta = track_meta self.weights_only = weights_only From 341538d7299d6b062cd48d483a734b1b1fd98d60 Mon Sep 17 00:00:00 2001 From: Mason Cleveland Date: Tue, 11 Nov 2025 22:50:24 -0500 Subject: [PATCH 4/7] Add new test Signed-off-by: Mason Cleveland --- tests/data/test_persistentdataset.py | 44 +++++++++++++++++++++++++++- 1 file changed, 43 insertions(+), 1 deletion(-) diff --git a/tests/data/test_persistentdataset.py b/tests/data/test_persistentdataset.py index 7bf1245592..b3009055d8 100644 --- a/tests/data/test_persistentdataset.py +++ b/tests/data/test_persistentdataset.py @@ -20,7 +20,7 @@ import torch from parameterized import parameterized -from monai.data import PersistentDataset, json_hashing +from monai.data import MetaTensor, PersistentDataset, json_hashing from monai.transforms import Compose, Flip, Identity, LoadImaged, SimulateDelayd, Transform TEST_CASE_1 = [ @@ -43,6 +43,14 @@ TEST_CASE_3 = [None, (128, 128, 128)] +TEST_CASE_4 = [True, False, False, MetaTensor] + +TEST_CASE_5 = [True, True, True, None] + +TEST_CASE_6 = [False, False, False, torch.Tensor] + +TEST_CASE_7 = [False, True, False, torch.Tensor] + class _InplaceXform(Transform): @@ -168,6 +176,40 @@ def test_different_transforms(self): l2 = ((im1 - im2) ** 2).sum() ** 0.5 self.assertGreater(l2, 1) + @parameterized.expand([TEST_CASE_4, TEST_CASE_5, TEST_CASE_6, TEST_CASE_7]) + def test_track_meta_and_weights_only(self, track_meta, weights_only, expected_error, expected_type): + """ + Ensure expected behavior for all combinations of `track_meta` and `weights_only`. + """ + test_image = nib.Nifti1Image(np.random.randint(0, 2, size=[128, 128, 128]).astype(float), np.eye(4)) + with tempfile.TemporaryDirectory() as tempdir: + nib.save(test_image, os.path.join(tempdir, "test_image.nii.gz")) + test_data = [{"image": os.path.join(tempdir, "test_image.nii.gz")}] + transform = Compose([LoadImaged(keys=["image"])]) + cache_dir = os.path.join(os.path.join(tempdir, "cache"), "data") + + if expected_error: + with self.assertRaises(ValueError): + PersistentDataset( + data=test_data, + transform=transform, + cache_dir=cache_dir, + track_meta=track_meta, + weights_only=weights_only, + ) + + else: + test_dataset = PersistentDataset( + data=test_data, + transform=transform, + cache_dir=cache_dir, + track_meta=track_meta, + weights_only=weights_only, + ) + + im = test_dataset[0]["image"] + self.assertIsInstance(im, expected_type) + if __name__ == "__main__": unittest.main() From 15f0eb599250f188d9a3768dac7053949f7ad454 Mon Sep 17 00:00:00 2001 From: "Mason C. Cleveland" <104479423+mccle@users.noreply.github.com> Date: Fri, 14 Nov 2025 15:39:18 -0500 Subject: [PATCH 5/7] Update PersistentDataset docstring Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Signed-off-by: Mason C. Cleveland <104479423+mccle@users.noreply.github.com> --- monai/data/dataset.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/monai/data/dataset.py b/monai/data/dataset.py index 51deff3bf8..066cec41b7 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -271,7 +271,8 @@ def __init__( weights_only: keyword argument passed to `torch.load` when reading cached files. default to `True`. When set to `True`, `torch.load` restricts loading to tensors and other safe objects. Setting this to `False` is required for loading `MetaTensor` - objects saved with `track_meta=True`. + objects saved with `track_meta=True`, however this creates the possibility of remote + code execution through `torch.load` so be aware of the security implications of doing so. Raises: ValueError: When both `track_meta=True` and `weights_only=True`, since this combination From 0bacb8bbfdfc1a0824177b27fc5517b259b5bae9 Mon Sep 17 00:00:00 2001 From: "Mason C. Cleveland" <104479423+mccle@users.noreply.github.com> Date: Fri, 14 Nov 2025 15:40:21 -0500 Subject: [PATCH 6/7] Update tests/data/test_persistentdataset.py Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Signed-off-by: Mason C. Cleveland <104479423+mccle@users.noreply.github.com> --- tests/data/test_persistentdataset.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/tests/data/test_persistentdataset.py b/tests/data/test_persistentdataset.py index b3009055d8..e521d50448 100644 --- a/tests/data/test_persistentdataset.py +++ b/tests/data/test_persistentdataset.py @@ -188,17 +188,9 @@ def test_track_meta_and_weights_only(self, track_meta, weights_only, expected_er transform = Compose([LoadImaged(keys=["image"])]) cache_dir = os.path.join(os.path.join(tempdir, "cache"), "data") - if expected_error: - with self.assertRaises(ValueError): - PersistentDataset( - data=test_data, - transform=transform, - cache_dir=cache_dir, - track_meta=track_meta, - weights_only=weights_only, - ) - else: + cm = self.assertRaises(ValueError) if expected_error else contextlib.nullcontext() + with cm: test_dataset = PersistentDataset( data=test_data, transform=transform, From facc345d08bdf416ce40c70642dd6579bacb9f71 Mon Sep 17 00:00:00 2001 From: mccle Date: Fri, 14 Nov 2025 15:46:41 -0500 Subject: [PATCH 7/7] Import contextlib for updated test Signed-off-by: mccle --- tests/data/test_persistentdataset.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/data/test_persistentdataset.py b/tests/data/test_persistentdataset.py index e521d50448..ca62cdb184 100644 --- a/tests/data/test_persistentdataset.py +++ b/tests/data/test_persistentdataset.py @@ -11,6 +11,7 @@ from __future__ import annotations +import contextlib import os import tempfile import unittest @@ -53,7 +54,6 @@ class _InplaceXform(Transform): - def __call__(self, data): if data: data[0] = data[0] + np.pi @@ -63,7 +63,6 @@ def __call__(self, data): class TestDataset(unittest.TestCase): - def test_cache(self): """testing no inplace change to the hashed item""" items = [[list(range(i))] for i in range(5)] @@ -188,7 +187,6 @@ def test_track_meta_and_weights_only(self, track_meta, weights_only, expected_er transform = Compose([LoadImaged(keys=["image"])]) cache_dir = os.path.join(os.path.join(tempdir, "cache"), "data") - cm = self.assertRaises(ValueError) if expected_error else contextlib.nullcontext() with cm: test_dataset = PersistentDataset(