Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion monai/apps/deepedit/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from monai.config import KeysCollection
from monai.transforms.transform import MapTransform, Randomizable, Transform
from monai.utils import optional_import
from monai.utils.enums import PostFix

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -68,7 +69,7 @@ def __call__(self, data):
d = dict(data)
current_shape = d[self.ref_image].shape[1:]

factor = np.divide(current_shape, d["image_meta_dict"]["dim"][1:4])
factor = np.divide(current_shape, d[PostFix.meta("image")]["dim"][1:4])
pos_clicks, neg_clicks = d["foreground"], d["background"]

pos = np.multiply(pos_clicks, factor).astype(int, copy=False).tolist() if len(pos_clicks) else []
Expand Down
15 changes: 9 additions & 6 deletions monai/apps/deepgrow/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,13 @@
from monai.transforms.transform import MapTransform, Randomizable, Transform
from monai.transforms.utils import generate_spatial_bounding_box, is_positive
from monai.utils import InterpolateMode, deprecated_arg, ensure_tuple, ensure_tuple_rep, min_version, optional_import
from monai.utils.enums import PostFix

measure, _ = optional_import("skimage.measure", "0.14.2", min_version)
distance_transform_cdt, _ = optional_import("scipy.ndimage.morphology", name="distance_transform_cdt")

DEFAULT_POST_FIX = PostFix.meta()


