Skip to content

Commit

Permalink
Merge pull request #69 from albu/enhance_bbox_pipeline
Browse files Browse the repository at this point in the history
Enhance bbox pipeline
  • Loading branch information
albu committed Sep 19, 2018
2 parents 8a7353c + eab24a4 commit a5bbfed
Show file tree
Hide file tree
Showing 11 changed files with 283 additions and 153 deletions.
4 changes: 2 additions & 2 deletions albumentations/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from __future__ import absolute_import

__version__ = '0.0.14'
__version__ = '0.0.15'

from .core.composition import *
from .core.transforms_interface import *
from .augmentations.transforms import *
from .augmentations.bbox import *
from .augmentations.bbox_utils import *
from .imgaug.transforms import *
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import division

import numpy as np

__all__ = ['normalize_bbox', 'denormalize_bbox', 'normalize_bboxes', 'denormalize_bboxes', 'calculate_bbox_area',
'filter_bboxes_by_visibility', 'convert_bbox_to_albumentations', 'convert_bbox_from_albumentations',
Expand Down Expand Up @@ -42,41 +43,47 @@ def calculate_bbox_area(bbox, rows, cols):
return area


def filter_bboxes_by_visibility(img, bboxes, transformed_img, transformed_bboxes, threshold):
def filter_bboxes_by_visibility(original_shape, bboxes, transformed_shape, transformed_bboxes,
threshold=0., min_area=0.):
"""Filter bounding boxes and return only those boxes whose visibility after transformation is above
the threshold.
the threshold and minimal area of bounding box in pixels is more then min_area.
Args:
img (np.array): original image
original_shape (tuple): original image shape
bboxes (list): original bounding boxes
transformed_img (np.array): transformed image
transformed_shape(tuple): transformed image
transformed_bboxes (list): transformed bounding boxes
threshold (float): visibility threshold. Should be a value in the range [0.0, 1.0].
min_area (float): Minimal area threshold.
"""
img_height, img_width = img.shape[:2]
transformed_img_height, transformed_img_width = transformed_img.shape[:2]
img_height, img_width = original_shape[:2]
transformed_img_height, transformed_img_width = transformed_shape[:2]

visible_bboxes = []
for bbox, transformed_bbox in zip(bboxes, transformed_bboxes):
if not all(0.0 <= value <= 1.0 for value in transformed_bbox[:4]):
continue
bbox_area = calculate_bbox_area(bbox, img_height, img_width)
transformed_bbox_area = calculate_bbox_area(transformed_bbox, transformed_img_height, transformed_img_width)
if transformed_bbox_area < min_area:
continue
visibility = transformed_bbox_area / bbox_area
if visibility >= threshold:
visible_bboxes.append(transformed_bbox)
return visible_bboxes


def convert_bbox_to_albumentations(shape, bbox, source_format):
def convert_bbox_to_albumentations(bbox, source_format, rows, cols, check_validity=False):
"""Convert a bounding box from a format specified in `source_format` to the format used by albumentations:
normalized coordinates of bottom-left and top-right corners of the bounding box in a form of
`[x_min, y_min, x_max, y_max]` e.g. `[0.15, 0.27, 0.67, 0.5]`.
Args:
shape (tuple): input image shape. Image must have at least 2 dims
bbox (list): bounding box
source_format (str): format of the bounding box. Should be 'coco' or 'pascal_voc'.
check_validity (bool): check if all boxes are valid boxes
rows (int): image height
cols (int): image width
Note:
The `coco` format of a bounding box looks like `[x_min, y_min, width, height]`, e.g. [97, 12, 150, 200].
Expand All @@ -90,25 +97,29 @@ def convert_bbox_to_albumentations(shape, bbox, source_format):
raise ValueError(
"Unknown source_format {}. Supported formats are: 'coco' and 'pascal_voc'".format(source_format)
)
img_height, img_width = shape[:2]
if source_format == 'coco':
x_min, y_min, width, height = bbox[:4]
x_max = x_min + width
y_max = y_min + height
else:
x_min, y_min, x_max, y_max = bbox[:4]
bbox = [x_min, y_min, x_max, y_max] + list(bbox[4:])
bbox = normalize_bbox(bbox, img_height, img_width)
bbox = normalize_bbox(bbox, rows, cols)
if check_validity:
check_bbox(bbox)
return bbox


def convert_bbox_from_albumentations(shape, bbox, target_format):
def convert_bbox_from_albumentations(bbox, target_format, rows, cols, check_validity=False):
"""Convert a bounding box from the format used by albumentations to a format, specified in `target_format`.
Args:
shape (tuple): input image shape. Image must have at least 2 dims
bbox (list): bounding box with coordinates in the format used by albumentations
target_format (str): required format of the output bounding box. Should be 'coco' or 'pascal_voc'.
check_validity (bool): check if all boxes are valid boxes
rows (int): image height
cols (int): image width
Note:
The `coco` format of a bounding box looks like `[x_min, y_min, width, height]`, e.g. [97, 12, 150, 200].
Expand All @@ -122,8 +133,9 @@ def convert_bbox_from_albumentations(shape, bbox, target_format):
raise ValueError(
"Unknown target_format {}. Supported formats are: 'coco' and 'pascal_voc'".format(target_format)
)
img_height, img_width = shape[:2]
bbox = denormalize_bbox(bbox, img_height, img_width)
if check_validity:
check_bbox(bbox)
bbox = denormalize_bbox(bbox, rows, cols)
if target_format == 'coco':
x_min, y_min, x_max, y_max = bbox[:4]
width = x_max - x_min
Expand All @@ -132,13 +144,13 @@ def convert_bbox_from_albumentations(shape, bbox, target_format):
return bbox


def convert_bboxes_to_albumentations(shape, bboxes, source_format):
def convert_bboxes_to_albumentations(bboxes, source_format, rows, cols, check_validity=False):
"""Convert a list bounding boxes from a format specified in `source_format` to the format used by albumentations
"""
return [convert_bbox_to_albumentations(shape, bbox, source_format) for bbox in bboxes]
return [convert_bbox_to_albumentations(bbox, source_format, rows, cols, check_validity) for bbox in bboxes]


def convert_bboxes_from_albumentations(shape, bboxes, target_format):
def convert_bboxes_from_albumentations(bboxes, target_format, rows, cols, check_validity=False):
"""Convert a list of bounding boxes from the format used by albumentations to a format, specified
in `target_format`.
Expand All @@ -147,4 +159,60 @@ def convert_bboxes_from_albumentations(shape, bboxes, target_format):
bboxes (list): List of bounding box with coordinates in the format used by albumentations
target_format (str): required format of the output bounding box. Should be 'coco' or 'pascal_voc'.
"""
return [convert_bbox_from_albumentations(shape, bbox, target_format) for bbox in bboxes]
return [convert_bbox_from_albumentations(bbox, target_format, rows, cols, check_validity) for bbox in bboxes]


def check_bbox(bbox):
"""Check if bbox boundaries are in range 0, 1 and minimums are lesser then maximums"""
for name, value in zip(['x_min', 'y_min', 'x_max', 'y_max'], bbox[:4]):
if not 0 <= value <= 1:
raise ValueError(
'Expected {name} for bbox {bbox} '
'to be in the range [0.0, 1.0], got {value}.'.format(
bbox=bbox,
name=name,
value=value,
)
)
x_min, y_min, x_max, y_max = bbox[:4]
if x_max <= x_min:
raise ValueError('x_max is less than or equal to x_min for bbox {bbox}.'.format(
bbox=bbox,
))
if y_max <= y_min:
raise ValueError('y_max is less than or equal to y_min for bbox {bbox}.'.format(
bbox=bbox,
))


def check_bboxes(bboxes):
"""Check if bboxes boundaries are in range 0, 1 and minimums are lesser then maximums"""
for bbox in bboxes:
check_bbox(bbox)


def filter_bboxes(bboxes, rows, cols, min_area=0., min_visibility=0.):
"""Remove bounding boxes that either lie outside of the visible area by more then min_visibility
or whose area in pixels is under the threshold set by `min_area`. Also it crops boxes to final image size.
Args:
min_area (float): minimum area of a bounding box. All bounding boxes whose visible area in pixels
is less than this value will be removed. Default: 0.0.
min_visibility (float): minimum fraction of area for a bounding box to remain this box in list. Default: 0.0.
rows (int): Image rows.
rows (int): Image cols.
"""
resulting_boxes = []
for bbox in bboxes:
if min_visibility:
transformed_box_area = calculate_bbox_area(bbox, rows, cols)
bbox[:4] = np.clip(bbox[:4], 0, 1.)
clipped_box_area = calculate_bbox_area(bbox, rows, cols)
if not transformed_box_area or clipped_box_area / transformed_box_area < min_visibility:
continue
else:
bbox[:4] = np.clip(bbox[:4], 0, 1.)
if min_area and calculate_bbox_area(bbox, rows, cols) < min_area:
continue
resulting_boxes.append(bbox)
return resulting_boxes
16 changes: 2 additions & 14 deletions albumentations/augmentations/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
import numpy as np
from scipy.ndimage.filters import gaussian_filter

from albumentations.augmentations.bbox import denormalize_bbox, normalize_bbox, calculate_bbox_area

from albumentations.augmentations.bbox_utils import denormalize_bbox, normalize_bbox

MAX_VALUES_BY_DTYPE = {
np.dtype('uint8'): 255,
Expand Down Expand Up @@ -537,7 +536,7 @@ def crop_bbox_by_coords(bbox, crop_coords, crop_height, crop_width, rows, cols):
bbox = denormalize_bbox(bbox, rows, cols)
x_min, y_min, x_max, y_max = bbox
x1, y1, x2, y2 = crop_coords
cropped_bbox = [max(x_min, x1) - x1, max(y_min, y1) - y1, min(x_max, x2) - x1, min(y_max, y2) - y1]
cropped_bbox = [x_min - x1, y_min - y1, x_max - x1, y_max - y1]
return normalize_bbox(cropped_bbox, crop_height, crop_width)


Expand Down Expand Up @@ -596,14 +595,3 @@ def bbox_transpose(bbox, axis, rows, cols):
if axis == 1:
bbox = [1 - y_max, 1 - x_max, 1 - y_min, 1 - x_min]
return bbox


def filter_bboxes(bboxes, min_area, rows, cols):
filtered_bboxes = []
for bbox in bboxes:
if not all(0.0 <= value <= 1.0 for value in bbox[:4]):
continue
if min_area and calculate_bbox_area(bbox, rows, cols) < min_area:
continue
filtered_bboxes.append(bbox)
return filtered_bboxes
26 changes: 1 addition & 25 deletions albumentations/augmentations/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
'ElasticTransform', 'HueSaturationValue', 'PadIfNeeded', 'RGBShift', 'RandomBrightness', 'RandomContrast',
'MotionBlur', 'MedianBlur', 'GaussNoise', 'CLAHE', 'ChannelShuffle', 'InvertImg', 'ToGray',
'JpegCompression', 'Cutout', 'ToFloat', 'FromFloat', 'Crop', 'RandomScale', 'LongestMaxSize', 'Resize',
'FilterBboxes', 'RandomSizedCrop']
'RandomSizedCrop']


