Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
53f7385
fixes integration tests
wyli May 31, 2022
a62d029
[MONAI] code formatting
monai-bot May 31, 2022
98c0df0
original spatial shape
wyli May 31, 2022
8de2fd5
fixes tests
wyli Jun 1, 2022
40a0226
flip/flipd
wyli Jun 1, 2022
7be5fd7
rand flip/flipd
wyli Jun 1, 2022
9cbf068
rotate/rotated
wyli Jun 1, 2022
1e9e033
rand rotate/rotated
wyli Jun 1, 2022
a09e657
Merge branch 'feature/MetaTensor' into some-spatial
wyli Jun 1, 2022
e4a4643
RandAxisFlip/RandAxisFlipd
wyli Jun 1, 2022
0ac653a
fixes local var
wyli Jun 1, 2022
c0de4d6
Merge branch 'feature/MetaTensor' into some-spatial
wyli Jun 1, 2022
2b3c96c
consistency tests
wyli Jun 1, 2022
c655b4d
Merge branch 'feature/MetaTensor' into some-spatial
wyli Jun 6, 2022
b5f8433
test tests.test_rand_axis_flip
wyli Jun 6, 2022
7e01968
fixes tests
wyli Jun 7, 2022
8af8014
error -> warnings
wyli Jun 7, 2022
31ae028
fixes tests
wyli Jun 7, 2022
46de0be
update tests
wyli Jun 7, 2022
254f347
adds resize/resized
wyli Jun 7, 2022
a1ca082
adds zoom/zoomd/randzoom/randzoomd
wyli Jun 7, 2022
0484410
fixes test inverse
wyli Jun 7, 2022
201a415
fixes inverse
wyli Jun 7, 2022
1fdbb32
fixes typing
wyli Jun 7, 2022
5bbf995
resume collate
wyli Jun 7, 2022
0f4a46c
fixes unit tests
wyli Jun 7, 2022
6bad2ac
affine/affined
wyli Jun 8, 2022
ca745cc
randaffine/randaffined
wyli Jun 8, 2022
4789c81
invertd transform
wyli Jun 8, 2022
e104891
fixes unit tests
wyli Jun 8, 2022
e56816c
fixes unit test
wyli Jun 8, 2022
b96b582
fixes tests
wyli Jun 8, 2022
6a326b8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 8, 2022
b7c6225
simpler tests
wyli Jun 8, 2022
519340d
rotate90/rotate90d
wyli Jun 8, 2022
eea12a4
update invertd tests
wyli Jun 8, 2022
c677b00
update tests
wyli Jun 9, 2022
24ed8d9
update tests
wyli Jun 9, 2022
89a5926
update revertd
wyli Jun 9, 2022
4559858
typing fixes
wyli Jun 9, 2022
a994f72
resume testtimeaug tests
wyli Jun 9, 2022
59d66e1
enable tests
wyli Jun 9, 2022
4d43324
update
wyli Jun 9, 2022
9900b31
update convert dtype
wyli Jun 9, 2022
4137cac
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 9, 2022
d34dd7c
fixes compatible inverse
wyli Jun 9, 2022
d6a1a0b
update to ignore check pop
wyli Jun 9, 2022
0cb78bf
fixes cpp ext metatensor
wyli Jun 9, 2022
5da66fe
Merge branch 'feature/MetaTensor' into some-spatial
wyli Jun 9, 2022
8b3e52d
autofix
wyli Jun 9, 2022
4152b95
fixes type convertion
wyli Jun 10, 2022
549e491
update randrotate90d tests
wyli Jun 12, 2022
83c58d1
revert unecessary changes
wyli Jun 12, 2022
f607055
sliding window inferer to preserve type
wyli Jun 12, 2022
b7f6511
tests sliding window
wyli Jun 13, 2022
9c04da4
update integration tests
wyli Jun 13, 2022
4df1563
fixes torch.mode
wyli Jun 13, 2022
39c79af
[MONAI] code formatting
monai-bot Jun 13, 2022
af2ae1d
fixes tests
wyli Jun 13, 2022
2dff3f1
fixes return types
wyli Jun 13, 2022
bc52321
fixes unit tests
wyli Jun 13, 2022
325a9b4
Merge branch 'feature/MetaTensor' into some-spatial
wyli Jun 13, 2022
9d6fdfe
fixes integration tests
wyli Jun 14, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 18 additions & 18 deletions monai/apps/detection/transforms/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ def __init__(
self.zoomer = Zoom(zoom=zoom, keep_size=keep_size, **kwargs)
self.keep_size = keep_size

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]:
d = dict(data)

