diff --git a/albumentations/core/composition.py b/albumentations/core/composition.py index 7a78e43f7..d9f53c4a2 100644 --- a/albumentations/core/composition.py +++ b/albumentations/core/composition.py @@ -48,6 +48,12 @@ def get_always_apply(transforms: typing.Union["BaseCompose", TransformsSeqType]) class BaseCompose(metaclass=SerializableMeta): def __init__(self, transforms: TransformsSeqType, p: float): + if isinstance(transforms, (BaseCompose, BasicTransform)): + warnings.warn( + "transforms is single transform, but a sequence is expected! Transform will be wrapped into list." + ) + transforms = [transforms] + self.transforms = transforms self.p = p @@ -276,7 +282,7 @@ class OneOf(BaseCompose): def __init__(self, transforms: TransformsSeqType, p: float = 0.5): super(OneOf, self).__init__(transforms, p) - transforms_ps = [t.p for t in transforms] + transforms_ps = [t.p for t in self.transforms] s = sum(transforms_ps) self.transforms_ps = [t / s for t in transforms_ps] @@ -308,7 +314,7 @@ def __init__(self, transforms: TransformsSeqType, n: int, replace: bool = True, super(SomeOf, self).__init__(transforms, p) self.n = n self.replace = replace - transforms_ps = [t.p for t in transforms] + transforms_ps = [t.p for t in self.transforms] s = sum(transforms_ps) self.transforms_ps = [t / s for t in transforms_ps] @@ -347,9 +353,9 @@ def __init__( if first is None or second is None: raise ValueError("You must set both first and second or set transforms argument.") transforms = [first, second] - elif len(transforms) != 2: - warnings.warn("Length of transforms is not equal to 2.") super(OneOrOther, self).__init__(transforms, p) + if len(self.transforms) != 2: + warnings.warn("Length of transforms is not equal to 2.") def __call__(self, *args, force_apply: bool = False, **data) -> typing.Dict[str, typing.Any]: if self.replay_mode: diff --git a/tests/test_core.py b/tests/test_core.py index 18e677bb9..5d69d2989 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1,26 +1,43 @@ from __future__ import absolute_import +import typing from unittest import mock -from unittest.mock import Mock, MagicMock, call +from unittest.mock import MagicMock, Mock, call import cv2 import numpy as np import pytest -from albumentations.core.transforms_interface import to_tuple, ImageOnlyTransform, DualTransform +from albumentations import ( + BasicTransform, + Blur, + Crop, + HorizontalFlip, + MedianBlur, + Normalize, + PadIfNeeded, + Resize, + Rotate, +) from albumentations.augmentations.bbox_utils import check_bboxes from albumentations.core.composition import ( - OneOrOther, + BaseCompose, + BboxParams, Compose, + KeypointParams, OneOf, + OneOrOther, SomeOf, PerChannel, ReplayCompose, - KeypointParams, - BboxParams, Sequential, ) -from albumentations import HorizontalFlip, Rotate, Blur, MedianBlur, PadIfNeeded, Crop +from albumentations.core.transforms_interface import ( + DualTransform, + ImageOnlyTransform, + to_tuple, +) +from .utils import get_filtered_transforms def test_one_or_other(): @@ -332,3 +349,24 @@ def test_bbox_params_is_not_set(image, bboxes): with pytest.raises(ValueError) as exc_info: t(image=image, bboxes=bboxes) assert str(exc_info.value) == "bbox_params must be specified for bbox transformations" + + +@pytest.mark.parametrize( + "compose_transform", get_filtered_transforms((BaseCompose,), custom_arguments={SomeOf: {"n": 1}}) +) +@pytest.mark.parametrize( + "inner_transform", + [(Normalize, {}), (Resize, {"height": 100, "width": 100})] + + get_filtered_transforms((BaseCompose,), custom_arguments={SomeOf: {"n": 1}}), # type: ignore +) +def test_single_transform_compose( + compose_transform: typing.Tuple[typing.Type[BaseCompose], dict], + inner_transform: typing.Tuple[typing.Union[typing.Type[BaseCompose], typing.Type[BasicTransform]], dict], +): + compose_cls, compose_kwargs = compose_transform + cls, kwargs = inner_transform + transform = cls(transforms=[], **kwargs) if issubclass(cls, BaseCompose) else cls(**kwargs) + + with pytest.warns(UserWarning): + res_transform = compose_cls(transforms=transform, **compose_kwargs) # type: ignore + assert isinstance(res_transform.transforms, list) diff --git a/tests/utils.py b/tests/utils.py index b5dea110c..13d1fb8ca 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,6 +1,6 @@ +import inspect import random import typing -import inspect import numpy as np from io import StringIO @@ -73,7 +73,9 @@ def get_filtered_transforms( result = [] for name, cls in inspect.getmembers(albumentations): - if not inspect.isclass(cls) or not issubclass(cls, albumentations.BasicTransform): + if not inspect.isclass(cls) or not issubclass( + cls, (albumentations.BasicTransform, albumentations.BaseCompose) + ): continue if "DeprecationWarning" in inspect.getsource(cls) or "FutureWarning" in inspect.getsource(cls):