From 0f3ddf43d87534e63c3e822d63918cd9a6de6b63 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Fri, 11 Nov 2022 09:16:02 +0000 Subject: [PATCH] update invert Signed-off-by: Wenqi Li --- monai/transforms/post/array.py | 20 ++++++++++++---- monai/transforms/post/dictionary.py | 36 +++++++++++++++-------------- tests/test_invert.py | 4 +++- tests/test_invertd.py | 3 ++- 4 files changed, 39 insertions(+), 24 deletions(-) diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index 34709daa42..3bc0be6391 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -27,6 +27,7 @@ from monai.networks.layers import GaussianFilter, apply_filter, separable_filtering from monai.transforms.inverse import InvertibleTransform from monai.transforms.transform import Transform +from monai.transforms.utility.array import ToTensor from monai.transforms.utils import ( convert_applied_interp_mode, fill_holes, @@ -788,16 +789,18 @@ def __init__( self, transform: Optional[InvertibleTransform] = None, nearest_interp: Union[bool, Sequence[bool]] = True, - device: Union[Union[str, torch.device], Sequence[Union[str, torch.device]]] = "cpu", - post_func: Union[Callable, Sequence[Callable]] = lambda x: x, + device: Union[str, torch.device, None] = None, + post_func: Optional[Callable] = None, + to_tensor: Union[bool, Sequence[bool]] = True, ) -> None: """ Args: transform: the previously applied transform. nearest_interp: whether to use `nearest` interpolation mode when inverting the spatial transforms, default to `True`. If `False`, use the same interpolation mode as the original transform. - device: move the inverted results to a target device before `post_func`, default to "cpu". - post_func: postprocessing for the inverted MetaTensor, should be a callable function. + device: move the inverted results to a target device before `post_func`, default to `None`. + post_func: postprocessing for the inverted result, should be a callable function. + to_tensor: whether to convert the inverted data into PyTorch Tensor first, default to `True`. """ if not isinstance(transform, InvertibleTransform): raise ValueError("transform is not invertible, can't invert transform for the data.") @@ -805,6 +808,8 @@ def __init__( self.nearest_interp = nearest_interp self.device = device self.post_func = post_func + self.to_tensor = to_tensor + self._totensor = ToTensor() def __call__(self, data): if not isinstance(data, MetaTensor): @@ -817,7 +822,12 @@ def __call__(self, data): data = data.detach() inverted = self.transform.inverse(data) - inverted = self.post_func(inverted.to(self.device)) + if self.to_tensor and not isinstance(inverted, MetaTensor): + inverted = self._totensor(inverted) + if isinstance(inverted, torch.Tensor): + inverted = inverted.to(device=self.device) + if callable(self.post_func): + inverted = self.post_func(inverted) return inverted diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py index 78d84a0bd1..e883310510 100644 --- a/monai/transforms/post/dictionary.py +++ b/monai/transforms/post/dictionary.py @@ -547,13 +547,17 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]): class Invertd(MapTransform): """ - Utility transform to automatically invert the previously applied transforms. + Utility transform to invert the previously applied transforms. Taking the ``transform`` previously applied on ``orig_keys``, this ``Invertd`` will apply the inverse of it - to the data stored at ``keys``. ``Invertd``'s output will also include a copy of the metadata - dictionary (originally from ``orig_meta_keys``), with the relevant fields inverted and stored at ``meta_keys``. + to the data stored at ``keys``. - A typical usage is to apply the inverse of the preprocessing on input ``image`` to the model ``pred``. + ``Invertd``'s output will also include a copy of the metadata + dictionary (originally from ``orig_meta_keys`` or the metadata of ``orig_keys``), + with the relevant fields inverted and stored at ``meta_keys``. + + A typical usage is to apply the inverse of the preprocessing (``transform=preprocessings``) on + input ``orig_keys=image`` to the model predictions ``keys=pred``. A detailed usage example is available in the tutorial: https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/torch/unet_inference_dict.py @@ -580,8 +584,8 @@ def __init__( meta_key_postfix: str = DEFAULT_POST_FIX, nearest_interp: Union[bool, Sequence[bool]] = True, to_tensor: Union[bool, Sequence[bool]] = True, - device: Union[Union[str, torch.device], Sequence[Union[str, torch.device]]] = "cpu", - post_func: Union[Callable, Sequence[Callable]] = lambda x: x, + device: Union[Union[str, torch.device], Sequence[Union[str, torch.device]], None] = None, + post_func: Union[Callable, Sequence[Callable], None] = None, allow_missing_keys: bool = False, ) -> None: """ @@ -609,7 +613,7 @@ def __init__( to_tensor: whether to convert the inverted data into PyTorch Tensor first, default to `True`. It also can be a list of bool, each matches to the `keys` data. device: if converted to Tensor, move the inverted results to target device before `post_func`, - default to "cpu", it also can be a list of string or `torch.device`, each matches to the `keys` data. + default to None, it also can be a list of string or `torch.device`, each matches to the `keys` data. post_func: post processing for the inverted data, should be a callable function. It also can be a list of callable, each matches to the `keys` data. allow_missing_keys: don't raise exception if key is missing. @@ -694,16 +698,14 @@ def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: inverted = self.transform.inverse(input_dict) # save the inverted data - if to_tensor and not isinstance(inverted[orig_key], MetaTensor): - inverted_data = self._totensor(inverted[orig_key]) - else: - inverted_data = inverted[orig_key] - if isinstance(inverted_data, np.ndarray) and torch.device(device).type != "cpu": - raise ValueError("Inverted data with type of 'numpy.ndarray' do not support GPU.") - elif isinstance(inverted_data, torch.Tensor): - d[key] = post_func(inverted_data.to(device)) - else: - d[key] = post_func(inverted_data) + inverted_data = inverted[orig_key] + if to_tensor and not isinstance(inverted_data, MetaTensor): + inverted_data = self._totensor(inverted_data) + if isinstance(inverted_data, np.ndarray) and device is not None and torch.device(device).type != "cpu": + raise ValueError(f"Inverted data with type of 'numpy.ndarray' support device='cpu', got {device}.") + if isinstance(inverted_data, torch.Tensor): + inverted_data = inverted_data.to(device=device) + d[key] = post_func(inverted_data) if callable(post_func) else inverted_data # save the invertd applied_operations if it's in the source dict if InvertibleTransform.trace_key(orig_key) in d: d[InvertibleTransform.trace_key(orig_key)] = inverted_data.applied_operations diff --git a/tests/test_invert.py b/tests/test_invert.py index 4bd648c264..410f170b4f 100644 --- a/tests/test_invert.py +++ b/tests/test_invert.py @@ -64,7 +64,9 @@ def test_invert(self): dataset = Dataset(data, transform=transform) self.assertIsInstance(transform.inverse(dataset[0]), MetaTensor) loader = DataLoader(dataset, num_workers=num_workers, batch_size=1) - inverter = Invert(transform=transform, nearest_interp=True, device="cpu") + inverter = Invert( + transform=transform, nearest_interp=True, device="cpu", post_func=lambda x: torch.as_tensor(x) + ) for d in loader: d = decollate_batch(d) diff --git a/tests/test_invertd.py b/tests/test_invertd.py index afa7958e9a..37db4f66c6 100644 --- a/tests/test_invertd.py +++ b/tests/test_invertd.py @@ -77,7 +77,8 @@ def test_invert(self): transform=transform, orig_keys=["label", "label"], nearest_interp=True, - device="cpu", + device=None, + post_func=lambda x: torch.as_tensor(x), ) inverter_1 = Invertd(