# zoom box
Expand All @@ -408,7 +408,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N
box_key,
extra_info={"zoom": self.zoomer.zoom, "src_spatial_size": src_spatial_size, "type": "box_key"},
)
d[box_key] = ZoomBox(zoom=self.zoomer.zoom, keep_size=self.keep_size)(
d[box_key] = ZoomBox(zoom=self.zoomer.zoom, keep_size=self.keep_size)( # type: ignore
d[box_key], src_spatial_size=src_spatial_size
)

Expand All @@ -431,7 +431,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N

return d

def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]:
d = deepcopy(dict(data))

for key in self.key_iterator(d):
Expand Down Expand Up @@ -461,7 +461,7 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd
zoom = np.array(transform[TraceKeys.EXTRA_INFO]["zoom"])
src_spatial_size = transform[TraceKeys.EXTRA_INFO]["src_spatial_size"]
box_inverse_transform = ZoomBox(zoom=(1 / zoom).tolist(), keep_size=self.zoomer.keep_size)
d[key] = box_inverse_transform(d[key], src_spatial_size=src_spatial_size)
d[key] = box_inverse_transform(d[key], src_spatial_size=src_spatial_size) # type: ignore

# Remove the applied transform
self.pop_transform(d, key)
Expand Down Expand Up @@ -545,7 +545,7 @@ def set_random_state(
self.rand_zoom.set_random_state(seed, state)
return self

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]:
d = dict(data)
first_key: Union[Hashable, List] = self.first_key(d)
if first_key == []:
Expand All @@ -568,7 +568,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N
box_key,
extra_info={"zoom": self.rand_zoom._zoom, "src_spatial_size": src_spatial_size, "type": "box_key"},
)
d[box_key] = ZoomBox(zoom=self.rand_zoom._zoom, keep_size=self.keep_size)(
d[box_key] = ZoomBox(zoom=self.rand_zoom._zoom, keep_size=self.keep_size)( # type: ignore
d[box_key], src_spatial_size=src_spatial_size
)

Expand All @@ -595,7 +595,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N

return d

def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]:
d = deepcopy(dict(data))

for key in self.key_iterator(d):
Expand Down Expand Up @@ -626,7 +626,7 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd
zoom = np.array(transform[TraceKeys.EXTRA_INFO]["zoom"])
src_spatial_size = transform[TraceKeys.EXTRA_INFO]["src_spatial_size"]
box_inverse_transform = ZoomBox(zoom=(1.0 / zoom).tolist(), keep_size=self.rand_zoom.keep_size)
d[key] = box_inverse_transform(d[key], src_spatial_size=src_spatial_size)
d[key] = box_inverse_transform(d[key], src_spatial_size=src_spatial_size) # type: ignore

# Remove the applied transform
self.pop_transform(d, key)
Expand Down Expand Up @@ -667,7 +667,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N
d = dict(data)

for key in self.image_keys:
d[key] = self.flipper(d[key])
d[key] = self.flipper(d[key]) # type: ignore
self.push_transform(d, key, extra_info={"type": "image_key"})

for box_key, box_ref_image_key in zip(self.box_keys, self.box_ref_image_keys):
Expand All @@ -685,7 +685,7 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd

# flip image, copied from monai.transforms.spatial.dictionary.Flipd
if key_type == "image_key":
d[key] = self.flipper(d[key])
d[key] = self.flipper(d[key]) # type: ignore

# flip boxes
if key_type == "box_key":
Expand Down Expand Up @@ -743,7 +743,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N

for key in self.image_keys:
if self._do_transform:
d[key] = self.flipper(d[key], randomize=False)
d[key] = self.flipper(d[key], randomize=False) # type: ignore
self.push_transform(d, key, extra_info={"type": "image_key"})

