Skip to content

Commit

Permalink
PadIfNeeded: add random padding (#1160)
Browse files Browse the repository at this point in the history
* Add random padding to PadIfNeeded

* fix formatting issues

Co-authored-by: Martin Simonovsky <Martin.Simonovsky@HeidelbergEngineering.com>
Co-authored-by: Mikhail Druzhinin <dipetm@gmail.com>
  • Loading branch information
3 people committed Jun 11, 2022
1 parent 81c2fa0 commit d492fe3
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 2 deletions.
11 changes: 10 additions & 1 deletion albumentations/augmentations/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class PadIfNeeded(DualTransform):
pad_width_divisor (int): if not None, ensures image width is dividable by value of this argument.
position (Union[str, PositionType]): Position of the image. should be PositionType.CENTER or
PositionType.TOP_LEFT or PositionType.TOP_RIGHT or PositionType.BOTTOM_LEFT or PositionType.BOTTOM_RIGHT.
Default: PositionType.CENTER.
or PositionType.RANDOM. Default: PositionType.CENTER.
border_mode (OpenCV flag): OpenCV border mode.
value (int, float, list of int, list of float): padding value if border_mode is cv2.BORDER_CONSTANT.
mask_value (int, float,
Expand All @@ -112,6 +112,7 @@ class PositionType(Enum):
TOP_RIGHT = "top_right"
BOTTOM_LEFT = "bottom_left"
BOTTOM_RIGHT = "bottom_right"
RANDOM = "random"

def __init__(
self,
Expand Down Expand Up @@ -259,6 +260,14 @@ def __update_position_params(
h_bottom = 0
w_right = 0

elif self.position == PadIfNeeded.PositionType.RANDOM:
h_pad = h_top + h_bottom
w_pad = w_left + w_right
h_top = random.randint(0, h_pad)
h_bottom = h_pad - h_top
w_left = random.randint(0, w_pad)
w_right = w_pad - w_left

return h_top, h_bottom, w_left, w_right


Expand Down
9 changes: 8 additions & 1 deletion tests/test_augmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,9 +664,12 @@ def test_pad_if_needed(augmentation_cls: Type[A.PadIfNeeded], params: Dict, imag
[{"min_height": 10, "min_width": 12, "border_mode": 0, "value": 1, "position": "top_right"}, (5, 6)],
[{"min_height": 10, "min_width": 12, "border_mode": 0, "value": 1, "position": "bottom_left"}, (5, 6)],
[{"min_height": 10, "min_width": 12, "border_mode": 0, "value": 1, "position": "bottom_right"}, (5, 6)],
[{"min_height": 10, "min_width": 12, "border_mode": 0, "value": 1, "position": "random"}, (5, 6)],
],
)
def test_pad_if_needed_position(params, image_shape):
random.seed(42)

image = np.zeros(image_shape)
pad = A.PadIfNeeded(**params)
image_padded = pad(image=image)["image"]
Expand All @@ -691,10 +694,14 @@ def test_pad_if_needed_position(params, image_shape):
true_result[-image_shape[0] :, : image_shape[1]] = 0
assert (image_padded == true_result).all()

if params["position"] == "bottom_right":
elif params["position"] == "bottom_right":
true_result[-image_shape[0] :, -image_shape[1] :] = 0
assert (image_padded == true_result).all()

elif params["position"] == "random":
true_result[0:5, -7:-1] = 0
assert (image_padded == true_result).all()


@pytest.mark.parametrize(
["points"],
Expand Down

0 comments on commit d492fe3

Please sign in to comment.