Skip to content

Commit

Permalink
brightness_contrast_adjust optimization (#329)
Browse files Browse the repository at this point in the history
* brightness_contrast_adjust optimization

* benchmark: Return brightness_contrast_adjust to standard behaviour
  • Loading branch information
Dipet authored and ternaus committed Sep 11, 2019
1 parent b612786 commit 4e12c6e
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 11 deletions.
24 changes: 16 additions & 8 deletions albumentations/augmentations/functional.py
Expand Up @@ -1101,18 +1101,24 @@ def gauss_noise(image, gauss):
return image + gauss


def _brightness_contrast_adjust_non_uint(img, alpha=1, beta=0):
@clipped
def _brightness_contrast_adjust_non_uint(img, alpha=1, beta=0, beta_by_max=False):
dtype = img.dtype
img = img.astype('float32')

if alpha != 1:
img *= alpha
if beta != 0:
img += beta * np.mean(img)
if beta_by_max:
max_value = MAX_VALUES_BY_DTYPE[dtype]
img += beta * max_value
else:
img += beta * np.mean(img)
return img


@preserve_shape
def _brightness_contrast_adjust_uint(img, alpha=1, beta=0):
def _brightness_contrast_adjust_uint(img, alpha=1, beta=0, beta_by_max=False):
dtype = np.dtype('uint8')

max_value = MAX_VALUES_BY_DTYPE[dtype]
Expand All @@ -1122,19 +1128,21 @@ def _brightness_contrast_adjust_uint(img, alpha=1, beta=0):
if alpha != 1:
lut *= alpha
if beta != 0:
lut += beta * np.mean(img)
if beta_by_max:
lut += beta * max_value
else:
lut += beta * np.mean(img)

lut = np.clip(lut, 0, max_value).astype(dtype)
img = cv2.LUT(img, lut)
return img


@clipped
def brightness_contrast_adjust(img, alpha=1, beta=0):
def brightness_contrast_adjust(img, alpha=1, beta=0, beta_by_max=False):
if img.dtype == np.uint8:
return _brightness_contrast_adjust_uint(img, alpha, beta)
return _brightness_contrast_adjust_uint(img, alpha, beta, beta_by_max)
else:
return _brightness_contrast_adjust_non_uint(img, alpha, beta)
return _brightness_contrast_adjust_non_uint(img, alpha, beta, beta_by_max)


@clipped
Expand Down
13 changes: 10 additions & 3 deletions albumentations/augmentations/transforms.py
Expand Up @@ -1993,6 +1993,8 @@ class RandomBrightnessContrast(ImageOnlyTransform):
If limit is a single float, the range will be (-limit, limit). Default: 0.2.
contrast_limit ((float, float) or float): factor range for changing contrast.
If limit is a single float, the range will be (-limit, limit). Default: 0.2.
brightness_by_max (Boolean): If True adjust contrast by image dtype maximum,
else adjust contrast by image mean.
p (float): probability of applying the transform. Default: 0.5.
Targets:
Expand All @@ -2002,13 +2004,18 @@ class RandomBrightnessContrast(ImageOnlyTransform):
uint8, float32
"""

def __init__(self, brightness_limit=0.2, contrast_limit=0.2, always_apply=False, p=0.5):
def __init__(self, brightness_limit=0.2, contrast_limit=0.2, brightness_by_max=None, always_apply=False, p=0.5):
super(RandomBrightnessContrast, self).__init__(always_apply, p)
self.brightness_limit = to_tuple(brightness_limit)
self.contrast_limit = to_tuple(contrast_limit)
self.brightness_by_max = brightness_by_max

if brightness_by_max is None:
DeprecationWarning('In the version 0.4.0 default behavior of RandomBrightnessContrast '
'brightness_by_max will be changed to True.')

def apply(self, img, alpha=1., beta=0., **params):
return F.brightness_contrast_adjust(img, alpha, beta)
return F.brightness_contrast_adjust(img, alpha, beta, self.brightness_by_max)

def get_params(self):
return {
Expand All @@ -2017,7 +2024,7 @@ def get_params(self):
}

def get_transform_init_args_names(self):
return ('brightness_limit', 'contrast_limit')
return ('brightness_limit', 'contrast_limit', 'brightness_by_max')


class RandomBrightness(RandomBrightnessContrast):
Expand Down

0 comments on commit 4e12c6e

Please sign in to comment.