class PadIfNeeded(DualTransform):
Expand Down Expand Up @@ -1008,27 +1008,3 @@ def __init__(self, dtype='uint16', max_value=None, p=1.0):

def apply(self, img, **params):
return F.from_float(img, self.dtype, self.max_value)


class FilterBboxes(DualTransform):
"""Remove bounding boxes that either lie outside of the visible area or whose area in pixels is under
the threshold set by `min_area`.
Args:
min_area (float): minimum area of a bounding box. All bounding boxes whose visible area in pixels
is less than this value will be removed. Default: 0.0.
p (float): probability of applying the transform. Default: 1.0.
Targets:
bboxes
"""

def __init__(self, min_area=0.0, p=1.0):
super(FilterBboxes, self).__init__(p)
self.min_area = min_area

def apply(self, img, **params):
return img

def apply_to_bboxes(self, bboxes, **params):
return F.filter_bboxes(bboxes, self.min_area, **params)
86 changes: 82 additions & 4 deletions albumentations/core/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,104 @@
import random

import numpy as np
from albumentations.augmentations.bbox_utils import convert_bboxes_from_albumentations, \
convert_bboxes_to_albumentations, filter_bboxes


__all__ = ['Compose', 'OneOf', 'OneOrOther']
__all__ = ['Compose', 'OneOf', 'OneOrOther', 'ComposeWithBoxes']


