Skip to content

Commit

Permalink
Do not modify the input image in Cutout and RGBShift (#39)
Browse files Browse the repository at this point in the history
  • Loading branch information
creafz committed Jul 26, 2018
1 parent 8ace9ff commit 4fdbfc5
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 2 deletions.
5 changes: 5 additions & 0 deletions albumentations/augmentations/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ def normalize(img, mean, std, max_pixel_value=255.0):


def cutout(img, num_holes, max_h_size, max_w_size):
# Make a copy of the input image since we don't want to modify it directly
img = img.copy()
height, width = img.shape[:2]

for n in range(num_holes):
Expand Down Expand Up @@ -154,6 +156,9 @@ def shift_rgb(img, r_shift, g_shift, b_shift):
if img.dtype == np.uint8:
img = img.astype('int32')
r_shift, g_shift, b_shift = np.int32(r_shift), np.int32(g_shift), np.int32(b_shift)
else:
# Make a copy of the input image since we don't want to modify it directly
img = img.copy()
img[..., 0] += r_shift
img[..., 1] += g_shift
img[..., 2] += b_shift
Expand Down
2 changes: 1 addition & 1 deletion albumentations/augmentations/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ class Cutout(ImageOnlyTransform):
"""

def __init__(self, num_holes=0, max_h_size=0, max_w_size=0, p=0.5):
def __init__(self, num_holes=8, max_h_size=8, max_w_size=8, p=0.5):
super(Cutout, self).__init__(p)
self.num_holes = num_holes
self.max_h_size = max_h_size
Expand Down
88 changes: 87 additions & 1 deletion tests/test_augmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
Rotate, ShiftScaleRotate, CenterCrop, OpticalDistortion, GridDistortion, ElasticTransform, ToGray, RandomGamma, \
JpegCompression, HueSaturationValue, RGBShift, RandomBrightness, RandomContrast, Blur, MotionBlur, MedianBlur, \
GaussNoise, CLAHE, ChannelShuffle, InvertImg, IAAEmboss, IAASuperpixels, IAASharpen, IAAAdditiveGaussianNoise, \
IAAPiecewiseAffine, IAAPerspective, Cutout
IAAPiecewiseAffine, IAAPerspective, Cutout, Normalize, ToFloat, FromFloat


@pytest.mark.parametrize(['augmentation_cls', 'params'], [
Expand All @@ -32,6 +32,7 @@
[RandomGamma, {}],
[ToGray, {}],
[Cutout, {}],
[GaussNoise, {}],
])
def test_image_only_augmentations(augmentation_cls, params, image, mask):
aug = augmentation_cls(p=1, **params)
Expand All @@ -56,6 +57,7 @@ def test_image_only_augmentations(augmentation_cls, params, image, mask):
[JpegCompression, {}],
[ToGray, {}],
[Cutout, {}],
[GaussNoise, {}],
])
def test_image_only_augmentations_with_float_values(augmentation_cls, params, float_image, mask):
aug = augmentation_cls(p=1, **params)
Expand Down Expand Up @@ -132,3 +134,87 @@ def test_torch_to_tensor_augmentations(image, mask):
data = aug(image=image, mask=mask)
assert data['image'].dtype == torch.float32
assert data['mask'].dtype == torch.float32


@pytest.mark.parametrize(['augmentation_cls', 'params'], [
[Cutout, {}],
[JpegCompression, {}],
[HueSaturationValue, {}],
[RGBShift, {}],
[RandomBrightness, {}],
[RandomContrast, {}],
[Blur, {}],
[MotionBlur, {}],
[MedianBlur, {}],
[GaussNoise, {}],
[CLAHE, {}],
[ChannelShuffle, {}],
[InvertImg, {}],
[RandomGamma, {}],
[ToGray, {}],
[Cutout, {}],
[PadIfNeeded, {}],
[VerticalFlip, {}],
[HorizontalFlip, {}],
[Flip, {}],
[Transpose, {}],
[RandomRotate90, {}],
[Rotate, {}],
[ShiftScaleRotate, {}],
[OpticalDistortion, {}],
[GridDistortion, {}],
[ElasticTransform, {}],
[CenterCrop, {'height': 10, 'width': 10}],
[RandomCrop, {'height': 10, 'width': 10}],
[Normalize, {}],
[GaussNoise, {}],
[ToFloat, {}],
[FromFloat, {}],
])
def test_augmentations_wont_change_input(augmentation_cls, params, image, mask):
image_copy = image.copy()
mask_copy = mask.copy()
aug = augmentation_cls(p=1, **params)
aug(image=image, mask=mask)
assert np.array_equal(image, image_copy)
assert np.array_equal(mask, mask_copy)


@pytest.mark.parametrize(['augmentation_cls', 'params'], [
[Cutout, {}],
[HueSaturationValue, {}],
[RGBShift, {}],
[RandomBrightness, {}],
[RandomContrast, {}],
[Blur, {}],
[MotionBlur, {}],
[MedianBlur, {'blur_limit': (3, 5)}],
[GaussNoise, {}],
[ChannelShuffle, {}],
[InvertImg, {}],
[RandomGamma, {}],
[ToGray, {}],
[Cutout, {}],
[PadIfNeeded, {}],
[VerticalFlip, {}],
[HorizontalFlip, {}],
[Flip, {}],
[Transpose, {}],
[RandomRotate90, {}],
[Rotate, {}],
[ShiftScaleRotate, {}],
[OpticalDistortion, {}],
[GridDistortion, {}],
[ElasticTransform, {}],
[CenterCrop, {'height': 10, 'width': 10}],
[RandomCrop, {'height': 10, 'width': 10}],
[Normalize, {}],
[GaussNoise, {}],
[ToFloat, {}],
[FromFloat, {}],
])
def test_augmentations_wont_change_float_input(augmentation_cls, params, float_image):
float_image_copy = float_image.copy()
aug = augmentation_cls(p=1, **params)
aug(image=float_image)
assert np.array_equal(float_image, float_image_copy)

0 comments on commit 4fdbfc5

Please sign in to comment.