Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions monai/data/test_time_augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from monai.transforms.inverse_batch_transform import BatchInverseTransform
from monai.transforms.transform import Randomizable
from monai.transforms.utils import allow_missing_keys_mode, convert_inverse_interp_mode
from monai.utils.enums import CommonKeys, InverseKeys
from monai.utils.enums import CommonKeys, TraceKeys
from monai.utils.module import optional_import

if TYPE_CHECKING:
Expand Down Expand Up @@ -168,7 +168,7 @@ def __call__(
ds = Dataset(data_in, self.transform)
dl = DataLoader(ds, num_workers=self.num_workers, batch_size=self.batch_size, collate_fn=pad_list_data_collate)

transform_key = self.orig_key + InverseKeys.KEY_SUFFIX
transform_key = InvertibleTransform.trace_key(self.orig_key)

# create inverter
inverter = BatchInverseTransform(self.transform, dl, collate_fn=list_data_collate)
Expand All @@ -188,7 +188,7 @@ def __call__(
transform_info = batch_data.get(transform_key, None)
if transform_info is None:
# no invertible transforms, adding dummy info for identity invertible
transform_info = [[InverseKeys.NONE] for _ in range(self.batch_size)]
transform_info = [[TraceKeys.NONE] for _ in range(self.batch_size)]
if self.nearest_interp:
transform_info = convert_inverse_interp_mode(
trans_info=deepcopy(transform_info), mode="nearest", align_corners=None
Expand Down
2 changes: 1 addition & 1 deletion monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@
ThresholdIntensityD,
ThresholdIntensityDict,
)
from .inverse import InvertibleTransform
from .inverse import InvertibleTransform, TraceableTransform
from .inverse_batch_transform import BatchInverseTransform, Decollated
from .io.array import SUPPORTED_READERS, LoadImage, SaveImage
from .io.dictionary import LoadImaged, LoadImageD, LoadImageDict, SaveImaged, SaveImageD, SaveImageDict
Expand Down
8 changes: 4 additions & 4 deletions monai/transforms/compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
apply_transform,
)
from monai.utils import MAX_SEED, ensure_tuple, get_seed
from monai.utils.enums import InverseKeys
from monai.utils.enums import TraceKeys

__all__ = ["Compose", "OneOf"]

Expand Down Expand Up @@ -237,7 +237,7 @@ def __call__(self, data):
# if the data is a mapping (dictionary), append the OneOf transform to the end
if isinstance(data, Mapping):
for key in data.keys():
if key + InverseKeys.KEY_SUFFIX in data:
if self.trace_key(key) in data:
self.push_transform(data, key, extra_info={"index": index})
return data

Expand All @@ -250,9 +250,9 @@ def inverse(self, data):
# loop until we get an index and then break (since they'll all be the same)
index = None
for key in data.keys():
if key + InverseKeys.KEY_SUFFIX in data:
if self.trace_key(key) in data:
# get the index of the applied OneOf transform
index = self.get_most_recent_transform(data, key)[InverseKeys.EXTRA_INFO]["index"]
index = self.get_most_recent_transform(data, key)[TraceKeys.EXTRA_INFO]["index"]
# and then remove the OneOf transform
self.pop_transform(data, key)
if index is None:
Expand Down
6 changes: 3 additions & 3 deletions monai/transforms/croppad/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from monai.data.utils import list_data_collate
from monai.transforms.croppad.array import CenterSpatialCrop, SpatialPad
from monai.transforms.inverse import InvertibleTransform
from monai.utils.enums import InverseKeys, Method, NumpyPadMode
from monai.utils.enums import Method, NumpyPadMode, TraceKeys

__all__ = ["PadListDataCollate"]

Expand Down Expand Up @@ -115,12 +115,12 @@ def inverse(data: dict) -> Dict[Hashable, np.ndarray]:

d = deepcopy(data)
for key in d:
transform_key = str(key) + InverseKeys.KEY_SUFFIX
transform_key = InvertibleTransform.trace_key(key)
if transform_key in d:
transform = d[transform_key][-1]
if not isinstance(transform, Dict):
continue
if transform.get(InverseKeys.CLASS_NAME) == PadListDataCollate.__name__:
if transform.get(TraceKeys.CLASS_NAME) == PadListDataCollate.__name__:
d[key] = CenterSpatialCrop(transform.get("orig_size", -1))(d[key]) # fallback to image size
# remove transform
d[transform_key].pop()
Expand Down
46 changes: 23 additions & 23 deletions monai/transforms/croppad/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
)
from monai.utils import ImageMetaKey as Key
from monai.utils import Method, NumpyPadMode, PytorchPadMode, ensure_tuple, ensure_tuple_rep, fall_back_tuple, first
from monai.utils.enums import InverseKeys
from monai.utils.enums import TraceKeys

