diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index 26601de76b..53f1009a76 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -201,7 +201,7 @@ ThresholdIntensityDict, ) from .inverse import InvertibleTransform, TraceableTransform -from .inverse_batch_transform import BatchInverseTransform, Decollated +from .inverse_batch_transform import BatchInverseTransform, Decollated, DecollateD, DecollateDict from .io.array import SUPPORTED_READERS, LoadImage, SaveImage from .io.dictionary import LoadImaged, LoadImageD, LoadImageDict, SaveImaged, SaveImageD, SaveImageDict from .nvtx import ( @@ -249,6 +249,8 @@ AsDiscreted, AsDiscreteDict, Ensembled, + EnsembleD, + EnsembleDict, FillHolesD, FillHolesd, FillHolesDict, @@ -283,7 +285,17 @@ RandSmoothFieldAdjustIntensity, SmoothField, ) -from .smooth_field.dictionary import RandSmoothDeformd, RandSmoothFieldAdjustContrastd, RandSmoothFieldAdjustIntensityd +from .smooth_field.dictionary import ( + RandSmoothDeformd, + RandSmoothDeformD, + RandSmoothDeformDict, + RandSmoothFieldAdjustContrastd, + RandSmoothFieldAdjustContrastD, + RandSmoothFieldAdjustContrastDict, + RandSmoothFieldAdjustIntensityd, + RandSmoothFieldAdjustIntensityD, + RandSmoothFieldAdjustIntensityDict, +) from .spatial.array import ( Affine, AffineGrid, diff --git a/monai/transforms/inverse_batch_transform.py b/monai/transforms/inverse_batch_transform.py index ae0317cea8..cc77a199dd 100644 --- a/monai/transforms/inverse_batch_transform.py +++ b/monai/transforms/inverse_batch_transform.py @@ -23,7 +23,7 @@ from monai.transforms.transform import MapTransform, Transform from monai.utils import first -__all__ = ["BatchInverseTransform", "Decollated"] +__all__ = ["BatchInverseTransform", "Decollated", "DecollateD", "DecollateDict"] class _BatchInverseDataset(Dataset): @@ -151,3 +151,6 @@ def __call__(self, data: Union[Dict, List]): d[key] = data[key] return decollate_batch(d, detach=self.detach, pad=self.pad_batch, fill_value=self.fill_value) + + +DecollateD = DecollateDict = Decollated diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py index a7ffcb19bf..d44fb6b3fa 100644 --- a/monai/transforms/post/dictionary.py +++ b/monai/transforms/post/dictionary.py @@ -49,6 +49,8 @@ "AsDiscreteDict", "AsDiscreted", "Ensembled", + "EnsembleD", + "EnsembleDict", "FillHolesD", "FillHolesDict", "FillHolesd", @@ -752,3 +754,4 @@ def get_saver(self): ProbNMSD = ProbNMSDict = ProbNMSd SaveClassificationD = SaveClassificationDict = SaveClassificationd VoteEnsembleD = VoteEnsembleDict = VoteEnsembled +EnsembleD = EnsembleDict = Ensembled diff --git a/monai/transforms/smooth_field/dictionary.py b/monai/transforms/smooth_field/dictionary.py index c129d14f32..24890140cc 100644 --- a/monai/transforms/smooth_field/dictionary.py +++ b/monai/transforms/smooth_field/dictionary.py @@ -26,7 +26,17 @@ from monai.utils import GridSampleMode, GridSamplePadMode, InterpolateMode, ensure_tuple_rep from monai.utils.enums import TransformBackends -__all__ = ["RandSmoothFieldAdjustContrastd", "RandSmoothFieldAdjustIntensityd", "RandSmoothDeformd"] +__all__ = [ + "RandSmoothFieldAdjustContrastd", + "RandSmoothFieldAdjustIntensityd", + "RandSmoothDeformd", + "RandSmoothFieldAdjustContrastD", + "RandSmoothFieldAdjustIntensityD", + "RandSmoothDeformD", + "RandSmoothFieldAdjustContrastDict", + "RandSmoothFieldAdjustIntensityDict", + "RandSmoothDeformDict", +] InterpolateModeType = Union[InterpolateMode, str] @@ -276,3 +286,8 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Mapping[Hashable d[key] = self.trans(d[key], False, self.trans.device) return d + + +RandSmoothDeformD = RandSmoothDeformDict = RandSmoothDeformd +RandSmoothFieldAdjustIntensityD = RandSmoothFieldAdjustIntensityDict = RandSmoothFieldAdjustIntensityd +RandSmoothFieldAdjustContrastD = RandSmoothFieldAdjustContrastDict = RandSmoothFieldAdjustContrastd diff --git a/tests/test_module_list.py b/tests/test_module_list.py index 5ec4aa9ff1..ea520c59f3 100644 --- a/tests/test_module_list.py +++ b/tests/test_module_list.py @@ -10,6 +10,7 @@ # limitations under the License. import glob +import inspect import os import unittest @@ -33,6 +34,24 @@ def test_public_api(self): mod.append(code_folder) self.assertEqual(sorted(monai.__all__), sorted(mod)) + def test_transform_api(self): + """monai subclasses of MapTransforms must have alias names ending with 'd', 'D', 'Dict'""" + to_exclude = {"MapTransform"} # except for these transforms + xforms = { + name: obj + for name, obj in monai.transforms.__dict__.items() + if inspect.isclass(obj) and issubclass(obj, monai.transforms.MapTransform) + } + names = sorted(x for x in xforms if x not in to_exclude) + remained = set(names) + for n in names: + if not n.endswith("d"): + continue + basename = n[:-1] # Transformd basename is Transform + for postfix in ("D", "d", "Dict"): + remained.remove(f"{basename}{postfix}") + self.assertFalse(remained) + if __name__ == "__main__": unittest.main()