class Compose(object):
"""Compose transforms together."""
"""Compose transforms together.
def __init__(self, transforms, p=1.0):
Args:
transforms (list): list of transformations to compose.
preprocessing_transforms (list): list of transforms to run before transforms
p (float): probability of applying all list of transforms. Default: 1.0.
"""

def __init__(self, transforms, preprocessing_transforms=[], p=1.0):
self.preprocessing_transforms = preprocessing_transforms
self.transforms = [t for t in transforms if t is not None]
self.p = p

def __call__(self, **data):
if random.random() < self.p:
need_to_run = random.random() < self.p
return self.run_transforms_if_needed(need_to_run, data)

def run_transforms_if_needed(self, need_to_run, data):
for t in self.preprocessing_transforms:
data = t(**data)
if need_to_run:
for t in self.transforms:
data = t(**data)
return data


class ComposeWithBoxes(Compose):
"""Compose transforms and handle all transformations regrading bounding boxes
Args:
transforms (list): list of transformations to compose.
preprocessing_transforms (list): list of transforms to run before transforms
bbox_format (str): format of bounding boxes. Should be 'coco' or 'pascal_voc'.
label_fields (list): list of fields that are joined with boxes, e.g labels. Should be same type as boxes.
min_area (float): minimum area of a bounding box. All bounding boxes whose visible area in pixels
is less than this value will be removed. Default: 0.0.
min_visibility (float): minimum fraction of area for a bounding box to remain this box in list. Default: 0.0.
p (float): probability of applying all list of transforms. Default: 1.0.
"""

