From 26224025458b6fac52f72efc85e25340cb85abe8 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 14 Apr 2022 14:32:00 +0100 Subject: [PATCH 01/11] collate , decollate, dataset, dataloader, out= Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/meta_tensor.py | 3 ++ monai/data/utils.py | 18 ++++++++-- tests/test_meta_tensor.py | 73 +++++++++++++++++++++++++++++++++++---- 3 files changed, 85 insertions(+), 9 deletions(-) diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index 30270d89e2..00943ba433 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -102,6 +102,9 @@ def __torch_function__(cls, func, types, args=(), kwargs=None) -> torch.Tensor: if kwargs is None: kwargs = {} ret: MetaTensor = super().__torch_function__(func, types, args, kwargs) + # if `out` has been used as argument, metadata is not copied, nothing to do. + if "out" in kwargs: + return ret # e.g., __repr__ returns a string if not isinstance(ret, torch.Tensor): return ret diff --git a/monai/data/utils.py b/monai/data/utils.py index 495daf15e2..5a3c1235a2 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -28,6 +28,7 @@ from torch.utils.data._utils.collate import default_collate from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor, PathLike +from monai.data.meta_tensor import MetaTensor from monai.networks.layers.simplelayers import GaussianFilter from monai.utils import ( MAX_SEED, @@ -346,9 +347,15 @@ def list_data_collate(batch: Sequence): ret = {} for k in elem: key = k - ret[key] = default_collate([d[key] for d in data]) - return ret - return default_collate(data) + data_for_batch = [d[key] for d in data] + ret[key] = default_collate(data_for_batch) + if isinstance(ret[key], MetaTensor) and all(isinstance(d, MetaTensor) for d in data_for_batch): + ret[key].meta = list_data_collate([i.meta for i in data_for_batch]) + else: + ret = default_collate(data) + if isinstance(ret, MetaTensor) and all(isinstance(d, MetaTensor) for d in data): + ret.meta = list_data_collate([i.meta for i in data]) + return ret except RuntimeError as re: re_str = str(re) if "equal size" in re_str: @@ -466,6 +473,11 @@ def decollate_batch(batch, detach: bool = True, pad=True, fill_value=None): if batch.ndim == 0: return batch.item() if detach else batch out_list = torch.unbind(batch, dim=0) + # if of type MetaTensor, decollate the metadata and affines + if isinstance(batch, MetaTensor): + metas = decollate_batch(batch.meta) + for i in range(len(out_list)): + out_list[i].meta = metas[i] if out_list[0].ndim == 0 and detach: return [t.item() for t in out_list] return list(out_list) diff --git a/tests/test_meta_tensor.py b/tests/test_meta_tensor.py index c18ef08b85..90bfb8d835 100644 --- a/tests/test_meta_tensor.py +++ b/tests/test_meta_tensor.py @@ -21,8 +21,10 @@ import torch from parameterized import parameterized +from monai.data import DataLoader, Dataset from monai.data.meta_obj import get_track_meta, get_track_transforms, set_track_meta, set_track_transforms from monai.data.meta_tensor import MetaTensor +from monai.data.utils import decollate_batch, list_data_collate from monai.utils.enums import PostFix from monai.utils.module import pytorch_after from tests.utils import TEST_DEVICES, assert_allclose, skip_if_no_cuda @@ -261,12 +263,71 @@ def test_amp(self): im_conv2 = conv(im) self.check(im_conv2, im_conv, ids=False, rtol=1e-4, atol=1e-3) - # TODO - # collate - # decollate - # dataset - # dataloader - # matplotlib + def test_out(self): + """Test when `out` is given as an argument.""" + m1, _ = self.get_im() + m1_orig = deepcopy(m1) + m2, _ = self.get_im() + m3, _ = self.get_im() + torch.add(m2, m3, out=m1) + m1_add = m2 + m3 + + assert_allclose(m1, m1_add) + aff1, aff1_orig = m1.affine, m1_orig.affine + assert_allclose(aff1, aff1_orig) + meta1 = {k: v for k, v in m1.meta.items() if k != "affine"} + meta1_orig = {k: v for k, v in m1_orig.meta.items() if k != "affine"} + self.assertEqual(meta1, meta1_orig) + + @parameterized.expand(TESTS) + def test_collate(self, device, dtype): + numel = 3 + ims = [self.get_im(device=device, dtype=dtype)[0] for _ in range(numel)] + collated = list_data_collate(ims) + # tensor + self.assertIsInstance(collated, MetaTensor) + expected_shape = (numel,) + tuple(ims[0].shape) + self.assertTupleEqual(tuple(collated.shape), expected_shape) + for i, im in enumerate(ims): + self.check(im, ims[i], ids=True) + # affine + self.assertIsInstance(collated.affine, torch.Tensor) + expected_shape = (numel,) + tuple(ims[0].affine.shape) + self.assertTupleEqual(tuple(collated.affine.shape), expected_shape) + + @parameterized.expand(TESTS) + def test_dataset(self, device, dtype): + ims = [self.get_im(device=device, dtype=dtype)[0] for _ in range(4)] + ds = Dataset(ims) + for i, im in enumerate(ds): + self.check(im, ims[i], ids=True) + + @parameterized.expand(DTYPES) + def test_dataloader(self, dtype): + batch_size = 5 + ims = [self.get_im(dtype=dtype)[0] for _ in range(batch_size * 2)] + ds = Dataset(ims) + expected_im_shape = (batch_size,) + tuple(ims[0].shape) + expected_affine_shape = (batch_size,) + tuple(ims[0].affine.shape) + dl = DataLoader(ds, num_workers=batch_size, batch_size=batch_size) + for batch in dl: + self.assertIsInstance(batch, MetaTensor) + self.assertTupleEqual(tuple(batch.shape), expected_im_shape) + self.assertTupleEqual(tuple(batch.affine.shape), expected_affine_shape) + + @parameterized.expand(DTYPES) + def test_decollate(self, dtype): + batch_size = 3 + ims = [self.get_im(dtype=dtype)[0] for _ in range(batch_size * 2)] + ds = Dataset(ims) + dl = DataLoader(ds, num_workers=batch_size, batch_size=batch_size) + batch = next(iter(dl)) + decollated = decollate_batch(batch) + self.assertIsInstance(decollated, list) + self.assertEqual(len(decollated), batch_size) + for elem, im in zip(decollated, ims): + self.assertIsInstance(elem, MetaTensor) + self.check(elem, im, ids=False) if __name__ == "__main__": From 19e68c9ccfa2da781464f7247ae0b406047f0bb9 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 14 Apr 2022 15:37:39 +0100 Subject: [PATCH 02/11] mypy Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/data/utils.py b/monai/data/utils.py index 5a3c1235a2..18a467803b 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -473,11 +473,11 @@ def decollate_batch(batch, detach: bool = True, pad=True, fill_value=None): if batch.ndim == 0: return batch.item() if detach else batch out_list = torch.unbind(batch, dim=0) - # if of type MetaTensor, decollate the metadata and affines + # if of type MetaTensor, decollate the metadata if isinstance(batch, MetaTensor): metas = decollate_batch(batch.meta) for i in range(len(out_list)): - out_list[i].meta = metas[i] + out_list[i].meta = metas[i] # type: ignore if out_list[0].ndim == 0 and detach: return [t.item() for t in out_list] return list(out_list) From d017918de51ff8e3ab5048d534a14fb535ff797e Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 14 Apr 2022 16:27:18 +0100 Subject: [PATCH 03/11] skip decollation for pytorch 1.7 Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_meta_tensor.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_meta_tensor.py b/tests/test_meta_tensor.py index 90bfb8d835..651c9e0d22 100644 --- a/tests/test_meta_tensor.py +++ b/tests/test_meta_tensor.py @@ -27,7 +27,7 @@ from monai.data.utils import decollate_batch, list_data_collate from monai.utils.enums import PostFix from monai.utils.module import pytorch_after -from tests.utils import TEST_DEVICES, assert_allclose, skip_if_no_cuda +from tests.utils import TEST_DEVICES, SkipIfBeforePyTorchVersion, assert_allclose, skip_if_no_cuda DTYPES = [[torch.float32], [torch.float64], [torch.float16], [torch.int64], [torch.int32]] TESTS = [] @@ -315,6 +315,7 @@ def test_dataloader(self, dtype): self.assertTupleEqual(tuple(batch.shape), expected_im_shape) self.assertTupleEqual(tuple(batch.affine.shape), expected_affine_shape) + @SkipIfBeforePyTorchVersion((1, 8)) @parameterized.expand(DTYPES) def test_decollate(self, dtype): batch_size = 3 From b36cd10c1c623b926d8e9dd6aca6acf091a00db0 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 14 Apr 2022 16:39:53 +0100 Subject: [PATCH 04/11] fix Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_meta_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_meta_tensor.py b/tests/test_meta_tensor.py index 651c9e0d22..e14695c3e1 100644 --- a/tests/test_meta_tensor.py +++ b/tests/test_meta_tensor.py @@ -315,8 +315,8 @@ def test_dataloader(self, dtype): self.assertTupleEqual(tuple(batch.shape), expected_im_shape) self.assertTupleEqual(tuple(batch.affine.shape), expected_affine_shape) - @SkipIfBeforePyTorchVersion((1, 8)) @parameterized.expand(DTYPES) + @SkipIfBeforePyTorchVersion((1, 8)) def test_decollate(self, dtype): batch_size = 3 ims = [self.get_im(dtype=dtype)[0] for _ in range(batch_size * 2)] From a8f0373aa728ec36452610071fc8f668cabf7efb Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 14 Apr 2022 18:48:08 +0100 Subject: [PATCH 05/11] fix Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_meta_tensor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_meta_tensor.py b/tests/test_meta_tensor.py index e14695c3e1..8fd31b58b6 100644 --- a/tests/test_meta_tensor.py +++ b/tests/test_meta_tensor.py @@ -303,6 +303,7 @@ def test_dataset(self, device, dtype): self.check(im, ims[i], ids=True) @parameterized.expand(DTYPES) + @SkipIfBeforePyTorchVersion((1, 8)) def test_dataloader(self, dtype): batch_size = 5 ims = [self.get_im(dtype=dtype)[0] for _ in range(batch_size * 2)] From 12afd4a4d760f8b9376b6bdda3ff90c0f6c4a60a Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 20 Apr 2022 14:20:55 +0100 Subject: [PATCH 06/11] add batch index testing Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/meta_obj.py | 13 +++++ monai/data/meta_tensor.py | 100 ++++++++++++++++++++++++++++++++---- monai/data/utils.py | 13 +++-- tests/test_meta_tensor.py | 104 +++++++++++++++++++++++++++++++++----- 4 files changed, 201 insertions(+), 29 deletions(-) diff --git a/monai/data/meta_obj.py b/monai/data/meta_obj.py index 0e213f130b..00e10ca816 100644 --- a/monai/data/meta_obj.py +++ b/monai/data/meta_obj.py @@ -111,6 +111,7 @@ class MetaObj: def __init__(self): self._meta: dict = self.get_default_meta() + self._is_batch: bool = False @staticmethod def flatten_meta_objs(args: Sequence[Any]) -> list[MetaObj]: @@ -176,6 +177,7 @@ def _copy_meta(self, input_objs: list[MetaObj]) -> None: id_in = id(input_objs[0]) if len(input_objs) > 0 else None deep_copy = id(self) != id_in self._copy_attr("meta", input_objs, self.get_default_meta, deep_copy) + self.is_batch = input_objs[0].is_batch def get_default_meta(self) -> dict: """Get the default meta. @@ -194,6 +196,7 @@ def __repr__(self) -> str: out += "".join(f"\t{k}: {v}\n" for k, v in self.meta.items()) else: out += "None" + out += f"\nIs batch?: {self.is_batch}" return out @@ -206,3 +209,13 @@ def meta(self) -> dict: def meta(self, d: dict) -> None: """Set the meta.""" self._meta = d + + @property + def is_batch(self) -> bool: + """Return whether object is part of batch or not.""" + return self._is_batch + + @is_batch.setter + def is_batch(self, val: bool) -> None: + """Set whether object is part of batch or not.""" + self._is_batch = val diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index f1d87bd5f0..3a3fdc47db 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -13,11 +13,12 @@ import warnings from copy import deepcopy -from typing import Callable +from typing import Callable, Sequence import torch from monai.data.meta_obj import MetaObj, get_track_meta, get_track_transforms +from monai.data.utils import decollate_batch, list_data_collate from monai.utils.enums import PostFix __all__ = ["MetaTensor"] @@ -59,6 +60,14 @@ class MetaTensor(MetaObj, torch.Tensor): `torch.jit.trace(net, im.as_tensor())`. - A warning will be raised if in the constructor `affine` is not `None` and `meta` already contains the key `affine`. + - You can query whether the `MetaTensor` is a batch with the `is_batch` attribute. + - With a batch of data, `batch[0]` will return the 0th image + with the 0th metadata. When the batch dimension is non-singleton, e.g., + `batch[:, 0]`, `batch[..., -1]` and `batch[1:3]`, then all (or a subset in the + last example) of the metadata will be returned, and `is_batch` will return `True`. + - When creating a batch with this class, use `monai.data.DataLoader` as opposed + to `torch.utils.data.DataLoader`, as this will take care of collating the + metadata properly. """ @staticmethod @@ -101,24 +110,93 @@ def _copy_attr(self, attribute: str, input_objs: list[MetaObj], default_fn: Call if isinstance(val, torch.Tensor): setattr(self, attribute, val.to(self.device)) + @staticmethod + def update_meta(rets: Sequence, func, args, kwargs): + """Update the metadata from the output of `__torch_function__`. + The output could be a single object, or a sequence of them. Hence, they get + converted to a sequence if necessary and then processed by iterating across them. + + For each element, if not of type `MetaTensor`, then nothing to do + """ + out = [] + metas = None + for idx, ret in enumerate(rets): + # if not `MetaTensor`, nothing to do. + if not isinstance(ret, MetaTensor): + pass + # if not tracking, convert to `torch.Tensor`. + elif not (get_track_meta() or get_track_transforms()): + ret = ret.as_tensor() + # else, handle the `MetaTensor` metadata. + else: + meta_args = MetaObj.flatten_meta_objs(list(args) + list(kwargs.values())) + ret._copy_meta(meta_args) + + # If we have a batch of data, then we need to be careful if a slice of + # the data is returned. Depending on how the data are indexed, we return + # some or all of the metadata, and the return object may or may not be a + # batch of data (e.g., `batch[:,-1]` versus `batch[0]`.) + if ret.is_batch: + # only decollate metadata once + if metas is None: + metas = decollate_batch(ret.meta) + # if indexing e.g., `batch[0]` + if func == torch.Tensor.__getitem__: + idx = args[1] + if isinstance(idx, Sequence): + idx = idx[0] + # if using e.g., `batch[:, -1]` or `batch[..., -1]`, then the + # first element will be `slice(None, None, None)` and `Ellipsis`, + # respectively. Don't need to do anything with the metadata. + if idx not in (slice(None, None, None), Ellipsis): + meta = metas[idx] + # if using e.g., `batch[0:2]`, then `is_batch` should still be + # `True`. Also re-collate the remaining elements. + if isinstance(meta, list) and len(meta) > 1: + ret.meta = list_data_collate(meta) + # if using e.g., `batch[0]` or `batch[0, 1]`, then return single + # element from batch, and set `is_batch` to `False`. + else: + ret.meta = meta + ret.is_batch = False + # `unbind` is used for `next(iter(batch))`. Also for `decollate_batch`. + # But we only want to split the batch if the `unbind` is along the 0th + # dimension. + elif func == torch.Tensor.unbind: + if len(args) > 1: + dim = args[1] + elif "dim" in kwargs: + dim = kwargs["dim"] + else: + dim = 0 + if dim == 0: + ret.meta = metas[idx] + ret.is_batch = False + + ret.affine = ret.affine.to(ret.device) + out.append(ret) + # if the input was a tuple, then return it as a tuple + return tuple(out) if isinstance(rets, tuple) else out + @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None) -> torch.Tensor: """Wraps all torch functions.""" if kwargs is None: kwargs = {} - ret: MetaTensor = super().__torch_function__(func, types, args, kwargs) + ret = super().__torch_function__(func, types, args, kwargs) # if `out` has been used as argument, metadata is not copied, nothing to do. if "out" in kwargs: return ret - # e.g., __repr__ returns a string - if not isinstance(ret, torch.Tensor): - return ret - if not (get_track_meta() or get_track_transforms()): - return ret.as_tensor() - meta_args = MetaObj.flatten_meta_objs(list(args) + list(kwargs.values())) - ret._copy_meta(meta_args) - ret.affine = ret.affine.to(ret.device) - return ret + # we might have 1 or multiple outputs. Might be MetaTensor, might be something + # else (e.g., `__repr__` returns a string). + # Convert to list (if necessary), process, and at end remove list if one was added. + if not isinstance(ret, Sequence): + ret = [ret] + unpack = True + else: + unpack = False + ret = MetaTensor.update_meta(ret, func, args, kwargs) + return ret[0] if unpack else ret def get_default_affine(self, dtype=torch.float64) -> torch.Tensor: return torch.eye(4, device=self.device, dtype=dtype) diff --git a/monai/data/utils.py b/monai/data/utils.py index 18a467803b..7f67088a06 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -28,7 +28,7 @@ from torch.utils.data._utils.collate import default_collate from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor, PathLike -from monai.data.meta_tensor import MetaTensor +from monai.data.meta_obj import MetaObj from monai.networks.layers.simplelayers import GaussianFilter from monai.utils import ( MAX_SEED, @@ -349,12 +349,14 @@ def list_data_collate(batch: Sequence): key = k data_for_batch = [d[key] for d in data] ret[key] = default_collate(data_for_batch) - if isinstance(ret[key], MetaTensor) and all(isinstance(d, MetaTensor) for d in data_for_batch): + if isinstance(ret[key], MetaObj) and all(isinstance(d, MetaObj) for d in data_for_batch): ret[key].meta = list_data_collate([i.meta for i in data_for_batch]) + ret[key].is_batch = True else: ret = default_collate(data) - if isinstance(ret, MetaTensor) and all(isinstance(d, MetaTensor) for d in data): + if isinstance(ret, MetaObj) and all(isinstance(d, MetaObj) for d in data): ret.meta = list_data_collate([i.meta for i in data]) + ret.is_batch = True return ret except RuntimeError as re: re_str = str(re) @@ -473,11 +475,12 @@ def decollate_batch(batch, detach: bool = True, pad=True, fill_value=None): if batch.ndim == 0: return batch.item() if detach else batch out_list = torch.unbind(batch, dim=0) - # if of type MetaTensor, decollate the metadata - if isinstance(batch, MetaTensor): + # if of type MetaObj, decollate the metadata + if isinstance(batch, MetaObj): metas = decollate_batch(batch.meta) for i in range(len(out_list)): out_list[i].meta = metas[i] # type: ignore + out_list[i].is_batch = False if out_list[0].ndim == 0 and detach: return [t.item() for t in out_list] return list(out_list) diff --git a/tests/test_meta_tensor.py b/tests/test_meta_tensor.py index 8fd31b58b6..7968b8202e 100644 --- a/tests/test_meta_tensor.py +++ b/tests/test_meta_tensor.py @@ -61,6 +61,17 @@ def check_ids(self, a, b, should_match): comp = self.assertEqual if should_match else self.assertNotEqual comp(id(a), id(b)) + def check_meta(self, a: MetaTensor, b: MetaTensor) -> None: + self.assertEqual(a.is_batch, b.is_batch) + meta_a, meta_b = a.meta, b.meta + # need to split affine from rest of metadata + aff_a = meta_a.get("affine", None) + aff_b = meta_b.get("affine", None) + assert_allclose(aff_a, aff_b) + meta_a = {k: v for k, v in meta_a.items() if k != "affine"} + meta_b = {k: v for k, v in meta_b.items() if k != "affine"} + self.assertEqual(meta_a, meta_b) + def check( self, out: torch.Tensor, @@ -89,12 +100,7 @@ def check( # check meta and affine are equal and affine is on correct device if isinstance(orig, MetaTensor) and isinstance(out, MetaTensor) and meta: - orig_meta_no_affine = deepcopy(orig.meta) - del orig_meta_no_affine["affine"] - out_meta_no_affine = deepcopy(out.meta) - del out_meta_no_affine["affine"] - self.assertEqual(orig_meta_no_affine, out_meta_no_affine) - assert_allclose(out.affine, orig.affine) + self.check_meta(orig, out) self.assertTrue(str(device) in str(out.affine.device)) if check_ids: self.check_ids(out.affine, orig.affine, ids) @@ -273,11 +279,7 @@ def test_out(self): m1_add = m2 + m3 assert_allclose(m1, m1_add) - aff1, aff1_orig = m1.affine, m1_orig.affine - assert_allclose(aff1, aff1_orig) - meta1 = {k: v for k, v in m1.meta.items() if k != "affine"} - meta1_orig = {k: v for k, v in m1_orig.meta.items() if k != "affine"} - self.assertEqual(meta1, meta1_orig) + self.check_meta(m1, m1_orig) @parameterized.expand(TESTS) def test_collate(self, device, dtype): @@ -308,14 +310,90 @@ def test_dataloader(self, dtype): batch_size = 5 ims = [self.get_im(dtype=dtype)[0] for _ in range(batch_size * 2)] ds = Dataset(ims) - expected_im_shape = (batch_size,) + tuple(ims[0].shape) - expected_affine_shape = (batch_size,) + tuple(ims[0].affine.shape) + im_shape = tuple(ims[0].shape) + affine_shape = tuple(ims[0].affine.shape) + expected_im_shape = (batch_size,) + im_shape + expected_affine_shape = (batch_size,) + affine_shape dl = DataLoader(ds, num_workers=batch_size, batch_size=batch_size) for batch in dl: self.assertIsInstance(batch, MetaTensor) self.assertTupleEqual(tuple(batch.shape), expected_im_shape) self.assertTupleEqual(tuple(batch.affine.shape), expected_affine_shape) + def test_indexing(self): + """ + Check the metadata is returned in the expected format depending on whether + the input `MetaTensor` is a batch of data or not. + """ + ims = [self.get_im()[0] for _ in range(5)] + data = list_data_collate(ims) + + # check that when using non-batch data, metadata is copied wholly when indexing + # or iterating across data. + im = ims[0] + self.check_meta(im[0], im) + self.check_meta(next(iter(im)), im) + + # index + d = data[0] + self.check(d, ims[0], ids=False) + + # iter + d = next(iter(data)) + self.check(d, ims[0], ids=False) + + # complex indexing + + # `is_batch==True`, should have subset of image and metadata. + d = data[1:3] + self.check(d, list_data_collate(ims[1:3]), ids=False) + + # is_batch==True, should have subset of image and same metadata as `[1:3]`. + d = data[1:3, 0] + self.check(d, list_data_collate([i[0] for i in ims[1:3]]), ids=False) + + # `is_batch==False`, should have first metadata and subset of first image. + d = data[0, 0] + self.check(d, ims[0][0], ids=False) + + # `is_batch==True`, should have all metadata and subset of all images. + d = data[:, 0] + self.check(d, list_data_collate([i[0] for i in ims]), ids=False) + + # `is_batch==True`, should have all metadata and subset of all images. + d = data[..., -1] + self.check(d, list_data_collate([i[..., -1] for i in ims]), ids=False) + + # `is_batch==False`, tuple split along batch dim. Should have individual + # metadata. + d = data.unbind(0) + self.assertIsInstance(d, tuple) + self.assertEqual(len(d), len(ims)) + for _d, _im in zip(d, ims): + self.check(_d, _im, ids=False) + + # `is_batch==False`, tuple split along batch dim. Should have individual + # metadata. + d = data.unbind(dim=0) + self.assertIsInstance(d, tuple) + self.assertEqual(len(d), len(ims)) + for _d, _im in zip(d, ims): + self.check(_d, _im, ids=False) + + # `is_batch==True`, tuple split along non-batch dim. Should have all metadata. + d = data.unbind(-1) + self.assertIsInstance(d, tuple) + self.assertEqual(len(d), ims[0].shape[-1]) + for _d in d: + self.check_meta(_d, data) + + # `is_batch==True`, tuple split along non-batch dim. Should have all metadata. + d = data.unbind(dim=-1) + self.assertIsInstance(d, tuple) + self.assertEqual(len(d), ims[0].shape[-1]) + for _d in d: + self.check_meta(_d, data) + @parameterized.expand(DTYPES) @SkipIfBeforePyTorchVersion((1, 8)) def test_decollate(self, dtype): From fb9b10f0991ba1596f656ef674c2c788ab52d891 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 20 Apr 2022 14:30:02 +0100 Subject: [PATCH 07/11] fixes Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/meta_tensor.py | 4 ++-- monai/data/utils.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index 3a3fdc47db..9bfcb0cfae 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -13,7 +13,7 @@ import warnings from copy import deepcopy -from typing import Callable, Sequence +from typing import Any, Callable, Sequence import torch @@ -179,7 +179,7 @@ def update_meta(rets: Sequence, func, args, kwargs): return tuple(out) if isinstance(rets, tuple) else out @classmethod - def __torch_function__(cls, func, types, args=(), kwargs=None) -> torch.Tensor: + def __torch_function__(cls, func, types, args=(), kwargs=None) -> Any: """Wraps all torch functions.""" if kwargs is None: kwargs = {} diff --git a/monai/data/utils.py b/monai/data/utils.py index 7f67088a06..2bd7b49731 100644 --- a/monai/data/utils.py +++ b/monai/data/utils.py @@ -476,11 +476,11 @@ def decollate_batch(batch, detach: bool = True, pad=True, fill_value=None): return batch.item() if detach else batch out_list = torch.unbind(batch, dim=0) # if of type MetaObj, decollate the metadata - if isinstance(batch, MetaObj): + if isinstance(batch, MetaObj) and all(isinstance(i, MetaObj) for i in out_list): metas = decollate_batch(batch.meta) for i in range(len(out_list)): out_list[i].meta = metas[i] # type: ignore - out_list[i].is_batch = False + out_list[i].is_batch = False # type: ignore if out_list[0].ndim == 0 and detach: return [t.item() for t in out_list] return list(out_list) From 487578441a7067cf52a7ef6aeb91e100d8d02b0c Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 20 Apr 2022 14:31:21 +0100 Subject: [PATCH 08/11] fix Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/meta_obj.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/data/meta_obj.py b/monai/data/meta_obj.py index 00e10ca816..e38e009e96 100644 --- a/monai/data/meta_obj.py +++ b/monai/data/meta_obj.py @@ -177,7 +177,7 @@ def _copy_meta(self, input_objs: list[MetaObj]) -> None: id_in = id(input_objs[0]) if len(input_objs) > 0 else None deep_copy = id(self) != id_in self._copy_attr("meta", input_objs, self.get_default_meta, deep_copy) - self.is_batch = input_objs[0].is_batch + self.is_batch = input_objs[0].is_batch if len(input_objs) > 0 else False def get_default_meta(self) -> dict: """Get the default meta. From b307d466f63d51488e5a5cdcc255b7a9ca32070b Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 20 Apr 2022 14:44:24 +0100 Subject: [PATCH 09/11] fix Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- tests/test_meta_tensor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_meta_tensor.py b/tests/test_meta_tensor.py index 7968b8202e..05356fcc84 100644 --- a/tests/test_meta_tensor.py +++ b/tests/test_meta_tensor.py @@ -320,6 +320,7 @@ def test_dataloader(self, dtype): self.assertTupleEqual(tuple(batch.shape), expected_im_shape) self.assertTupleEqual(tuple(batch.affine.shape), expected_affine_shape) + @SkipIfBeforePyTorchVersion((1, 9)) def test_indexing(self): """ Check the metadata is returned in the expected format depending on whether From e40553a5e54b481f09dae24656844f649d8c6832 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 20 Apr 2022 14:45:16 +0100 Subject: [PATCH 10/11] fix Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/meta_tensor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index 9bfcb0cfae..e3fb7846ae 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -55,6 +55,7 @@ class MetaTensor(MetaObj, torch.Tensor): assert m2.affine == affine Notes: + - Requires pytorch 1.9 or newer for full compatibility. - Older versions of pytorch (<=1.8), `torch.jit.trace(net, im)` may not work if `im` is of type `MetaTensor`. This can be resolved with `torch.jit.trace(net, im.as_tensor())`. From f2c254877663e5774703d0eb1dbe822f47b4027d Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 20 Apr 2022 14:59:51 +0100 Subject: [PATCH 11/11] fix Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/data/meta_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/data/meta_tensor.py b/monai/data/meta_tensor.py index e3fb7846ae..9196f0186c 100644 --- a/monai/data/meta_tensor.py +++ b/monai/data/meta_tensor.py @@ -136,7 +136,7 @@ def update_meta(rets: Sequence, func, args, kwargs): # If we have a batch of data, then we need to be careful if a slice of # the data is returned. Depending on how the data are indexed, we return # some or all of the metadata, and the return object may or may not be a - # batch of data (e.g., `batch[:,-1]` versus `batch[0]`.) + # batch of data (e.g., `batch[:,-1]` versus `batch[0]`). if ret.is_batch: # only decollate metadata once if metas is None: