Skip to content

Commit

Permalink
RGBShift optimization (#328)
Browse files Browse the repository at this point in the history
  • Loading branch information
Dipet authored and ternaus committed Sep 2, 2019
1 parent 61c9f8c commit c3cc277
Showing 1 changed file with 41 additions and 9 deletions.
50 changes: 41 additions & 9 deletions albumentations/augmentations/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,18 +338,50 @@ def solarize(img, threshold=128):
return result_img


def _shift_image_uint8(img, value):
max_value = MAX_VALUES_BY_DTYPE[img.dtype]

lut = np.arange(0, max_value + 1).astype('float32')
lut += value

lut = np.clip(lut, 0, max_value).astype(img.dtype)
return cv2.LUT(img, lut)


@preserve_shape
def _shift_rgb_uint8(img, r_shift, g_shift, b_shift):
if r_shift == g_shift == b_shift:
h, w, c = img.shape
img = img.reshape([h, w * c])

return _shift_image_uint8(img, r_shift)

result_img = np.empty_like(img)
shifts = [r_shift, g_shift, b_shift]
for i, shift in enumerate(shifts):
result_img[..., i] = _shift_image_uint8(img[..., i], shift)

return result_img


@clipped
def _shift_rgb_non_uint8(img, r_shift, g_shift, b_shift):
if r_shift == g_shift == b_shift:
return img + r_shift

result_img = np.empty_like(img)
shifts = [r_shift, g_shift, b_shift]
for i, shift in enumerate(shifts):
result_img[..., i] = img[..., i] + shift

return result_img


def shift_rgb(img, r_shift, g_shift, b_shift):
if img.dtype == np.uint8:
img = img.astype('int32')
r_shift, g_shift, b_shift = np.int32(r_shift), np.int32(g_shift), np.int32(b_shift)
else:
# Make a copy of the input image since we don't want to modify it directly
img = img.copy()
img[..., 0] += r_shift
img[..., 1] += g_shift
img[..., 2] += b_shift
return img
return _shift_rgb_uint8(img, r_shift, g_shift, b_shift)

return _shift_rgb_non_uint8(img, r_shift, g_shift, b_shift)


def clahe(img, clip_limit=2.0, tile_grid_size=(8, 8)):
Expand Down

0 comments on commit c3cc277

Please sign in to comment.