diff --git a/docs/source/data.rst b/docs/source/data.rst index e76eb53f39..0910001783 100644 --- a/docs/source/data.rst +++ b/docs/source/data.rst @@ -271,3 +271,14 @@ ThreadDataLoader TestTimeAugmentation ~~~~~~~~~~~~~~~~~~~~ .. autoclass:: monai.data.TestTimeAugmentation + + +Meta Object +----------- +.. automodule:: monai.data.meta_obj + :members: + +MetaTensor +---------- +.. autoclass:: monai.data.MetaTensor + :members: diff --git a/monai/data/__init__.py b/monai/data/__init__.py index 53aa3d3f46..cdab2a1037 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -46,6 +46,8 @@ resolve_writer, ) from .iterable_dataset import CSVIterableDataset, IterableDataset, ShuffleBuffer +from .meta_obj import get_track_meta, get_track_transforms, set_track_meta, set_track_transforms +from .meta_tensor import MetaTensor from .nifti_saver import NiftiSaver from .nifti_writer import write_nifti from .png_saver import PNGSaver diff --git a/monai/data/meta_obj.py b/monai/data/meta_obj.py new file mode 100644 index 0000000000..d60ec6e473 --- /dev/null +++ b/monai/data/meta_obj.py @@ -0,0 +1,207 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from copy import deepcopy +from typing import Any, Callable, Sequence + +_TRACK_META = True +_TRACK_TRANSFORMS = True + +__all__ = ["get_track_meta", "get_track_transforms", "set_track_meta", "set_track_transforms", "MetaObj"] + + +def set_track_meta(val: bool) -> None: + """ + Boolean to set whether metadata is tracked. If `True`, metadata will be associated + its data by using subclasses of `MetaObj`. If `False`, then data will be returned + with empty metadata. + + If both `set_track_meta` and `set_track_transforms` are set to + `False`, then standard data objects will be returned (e.g., `torch.Tensor` and + `np.ndarray`) as opposed to our enhanced objects. + + By default, this is `True`, and most users will want to leave it this way. However, + if you are experiencing any problems regarding metadata, and aren't interested in + preserving metadata, then you can disable it. + """ + global _TRACK_META + _TRACK_META = val + + +def set_track_transforms(val: bool) -> None: + """ + Boolean to set whether transforms are tracked. If `True`, applied transforms will be + associated its data by using subclasses of `MetaObj`. If `False`, then transforms + won't be tracked. + + If both `set_track_meta` and `set_track_transforms` are set to + `False`, then standard data objects will be returned (e.g., `torch.Tensor` and + `np.ndarray`) as opposed to our enhanced objects. + + By default, this is `True`, and most users will want to leave it this way. However, + if you are experiencing any problems regarding transforms, and aren't interested in + preserving transforms, then you can disable it. + """ + global _TRACK_TRANSFORMS + _TRACK_TRANSFORMS = val + + +def get_track_meta() -> bool: + """ + Return the boolean as to whether metadata is tracked. If `True`, metadata will be + associated its data by using subclasses of `MetaObj`. If `False`, then data will be + returned with empty metadata. + + If both `set_track_meta` and `set_track_transforms` are set to + `False`, then standard data objects will be returned (e.g., `torch.Tensor` and + `np.ndarray`) as opposed to our enhanced objects. + + By default, this is `True`, and most users will want to leave it this way. However, + if you are experiencing any problems regarding metadata, and aren't interested in + preserving metadata, then you can disable it. + """ + return _TRACK_META + + +def get_track_transforms() -> bool: + """ + Return the boolean as to whether transforms are tracked. If `True`, applied + transforms will be associated its data by using subclasses of `MetaObj`. If `False`, + then transforms won't be tracked. + + If both `set_track_meta` and `set_track_transforms` are set to + `False`, then standard data objects will be returned (e.g., `torch.Tensor` and + `np.ndarray`) as opposed to our enhanced objects. + + By default, this is `True`, and most users will want to leave it this way. However, + if you are experiencing any problems regarding transforms, and aren't interested in + preserving transforms, then you can disable it. + """ + return _TRACK_TRANSFORMS + + +class MetaObj: + """ + Abstract base class that stores data as well as any extra metadata. + + This allows for subclassing `torch.Tensor` and `np.ndarray` through multiple + inheritance. + + Metadata is stored in the form of a dictionary. + + Behavior should be the same as extended class (e.g., `torch.Tensor` or `np.ndarray`) + aside from the extended meta functionality. + + Copying of information: + + * For `c = a + b`, then auxiliary data (e.g., metadata) will be copied from the + first instance of `MetaObj`. + + """ + + _meta: dict + + @staticmethod + def flatten_meta_objs(args: Sequence[Any]) -> list[MetaObj]: + """ + Recursively flatten input and return all instances of `MetaObj` as a single + list. This means that for both `torch.add(a, b)`, `torch.stack([a, b])` (and + their numpy equivalents), we return `[a, b]` if both `a` and `b` are of type + `MetaObj`. + + Args: + args: Sequence of inputs to be flattened. + Returns: + list of nested `MetaObj` from input. + """ + out = [] + for a in args: + if isinstance(a, (list, tuple)): + out += MetaObj.flatten_meta_objs(a) + elif isinstance(a, MetaObj): + out.append(a) + return out + + def _copy_attr(self, attribute: str, input_objs: list[MetaObj], default_fn: Callable, deep_copy: bool) -> None: + """ + Copy an attribute from the first in a list of `MetaObj`. In the case of + `torch.add(a, b)`, both `a` and `b` could be `MetaObj` or something else, so + check them all. Copy the first to `self`. + + We also perform a deep copy of the data if desired. + + Args: + attribute: string corresponding to attribute to be copied (e.g., `meta`). + input_objs: List of `MetaObj`. We'll copy the attribute from the first one + that contains that particular attribute. + default_fn: If none of `input_objs` have the attribute that we're + interested in, then use this default function (e.g., `lambda: {}`.) + deep_copy: Should the attribute be deep copied? See `_copy_meta`. + + Returns: + Returns `None`, but `self` should be updated to have the copied attribute. + """ + attributes = [getattr(i, attribute) for i in input_objs] + if len(attributes) > 0: + val = attributes[0] + if deep_copy: + val = deepcopy(val) + setattr(self, attribute, val) + else: + setattr(self, attribute, default_fn()) + + def _copy_meta(self, input_objs: list[MetaObj]) -> None: + """ + Copy metadata from a list of `MetaObj`. For a given attribute, we copy the + adjunct data from the first element in the list containing that attribute. + + If there has been a change in `id` (e.g., `a=b+c`), then deepcopy. Else (e.g., + `a+=1`), then don't. + + Args: + input_objs: list of `MetaObj` to copy data from. + + """ + id_in = id(input_objs[0]) if len(input_objs) > 0 else None + deep_copy = id(self) != id_in + self._copy_attr("meta", input_objs, self.get_default_meta, deep_copy) + + def get_default_meta(self) -> dict: + """Get the default meta. + + Returns: + default metadata. + """ + return {} + + def __repr__(self) -> str: + """String representation of class.""" + out: str = super().__repr__() + + out += "\nMetaData\n" + if self.meta is not None: + out += "".join(f"\t{k}: {v}\n" for k, v in self.meta.items()) + else: + out += "None" + + return out + + @property + def meta(self) -> dict: + """Get the meta.""" + return self._meta + + @meta.setter + def meta(self, d: dict) -> None: + """Set the meta.""" + self._meta = d diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py new file mode 100644 index 0000000000..c5b95f8d08 --- /dev/null +++ b/monai/data/meta_tensor.py @@ -0,0 +1,148 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import warnings +from typing import Callable + +import torch + +from monai.data.meta_obj import MetaObj, get_track_meta, get_track_transforms +from monai.utils.enums import PostFix + +__all__ = ["MetaTensor"] + + +class MetaTensor(MetaObj, torch.Tensor): + """ + Class that inherits from both `torch.Tensor` and `MetaObj`, adding support for metadata. + + Metadata is stored in the form of a dictionary. Nested, an affine matrix will be + stored. This should be in the form of `torch.Tensor`. + + Behavior should be the same as `torch.Tensor` aside from the extended + meta functionality. + + Copying of information: + + * For `c = a + b`, then auxiliary data (e.g., metadata) will be copied from the + first instance of `MetaTensor`. + + Example: + .. code-block:: python + + import torch + from monai.data import MetaTensor + + t = torch.tensor([1,2,3]) + affine = torch.eye(4) * 100 + meta = {"some": "info"} + m = MetaTensor(t, affine=affine, meta=meta) + m2 = m+m + assert isinstance(m2, MetaTensor) + assert m2.meta["some"] == "info" + assert m2.affine == affine + + Notes: + - Older versions of pytorch (<=1.8), `torch.jit.trace(net, im)` may + not work if `im` is of type `MetaTensor`. This can be resolved with + `torch.jit.trace(net, im.as_tensor())`. + - A warning will be raised if in the constructor `affine` is not `None` and + `meta` already contains the key `affine`. + """ + + @staticmethod + def __new__(cls, x, affine: torch.Tensor | None = None, meta: dict | None = None, *args, **kwargs) -> MetaTensor: + """ + If `meta` is given, use it. Else, if `meta` exists in the input tensor, use it. + Else, use the default value. Similar for the affine, except this could come from + four places. + Priority: `affine`, `meta["affine"]`, `x.affine`, `get_default_affine`. + """ + out: MetaTensor = torch.as_tensor(x, *args, **kwargs).as_subclass(cls) # type: ignore + # set meta + if meta is not None: + out.meta = meta + elif isinstance(x, MetaObj): + out.meta = x.meta + else: + out.meta = out.get_default_meta() + # set the affine + if affine is not None: + if "affine" in out.meta: + warnings.warn("Setting affine, but the applied meta contains an affine. " "This will be overwritten.") + out.affine = affine + elif "affine" in out.meta: + pass # nothing to do + elif isinstance(x, MetaTensor): + out.affine = x.affine + else: + out.affine = out.get_default_affine() + out.affine = out.affine.to(out.device) + + return out + + def _copy_attr(self, attribute: str, input_objs: list[MetaObj], default_fn: Callable, deep_copy: bool) -> None: + super()._copy_attr(attribute, input_objs, default_fn, deep_copy) + val = getattr(self, attribute) + if isinstance(val, torch.Tensor): + setattr(self, attribute, val.to(self.device)) + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None) -> torch.Tensor: + """Wraps all torch functions.""" + if kwargs is None: + kwargs = {} + ret: MetaTensor = super().__torch_function__(func, types, args, kwargs) + # e.g., __repr__ returns a string + if not isinstance(ret, torch.Tensor): + return ret + if not (get_track_meta() or get_track_transforms()): + return ret.as_tensor() + meta_args = MetaObj.flatten_meta_objs(list(args) + list(kwargs.values())) + ret._copy_meta(meta_args) + ret.affine = ret.affine.to(ret.device) + return ret + + def get_default_affine(self) -> torch.Tensor: + return torch.eye(4, device=self.device) + + def as_tensor(self) -> torch.Tensor: + """ + Return the `MetaTensor` as a `torch.Tensor`. + It is OS dependent as to whether this will be a deep copy or not. + """ + return self.as_subclass(torch.Tensor) # type: ignore + + def as_dict(self, key: str) -> dict: + """ + Get the object as a dictionary for backwards compatibility. + + Args: + key: Base key to store main data. The key for the metadata will be + determined using `PostFix.meta`. + + Return: + A dictionary consisting of two keys, the main data (stored under `key`) and + the metadata. + """ + return {key: self.as_tensor(), PostFix.meta(key): self.meta} + + @property + def affine(self) -> torch.Tensor: + """Get the affine.""" + return self.meta["affine"] # type: ignore + + @affine.setter + def affine(self, d: torch.Tensor) -> None: + """Set the affine.""" + self.meta["affine"] = d diff --git a/tests/test_meta_tensor.py b/tests/test_meta_tensor.py new file mode 100644 index 0000000000..1721e7d2b9 --- /dev/null +++ b/tests/test_meta_tensor.py @@ -0,0 +1,275 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import random +import string +import tempfile +import unittest +import warnings +from copy import deepcopy +from typing import Optional, Union + +import torch +from parameterized import parameterized + +from monai.data.meta_obj import get_track_meta, get_track_transforms, set_track_meta, set_track_transforms +from monai.data.meta_tensor import MetaTensor +from monai.utils.enums import PostFix +from monai.utils.module import get_torch_version_tuple +from tests.utils import TEST_DEVICES, assert_allclose, skip_if_no_cuda + +PT_VER_MAJ, PT_VER_MIN = get_torch_version_tuple() + +DTYPES = [[torch.float32], [torch.float64], [torch.float16], [torch.int64], [torch.int32]] +TESTS = [] +for _device in TEST_DEVICES: + for _dtype in DTYPES: + TESTS.append((*_device, *_dtype)) + + +def rand_string(min_len=5, max_len=10): + str_size = random.randint(min_len, max_len) + chars = string.ascii_letters + string.punctuation + return "".join(random.choice(chars) for _ in range(str_size)) + + +class TestMetaTensor(unittest.TestCase): + @staticmethod + def get_im(shape=None, dtype=None, device=None): + if shape is None: + shape = shape = (1, 10, 8) + affine = torch.randint(0, 10, (4, 4)) + meta = {"fname": rand_string()} + t = torch.rand(shape) + if dtype is not None: + t = t.to(dtype) + if device is not None: + t = t.to(device) + m = MetaTensor(t.clone(), affine, meta) + return m, t + + def check_ids(self, a, b, should_match): + comp = self.assertEqual if should_match else self.assertNotEqual + comp(id(a), id(b)) + + def check( + self, + out: torch.Tensor, + orig: torch.Tensor, + *, + shape: bool = True, + vals: bool = True, + ids: bool = True, + device: Optional[Union[str, torch.device]] = None, + meta: bool = True, + check_ids: bool = True, + **kwargs, + ): + if device is None: + device = orig.device + + # check the image + self.assertIsInstance(out, type(orig)) + if shape: + assert_allclose(torch.as_tensor(out.shape), torch.as_tensor(orig.shape)) + if vals: + assert_allclose(out, orig, **kwargs) + if check_ids: + self.check_ids(out, orig, ids) + self.assertTrue(str(device) in str(out.device)) + + # check meta and affine are equal and affine is on correct device + if isinstance(orig, MetaTensor) and isinstance(out, MetaTensor) and meta: + orig_meta_no_affine = deepcopy(orig.meta) + del orig_meta_no_affine["affine"] + out_meta_no_affine = deepcopy(out.meta) + del out_meta_no_affine["affine"] + self.assertEqual(orig_meta_no_affine, out_meta_no_affine) + assert_allclose(out.affine, orig.affine) + self.assertTrue(str(device) in str(out.affine.device)) + if check_ids: + self.check_ids(out.affine, orig.affine, ids) + self.check_ids(out.meta, orig.meta, ids) + + @parameterized.expand(TESTS) + def test_as_tensor(self, device, dtype): + m, t = self.get_im(device=device, dtype=dtype) + t2 = m.as_tensor() + self.assertIsInstance(t2, torch.Tensor) + self.assertNotIsInstance(t2, MetaTensor) + self.assertIsInstance(m, MetaTensor) + self.check(t, t2, ids=False) + + def test_as_dict(self): + m, _ = self.get_im() + m_dict = m.as_dict("im") + im, meta = m_dict["im"], m_dict[PostFix.meta("im")] + affine = meta.pop("affine") + m2 = MetaTensor(im, affine, meta) + self.check(m2, m, check_ids=False) + + @parameterized.expand(TESTS) + def test_constructor(self, device, dtype): + m, t = self.get_im(device=device, dtype=dtype) + # construct from pre-existing + m1 = MetaTensor(m.clone()) + self.check(m, m1, ids=False, meta=False) + # meta already has affine + m2 = MetaTensor(t.clone(), meta=m.meta) + self.check(m, m2, ids=False, meta=False) + # meta dosen't have affine + affine = m.meta.pop("affine") + m3 = MetaTensor(t.clone(), affine=affine, meta=m.meta) + self.check(m, m3, ids=False, meta=False) + + @parameterized.expand(TESTS) + @skip_if_no_cuda + def test_to_cuda(self, device, dtype): + """Test `to`, `cpu` and `cuda`. For `to`, check args and kwargs.""" + orig, _ = self.get_im(device=device, dtype=dtype) + m = orig.clone() + m = m.to("cuda") + self.check(m, orig, ids=False, device="cuda") + m = m.cpu() + self.check(m, orig, ids=False, device="cpu") + m = m.cuda() + self.check(m, orig, ids=False, device="cuda") + m = m.to("cpu") + self.check(m, orig, ids=False, device="cpu") + m = m.to(device="cuda") + self.check(m, orig, ids=False, device="cuda") + m = m.to(device="cpu") + self.check(m, orig, ids=False, device="cpu") + + @skip_if_no_cuda + def test_affine_device(self): + m, _ = self.get_im() # device="cuda") + m.affine = torch.eye(4) + self.assertEqual(m.device, m.affine.device) + + @parameterized.expand(TESTS) + def test_copy(self, device, dtype): + m, _ = self.get_im(device=device, dtype=dtype) + # shallow copy + a = m + self.check(a, m, ids=True) + # deepcopy + a = deepcopy(m) + self.check(a, m, ids=False) + # clone + a = m.clone() + self.check(a, m, ids=False) + + @parameterized.expand(TESTS) + def test_add(self, device, dtype): + m1, t1 = self.get_im(device=device, dtype=dtype) + m2, t2 = self.get_im(device=device, dtype=dtype) + self.check(m1 + m2, t1 + t2, ids=False) + self.check(torch.add(m1, m2), t1 + t2, ids=False) + self.check(torch.add(input=m1, other=m2), t1 + t2, ids=False) + self.check(torch.add(m1, other=m2), t1 + t2, ids=False) + m3 = deepcopy(m2) + t3 = deepcopy(t2) + m3 += 3 + t3 += 3 + self.check(m3, t3, ids=False) + # check torch.Tensor+MetaTensor and MetaTensor+torch.Tensor + self.check(torch.add(m1, t2), t1 + t2, ids=False) + self.check(torch.add(t2, m1), t1 + t2, ids=False) + + @parameterized.expand(TEST_DEVICES) + def test_conv(self, device): + im, _ = self.get_im((1, 3, 10, 8, 12), device=device) + conv = torch.nn.Conv3d(im.shape[1], 5, 3) + conv.to(device) + out = conv(im) + self.check(out, im, shape=False, vals=False, ids=False) + + @parameterized.expand(TESTS) + def test_stack(self, device, dtype): + numel = 3 + ims = [self.get_im(device=device, dtype=dtype)[0] for _ in range(numel)] + stacked = torch.stack(ims) + self.assertIsInstance(stacked, MetaTensor) + orig_affine = ims[0].meta.pop("affine") + stacked_affine = stacked.meta.pop("affine") + assert_allclose(orig_affine, stacked_affine) + self.assertEqual(stacked.meta, ims[0].meta) + + def test_get_set_meta_fns(self): + set_track_meta(False) + self.assertEqual(get_track_meta(), False) + set_track_meta(True) + self.assertEqual(get_track_meta(), True) + set_track_transforms(False) + self.assertEqual(get_track_transforms(), False) + set_track_transforms(True) + self.assertEqual(get_track_transforms(), True) + + @parameterized.expand(TEST_DEVICES) + def test_torchscript(self, device): + shape = (1, 3, 10, 8) + im, _ = self.get_im(shape, device=device) + conv = torch.nn.Conv2d(im.shape[1], 5, 3) + conv.to(device) + im_conv = conv(im) + traced_fn = torch.jit.trace(conv, im.as_tensor()) + # save it, load it, use it + with tempfile.TemporaryDirectory() as tmp_dir: + fname = os.path.join(tmp_dir, "im.pt") + torch.jit.save(traced_fn, f=fname) + traced_fn = torch.jit.load(fname) + out = traced_fn(im) + self.assertIsInstance(out, torch.Tensor) + if not isinstance(out, MetaTensor) and PT_VER_MAJ == 1 and PT_VER_MIN <= 9: + warnings.warn( + "When calling `nn.Module(MetaTensor) on a module traced with " + "`torch.jit.trace`, your version of pytorch returns a " + "`torch.Tensor` instead of a `MetaTensor`. Consider upgrading " + "your pytorch version if this is important to you." + ) + im_conv = im_conv.as_tensor() + self.check(out, im_conv, ids=False) + + def test_pickling(self): + m, _ = self.get_im() + with tempfile.TemporaryDirectory() as tmp_dir: + fname = os.path.join(tmp_dir, "im.pt") + torch.save(m, fname) + m2 = torch.load(fname) + if not isinstance(m2, MetaTensor) and PT_VER_MAJ == 1 and PT_VER_MIN <= 7: + warnings.warn("Old version of pytorch. pickling converts `MetaTensor` to `torch.Tensor`.") + m = m.as_tensor() + self.check(m2, m, ids=False) + + @skip_if_no_cuda + def test_amp(self): + shape = (1, 3, 10, 8) + device = "cuda" + im, _ = self.get_im(shape, device=device) + conv = torch.nn.Conv2d(im.shape[1], 5, 3) + conv.to(device) + im_conv = conv(im) + with torch.cuda.amp.autocast(): + im_conv2 = conv(im) + self.check(im_conv2, im_conv, ids=False, rtol=1e-4, atol=1e-3) + + # TODO + # collate + # decollate + # dataset + # dataloader + # matplotlib + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/utils.py b/tests/utils.py index 3065f9b3df..c4f9bd1b70 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -102,8 +102,8 @@ def assert_allclose( if isinstance(desired, torch.Tensor) or isinstance(actual, torch.Tensor): if device_test: np.testing.assert_equal(str(actual.device), str(desired.device), "torch device check") # type: ignore - actual = actual.cpu().numpy() if isinstance(actual, torch.Tensor) else actual - desired = desired.cpu().numpy() if isinstance(desired, torch.Tensor) else desired + actual = actual.detach().cpu().numpy() if isinstance(actual, torch.Tensor) else actual + desired = desired.detach().cpu().numpy() if isinstance(desired, torch.Tensor) else desired np.testing.assert_allclose(actual, desired, *args, **kwargs) @@ -715,5 +715,10 @@ def query_memory(n=2): TEST_NDARRAYS = TEST_NDARRAYS + (gpu_tensor,) # type: ignore +TEST_DEVICES = [[torch.device("cpu")]] +if torch.cuda.is_available(): + TEST_DEVICES.append([torch.device("cuda")]) + + if __name__ == "__main__": print(query_memory())