Skip to content
Merged
9 changes: 3 additions & 6 deletions monai/transforms/compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,12 +256,9 @@ def inverse(self, data):
# and then remove the OneOf transform
self.pop_transform(data, key)
if index is None:
raise RuntimeError("No invertible transforms have been applied")
# no invertible transforms have been applied
return data

# if applied transform is not InvertibleTransform, throw error
_transform = self.transforms[index]
if not isinstance(_transform, InvertibleTransform):
raise RuntimeError(f"Applied OneOf transform is not invertible (applied index: {index}).")

# apply the inverse
return _transform.inverse(data)
return _transform.inverse(data) if isinstance(_transform, InvertibleTransform) else data
71 changes: 50 additions & 21 deletions tests/test_one_of.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,18 @@
import unittest
from copy import deepcopy

import numpy as np
from parameterized import parameterized

from monai.transforms import InvertibleTransform, OneOf, TraceableTransform, Transform
from monai.transforms import (
InvertibleTransform,
OneOf,
RandScaleIntensityd,
RandShiftIntensityd,
Resized,
TraceableTransform,
Transform,
)
from monai.transforms.compose import Compose
from monai.transforms.transform import MapTransform
from monai.utils.enums import TraceKeys
Expand Down Expand Up @@ -139,32 +148,52 @@ def _match(a, b):
_match(p, f)

@parameterized.expand(TEST_INVERSES)
def test_inverse(self, transform, should_be_ok):
def test_inverse(self, transform, invertible):
data = {k: (i + 1) * 10.0 for i, k in enumerate(KEYS)}
fwd_data = transform(data)
if not should_be_ok:
with self.assertRaises(RuntimeError):
transform.inverse(fwd_data)
return

for k in KEYS:
t = fwd_data[TraceableTransform.trace_key(k)][-1]
# make sure the OneOf index was stored
self.assertEqual(t[TraceKeys.CLASS_NAME], OneOf.__name__)
# make sure index exists and is in bounds
self.assertTrue(0 <= t[TraceKeys.EXTRA_INFO]["index"] < len(transform))

if invertible:
for k in KEYS:
t = fwd_data[TraceableTransform.trace_key(k)][-1]
# make sure the OneOf index was stored
self.assertEqual(t[TraceKeys.CLASS_NAME], OneOf.__name__)
# make sure index exists and is in bounds
self.assertTrue(0 <= t[TraceKeys.EXTRA_INFO]["index"] < len(transform))

# call the inverse
fwd_inv_data = transform.inverse(fwd_data)

for k in KEYS:
# check transform was removed
self.assertTrue(
len(fwd_inv_data[TraceableTransform.trace_key(k)]) < len(fwd_data[TraceableTransform.trace_key(k)])
)
# check data is same as original (and different from forward)
self.assertEqual(fwd_inv_data[k], data[k])
self.assertNotEqual(fwd_inv_data[k], fwd_data[k])
if invertible:
for k in KEYS:
# check transform was removed
self.assertTrue(
len(fwd_inv_data[TraceableTransform.trace_key(k)]) < len(fwd_data[TraceableTransform.trace_key(k)])
)
# check data is same as original (and different from forward)
self.assertEqual(fwd_inv_data[k], data[k])
self.assertNotEqual(fwd_inv_data[k], fwd_data[k])
else:
# if not invertible, should not change the data
self.assertDictEqual(fwd_data, fwd_inv_data)

def test_inverse_compose(self):
transform = Compose(
[
Resized(keys="img", spatial_size=[100, 100, 100]),
OneOf(
[
RandScaleIntensityd(keys="img", factors=0.5, prob=1.0),
RandShiftIntensityd(keys="img", offsets=0.5, prob=1.0),
]
),
]
)
transform.set_random_state(seed=0)
result = transform({"img": np.ones((1, 101, 102, 103))})

result = transform.inverse(result)
# invert to the original spatial shape
self.assertTupleEqual(result["img"].shape, (1, 101, 102, 103))

def test_one_of(self):
p = OneOf((A(), B(), C()), (1, 2, 1))
Expand Down