Skip to content

Commit

Permalink
Add support for non-8-bit images (#31)
Browse files Browse the repository at this point in the history
  • Loading branch information
creafz committed Jul 21, 2018
1 parent 80b8526 commit 6bf45c9
Show file tree
Hide file tree
Showing 10 changed files with 850 additions and 63 deletions.
7 changes: 6 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ augmented = augmentation(**data)
image, mask, whatever_data, additional = augmented["image"], augmented["mask"], augmented["whatever_data"], augmented["additional"]
```

See `example.ipynb`
See [`example.ipynb`](notebooks/example.ipynb)


## Installation
You can use pip to install albumentations:
Expand All @@ -80,6 +81,10 @@ You can use this [Google Colaboratory notebook](https://colab.research.google.co
to adjust image augmentation parameters and see the resulting images.


## Working with non-8-bit images
[`example_16_bit_tiff.ipynb`](notebooks/example_16_bit_tiff.ipynb) shows how albumentations can be used to work with non-8-bit images (such as 16-bit and 32-bit TIFF images).


## Benchmarking results
To run the benchmark yourself follow the instructions in [benchmark/README.md](benchmark/README.md)

Expand Down
50 changes: 44 additions & 6 deletions albumentations/augmentations/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@
import numpy as np
from scipy.ndimage.filters import gaussian_filter

MAX_CLIPPING_VALUES = {
MAX_VALUES_BY_DTYPE = {
np.dtype('uint8'): 255,
np.dtype('uint16'): 65535,
np.dtype('uint32'): 4294967295,
np.dtype('float32'): 1.0,
}


Expand Down Expand Up @@ -122,37 +125,43 @@ def clipped(func):
@wraps(func)
def wrapped_function(img, *args, **kwargs):
dtype = img.dtype
maxval = MAX_CLIPPING_VALUES.get(dtype, 1.0)
maxval = MAX_VALUES_BY_DTYPE.get(dtype, 1.0)
return clip(func(img, *args, **kwargs), dtype, maxval)

return wrapped_function


def shift_hsv(img, hue_shift, sat_shift, val_shift):
dtype = img.dtype
img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV).astype(np.int32)
img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)
if dtype == np.uint8:
img = img.astype(np.int32)
hue, sat, val = cv2.split(img)
hue = cv2.add(hue, hue_shift)
hue = np.where(hue < 0, 180 - hue, hue)
hue = np.where(hue > 180, hue - 180, hue)
hue = hue.astype(dtype)
sat = clip(cv2.add(sat, sat_shift), dtype, 255 if dtype == np.uint8 else 1.)
val = clip(cv2.add(val, val_shift), dtype, 255 if dtype == np.uint8 else 1.)
sat = clip(cv2.add(sat, sat_shift), dtype, 255 if dtype == np.uint8 else 1.0)
val = clip(cv2.add(val, val_shift), dtype, 255 if dtype == np.uint8 else 1.0)
img = cv2.merge((hue, sat, val)).astype(dtype)
img = cv2.cvtColor(img, cv2.COLOR_HSV2RGB)
return img


@clipped
def shift_rgb(img, r_shift, g_shift, b_shift):
img = img.astype('int32')
if img.dtype == np.uint8:
img = img.astype('int32')
r_shift, g_shift, b_shift = np.int32(r_shift), np.int32(g_shift), np.int32(b_shift)
img[..., 0] += r_shift
img[..., 1] += g_shift
img[..., 2] += b_shift
return img


def clahe(img, clip_limit=2.0, tile_grid_size=(8, 8)):
if img.dtype != np.uint8:
raise TypeError('clahe supports only uint8 inputs')
img = cv2.cvtColor(img, cv2.COLOR_RGB2LAB)
clahe = cv2.createCLAHE(clipLimit=clip_limit, tileGridSize=tile_grid_size)
img[:, :, 0] = clahe.apply(img[:, :, 0])
Expand Down Expand Up @@ -190,6 +199,9 @@ def blur(img, ksize):


def median_blur(img, ksize):
if img.dtype == np.float32 and ksize not in {3, 5}:
raise ValueError(
'Invalid ksize value {}. For a float32 image the only valid ksize values are 3 and 5'.format(ksize))
return cv2.medianBlur(img, ksize)


Expand All @@ -202,6 +214,8 @@ def motion_blur(img, ksize):


def jpeg_compression(img, quality):
if img.dtype != np.uint8:
raise TypeError('jpeg_compression supports only uint8 inputs')
_, encoded_img = cv2.imencode('.jpg', img, (cv2.IMWRITE_JPEG_QUALITY, quality))
img = cv2.imdecode(encoded_img, cv2.IMREAD_UNCHANGED)
return img
Expand Down Expand Up @@ -370,6 +384,30 @@ def to_gray(img):
return cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB)


def to_float(img, max_value=None):
if max_value is None:
try:
max_value = MAX_VALUES_BY_DTYPE[img.dtype]
except KeyError:
raise RuntimeError(
'Can\'t infer the maximum value for dtype {}. You need to specify the maximum value manually by '
'passing the max_value argument'.format(img.dtype)
)
return img.astype('float32') / max_value


def from_float(img, dtype, max_value=None):
if max_value is None:
try:
max_value = MAX_VALUES_BY_DTYPE[dtype]
except KeyError:
raise RuntimeError(
'Can\'t infer the maximum value for dtype {}. You need to specify the maximum value manually by '
'passing the max_value argument'.format(dtype)
)
return (img * max_value).astype(dtype)


def bbox_vflip(bbox, cols, rows):
return (cols - bbox[0] - bbox[2],) + tuple(bbox[1:])

Expand Down

0 comments on commit 6bf45c9

Please sign in to comment.