Skip to content

Commit

Permalink
Warning if Compose got single transform, wrap transform into list (#1055
Browse files Browse the repository at this point in the history
)

* Warning if Compose got single transform, wrap transform into list

* Black
  • Loading branch information
Dipet committed Nov 6, 2021
1 parent 6741f51 commit 477156d
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 12 deletions.
14 changes: 10 additions & 4 deletions albumentations/core/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@ def get_always_apply(transforms: typing.Union["BaseCompose", TransformsSeqType])

class BaseCompose(metaclass=SerializableMeta):
def __init__(self, transforms: TransformsSeqType, p: float):
if isinstance(transforms, (BaseCompose, BasicTransform)):
warnings.warn(
"transforms is single transform, but a sequence is expected! Transform will be wrapped into list."
)
transforms = [transforms]

self.transforms = transforms
self.p = p

Expand Down Expand Up @@ -276,7 +282,7 @@ class OneOf(BaseCompose):

def __init__(self, transforms: TransformsSeqType, p: float = 0.5):
super(OneOf, self).__init__(transforms, p)
transforms_ps = [t.p for t in transforms]
transforms_ps = [t.p for t in self.transforms]
s = sum(transforms_ps)
self.transforms_ps = [t / s for t in transforms_ps]

Expand Down Expand Up @@ -308,7 +314,7 @@ def __init__(self, transforms: TransformsSeqType, n: int, replace: bool = True,
super(SomeOf, self).__init__(transforms, p)
self.n = n
self.replace = replace
transforms_ps = [t.p for t in transforms]
transforms_ps = [t.p for t in self.transforms]
s = sum(transforms_ps)
self.transforms_ps = [t / s for t in transforms_ps]

Expand Down Expand Up @@ -347,9 +353,9 @@ def __init__(
if first is None or second is None:
raise ValueError("You must set both first and second or set transforms argument.")
transforms = [first, second]
elif len(transforms) != 2:
warnings.warn("Length of transforms is not equal to 2.")
super(OneOrOther, self).__init__(transforms, p)
if len(self.transforms) != 2:
warnings.warn("Length of transforms is not equal to 2.")

def __call__(self, *args, force_apply: bool = False, **data) -> typing.Dict[str, typing.Any]:
if self.replay_mode:
Expand Down
50 changes: 44 additions & 6 deletions tests/test_core.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,43 @@
from __future__ import absolute_import

import typing
from unittest import mock
from unittest.mock import Mock, MagicMock, call
from unittest.mock import MagicMock, Mock, call

import cv2
import numpy as np
import pytest

from albumentations.core.transforms_interface import to_tuple, ImageOnlyTransform, DualTransform
from albumentations import (
BasicTransform,
Blur,
Crop,
HorizontalFlip,
MedianBlur,
Normalize,
PadIfNeeded,
Resize,
Rotate,
)
from albumentations.augmentations.bbox_utils import check_bboxes
from albumentations.core.composition import (
OneOrOther,
BaseCompose,
BboxParams,
Compose,
KeypointParams,
OneOf,
OneOrOther,
SomeOf,
PerChannel,
ReplayCompose,
KeypointParams,
BboxParams,
Sequential,
)
from albumentations import HorizontalFlip, Rotate, Blur, MedianBlur, PadIfNeeded, Crop
from albumentations.core.transforms_interface import (
DualTransform,
ImageOnlyTransform,
to_tuple,
)
from .utils import get_filtered_transforms


def test_one_or_other():
Expand Down Expand Up @@ -332,3 +349,24 @@ def test_bbox_params_is_not_set(image, bboxes):
with pytest.raises(ValueError) as exc_info:
t(image=image, bboxes=bboxes)
assert str(exc_info.value) == "bbox_params must be specified for bbox transformations"


@pytest.mark.parametrize(
"compose_transform", get_filtered_transforms((BaseCompose,), custom_arguments={SomeOf: {"n": 1}})
)
@pytest.mark.parametrize(
"inner_transform",
[(Normalize, {}), (Resize, {"height": 100, "width": 100})]
+ get_filtered_transforms((BaseCompose,), custom_arguments={SomeOf: {"n": 1}}), # type: ignore
)
def test_single_transform_compose(
compose_transform: typing.Tuple[typing.Type[BaseCompose], dict],
inner_transform: typing.Tuple[typing.Union[typing.Type[BaseCompose], typing.Type[BasicTransform]], dict],
):
compose_cls, compose_kwargs = compose_transform
cls, kwargs = inner_transform
transform = cls(transforms=[], **kwargs) if issubclass(cls, BaseCompose) else cls(**kwargs)

with pytest.warns(UserWarning):
res_transform = compose_cls(transforms=transform, **compose_kwargs) # type: ignore
assert isinstance(res_transform.transforms, list)
6 changes: 4 additions & 2 deletions tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import inspect
import random
import typing
import inspect
import numpy as np

from io import StringIO
Expand Down Expand Up @@ -73,7 +73,9 @@ def get_filtered_transforms(
result = []

for name, cls in inspect.getmembers(albumentations):
if not inspect.isclass(cls) or not issubclass(cls, albumentations.BasicTransform):
if not inspect.isclass(cls) or not issubclass(
cls, (albumentations.BasicTransform, albumentations.BaseCompose)
):
continue

if "DeprecationWarning" in inspect.getsource(cls) or "FutureWarning" in inspect.getsource(cls):
Expand Down

0 comments on commit 477156d

Please sign in to comment.