From e15ec97178fc59ef7fe26f1e3c1edd9b674a727c Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 5 Apr 2022 18:13:20 +0100 Subject: [PATCH 01/26] meta tensor Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/__init__.py | 2 + monai/data/meta_obj.py | 172 ++++++++++++++++++++++++++++++++++++ monai/data/meta_tensor.py | 74 ++++++++++++++++ tests/test_meta_tensor.py | 178 ++++++++++++++++++++++++++++++++++++++ tests/utils.py | 5 ++ 5 files changed, 431 insertions(+) create mode 100644 monai/data/meta_obj.py create mode 100644 monai/data/meta_tensor.py create mode 100644 tests/test_meta_tensor.py diff --git a/monai/data/__init__.py b/monai/data/__init__.py index bed194d2f4..300b21a57d 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -46,6 +46,8 @@ resolve_writer, ) from .iterable_dataset import CSVIterableDataset, IterableDataset, ShuffleBuffer +from .meta_obj import get_track_meta, get_track_transforms, set_track_meta, set_track_transforms +from .meta_tensor import MetaTensor from .nifti_saver import NiftiSaver from .nifti_writer import write_nifti from .png_saver import PNGSaver diff --git a/monai/data/meta_obj.py b/monai/data/meta_obj.py new file mode 100644 index 0000000000..fd95c9737a --- /dev/null +++ b/monai/data/meta_obj.py @@ -0,0 +1,172 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from copy import deepcopy +from typing import Any, Callable, Sequence + +import numpy as np +import torch + +_TRACK_META = True +_TRACK_TRANSFORMS = True + +__all__ = ["get_track_meta", "get_track_transforms", "set_track_meta", "set_track_transforms", "MetaObj"] + + +def set_track_meta(val: bool) -> None: + """ + Boolean to set whether metadata is tracked. If `True`, + `MetaTensor` will be returned where appropriate. If `False`, + `torch.Tensor` will be returned instead. + """ + global _TRACK_META + _TRACK_META = val + + +def set_track_transforms(val: bool) -> None: + """ + Boolean to set whether transforms are tracked. + """ + global _TRACK_TRANSFORMS + _TRACK_TRANSFORMS = val + + +def get_track_meta() -> bool: + """ + Get track meta data boolean. + """ + global _TRACK_META + return _TRACK_META + + +def get_track_transforms() -> bool: + """ + Get track transform boolean. + """ + global _TRACK_TRANSFORMS + return _TRACK_TRANSFORMS + + +class MetaObj: + """ + Class that stores meta and affine. + + We store the affine as its own element, so that this can be updated by + transforms. All other meta data that we don't plan on touching until we + need to save the image to file lives in `meta`. + + This allows for subclassing `np.ndarray` and `torch.Tensor`. + + Copying metadata: + * For `c = a + b`, then the meta data will be copied from the first + instance of `MetaImage`. + """ + + _meta: dict + _affine: torch.Tensor + + def set_initial_val(self, attribute: str, input_arg: Any, input_tensor: MetaObj, default_fn: Callable) -> None: + """ + Set the initial value. Try to use input argument, but if this is None + and there is a MetaImage input, then copy that. Failing both these two, + use a default value. + """ + if input_arg is None: + input_arg = getattr(input_tensor, attribute, None) + if input_arg is None: + input_arg = default_fn(self) + setattr(self, attribute, input_arg) + + @staticmethod + def get_tensors_or_arrays(args: Sequence[Any]) -> list[MetaObj]: + """ + Recursively extract all instances of `MetaObj`. + Works for `torch.add(a, b)`, `torch.stack([a, b])` and numpy equivalents. + """ + out = [] + for a in args: + if isinstance(a, (list, tuple)): + out += MetaObj.get_tensors_or_arrays(a) + elif isinstance(a, MetaObj): + out.append(a) + return out + + def _copy_attr( + self, attribute: str, input_objs: list[MetaObj], default_fn: Callable, deepcopy_required: bool + ) -> None: + """ + Copy an attribute from the first in a list of `MetaObj` + In the cases `torch.add(a, b)` and `torch.add(input=a, other=b)`, + both `a` and `b` could be `MetaObj` or `torch.Tensor` so check + them all. Copy the first to the output, and make sure on correct + device. + Might have the MetaObj nested in list, e.g., `torch.stack([a, b])`. + """ + attributes = [getattr(i, attribute) for i in input_objs] + if len(attributes) > 0: + val = attributes[0] + if deepcopy_required: + val = deepcopy(val) + if isinstance(self, torch.Tensor) and isinstance(val, torch.Tensor): + val = val.to(self.device) + setattr(self, attribute, val) + else: + setattr(self, attribute, default_fn()) + + def _copy_meta(self, input_meta_objs: list[MetaObj]) -> None: + """ + Copy meta data from a list of `MetaObj`. + If there has been a change in `id` (e.g., `a+b`), then deepcopy. Else (e.g., `a+=1`), don't. + """ + id_in = id(input_meta_objs[0]) if len(input_meta_objs) > 0 else None + deepcopy_required = id(self) != id_in + attributes = ("affine", "meta") + default_fns: tuple[Callable, ...] = (self.get_default_affine, self.get_default_meta) + for attribute, default_fn in zip(attributes, default_fns): + self._copy_attr(attribute, input_meta_objs, default_fn, deepcopy_required) + + def get_default_meta(self) -> dict: + return {} + + def get_default_affine(self) -> torch.Tensor | np.ndarray: + raise NotImplementedError() + + def __repr__(self) -> str: + """String representation of class.""" + out: str = super().__repr__() + + out += f"\nAffine\n{self.affine}" + + out += "\nMetaData\n" + if self.meta is not None: + out += "".join(f"\t{k}: {v}\n" for k, v in self.meta.items()) + else: + out += "None" + + return out + + @property + def affine(self) -> torch.Tensor: + return self._affine + + @affine.setter + def affine(self, d: torch.Tensor) -> None: + self._affine = d + + @property + def meta(self) -> dict: + return self._meta + + @meta.setter + def meta(self, d: dict): + self._meta = d diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py new file mode 100644 index 0000000000..843f42518d --- /dev/null +++ b/monai/data/meta_tensor.py @@ -0,0 +1,74 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import warnings + +import torch + +from monai.data.meta_obj import MetaObj, get_track_meta, get_track_transforms + +__all__ = ["MetaTensor"] + + +class MetaTensor(MetaObj, torch.Tensor): + """ + Class that extends upon `torch.Tensor`, adding support for meta data. + + We store the affine as its own element, so that this can be updated by + transforms. All other meta data that we don't plan on touching until we + need to save the image to file lives in `meta`. + + Behavior should be the same as `torch.Tensor` aside from the extended + functionality. + + Copying metadata: + * For `c = a + b`, then the meta data will be copied from the first + instance of `MetaTensor`. + """ + + @staticmethod + def __new__(cls, x, affine: torch.Tensor | None = None, meta: dict | None = None, *args, **kwargs) -> MetaTensor: + return torch.as_tensor(x, *args, **kwargs).as_subclass(cls) # type: ignore + + def __init__(self, x, affine: torch.Tensor | None = None, meta: dict | None = None) -> None: + """If `affine` is given, use it. Else, if `affine` exists in the input tensor, use it. Else, use + the default value. The same is true for `meta` and `transforms`.""" + self.set_initial_val("affine", affine, x, self.get_default_affine) + self.set_initial_val("meta", meta, x, self.get_default_meta) + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None) -> torch.Tensor: + """Wraps all torch functions.""" + if kwargs is None: + kwargs = {} + ret: MetaTensor = super().__torch_function__(func, types, args, kwargs) + # e.g., __repr__ returns a string + if not isinstance(ret, torch.Tensor): + return ret + if not (get_track_meta() or get_track_transforms()): + return ret.as_tensor() + meta_args = MetaObj.get_tensors_or_arrays(list(args) + list(kwargs.values())) + ret._copy_meta(meta_args) + return ret + + def get_default_affine(self) -> torch.Tensor: + return torch.eye(4, device=self.device) + + def as_tensor(self) -> torch.Tensor: + """ + Return the `MetaTensor` as a `torch.Tensor`. + It is OS dependent as to whether this will be a deep copy or not. + """ + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + return torch.tensor(self) diff --git a/tests/test_meta_tensor.py b/tests/test_meta_tensor.py new file mode 100644 index 0000000000..5a772d930c --- /dev/null +++ b/tests/test_meta_tensor.py @@ -0,0 +1,178 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import string +import unittest +from copy import deepcopy +from typing import Optional, Union + +import torch +from parameterized import parameterized + +from monai.data.meta_tensor import MetaTensor +from tests.utils import TEST_DEVICES, assert_allclose, skip_if_no_cuda + +DTYPES = [[torch.float32], [torch.float64], [torch.float16], [torch.int64], [torch.int32]] +TESTS = [] +for _device in TEST_DEVICES: + for _dtype in DTYPES: + TESTS.append((*_device, *_dtype)) + + +def rand_string(min_len=5, max_len=10): + str_size = random.randint(min_len, max_len) + chars = string.ascii_letters + string.punctuation + return "".join(random.choice(chars) for _ in range(str_size)) + + +class TestMetaTensor(unittest.TestCase): + @staticmethod + def get_im(shape=None, dtype=None, device=None): + if shape is None: + shape = shape = (1, 10, 8) + affine = torch.randint(0, 10, (4, 4)) + meta = {"fname": rand_string()} + t = torch.rand(shape) + if dtype is not None: + t = t.to(dtype) + if device is not None: + t = t.to(device) + m = MetaTensor(t.clone(), affine, meta) + return m, t + + def check_ids(self, a, b, should_match): + comp = self.assertEqual if should_match else self.assertNotEqual + comp(id(a), id(b)) + + def check( + self, + out: torch.Tensor, + orig: torch.Tensor, + *, + shape: bool = True, + vals: bool = True, + ids: bool = True, + device: Optional[Union[str, torch.device]] = None, + meta: bool = True, + ): + if device is None: + device = orig.device + + # check the image + self.assertIsInstance(out, type(orig)) + if shape: + assert_allclose(torch.as_tensor(out.shape), torch.as_tensor(orig.shape)) + if vals: + assert_allclose(out, orig) + self.check_ids(out, orig, ids) + self.assertTrue(str(device) in str(out.device)) + + # check meta and affine are equal and affine is on correct device + if isinstance(orig, MetaTensor) and isinstance(out, MetaTensor) and meta: + self.assertEqual(out.meta, orig.meta) + assert_allclose(out.affine, orig.affine) + self.assertTrue(str(device) in str(out.affine.device)) + self.check_ids(out.affine, orig.affine, ids) + self.check_ids(out.meta, orig.meta, ids) + + @parameterized.expand(TESTS) + def test_as_tensor(self, device, dtype): + m, t = self.get_im(device=device, dtype=dtype) + t2 = m.as_tensor() + self.assertIsInstance(t2, torch.Tensor) + self.assertNotIsInstance(t2, MetaTensor) + self.assertIsInstance(m, MetaTensor) + self.check(t, t2, ids=False) + + @parameterized.expand(TESTS) + def test_constructor(self, device, dtype): + m, t = self.get_im(device=device, dtype=dtype) + m2 = MetaTensor(t.clone(), m.affine, m.meta) + self.check(m, m2, ids=False, meta=False) + + @parameterized.expand(TESTS) + @skip_if_no_cuda + def test_to_cuda(self, device, dtype): + """Test `to`, `cpu` and `cuda`. For `to`, check args and kwargs.""" + orig, _ = self.get_im(device=device, dtype=dtype) + m = orig.clone() + m = m.to("cuda") + self.check(m, orig, ids=False, device="cuda") + m = m.cpu() + self.check(m, orig, ids=False, device="cpu") + m = m.cuda() + self.check(m, orig, ids=False, device="cuda") + m = m.to("cpu") + self.check(m, orig, ids=False, device="cpu") + m = m.to(device="cuda") + self.check(m, orig, ids=False, device="cuda") + m = m.to(device="cpu") + self.check(m, orig, ids=False, device="cpu") + + @parameterized.expand(TESTS) + def test_copy(self, device, dtype): + m, _ = self.get_im(device=device, dtype=dtype) + # shallow copy + a = m + self.check(a, m, ids=True) + # deepcopy + a = deepcopy(m) + self.check(a, m, ids=False) + # clone + a = m.clone() + self.check(a, m, ids=False) + + @parameterized.expand(TESTS) + def test_add(self, device, dtype): + m1, t1 = self.get_im(device=device, dtype=dtype) + m2, t2 = self.get_im(device=device, dtype=dtype) + self.check(m1 + m2, t1 + t2, ids=False) + self.check(torch.add(m1, m2), t1 + t2, ids=False) + self.check(torch.add(input=m1, other=m2), t1 + t2, ids=False) + self.check(torch.add(m1, other=m2), t1 + t2, ids=False) + m3 = deepcopy(m2) + t3 = deepcopy(m2) + m3 += 3 + t3 += 3 + self.check(m3, t3, ids=False) + # check torch.Tensor+MetaTensor and MetaTensor+torch.Tensor + self.check(torch.add(m1, t2), t1 + t2, ids=False) + self.check(torch.add(t2, m1), t1 + t2, ids=False) + + @parameterized.expand(TEST_DEVICES) + def test_conv(self, device): + im, _ = self.get_im((1, 3, 10, 8, 12), device=device) + conv = torch.nn.Conv3d(im.shape[1], 5, 3, device=device) + out = conv(im) + self.check(out, im, shape=False, vals=False, ids=False) + + @parameterized.expand(TESTS) + def test_stack(self, device, dtype): + numel = 3 + ims = [self.get_im(device=device, dtype=dtype)[0] for _ in range(numel)] + stacked = torch.stack(ims) + self.assertIsInstance(stacked, MetaTensor) + assert_allclose(stacked.affine, ims[0].affine) + self.assertEqual(stacked.meta, ims[0].meta) + + # TODO + # collate + # decollate + # dataset + # dataloader + # torchscript + # matplotlib + # pickling + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/utils.py b/tests/utils.py index 3065f9b3df..9c7a9984a8 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -715,5 +715,10 @@ def query_memory(n=2): TEST_NDARRAYS = TEST_NDARRAYS + (gpu_tensor,) # type: ignore +TEST_DEVICES = [[torch.device("cpu")]] +if torch.cuda.is_available(): + TEST_DEVICES.append([torch.device("cuda")]) + + if __name__ == "__main__": print(query_memory()) From accca59aef9c3afddf574fd6489fa042b5548496 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 7 Apr 2022 15:51:36 +0100 Subject: [PATCH 02/26] fixes Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/meta_obj.py | 178 +++++++++++++++++++++++++++++--------- monai/data/meta_tensor.py | 54 ++++++++---- tests/test_meta_tensor.py | 11 +++ 3 files changed, 185 insertions(+), 58 deletions(-) diff --git a/monai/data/meta_obj.py b/monai/data/meta_obj.py index fd95c9737a..ea271bee59 100644 --- a/monai/data/meta_obj.py +++ b/monai/data/meta_obj.py @@ -14,9 +14,6 @@ from copy import deepcopy from typing import Any, Callable, Sequence -import numpy as np -import torch - _TRACK_META = True _TRACK_TRANSFORMS = True @@ -25,9 +22,17 @@ def set_track_meta(val: bool) -> None: """ - Boolean to set whether metadata is tracked. If `True`, - `MetaTensor` will be returned where appropriate. If `False`, - `torch.Tensor` will be returned instead. + Boolean to set whether meta data is tracked. If `True`, meta data will be associated + its data by using subclasses of `MetaObj`. If `False`, then data will be returned + with empty meta data. + + If both `set_track_meta` and `set_track_transforms` are set to + `False`, then standard data objects will be returned (e.g., `torch.Tensor` and + `np.ndarray`) as opposed to our enhanced objects. + + By default, this is `True`, and most users will want to leave it this way. However, + if you are experiencing any problems regarding meta data, and aren't interested in + preserving meta data, then you can disable it. """ global _TRACK_META _TRACK_META = val @@ -35,7 +40,17 @@ def set_track_meta(val: bool) -> None: def set_track_transforms(val: bool) -> None: """ - Boolean to set whether transforms are tracked. + Boolean to set whether transforms are tracked. If `True`, applied transforms will be + associated its data by using subclasses of `MetaObj`. If `False`, then transforms + won't be tracked. + + If both `set_track_meta` and `set_track_transforms` are set to + `False`, then standard data objects will be returned (e.g., `torch.Tensor` and + `np.ndarray`) as opposed to our enhanced objects. + + By default, this is `True`, and most users will want to leave it this way. However, + if you are experiencing any problems regarding transforms, and aren't interested in + preserving transforms, then you can disable it. """ global _TRACK_TRANSFORMS _TRACK_TRANSFORMS = val @@ -43,60 +58,107 @@ def set_track_transforms(val: bool) -> None: def get_track_meta() -> bool: """ - Get track meta data boolean. + Return the boolean as to whether meta data is tracked. If `True`, meta data will be + associated its data by using subclasses of `MetaObj`. If `False`, then data will be + returned with empty meta data. + + If both `set_track_meta` and `set_track_transforms` are set to + `False`, then standard data objects will be returned (e.g., `torch.Tensor` and + `np.ndarray`) as opposed to our enhanced objects. + + By default, this is `True`, and most users will want to leave it this way. However, + if you are experiencing any problems regarding meta data, and aren't interested in + preserving meta data, then you can disable it. """ - global _TRACK_META return _TRACK_META def get_track_transforms() -> bool: """ - Get track transform boolean. + Return the boolean as to whether transforms are tracked. If `True`, applied + transforms will be associated its data by using subclasses of `MetaObj`. If `False`, + then transforms won't be tracked. + + If both `set_track_meta` and `set_track_transforms` are set to + `False`, then standard data objects will be returned (e.g., `torch.Tensor` and + `np.ndarray`) as opposed to our enhanced objects. + + By default, this is `True`, and most users will want to leave it this way. However, + if you are experiencing any problems regarding transforms, and aren't interested in + preserving transforms, then you can disable it. """ - global _TRACK_TRANSFORMS return _TRACK_TRANSFORMS class MetaObj: """ - Class that stores meta and affine. + Abstract base class that stores data as well as any extra meta data and an affine + transformation matrix. + + This allows for subclassing `torch.Tensor` and `np.ndarray` through multiple + inheritance. + + Meta data is stored in the form of of a dictionary. Affine matrices are stored in + the form of e.g., `torch.Tensor` or `np.ndarray`. We store the affine as its own element, so that this can be updated by transforms. All other meta data that we don't plan on touching until we need to save the image to file lives in `meta`. - This allows for subclassing `np.ndarray` and `torch.Tensor`. + Behavior should be the same as extended class (e.g., `torch.Tensor` or `np.ndarray`) + aside from the extended meta functionality. - Copying metadata: - * For `c = a + b`, then the meta data will be copied from the first - instance of `MetaImage`. + Copying of information: + * For `c = a + b`, then auxiliary data (e.g., meta data) will be copied from the + first instance of `MetaObj`. """ _meta: dict - _affine: torch.Tensor + _affine: Any - def set_initial_val(self, attribute: str, input_arg: Any, input_tensor: MetaObj, default_fn: Callable) -> None: + def _set_initial_val(self, attribute: str, input_arg: Any, input_obj: Any, default_fn: Callable) -> None: """ - Set the initial value. Try to use input argument, but if this is None - and there is a MetaImage input, then copy that. Failing both these two, - use a default value. + Set the initial value of an attribute (e.g., `meta` or `affine`). + First, try to set `attribute` using `input_arg`. But if `input_arg` is `None`, + then we try to copy the value from `input_obj`. But if value is also `None`, + then we finally fall back on using the `default_fn`. + + Args: + attribute: string corresponding to attribute we want to set (e.g., `meta` or + `affine`). + input_arg: the value we would like `attribute` to take or `None` if not + given. + input_obj: if `input_arg` is `None`, try to copy `attribute` from + `input_obj`, if it is present and not `None`. + default_fn: function to be used if all previous arguments return `None`. + Default meta data might be empty dictionary so could be as simple as + `lambda: {}`. + Returns: + Returns `None`, but `self` should have the updated `attribute`. """ if input_arg is None: - input_arg = getattr(input_tensor, attribute, None) + input_arg = getattr(input_obj, attribute, None) if input_arg is None: input_arg = default_fn(self) setattr(self, attribute, input_arg) @staticmethod - def get_tensors_or_arrays(args: Sequence[Any]) -> list[MetaObj]: + def flatten_meta_objs(args: Sequence[Any]) -> list[MetaObj]: """ - Recursively extract all instances of `MetaObj`. - Works for `torch.add(a, b)`, `torch.stack([a, b])` and numpy equivalents. + Recursively flatten input and return all instances of `MetaObj` as a single + list. This means that for both `torch.add(a, b)`, `torch.stack([a, b])` (and + their numpy equivalents), we return `[a, b]` if both `a` and `b` are of type + `MetaObj`. + + Args: + args: Sequence of inputs to be flattened. + Returns: + list of nested `MetaObj` from input. """ out = [] for a in args: if isinstance(a, (list, tuple)): - out += MetaObj.get_tensors_or_arrays(a) + out += MetaObj.flatten_meta_objs(a) elif isinstance(a, MetaObj): out.append(a) return out @@ -105,40 +167,66 @@ def _copy_attr( self, attribute: str, input_objs: list[MetaObj], default_fn: Callable, deepcopy_required: bool ) -> None: """ - Copy an attribute from the first in a list of `MetaObj` - In the cases `torch.add(a, b)` and `torch.add(input=a, other=b)`, - both `a` and `b` could be `MetaObj` or `torch.Tensor` so check - them all. Copy the first to the output, and make sure on correct - device. - Might have the MetaObj nested in list, e.g., `torch.stack([a, b])`. + Copy an attribute from the first in a list of `MetaObj`. In the case of + `torch.add(a, b)`, both `a` and `b` could be `MetaObj` or something else, so + check them all. Copy the first to `self`. + + We also perform a deep copy of the data if desired. + + Args: + attribute: string corresponding to attribute to be copied (e.g., `meta` or + `affine`). + input_objs: List of `MetaObj`. We'll copy the attribute from the first one + that contains that particular attribute. + default_fn: If none of `input_objs` have the attribute that we're + interested in, then use this default function (e.g., `lambda: {}`.) + deepcopy_required: Should the attribute be deep copied? See `_copy_meta`. + + Returns: + Returns `None`, but `self` should be updated to have the copied attribute. """ attributes = [getattr(i, attribute) for i in input_objs] if len(attributes) > 0: val = attributes[0] if deepcopy_required: val = deepcopy(val) - if isinstance(self, torch.Tensor) and isinstance(val, torch.Tensor): - val = val.to(self.device) setattr(self, attribute, val) else: setattr(self, attribute, default_fn()) - def _copy_meta(self, input_meta_objs: list[MetaObj]) -> None: + def _copy_meta(self, input_objs: list[MetaObj]) -> None: """ - Copy meta data from a list of `MetaObj`. - If there has been a change in `id` (e.g., `a+b`), then deepcopy. Else (e.g., `a+=1`), don't. + Copy meta data from a list of `MetaObj`. For a given attribute, we copy the + adjunct data from the first element in the list containing that attribute. + + If there has been a change in `id` (e.g., `a=b+c`), then deepcopy. Else (e.g., + `a+=1`), then don't. + + Args: + input_objs: list of `MetaObj` to copy data from. + """ - id_in = id(input_meta_objs[0]) if len(input_meta_objs) > 0 else None + id_in = id(input_objs[0]) if len(input_objs) > 0 else None deepcopy_required = id(self) != id_in attributes = ("affine", "meta") default_fns: tuple[Callable, ...] = (self.get_default_affine, self.get_default_meta) for attribute, default_fn in zip(attributes, default_fns): - self._copy_attr(attribute, input_meta_objs, default_fn, deepcopy_required) + self._copy_attr(attribute, input_objs, default_fn, deepcopy_required) def get_default_meta(self) -> dict: + """Get the default meta. + + Returns: + default meta data. + """ return {} - def get_default_affine(self) -> torch.Tensor | np.ndarray: + def get_default_affine(self) -> Any: + """Get the default affine. + + Returns: + default affine. + """ raise NotImplementedError() def __repr__(self) -> str: @@ -156,17 +244,21 @@ def __repr__(self) -> str: return out @property - def affine(self) -> torch.Tensor: + def affine(self) -> Any: + """Get the affine.""" return self._affine @affine.setter - def affine(self, d: torch.Tensor) -> None: + def affine(self, d: Any) -> None: + """Set the affine.""" self._affine = d @property def meta(self) -> dict: + """Get the meta.""" return self._meta @meta.setter - def meta(self, d: dict): + def meta(self, d: dict) -> None: + """Set the meta.""" self._meta = d diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index 843f42518d..4d60655070 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -11,7 +11,7 @@ from __future__ import annotations -import warnings +from typing import Callable, Optional import torch @@ -22,29 +22,55 @@ class MetaTensor(MetaObj, torch.Tensor): """ - Class that extends upon `torch.Tensor`, adding support for meta data. + Class that inherits from both `torch.Tensor` and `MetaObj`, adding support for meta + data. + + Meta data is stored in the form of of a dictionary. Affine matrices are stored in + the form of `torch.Tensor`. We store the affine as its own element, so that this can be updated by transforms. All other meta data that we don't plan on touching until we need to save the image to file lives in `meta`. Behavior should be the same as `torch.Tensor` aside from the extended - functionality. + meta functionality. + + Copying of information: + * For `c = a + b`, then auxiliary data (e.g., meta data) will be copied from the + first instance of `MetaTensor`. + + Example: + .. code-block:: python - Copying metadata: - * For `c = a + b`, then the meta data will be copied from the first - instance of `MetaTensor`. + import torch + from monai.data import MetaTensor + + t = torch.tensor([1,2,3]) + meta = {"some": "info"} + affine = torch.eye(4) + m = MetaTensor(t, meta=meta, affine=affine) + m2 = m+m + assert isinstance(m2, MetaTensor) + assert m2.meta == meta """ @staticmethod - def __new__(cls, x, affine: torch.Tensor | None = None, meta: dict | None = None, *args, **kwargs) -> MetaTensor: + def __new__(cls, x, affine: Optional[torch.Tensor] = None, meta: Optional[dict] = None, *args, **kwargs) -> MetaTensor: return torch.as_tensor(x, *args, **kwargs).as_subclass(cls) # type: ignore - def __init__(self, x, affine: torch.Tensor | None = None, meta: dict | None = None) -> None: + def __init__(self, x, affine: Optional[torch.Tensor] = None, meta: Optional[dict] = None) -> None: """If `affine` is given, use it. Else, if `affine` exists in the input tensor, use it. Else, use - the default value. The same is true for `meta` and `transforms`.""" - self.set_initial_val("affine", affine, x, self.get_default_affine) - self.set_initial_val("meta", meta, x, self.get_default_meta) + the default value. The same is true for `meta`.""" + self._set_initial_val("affine", affine, x, self.get_default_affine) + self._set_initial_val("meta", meta, x, self.get_default_meta) + + def _copy_attr( + self, attribute: str, input_objs: list[MetaObj], default_fn: Callable, deepcopy_required: bool + ) -> None: + super()._copy_attr(attribute, input_objs, default_fn, deepcopy_required) + val = getattr(self, attribute) + if isinstance(self, torch.Tensor) and isinstance(val, torch.Tensor): + setattr(self, attribute, val.to(self.device)) @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None) -> torch.Tensor: @@ -57,7 +83,7 @@ def __torch_function__(cls, func, types, args=(), kwargs=None) -> torch.Tensor: return ret if not (get_track_meta() or get_track_transforms()): return ret.as_tensor() - meta_args = MetaObj.get_tensors_or_arrays(list(args) + list(kwargs.values())) + meta_args = MetaObj.flatten_meta_objs(list(args) + list(kwargs.values())) ret._copy_meta(meta_args) return ret @@ -69,6 +95,4 @@ def as_tensor(self) -> torch.Tensor: Return the `MetaTensor` as a `torch.Tensor`. It is OS dependent as to whether this will be a deep copy or not. """ - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - return torch.tensor(self) + return self.as_subclass(torch.Tensor) # type: ignore diff --git a/tests/test_meta_tensor.py b/tests/test_meta_tensor.py index 5a772d930c..b0e9f94b01 100644 --- a/tests/test_meta_tensor.py +++ b/tests/test_meta_tensor.py @@ -18,6 +18,7 @@ import torch from parameterized import parameterized +from monai.data.meta_obj import get_track_meta, get_track_transforms, set_track_meta, set_track_transforms from monai.data.meta_tensor import MetaTensor from tests.utils import TEST_DEVICES, assert_allclose, skip_if_no_cuda @@ -164,6 +165,16 @@ def test_stack(self, device, dtype): assert_allclose(stacked.affine, ims[0].affine) self.assertEqual(stacked.meta, ims[0].meta) + def test_get_set_meta_fns(self): + set_track_meta(False) + self.assertEqual(get_track_meta(), False) + set_track_meta(True) + self.assertEqual(get_track_meta(), True) + set_track_transforms(False) + self.assertEqual(get_track_transforms(), False) + set_track_transforms(True) + self.assertEqual(get_track_transforms(), True) + # TODO # collate # decollate From 36a1e754a7a26b77282372e215a9aa883c1dd9af Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 7 Apr 2022 14:52:09 +0000 Subject: [PATCH 03/26] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- monai/data/meta_tensor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index 4d60655070..61a9857769 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -55,10 +55,10 @@ class MetaTensor(MetaObj, torch.Tensor): """ @staticmethod - def __new__(cls, x, affine: Optional[torch.Tensor] = None, meta: Optional[dict] = None, *args, **kwargs) -> MetaTensor: + def __new__(cls, x, affine: torch.Tensor | None = None, meta: dict | None = None, *args, **kwargs) -> MetaTensor: return torch.as_tensor(x, *args, **kwargs).as_subclass(cls) # type: ignore - def __init__(self, x, affine: Optional[torch.Tensor] = None, meta: Optional[dict] = None) -> None: + def __init__(self, x, affine: torch.Tensor | None = None, meta: dict | None = None) -> None: """If `affine` is given, use it. Else, if `affine` exists in the input tensor, use it. Else, use the default value. The same is true for `meta`.""" self._set_initial_val("affine", affine, x, self.get_default_affine) From f71d60f8481a86880770441f22dfc94bf5d0267d Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 7 Apr 2022 16:05:08 +0100 Subject: [PATCH 04/26] deep_copy Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/meta_obj.py | 10 +++++----- monai/data/meta_tensor.py | 4 ++-- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/monai/data/meta_obj.py b/monai/data/meta_obj.py index ea271bee59..714164049c 100644 --- a/monai/data/meta_obj.py +++ b/monai/data/meta_obj.py @@ -164,7 +164,7 @@ def flatten_meta_objs(args: Sequence[Any]) -> list[MetaObj]: return out def _copy_attr( - self, attribute: str, input_objs: list[MetaObj], default_fn: Callable, deepcopy_required: bool + self, attribute: str, input_objs: list[MetaObj], default_fn: Callable, deep_copy: bool ) -> None: """ Copy an attribute from the first in a list of `MetaObj`. In the case of @@ -180,7 +180,7 @@ def _copy_attr( that contains that particular attribute. default_fn: If none of `input_objs` have the attribute that we're interested in, then use this default function (e.g., `lambda: {}`.) - deepcopy_required: Should the attribute be deep copied? See `_copy_meta`. + deep_copy: Should the attribute be deep copied? See `_copy_meta`. Returns: Returns `None`, but `self` should be updated to have the copied attribute. @@ -188,7 +188,7 @@ def _copy_attr( attributes = [getattr(i, attribute) for i in input_objs] if len(attributes) > 0: val = attributes[0] - if deepcopy_required: + if deep_copy: val = deepcopy(val) setattr(self, attribute, val) else: @@ -207,11 +207,11 @@ def _copy_meta(self, input_objs: list[MetaObj]) -> None: """ id_in = id(input_objs[0]) if len(input_objs) > 0 else None - deepcopy_required = id(self) != id_in + deep_copy = id(self) != id_in attributes = ("affine", "meta") default_fns: tuple[Callable, ...] = (self.get_default_affine, self.get_default_meta) for attribute, default_fn in zip(attributes, default_fns): - self._copy_attr(attribute, input_objs, default_fn, deepcopy_required) + self._copy_attr(attribute, input_objs, default_fn, deep_copy) def get_default_meta(self) -> dict: """Get the default meta. diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index 61a9857769..8d9c19b712 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -65,9 +65,9 @@ def __init__(self, x, affine: torch.Tensor | None = None, meta: dict | None = No self._set_initial_val("meta", meta, x, self.get_default_meta) def _copy_attr( - self, attribute: str, input_objs: list[MetaObj], default_fn: Callable, deepcopy_required: bool + self, attribute: str, input_objs: list[MetaObj], default_fn: Callable, deep_copy: bool ) -> None: - super()._copy_attr(attribute, input_objs, default_fn, deepcopy_required) + super()._copy_attr(attribute, input_objs, default_fn, deep_copy) val = getattr(self, attribute) if isinstance(self, torch.Tensor) and isinstance(val, torch.Tensor): setattr(self, attribute, val.to(self.device)) From 51e6b3bcefae3cbad084a1e749f043590330ffd6 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 7 Apr 2022 16:16:40 +0100 Subject: [PATCH 05/26] fixes Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/meta_obj.py | 4 +--- monai/data/meta_tensor.py | 6 ++---- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/monai/data/meta_obj.py b/monai/data/meta_obj.py index 714164049c..7259afbf01 100644 --- a/monai/data/meta_obj.py +++ b/monai/data/meta_obj.py @@ -163,9 +163,7 @@ def flatten_meta_objs(args: Sequence[Any]) -> list[MetaObj]: out.append(a) return out - def _copy_attr( - self, attribute: str, input_objs: list[MetaObj], default_fn: Callable, deep_copy: bool - ) -> None: + def _copy_attr(self, attribute: str, input_objs: list[MetaObj], default_fn: Callable, deep_copy: bool) -> None: """ Copy an attribute from the first in a list of `MetaObj`. In the case of `torch.add(a, b)`, both `a` and `b` could be `MetaObj` or something else, so diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index 8d9c19b712..80f5372456 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -11,7 +11,7 @@ from __future__ import annotations -from typing import Callable, Optional +from typing import Callable import torch @@ -64,9 +64,7 @@ def __init__(self, x, affine: torch.Tensor | None = None, meta: dict | None = No self._set_initial_val("affine", affine, x, self.get_default_affine) self._set_initial_val("meta", meta, x, self.get_default_meta) - def _copy_attr( - self, attribute: str, input_objs: list[MetaObj], default_fn: Callable, deep_copy: bool - ) -> None: + def _copy_attr(self, attribute: str, input_objs: list[MetaObj], default_fn: Callable, deep_copy: bool) -> None: super()._copy_attr(attribute, input_objs, default_fn, deep_copy) val = getattr(self, attribute) if isinstance(self, torch.Tensor) and isinstance(val, torch.Tensor): From 1e9f37044332df83c57ba91e7997a6d2e1ef2343 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 7 Apr 2022 16:44:55 +0100 Subject: [PATCH 06/26] torchscript Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_meta_tensor.py | 12 +++++++++++- tests/utils.py | 4 ++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/tests/test_meta_tensor.py b/tests/test_meta_tensor.py index b0e9f94b01..ab179414e6 100644 --- a/tests/test_meta_tensor.py +++ b/tests/test_meta_tensor.py @@ -175,12 +175,22 @@ def test_get_set_meta_fns(self): set_track_transforms(True) self.assertEqual(get_track_transforms(), True) + @parameterized.expand(TEST_DEVICES) + def test_torchscript(self, device): + shape = (1, 3, 10, 8) + im, _ = self.get_im(shape, device=device) + im2 = im.clone() + conv = torch.nn.Conv2d(im.shape[1], 5, 3, device=device) + traced_fn = torch.jit.trace(conv, im) + out = traced_fn(im2) + self.assertIsInstance(out, MetaTensor) + self.check(out, conv(im), ids=False) + # TODO # collate # decollate # dataset # dataloader - # torchscript # matplotlib # pickling diff --git a/tests/utils.py b/tests/utils.py index 9c7a9984a8..c4f9bd1b70 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -102,8 +102,8 @@ def assert_allclose( if isinstance(desired, torch.Tensor) or isinstance(actual, torch.Tensor): if device_test: np.testing.assert_equal(str(actual.device), str(desired.device), "torch device check") # type: ignore - actual = actual.cpu().numpy() if isinstance(actual, torch.Tensor) else actual - desired = desired.cpu().numpy() if isinstance(desired, torch.Tensor) else desired + actual = actual.detach().cpu().numpy() if isinstance(actual, torch.Tensor) else actual + desired = desired.detach().cpu().numpy() if isinstance(desired, torch.Tensor) else desired np.testing.assert_allclose(actual, desired, *args, **kwargs) From 76e0cceb743b0d723a98d0db643d192618b2769a Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 7 Apr 2022 17:11:25 +0100 Subject: [PATCH 07/26] test pickling torchscript amp Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_meta_tensor.py | 34 ++++++++++++++++++++++++++++++---- 1 file changed, 30 insertions(+), 4 deletions(-) diff --git a/tests/test_meta_tensor.py b/tests/test_meta_tensor.py index ab179414e6..eb139c99c3 100644 --- a/tests/test_meta_tensor.py +++ b/tests/test_meta_tensor.py @@ -11,6 +11,7 @@ import random import string +import tempfile import unittest from copy import deepcopy from typing import Optional, Union @@ -179,12 +180,38 @@ def test_get_set_meta_fns(self): def test_torchscript(self, device): shape = (1, 3, 10, 8) im, _ = self.get_im(shape, device=device) - im2 = im.clone() conv = torch.nn.Conv2d(im.shape[1], 5, 3, device=device) + im_conv = conv(im) traced_fn = torch.jit.trace(conv, im) - out = traced_fn(im2) + # try and use it + out = traced_fn(im) self.assertIsInstance(out, MetaTensor) - self.check(out, conv(im), ids=False) + self.check(out, im_conv, ids=False) + # save it, load it, use it + with tempfile.NamedTemporaryFile() as fname: + torch.jit.save(traced_fn, f=fname.name) + traced_fn2 = torch.jit.load(fname.name) + out2 = traced_fn2(im) + self.assertIsInstance(out2, MetaTensor) + self.check(out2, im_conv, ids=False) + + def test_pickling(self): + m, _ = self.get_im() + with tempfile.NamedTemporaryFile() as fname: + torch.save(m, fname.name) + m2 = torch.load(fname.name) + self.check(m2, m, ids=False) + + @skip_if_no_cuda + def test_amp(self): + shape = (1, 3, 10, 8) + device = "cuda" + im, _ = self.get_im(shape, device=device) + conv = torch.nn.Conv2d(im.shape[1], 5, 3, device=device) + im_conv = conv(im) + with torch.autocast(str(device)): + im_conv2 = conv(im) + self.check(im_conv2, im_conv, ids=False) # TODO # collate @@ -192,7 +219,6 @@ def test_torchscript(self, device): # dataset # dataloader # matplotlib - # pickling if __name__ == "__main__": From 8a60df1f70b639960c28e0114c19e92ed67a0526 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 7 Apr 2022 18:01:59 +0100 Subject: [PATCH 08/26] test fixes Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/meta_tensor.py | 3 ++- tests/test_meta_tensor.py | 5 +++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index 80f5372456..3c02dbad52 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -63,11 +63,12 @@ def __init__(self, x, affine: torch.Tensor | None = None, meta: dict | None = No the default value. The same is true for `meta`.""" self._set_initial_val("affine", affine, x, self.get_default_affine) self._set_initial_val("meta", meta, x, self.get_default_meta) + self.affine = self.affine.to(self.device) def _copy_attr(self, attribute: str, input_objs: list[MetaObj], default_fn: Callable, deep_copy: bool) -> None: super()._copy_attr(attribute, input_objs, default_fn, deep_copy) val = getattr(self, attribute) - if isinstance(self, torch.Tensor) and isinstance(val, torch.Tensor): + if isinstance(val, torch.Tensor): setattr(self, attribute, val.to(self.device)) @classmethod diff --git a/tests/test_meta_tensor.py b/tests/test_meta_tensor.py index eb139c99c3..fa3c8ff91e 100644 --- a/tests/test_meta_tensor.py +++ b/tests/test_meta_tensor.py @@ -65,6 +65,7 @@ def check( ids: bool = True, device: Optional[Union[str, torch.device]] = None, meta: bool = True, + **kwargs, ): if device is None: device = orig.device @@ -74,7 +75,7 @@ def check( if shape: assert_allclose(torch.as_tensor(out.shape), torch.as_tensor(orig.shape)) if vals: - assert_allclose(out, orig) + assert_allclose(out, orig, **kwargs) self.check_ids(out, orig, ids) self.assertTrue(str(device) in str(out.device)) @@ -211,7 +212,7 @@ def test_amp(self): im_conv = conv(im) with torch.autocast(str(device)): im_conv2 = conv(im) - self.check(im_conv2, im_conv, ids=False) + self.check(im_conv2, im_conv, ids=False, rtol=1e-4, atol=1e-3) # TODO # collate From 70bbcdbea1ccd498b28f0aad19985c2397376995 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 7 Apr 2022 17:45:17 +0100 Subject: [PATCH 09/26] fixes Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_meta_tensor.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/tests/test_meta_tensor.py b/tests/test_meta_tensor.py index fa3c8ff91e..d6009d4829 100644 --- a/tests/test_meta_tensor.py +++ b/tests/test_meta_tensor.py @@ -9,6 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import random import string import tempfile @@ -189,18 +190,20 @@ def test_torchscript(self, device): self.assertIsInstance(out, MetaTensor) self.check(out, im_conv, ids=False) # save it, load it, use it - with tempfile.NamedTemporaryFile() as fname: - torch.jit.save(traced_fn, f=fname.name) - traced_fn2 = torch.jit.load(fname.name) + with tempfile.TemporaryDirectory() as tmp_dir: + fname = os.path.join(tmp_dir, "im.pt") + torch.jit.save(traced_fn, f=fname) + traced_fn2 = torch.jit.load(fname) out2 = traced_fn2(im) self.assertIsInstance(out2, MetaTensor) self.check(out2, im_conv, ids=False) def test_pickling(self): m, _ = self.get_im() - with tempfile.NamedTemporaryFile() as fname: - torch.save(m, fname.name) - m2 = torch.load(fname.name) + with tempfile.TemporaryDirectory() as tmp_dir: + fname = os.path.join(tmp_dir, "im.pt") + torch.save(m, fname) + m2 = torch.load(fname) self.check(m2, m, ids=False) @skip_if_no_cuda @@ -210,7 +213,7 @@ def test_amp(self): im, _ = self.get_im(shape, device=device) conv = torch.nn.Conv2d(im.shape[1], 5, 3, device=device) im_conv = conv(im) - with torch.autocast(str(device)): + with torch.cuda.amp.autocast(): im_conv2 = conv(im) self.check(im_conv2, im_conv, ids=False, rtol=1e-4, atol=1e-3) From de02d7f66599f4bf240000c249d47215e09fb1c0 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 7 Apr 2022 17:54:43 +0100 Subject: [PATCH 10/26] fxies Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_meta_tensor.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/test_meta_tensor.py b/tests/test_meta_tensor.py index d6009d4829..36b582631c 100644 --- a/tests/test_meta_tensor.py +++ b/tests/test_meta_tensor.py @@ -155,7 +155,8 @@ def test_add(self, device, dtype): @parameterized.expand(TEST_DEVICES) def test_conv(self, device): im, _ = self.get_im((1, 3, 10, 8, 12), device=device) - conv = torch.nn.Conv3d(im.shape[1], 5, 3, device=device) + conv = torch.nn.Conv3d(im.shape[1], 5, 3) + conv.to(device) out = conv(im) self.check(out, im, shape=False, vals=False, ids=False) @@ -182,7 +183,8 @@ def test_get_set_meta_fns(self): def test_torchscript(self, device): shape = (1, 3, 10, 8) im, _ = self.get_im(shape, device=device) - conv = torch.nn.Conv2d(im.shape[1], 5, 3, device=device) + conv = torch.nn.Conv2d(im.shape[1], 5, 3) + conv.to(device) im_conv = conv(im) traced_fn = torch.jit.trace(conv, im) # try and use it @@ -211,7 +213,8 @@ def test_amp(self): shape = (1, 3, 10, 8) device = "cuda" im, _ = self.get_im(shape, device=device) - conv = torch.nn.Conv2d(im.shape[1], 5, 3, device=device) + conv = torch.nn.Conv2d(im.shape[1], 5, 3) + conv.to(device) im_conv = conv(im) with torch.cuda.amp.autocast(): im_conv2 = conv(im) From eabd3bf6a9161d2a40b7f0d81439cb0c6597af50 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 8 Apr 2022 14:00:23 +0100 Subject: [PATCH 11/26] typos Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/meta_obj.py | 30 +++++++++++++++--------------- monai/data/meta_tensor.py | 6 +++--- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/monai/data/meta_obj.py b/monai/data/meta_obj.py index 7259afbf01..9513277160 100644 --- a/monai/data/meta_obj.py +++ b/monai/data/meta_obj.py @@ -22,17 +22,17 @@ def set_track_meta(val: bool) -> None: """ - Boolean to set whether meta data is tracked. If `True`, meta data will be associated + Boolean to set whether metadata is tracked. If `True`, metadata will be associated its data by using subclasses of `MetaObj`. If `False`, then data will be returned - with empty meta data. + with empty metadata. If both `set_track_meta` and `set_track_transforms` are set to `False`, then standard data objects will be returned (e.g., `torch.Tensor` and `np.ndarray`) as opposed to our enhanced objects. By default, this is `True`, and most users will want to leave it this way. However, - if you are experiencing any problems regarding meta data, and aren't interested in - preserving meta data, then you can disable it. + if you are experiencing any problems regarding metadata, and aren't interested in + preserving metadata, then you can disable it. """ global _TRACK_META _TRACK_META = val @@ -58,17 +58,17 @@ def set_track_transforms(val: bool) -> None: def get_track_meta() -> bool: """ - Return the boolean as to whether meta data is tracked. If `True`, meta data will be + Return the boolean as to whether metadata is tracked. If `True`, metadata will be associated its data by using subclasses of `MetaObj`. If `False`, then data will be - returned with empty meta data. + returned with empty metadata. If both `set_track_meta` and `set_track_transforms` are set to `False`, then standard data objects will be returned (e.g., `torch.Tensor` and `np.ndarray`) as opposed to our enhanced objects. By default, this is `True`, and most users will want to leave it this way. However, - if you are experiencing any problems regarding meta data, and aren't interested in - preserving meta data, then you can disable it. + if you are experiencing any problems regarding metadata, and aren't interested in + preserving metadata, then you can disable it. """ return _TRACK_META @@ -92,24 +92,24 @@ def get_track_transforms() -> bool: class MetaObj: """ - Abstract base class that stores data as well as any extra meta data and an affine + Abstract base class that stores data as well as any extra metadata and an affine transformation matrix. This allows for subclassing `torch.Tensor` and `np.ndarray` through multiple inheritance. - Meta data is stored in the form of of a dictionary. Affine matrices are stored in + Metadata is stored in the form of a dictionary. Affine matrices are stored in the form of e.g., `torch.Tensor` or `np.ndarray`. We store the affine as its own element, so that this can be updated by - transforms. All other meta data that we don't plan on touching until we + transforms. All other metadata that we don't plan on touching until we need to save the image to file lives in `meta`. Behavior should be the same as extended class (e.g., `torch.Tensor` or `np.ndarray`) aside from the extended meta functionality. Copying of information: - * For `c = a + b`, then auxiliary data (e.g., meta data) will be copied from the + * For `c = a + b`, then auxiliary data (e.g., metadata) will be copied from the first instance of `MetaObj`. """ @@ -131,7 +131,7 @@ def _set_initial_val(self, attribute: str, input_arg: Any, input_obj: Any, defau input_obj: if `input_arg` is `None`, try to copy `attribute` from `input_obj`, if it is present and not `None`. default_fn: function to be used if all previous arguments return `None`. - Default meta data might be empty dictionary so could be as simple as + Default metadata might be empty dictionary so could be as simple as `lambda: {}`. Returns: Returns `None`, but `self` should have the updated `attribute`. @@ -194,7 +194,7 @@ def _copy_attr(self, attribute: str, input_objs: list[MetaObj], default_fn: Call def _copy_meta(self, input_objs: list[MetaObj]) -> None: """ - Copy meta data from a list of `MetaObj`. For a given attribute, we copy the + Copy metadata from a list of `MetaObj`. For a given attribute, we copy the adjunct data from the first element in the list containing that attribute. If there has been a change in `id` (e.g., `a=b+c`), then deepcopy. Else (e.g., @@ -215,7 +215,7 @@ def get_default_meta(self) -> dict: """Get the default meta. Returns: - default meta data. + default metadata. """ return {} diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index 3c02dbad52..6195d56b5d 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -25,18 +25,18 @@ class MetaTensor(MetaObj, torch.Tensor): Class that inherits from both `torch.Tensor` and `MetaObj`, adding support for meta data. - Meta data is stored in the form of of a dictionary. Affine matrices are stored in + Metadata is stored in the form of a dictionary. Affine matrices are stored in the form of `torch.Tensor`. We store the affine as its own element, so that this can be updated by - transforms. All other meta data that we don't plan on touching until we + transforms. All other metadata that we don't plan on touching until we need to save the image to file lives in `meta`. Behavior should be the same as `torch.Tensor` aside from the extended meta functionality. Copying of information: - * For `c = a + b`, then auxiliary data (e.g., meta data) will be copied from the + * For `c = a + b`, then auxiliary data (e.g., metadata) will be copied from the first instance of `MetaTensor`. Example: From 4af4d50f9e5e0c478f2bd147d1c44779aa40c635 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 8 Apr 2022 15:01:41 +0100 Subject: [PATCH 12/26] fixes Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/meta_tensor.py | 24 ++++++++++++++++++++++++ tests/test_meta_tensor.py | 20 ++++++++++++++++---- 2 files changed, 40 insertions(+), 4 deletions(-) diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index 6195d56b5d..c862fe3394 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -16,6 +16,7 @@ import torch from monai.data.meta_obj import MetaObj, get_track_meta, get_track_transforms +from monai.utils.enums import PostFix __all__ = ["MetaTensor"] @@ -52,6 +53,11 @@ class MetaTensor(MetaObj, torch.Tensor): m2 = m+m assert isinstance(m2, MetaTensor) assert m2.meta == meta + + Notes: + - Depending on your version of pytorch, `torch.jit.trace(net, im)` may or may + not work if `im` is of type `MetaTensor`. This can be resolved with + `torch.jit.trace(net, im.as_tensor)`. """ @staticmethod @@ -95,3 +101,21 @@ def as_tensor(self) -> torch.Tensor: It is OS dependent as to whether this will be a deep copy or not. """ return self.as_subclass(torch.Tensor) # type: ignore + + def as_dict(self, key: str) -> dict: + """ + Get the object as a dictionary for backwards compatibility. + + Args: + key: Base key to store main data. The key for the metadata will be + determined using `PostFix.meta`. + + + Return: + A dictionary consisting of two keys, the main data (stored under `key`) and + the metadata. The affine will be stored with the metadata for backwards + compatibility. + """ + meta = self.meta + meta["affine"] = self.affine + return {key: self.as_tensor(), PostFix.meta(key): self.meta} diff --git a/tests/test_meta_tensor.py b/tests/test_meta_tensor.py index 36b582631c..56dd42099c 100644 --- a/tests/test_meta_tensor.py +++ b/tests/test_meta_tensor.py @@ -22,6 +22,7 @@ from monai.data.meta_obj import get_track_meta, get_track_transforms, set_track_meta, set_track_transforms from monai.data.meta_tensor import MetaTensor +from monai.utils.enums import PostFix from tests.utils import TEST_DEVICES, assert_allclose, skip_if_no_cuda DTYPES = [[torch.float32], [torch.float64], [torch.float16], [torch.int64], [torch.int32]] @@ -66,6 +67,7 @@ def check( ids: bool = True, device: Optional[Union[str, torch.device]] = None, meta: bool = True, + check_ids: bool = True, **kwargs, ): if device is None: @@ -77,7 +79,8 @@ def check( assert_allclose(torch.as_tensor(out.shape), torch.as_tensor(orig.shape)) if vals: assert_allclose(out, orig, **kwargs) - self.check_ids(out, orig, ids) + if check_ids: + self.check_ids(out, orig, ids) self.assertTrue(str(device) in str(out.device)) # check meta and affine are equal and affine is on correct device @@ -85,8 +88,9 @@ def check( self.assertEqual(out.meta, orig.meta) assert_allclose(out.affine, orig.affine) self.assertTrue(str(device) in str(out.affine.device)) - self.check_ids(out.affine, orig.affine, ids) - self.check_ids(out.meta, orig.meta, ids) + if check_ids: + self.check_ids(out.affine, orig.affine, ids) + self.check_ids(out.meta, orig.meta, ids) @parameterized.expand(TESTS) def test_as_tensor(self, device, dtype): @@ -97,6 +101,14 @@ def test_as_tensor(self, device, dtype): self.assertIsInstance(m, MetaTensor) self.check(t, t2, ids=False) + def test_as_dict(self): + m, _ = self.get_im() + m_dict = m.as_dict("im") + im, meta = m_dict["im"], m_dict[PostFix.meta("im")] + affine = meta.pop("affine") + m2 = MetaTensor(im, affine, meta) + self.check(m2, m, check_ids=False) + @parameterized.expand(TESTS) def test_constructor(self, device, dtype): m, t = self.get_im(device=device, dtype=dtype) @@ -186,7 +198,7 @@ def test_torchscript(self, device): conv = torch.nn.Conv2d(im.shape[1], 5, 3) conv.to(device) im_conv = conv(im) - traced_fn = torch.jit.trace(conv, im) + traced_fn = torch.jit.trace(conv, im.as_tensor()) # try and use it out = traced_fn(im) self.assertIsInstance(out, MetaTensor) From 4a2c211c33322c13cb6a67fbba10cc92c3f0527f Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 8 Apr 2022 15:59:01 +0100 Subject: [PATCH 13/26] fix test Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_meta_tensor.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/tests/test_meta_tensor.py b/tests/test_meta_tensor.py index 56dd42099c..3833aa353c 100644 --- a/tests/test_meta_tensor.py +++ b/tests/test_meta_tensor.py @@ -14,6 +14,7 @@ import string import tempfile import unittest +import warnings from copy import deepcopy from typing import Optional, Union @@ -201,7 +202,15 @@ def test_torchscript(self, device): traced_fn = torch.jit.trace(conv, im.as_tensor()) # try and use it out = traced_fn(im) - self.assertIsInstance(out, MetaTensor) + self.assertIsInstance(out, torch.Tensor) + if not isinstance(out, MetaTensor): + warnings.warn( + "When calling `nn.Module(MetaTensor) on a module traced with " + "`torch.jit.trace`, your version of pytorch returns a " + "`torch.Tensor` instead of a `MetaTensor`. Consider upgrading " + "your pytorch version if this is important to you." + ) + im_conv = im.as_tensor() self.check(out, im_conv, ids=False) # save it, load it, use it with tempfile.TemporaryDirectory() as tmp_dir: From 6373711d3621a167d2e15761bd2af6b79c96b2bd Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 8 Apr 2022 17:49:05 +0100 Subject: [PATCH 14/26] fix Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_meta_tensor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_meta_tensor.py b/tests/test_meta_tensor.py index 3833aa353c..d9d4319fe5 100644 --- a/tests/test_meta_tensor.py +++ b/tests/test_meta_tensor.py @@ -210,7 +210,7 @@ def test_torchscript(self, device): "`torch.Tensor` instead of a `MetaTensor`. Consider upgrading " "your pytorch version if this is important to you." ) - im_conv = im.as_tensor() + im_conv = im_conv.as_tensor() self.check(out, im_conv, ids=False) # save it, load it, use it with tempfile.TemporaryDirectory() as tmp_dir: @@ -218,7 +218,7 @@ def test_torchscript(self, device): torch.jit.save(traced_fn, f=fname) traced_fn2 = torch.jit.load(fname) out2 = traced_fn2(im) - self.assertIsInstance(out2, MetaTensor) + self.assertIsInstance(out2, type(out)) self.check(out2, im_conv, ids=False) def test_pickling(self): From 07f511759acf8c188c11cd53347bf1189b15e788 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 8 Apr 2022 18:58:37 +0100 Subject: [PATCH 15/26] fix? Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_meta_tensor.py | 28 ++++++++++++---------------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/tests/test_meta_tensor.py b/tests/test_meta_tensor.py index d9d4319fe5..7fc0fddbd2 100644 --- a/tests/test_meta_tensor.py +++ b/tests/test_meta_tensor.py @@ -200,26 +200,22 @@ def test_torchscript(self, device): conv.to(device) im_conv = conv(im) traced_fn = torch.jit.trace(conv, im.as_tensor()) - # try and use it - out = traced_fn(im) - self.assertIsInstance(out, torch.Tensor) - if not isinstance(out, MetaTensor): - warnings.warn( - "When calling `nn.Module(MetaTensor) on a module traced with " - "`torch.jit.trace`, your version of pytorch returns a " - "`torch.Tensor` instead of a `MetaTensor`. Consider upgrading " - "your pytorch version if this is important to you." - ) - im_conv = im_conv.as_tensor() - self.check(out, im_conv, ids=False) # save it, load it, use it with tempfile.TemporaryDirectory() as tmp_dir: fname = os.path.join(tmp_dir, "im.pt") torch.jit.save(traced_fn, f=fname) - traced_fn2 = torch.jit.load(fname) - out2 = traced_fn2(im) - self.assertIsInstance(out2, type(out)) - self.check(out2, im_conv, ids=False) + traced_fn = torch.jit.load(fname) + out = traced_fn(im) + self.assertIsInstance(out, torch.Tensor) + if not isinstance(out, MetaTensor): + warnings.warn( + "When calling `nn.Module(MetaTensor) on a module traced with " + "`torch.jit.trace`, your version of pytorch returns a " + "`torch.Tensor` instead of a `MetaTensor`. Consider upgrading " + "your pytorch version if this is important to you." + ) + im_conv = im_conv.as_tensor() + self.check(out, im_conv, ids=False) def test_pickling(self): m, _ = self.get_im() From 723f475ab610795821192e9392420b799ca672d5 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 8 Apr 2022 20:02:09 +0100 Subject: [PATCH 16/26] affine lives inside meta Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/meta_obj.py | 50 +++++------------------------------- monai/data/meta_tensor.py | 54 +++++++++++++++++++++++++++------------ tests/test_meta_tensor.py | 28 +++++++++++++++++--- 3 files changed, 68 insertions(+), 64 deletions(-) diff --git a/monai/data/meta_obj.py b/monai/data/meta_obj.py index 9513277160..ff77245dae 100644 --- a/monai/data/meta_obj.py +++ b/monai/data/meta_obj.py @@ -92,18 +92,14 @@ def get_track_transforms() -> bool: class MetaObj: """ - Abstract base class that stores data as well as any extra metadata and an affine - transformation matrix. + Abstract base class that stores data as well as any extra metadata. This allows for subclassing `torch.Tensor` and `np.ndarray` through multiple inheritance. - Metadata is stored in the form of a dictionary. Affine matrices are stored in - the form of e.g., `torch.Tensor` or `np.ndarray`. - - We store the affine as its own element, so that this can be updated by - transforms. All other metadata that we don't plan on touching until we - need to save the image to file lives in `meta`. + Metadata is stored in the form of a dictionary. Nested, an affine matrix will be + stored. This should be the same type as the original data (e.g., `torch.Tensor` or + `np.ndarray`). Behavior should be the same as extended class (e.g., `torch.Tensor` or `np.ndarray`) aside from the extended meta functionality. @@ -114,33 +110,6 @@ class MetaObj: """ _meta: dict - _affine: Any - - def _set_initial_val(self, attribute: str, input_arg: Any, input_obj: Any, default_fn: Callable) -> None: - """ - Set the initial value of an attribute (e.g., `meta` or `affine`). - First, try to set `attribute` using `input_arg`. But if `input_arg` is `None`, - then we try to copy the value from `input_obj`. But if value is also `None`, - then we finally fall back on using the `default_fn`. - - Args: - attribute: string corresponding to attribute we want to set (e.g., `meta` or - `affine`). - input_arg: the value we would like `attribute` to take or `None` if not - given. - input_obj: if `input_arg` is `None`, try to copy `attribute` from - `input_obj`, if it is present and not `None`. - default_fn: function to be used if all previous arguments return `None`. - Default metadata might be empty dictionary so could be as simple as - `lambda: {}`. - Returns: - Returns `None`, but `self` should have the updated `attribute`. - """ - if input_arg is None: - input_arg = getattr(input_obj, attribute, None) - if input_arg is None: - input_arg = default_fn(self) - setattr(self, attribute, input_arg) @staticmethod def flatten_meta_objs(args: Sequence[Any]) -> list[MetaObj]: @@ -206,10 +175,7 @@ def _copy_meta(self, input_objs: list[MetaObj]) -> None: """ id_in = id(input_objs[0]) if len(input_objs) > 0 else None deep_copy = id(self) != id_in - attributes = ("affine", "meta") - default_fns: tuple[Callable, ...] = (self.get_default_affine, self.get_default_meta) - for attribute, default_fn in zip(attributes, default_fns): - self._copy_attr(attribute, input_objs, default_fn, deep_copy) + self._copy_attr("meta", input_objs, self.get_default_meta, deep_copy) def get_default_meta(self) -> dict: """Get the default meta. @@ -231,8 +197,6 @@ def __repr__(self) -> str: """String representation of class.""" out: str = super().__repr__() - out += f"\nAffine\n{self.affine}" - out += "\nMetaData\n" if self.meta is not None: out += "".join(f"\t{k}: {v}\n" for k, v in self.meta.items()) @@ -244,12 +208,12 @@ def __repr__(self) -> str: @property def affine(self) -> Any: """Get the affine.""" - return self._affine + return self.meta.get("affine", None) @affine.setter def affine(self, d: Any) -> None: """Set the affine.""" - self._affine = d + self.meta["affine"] = d @property def meta(self) -> dict: diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index c862fe3394..a103341931 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -11,6 +11,7 @@ from __future__ import annotations +import warnings from typing import Callable import torch @@ -26,12 +27,8 @@ class MetaTensor(MetaObj, torch.Tensor): Class that inherits from both `torch.Tensor` and `MetaObj`, adding support for meta data. - Metadata is stored in the form of a dictionary. Affine matrices are stored in - the form of `torch.Tensor`. - - We store the affine as its own element, so that this can be updated by - transforms. All other metadata that we don't plan on touching until we - need to save the image to file lives in `meta`. + Metadata is stored in the form of a dictionary. Nested, an affine matrix will be + stored. This should be in the form of `torch.Tensor`. Behavior should be the same as `torch.Tensor` aside from the extended meta functionality. @@ -47,9 +44,9 @@ class MetaTensor(MetaObj, torch.Tensor): from monai.data import MetaTensor t = torch.tensor([1,2,3]) - meta = {"some": "info"} affine = torch.eye(4) - m = MetaTensor(t, meta=meta, affine=affine) + meta = {"some": "info"} + m = MetaTensor(t, affine=affine, meta=meta) m2 = m+m assert isinstance(m2, MetaTensor) assert m2.meta == meta @@ -58,6 +55,8 @@ class MetaTensor(MetaObj, torch.Tensor): - Depending on your version of pytorch, `torch.jit.trace(net, im)` may or may not work if `im` is of type `MetaTensor`. This can be resolved with `torch.jit.trace(net, im.as_tensor)`. + - A warning will be raised if in the constructor `affine` is not `None` and + `meta` already contains the key `affine`. """ @staticmethod @@ -65,10 +64,30 @@ def __new__(cls, x, affine: torch.Tensor | None = None, meta: dict | None = None return torch.as_tensor(x, *args, **kwargs).as_subclass(cls) # type: ignore def __init__(self, x, affine: torch.Tensor | None = None, meta: dict | None = None) -> None: - """If `affine` is given, use it. Else, if `affine` exists in the input tensor, use it. Else, use - the default value. The same is true for `meta`.""" - self._set_initial_val("affine", affine, x, self.get_default_affine) - self._set_initial_val("meta", meta, x, self.get_default_meta) + """ + If `meta` is given, use it. Else, if `meta` exists in the input tensor, use it. + Else, use the default value. Similar for the affin, except this could come from + four places. + Priority: `affine`, `meta["affine"]`, `x.affine`, `get_default_affine`. + """ + # set meta + if meta is not None: + self.meta = meta + elif isinstance(x, MetaObj): + self.meta = x.meta + else: + self.meta = self.get_default_meta() + # set the affine + if affine is not None: + if "affine" in self.meta: + warnings.warn("Setting affine, but the applied meta contains an affine. " "This will be overwritten.") + self.affine = affine + elif "affine" in self.meta: + pass # nothing to do + elif isinstance(x, MetaObj): + self.affine = x.affine + else: + self.affine = self.get_default_affine() self.affine = self.affine.to(self.device) def _copy_attr(self, attribute: str, input_objs: list[MetaObj], default_fn: Callable, deep_copy: bool) -> None: @@ -110,12 +129,13 @@ def as_dict(self, key: str) -> dict: key: Base key to store main data. The key for the metadata will be determined using `PostFix.meta`. - Return: A dictionary consisting of two keys, the main data (stored under `key`) and - the metadata. The affine will be stored with the metadata for backwards - compatibility. + the metadata. """ - meta = self.meta - meta["affine"] = self.affine return {key: self.as_tensor(), PostFix.meta(key): self.meta} + + @MetaObj.affine.setter #  type: ignore + def affine(self, d: torch.Tensor) -> None: + """Set the affine.""" + self.meta["affine"] = d.to(self.device) diff --git a/tests/test_meta_tensor.py b/tests/test_meta_tensor.py index 7fc0fddbd2..7dc633eb66 100644 --- a/tests/test_meta_tensor.py +++ b/tests/test_meta_tensor.py @@ -86,7 +86,11 @@ def check( # check meta and affine are equal and affine is on correct device if isinstance(orig, MetaTensor) and isinstance(out, MetaTensor) and meta: - self.assertEqual(out.meta, orig.meta) + orig_meta_no_affine = deepcopy(orig.meta) + del orig_meta_no_affine["affine"] + out_meta_no_affine = deepcopy(out.meta) + del out_meta_no_affine["affine"] + self.assertEqual(orig_meta_no_affine, out_meta_no_affine) assert_allclose(out.affine, orig.affine) self.assertTrue(str(device) in str(out.affine.device)) if check_ids: @@ -113,8 +117,16 @@ def test_as_dict(self): @parameterized.expand(TESTS) def test_constructor(self, device, dtype): m, t = self.get_im(device=device, dtype=dtype) - m2 = MetaTensor(t.clone(), m.affine, m.meta) + # construct from pre-existing + m1 = MetaTensor(m.clone()) + self.check(m, m1, ids=False, meta=False) + # meta already has affine + m2 = MetaTensor(t.clone(), meta=m.meta) self.check(m, m2, ids=False, meta=False) + # meta dosen't have affine + affine = m.meta.pop("affine") + m3 = MetaTensor(t.clone(), affine=affine, meta=m.meta) + self.check(m, m3, ids=False, meta=False) @parameterized.expand(TESTS) @skip_if_no_cuda @@ -135,6 +147,12 @@ def test_to_cuda(self, device, dtype): m = m.to(device="cpu") self.check(m, orig, ids=False, device="cpu") + @skip_if_no_cuda + def test_affine_device(self): + m, _ = self.get_im() # device="cuda") + m.affine = torch.eye(4) + self.assertEqual(m.device, m.affine.device) + @parameterized.expand(TESTS) def test_copy(self, device, dtype): m, _ = self.get_im(device=device, dtype=dtype) @@ -157,7 +175,7 @@ def test_add(self, device, dtype): self.check(torch.add(input=m1, other=m2), t1 + t2, ids=False) self.check(torch.add(m1, other=m2), t1 + t2, ids=False) m3 = deepcopy(m2) - t3 = deepcopy(m2) + t3 = deepcopy(t2) m3 += 3 t3 += 3 self.check(m3, t3, ids=False) @@ -179,7 +197,9 @@ def test_stack(self, device, dtype): ims = [self.get_im(device=device, dtype=dtype)[0] for _ in range(numel)] stacked = torch.stack(ims) self.assertIsInstance(stacked, MetaTensor) - assert_allclose(stacked.affine, ims[0].affine) + orig_affine = ims[0].meta.pop("affine") + stacked_affine = stacked.meta.pop("affine") + assert_allclose(orig_affine, stacked_affine) self.assertEqual(stacked.meta, ims[0].meta) def test_get_set_meta_fns(self): From b25440ce04601aeeb69bb12c6025a40620489b87 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Fri, 8 Apr 2022 20:30:51 +0100 Subject: [PATCH 17/26] move affine in meta Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/meta_obj.py | 25 ++----------------------- monai/data/meta_tensor.py | 9 +++++++-- 2 files changed, 9 insertions(+), 25 deletions(-) diff --git a/monai/data/meta_obj.py b/monai/data/meta_obj.py index ff77245dae..84dbe25a69 100644 --- a/monai/data/meta_obj.py +++ b/monai/data/meta_obj.py @@ -97,9 +97,7 @@ class MetaObj: This allows for subclassing `torch.Tensor` and `np.ndarray` through multiple inheritance. - Metadata is stored in the form of a dictionary. Nested, an affine matrix will be - stored. This should be the same type as the original data (e.g., `torch.Tensor` or - `np.ndarray`). + Metadata is stored in the form of a dictionary. Behavior should be the same as extended class (e.g., `torch.Tensor` or `np.ndarray`) aside from the extended meta functionality. @@ -141,8 +139,7 @@ def _copy_attr(self, attribute: str, input_objs: list[MetaObj], default_fn: Call We also perform a deep copy of the data if desired. Args: - attribute: string corresponding to attribute to be copied (e.g., `meta` or - `affine`). + attribute: string corresponding to attribute to be copied (e.g., `meta`). input_objs: List of `MetaObj`. We'll copy the attribute from the first one that contains that particular attribute. default_fn: If none of `input_objs` have the attribute that we're @@ -185,14 +182,6 @@ def get_default_meta(self) -> dict: """ return {} - def get_default_affine(self) -> Any: - """Get the default affine. - - Returns: - default affine. - """ - raise NotImplementedError() - def __repr__(self) -> str: """String representation of class.""" out: str = super().__repr__() @@ -205,16 +194,6 @@ def __repr__(self) -> str: return out - @property - def affine(self) -> Any: - """Get the affine.""" - return self.meta.get("affine", None) - - @affine.setter - def affine(self, d: Any) -> None: - """Set the affine.""" - self.meta["affine"] = d - @property def meta(self) -> dict: """Get the meta.""" diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index a103341931..d82d0d259e 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -135,7 +135,12 @@ def as_dict(self, key: str) -> dict: """ return {key: self.as_tensor(), PostFix.meta(key): self.meta} - @MetaObj.affine.setter #  type: ignore + @property + def affine(self) -> torch.Tensor: + """Get the affine.""" + return self.meta["affine"] + + @affine.setter def affine(self, d: torch.Tensor) -> None: """Set the affine.""" - self.meta["affine"] = d.to(self.device) + self.meta["affine"] = d From 3e2eb023fce1fcf14ee61ea1df9126a42c92c562 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 11 Apr 2022 09:51:03 +0100 Subject: [PATCH 18/26] flake8 Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/meta_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index d82d0d259e..c2401b6f16 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -84,7 +84,7 @@ def __init__(self, x, affine: torch.Tensor | None = None, meta: dict | None = No self.affine = affine elif "affine" in self.meta: pass # nothing to do - elif isinstance(x, MetaObj): + elif isinstance(x, MetaTensor): self.affine = x.affine else: self.affine = self.get_default_affine() From 32cc5b397f566a01b654b81be57416d1c4380344 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 11 Apr 2022 09:58:56 +0100 Subject: [PATCH 19/26] fixes Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/meta_tensor.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index c2401b6f16..f5c1288479 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -109,6 +109,7 @@ def __torch_function__(cls, func, types, args=(), kwargs=None) -> torch.Tensor: return ret.as_tensor() meta_args = MetaObj.flatten_meta_objs(list(args) + list(kwargs.values())) ret._copy_meta(meta_args) + ret.affine = ret.affine.to(ret.device) return ret def get_default_affine(self) -> torch.Tensor: @@ -138,7 +139,7 @@ def as_dict(self, key: str) -> dict: @property def affine(self) -> torch.Tensor: """Get the affine.""" - return self.meta["affine"] + return self.meta["affine"] # type: ignore @affine.setter def affine(self, d: torch.Tensor) -> None: From 7fe23fd27da95361e606899283d5c03bacc62aa0 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 11 Apr 2022 10:36:32 +0100 Subject: [PATCH 20/26] fixes Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/meta_tensor.py | 4 +++- tests/test_meta_tensor.py | 8 +++++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index f5c1288479..b129ff839b 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -54,7 +54,9 @@ class MetaTensor(MetaObj, torch.Tensor): Notes: - Depending on your version of pytorch, `torch.jit.trace(net, im)` may or may not work if `im` is of type `MetaTensor`. This can be resolved with - `torch.jit.trace(net, im.as_tensor)`. + `torch.jit.trace(net, im.as_tensor())`. + - Depending on your version of pytorch `torch.save(m, fname); m=torch.load(fname)` + may return a `torch.Tensor` instead of MetaTensor`. - A warning will be raised if in the constructor `affine` is not `None` and `meta` already contains the key `affine`. """ diff --git a/tests/test_meta_tensor.py b/tests/test_meta_tensor.py index 7dc633eb66..3a006dc34d 100644 --- a/tests/test_meta_tensor.py +++ b/tests/test_meta_tensor.py @@ -24,8 +24,11 @@ from monai.data.meta_obj import get_track_meta, get_track_transforms, set_track_meta, set_track_transforms from monai.data.meta_tensor import MetaTensor from monai.utils.enums import PostFix +from monai.utils.module import get_torch_version_tuple from tests.utils import TEST_DEVICES, assert_allclose, skip_if_no_cuda +PT_VER_MAJ, PT_VER_MIN = get_torch_version_tuple() + DTYPES = [[torch.float32], [torch.float64], [torch.float16], [torch.int64], [torch.int32]] TESTS = [] for _device in TEST_DEVICES: @@ -227,7 +230,7 @@ def test_torchscript(self, device): traced_fn = torch.jit.load(fname) out = traced_fn(im) self.assertIsInstance(out, torch.Tensor) - if not isinstance(out, MetaTensor): + if not isinstance(out, MetaTensor) and PT_VER_MAJ == 1 and PT_VER_MIN <= 7: warnings.warn( "When calling `nn.Module(MetaTensor) on a module traced with " "`torch.jit.trace`, your version of pytorch returns a " @@ -243,6 +246,9 @@ def test_pickling(self): fname = os.path.join(tmp_dir, "im.pt") torch.save(m, fname) m2 = torch.load(fname) + if not isinstance(m2, MetaTensor) and PT_VER_MAJ == 1 and PT_VER_MIN <= 7: + warnings.warn("Old version of pytorch. pickling converts `MetaTensor` to `torch.Tensor`.") + m = m.as_tensor() self.check(m2, m, ids=False) @skip_if_no_cuda From 7f43b00b1a31634f545ad609e249b351c687a354 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 11 Apr 2022 10:40:42 +0100 Subject: [PATCH 21/26] fixes Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/meta_tensor.py | 4 ++-- tests/test_meta_tensor.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index b129ff839b..c8194bd076 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -52,10 +52,10 @@ class MetaTensor(MetaObj, torch.Tensor): assert m2.meta == meta Notes: - - Depending on your version of pytorch, `torch.jit.trace(net, im)` may or may + - Older versions of pytorch (<=1.8), `torch.jit.trace(net, im)` may not work if `im` is of type `MetaTensor`. This can be resolved with `torch.jit.trace(net, im.as_tensor())`. - - Depending on your version of pytorch `torch.save(m, fname); m=torch.load(fname)` + - For older versions of pytorch (<=1.7), `torch.save(m, fname); m=torch.load(fname)` may return a `torch.Tensor` instead of MetaTensor`. - A warning will be raised if in the constructor `affine` is not `None` and `meta` already contains the key `affine`. diff --git a/tests/test_meta_tensor.py b/tests/test_meta_tensor.py index 3a006dc34d..0448e403db 100644 --- a/tests/test_meta_tensor.py +++ b/tests/test_meta_tensor.py @@ -230,7 +230,7 @@ def test_torchscript(self, device): traced_fn = torch.jit.load(fname) out = traced_fn(im) self.assertIsInstance(out, torch.Tensor) - if not isinstance(out, MetaTensor) and PT_VER_MAJ == 1 and PT_VER_MIN <= 7: + if not isinstance(out, MetaTensor) and PT_VER_MAJ == 1 and PT_VER_MIN <= 8: warnings.warn( "When calling `nn.Module(MetaTensor) on a module traced with " "`torch.jit.trace`, your version of pytorch returns a " From 840e7dfe6c7ec0dd5104df01d7003a6a1ddc548e Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 11 Apr 2022 10:59:20 +0100 Subject: [PATCH 22/26] pytorch min version 1.7 Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- .github/workflows/cron.yml | 4 +--- .github/workflows/pythonapp-min.yml | 4 +--- .github/workflows/pythonapp.yml | 2 +- CHANGELOG.md | 2 +- docs/requirements.txt | 4 ++-- environment-dev.yml | 2 +- monai/apps/mmars/mmars.py | 2 +- monai/data/torchscript_utils.py | 4 ++-- pyproject.toml | 2 +- requirements.txt | 2 +- setup.cfg | 2 +- 11 files changed, 13 insertions(+), 17 deletions(-) diff --git a/.github/workflows/cron.yml b/.github/workflows/cron.yml index f76fe01699..7a2e0b5bc0 100644 --- a/.github/workflows/cron.yml +++ b/.github/workflows/cron.yml @@ -15,7 +15,7 @@ jobs: runs-on: [self-hosted, linux, x64, common] strategy: matrix: - pytorch-version: [1.6.0, 1.7.1, 1.8.1, 1.9.1, latest] + pytorch-version: [1.7.1, 1.8.1, 1.9.1, latest] steps: - uses: actions/checkout@v2 - name: Install the dependencies @@ -25,8 +25,6 @@ jobs: python -m pip uninstall -y torch torchvision if [ ${{ matrix.pytorch-version }} == "latest" ]; then python -m pip install torch torchvision - elif [ ${{ matrix.pytorch-version }} == "1.6.0" ]; then - python -m pip install torch==1.6.0 torchvision==0.7.0 elif [ ${{ matrix.pytorch-version }} == "1.7.1" ]; then python -m pip install torch==1.7.1 torchvision==0.8.2 elif [ ${{ matrix.pytorch-version }} == "1.8.1" ]; then diff --git a/.github/workflows/pythonapp-min.yml b/.github/workflows/pythonapp-min.yml index c3294c2b2a..0c9e63c7c4 100644 --- a/.github/workflows/pythonapp-min.yml +++ b/.github/workflows/pythonapp-min.yml @@ -119,7 +119,7 @@ jobs: strategy: fail-fast: false matrix: - pytorch-version: [1.6.0, 1.7.1, 1.8.1, 1.9.1, 1.10.1, latest] + pytorch-version: [1.7.1, 1.8.1, 1.9.1, 1.10.1, latest] timeout-minutes: 40 steps: - uses: actions/checkout@v2 @@ -148,8 +148,6 @@ jobs: # min. requirements if [ ${{ matrix.pytorch-version }} == "latest" ]; then python -m pip install torch - elif [ ${{ matrix.pytorch-version }} == "1.6.0" ]; then - python -m pip install torch==1.6.0 elif [ ${{ matrix.pytorch-version }} == "1.7.1" ]; then python -m pip install torch==1.7.1 elif [ ${{ matrix.pytorch-version }} == "1.8.1" ]; then diff --git a/.github/workflows/pythonapp.yml b/.github/workflows/pythonapp.yml index cf251c2293..38c96b3b0d 100644 --- a/.github/workflows/pythonapp.yml +++ b/.github/workflows/pythonapp.yml @@ -137,7 +137,7 @@ jobs: # install the latest pytorch for testing # however, "pip install monai*.tar.gz" will build cpp/cuda with an isolated # fresh torch installation according to pyproject.toml - python -m pip install torch>=1.6 torchvision + python -m pip install torch>=1.7 torchvision - name: Check packages run: | pip uninstall monai diff --git a/CHANGELOG.md b/CHANGELOG.md index 3f55ded72f..20a1583274 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -403,7 +403,7 @@ the postprocessing steps should be used before calling the metrics methods * Progress bar with tqdm ### Changed -* Now fully compatible with PyTorch 1.6 +* Now fully compatible with PyTorch 1.7 * Base Docker image upgraded to `nvcr.io/nvidia/pytorch:20.08-py3` from `nvcr.io/nvidia/pytorch:20.03-py3` * Code contributions now require signing off on the [Developer Certificate of Origin (DCO)](https://developercertificate.org/) * Major work in type hinting finished diff --git a/docs/requirements.txt b/docs/requirements.txt index f9749e9e36..71d899f748 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,5 +1,5 @@ --f https://download.pytorch.org/whl/cpu/torch-1.6.0%2Bcpu-cp37-cp37m-linux_x86_64.whl -torch>=1.6 +-f https://download.pytorch.org/whl/cpu/torch-1.7.0%2Bcpu-cp37-cp37m-linux_x86_64.whl +torch>=1.7 pytorch-ignite==0.4.8 numpy>=1.17 itk>=5.2 diff --git a/environment-dev.yml b/environment-dev.yml index a361262930..18ed4e1aa8 100644 --- a/environment-dev.yml +++ b/environment-dev.yml @@ -5,7 +5,7 @@ channels: - conda-forge dependencies: - numpy>=1.17 - - pytorch>=1.6 + - pytorch>=1.7 - coverage>=5.5 - parameterized - setuptools>=50.3.0,!=60.0.0 diff --git a/monai/apps/mmars/mmars.py b/monai/apps/mmars/mmars.py index f389f7ad33..53d3a414ad 100644 --- a/monai/apps/mmars/mmars.py +++ b/monai/apps/mmars/mmars.py @@ -135,7 +135,7 @@ def download_mmar( if has_home: mmar_dir = Path(get_dir()) / "mmars" else: - raise ValueError("mmar_dir=None, but no suitable default directory computed. Upgrade Pytorch to 1.6+ ?") + raise ValueError("mmar_dir=None, but no suitable default directory computed. Upgrade Pytorch to 1.7+ ?") mmar_dir = Path(mmar_dir) if api: model_dict = _get_all_ngc_models(item.get(Keys.NAME, f"{item}") if isinstance(item, Mapping) else f"{item}") diff --git a/monai/data/torchscript_utils.py b/monai/data/torchscript_utils.py index 61477e8ca9..77c8ab35f1 100644 --- a/monai/data/torchscript_utils.py +++ b/monai/data/torchscript_utils.py @@ -80,7 +80,7 @@ def save_net_with_metadata( json_data = json.dumps(metadict) - # Pytorch>1.6 can use dictionaries directly, otherwise need to use special map object + # Pytorch>1.7 can use dictionaries directly, otherwise need to use special map object if pytorch_after(1, 7): extra_files = {METADATA_FILENAME: json_data.encode()} @@ -123,7 +123,7 @@ def load_net_with_metadata( Returns: Triple containing loaded object, metadata dict, and extra files dict containing other file data if present """ - # Pytorch>1.6 can use dictionaries directly, otherwise need to use special map object + # Pytorch>1.7 can use dictionaries directly, otherwise need to use special map object if pytorch_after(1, 7): extra_files = {f: "" for f in more_extra_files} extra_files[METADATA_FILENAME] = "" diff --git a/pyproject.toml b/pyproject.toml index 03e9f49ab5..eea4ebf9b1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,7 +2,7 @@ requires = [ "wheel", "setuptools", - "torch>=1.6", + "torch>=1.7", "ninja", ] diff --git a/requirements.txt b/requirements.txt index e4ea34b5d4..14eb2b30e9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,2 @@ -torch>=1.6 +torch>=1.7 numpy>=1.17 diff --git a/setup.cfg b/setup.cfg index a7d597d6bd..12f974ca6d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -24,7 +24,7 @@ setup_requires = torch ninja install_requires = - torch>=1.6 + torch>=1.7 numpy>=1.17 [options.extras_require] From e2256da517cd41c2a657b3421c15ff35f240d4d2 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 11 Apr 2022 10:59:20 +0100 Subject: [PATCH 23/26] Revert "pytorch min version 1.7" This reverts commit 840e7dfe6c7ec0dd5104df01d7003a6a1ddc548e. Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- .github/workflows/cron.yml | 4 +++- .github/workflows/pythonapp-min.yml | 4 +++- .github/workflows/pythonapp.yml | 2 +- CHANGELOG.md | 2 +- docs/requirements.txt | 4 ++-- environment-dev.yml | 2 +- monai/apps/mmars/mmars.py | 2 +- monai/data/torchscript_utils.py | 4 ++-- pyproject.toml | 2 +- requirements.txt | 2 +- setup.cfg | 2 +- 11 files changed, 17 insertions(+), 13 deletions(-) diff --git a/.github/workflows/cron.yml b/.github/workflows/cron.yml index 7a2e0b5bc0..f76fe01699 100644 --- a/.github/workflows/cron.yml +++ b/.github/workflows/cron.yml @@ -15,7 +15,7 @@ jobs: runs-on: [self-hosted, linux, x64, common] strategy: matrix: - pytorch-version: [1.7.1, 1.8.1, 1.9.1, latest] + pytorch-version: [1.6.0, 1.7.1, 1.8.1, 1.9.1, latest] steps: - uses: actions/checkout@v2 - name: Install the dependencies @@ -25,6 +25,8 @@ jobs: python -m pip uninstall -y torch torchvision if [ ${{ matrix.pytorch-version }} == "latest" ]; then python -m pip install torch torchvision + elif [ ${{ matrix.pytorch-version }} == "1.6.0" ]; then + python -m pip install torch==1.6.0 torchvision==0.7.0 elif [ ${{ matrix.pytorch-version }} == "1.7.1" ]; then python -m pip install torch==1.7.1 torchvision==0.8.2 elif [ ${{ matrix.pytorch-version }} == "1.8.1" ]; then diff --git a/.github/workflows/pythonapp-min.yml b/.github/workflows/pythonapp-min.yml index 0c9e63c7c4..c3294c2b2a 100644 --- a/.github/workflows/pythonapp-min.yml +++ b/.github/workflows/pythonapp-min.yml @@ -119,7 +119,7 @@ jobs: strategy: fail-fast: false matrix: - pytorch-version: [1.7.1, 1.8.1, 1.9.1, 1.10.1, latest] + pytorch-version: [1.6.0, 1.7.1, 1.8.1, 1.9.1, 1.10.1, latest] timeout-minutes: 40 steps: - uses: actions/checkout@v2 @@ -148,6 +148,8 @@ jobs: # min. requirements if [ ${{ matrix.pytorch-version }} == "latest" ]; then python -m pip install torch + elif [ ${{ matrix.pytorch-version }} == "1.6.0" ]; then + python -m pip install torch==1.6.0 elif [ ${{ matrix.pytorch-version }} == "1.7.1" ]; then python -m pip install torch==1.7.1 elif [ ${{ matrix.pytorch-version }} == "1.8.1" ]; then diff --git a/.github/workflows/pythonapp.yml b/.github/workflows/pythonapp.yml index 38c96b3b0d..cf251c2293 100644 --- a/.github/workflows/pythonapp.yml +++ b/.github/workflows/pythonapp.yml @@ -137,7 +137,7 @@ jobs: # install the latest pytorch for testing # however, "pip install monai*.tar.gz" will build cpp/cuda with an isolated # fresh torch installation according to pyproject.toml - python -m pip install torch>=1.7 torchvision + python -m pip install torch>=1.6 torchvision - name: Check packages run: | pip uninstall monai diff --git a/CHANGELOG.md b/CHANGELOG.md index 20a1583274..3f55ded72f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -403,7 +403,7 @@ the postprocessing steps should be used before calling the metrics methods * Progress bar with tqdm ### Changed -* Now fully compatible with PyTorch 1.7 +* Now fully compatible with PyTorch 1.6 * Base Docker image upgraded to `nvcr.io/nvidia/pytorch:20.08-py3` from `nvcr.io/nvidia/pytorch:20.03-py3` * Code contributions now require signing off on the [Developer Certificate of Origin (DCO)](https://developercertificate.org/) * Major work in type hinting finished diff --git a/docs/requirements.txt b/docs/requirements.txt index 71d899f748..f9749e9e36 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,5 +1,5 @@ --f https://download.pytorch.org/whl/cpu/torch-1.7.0%2Bcpu-cp37-cp37m-linux_x86_64.whl -torch>=1.7 +-f https://download.pytorch.org/whl/cpu/torch-1.6.0%2Bcpu-cp37-cp37m-linux_x86_64.whl +torch>=1.6 pytorch-ignite==0.4.8 numpy>=1.17 itk>=5.2 diff --git a/environment-dev.yml b/environment-dev.yml index 18ed4e1aa8..a361262930 100644 --- a/environment-dev.yml +++ b/environment-dev.yml @@ -5,7 +5,7 @@ channels: - conda-forge dependencies: - numpy>=1.17 - - pytorch>=1.7 + - pytorch>=1.6 - coverage>=5.5 - parameterized - setuptools>=50.3.0,!=60.0.0 diff --git a/monai/apps/mmars/mmars.py b/monai/apps/mmars/mmars.py index 53d3a414ad..f389f7ad33 100644 --- a/monai/apps/mmars/mmars.py +++ b/monai/apps/mmars/mmars.py @@ -135,7 +135,7 @@ def download_mmar( if has_home: mmar_dir = Path(get_dir()) / "mmars" else: - raise ValueError("mmar_dir=None, but no suitable default directory computed. Upgrade Pytorch to 1.7+ ?") + raise ValueError("mmar_dir=None, but no suitable default directory computed. Upgrade Pytorch to 1.6+ ?") mmar_dir = Path(mmar_dir) if api: model_dict = _get_all_ngc_models(item.get(Keys.NAME, f"{item}") if isinstance(item, Mapping) else f"{item}") diff --git a/monai/data/torchscript_utils.py b/monai/data/torchscript_utils.py index 77c8ab35f1..61477e8ca9 100644 --- a/monai/data/torchscript_utils.py +++ b/monai/data/torchscript_utils.py @@ -80,7 +80,7 @@ def save_net_with_metadata( json_data = json.dumps(metadict) - # Pytorch>1.7 can use dictionaries directly, otherwise need to use special map object + # Pytorch>1.6 can use dictionaries directly, otherwise need to use special map object if pytorch_after(1, 7): extra_files = {METADATA_FILENAME: json_data.encode()} @@ -123,7 +123,7 @@ def load_net_with_metadata( Returns: Triple containing loaded object, metadata dict, and extra files dict containing other file data if present """ - # Pytorch>1.7 can use dictionaries directly, otherwise need to use special map object + # Pytorch>1.6 can use dictionaries directly, otherwise need to use special map object if pytorch_after(1, 7): extra_files = {f: "" for f in more_extra_files} extra_files[METADATA_FILENAME] = "" diff --git a/pyproject.toml b/pyproject.toml index eea4ebf9b1..03e9f49ab5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,7 +2,7 @@ requires = [ "wheel", "setuptools", - "torch>=1.7", + "torch>=1.6", "ninja", ] diff --git a/requirements.txt b/requirements.txt index 14eb2b30e9..e4ea34b5d4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,2 @@ -torch>=1.7 +torch>=1.6 numpy>=1.17 diff --git a/setup.cfg b/setup.cfg index 12f974ca6d..a7d597d6bd 100644 --- a/setup.cfg +++ b/setup.cfg @@ -24,7 +24,7 @@ setup_requires = torch ninja install_requires = - torch>=1.7 + torch>=1.6 numpy>=1.17 [options.extras_require] From e3e567c427ac074aeab00973bf0f8d84113a17da Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Mon, 11 Apr 2022 15:54:19 +0100 Subject: [PATCH 24/26] torch 1.9 Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_meta_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_meta_tensor.py b/tests/test_meta_tensor.py index 0448e403db..1721e7d2b9 100644 --- a/tests/test_meta_tensor.py +++ b/tests/test_meta_tensor.py @@ -230,7 +230,7 @@ def test_torchscript(self, device): traced_fn = torch.jit.load(fname) out = traced_fn(im) self.assertIsInstance(out, torch.Tensor) - if not isinstance(out, MetaTensor) and PT_VER_MAJ == 1 and PT_VER_MIN <= 8: + if not isinstance(out, MetaTensor) and PT_VER_MAJ == 1 and PT_VER_MIN <= 9: warnings.warn( "When calling `nn.Module(MetaTensor) on a module traced with " "`torch.jit.trace`, your version of pytorch returns a " From 56bc6d5d42ee58af54342bb3e8d427d07b08dfb1 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Tue, 12 Apr 2022 16:05:20 +0100 Subject: [PATCH 25/26] remove __init__, correct docstring Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/meta_tensor.py | 31 +++++++++++++++---------------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index c8194bd076..465fc9d93b 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -44,53 +44,52 @@ class MetaTensor(MetaObj, torch.Tensor): from monai.data import MetaTensor t = torch.tensor([1,2,3]) - affine = torch.eye(4) + affine = torch.eye(4) * 100 meta = {"some": "info"} m = MetaTensor(t, affine=affine, meta=meta) m2 = m+m assert isinstance(m2, MetaTensor) - assert m2.meta == meta + assert m2.meta["some"] == "info" + assert m2.affine = affine Notes: - Older versions of pytorch (<=1.8), `torch.jit.trace(net, im)` may not work if `im` is of type `MetaTensor`. This can be resolved with `torch.jit.trace(net, im.as_tensor())`. - - For older versions of pytorch (<=1.7), `torch.save(m, fname); m=torch.load(fname)` - may return a `torch.Tensor` instead of MetaTensor`. - A warning will be raised if in the constructor `affine` is not `None` and `meta` already contains the key `affine`. """ @staticmethod def __new__(cls, x, affine: torch.Tensor | None = None, meta: dict | None = None, *args, **kwargs) -> MetaTensor: - return torch.as_tensor(x, *args, **kwargs).as_subclass(cls) # type: ignore - - def __init__(self, x, affine: torch.Tensor | None = None, meta: dict | None = None) -> None: """ If `meta` is given, use it. Else, if `meta` exists in the input tensor, use it. Else, use the default value. Similar for the affin, except this could come from four places. Priority: `affine`, `meta["affine"]`, `x.affine`, `get_default_affine`. """ + out: MetaTensor = torch.as_tensor(x, *args, **kwargs).as_subclass(cls) # type: ignore # set meta if meta is not None: - self.meta = meta + out.meta = meta elif isinstance(x, MetaObj): - self.meta = x.meta + out.meta = x.meta else: - self.meta = self.get_default_meta() + out.meta = out.get_default_meta() # set the affine if affine is not None: - if "affine" in self.meta: + if "affine" in out.meta: warnings.warn("Setting affine, but the applied meta contains an affine. " "This will be overwritten.") - self.affine = affine - elif "affine" in self.meta: + out.affine = affine + elif "affine" in out.meta: pass # nothing to do elif isinstance(x, MetaTensor): - self.affine = x.affine + out.affine = x.affine else: - self.affine = self.get_default_affine() - self.affine = self.affine.to(self.device) + out.affine = out.get_default_affine() + out.affine = out.affine.to(out.device) + + return out def _copy_attr(self, attribute: str, input_objs: list[MetaObj], default_fn: Callable, deep_copy: bool) -> None: super()._copy_attr(attribute, input_objs, default_fn, deep_copy) From 09eb602a1f90d0697f31b7f619662c50c433ce2d Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Tue, 12 Apr 2022 16:58:45 +0100 Subject: [PATCH 26/26] adds docs Signed-off-by: Wenqi Li --- docs/source/data.rst | 11 +++++++++++ monai/data/meta_obj.py | 4 +++- monai/data/meta_tensor.py | 16 ++++++++-------- 3 files changed, 22 insertions(+), 9 deletions(-) diff --git a/docs/source/data.rst b/docs/source/data.rst index e76eb53f39..0910001783 100644 --- a/docs/source/data.rst +++ b/docs/source/data.rst @@ -271,3 +271,14 @@ ThreadDataLoader TestTimeAugmentation ~~~~~~~~~~~~~~~~~~~~ .. autoclass:: monai.data.TestTimeAugmentation + + +Meta Object +----------- +.. automodule:: monai.data.meta_obj + :members: + +MetaTensor +---------- +.. autoclass:: monai.data.MetaTensor + :members: diff --git a/monai/data/meta_obj.py b/monai/data/meta_obj.py index 84dbe25a69..d60ec6e473 100644 --- a/monai/data/meta_obj.py +++ b/monai/data/meta_obj.py @@ -103,8 +103,10 @@ class MetaObj: aside from the extended meta functionality. Copying of information: + * For `c = a + b`, then auxiliary data (e.g., metadata) will be copied from the - first instance of `MetaObj`. + first instance of `MetaObj`. + """ _meta: dict diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index 465fc9d93b..c5b95f8d08 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -24,8 +24,7 @@ class MetaTensor(MetaObj, torch.Tensor): """ - Class that inherits from both `torch.Tensor` and `MetaObj`, adding support for meta - data. + Class that inherits from both `torch.Tensor` and `MetaObj`, adding support for metadata. Metadata is stored in the form of a dictionary. Nested, an affine matrix will be stored. This should be in the form of `torch.Tensor`. @@ -34,8 +33,9 @@ class MetaTensor(MetaObj, torch.Tensor): meta functionality. Copying of information: + * For `c = a + b`, then auxiliary data (e.g., metadata) will be copied from the - first instance of `MetaTensor`. + first instance of `MetaTensor`. Example: .. code-block:: python @@ -50,21 +50,21 @@ class MetaTensor(MetaObj, torch.Tensor): m2 = m+m assert isinstance(m2, MetaTensor) assert m2.meta["some"] == "info" - assert m2.affine = affine + assert m2.affine == affine Notes: - Older versions of pytorch (<=1.8), `torch.jit.trace(net, im)` may - not work if `im` is of type `MetaTensor`. This can be resolved with - `torch.jit.trace(net, im.as_tensor())`. + not work if `im` is of type `MetaTensor`. This can be resolved with + `torch.jit.trace(net, im.as_tensor())`. - A warning will be raised if in the constructor `affine` is not `None` and - `meta` already contains the key `affine`. + `meta` already contains the key `affine`. """ @staticmethod def __new__(cls, x, affine: torch.Tensor | None = None, meta: dict | None = None, *args, **kwargs) -> MetaTensor: """ If `meta` is given, use it. Else, if `meta` exists in the input tensor, use it. - Else, use the default value. Similar for the affin, except this could come from + Else, use the default value. Similar for the affine, except this could come from four places. Priority: `affine`, `meta["affine"]`, `x.affine`, `get_default_affine`. """