for box_key, box_ref_image_key in zip(self.box_keys, self.box_ref_image_keys):
Expand All @@ -763,7 +763,7 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd
if transform[TraceKeys.DO_TRANSFORM]:
# flip image, copied from monai.transforms.spatial.dictionary.RandFlipd
if key_type == "image_key":
d[key] = self.flipper(d[key], randomize=False)
d[key] = self.flipper(d[key], randomize=False) # type: ignore

# flip boxes
if key_type == "box_key":
Expand Down Expand Up @@ -1271,7 +1271,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable
self.push_transform(d, key, extra_info={"spatial_size": spatial_size, "type": "box_key"})

for key in self.image_keys:
d[key] = self.img_rotator(d[key])
d[key] = self.img_rotator(d[key]) # type: ignore
self.push_transform(d, key, extra_info={"type": "image_key"})
return d

Expand All @@ -1285,7 +1285,7 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd

if key_type == "image_key":
inverse_transform = Rotate90(num_times_to_rotate, self.img_rotator.spatial_axes)
d[key] = inverse_transform(d[key])
d[key] = inverse_transform(d[key]) # type: ignore
if key_type == "box_key":
spatial_size = transform[TraceKeys.EXTRA_INFO]["spatial_size"]
inverse_transform = RotateBox90(num_times_to_rotate, self.box_rotator.spatial_axes)
Expand Down Expand Up @@ -1329,7 +1329,7 @@ def __init__(
super().__init__(self.image_keys + self.box_keys, prob, max_k, spatial_axes, allow_missing_keys)
self.box_ref_image_keys = ensure_tuple_rep(box_ref_image_keys, len(self.box_keys))

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable, NdarrayOrTensor]:
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable, NdarrayOrTensor]: # type: ignore
self.randomize()
d = dict(data)

Expand Down Expand Up @@ -1357,11 +1357,11 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable

for key in self.image_keys:
if self._do_transform:
d[key] = img_rotator(d[key])
d[key] = img_rotator(d[key]) # type: ignore
self.push_transform(d, key, extra_info={"rand_k": self._rand_k, "type": "image_key"})
return d

def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: # type: ignore
d = deepcopy(dict(data))
if self._rand_k % 4 == 0:
return d
Expand All @@ -1376,7 +1376,7 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd
# flip image, copied from monai.transforms.spatial.dictionary.RandFlipd
if key_type == "image_key":
inverse_transform = Rotate90(num_times_to_rotate, self.spatial_axes)
d[key] = inverse_transform(d[key])
d[key] = inverse_transform(d[key]) # type: ignore
if key_type == "box_key":
spatial_size = transform[TraceKeys.EXTRA_INFO]["spatial_size"]
inverse_transform = RotateBox90(num_times_to_rotate, self.spatial_axes)
Expand Down
8 changes: 6 additions & 2 deletions monai/data/image_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from typing import TYPE_CHECKING, Dict, Mapping, Optional, Sequence, Union

