Skip to content

Commit

Permalink
Added longest max size augmentation (#44)
Browse files Browse the repository at this point in the history
* Added longest max size augmentation

* pep fixes
  • Loading branch information
ternaus committed Aug 13, 2018
1 parent bf9d218 commit d99db62
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 5 deletions.
11 changes: 11 additions & 0 deletions albumentations/augmentations/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,17 @@ def blur(img, ksize):
return cv2.blur(img, (ksize, ksize))


def longest_max_size(img, longest_max_size, interpolation):
height, width = img.shape[:2]

scale = longest_max_size / float(max(width, height))

if scale != 1.0:
out_size = tuple(int(dim * scale) for dim in (width, height))
img = cv2.resize(img, out_size, interpolation=interpolation)
return img


def median_blur(img, ksize):
if img.dtype == np.float32 and ksize not in {3, 5}:
raise ValueError(
Expand Down
28 changes: 24 additions & 4 deletions albumentations/augmentations/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
'RandomRotate90', 'Rotate', 'ShiftScaleRotate', 'CenterCrop', 'OpticalDistortion', 'GridDistortion',
'ElasticTransform', 'HueSaturationValue', 'PadIfNeeded', 'RGBShift', 'RandomBrightness', 'RandomContrast',
'MotionBlur', 'MedianBlur', 'GaussNoise', 'CLAHE', 'ChannelShuffle', 'InvertImg', 'ToGray',
'JpegCompression', 'Cutout', 'ToFloat', 'FromFloat', 'Crop', 'RandomScale']
'JpegCompression', 'Cutout', 'ToFloat', 'FromFloat', 'Crop', 'RandomScale', 'LongestMaxSize']


class PadIfNeeded(DualTransform):
Expand All @@ -37,9 +37,6 @@ def __init__(self, min_height=1024, min_width=1024, p=1.0):
def apply(self, img, **params):
return F.pad(img, min_height=self.min_height, min_width=self.min_width)

def apply_to_bbox(self, bbox, **params):
pass


class Crop(DualTransform):
"""Crops region from image
Expand Down Expand Up @@ -153,6 +150,29 @@ def apply(self, img, **params):
return F.transpose(img)


class LongestMaxSize(DualTransform):
"""Rescales and image so that maximum side is equal to max_size, keeping th aspect ratio of the initial image.
Args:
p (float): probability of applying the transform. Default: 1.
max_size (int): maximum size of the image after the transformation
Targets:
image, mask
Image types:
uint8, float32
"""

def __init__(self, max_size=1024, interpolation=cv2.INTER_LINEAR, p=1):
super(LongestMaxSize, self).__init__(p)
self.interpolation = interpolation
self.max_size = max_size

def apply(self, img, **params):
return F.longest_max_size(img)


class RandomRotate90(DualTransform):
"""Randomly rotates the input by 90 degrees zero or more times.
Expand Down
2 changes: 1 addition & 1 deletion docs/probabilities.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ About probabilities.
Default probability values
******************************

**Compose**, **PadIfNeeded**, **CenterCrop**, **RandomCrop**, **Crop**, **Normalize**, **ToFloat**, **FromFloat**, **ToTensor** have default
**Compose**, **PadIfNeeded**, **CenterCrop**, **RandomCrop**, **Crop**, **Normalize**, **ToFloat**, **FromFloat**, **ToTensor**, **LongestMaxSize** have default
probability values equal to **1**. All other are equal to **0.5**


Expand Down
16 changes: 16 additions & 0 deletions tests/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,22 @@ def test_scale(target):
assert np.array_equal(scaled, expected)


@pytest.mark.parametrize('target', ['image', 'mask'])
def test_longest_max_size(target):
img = np.array([
[1, 2, 3],
[4, 5, 6],
[7, 8, 9],
[10, 11, 12]], dtype=np.uint8)
expected = np.array([[2, 3],
[6, 7],
[10, 11]], dtype=np.uint8)

img, expected = convert_2d_to_target_format([img, expected], target=target)
scaled = F.longest_max_size(img, longest_max_size=3, interpolation=cv2.INTER_LINEAR)
assert np.array_equal(scaled, expected)


def test_from_float_unknown_dtype():
img = np.ones((100, 100, 3), dtype=np.float32)
with pytest.raises(RuntimeError) as exc_info:
Expand Down

0 comments on commit d99db62

Please sign in to comment.