Skip to content

Commit

Permalink
Make apply_to_bboxes work with a list of bounding boxes (#40)
Browse files Browse the repository at this point in the history
  • Loading branch information
creafz committed Jul 26, 2018
1 parent 4fdbfc5 commit cec846e
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 2 deletions.
5 changes: 4 additions & 1 deletion albumentations/core/transforms_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,14 @@ class DualTransform(BasicTransform):

@property
def targets(self):
return {'image': self.apply, 'mask': self.apply_to_mask, 'bboxes': self.apply_to_bbox}
return {'image': self.apply, 'mask': self.apply_to_mask, 'bboxes': self.apply_to_bboxes}

def apply_to_bbox(self, bbox, **params):
raise NotImplementedError

def apply_to_bboxes(self, bboxes, **params):
return [self.apply_to_bbox(bbox, **params) for bbox in bboxes]

def apply_to_mask(self, img, **params):
return self.apply(img, **{k: cv2.INTER_NEAREST if k == 'interpolation' else v for k, v in params.items()})

Expand Down
14 changes: 13 additions & 1 deletion tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import cv2
import pytest

from albumentations import Transpose, Rotate, ShiftScaleRotate, OpticalDistortion, GridDistortion, ElasticTransform
from albumentations import Transpose, Rotate, ShiftScaleRotate, OpticalDistortion, GridDistortion, ElasticTransform,\
VerticalFlip, HorizontalFlip
import albumentations.augmentations.functional as F


Expand Down Expand Up @@ -87,3 +88,14 @@ def test_elastic_transform_interpolation(monkeypatch, interpolation):
random_state=np.random.RandomState(1111))
assert np.array_equal(data['image'], expected_image)
assert np.array_equal(data['mask'], expected_mask)


@pytest.mark.parametrize(['augmentation_cls', 'expected'], [
[VerticalFlip, [(194, 2, 5, 5), (120, 67, 35, 24)]],
[HorizontalFlip, [(1, 93, 5, 5), (45, 9, 35, 24)]],
])
def test_apply_to_bboxes(augmentation_cls, expected):
aug = augmentation_cls(p=1)
image = np.ones((100, 200, 3), dtype=np.uint8)
data = aug(image=image, bboxes=[(1, 2, 5, 5), (45, 67, 35, 24)])
assert np.array_equal(data['bboxes'], expected)

0 comments on commit cec846e

Please sign in to comment.