From 98eafe0cb5388f689f53209bde47e3939916c5e8 Mon Sep 17 00:00:00 2001 From: Mikhail Druzhinin Date: Sun, 12 Jun 2022 08:51:12 +0300 Subject: [PATCH] Fix safe rotate targets (#1109) * Fix safe rotate targets * Save all keypoints and bboxes * Fix worng angle error * Accurate safe rotate * Tests * Fix mask interpolation, tests, remove unused code * Remove dummy test * Fix mypy --- .../augmentations/geometric/functional.py | 144 +++++++----------- .../augmentations/geometric/rotate.py | 96 +++++++++--- .../augmentations/keypoints_utils.py | 4 +- tests/test_transforms.py | 97 ++++++++++-- 4 files changed, 210 insertions(+), 131 deletions(-) diff --git a/albumentations/augmentations/geometric/functional.py b/albumentations/augmentations/geometric/functional.py index 7713b86b7..b5525e0f1 100644 --- a/albumentations/augmentations/geometric/functional.py +++ b/albumentations/augmentations/geometric/functional.py @@ -44,7 +44,6 @@ "safe_rotate", "bbox_safe_rotate", "keypoint_safe_rotate", - "safe_rotate_enlarged_img_size", "piecewise_affine", "to_distance_maps", "from_distance_maps", @@ -539,112 +538,75 @@ def bbox_affine( @preserve_channel_dim def safe_rotate( img: np.ndarray, - angle: int = 0, - interpolation: int = cv2.INTER_LINEAR, - value: int = None, + matrix: np.ndarray, + interpolation: int, + value: Optional[int] = None, border_mode: int = cv2.BORDER_REFLECT_101, -): - - old_rows, old_cols = img.shape[:2] - - # getRotationMatrix2D needs coordinates in reverse order (width, height) compared to shape - image_center = (old_cols / 2, old_rows / 2) - - # Rows and columns of the rotated image (not cropped) - new_rows, new_cols = safe_rotate_enlarged_img_size(angle=angle, rows=old_rows, cols=old_cols) - - # Rotation Matrix - rotation_mat = cv2.getRotationMatrix2D(image_center, angle, 1.0) - - # Shift the image to create padding - rotation_mat[0, 2] += new_cols / 2 - image_center[0] - rotation_mat[1, 2] += new_rows / 2 - image_center[1] - - # CV2 Transformation function - warp_affine_fn = _maybe_process_in_chunks( +) -> np.ndarray: + h, w = img.shape[:2] + warp_fn = _maybe_process_in_chunks( cv2.warpAffine, - M=rotation_mat, - dsize=(new_cols, new_rows), + M=matrix, + dsize=(w, h), flags=interpolation, borderMode=border_mode, borderValue=value, ) + return warp_fn(img) - # rotate image with the new bounds - rotated_img = warp_affine_fn(img) - - # Resize image back to the original size - resized_img = resize(img=rotated_img, height=old_rows, width=old_cols, interpolation=interpolation) - - return resized_img - - -def bbox_safe_rotate(bbox, angle, rows, cols): - old_rows = rows - old_cols = cols - - # Rows and columns of the rotated image (not cropped) - new_rows, new_cols = safe_rotate_enlarged_img_size(angle=angle, rows=old_rows, cols=old_cols) - - col_diff = int(np.ceil(abs(new_cols - old_cols) / 2)) - row_diff = int(np.ceil(abs(new_rows - old_rows) / 2)) - - # Normalize shifts - norm_col_shift = col_diff / new_cols - norm_row_shift = row_diff / new_rows - # shift bbox - shifted_bbox = ( - bbox[0] + norm_col_shift, - bbox[1] + norm_row_shift, - bbox[2] + norm_col_shift, - bbox[3] + norm_row_shift, +def bbox_safe_rotate( + bbox: Tuple[float, float, float, float], matrix: np.ndarray, cols: int, rows: int +) -> Tuple[float, float, float, float]: + x1, y1, x2, y2 = denormalize_bbox(bbox, rows, cols) + points = np.array( + [ + [x1, y1, 1], + [x2, y1, 1], + [x2, y2, 1], + [x1, y2, 1], + ] ) + points = points @ matrix.T + x1 = points[:, 0].min() + x2 = points[:, 0].max() + y1 = points[:, 1].min() + y2 = points[:, 1].max() - rotated_bbox = bbox_rotate(bbox=shifted_bbox, angle=angle, rows=new_rows, cols=new_cols) - - # Bounding boxes are scale invariant, so this does not need to be rescaled to the old size - return rotated_bbox - - -def keypoint_safe_rotate(keypoint, angle, rows, cols): - old_rows = rows - old_cols = cols - - # Rows and columns of the rotated image (not cropped) - new_rows, new_cols = safe_rotate_enlarged_img_size(angle=angle, rows=old_rows, cols=old_cols) - - col_diff = int(np.ceil(abs(new_cols - old_cols) / 2)) - row_diff = int(np.ceil(abs(new_rows - old_rows) / 2)) - - # Shift keypoint - shifted_keypoint = (keypoint[0] + col_diff, keypoint[1] + row_diff, keypoint[2], keypoint[3]) + def fix_point(pt1: float, pt2: float, max_val: float) -> Tuple[float, float]: + # In my opinion, these errors should be very low, around 1-2 pixels. + if pt1 < 0: + return 0, pt2 + pt1 + if pt2 > max_val: + return pt1 - (pt2 - max_val), max_val + return pt1, pt2 - # Rotate keypoint - rotated_keypoint = keypoint_rotate(shifted_keypoint, angle, rows=new_rows, cols=new_cols) + x1, x2 = fix_point(x1, x2, cols) + y1, y2 = fix_point(y1, y2, rows) - # Scale the keypoint - return keypoint_scale(rotated_keypoint, old_cols / new_cols, old_rows / new_rows) + return normalize_bbox((x1, y1, x2, y2), rows, cols) -def safe_rotate_enlarged_img_size(angle: float, rows: int, cols: int): - - deg_angle = abs(angle) - - # The rotation angle - angle = np.deg2rad(deg_angle % 90) - - # The width of the frame to contain the rotated image - r_cols = cols * np.cos(angle) + rows * np.sin(angle) +def keypoint_safe_rotate( + keypoint: Tuple[float, float, float, float], + matrix: np.ndarray, + angle: float, + scale_x: float, + scale_y: float, + cols: int, + rows: int, +) -> Tuple[float, float, float, float]: + x, y, a, s = keypoint + point = np.array([[x, y, 1]]) + x, y = (point @ matrix.T)[0] - # The height of the frame to contain the rotated image - r_rows = cols * np.sin(angle) + rows * np.cos(angle) + # To avoid problems with float errors + x = np.clip(x, 0, cols - 1) + y = np.clip(y, 0, rows - 1) - # The above calculations work as is for 0<90 degrees, and for 90<180 the cols and rows are flipped - if deg_angle > 90: - return int(r_cols), int(r_rows) - else: - return int(r_rows), int(r_cols) + a += angle + s *= max(scale_x, scale_y) + return x, y, a, s @clipped diff --git a/albumentations/augmentations/geometric/rotate.py b/albumentations/augmentations/geometric/rotate.py index 9e7fc0b76..4c3087b1b 100644 --- a/albumentations/augmentations/geometric/rotate.py +++ b/albumentations/augmentations/geometric/rotate.py @@ -1,4 +1,6 @@ +import math import random +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import cv2 import numpy as np @@ -135,13 +137,13 @@ class SafeRotate(DualTransform): def __init__( self, - limit=90, - interpolation=cv2.INTER_LINEAR, - border_mode=cv2.BORDER_REFLECT_101, - value=None, - mask_value=None, - always_apply=False, - p=0.5, + limit: Union[float, Tuple[float, float]] = 90, + interpolation: int = cv2.INTER_LINEAR, + border_mode: int = cv2.BORDER_REFLECT_101, + value: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, + mask_value: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, + always_apply: bool = False, + p: float = 0.5, ): super(SafeRotate, self).__init__(always_apply, p) self.limit = to_tuple(limit) @@ -150,24 +152,70 @@ def __init__( self.value = value self.mask_value = mask_value - def apply(self, img, angle=0, interpolation=cv2.INTER_LINEAR, **params): - return F.safe_rotate( - img=img, value=self.value, angle=angle, interpolation=interpolation, border_mode=self.border_mode - ) + def apply(self, img: np.ndarray, matrix: np.ndarray = None, **params) -> np.ndarray: + return F.safe_rotate(img, matrix, self.interpolation, self.value, self.border_mode) - def apply_to_mask(self, img, angle=0, **params): - return F.safe_rotate( - img=img, value=self.mask_value, angle=angle, interpolation=cv2.INTER_NEAREST, border_mode=self.border_mode - ) + def apply_to_mask(self, img: np.ndarray, matrix: np.ndarray = None, **params) -> np.ndarray: + return F.safe_rotate(img, matrix, cv2.INTER_NEAREST, self.mask_value, self.border_mode) - def get_params(self): - return {"angle": random.uniform(self.limit[0], self.limit[1])} - - def apply_to_bbox(self, bbox, angle=0, **params): - return F.bbox_safe_rotate(bbox=bbox, angle=angle, rows=params["rows"], cols=params["cols"]) + def apply_to_bbox( + self, bbox: Tuple[float, float, float, float], cols: int = 0, rows: int = 0, **params + ) -> Tuple[float, float, float, float]: + return F.bbox_safe_rotate(bbox, params["matrix"], cols, rows) - def apply_to_keypoint(self, keypoint, angle=0, **params): - return F.keypoint_safe_rotate(keypoint, angle=angle, rows=params["rows"], cols=params["cols"]) - - def get_transform_init_args_names(self): + def apply_to_keypoint( + self, + keypoint: Tuple[float, float, float, float], + angle: float = 0, + scale_x: float = 0, + scale_y: float = 0, + cols: int = 0, + rows: int = 0, + **params + ) -> Tuple[float, float, float, float]: + return F.keypoint_safe_rotate(keypoint, params["matrix"], angle, scale_x, scale_y, cols, rows) + + @property + def targets_as_params(self) -> List[str]: + return ["image"] + + def get_params_dependent_on_targets(self, params: Dict[str, Any]) -> Dict[str, Any]: + angle = random.uniform(self.limit[0], self.limit[1]) + + image = params["image"] + h, w = image.shape[:2] + + # https://stackoverflow.com/questions/43892506/opencv-python-rotate-image-without-cropping-sides + image_center = (w / 2, h / 2) + + # Rotation Matrix + rotation_mat = cv2.getRotationMatrix2D(image_center, angle, 1.0) + + # rotation calculates the cos and sin, taking absolutes of those. + abs_cos = abs(rotation_mat[0, 0]) + abs_sin = abs(rotation_mat[0, 1]) + + # find the new width and height bounds + new_w = math.ceil(h * abs_sin + w * abs_cos) + new_h = math.ceil(h * abs_cos + w * abs_sin) + + scale_x = w / new_w + scale_y = h / new_h + + # Shift the image to create padding + rotation_mat[0, 2] += new_w / 2 - image_center[0] + rotation_mat[1, 2] += new_h / 2 - image_center[1] + + # Rescale to original size + scale_mat = np.diag(np.ones(3)) + scale_mat[0, 0] *= scale_x + scale_mat[1, 1] *= scale_y + _tmp = np.diag(np.ones(3)) + _tmp[:2] = rotation_mat + _tmp = scale_mat @ _tmp + rotation_mat = _tmp[:2] + + return {"matrix": rotation_mat, "angle": angle, "scale_x": scale_x, "scale_y": scale_y} + + def get_transform_init_args_names(self) -> Tuple[str, str, str, str, str]: return ("limit", "interpolation", "border_mode", "value", "mask_value") diff --git a/albumentations/augmentations/keypoints_utils.py b/albumentations/augmentations/keypoints_utils.py index 9250d70c7..ccb6b6942 100644 --- a/albumentations/augmentations/keypoints_utils.py +++ b/albumentations/augmentations/keypoints_utils.py @@ -160,11 +160,11 @@ def convert_keypoint_from_albumentations( # type (tuple, str, int, int, bool, bool) -> tuple if target_format not in keypoint_formats: raise ValueError("Unknown target_format {}. Supported formats are: {}".format(target_format, keypoint_formats)) - if check_validity: - check_keypoint(keypoint, rows, cols) (x, y, angle, scale), tail = keypoint[:4], tuple(keypoint[4:]) angle = angle_to_2pi_range(angle) + if check_validity: + check_keypoint((x, y, angle, scale), rows, cols) if angle_in_degrees: angle = math.degrees(angle) diff --git a/tests/test_transforms.py b/tests/test_transforms.py index ec09549f1..78038c704 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -26,20 +26,6 @@ def test_transpose_both_image_and_mask(): assert augmented["mask"].shape == (6, 8) -@pytest.mark.parametrize("interpolation", [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC]) -def test_safe_rotate_interpolation(interpolation): - image = np.random.randint(low=0, high=256, size=(100, 100, 3), dtype=np.uint8) - mask = np.random.randint(low=0, high=2, size=(100, 100), dtype=np.uint8) - aug = A.SafeRotate(limit=(45, 45), interpolation=interpolation, p=1) - data = aug(image=image, mask=mask) - expected_image = FGeometric.safe_rotate(image, 45, interpolation=interpolation, border_mode=cv2.BORDER_REFLECT_101) - expected_mask = FGeometric.safe_rotate( - mask, 45, interpolation=cv2.INTER_NEAREST, border_mode=cv2.BORDER_REFLECT_101 - ) - assert np.array_equal(data["image"], expected_image) - assert np.array_equal(data["mask"], expected_mask) - - @pytest.mark.parametrize("interpolation", [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC]) def test_rotate_interpolation(interpolation): image = np.random.randint(low=0, high=256, size=(100, 100, 3), dtype=np.uint8) @@ -1012,3 +998,86 @@ def test_advanced_blur_float_uint8_diff_less_than_two(val_uint8): def test_advanced_blur_raises_on_incorrect_params(params): with pytest.raises(ValueError): A.AdvancedBlur(**params) + + +@pytest.mark.parametrize( + ["angle", "targets", "expected"], + [ + [ + -10, + { + "bboxes": [ + [0, 0, 5, 5, 0], + [195, 0, 200, 5, 0], + [195, 95, 200, 100, 0], + [0, 95, 5, 99, 0], + ], + "keypoints": [ + [0, 0, 0, 0], + [199, 0, 10, 10], + [199, 99, 20, 20], + [0, 99, 30, 30], + ], + }, + { + "bboxes": [ + [15.65896994771262, 0.2946228229078849, 21.047137067150473, 4.617219579173327, 0], + [194.29851584295034, 25.564320319214918, 199.68668296238818, 29.88691707548036, 0], + [178.9528629328495, 95.38278042082668, 184.34103005228735, 99.70537717709212, 0], + [0.47485022613917677, 70.11308292451965, 5.701484157049652, 73.70074852182076, 0], + ], + "keypoints": [ + [16.466635890349504, 0.2946228229078849, 147.04220486917677, 0.0], + [198.770582727028, 26.08267308836993, 157.04220486917674, 9.30232558139535], + [182.77879706281766, 98.84085782583904, 167.04220486917674, 18.6046511627907], + [0.4748502261391767, 73.05280756037699, 177.04220486917674, 27.90697674418604], + ], + }, + ], + [ + 10, + { + "bboxes": [ + [0, 0, 5, 5, 0], + [195, 0, 200, 5, 0], + [195, 95, 200, 100, 0], + [0, 95, 5, 99, 0], + ], + "keypoints": [ + [0, 0, 0, 0], + [199, 0, 10, 10], + [199, 99, 20, 20], + [0, 99, 30, 30], + ], + }, + { + "bboxes": [ + [0.3133170376117963, 25.564320319214918, 5.701484157049649, 29.88691707548036, 0], + [178.9528629328495, 0.2946228229078862, 184.34103005228735, 4.617219579173327, 0], + [194.29851584295034, 70.11308292451965, 199.68668296238818, 74.43567968078509, 0], + [15.658969947712617, 95.38278042082668, 20.88560387862309, 98.97044601812779, 0], + ], + "keypoints": [ + [0.3133170376117963, 26.212261280658684, 212.95779513082323, 0.0], + [182.6172638742903, 0.42421101519664006, 222.95779513082323, 9.30232558139535], + [198.60904953850064, 73.18239575266574, 232.9577951308232, 18.6046511627907], + [16.305102701822126, 98.97044601812779, 242.9577951308232, 27.906976744186046], + ], + }, + ], + ], +) +def test_safe_rotate(angle: float, targets: dict, expected: dict): + image = np.empty([100, 200, 3], dtype=np.uint8) + t = A.Compose( + [ + A.SafeRotate(limit=(angle, angle), border_mode=0, value=0, p=1), + ], + bbox_params=A.BboxParams(format="pascal_voc", min_visibility=0.0), + keypoint_params=A.KeypointParams("xyas"), + p=1, + ) + res = t(image=image, **targets) + + for key, value in expected.items(): + assert np.allclose(np.array(value), np.array(res[key])), key