Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 36 additions & 42 deletions monai/transforms/post/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,28 +517,25 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]):
class Invertd(MapTransform):
"""
Utility transform to automatically invert the previously applied transforms.
When applying preprocessing transforms on a orig_key(like: `image`, `label`, etc.), we record the context
information of applied transforms in a dictionary in the input data dictionary with the key
"{orig_key}_transforms". This transform will extract the transform context information of `orig_keys`
then invert the transforms(got from this context information) on the `keys` data.
Typical usage is to invert the preprocessing transforms(applied on input `image`) on the model `pred` data.

The output of the inverted data and metadata will be stored at `keys` and `meta_keys` respectively.
To correctly invert the transforms, the information of the previously applied transforms should be
available at `orig_keys`, and the original metadata at `orig_meta_keys`.
(`meta_key_postfix` is an optional string to conveniently construct "meta_keys" and/or "orig_meta_keys".)
Taking the ``transform`` previously applied on ``orig_keys``, this ``Invertd`` will apply the inverse of it
to the data stored at ``keys``. ``Invertd``'s output will also include a copy of the metadata
dictionary (originally from ``orig_meta_keys``), with the relevant fields inverted and stored at ``meta_keys``.

A typical usage is to apply the inverse of the preprocessing on input ``image`` to the model ``pred``.

A detailed usage example is available in the tutorial:
https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/torch/unet_inference_dict.py

Note:
According to the `collate_fn`, this transform may return a list of Tensor without batch dim,
thus some following transforms may not support a list of Tensor, and users can leverage the
`post_func` arg for basic processing logic.

This transform needs to extract the context information of applied transforms and the meta data
dictionary from the input data dictionary, then use some numpy arrays in them to computes the inverse
logic, so please don't move `data["{orig_key}_transforms"]` and `data["{orig_meta_key}"]` to GPU device.
- The output of the inverted data and metadata will be stored at ``keys`` and ``meta_keys`` respectively.
- To correctly invert the transforms, the information of the previously applied transforms should be
available at ``{orig_keys}_transforms``, and the original metadata at ``orig_meta_keys``.
(``meta_key_postfix`` is an optional string to conveniently construct "meta_keys" and/or "orig_meta_keys".)
see also: :py:class:`monai.transforms.TraceableTransform`.
- The transform will not change the content in ``orig_keys`` and ``orig_meta_key``.
These keys are only used to represent the data status of ``key`` before inverting.

"""

