Skip to content

Commit

Permalink
Added crop augmentation (#41)
Browse files Browse the repository at this point in the history
* Added crop augmentations

* bugfix
  • Loading branch information
ternaus committed Aug 6, 2018
1 parent cec846e commit 78528d9
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 3 deletions.
30 changes: 30 additions & 0 deletions albumentations/augmentations/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,36 @@ def shift_scale_rotate(img, angle, scale, dx, dy, interpolation=cv2.INTER_LINEAR
return img


def crop(img, x_min, y_min, x_max, y_max):
height, width = img.shape[:2]
if x_max <= x_min or y_max <= y_min:
raise ValueError(
'We should have x_min < x_max and y_min < y_max. But we got'
' (x_min = {x_min}, y_min = {y_min}, x_max = {x_max}, y_max = {y_max})'.format(
x_min=x_min,
x_max=x_max,
y_min=y_min,
y_max=y_max
)
)

if x_min < 0 or x_max >= width or y_min < 0 or y_max >= height:
raise ValueError(
'Values for crop should be non negative and smaller than image sizes'
'(x_min = {x_min}, y_min = {y_min}, x_max = {x_max}, y_max = {y_max}'
'height = {height}, width = {width})'.format(
x_min=x_min,
x_max=x_max,
y_min=y_min,
y_max=y_max,
height=height,
width=width
)
)

return img[y_min:y_max, x_min:x_max]


def center_crop(img, crop_height, crop_width):
height, width = img.shape[:2]
if height < crop_height or width < crop_width:
Expand Down
34 changes: 32 additions & 2 deletions albumentations/augmentations/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
'RandomRotate90', 'Rotate', 'ShiftScaleRotate', 'CenterCrop', 'OpticalDistortion', 'GridDistortion',
'ElasticTransform', 'HueSaturationValue', 'PadIfNeeded', 'RGBShift', 'RandomBrightness', 'RandomContrast',
'MotionBlur', 'MedianBlur', 'GaussNoise', 'CLAHE', 'ChannelShuffle', 'InvertImg', 'ToGray',
'JpegCompression', 'Cutout', 'ToFloat', 'FromFloat']
'JpegCompression', 'Cutout', 'ToFloat', 'FromFloat', 'Crop']


class PadIfNeeded(DualTransform):
Expand All @@ -26,7 +26,6 @@ class PadIfNeeded(DualTransform):
Image types:
uint8, float32
TODO: add application to boxes
"""

def __init__(self, min_height=1024, min_width=1024, p=1.0):
Expand All @@ -41,6 +40,33 @@ def apply_to_bbox(self, bbox, **params):
pass


class Crop(DualTransform):
"""Crops region from image
Args:
x_min (int): minimum upper left x coordinate
y_min (int): minimum upper left y coordinate
x_max (int): maximum lower right x coordinate
y_max (int): maximum lower right y coordinate
Targets:
image, mask
Image types:
uint8, float32
"""

def __init__(self, x_min=0, y_min=0, x_max=1024, y_max=1024, p=1.0):
super(Crop, self).__init__(p)
self.x_min = x_min
self.y_min = y_min
self.x_max = x_max
self.y_max = y_max

def apply(self, img, **params):
return F.crop(img, x_min=self.x_min, y_min=self.y_min, x_max=self.x_max, y_max=self.y_max)


class VerticalFlip(DualTransform):
"""Flips the input vertically around the x-axis.
Expand Down Expand Up @@ -296,6 +322,7 @@ class OpticalDistortion(DualTransform):
Image types:
uint8, float32
"""

def __init__(self, distort_limit=0.05, shift_limit=0.05, interpolation=cv2.INTER_LINEAR,
border_mode=cv2.BORDER_REFLECT_101, p=0.5):
super(OpticalDistortion, self).__init__(p)
Expand All @@ -321,6 +348,7 @@ class GridDistortion(DualTransform):
Image types:
uint8, float32
"""

def __init__(self, num_steps=5, distort_limit=0.3, interpolation=cv2.INTER_LINEAR,
border_mode=cv2.BORDER_REFLECT_101, p=0.5):
super(GridDistortion, self).__init__(p)
Expand Down Expand Up @@ -351,6 +379,7 @@ class ElasticTransform(DualTransform):
Image types:
uint8, float32
"""

def __init__(self, alpha=1, sigma=50, alpha_affine=50, interpolation=cv2.INTER_LINEAR,
border_mode=cv2.BORDER_REFLECT_101, p=0.5):
super(ElasticTransform, self).__init__(p)
Expand Down Expand Up @@ -633,6 +662,7 @@ class MedianBlur(Blur):
Image types:
uint8, float32
"""

def __init__(self, blur_limit=7, p=0.5):
super(MedianBlur, self).__init__(p)
self.blur_limit = to_tuple(blur_limit, 3)
Expand Down
2 changes: 1 addition & 1 deletion docs/probabilities.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ About probabilities.
Default probability values
******************************

**Compose**, **PadIfNeeded**, **CenterCrop**, **RandomCrop**, **Normalize**, **ToFloat**, **FromFloat**, **ToTensor** have default
**Compose**, **PadIfNeeded**, **CenterCrop**, **RandomCrop**, **Crop**, **Normalize**, **ToFloat**, **FromFloat**, **ToTensor** have default
probability values equal to **1**. All other are equal to **0.5**


Expand Down

0 comments on commit 78528d9

Please sign in to comment.