From a0efc23f96ccf5dcf0bddde96cf6b07b8bb5872b Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 21 Dec 2021 22:48:09 +0800 Subject: [PATCH 1/6] [DLMED] fix oneof Signed-off-by: Nic Ma --- monai/transforms/compose.py | 9 +++----- tests/test_one_of.py | 41 +++++++++++++++++++------------------ 2 files changed, 24 insertions(+), 26 deletions(-) diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index b75a18dec1..165d9b732f 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -256,12 +256,9 @@ def inverse(self, data): # and then remove the OneOf transform self.pop_transform(data, key) if index is None: - raise RuntimeError("No invertible transforms have been applied") + # no invertible transforms have been applied + return data - # if applied transform is not InvertibleTransform, throw error _transform = self.transforms[index] - if not isinstance(_transform, InvertibleTransform): - raise RuntimeError(f"Applied OneOf transform is not invertible (applied index: {index}).") - # apply the inverse - return _transform.inverse(data) + return _transform.inverse(data) if isinstance(_transform, InvertibleTransform) else data diff --git a/tests/test_one_of.py b/tests/test_one_of.py index a7cd09f10b..7087c8a8df 100644 --- a/tests/test_one_of.py +++ b/tests/test_one_of.py @@ -139,32 +139,33 @@ def _match(a, b): _match(p, f) @parameterized.expand(TEST_INVERSES) - def test_inverse(self, transform, should_be_ok): + def test_inverse(self, transform, invertible): data = {k: (i + 1) * 10.0 for i, k in enumerate(KEYS)} fwd_data = transform(data) - if not should_be_ok: - with self.assertRaises(RuntimeError): - transform.inverse(fwd_data) - return - - for k in KEYS: - t = fwd_data[TraceableTransform.trace_key(k)][-1] - # make sure the OneOf index was stored - self.assertEqual(t[TraceKeys.CLASS_NAME], OneOf.__name__) - # make sure index exists and is in bounds - self.assertTrue(0 <= t[TraceKeys.EXTRA_INFO]["index"] < len(transform)) + + if invertible: + for k in KEYS: + t = fwd_data[TraceableTransform.trace_key(k)][-1] + # make sure the OneOf index was stored + self.assertEqual(t[TraceKeys.CLASS_NAME], OneOf.__name__) + # make sure index exists and is in bounds + self.assertTrue(0 <= t[TraceKeys.EXTRA_INFO]["index"] < len(transform)) # call the inverse fwd_inv_data = transform.inverse(fwd_data) - for k in KEYS: - # check transform was removed - self.assertTrue( - len(fwd_inv_data[TraceableTransform.trace_key(k)]) < len(fwd_data[TraceableTransform.trace_key(k)]) - ) - # check data is same as original (and different from forward) - self.assertEqual(fwd_inv_data[k], data[k]) - self.assertNotEqual(fwd_inv_data[k], fwd_data[k]) + if invertible: + for k in KEYS: + # check transform was removed + self.assertTrue( + len(fwd_inv_data[TraceableTransform.trace_key(k)]) < len(fwd_data[TraceableTransform.trace_key(k)]) + ) + # check data is same as original (and different from forward) + self.assertEqual(fwd_inv_data[k], data[k]) + self.assertNotEqual(fwd_inv_data[k], fwd_data[k]) + else: + # if not invertible, should not change the data + self.assertDictEqual(fwd_data, fwd_inv_data) def test_one_of(self): p = OneOf((A(), B(), C()), (1, 2, 1)) From 240a7bdece1d42b749e749fdfdb0766fbec84f13 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 21 Dec 2021 22:54:31 +0800 Subject: [PATCH 2/6] [DLMED] add more unit tests Signed-off-by: Nic Ma --- tests/test_one_of.py | 30 +++++++++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/tests/test_one_of.py b/tests/test_one_of.py index 7087c8a8df..29d13d7d0c 100644 --- a/tests/test_one_of.py +++ b/tests/test_one_of.py @@ -12,9 +12,18 @@ import unittest from copy import deepcopy +import numpy as np from parameterized import parameterized -from monai.transforms import InvertibleTransform, OneOf, TraceableTransform, Transform +from monai.transforms import ( + InvertibleTransform, + OneOf, + RandScaleIntensityd, + RandShiftIntensityd, + Resized, + TraceableTransform, + Transform, +) from monai.transforms.compose import Compose from monai.transforms.transform import MapTransform from monai.utils.enums import TraceKeys @@ -167,6 +176,25 @@ def test_inverse(self, transform, invertible): # if not invertible, should not change the data self.assertDictEqual(fwd_data, fwd_inv_data) + def test_inverse_compose(self): + transform = Compose( + [ + Resized(keys="img", spatial_size=[100, 100, 100]), + OneOf( + [ + RandScaleIntensityd(keys="img", factors=0.5, prob=1.0), + RandShiftIntensityd(keys="img", offsets=0.5, prob=1.0), + ] + ), + ] + ) + transform.set_random_state(seed=0) + result = transform({"img": np.ones((1, 101, 102, 103))}) + + result = transform.inverse(result) + # invert to the original spatial shape + self.assertTupleEqual(result["img"].shape, (1, 101, 102, 103)) + def test_one_of(self): p = OneOf((A(), B(), C()), (1, 2, 1)) counts = [0] * 3 From 649a7c59e094283be9625b7b496c53b5884265ab Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 22 Dec 2021 07:02:18 +0800 Subject: [PATCH 3/6] [DLMED] update index Signed-off-by: Nic Ma --- monai/transforms/compose.py | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index 165d9b732f..ad3f35fa79 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -236,9 +236,7 @@ def __call__(self, data): data = apply_transform(_transform, data, self.map_items, self.unpack_items) # if the data is a mapping (dictionary), append the OneOf transform to the end if isinstance(data, Mapping): - for key in data.keys(): - if self.trace_key(key) in data: - self.push_transform(data, key, extra_info={"index": index}) + self.push_transform(data, self.__class__.__name__, extra_info={"index": index}) return data def inverse(self, data): @@ -248,16 +246,14 @@ def inverse(self, data): raise RuntimeError("Inverse only implemented for Mapping (dictionary) data") # loop until we get an index and then break (since they'll all be the same) - index = None - for key in data.keys(): - if self.trace_key(key) in data: - # get the index of the applied OneOf transform - index = self.get_most_recent_transform(data, key)[TraceKeys.EXTRA_INFO]["index"] - # and then remove the OneOf transform - self.pop_transform(data, key) - if index is None: - # no invertible transforms have been applied - return data + key = self.__class__.__name__ + if self.trace_key(key) not in data: + raise RuntimeError("can not find the index of transform have been applied.") + + # get the index of the applied OneOf transform + index = self.get_most_recent_transform(data, key)[TraceKeys.EXTRA_INFO]["index"] + # and then remove the OneOf transform + self.pop_transform(data, key) _transform = self.transforms[index] # apply the inverse From c6c3a35f451828570216622ee42c9a25caa1a769 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 22 Dec 2021 07:33:48 +0800 Subject: [PATCH 4/6] [DLMED] update according to comments Signed-off-by: Nic Ma --- monai/transforms/compose.py | 10 ++++++---- tests/test_one_of.py | 36 +++++++++++++++++------------------- 2 files changed, 23 insertions(+), 23 deletions(-) diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index ad3f35fa79..ea89da3e9b 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -13,6 +13,7 @@ """ import warnings +from copy import deepcopy from typing import Any, Callable, Mapping, Optional, Sequence, Union import numpy as np @@ -245,16 +246,17 @@ def inverse(self, data): if not isinstance(data, Mapping): raise RuntimeError("Inverse only implemented for Mapping (dictionary) data") + d = deepcopy(dict(data)) # loop until we get an index and then break (since they'll all be the same) key = self.__class__.__name__ - if self.trace_key(key) not in data: + if self.trace_key(key) not in d: raise RuntimeError("can not find the index of transform have been applied.") # get the index of the applied OneOf transform - index = self.get_most_recent_transform(data, key)[TraceKeys.EXTRA_INFO]["index"] + index = self.get_most_recent_transform(d, key)[TraceKeys.EXTRA_INFO]["index"] # and then remove the OneOf transform - self.pop_transform(data, key) + self.pop_transform(d, key) _transform = self.transforms[index] # apply the inverse - return _transform.inverse(data) if isinstance(_transform, InvertibleTransform) else data + return _transform.inverse(d) if isinstance(_transform, InvertibleTransform) else d diff --git a/tests/test_one_of.py b/tests/test_one_of.py index 29d13d7d0c..dc20865892 100644 --- a/tests/test_one_of.py +++ b/tests/test_one_of.py @@ -150,31 +150,29 @@ def _match(a, b): @parameterized.expand(TEST_INVERSES) def test_inverse(self, transform, invertible): data = {k: (i + 1) * 10.0 for i, k in enumerate(KEYS)} + key = OneOf.__name__ fwd_data = transform(data) - - if invertible: - for k in KEYS: - t = fwd_data[TraceableTransform.trace_key(k)][-1] - # make sure the OneOf index was stored - self.assertEqual(t[TraceKeys.CLASS_NAME], OneOf.__name__) - # make sure index exists and is in bounds - self.assertTrue(0 <= t[TraceKeys.EXTRA_INFO]["index"] < len(transform)) + t = fwd_data[TraceableTransform.trace_key(key)][-1] + # make sure the OneOf index was stored + self.assertEqual(t[TraceKeys.CLASS_NAME], key) + # make sure index exists and is in bounds + self.assertTrue(0 <= t[TraceKeys.EXTRA_INFO]["index"] < len(transform)) # call the inverse fwd_inv_data = transform.inverse(fwd_data) - if invertible: - for k in KEYS: - # check transform was removed - self.assertTrue( - len(fwd_inv_data[TraceableTransform.trace_key(k)]) < len(fwd_data[TraceableTransform.trace_key(k)]) - ) - # check data is same as original (and different from forward) - self.assertEqual(fwd_inv_data[k], data[k]) + # check transform was removed + self.assertTrue( + len(fwd_inv_data[TraceableTransform.trace_key(key)]) < len(fwd_data[TraceableTransform.trace_key(key)]) + ) + # check data is same as original (and different from forward) + for k, v in data.items(): + if invertible: + self.assertEqual(fwd_inv_data[k], v) self.assertNotEqual(fwd_inv_data[k], fwd_data[k]) - else: - # if not invertible, should not change the data - self.assertDictEqual(fwd_data, fwd_inv_data) + else: + # if not invertible, should not change the data + self.assertEqual(fwd_inv_data[k], fwd_data[k]) def test_inverse_compose(self): transform = Compose( From 5466c1bf82d95566b34a96621c787f1a67df613c Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 22 Dec 2021 23:37:52 +0800 Subject: [PATCH 5/6] Revert "[DLMED] update according to comments" This reverts commit c6c3a35f451828570216622ee42c9a25caa1a769. Signed-off-by: Nic Ma --- monai/transforms/compose.py | 10 ++++------ tests/test_one_of.py | 36 +++++++++++++++++++----------------- 2 files changed, 23 insertions(+), 23 deletions(-) diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index ea89da3e9b..ad3f35fa79 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -13,7 +13,6 @@ """ import warnings -from copy import deepcopy from typing import Any, Callable, Mapping, Optional, Sequence, Union import numpy as np @@ -246,17 +245,16 @@ def inverse(self, data): if not isinstance(data, Mapping): raise RuntimeError("Inverse only implemented for Mapping (dictionary) data") - d = deepcopy(dict(data)) # loop until we get an index and then break (since they'll all be the same) key = self.__class__.__name__ - if self.trace_key(key) not in d: + if self.trace_key(key) not in data: raise RuntimeError("can not find the index of transform have been applied.") # get the index of the applied OneOf transform - index = self.get_most_recent_transform(d, key)[TraceKeys.EXTRA_INFO]["index"] + index = self.get_most_recent_transform(data, key)[TraceKeys.EXTRA_INFO]["index"] # and then remove the OneOf transform - self.pop_transform(d, key) + self.pop_transform(data, key) _transform = self.transforms[index] # apply the inverse - return _transform.inverse(d) if isinstance(_transform, InvertibleTransform) else d + return _transform.inverse(data) if isinstance(_transform, InvertibleTransform) else data diff --git a/tests/test_one_of.py b/tests/test_one_of.py index dc20865892..29d13d7d0c 100644 --- a/tests/test_one_of.py +++ b/tests/test_one_of.py @@ -150,29 +150,31 @@ def _match(a, b): @parameterized.expand(TEST_INVERSES) def test_inverse(self, transform, invertible): data = {k: (i + 1) * 10.0 for i, k in enumerate(KEYS)} - key = OneOf.__name__ fwd_data = transform(data) - t = fwd_data[TraceableTransform.trace_key(key)][-1] - # make sure the OneOf index was stored - self.assertEqual(t[TraceKeys.CLASS_NAME], key) - # make sure index exists and is in bounds - self.assertTrue(0 <= t[TraceKeys.EXTRA_INFO]["index"] < len(transform)) + + if invertible: + for k in KEYS: + t = fwd_data[TraceableTransform.trace_key(k)][-1] + # make sure the OneOf index was stored + self.assertEqual(t[TraceKeys.CLASS_NAME], OneOf.__name__) + # make sure index exists and is in bounds + self.assertTrue(0 <= t[TraceKeys.EXTRA_INFO]["index"] < len(transform)) # call the inverse fwd_inv_data = transform.inverse(fwd_data) - # check transform was removed - self.assertTrue( - len(fwd_inv_data[TraceableTransform.trace_key(key)]) < len(fwd_data[TraceableTransform.trace_key(key)]) - ) - # check data is same as original (and different from forward) - for k, v in data.items(): - if invertible: - self.assertEqual(fwd_inv_data[k], v) + if invertible: + for k in KEYS: + # check transform was removed + self.assertTrue( + len(fwd_inv_data[TraceableTransform.trace_key(k)]) < len(fwd_data[TraceableTransform.trace_key(k)]) + ) + # check data is same as original (and different from forward) + self.assertEqual(fwd_inv_data[k], data[k]) self.assertNotEqual(fwd_inv_data[k], fwd_data[k]) - else: - # if not invertible, should not change the data - self.assertEqual(fwd_inv_data[k], fwd_data[k]) + else: + # if not invertible, should not change the data + self.assertDictEqual(fwd_data, fwd_inv_data) def test_inverse_compose(self): transform = Compose( From c081351e4e685f66e5c2efb59ab5a49e83f54a68 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Wed, 22 Dec 2021 23:38:09 +0800 Subject: [PATCH 6/6] Revert "[DLMED] update index" This reverts commit 649a7c59e094283be9625b7b496c53b5884265ab. Signed-off-by: Nic Ma --- monai/transforms/compose.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index ad3f35fa79..165d9b732f 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -236,7 +236,9 @@ def __call__(self, data): data = apply_transform(_transform, data, self.map_items, self.unpack_items) # if the data is a mapping (dictionary), append the OneOf transform to the end if isinstance(data, Mapping): - self.push_transform(data, self.__class__.__name__, extra_info={"index": index}) + for key in data.keys(): + if self.trace_key(key) in data: + self.push_transform(data, key, extra_info={"index": index}) return data def inverse(self, data): @@ -246,14 +248,16 @@ def inverse(self, data): raise RuntimeError("Inverse only implemented for Mapping (dictionary) data") # loop until we get an index and then break (since they'll all be the same) - key = self.__class__.__name__ - if self.trace_key(key) not in data: - raise RuntimeError("can not find the index of transform have been applied.") - - # get the index of the applied OneOf transform - index = self.get_most_recent_transform(data, key)[TraceKeys.EXTRA_INFO]["index"] - # and then remove the OneOf transform - self.pop_transform(data, key) + index = None + for key in data.keys(): + if self.trace_key(key) in data: + # get the index of the applied OneOf transform + index = self.get_most_recent_transform(data, key)[TraceKeys.EXTRA_INFO]["index"] + # and then remove the OneOf transform + self.pop_transform(data, key) + if index is None: + # no invertible transforms have been applied + return data _transform = self.transforms[index] # apply the inverse