Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented Perspective transform equal to imgaug #790

Merged
merged 10 commits into from
Dec 16, 2020
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ Spatial-level transforms will simultaneously change both an input image as well
| [NoOp](https://albumentations.ai/docs/api_reference/core/transforms_interface/#albumentations.core.transforms_interface.NoOp) | ✓ | ✓ | ✓ | ✓ |
| [OpticalDistortion](https://albumentations.ai/docs/api_reference/augmentations/transforms/#albumentations.augmentations.transforms.OpticalDistortion) | ✓ | ✓ | | |
| [PadIfNeeded](https://albumentations.ai/docs/api_reference/augmentations/transforms/#albumentations.augmentations.transforms.PadIfNeeded) | ✓ | ✓ | ✓ | ✓ |
| [Perspective](https://albumentations.ai/docs/api_reference/augmentations/geometric/transforms/#albumentations.augmentations.geometric.transforms.Perspective) | ✓ | ✓ | ✓ | ✓ |
| [RandomCrop](https://albumentations.ai/docs/api_reference/augmentations/transforms/#albumentations.augmentations.transforms.RandomCrop) | ✓ | ✓ | ✓ | ✓ |
| [RandomCropNearBBox](https://albumentations.ai/docs/api_reference/augmentations/transforms/#albumentations.augmentations.transforms.RandomCropNearBBox) | ✓ | ✓ | ✓ | ✓ |
| [RandomGridShuffle](https://albumentations.ai/docs/api_reference/augmentations/transforms/#albumentations.augmentations.transforms.RandomGridShuffle) | ✓ | ✓ | | |
Expand Down
50 changes: 50 additions & 0 deletions albumentations/augmentations/geometric/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from scipy.ndimage.filters import gaussian_filter

from ..bbox_utils import denormalize_bbox, normalize_bbox
from ..functional import angle_2pi_range, preserve_channel_dim, _maybe_process_in_chunks, preserve_shape


Expand Down Expand Up @@ -310,3 +311,52 @@ def longest_max_size(img, max_size, interpolation):
@preserve_channel_dim
def smallest_max_size(img, max_size, interpolation):
return _func_max_size(img, max_size, interpolation, min)


@preserve_channel_dim
def perspective(img, matrix, max_width, max_height, border_val, border_mode, keep_size, interpolation):
Dipet marked this conversation as resolved.
Show resolved Hide resolved
h, w = img.shape[:2]
perspective_func = _maybe_process_in_chunks(
cv2.warpPerspective,
M=matrix,
dsize=(max_width, max_height),
borderMode=border_mode,
borderValue=border_val,
flags=interpolation,
)
warped = perspective_func(img)

if keep_size:
return resize(warped, h, w, interpolation=interpolation)

return warped


def perspective_bbox(bbox, h, w, matrix, max_width, max_height):
Dipet marked this conversation as resolved.
Show resolved Hide resolved
bbox = denormalize_bbox(bbox, h, w)

points = np.array(
[[bbox[0], bbox[1]], [bbox[2], bbox[1]], [bbox[2], bbox[2]], [bbox[0], bbox[2]]], dtype=np.float32
)

points = cv2.perspectiveTransform(points.reshape([1, 4, 2]), matrix)

bbox = [np.min(points[..., 0]), np.min(points[..., 1]), np.max(points[..., 0]), np.max(points[..., 1])]
return normalize_bbox(bbox, max_height, max_width)


def rotation2DMatrixToEulerAngles(matrix):
return np.arctan2(matrix[1, 0], matrix[0, 0])


def perspective_keypoint(keypoint, h, w, matrix, max_width, max_height, keep_size):
x, y, angle, scale = keypoint

keypoint = np.array([x, y], dtype=np.float32).reshape([1, 1, 2])

x, y = cv2.perspectiveTransform(keypoint, matrix)[0, 0]
angle += rotation2DMatrixToEulerAngles(matrix[:2, :2])
if not keep_size:
scale += max(max_height / h, max_width / w)

return x, y, angle, scale
182 changes: 181 additions & 1 deletion albumentations/augmentations/geometric/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from . import functional as F
from ...core.transforms_interface import DualTransform, to_tuple

__all__ = ["ShiftScaleRotate", "ElasticTransform"]
__all__ = ["ShiftScaleRotate", "ElasticTransform", "Perspective"]


class ShiftScaleRotate(DualTransform):
Expand Down Expand Up @@ -190,3 +190,183 @@ def get_params(self):

def get_transform_init_args_names(self):
return ("alpha", "sigma", "alpha_affine", "interpolation", "border_mode", "value", "mask_value", "approximate")


class Perspective(DualTransform):
"""Perform a random four point perspective transform of the input.

Args:
scale ((float, float): standard deviation of the normal distributions. These are used to sample
the random distances of the subimage's corners from the full image's corners. Default: (0.05, 0.1).
keep_size (bool): Whether to resize image’s back to their original size after applying the perspective
transform. If set to False, the resulting images may end up having different shapes
and will always be a list, never an array. Default: True
pad_mode (OpenCV flag): OpenCV border mode.
pad_val (int, float, list of int, list of float): padding value if border_mode is cv2.BORDER_CONSTANT.
Default: 0
mask_pad_val (int, float, list of int, list of float): padding value for mask
if border_mode is cv2.BORDER_CONSTANT. Default: 0
fit_output (bool): If True, the image plane size and position will be adjusted to still capture
the whole image after perspective transformation. (Followed by image resizing if keep_size is set to True.)
Otherwise, parts of the transformed image may be outside of the image plane.
This setting should not be set to True when using large scale values as it could lead to very large images.
Default: False
p (float): probability of applying the transform. Default: 0.5.

Targets:
image, mask, keypoints, bboxes

Image types:
uint8, float32
"""

def __init__(
self,
scale=(0.05, 0.1),
keep_size=True,
pad_mode=cv2.BORDER_CONSTANT,
pad_val=0,
mask_pad_val=0,
fit_output=False,
interpolation=cv2.INTER_LINEAR,
always_apply=False,
p=0.5,
):
super().__init__(always_apply, p)
self.scale = to_tuple(scale, 0)
self.keep_size = keep_size
self.pad_mode = pad_mode
self.pad_val = pad_val
self.mask_pad_val = mask_pad_val
self.fit_output = fit_output
self.interpolation = interpolation

def apply(self, img, matrix=None, max_height=None, max_width=None, **params):
return F.perspective(
img, matrix, max_width, max_height, self.pad_val, self.pad_mode, self.keep_size, params["interpolation"]
)

def apply_to_bbox(self, bbox, matrix=None, max_height=None, max_width=None, **params):
return F.perspective_bbox(bbox, params["rows"], params["cols"], matrix, max_width, max_height)

def apply_to_keypoint(self, keypoint, matrix=None, max_height=None, max_width=None, **params):
return F.perspective_keypoint(
keypoint, params["rows"], params["cols"], matrix, max_width, max_height, self.keep_size
)

@property
def targets_as_params(self):
return ["image"]

def get_params_dependent_on_targets(self, params):
h, w = params["image"].shape[:2]

scale = np.random.uniform(*self.scale)
points = np.random.normal(0, scale, [4, 2])
points = np.mod(np.abs(points), 1)

# top left -- no changes needed, just use jitter
# top right
points[1, 0] = 1.0 - points[1, 0] # w = 1.0 - jitter
# bottom right
points[2] = 1.0 - points[2] # w = 1.0 - jitt
# bottom left
points[3, 1] = 1.0 - points[3, 1] # h = 1.0 - jitter

points[:, 0] *= w
points[:, 1] *= h

# Obtain a consistent order of the points and unpack them individually.
# Warning: don't just do (tl, tr, br, bl) = _order_points(...)
# here, because the reordered points is used further below.
points = self._order_points(points)
(tl, tr, br, bl) = points

# compute the width of the new image, which will be the
# maximum distance between bottom-right and bottom-left
# x-coordiates or the top-right and top-left x-coordinates
min_width = None
max_width = None
while min_width is None or min_width < 2:
width_top = np.sqrt(((tr[0] - tl[0]) ** 2) + ((tr[1] - tl[1]) ** 2))
width_bottom = np.sqrt(((br[0] - bl[0]) ** 2) + ((br[1] - bl[1]) ** 2))
max_width = int(max(width_top, width_bottom))
min_width = int(min(width_top, width_bottom))
if min_width < 2:
step_size = (2 - min_width) / 2
tl[0] -= step_size
tr[0] += step_size
bl[0] -= step_size
br[0] += step_size

# compute the height of the new image, which will be the maximum distance between the top-right
# and bottom-right y-coordinates or the top-left and bottom-left y-coordinates
min_height = None
max_height = None
while min_height is None or min_height < 2:
height_right = np.sqrt(((tr[0] - br[0]) ** 2) + ((tr[1] - br[1]) ** 2))
height_left = np.sqrt(((tl[0] - bl[0]) ** 2) + ((tl[1] - bl[1]) ** 2))
max_height = int(max(height_right, height_left))
min_height = int(min(height_right, height_left))
if min_height < 2:
step_size = (2 - min_height) / 2
tl[1] -= step_size
tr[1] -= step_size
bl[1] += step_size
br[1] += step_size

# now that we have the dimensions of the new image, construct
# the set of destination points to obtain a "birds eye view",
# (i.e. top-down view) of the image, again specifying points
# in the top-left, top-right, bottom-right, and bottom-left order
# do not use width-1 or height-1 here, as for e.g. width=3, height=2
# the bottom right coordinate is at (3.0, 2.0) and not (2.0, 1.0)
dst = np.array([[0, 0], [max_width, 0], [max_width, max_height], [0, max_height]], dtype=np.float32)

# compute the perspective transform matrix and then apply it
m = cv2.getPerspectiveTransform(points, dst)

if self.fit_output:
m, max_width, max_height = self._expand_transform(m, (h, w))

return {"matrix": m, "max_height": max_height, "max_width": max_width, "interpolation": self.interpolation}

@classmethod
def _expand_transform(cls, matrix, shape):
height, width = shape
# do not use width-1 or height-1 here, as for e.g. width=3, height=2, max_height
# the bottom right coordinate is at (3.0, 2.0) and not (2.0, 1.0)
rect = np.array([[0, 0], [width, 0], [width, height], [0, height]], dtype=np.float32)
dst = cv2.perspectiveTransform(np.array([rect]), matrix)[0]

# get min x, y over transformed 4 points
# then modify target points by subtracting these minima => shift to (0, 0)
dst -= dst.min(axis=0, keepdims=True)
dst = np.around(dst, decimals=0)

matrix_expanded = cv2.getPerspectiveTransform(rect, dst)
max_width, max_height = dst.max(axis=0)
return matrix_expanded, int(max_width), int(max_height)

@classmethod
def _order_points(cls, pts):
# initialzie a list of coordinates that will be ordered such that the first entry in the list is the top-left,
# the second entry is the top-right, the third is the bottom-right, and the fourth is the bottom-left
pts_ordered = np.zeros((4, 2), dtype=np.float32)

# the top-left point will have the smallest sum, whereas the bottom-right point will have the largest sum
pointwise_sum = pts.sum(axis=1)
pts_ordered[0] = pts[np.argmin(pointwise_sum)]
pts_ordered[2] = pts[np.argmax(pointwise_sum)]

# now, compute the difference between the points, the top-right point will have the smallest difference,
# whereas the bottom-left will have the largest difference
diff = np.diff(pts, axis=1)
pts_ordered[1] = pts[np.argmin(diff)]
pts_ordered[3] = pts[np.argmax(diff)]

# return the ordered coordinates
return pts_ordered

def get_transform_init_args_names(self):
return ("scale", "keep_size", "pad_mode", "pad_val", "mask_pad_val", "fit_output", "interpolation")
19 changes: 7 additions & 12 deletions albumentations/imgaug/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
from ..augmentations.bbox_utils import convert_bboxes_from_albumentations, convert_bboxes_to_albumentations
from ..augmentations.keypoints_utils import convert_keypoints_from_albumentations, convert_keypoints_to_albumentations
from ..core.transforms_interface import BasicTransform, DualTransform, ImageOnlyTransform, to_tuple
from ..augmentations import Perspective

import warnings

__all__ = [
"BasicIAATransform",
Expand Down Expand Up @@ -313,7 +316,7 @@ def get_transform_init_args_names(self):
return ("scale", "translate_percent", "translate_px", "rotate", "shear", "order", "cval", "mode")


class IAAPerspective(DualIAATransform):
class IAAPerspective(Perspective):
"""Perform a random four point perspective transform of the input.

Note: This class introduce interpolation artifacts to mask if it has values other than {0;1}
Expand All @@ -327,14 +330,6 @@ class IAAPerspective(DualIAATransform):
image, mask
"""

def __init__(self, scale=(0.05, 0.1), keep_size=True, always_apply=False, p=0.5):
super(IAAPerspective, self).__init__(always_apply, p)
self.scale = to_tuple(scale, 1.0)
self.keep_size = keep_size

@property
def processor(self):
return iaa.PerspectiveTransform(self.scale, keep_size=self.keep_size)

def get_transform_init_args_names(self):
return ("scale", "keep_size")
def __init__(self, scale=(0.05, 0.1), keep_size=True, always_apply=False, p=0.5, **kwargs):
warnings.warn("This augmentation is deprecated. Please use Perspective instead")
super().__init__(scale=scale, keep_size=keep_size, always_apply=always_apply, p=p, **kwargs)
9 changes: 9 additions & 0 deletions tests/test_augmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
ColorJitter,
FDA,
HistogramMatching,
Perspective,
)


Expand Down Expand Up @@ -191,6 +192,7 @@ def test_image_only_augmentations_with_float_values(augmentation_cls, params, fl
[ISONoise, {}],
[RandomGridShuffle, {}],
[GridDropout, {}],
[Perspective, {}],
],
)
def test_dual_augmentations(augmentation_cls, params, image, mask):
Expand Down Expand Up @@ -221,6 +223,7 @@ def test_dual_augmentations(augmentation_cls, params, image, mask):
[RandomSizedCrop, {"min_max_height": (4, 8), "height": 10, "width": 10}],
[RandomGridShuffle, {}],
[GridDropout, {}],
[Perspective, {}],
],
)
def test_dual_augmentations_with_float_values(augmentation_cls, params, float_image, mask):
Expand Down Expand Up @@ -310,6 +313,7 @@ def test_imgaug_dual_augmentations(augmentation_cls, image, mask):
FDA,
{"reference_images": [np.random.randint(0, 256, [100, 100, 3], dtype=np.uint8)], "read_fn": lambda x: x},
],
[Perspective, {}],
],
)
def test_augmentations_wont_change_input(augmentation_cls, params, image, mask):
Expand Down Expand Up @@ -377,6 +381,7 @@ def test_augmentations_wont_change_input(augmentation_cls, params, image, mask):
FDA,
{"reference_images": [np.random.randint(0, 256, [100, 100, 3], dtype=np.uint8)], "read_fn": lambda x: x},
],
[Perspective, {}],
],
)
def test_augmentations_wont_change_float_input(augmentation_cls, params, float_image):
Expand Down Expand Up @@ -425,6 +430,7 @@ def test_augmentations_wont_change_float_input(augmentation_cls, params, float_i
{"reference_images": [np.random.randint(0, 256, [100, 100], dtype=np.uint8)], "read_fn": lambda x: x},
],
[FDA, {"reference_images": [np.random.randint(0, 256, [100, 100], dtype=np.uint8)], "read_fn": lambda x: x}],
[Perspective, {}],
],
)
def test_augmentations_wont_change_shape_grayscale(augmentation_cls, params, image, mask):
Expand Down Expand Up @@ -500,6 +506,7 @@ def test_augmentations_wont_change_shape_grayscale(augmentation_cls, params, ima
FDA,
{"reference_images": [np.random.randint(0, 256, [100, 100, 3], dtype=np.uint8)], "read_fn": lambda x: x},
],
[Perspective, {}],
],
)
def test_augmentations_wont_change_shape_rgb(augmentation_cls, params, image, mask):
Expand Down Expand Up @@ -562,6 +569,7 @@ def test_mask_fill_value(augmentation_cls, params):
[RandomBrightnessContrast, {}],
[MultiplicativeNoise, {}],
[GridDropout, {}],
[Perspective, {}],
],
)
def test_multichannel_image_augmentations(augmentation_cls, params):
Expand Down Expand Up @@ -589,6 +597,7 @@ def test_multichannel_image_augmentations(augmentation_cls, params):
[RandomBrightnessContrast, {}],
[MultiplicativeNoise, {}],
[GridDropout, {}],
[Perspective, {}],
],
)
def test_multichannel_image_augmentations_diff_channels(augmentation_cls, params):
Expand Down
4 changes: 2 additions & 2 deletions tests/test_imgaug.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
convert_bboxes_from_albumentations,
convert_bboxes_to_albumentations,
)
from albumentations.imgaug.transforms import IAAPiecewiseAffine, IAAPerspective, IAAFliplr, IAAFlipud
from albumentations.imgaug.transforms import IAAPiecewiseAffine, IAAFliplr, IAAFlipud


@pytest.mark.parametrize("augmentation_cls", [IAAPiecewiseAffine, IAAPerspective, IAAFliplr])
@pytest.mark.parametrize("augmentation_cls", [IAAPiecewiseAffine, IAAFliplr])
def test_imagaug_dual_augmentations_are_deterministic(augmentation_cls, image):
aug = augmentation_cls(p=1)
mask = np.copy(image)
Expand Down
Loading