Expand All @@ -558,37 +555,32 @@ def __init__(
) -> None:
"""
Args:
keys: the key of expected data in the dict, invert transforms on it, in-place operation.
it also can be a list of keys, will invert transform for each of them, like: ["pred", "pred_class2"].
transform: the previous callable transform that applied on input data.
orig_keys: the key of the original input data in the dict. will get the applied transform information
for this input data, then invert them for the expected data with `keys`.
It can also be a list of keys, each matches to the `keys` data.
meta_keys: explicitly indicate the key for the inverted meta data dictionary.
the meta data is a dictionary object which contains: filename, original_shape, etc.
it can be a sequence of string, map to the `keys`.
if None, will try to construct meta_keys by `{key}_{meta_key_postfix}`.
orig_meta_keys: the key of the meta data of original input data, will get the `affine`, `data_shape`, etc.
the meta data is a dictionary object which contains: filename, original_shape, etc.
it can be a sequence of string, map to the `keys`.
if None, will try to construct meta_keys by `{orig_key}_{meta_key_postfix}`.
meta data will also be inverted and stored in `meta_keys`.
keys: the key of expected data in the dict, the inverse of ``transforms`` will be applied on it in-place.
It also can be a list of keys, will apply the inverse transform respectively.
transform: the transform applied to ``orig_key``, its inverse will be applied on ``key``.
orig_keys: the key of the original input data in the dict.
the transform trace information of ``transforms`` should be stored at ``{orig_keys}_transforms``.
It can also be a list of keys, each matches the ``keys``.
meta_keys: The key to output the inverted meta data dictionary.
The meta data is a dictionary optionally containing: filename, original_shape.
It can be a sequence of strings, maps to ``keys``.
If None, will try to create a meta data dict with the default key: `{key}_{meta_key_postfix}`.
orig_meta_keys: the key of the meta data of original input data.
The meta data is a dictionary optionally containing: filename, original_shape.
It can be a sequence of strings, maps to the `keys`.
If None, will try to create a meta data dict with the default key: `{orig_key}_{meta_key_postfix}`.
This meta data dict will also be included in the inverted dict, stored in `meta_keys`.
meta_key_postfix: if `orig_meta_keys` is None, use `{orig_key}_{meta_key_postfix}` to fetch the
meta data from dict, if `meta_keys` is None, use `{key}_{meta_key_postfix}`.
default is `meta_dict`, the meta data is a dictionary object.
Comment thread
wyli marked this conversation as resolved.
For example, to handle orig_key `image`, read/write `affine` matrices from the
metadata `image_meta_dict` dictionary's `affine` field.
the inverted meta dict will be stored with key: "{key}_{meta_key_postfix}".
meta data from dict, if `meta_keys` is None, use `{key}_{meta_key_postfix}`. Default: ``"meta_dict"``.
nearest_interp: whether to use `nearest` interpolation mode when inverting the spatial transforms,
default to `True`. If `False`, use the same interpolation mode as the original transform.
it also can be a list of bool, each matches to the `keys` data.
It also can be a list of bool, each matches to the `keys` data.
to_tensor: whether to convert the inverted data into PyTorch Tensor first, default to `True`.
it also can be a list of bool, each matches to the `keys` data.
It also can be a list of bool, each matches to the `keys` data.
device: if converted to Tensor, move the inverted results to target device before `post_func`,
default to "cpu", it also can be a list of string or `torch.device`,
each matches to the `keys` data.
default to "cpu", it also can be a list of string or `torch.device`, each matches to the `keys` data.
post_func: post processing for the inverted data, should be a callable function.
it also can be a list of callable, each matches to the `keys` data.
It also can be a list of callable, each matches to the `keys` data.
allow_missing_keys: don't raise exception if key is missing.

"""
Expand Down Expand Up @@ -645,10 +637,10 @@ def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]:
input = d[key]
if isinstance(input, torch.Tensor):
input = input.detach()
# construct the input dict data for BatchInverseTransform

# construct the input dict data
input_dict = {orig_key: input, transform_key: transform_info}
orig_meta_key = orig_meta_key or f"{orig_key}_{meta_key_postfix}"
meta_key = meta_key or f"{key}_{meta_key_postfix}"
if orig_meta_key in d:
input_dict[orig_meta_key] = d[orig_meta_key]

Expand All @@ -657,8 +649,10 @@ def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]:

# save the inverted data
d[key] = post_func(self._totensor(inverted[orig_key]).to(device) if to_tensor else inverted[orig_key])

# save the inverted meta dict
if orig_meta_key in d:
meta_key = meta_key or f"{key}_{meta_key_postfix}"
d[meta_key] = inverted.get(orig_meta_key)

return d
Expand Down
6 changes: 6 additions & 0 deletions monai/transforms/spatial/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,12 @@ def __call__(
else:
if self.axcodes is None:
raise ValueError("Incompatible values: axcodes=None and as_closest_canonical=True.")
if sr < len(self.axcodes):
warnings.warn(
f"axcodes ('{self.axcodes}') length is smaller than the number of input spatial dimensions D={sr}.\n"
f"{self.__class__.__name__}: input spatial shape is {spatial_shape}, num. channels is {data_array.shape[0]},"
"please make sure the input is in the channel-first format."
)
dst = nib.orientations.axcodes2ornt(self.axcodes[:sr], labels=self.labels)
if len(dst) < sr:
raise ValueError(
Expand Down