-
Notifications
You must be signed in to change notification settings - Fork 1.4k
meta tensor #4077
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
meta tensor #4077
Changes from all commits
Commits
Show all changes
33 commits
Select commit
Hold shift + click to select a range
e15ec97
meta tensor
rijobro 20f1b47
Merge remote-tracking branch 'MONAI/dev' into MetaTensor_1st_PR
rijobro accca59
fixes
rijobro a15d462
Merge remote-tracking branch 'MONAI/dev' into MetaTensor_1st_PR
rijobro 36a1e75
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] f71d60f
deep_copy
rijobro 51e6b3b
fixes
rijobro 1e9f370
torchscript
rijobro 76e0cce
test pickling torchscript amp
rijobro de02d7f
fxies
rijobro 8a60df1
test fixes
rijobro 70bbcdb
fixes
rijobro eabd3bf
typos
rijobro 4af4d50
fixes
rijobro 4a2c211
fix test
rijobro 6373711
fix
rijobro 07f5117
fix?
rijobro 723f475
affine lives inside meta
rijobro b25440c
move affine in meta
rijobro 3e2eb02
flake8
rijobro b6a1f56
Merge remote-tracking branch 'MONAI/dev' into MetaTensor_1st_PR
rijobro 32cc5b3
fixes
rijobro 7fe23fd
fixes
rijobro 7f43b00
fixes
rijobro 840e7df
pytorch min version 1.7
rijobro e2256da
Revert "pytorch min version 1.7"
rijobro 79ce908
Merge remote-tracking branch 'MONAI/dev' into MetaTensor_transforms
rijobro e3e567c
torch 1.9
rijobro 87b572e
Merge branch 'dev' into MetaTensor_1st_PR
wyli 56bc6d5
remove __init__, correct docstring
rijobro 09eb602
adds docs
wyli 7064843
Merge pull request #3 from wyli/adds-docs
rijobro 81404ed
Merge branch 'dev' into MetaTensor_1st_PR
rijobro File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,207 @@ | ||
| # Copyright (c) MONAI Consortium | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from copy import deepcopy | ||
| from typing import Any, Callable, Sequence | ||
|
|
||
| _TRACK_META = True | ||
| _TRACK_TRANSFORMS = True | ||
|
|
||
| __all__ = ["get_track_meta", "get_track_transforms", "set_track_meta", "set_track_transforms", "MetaObj"] | ||
|
|
||
|
|
||
| def set_track_meta(val: bool) -> None: | ||
| """ | ||
| Boolean to set whether metadata is tracked. If `True`, metadata will be associated | ||
| its data by using subclasses of `MetaObj`. If `False`, then data will be returned | ||
| with empty metadata. | ||
|
|
||
| If both `set_track_meta` and `set_track_transforms` are set to | ||
| `False`, then standard data objects will be returned (e.g., `torch.Tensor` and | ||
| `np.ndarray`) as opposed to our enhanced objects. | ||
|
|
||
| By default, this is `True`, and most users will want to leave it this way. However, | ||
| if you are experiencing any problems regarding metadata, and aren't interested in | ||
| preserving metadata, then you can disable it. | ||
| """ | ||
| global _TRACK_META | ||
| _TRACK_META = val | ||
|
|
||
|
|
||
| def set_track_transforms(val: bool) -> None: | ||
| """ | ||
| Boolean to set whether transforms are tracked. If `True`, applied transforms will be | ||
| associated its data by using subclasses of `MetaObj`. If `False`, then transforms | ||
| won't be tracked. | ||
|
|
||
| If both `set_track_meta` and `set_track_transforms` are set to | ||
| `False`, then standard data objects will be returned (e.g., `torch.Tensor` and | ||
| `np.ndarray`) as opposed to our enhanced objects. | ||
|
|
||
| By default, this is `True`, and most users will want to leave it this way. However, | ||
| if you are experiencing any problems regarding transforms, and aren't interested in | ||
| preserving transforms, then you can disable it. | ||
| """ | ||
| global _TRACK_TRANSFORMS | ||
| _TRACK_TRANSFORMS = val | ||
|
|
||
|
|
||
| def get_track_meta() -> bool: | ||
| """ | ||
| Return the boolean as to whether metadata is tracked. If `True`, metadata will be | ||
| associated its data by using subclasses of `MetaObj`. If `False`, then data will be | ||
| returned with empty metadata. | ||
|
|
||
| If both `set_track_meta` and `set_track_transforms` are set to | ||
| `False`, then standard data objects will be returned (e.g., `torch.Tensor` and | ||
| `np.ndarray`) as opposed to our enhanced objects. | ||
|
|
||
| By default, this is `True`, and most users will want to leave it this way. However, | ||
| if you are experiencing any problems regarding metadata, and aren't interested in | ||
| preserving metadata, then you can disable it. | ||
| """ | ||
| return _TRACK_META | ||
|
|
||
|
|
||
| def get_track_transforms() -> bool: | ||
| """ | ||
| Return the boolean as to whether transforms are tracked. If `True`, applied | ||
| transforms will be associated its data by using subclasses of `MetaObj`. If `False`, | ||
| then transforms won't be tracked. | ||
|
|
||
| If both `set_track_meta` and `set_track_transforms` are set to | ||
| `False`, then standard data objects will be returned (e.g., `torch.Tensor` and | ||
| `np.ndarray`) as opposed to our enhanced objects. | ||
|
|
||
| By default, this is `True`, and most users will want to leave it this way. However, | ||
| if you are experiencing any problems regarding transforms, and aren't interested in | ||
| preserving transforms, then you can disable it. | ||
| """ | ||
| return _TRACK_TRANSFORMS | ||
|
|
||
|
|
||
| class MetaObj: | ||
| """ | ||
| Abstract base class that stores data as well as any extra metadata. | ||
|
|
||
| This allows for subclassing `torch.Tensor` and `np.ndarray` through multiple | ||
| inheritance. | ||
|
|
||
| Metadata is stored in the form of a dictionary. | ||
|
|
||
| Behavior should be the same as extended class (e.g., `torch.Tensor` or `np.ndarray`) | ||
| aside from the extended meta functionality. | ||
|
|
||
| Copying of information: | ||
|
|
||
| * For `c = a + b`, then auxiliary data (e.g., metadata) will be copied from the | ||
| first instance of `MetaObj`. | ||
|
|
||
| """ | ||
|
|
||
| _meta: dict | ||
|
|
||
| @staticmethod | ||
| def flatten_meta_objs(args: Sequence[Any]) -> list[MetaObj]: | ||
| """ | ||
| Recursively flatten input and return all instances of `MetaObj` as a single | ||
| list. This means that for both `torch.add(a, b)`, `torch.stack([a, b])` (and | ||
| their numpy equivalents), we return `[a, b]` if both `a` and `b` are of type | ||
| `MetaObj`. | ||
|
|
||
| Args: | ||
| args: Sequence of inputs to be flattened. | ||
| Returns: | ||
| list of nested `MetaObj` from input. | ||
| """ | ||
| out = [] | ||
| for a in args: | ||
| if isinstance(a, (list, tuple)): | ||
| out += MetaObj.flatten_meta_objs(a) | ||
| elif isinstance(a, MetaObj): | ||
| out.append(a) | ||
| return out | ||
|
|
||
| def _copy_attr(self, attribute: str, input_objs: list[MetaObj], default_fn: Callable, deep_copy: bool) -> None: | ||
| """ | ||
| Copy an attribute from the first in a list of `MetaObj`. In the case of | ||
| `torch.add(a, b)`, both `a` and `b` could be `MetaObj` or something else, so | ||
| check them all. Copy the first to `self`. | ||
|
|
||
| We also perform a deep copy of the data if desired. | ||
|
|
||
| Args: | ||
| attribute: string corresponding to attribute to be copied (e.g., `meta`). | ||
| input_objs: List of `MetaObj`. We'll copy the attribute from the first one | ||
| that contains that particular attribute. | ||
| default_fn: If none of `input_objs` have the attribute that we're | ||
| interested in, then use this default function (e.g., `lambda: {}`.) | ||
| deep_copy: Should the attribute be deep copied? See `_copy_meta`. | ||
|
|
||
| Returns: | ||
| Returns `None`, but `self` should be updated to have the copied attribute. | ||
| """ | ||
| attributes = [getattr(i, attribute) for i in input_objs] | ||
| if len(attributes) > 0: | ||
| val = attributes[0] | ||
| if deep_copy: | ||
| val = deepcopy(val) | ||
| setattr(self, attribute, val) | ||
| else: | ||
| setattr(self, attribute, default_fn()) | ||
|
|
||
| def _copy_meta(self, input_objs: list[MetaObj]) -> None: | ||
| """ | ||
| Copy metadata from a list of `MetaObj`. For a given attribute, we copy the | ||
| adjunct data from the first element in the list containing that attribute. | ||
|
|
||
| If there has been a change in `id` (e.g., `a=b+c`), then deepcopy. Else (e.g., | ||
| `a+=1`), then don't. | ||
|
|
||
| Args: | ||
| input_objs: list of `MetaObj` to copy data from. | ||
|
|
||
| """ | ||
| id_in = id(input_objs[0]) if len(input_objs) > 0 else None | ||
| deep_copy = id(self) != id_in | ||
| self._copy_attr("meta", input_objs, self.get_default_meta, deep_copy) | ||
|
|
||
| def get_default_meta(self) -> dict: | ||
rijobro marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """Get the default meta. | ||
|
|
||
| Returns: | ||
| default metadata. | ||
| """ | ||
| return {} | ||
|
|
||
| def __repr__(self) -> str: | ||
| """String representation of class.""" | ||
| out: str = super().__repr__() | ||
|
|
||
| out += "\nMetaData\n" | ||
| if self.meta is not None: | ||
| out += "".join(f"\t{k}: {v}\n" for k, v in self.meta.items()) | ||
| else: | ||
| out += "None" | ||
|
|
||
| return out | ||
|
|
||
| @property | ||
| def meta(self) -> dict: | ||
| """Get the meta.""" | ||
| return self._meta | ||
wyli marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| @meta.setter | ||
| def meta(self, d: dict) -> None: | ||
| """Set the meta.""" | ||
| self._meta = d | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,148 @@ | ||
| # Copyright (c) MONAI Consortium | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import warnings | ||
| from typing import Callable | ||
|
|
||
| import torch | ||
|
|
||
| from monai.data.meta_obj import MetaObj, get_track_meta, get_track_transforms | ||
| from monai.utils.enums import PostFix | ||
|
|
||
| __all__ = ["MetaTensor"] | ||
|
|
||
|
|
||
| class MetaTensor(MetaObj, torch.Tensor): | ||
| """ | ||
| Class that inherits from both `torch.Tensor` and `MetaObj`, adding support for metadata. | ||
|
|
||
| Metadata is stored in the form of a dictionary. Nested, an affine matrix will be | ||
| stored. This should be in the form of `torch.Tensor`. | ||
|
|
||
| Behavior should be the same as `torch.Tensor` aside from the extended | ||
| meta functionality. | ||
|
|
||
| Copying of information: | ||
|
|
||
| * For `c = a + b`, then auxiliary data (e.g., metadata) will be copied from the | ||
| first instance of `MetaTensor`. | ||
|
|
||
| Example: | ||
| .. code-block:: python | ||
|
|
||
| import torch | ||
| from monai.data import MetaTensor | ||
|
|
||
| t = torch.tensor([1,2,3]) | ||
| affine = torch.eye(4) * 100 | ||
| meta = {"some": "info"} | ||
| m = MetaTensor(t, affine=affine, meta=meta) | ||
| m2 = m+m | ||
| assert isinstance(m2, MetaTensor) | ||
| assert m2.meta["some"] == "info" | ||
| assert m2.affine == affine | ||
|
|
||
| Notes: | ||
| - Older versions of pytorch (<=1.8), `torch.jit.trace(net, im)` may | ||
| not work if `im` is of type `MetaTensor`. This can be resolved with | ||
| `torch.jit.trace(net, im.as_tensor())`. | ||
| - A warning will be raised if in the constructor `affine` is not `None` and | ||
| `meta` already contains the key `affine`. | ||
| """ | ||
|
|
||
| @staticmethod | ||
| def __new__(cls, x, affine: torch.Tensor | None = None, meta: dict | None = None, *args, **kwargs) -> MetaTensor: | ||
rijobro marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """ | ||
| If `meta` is given, use it. Else, if `meta` exists in the input tensor, use it. | ||
| Else, use the default value. Similar for the affine, except this could come from | ||
| four places. | ||
| Priority: `affine`, `meta["affine"]`, `x.affine`, `get_default_affine`. | ||
| """ | ||
| out: MetaTensor = torch.as_tensor(x, *args, **kwargs).as_subclass(cls) # type: ignore | ||
| # set meta | ||
| if meta is not None: | ||
| out.meta = meta | ||
| elif isinstance(x, MetaObj): | ||
| out.meta = x.meta | ||
| else: | ||
| out.meta = out.get_default_meta() | ||
| # set the affine | ||
| if affine is not None: | ||
| if "affine" in out.meta: | ||
| warnings.warn("Setting affine, but the applied meta contains an affine. " "This will be overwritten.") | ||
| out.affine = affine | ||
| elif "affine" in out.meta: | ||
| pass # nothing to do | ||
| elif isinstance(x, MetaTensor): | ||
| out.affine = x.affine | ||
| else: | ||
| out.affine = out.get_default_affine() | ||
| out.affine = out.affine.to(out.device) | ||
|
|
||
| return out | ||
|
|
||
| def _copy_attr(self, attribute: str, input_objs: list[MetaObj], default_fn: Callable, deep_copy: bool) -> None: | ||
| super()._copy_attr(attribute, input_objs, default_fn, deep_copy) | ||
| val = getattr(self, attribute) | ||
| if isinstance(val, torch.Tensor): | ||
| setattr(self, attribute, val.to(self.device)) | ||
|
|
||
| @classmethod | ||
| def __torch_function__(cls, func, types, args=(), kwargs=None) -> torch.Tensor: | ||
| """Wraps all torch functions.""" | ||
| if kwargs is None: | ||
| kwargs = {} | ||
| ret: MetaTensor = super().__torch_function__(func, types, args, kwargs) | ||
| # e.g., __repr__ returns a string | ||
| if not isinstance(ret, torch.Tensor): | ||
| return ret | ||
| if not (get_track_meta() or get_track_transforms()): | ||
| return ret.as_tensor() | ||
| meta_args = MetaObj.flatten_meta_objs(list(args) + list(kwargs.values())) | ||
| ret._copy_meta(meta_args) | ||
| ret.affine = ret.affine.to(ret.device) | ||
| return ret | ||
|
|
||
| def get_default_affine(self) -> torch.Tensor: | ||
| return torch.eye(4, device=self.device) | ||
|
|
||
| def as_tensor(self) -> torch.Tensor: | ||
| """ | ||
| Return the `MetaTensor` as a `torch.Tensor`. | ||
| It is OS dependent as to whether this will be a deep copy or not. | ||
| """ | ||
| return self.as_subclass(torch.Tensor) # type: ignore | ||
|
|
||
| def as_dict(self, key: str) -> dict: | ||
| """ | ||
| Get the object as a dictionary for backwards compatibility. | ||
|
|
||
| Args: | ||
| key: Base key to store main data. The key for the metadata will be | ||
| determined using `PostFix.meta`. | ||
|
|
||
| Return: | ||
| A dictionary consisting of two keys, the main data (stored under `key`) and | ||
| the metadata. | ||
| """ | ||
| return {key: self.as_tensor(), PostFix.meta(key): self.meta} | ||
|
|
||
| @property | ||
| def affine(self) -> torch.Tensor: | ||
wyli marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """Get the affine.""" | ||
| return self.meta["affine"] # type: ignore | ||
|
|
||
| @affine.setter | ||
| def affine(self, d: torch.Tensor) -> None: | ||
| """Set the affine.""" | ||
| self.meta["affine"] = d | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.