__all__ = [
"PadModeSequence",
Expand Down Expand Up @@ -163,7 +163,7 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd
for key in self.key_iterator(d):
transform = self.get_most_recent_transform(d, key)
# Create inverse transform
orig_size = transform[InverseKeys.ORIG_SIZE]
orig_size = transform[TraceKeys.ORIG_SIZE]
if self.padder.method == Method.SYMMETRIC:
current_size = d[key].shape[1:]
roi_center = [floor(i / 2) if r % 2 == 0 else (i - 1) // 2 for r, i in zip(orig_size, current_size)]
Expand Down Expand Up @@ -239,15 +239,15 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd
for key in self.key_iterator(d):
transform = self.get_most_recent_transform(d, key)
# Create inverse transform
orig_size = np.array(transform[InverseKeys.ORIG_SIZE])
orig_size = np.array(transform[TraceKeys.ORIG_SIZE])
roi_start = np.array(self.padder.spatial_border)
# Need to convert single value to [min1,min2,...]
if roi_start.size == 1:
roi_start = np.full((len(orig_size)), roi_start)
# need to convert [min1,max1,min2,...] to [min1,min2,...]
elif roi_start.size == 2 * orig_size.size:
roi_start = roi_start[::2]
roi_end = np.array(transform[InverseKeys.ORIG_SIZE]) + roi_start
roi_end = np.array(transform[TraceKeys.ORIG_SIZE]) + roi_start

inverse_transform = SpatialCrop(roi_start=roi_start, roi_end=roi_end)
# Apply inverse transform
Expand Down Expand Up @@ -315,7 +315,7 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd
for key in self.key_iterator(d):
transform = self.get_most_recent_transform(d, key)
# Create inverse transform
orig_size = np.array(transform[InverseKeys.ORIG_SIZE])
orig_size = np.array(transform[TraceKeys.ORIG_SIZE])
current_size = np.array(d[key].shape[1:])
roi_start = np.floor((current_size - orig_size) / 2)
roi_end = orig_size + roi_start
Expand Down Expand Up @@ -384,7 +384,7 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd
for key in self.key_iterator(d):
transform = self.get_most_recent_transform(d, key)
# Create inverse transform
orig_size = np.array(transform[InverseKeys.ORIG_SIZE])
orig_size = np.array(transform[TraceKeys.ORIG_SIZE])
current_size = np.array(d[key].shape[1:])
# get required pad to start and end
pad_to_start = np.array([s.indices(o)[0] for s, o in zip(self.cropper.slices, orig_size)])
Expand Down Expand Up @@ -440,7 +440,7 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd
for key in self.key_iterator(d):
transform = self.get_most_recent_transform(d, key)
# Create inverse transform
orig_size = np.array(transform[InverseKeys.ORIG_SIZE])
orig_size = np.array(transform[TraceKeys.ORIG_SIZE])
current_size = np.array(d[key].shape[1:])
pad_to_start = np.floor((orig_size - current_size) / 2).astype(int)
# in each direction, if original size is even and current size is odd, += 1
Expand Down Expand Up @@ -497,7 +497,7 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd
for key in self.key_iterator(d):
transform = self.get_most_recent_transform(d, key)
# Create inverse transform
orig_size = np.array(transform[InverseKeys.ORIG_SIZE])
orig_size = np.array(transform[TraceKeys.ORIG_SIZE])
current_size = np.array(d[key].shape[1:])
pad_to_start = np.floor((orig_size - current_size) / 2).astype(int)
# in each direction, if original size is even and current size is odd, += 1
Expand Down Expand Up @@ -594,12 +594,12 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd
for key in self.key_iterator(d):
transform = self.get_most_recent_transform(d, key)
# Create inverse transform
orig_size = transform[InverseKeys.ORIG_SIZE]
orig_size = transform[TraceKeys.ORIG_SIZE]
random_center = self.random_center
pad_to_start = np.empty((len(orig_size)), dtype=np.int32)
pad_to_end = np.empty((len(orig_size)), dtype=np.int32)
if random_center:
for i, _slice in enumerate(transform[InverseKeys.EXTRA_INFO]["slices"]):
for i, _slice in enumerate(transform[TraceKeys.EXTRA_INFO]["slices"]):
pad_to_start[i] = _slice[0]
pad_to_end[i] = orig_size[i] - _slice[1]
else:
Expand Down Expand Up @@ -776,8 +776,8 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> List[Dict[Hashab
cropped = self.cropper(d)
# self.cropper will have added RandSpatialCropd to the list. Change to RandSpatialCropSamplesd
for key in self.key_iterator(cropped):
cropped[str(key) + InverseKeys.KEY_SUFFIX][-1][InverseKeys.CLASS_NAME] = self.__class__.__name__ # type: ignore
cropped[str(key) + InverseKeys.KEY_SUFFIX][-1][InverseKeys.ID] = id(self) # type: ignore
cropped[self.trace_key(key)][-1][TraceKeys.CLASS_NAME] = self.__class__.__name__ # type: ignore
cropped[self.trace_key(key)][-1][TraceKeys.ID] = id(self) # type: ignore
# add `patch_index` to the meta data
for key, meta_key, meta_key_postfix in self.key_iterator(d, self.meta_keys, self.meta_key_postfix):
meta_key = meta_key or f"{key}_{meta_key_postfix}"
Expand All @@ -792,8 +792,8 @@ def inverse(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]:
# We changed the transform name from RandSpatialCropd to RandSpatialCropSamplesd
# Need to revert that since we're calling RandSpatialCropd's inverse
for key in self.key_iterator(d):
d[key + InverseKeys.KEY_SUFFIX][-1][InverseKeys.CLASS_NAME] = self.cropper.__class__.__name__
d[key + InverseKeys.KEY_SUFFIX][-1][InverseKeys.ID] = id(self.cropper)
d[self.trace_key(key)][-1][TraceKeys.CLASS_NAME] = self.cropper.__class__.__name__
d[self.trace_key(key)][-1][TraceKeys.ID] = id(self.cropper)
context_manager = allow_missing_keys_mode if self.allow_missing_keys else _nullcontext
with context_manager(self.cropper):
return self.cropper.inverse(d)
Expand Down Expand Up @@ -877,9 +877,9 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd
for key in self.key_iterator(d):
transform = self.get_most_recent_transform(d, key)
# Create inverse transform
orig_size = np.asarray(transform[InverseKeys.ORIG_SIZE])
orig_size = np.asarray(transform[TraceKeys.ORIG_SIZE])
cur_size = np.asarray(d[key].shape[1:])
extra_info = transform[InverseKeys.EXTRA_INFO]
extra_info = transform[TraceKeys.EXTRA_INFO]
box_start = np.asarray(extra_info["box_start"])
box_end = np.asarray(extra_info["box_end"])
# first crop the padding part
Expand Down Expand Up @@ -999,9 +999,9 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd
for key in self.key_iterator(d):
transform = self.get_most_recent_transform(d, key)
# Create inverse transform
orig_size = np.asarray(transform[InverseKeys.ORIG_SIZE])
orig_size = np.asarray(transform[TraceKeys.ORIG_SIZE])
current_size = np.asarray(d[key].shape[1:])
center = transform[InverseKeys.EXTRA_INFO]["center"]
center = transform[TraceKeys.EXTRA_INFO]["center"]
cropper = SpatialCrop(roi_center=center, roi_size=self.spatial_size)
# get required pad to start and end
pad_to_start = np.array([s.indices(o)[0] for s, o in zip(cropper.slices, orig_size)])
Expand Down Expand Up @@ -1179,9 +1179,9 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd
for key in self.key_iterator(d):
transform = self.get_most_recent_transform(d, key)
# Create inverse transform
orig_size = np.asarray(transform[InverseKeys.ORIG_SIZE])
orig_size = np.asarray(transform[TraceKeys.ORIG_SIZE])
current_size = np.asarray(d[key].shape[1:])
center = transform[InverseKeys.EXTRA_INFO]["center"]
center = transform[TraceKeys.EXTRA_INFO]["center"]
cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size) # type: ignore
# get required pad to start and end
pad_to_start = np.array([s.indices(o)[0] for s, o in zip(cropper.slices, orig_size)])
Expand Down Expand Up @@ -1364,9 +1364,9 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd
for key in self.key_iterator(d):
transform = self.get_most_recent_transform(d, key)
# Create inverse transform
orig_size = np.asarray(transform[InverseKeys.ORIG_SIZE])
orig_size = np.asarray(transform[TraceKeys.ORIG_SIZE])
current_size = np.asarray(d[key].shape[1:])
center = transform[InverseKeys.EXTRA_INFO]["center"]
center = transform[TraceKeys.EXTRA_INFO]["center"]
cropper = SpatialCrop(roi_center=tuple(center), roi_size=self.spatial_size) # type: ignore
# get required pad to start and end
pad_to_start = np.array([s.indices(o)[0] for s, o in zip(cropper.slices, orig_size)])
Expand Down Expand Up @@ -1432,7 +1432,7 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, Nd
for key in self.key_iterator(d):
transform = self.get_most_recent_transform(d, key)
# Create inverse transform
orig_size = np.array(transform[InverseKeys.ORIG_SIZE])
orig_size = np.array(transform[TraceKeys.ORIG_SIZE])
current_size = np.array(d[key].shape[1:])
# Unfortunately, we can't just use ResizeWithPadOrCrop with original size because of odd/even rounding.
# Instead, we first pad any smaller dimensions, and then we crop any larger dimensions.
Expand Down
Loading