diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py index d44fb6b3fa..b89196207b 100644 --- a/monai/transforms/post/dictionary.py +++ b/monai/transforms/post/dictionary.py @@ -517,28 +517,25 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]): class Invertd(MapTransform): """ Utility transform to automatically invert the previously applied transforms. - When applying preprocessing transforms on a orig_key(like: `image`, `label`, etc.), we record the context - information of applied transforms in a dictionary in the input data dictionary with the key - "{orig_key}_transforms". This transform will extract the transform context information of `orig_keys` - then invert the transforms(got from this context information) on the `keys` data. - Typical usage is to invert the preprocessing transforms(applied on input `image`) on the model `pred` data. - The output of the inverted data and metadata will be stored at `keys` and `meta_keys` respectively. - To correctly invert the transforms, the information of the previously applied transforms should be - available at `orig_keys`, and the original metadata at `orig_meta_keys`. - (`meta_key_postfix` is an optional string to conveniently construct "meta_keys" and/or "orig_meta_keys".) + Taking the ``transform`` previously applied on ``orig_keys``, this ``Invertd`` will apply the inverse of it + to the data stored at ``keys``. ``Invertd``'s output will also include a copy of the metadata + dictionary (originally from ``orig_meta_keys``), with the relevant fields inverted and stored at ``meta_keys``. + + A typical usage is to apply the inverse of the preprocessing on input ``image`` to the model ``pred``. A detailed usage example is available in the tutorial: https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/torch/unet_inference_dict.py Note: - According to the `collate_fn`, this transform may return a list of Tensor without batch dim, - thus some following transforms may not support a list of Tensor, and users can leverage the - `post_func` arg for basic processing logic. - This transform needs to extract the context information of applied transforms and the meta data - dictionary from the input data dictionary, then use some numpy arrays in them to computes the inverse - logic, so please don't move `data["{orig_key}_transforms"]` and `data["{orig_meta_key}"]` to GPU device. + - The output of the inverted data and metadata will be stored at ``keys`` and ``meta_keys`` respectively. + - To correctly invert the transforms, the information of the previously applied transforms should be + available at ``{orig_keys}_transforms``, and the original metadata at ``orig_meta_keys``. + (``meta_key_postfix`` is an optional string to conveniently construct "meta_keys" and/or "orig_meta_keys".) + see also: :py:class:`monai.transforms.TraceableTransform`. + - The transform will not change the content in ``orig_keys`` and ``orig_meta_key``. + These keys are only used to represent the data status of ``key`` before inverting. """ @@ -558,37 +555,32 @@ def __init__( ) -> None: """ Args: - keys: the key of expected data in the dict, invert transforms on it, in-place operation. - it also can be a list of keys, will invert transform for each of them, like: ["pred", "pred_class2"]. - transform: the previous callable transform that applied on input data. - orig_keys: the key of the original input data in the dict. will get the applied transform information - for this input data, then invert them for the expected data with `keys`. - It can also be a list of keys, each matches to the `keys` data. - meta_keys: explicitly indicate the key for the inverted meta data dictionary. - the meta data is a dictionary object which contains: filename, original_shape, etc. - it can be a sequence of string, map to the `keys`. - if None, will try to construct meta_keys by `{key}_{meta_key_postfix}`. - orig_meta_keys: the key of the meta data of original input data, will get the `affine`, `data_shape`, etc. - the meta data is a dictionary object which contains: filename, original_shape, etc. - it can be a sequence of string, map to the `keys`. - if None, will try to construct meta_keys by `{orig_key}_{meta_key_postfix}`. - meta data will also be inverted and stored in `meta_keys`. + keys: the key of expected data in the dict, the inverse of ``transforms`` will be applied on it in-place. + It also can be a list of keys, will apply the inverse transform respectively. + transform: the transform applied to ``orig_key``, its inverse will be applied on ``key``. + orig_keys: the key of the original input data in the dict. + the transform trace information of ``transforms`` should be stored at ``{orig_keys}_transforms``. + It can also be a list of keys, each matches the ``keys``. + meta_keys: The key to output the inverted meta data dictionary. + The meta data is a dictionary optionally containing: filename, original_shape. + It can be a sequence of strings, maps to ``keys``. + If None, will try to create a meta data dict with the default key: `{key}_{meta_key_postfix}`. + orig_meta_keys: the key of the meta data of original input data. + The meta data is a dictionary optionally containing: filename, original_shape. + It can be a sequence of strings, maps to the `keys`. + If None, will try to create a meta data dict with the default key: `{orig_key}_{meta_key_postfix}`. + This meta data dict will also be included in the inverted dict, stored in `meta_keys`. meta_key_postfix: if `orig_meta_keys` is None, use `{orig_key}_{meta_key_postfix}` to fetch the - meta data from dict, if `meta_keys` is None, use `{key}_{meta_key_postfix}`. - default is `meta_dict`, the meta data is a dictionary object. - For example, to handle orig_key `image`, read/write `affine` matrices from the - metadata `image_meta_dict` dictionary's `affine` field. - the inverted meta dict will be stored with key: "{key}_{meta_key_postfix}". + meta data from dict, if `meta_keys` is None, use `{key}_{meta_key_postfix}`. Default: ``"meta_dict"``. nearest_interp: whether to use `nearest` interpolation mode when inverting the spatial transforms, default to `True`. If `False`, use the same interpolation mode as the original transform. - it also can be a list of bool, each matches to the `keys` data. + It also can be a list of bool, each matches to the `keys` data. to_tensor: whether to convert the inverted data into PyTorch Tensor first, default to `True`. - it also can be a list of bool, each matches to the `keys` data. + It also can be a list of bool, each matches to the `keys` data. device: if converted to Tensor, move the inverted results to target device before `post_func`, - default to "cpu", it also can be a list of string or `torch.device`, - each matches to the `keys` data. + default to "cpu", it also can be a list of string or `torch.device`, each matches to the `keys` data. post_func: post processing for the inverted data, should be a callable function. - it also can be a list of callable, each matches to the `keys` data. + It also can be a list of callable, each matches to the `keys` data. allow_missing_keys: don't raise exception if key is missing. """ @@ -645,10 +637,10 @@ def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: input = d[key] if isinstance(input, torch.Tensor): input = input.detach() - # construct the input dict data for BatchInverseTransform + + # construct the input dict data input_dict = {orig_key: input, transform_key: transform_info} orig_meta_key = orig_meta_key or f"{orig_key}_{meta_key_postfix}" - meta_key = meta_key or f"{key}_{meta_key_postfix}" if orig_meta_key in d: input_dict[orig_meta_key] = d[orig_meta_key] @@ -657,8 +649,10 @@ def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: # save the inverted data d[key] = post_func(self._totensor(inverted[orig_key]).to(device) if to_tensor else inverted[orig_key]) + # save the inverted meta dict if orig_meta_key in d: + meta_key = meta_key or f"{key}_{meta_key_postfix}" d[meta_key] = inverted.get(orig_meta_key) return d diff --git a/monai/transforms/spatial/array.py b/monai/transforms/spatial/array.py index 0f11dc4390..2ed2ae42c7 100644 --- a/monai/transforms/spatial/array.py +++ b/monai/transforms/spatial/array.py @@ -489,6 +489,12 @@ def __call__( else: if self.axcodes is None: raise ValueError("Incompatible values: axcodes=None and as_closest_canonical=True.") + if sr < len(self.axcodes): + warnings.warn( + f"axcodes ('{self.axcodes}') length is smaller than the number of input spatial dimensions D={sr}.\n" + f"{self.__class__.__name__}: input spatial shape is {spatial_shape}, num. channels is {data_array.shape[0]}," + "please make sure the input is in the channel-first format." + ) dst = nib.orientations.axcodes2ornt(self.axcodes[:sr], labels=self.labels) if len(dst) < sr: raise ValueError(