Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions monai/transforms/post/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from monai.networks.layers import GaussianFilter, apply_filter, separable_filtering
from monai.transforms.inverse import InvertibleTransform
from monai.transforms.transform import Transform
from monai.transforms.utility.array import ToTensor
from monai.transforms.utils import (
convert_applied_interp_mode,
fill_holes,
Expand Down Expand Up @@ -788,23 +789,27 @@ def __init__(
self,
transform: Optional[InvertibleTransform] = None,
nearest_interp: Union[bool, Sequence[bool]] = True,
device: Union[Union[str, torch.device], Sequence[Union[str, torch.device]]] = "cpu",
post_func: Union[Callable, Sequence[Callable]] = lambda x: x,
device: Union[str, torch.device, None] = None,
post_func: Optional[Callable] = None,
to_tensor: Union[bool, Sequence[bool]] = True,
) -> None:
"""
Args:
transform: the previously applied transform.
nearest_interp: whether to use `nearest` interpolation mode when inverting the spatial transforms,
default to `True`. If `False`, use the same interpolation mode as the original transform.
device: move the inverted results to a target device before `post_func`, default to "cpu".
post_func: postprocessing for the inverted MetaTensor, should be a callable function.
device: move the inverted results to a target device before `post_func`, default to `None`.
post_func: postprocessing for the inverted result, should be a callable function.
to_tensor: whether to convert the inverted data into PyTorch Tensor first, default to `True`.
"""
if not isinstance(transform, InvertibleTransform):
raise ValueError("transform is not invertible, can't invert transform for the data.")
self.transform = transform
self.nearest_interp = nearest_interp
self.device = device
self.post_func = post_func
self.to_tensor = to_tensor
self._totensor = ToTensor()

def __call__(self, data):
if not isinstance(data, MetaTensor):
Expand All @@ -817,7 +822,12 @@ def __call__(self, data):

data = data.detach()
inverted = self.transform.inverse(data)
inverted = self.post_func(inverted.to(self.device))
if self.to_tensor and not isinstance(inverted, MetaTensor):
inverted = self._totensor(inverted)
if isinstance(inverted, torch.Tensor):
inverted = inverted.to(device=self.device)
if callable(self.post_func):
inverted = self.post_func(inverted)
return inverted


Expand Down
36 changes: 19 additions & 17 deletions monai/transforms/post/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,13 +547,17 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]):

class Invertd(MapTransform):
"""
Utility transform to automatically invert the previously applied transforms.
Utility transform to invert the previously applied transforms.

Taking the ``transform`` previously applied on ``orig_keys``, this ``Invertd`` will apply the inverse of it
to the data stored at ``keys``. ``Invertd``'s output will also include a copy of the metadata
dictionary (originally from ``orig_meta_keys``), with the relevant fields inverted and stored at ``meta_keys``.
to the data stored at ``keys``.

A typical usage is to apply the inverse of the preprocessing on input ``image`` to the model ``pred``.
``Invertd``'s output will also include a copy of the metadata
dictionary (originally from ``orig_meta_keys`` or the metadata of ``orig_keys``),
with the relevant fields inverted and stored at ``meta_keys``.

A typical usage is to apply the inverse of the preprocessing (``transform=preprocessings``) on
input ``orig_keys=image`` to the model predictions ``keys=pred``.

A detailed usage example is available in the tutorial:
https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/torch/unet_inference_dict.py
Expand All @@ -580,8 +584,8 @@ def __init__(
meta_key_postfix: str = DEFAULT_POST_FIX,
nearest_interp: Union[bool, Sequence[bool]] = True,
to_tensor: Union[bool, Sequence[bool]] = True,
device: Union[Union[str, torch.device], Sequence[Union[str, torch.device]]] = "cpu",
post_func: Union[Callable, Sequence[Callable]] = lambda x: x,
device: Union[Union[str, torch.device], Sequence[Union[str, torch.device]], None] = None,
post_func: Union[Callable, Sequence[Callable], None] = None,
allow_missing_keys: bool = False,
) -> None:
"""
Expand Down Expand Up @@ -609,7 +613,7 @@ def __init__(
to_tensor: whether to convert the inverted data into PyTorch Tensor first, default to `True`.
It also can be a list of bool, each matches to the `keys` data.
device: if converted to Tensor, move the inverted results to target device before `post_func`,
default to "cpu", it also can be a list of string or `torch.device`, each matches to the `keys` data.
default to None, it also can be a list of string or `torch.device`, each matches to the `keys` data.
post_func: post processing for the inverted data, should be a callable function.
It also can be a list of callable, each matches to the `keys` data.
allow_missing_keys: don't raise exception if key is missing.
Expand Down Expand Up @@ -694,16 +698,14 @@ def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]:
inverted = self.transform.inverse(input_dict)

# save the inverted data
if to_tensor and not isinstance(inverted[orig_key], MetaTensor):
inverted_data = self._totensor(inverted[orig_key])
else:
inverted_data = inverted[orig_key]
if isinstance(inverted_data, np.ndarray) and torch.device(device).type != "cpu":
raise ValueError("Inverted data with type of 'numpy.ndarray' do not support GPU.")
elif isinstance(inverted_data, torch.Tensor):
d[key] = post_func(inverted_data.to(device))
else:
d[key] = post_func(inverted_data)
inverted_data = inverted[orig_key]
if to_tensor and not isinstance(inverted_data, MetaTensor):
inverted_data = self._totensor(inverted_data)
if isinstance(inverted_data, np.ndarray) and device is not None and torch.device(device).type != "cpu":
raise ValueError(f"Inverted data with type of 'numpy.ndarray' support device='cpu', got {device}.")
if isinstance(inverted_data, torch.Tensor):
inverted_data = inverted_data.to(device=device)
d[key] = post_func(inverted_data) if callable(post_func) else inverted_data
# save the invertd applied_operations if it's in the source dict
if InvertibleTransform.trace_key(orig_key) in d:
d[InvertibleTransform.trace_key(orig_key)] = inverted_data.applied_operations
Expand Down
4 changes: 3 additions & 1 deletion tests/test_invert.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,9 @@ def test_invert(self):
dataset = Dataset(data, transform=transform)
self.assertIsInstance(transform.inverse(dataset[0]), MetaTensor)
loader = DataLoader(dataset, num_workers=num_workers, batch_size=1)
inverter = Invert(transform=transform, nearest_interp=True, device="cpu")
inverter = Invert(
transform=transform, nearest_interp=True, device="cpu", post_func=lambda x: torch.as_tensor(x)
)

for d in loader:
d = decollate_batch(d)
Expand Down
3 changes: 2 additions & 1 deletion tests/test_invertd.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ def test_invert(self):
transform=transform,
orig_keys=["label", "label"],
nearest_interp=True,
device="cpu",
device=None,
post_func=lambda x: torch.as_tensor(x),
)

inverter_1 = Invertd(
Expand Down