diff --git a/monai/transforms/intensity/array.py b/monai/transforms/intensity/array.py index d2d8cac1ad..fd4aceb376 100644 --- a/monai/transforms/intensity/array.py +++ b/monai/transforms/intensity/array.py @@ -1076,6 +1076,8 @@ class RandGaussianSmooth(RandomizableTransform): """ + backend = GaussianSmooth.backend + def __init__( self, sigma_x: Tuple[float, float] = (0.25, 1.5), diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index d20f368109..eb1860eed1 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -25,7 +25,7 @@ from monai.networks.layers import GaussianFilter from monai.transforms.transform import Transform from monai.transforms.utils import fill_holes, get_largest_connected_component_mask -from monai.utils import deprecated_arg, ensure_tuple, look_up_option +from monai.utils import TransformBackends, deprecated_arg, ensure_tuple, look_up_option __all__ = [ "Activations", @@ -57,6 +57,8 @@ class Activations(Transform): """ + backend = [TransformBackends.TORCH] + def __init__(self, sigmoid: bool = False, softmax: bool = False, other: Optional[Callable] = None) -> None: self.sigmoid = sigmoid self.softmax = softmax @@ -134,6 +136,8 @@ class AsDiscrete(Transform): """ + backend = [TransformBackends.TORCH] + @deprecated_arg("n_classes", since="0.6") def __init__( self, @@ -655,9 +659,8 @@ def __call__( prob_map = torch.as_tensor(prob_map, dtype=torch.float) self.filter.to(prob_map) prob_map = self.filter(prob_map) - else: - if not isinstance(prob_map, torch.Tensor): - prob_map = prob_map.copy() + elif not isinstance(prob_map, torch.Tensor): + prob_map = prob_map.copy() if isinstance(prob_map, torch.Tensor): prob_map = prob_map.detach().cpu().numpy() diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py index 97ae4ec5a9..4ca07da949 100644 --- a/monai/transforms/post/dictionary.py +++ b/monai/transforms/post/dictionary.py @@ -86,6 +86,8 @@ class Activationsd(MapTransform): Add activation layers to the input data specified by `keys`. """ + backend = Activations.backend + def __init__( self, keys: KeysCollection, @@ -126,6 +128,8 @@ class AsDiscreted(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.AsDiscrete`. """ + backend = AsDiscrete.backend + @deprecated_arg("n_classes", since="0.6") def __init__( self, diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index ef20cf6f3a..ffc3b99cb5 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -887,9 +887,7 @@ def __call__(self, img: np.ndarray) -> np.ndarray: if img.ndim == 4 and img.shape[0] == 1: img = np.squeeze(img, axis=0) - result = [] - # merge labels 1 (tumor non-enh) and 4 (tumor enh) to TC - result.append(np.logical_or(img == 1, img == 4)) + result = [np.logical_or(img == 1, img == 4)] # merge labels 1 (tumor non-enh) and 4 (tumor enh) and 2 (large edema) to WT result.append(np.logical_or(np.logical_or(img == 1, img == 4), img == 2)) # label 4 is ET @@ -1125,6 +1123,8 @@ class ToDevice(Transform): """ + backend = [TransformBackends.TORCH] + def __init__(self, device: Union[torch.device, str], **kwargs) -> None: """ Args: diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 0a4df58647..cefb654698 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -701,8 +701,7 @@ class SelectItemsd(MapTransform): """ def __call__(self, data): - result = {key: data[key] for key in self.key_iterator(data)} - return result + return {key: data[key] for key in self.key_iterator(data)} class SqueezeDimd(MapTransform): @@ -1465,6 +1464,8 @@ class ToDeviced(MapTransform): Dictionary-based wrapper of :py:class:`monai.transforms.ToDevice`. """ + backend = [TransformBackends.TORCH] + def __init__( self, keys: KeysCollection, diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 18657d0122..67fdb70a8c 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -1370,20 +1370,29 @@ def get_transform_backends(): continue unique_transforms.append(obj) - if isclass(obj) and issubclass(obj, Transform): - if n in [ - "Transform", + if ( + isclass(obj) + and issubclass(obj, Transform) + and n + not in [ + "BatchInverseTransform", + "Compose", + "Decollated", + "InvertD", "InvertibleTransform", "Lambda", "LambdaD", - "Compose", - "RandomizableTransform", + "MapTransform", "OneOf", - "BatchInverseTransform", - "InverteD", - ]: - continue - + "PadListDataCollate", + "RandLambda", + "RandLambdaD", + "RandTorchVisionD", + "RandomizableTransform", + "TorchVisionD", + "Transform", + ] + ): backends[n] = [ TransformBackends.TORCH in obj.backend, TransformBackends.NUMPY in obj.backend,