# Transforms to support Training for Deepgrow models
class FindAllValidSlicesd(Transform):
Expand Down Expand Up @@ -391,7 +394,7 @@ def __init__(
channel_indices: Optional[IndexSelection] = None,
margin: int = 0,
meta_keys: Optional[KeysCollection] = None,
meta_key_postfix="meta_dict",
meta_key_postfix=DEFAULT_POST_FIX,
start_coord_key: str = "foreground_start_coord",
end_coord_key: str = "foreground_end_coord",
original_shape_key: str = "foreground_original_shape",
Expand Down Expand Up @@ -493,7 +496,7 @@ def __init__(
spatial_dims: int = 2,
slice_key: str = "slice",
meta_keys: Optional[str] = None,
meta_key_postfix: str = "meta_dict",
meta_key_postfix: str = DEFAULT_POST_FIX,
dimensions: Optional[int] = None,
):
self.ref_image = ref_image
Expand Down Expand Up @@ -604,7 +607,7 @@ def __init__(
spatial_size,
margin=20,
meta_keys: Optional[KeysCollection] = None,
meta_key_postfix="meta_dict",
meta_key_postfix=DEFAULT_POST_FIX,
start_coord_key: str = "foreground_start_coord",
end_coord_key: str = "foreground_end_coord",
original_shape_key: str = "foreground_original_shape",
Expand Down Expand Up @@ -720,7 +723,7 @@ def __init__(
guidance: str,
ref_image: str,
meta_keys: Optional[str] = None,
meta_key_postfix: str = "meta_dict",
meta_key_postfix: str = DEFAULT_POST_FIX,
cropped_shape_key: str = "foreground_cropped_shape",
) -> None:
self.guidance = guidance
Expand Down Expand Up @@ -803,7 +806,7 @@ def __init__(
mode: Union[Sequence[Union[InterpolateMode, str]], InterpolateMode, str] = InterpolateMode.NEAREST,
align_corners: Union[Sequence[Optional[bool]], Optional[bool]] = None,
meta_keys: Optional[str] = None,
meta_key_postfix: str = "meta_dict",
meta_key_postfix: str = DEFAULT_POST_FIX,
start_coord_key: str = "foreground_start_coord",
end_coord_key: str = "foreground_end_coord",
original_shape_key: str = "foreground_original_shape",
Expand Down Expand Up @@ -907,7 +910,7 @@ def __init__(
guidance="guidance",
axis: int = 0,
meta_keys: Optional[KeysCollection] = None,
meta_key_postfix: str = "meta_dict",
meta_key_postfix: str = DEFAULT_POST_FIX,
allow_missing_keys: bool = False,
):
super().__init__(keys, allow_missing_keys)
Expand Down
5 changes: 4 additions & 1 deletion monai/data/dataset_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
from monai.data.dataset import Dataset
from monai.transforms import concatenate
from monai.utils import convert_data_type
from monai.utils.enums import PostFix

DEFAULT_POST_FIX = PostFix.meta()


class DatasetSummary:
Expand All @@ -42,7 +45,7 @@ def __init__(
image_key: Optional[str] = "image",
label_key: Optional[str] = "label",
meta_key: Optional[KeysCollection] = None,
meta_key_postfix: str = "meta_dict",
meta_key_postfix: str = DEFAULT_POST_FIX,
num_workers: int = 0,
**kwargs,
):
Expand Down
6 changes: 4 additions & 2 deletions monai/data/test_time_augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from monai.transforms.inverse_batch_transform import BatchInverseTransform
from monai.transforms.transform import Randomizable
from monai.transforms.utils import allow_missing_keys_mode, convert_inverse_interp_mode
from monai.utils.enums import CommonKeys, TraceKeys
from monai.utils.enums import CommonKeys, PostFix, TraceKeys
from monai.utils.module import optional_import
from monai.utils.type_conversion import convert_data_type

Expand All @@ -37,6 +37,8 @@

__all__ = ["TestTimeAugmentation"]

DEFAULT_POST_FIX = PostFix.meta()


def _identity(x):
return x
Expand Down Expand Up @@ -106,7 +108,7 @@ def __init__(
orig_key=CommonKeys.LABEL,
nearest_interp: bool = True,
orig_meta_keys: Optional[str] = None,
meta_key_postfix="meta_dict",
meta_key_postfix=DEFAULT_POST_FIX,
return_full_data: bool = False,
progress: bool = True,
) -> None:
Expand Down
4 changes: 2 additions & 2 deletions monai/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,14 +344,14 @@ def decollate_batch(batch, detach: bool = True, pad=True, fill_value=None):

batch_data = {
"image": torch.rand((2,1,10,10)),
"image_meta_dict": {"scl_slope": torch.Tensor([0.0, 0.0])}
DictPostFix.meta("image"): {"scl_slope": torch.Tensor([0.0, 0.0])}
}
out = decollate_batch(batch_data)
print(len(out))
>>> 2

print(out[0])
>>> {'image': tensor([[[4.3549e-01...43e-01]]]), 'image_meta_dict': {'scl_slope': 0.0}}
>>> {'image': tensor([[[4.3549e-01...43e-01]]]), DictPostFix.meta("image"): {'scl_slope': 0.0}}

batch_data = [torch.rand((2,1,10,10)), torch.rand((2,3,5,5))]
out = decollate_batch(batch_data)
Expand Down
11 changes: 6 additions & 5 deletions monai/transforms/croppad/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
)
from monai.utils import ImageMetaKey as Key
from monai.utils import Method, NumpyPadMode, PytorchPadMode, ensure_tuple, ensure_tuple_rep, fall_back_tuple
from monai.utils.enums import TraceKeys
from monai.utils.enums import PostFix, TraceKeys

__all__ = [
"PadModeSequence",
Expand Down Expand Up @@ -105,6 +105,7 @@

NumpyPadModeSequence = Union[Sequence[Union[NumpyPadMode, str]], NumpyPadMode, str]
PadModeSequence = Union[Sequence[Union[NumpyPadMode, PytorchPadMode, str]], NumpyPadMode, PytorchPadMode, str]
DEFAULT_POST_FIX = PostFix.meta()


class SpatialPadd(MapTransform, InvertibleTransform):
Expand Down Expand Up @@ -755,7 +756,7 @@ def __init__(
random_center: bool = True,
random_size: bool = True,
meta_keys: Optional[KeysCollection] = None,
meta_key_postfix: str = "meta_dict",
meta_key_postfix: str = DEFAULT_POST_FIX,
allow_missing_keys: bool = False,
) -> None:
MapTransform.__init__(self, keys, allow_missing_keys)
Expand Down Expand Up @@ -951,7 +952,7 @@ def __init__(
num_samples: int = 1,
center_coord_key: Optional[str] = None,
meta_keys: Optional[KeysCollection] = None,
meta_key_postfix: str = "meta_dict",
meta_key_postfix: str = DEFAULT_POST_FIX,
allow_missing_keys: bool = False,
):
MapTransform.__init__(self, keys, allow_missing_keys)
Expand Down Expand Up @@ -1103,7 +1104,7 @@ def __init__(
fg_indices_key: Optional[str] = None,
bg_indices_key: Optional[str] = None,
meta_keys: Optional[KeysCollection] = None,
meta_key_postfix: str = "meta_dict",
meta_key_postfix: str = DEFAULT_POST_FIX,
allow_smaller: bool = False,
allow_missing_keys: bool = False,
) -> None:
Expand Down Expand Up @@ -1302,7 +1303,7 @@ def __init__(
image_threshold: float = 0.0,
indices_key: Optional[str] = None,
meta_keys: Optional[KeysCollection] = None,
meta_key_postfix: str = "meta_dict",
meta_key_postfix: str = DEFAULT_POST_FIX,
allow_smaller: bool = False,
allow_missing_keys: bool = False,
) -> None:
Expand Down
7 changes: 5 additions & 2 deletions monai/transforms/intensity/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
from monai.transforms.utils import is_positive
from monai.utils import ensure_tuple, ensure_tuple_rep
from monai.utils.deprecate_utils import deprecated_arg
from monai.utils.enums import PostFix

__all__ = [
"RandGaussianNoised",
Expand Down Expand Up @@ -147,6 +148,8 @@
"RandKSpaceSpikeNoiseDict",
]

DEFAULT_POST_FIX = PostFix.meta()


class RandGaussianNoised(RandomizableTransform, MapTransform):
"""
Expand Down Expand Up @@ -285,7 +288,7 @@ def __init__(
offset: float,
factor_key: Optional[str] = None,
meta_keys: Optional[KeysCollection] = None,
meta_key_postfix: str = "meta_dict",
meta_key_postfix: str = DEFAULT_POST_FIX,
allow_missing_keys: bool = False,
) -> None:
"""
Expand Down Expand Up @@ -342,7 +345,7 @@ def __init__(
offsets: Union[Tuple[float, float], float],
factor_key: Optional[str] = None,
meta_keys: Optional[KeysCollection] = None,
meta_key_postfix: str = "meta_dict",
meta_key_postfix: str = DEFAULT_POST_FIX,
prob: float = 0.1,
allow_missing_keys: bool = False,
) -> None:
Expand Down
7 changes: 5 additions & 2 deletions monai/transforms/io/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,12 @@
from monai.transforms.io.array import LoadImage, SaveImage
from monai.transforms.transform import MapTransform
from monai.utils import GridSampleMode, GridSamplePadMode, InterpolateMode, ensure_tuple, ensure_tuple_rep
from monai.utils.enums import PostFix

__all__ = ["LoadImaged", "LoadImageD", "LoadImageDict", "SaveImaged", "SaveImageD", "SaveImageDict"]

DEFAULT_POST_FIX = PostFix.meta()


class LoadImaged(MapTransform):
"""
Expand Down Expand Up @@ -66,7 +69,7 @@ def __init__(
reader: Optional[Union[ImageReader, str]] = None,
dtype: DtypeLike = np.float32,
meta_keys: Optional[KeysCollection] = None,
meta_key_postfix: str = "meta_dict",
meta_key_postfix: str = DEFAULT_POST_FIX,
overwriting: bool = False,
image_only: bool = False,
allow_missing_keys: bool = False,
Expand Down Expand Up @@ -216,7 +219,7 @@ def __init__(
self,
keys: KeysCollection,
meta_keys: Optional[KeysCollection] = None,
meta_key_postfix: str = "meta_dict",
meta_key_postfix: str = DEFAULT_POST_FIX,
output_dir: Union[Path, str] = "./",
output_postfix: str = "trans",
output_ext: str = ".nii.gz",
Expand Down
7 changes: 5 additions & 2 deletions monai/transforms/post/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from monai.transforms.utility.array import ToTensor
from monai.transforms.utils import allow_missing_keys_mode, convert_inverse_interp_mode
from monai.utils import deprecated_arg, ensure_tuple, ensure_tuple_rep
from monai.utils.enums import PostFix

__all__ = [
"ActivationsD",
Expand Down Expand Up @@ -77,6 +78,8 @@
"VoteEnsembled",
]

DEFAULT_POST_FIX = PostFix.meta()


class Activationsd(MapTransform):
"""
Expand Down Expand Up @@ -544,7 +547,7 @@ def __init__(
orig_keys: KeysCollection,
meta_keys: Optional[KeysCollection] = None,
orig_meta_keys: Optional[KeysCollection] = None,
meta_key_postfix: str = "meta_dict",
meta_key_postfix: str = DEFAULT_POST_FIX,
nearest_interp: Union[bool, Sequence[bool]] = True,
to_tensor: Union[bool, Sequence[bool]] = True,
device: Union[Union[str, torch.device], Sequence[Union[str, torch.device]]] = "cpu",
Expand Down Expand Up @@ -669,7 +672,7 @@ def __init__(
self,
keys: KeysCollection,
meta_keys: Optional[KeysCollection] = None,
meta_key_postfix: str = "meta_dict",
meta_key_postfix: str = DEFAULT_POST_FIX,
saver: Optional[CSVSaver] = None,
output_dir: PathLike = "./",
filename: str = "predictions.csv",
Expand Down
7 changes: 4 additions & 3 deletions monai/transforms/spatial/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
fall_back_tuple,
)
from monai.utils.deprecate_utils import deprecated_arg
from monai.utils.enums import TraceKeys
from monai.utils.enums import PostFix, TraceKeys
from monai.utils.module import optional_import
from monai.utils.type_conversion import convert_data_type, convert_to_dst_type

Expand Down Expand Up @@ -128,6 +128,7 @@
GridSamplePadModeSequence = Union[Sequence[Union[GridSamplePadMode, str]], GridSamplePadMode, str]
InterpolateModeSequence = Union[Sequence[Union[InterpolateMode, str]], InterpolateMode, str]
PadModeSequence = Union[Sequence[Union[NumpyPadMode, PytorchPadMode, str]], NumpyPadMode, PytorchPadMode, str]
DEFAULT_POST_FIX = PostFix.meta()


class Spacingd(MapTransform, InvertibleTransform):
Expand Down Expand Up @@ -156,7 +157,7 @@ def __init__(
align_corners: Union[Sequence[bool], bool] = False,
dtype: Optional[Union[Sequence[DtypeLike], DtypeLike]] = np.float64,
meta_keys: Optional[KeysCollection] = None,
meta_key_postfix: str = "meta_dict",
meta_key_postfix: str = DEFAULT_POST_FIX,
allow_missing_keys: bool = False,
) -> None:
"""
Expand Down Expand Up @@ -315,7 +316,7 @@ def __init__(
as_closest_canonical: bool = False,
labels: Optional[Sequence[Tuple[str, str]]] = tuple(zip("LPI", "RAS")),
meta_keys: Optional[KeysCollection] = None,
meta_key_postfix: str = "meta_dict",
meta_key_postfix: str = DEFAULT_POST_FIX,
allow_missing_keys: bool = False,
) -> None:
"""
Expand Down
8 changes: 5 additions & 3 deletions monai/transforms/utility/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
from monai.transforms.utils import extreme_points_to_image, get_extreme_points
from monai.transforms.utils_pytorch_numpy_unification import concatenate
from monai.utils import convert_to_numpy, deprecated_arg, ensure_tuple, ensure_tuple_rep
from monai.utils.enums import TraceKeys, TransformBackends
from monai.utils.enums import PostFix, TraceKeys, TransformBackends
from monai.utils.type_conversion import convert_to_dst_type

__all__ = [
Expand Down Expand Up @@ -180,6 +180,8 @@
"ClassesToIndicesDict",
]

DEFAULT_POST_FIX = PostFix.meta()


class Identityd(MapTransform):
"""
Expand Down Expand Up @@ -291,7 +293,7 @@ def __init__(
self,
keys: KeysCollection,
meta_keys: Optional[KeysCollection] = None,
meta_key_postfix: str = "meta_dict",
meta_key_postfix: str = DEFAULT_POST_FIX,
strict_check: bool = True,
) -> None:
"""
Expand Down Expand Up @@ -1464,7 +1466,7 @@ def __init__(
mask_keys: Optional[KeysCollection] = None,
channel_wise: bool = False,
meta_keys: Optional[KeysCollection] = None,
meta_key_postfix: str = "meta_dict",
meta_key_postfix: str = DEFAULT_POST_FIX,
allow_missing_keys: bool = False,
) -> None:
super().__init__(keys, allow_missing_keys)
Expand Down
1 change: 1 addition & 0 deletions monai/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
Method,
MetricReduction,
NumpyPadMode,
PostFix,
PytorchPadMode,
SkipMode,
TraceKeys,
Expand Down
Loading