import numpy as np
Expand Down Expand Up @@ -269,6 +270,9 @@ def resample_if_needed(
resampler = SpatialResample(mode=mode, padding_mode=padding_mode, align_corners=align_corners, dtype=dtype)
output_array = resampler(data_array[None], dst_affine=target_affine, spatial_size=output_spatial_shape)
# convert back at the end
if isinstance(output_array, MetaTensor):
warnings.warn("ignoring the tracking transform info.")
output_array.applied_operations = []
data_array, *_ = convert_data_type(output_array, output_type=orig_type) # type: ignore
affine, *_ = convert_data_type(output_array.affine, output_type=orig_type) # type: ignore
return data_array[0], affine
Expand Down Expand Up @@ -764,11 +768,11 @@ def resample_and_clip(
_min, _max = np.min(data), np.max(data)
if len(data.shape) == 3:
data = np.moveaxis(data, -1, 0) # to channel first
data = xform(data) # type: ignore
data = convert_data_type(xform(data), np.ndarray, drop_meta=True)[0] # type: ignore
data = np.moveaxis(data, 0, -1)
else: # (H, W)
data = np.expand_dims(data, 0) # make a channel
data = xform(data)[0] # type: ignore
data = convert_data_type(xform(data), np.ndarray, drop_meta=True)[0][0] # type: ignore
if mode != InterpolateMode.NEAREST:
data = np.clip(data, _min, _max)
return data
Expand Down
8 changes: 5 additions & 3 deletions monai/data/meta_obj.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def __repr__(self) -> str:
@property
def meta(self) -> dict:
"""Get the meta."""
return self._meta
return self._meta if hasattr(self, "_meta") else self.get_default_meta()

@meta.setter
def meta(self, d) -> None:
Expand All @@ -195,7 +195,9 @@ def meta(self, d) -> None:
@property
def applied_operations(self) -> list:
"""Get the applied operations."""
return self._applied_operations
if hasattr(self, "_applied_operations"):
return self._applied_operations
return self.get_default_applied_operations()

@applied_operations.setter
def applied_operations(self, t) -> None:
Expand All @@ -215,7 +217,7 @@ def pop_applied_operation(self) -> Any:
@property
def is_batch(self) -> bool:
"""Return whether object is part of batch or not."""
return self._is_batch
return self._is_batch if hasattr(self, "_is_batch") else False

@is_batch.setter
def is_batch(self, val: bool) -> None:
Expand Down
28 changes: 22 additions & 6 deletions monai/data/meta_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,9 @@ def update_meta(rets: Sequence, func, args, kwargs) -> Sequence:
# else, handle the `MetaTensor` metadata.
else:
meta_args = MetaObj.flatten_meta_objs(list(args) + list(kwargs.values()))
# this is not implemented but the network arch may run into this case:
# if func == torch.cat and any(m.is_batch if hasattr(m, "is_batch") else False for m in meta_args):
# raise NotImplementedError("torch.cat is not implemented for batch of MetaTensors.")
ret._copy_meta(meta_args)

# If we have a batch of data, then we need to be careful if a slice of
Expand All @@ -195,17 +198,17 @@ def update_meta(rets: Sequence, func, args, kwargs) -> Sequence:
metas = decollate_batch(ret.meta)
# if indexing e.g., `batch[0]`
if func == torch.Tensor.__getitem__:
idx = args[1]
if isinstance(idx, Sequence):
idx = idx[0]
batch_idx = args[1]
if isinstance(batch_idx, Sequence):
batch_idx = batch_idx[0]
# if using e.g., `batch[:, -1]` or `batch[..., -1]`, then the
# first element will be `slice(None, None, None)` and `Ellipsis`,
# respectively. Don't need to do anything with the metadata.
if idx not in (slice(None, None, None), Ellipsis):
meta = metas[idx]
if batch_idx not in (slice(None, None, None), Ellipsis):
meta = metas[batch_idx]
# if using e.g., `batch[0:2]`, then `is_batch` should still be
# `True`. Also re-collate the remaining elements.
if isinstance(meta, list) and len(meta) > 1:
if isinstance(meta, list):
ret.meta = list_data_collate(meta)
# if using e.g., `batch[0]` or `batch[0, 1]`, then return single
# element from batch, and set `is_batch` to `False`.
Expand Down Expand Up @@ -243,6 +246,19 @@ def __torch_function__(cls, func, types, args=(), kwargs=None) -> Any:
# we might have 1 or multiple outputs. Might be MetaTensor, might be something
# else (e.g., `__repr__` returns a string).
# Convert to list (if necessary), process, and at end remove list if one was added.
if (
hasattr(torch, "return_types")
and hasattr(func, "__name__")
and hasattr(torch.return_types, func.__name__)
and isinstance(getattr(torch.return_types, func.__name__), type)
and isinstance(ret, getattr(torch.return_types, func.__name__))
):
# for torch.max(torch.tensor(1.0), dim=0), the return type is named-tuple like
out_items = MetaTensor.update_meta(ret, func, args, kwargs)
for idx in range(ret.n_fields):
ret[idx].meta = out_items[idx].meta
ret[idx].applied_operations = out_items[idx].applied_operations
return ret
if isinstance(ret, (str, bytes)) or not isinstance(ret, Sequence):
ret = [ret]
unpack = True
Expand Down
13 changes: 10 additions & 3 deletions monai/data/png_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,14 @@
import numpy as np

from monai.transforms.spatial.array import Resize
from monai.utils import InterpolateMode, deprecated, ensure_tuple_rep, look_up_option, optional_import
from monai.utils import (
InterpolateMode,
convert_data_type,
deprecated,
ensure_tuple_rep,
look_up_option,
optional_import,
)

Image, _ = optional_import("PIL", name="Image")

Expand Down Expand Up @@ -74,9 +81,9 @@ def write_png(
if scale is not None:
data = np.clip(data, 0.0, 1.0) # png writer only can scale data in range [0, 1]
if scale == np.iinfo(np.uint8).max:
data = (scale * data).astype(np.uint8, copy=False)
data = convert_data_type((scale * data), np.ndarray, dtype=np.uint8, drop_meta=True)[0]
elif scale == np.iinfo(np.uint16).max:
data = (scale * data).astype(np.uint16, copy=False)
data = convert_data_type((scale * data), np.ndarray, dtype=np.uint16, drop_meta=True)[0]
else:
raise ValueError(f"Unsupported scale: {scale}, available options are [255, 65535]")

Expand Down
36 changes: 22 additions & 14 deletions monai/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,24 @@ def dev_collate(batch, level: int = 1, logger_name: str = "dev_collate"):
return


def collate_meta_tensor(batch):
"""collate a sequence of meta tensor sequences/dictionaries into
a single batched metatensor or a dictionary of batched metatensor"""
if not isinstance(batch, Sequence):
raise NotImplementedError()
elem_0 = first(batch)
if isinstance(elem_0, MetaObj):
collated = default_collate(batch)
collated.meta = default_collate([i.meta or TraceKeys.NONE for i in batch])
collated.applied_operations = [i.applied_operations or TraceKeys.NONE for i in batch]
collated.is_batch = True
return collated
if isinstance(elem_0, Mapping):
return {k: collate_meta_tensor([d[k] for d in batch]) for k in elem_0}
# no more recursive search for MetaTensor
return default_collate(batch)


def list_data_collate(batch: Sequence):
"""
Enhancement for PyTorch DataLoader default collate.
Expand All @@ -411,19 +429,9 @@ def list_data_collate(batch: Sequence):
for k in elem:
key = k
data_for_batch = [d[key] for d in data]
ret[key] = default_collate(data_for_batch)
if isinstance(ret[key], MetaObj) and all(isinstance(d, MetaObj) for d in data_for_batch):
meta_list = [i.meta or TraceKeys.NONE for i in data_for_batch]
ret[key].meta = default_collate(meta_list)
ops_list = [i.applied_operations or TraceKeys.NONE for i in data_for_batch]
ret[key].applied_operations = default_collate(ops_list)
ret[key].is_batch = True
ret[key] = collate_meta_tensor(data_for_batch)
else:
ret = default_collate(data)
if isinstance(ret, MetaObj) and all(isinstance(d, MetaObj) for d in data):
ret.meta = default_collate([i.meta or TraceKeys.NONE for i in data])
ret.applied_operations = default_collate([i.applied_operations or TraceKeys.NONE for i in data])
ret.is_batch = True
ret = collate_meta_tensor(data)
return ret
except RuntimeError as re:
re_str = str(re)
Expand Down Expand Up @@ -550,7 +558,7 @@ def decollate_batch(batch, detach: bool = True, pad=True, fill_value=None):
if isinstance(t, MetaObj):
t.meta = m
t.is_batch = False
for t, m in zip(out_list, decollate_batch(batch.applied_operations)):
for t, m in zip(out_list, batch.applied_operations):
if isinstance(t, MetaObj):
t.applied_operations = m
t.is_batch = False
Expand Down Expand Up @@ -848,7 +856,7 @@ def to_affine_nd(r: Union[np.ndarray, int], affine: NdarrayTensor, dtype=np.floa
an (r+1) x (r+1) matrix (tensor or ndarray depends on the input ``affine`` data type)

"""
affine_np = convert_data_type(affine, output_type=np.ndarray, dtype=dtype, wrap_sequence=True)[0]
affine_np = convert_data_type(affine, output_type=np.ndarray, dtype=dtype, wrap_sequence=True, drop_meta=True)[0]
affine_np = affine_np.copy()
if affine_np.ndim != 2:
raise ValueError(f"affine must have 2 dimensions, got {affine_np.ndim}.")
Expand Down
Loading