Skip to content

Commit

Permalink
Added Cutout augmentation (#27)
Browse files Browse the repository at this point in the history
  • Loading branch information
ternaus committed Jul 12, 2018
1 parent fa0fd0a commit 5af82d3
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 8 deletions.
16 changes: 16 additions & 0 deletions albumentations/augmentations/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,22 @@ def normalize(img, mean, std, max_pixel_value=255.0):
return img


def cutout(img, num_holes, max_h_size, max_w_size):
height, width = img.shape[:2]

for n in range(num_holes):
y = np.random.randint(height)
x = np.random.randint(width)

y1 = np.clip(y - max_h_size // 2, 0, height)
y2 = np.clip(y + max_h_size // 2, 0, height)
x1 = np.clip(x - max_w_size // 2, 0, width)
x2 = np.clip(x + max_w_size // 2, 0, width)

img[y1: y2, x1: x2] = 0
return img


def rotate(img, angle, interpolation=cv2.INTER_LINEAR, border_mode=cv2.BORDER_REFLECT_101):
height, width = img.shape[:2]
matrix = cv2.getRotationMatrix2D((width / 2, height / 2), angle, 1.0)
Expand Down
42 changes: 34 additions & 8 deletions albumentations/augmentations/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
'RandomRotate90', 'Rotate', 'ShiftScaleRotate', 'CenterCrop', 'OpticalDistortion', 'GridDistortion',
'ElasticTransform', 'HueSaturationValue', 'PadIfNeeded', 'RGBShift', 'RandomBrightness', 'RandomContrast',
'MotionBlur', 'MedianBlur', 'GaussNoise', 'CLAHE', 'ChannelShuffle', 'InvertImg', 'ToGray',
'JpegCompression']
'JpegCompression', 'Cutout']


class PadIfNeeded(DualTransform):
Expand Down Expand Up @@ -316,8 +316,8 @@ class Normalize(ImageOnlyTransform):
"""Divides pixel values by 255 = 2**8 - 1, subtracts mean per channel and divides by std per channel
Args:
mean (float, float, float) - mean values
std (float, float, float) - std values
mean (float, float, float): mean values
std (float, float, float): std values
Targets:
image
Expand All @@ -332,15 +332,41 @@ def apply(self, image, **params):
return F.normalize(image, self.mean, self.std)


class Cutout(ImageOnlyTransform):
"""CoarseDropout of the square regions in the image
Args:
num_holes (int): number of regions to zero out
max_h_size (int): maximum height of the hole
max_w_size (int): maximum width of the hole
Targets:
image
Reference:
| https://arxiv.org/abs/1708.04552
| https://github.com/uoguelph-mlrg/Cutout/blob/master/util/cutout.py
| https://github.com/aleju/imgaug/blob/master/imgaug/augmenters/arithmetic.py
"""

def __init__(self, num_holes=0, max_h_size=0, max_w_size=0, p=1.0):
super(Cutout, self).__init__(p)
self.num_holes = num_holes
self.max_h_size = max_h_size
self.max_w_size = max_w_size

def apply(self, image, **params):
return F.cutout(image, self.num_holes, self.max_h_size, self.max_w_size)


class JpegCompression(ImageOnlyTransform):
"""Decreases Jpeg compression of an image.
Was essential part of the [IEEE's Signal Processing Society - Camera Model Identification Challenge]
(https://www.kaggle.com/c/sp-society-camera-model-identification)
Args:
quality_lower (float) - lower bound on the jpeg quality. Should be in [0, 100] range
quality_upper (float) - lower bound on the jpeg quality. Should be in [0, 100] range
quality_lower (float): lower bound on the jpeg quality. Should be in [0, 100] range
quality_upper (float): lower bound on the jpeg quality. Should be in [0, 100] range
Targets:
image
Expand Down

0 comments on commit 5af82d3

Please sign in to comment.