def __init__(self, transforms, bbox_format, label_fields=[], min_area=0., min_visibility=0.,
preprocessing_transforms=[], p=1.0):
super(ComposeWithBoxes, self).__init__(transforms, preprocessing_transforms, p=p)
self.bbox_format = bbox_format
self.label_fields = label_fields
self.min_area = min_area
self.min_visibility = min_visibility

def __call__(self, **data):
need_to_run = random.random() < self.p
if self.preprocessing_transforms or need_to_run:
if 'bboxes' not in data:
raise Exception('Please name field with bounding boxes `bboxes`')
if self.label_fields:
for field in self.label_fields:
bboxes_with_added_field = []
for bbox, field_value in zip(data['bboxes'], data[field]):
bboxes_with_added_field.append(list(bbox) + [field_value])
data['bboxes'] = bboxes_with_added_field

rows, cols = data['image'].shape[:2]
data['bboxes'] = convert_bboxes_to_albumentations(data['bboxes'], self.bbox_format, rows, cols,
check_validity=True)

data = self.run_transforms_if_needed(need_to_run, data)

rows, cols = data['image'].shape[:2]
data['bboxes'] = filter_bboxes(data['bboxes'], rows, cols, self.min_area, self.min_visibility)

data['bboxes'] = convert_bboxes_from_albumentations(data['bboxes'], self.bbox_format, rows, cols,
check_validity=True)

if self.label_fields:
for idx, field in enumerate(self.label_fields):
field_values = []
for bbox in data['bboxes']:
field_values.append(bbox[4 + idx])
data[field] = field_values
data['bboxes'] = [bbox[:4] for bbox in data['bboxes']]
return data


class OneOf(object):
"""Select on of transforms to apply
Args:
transforms (list): list of transformations to compose.
p (float): probability of applying selected transform. Default: 0.5.
"""

def __init__(self, transforms, p=0.5):
self.transforms = transforms
self.p = p
Expand Down

0 comments on commit a5bbfed

Please sign in to comment.