Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions monai/transforms/intensity/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1076,6 +1076,8 @@ class RandGaussianSmooth(RandomizableTransform):

"""

backend = GaussianSmooth.backend

def __init__(
self,
sigma_x: Tuple[float, float] = (0.25, 1.5),
Expand Down
11 changes: 7 additions & 4 deletions monai/transforms/post/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from monai.networks.layers import GaussianFilter
from monai.transforms.transform import Transform
from monai.transforms.utils import fill_holes, get_largest_connected_component_mask
from monai.utils import deprecated_arg, ensure_tuple, look_up_option
from monai.utils import TransformBackends, deprecated_arg, ensure_tuple, look_up_option

__all__ = [
"Activations",
Expand Down Expand Up @@ -57,6 +57,8 @@ class Activations(Transform):

"""

backend = [TransformBackends.TORCH]

def __init__(self, sigmoid: bool = False, softmax: bool = False, other: Optional[Callable] = None) -> None:
self.sigmoid = sigmoid
self.softmax = softmax
Expand Down Expand Up @@ -134,6 +136,8 @@ class AsDiscrete(Transform):

"""

backend = [TransformBackends.TORCH]

@deprecated_arg("n_classes", since="0.6")
def __init__(
self,
Expand Down Expand Up @@ -655,9 +659,8 @@ def __call__(
prob_map = torch.as_tensor(prob_map, dtype=torch.float)
self.filter.to(prob_map)
prob_map = self.filter(prob_map)
else:
if not isinstance(prob_map, torch.Tensor):
prob_map = prob_map.copy()
elif not isinstance(prob_map, torch.Tensor):
prob_map = prob_map.copy()

if isinstance(prob_map, torch.Tensor):
prob_map = prob_map.detach().cpu().numpy()
Expand Down
4 changes: 4 additions & 0 deletions monai/transforms/post/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ class Activationsd(MapTransform):
Add activation layers to the input data specified by `keys`.
"""

backend = Activations.backend

def __init__(
self,
keys: KeysCollection,
Expand Down Expand Up @@ -126,6 +128,8 @@ class AsDiscreted(MapTransform):
Dictionary-based wrapper of :py:class:`monai.transforms.AsDiscrete`.
"""

backend = AsDiscrete.backend

@deprecated_arg("n_classes", since="0.6")
def __init__(
self,
Expand Down
6 changes: 3 additions & 3 deletions monai/transforms/utility/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -887,9 +887,7 @@ def __call__(self, img: np.ndarray) -> np.ndarray:
if img.ndim == 4 and img.shape[0] == 1:
img = np.squeeze(img, axis=0)

result = []
# merge labels 1 (tumor non-enh) and 4 (tumor enh) to TC
result.append(np.logical_or(img == 1, img == 4))
result = [np.logical_or(img == 1, img == 4)]
# merge labels 1 (tumor non-enh) and 4 (tumor enh) and 2 (large edema) to WT
result.append(np.logical_or(np.logical_or(img == 1, img == 4), img == 2))
# label 4 is ET
Expand Down Expand Up @@ -1125,6 +1123,8 @@ class ToDevice(Transform):

"""

backend = [TransformBackends.TORCH]

def __init__(self, device: Union[torch.device, str], **kwargs) -> None:
"""
Args:
Expand Down
5 changes: 3 additions & 2 deletions monai/transforms/utility/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,8 +701,7 @@ class SelectItemsd(MapTransform):
"""

def __call__(self, data):
result = {key: data[key] for key in self.key_iterator(data)}
return result
return {key: data[key] for key in self.key_iterator(data)}


class SqueezeDimd(MapTransform):
Expand Down Expand Up @@ -1465,6 +1464,8 @@ class ToDeviced(MapTransform):
Dictionary-based wrapper of :py:class:`monai.transforms.ToDevice`.
"""

backend = [TransformBackends.TORCH]

def __init__(
self,
keys: KeysCollection,
Expand Down
29 changes: 19 additions & 10 deletions monai/transforms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1370,20 +1370,29 @@ def get_transform_backends():
continue
unique_transforms.append(obj)

if isclass(obj) and issubclass(obj, Transform):
if n in [
"Transform",
if (
isclass(obj)
and issubclass(obj, Transform)
and n
not in [
"BatchInverseTransform",
"Compose",
"Decollated",
"InvertD",
"InvertibleTransform",
"Lambda",
"LambdaD",
"Compose",
"RandomizableTransform",
"MapTransform",
"OneOf",
"BatchInverseTransform",
"InverteD",
]:
continue

"PadListDataCollate",
"RandLambda",
"RandLambdaD",
"RandTorchVisionD",
"RandomizableTransform",
"TorchVisionD",
"Transform",
]
):
backends[n] = [
TransformBackends.TORCH in obj.backend,
TransformBackends.NUMPY in obj.backend,
Expand Down