Skip to content

Commit

Permalink
Add allow_shifted flag. (#1239)
Browse files Browse the repository at this point in the history
Co-authored-by: Mikhail Druzhinin <dipet@gmail.com>
Co-authored-by: Vladimir Iglovikov <ternaus@users.noreply.github.com>
  • Loading branch information
3 people committed Aug 2, 2022
1 parent 8e958a3 commit c820592
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 5 deletions.
51 changes: 46 additions & 5 deletions albumentations/augmentations/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1215,6 +1215,8 @@ class MotionBlur(Blur):
Args:
blur_limit (int): maximum kernel size for blurring the input image.
Should be in range [3, inf). Default: (3, 7).
allow_shifted (bool): if set to true creates non shifted kernels only,
otherwise creates randomly shifted kernels. Default: True.
p (float): probability of applying the transform. Default: 0.5.
Targets:
Expand All @@ -1224,6 +1226,22 @@ class MotionBlur(Blur):
uint8, float32
"""

def __init__(
self,
blur_limit: Union[int, Sequence[int]] = 7,
allow_shifted: bool = True,
always_apply: bool = False,
p: float = 0.5,
):
super().__init__(blur_limit=blur_limit, always_apply=always_apply, p=p)
self.allow_shifted = allow_shifted

if not allow_shifted and self.blur_limit[0] % 2 != 1 or self.blur_limit[1] % 2 != 1:
raise ValueError(f"Blur limit must be odd when centered=True. Got: {self.blur_limit}")

def get_transform_init_args_names(self):
return super().get_transform_init_args_names() + ("allow_shifted",)

def apply(self, img, kernel=None, **params):
return F.convolve(img, kernel=kernel)

Expand All @@ -1232,12 +1250,35 @@ def get_params(self):
if ksize <= 2:
raise ValueError("ksize must be > 2. Got: {}".format(ksize))
kernel = np.zeros((ksize, ksize), dtype=np.uint8)
xs, xe = random.randint(0, ksize - 1), random.randint(0, ksize - 1)
if xs == xe:
ys, ye = random.sample(range(ksize), 2)
x1, x2 = random.randint(0, ksize - 1), random.randint(0, ksize - 1)
if x1 == x2:
y1, y2 = random.sample(range(ksize), 2)
else:
ys, ye = random.randint(0, ksize - 1), random.randint(0, ksize - 1)
cv2.line(kernel, (xs, ys), (xe, ye), 1, thickness=1)
y1, y2 = random.randint(0, ksize - 1), random.randint(0, ksize - 1)

def make_odd_val(v1, v2):
len_v = abs(v1 - v2) + 1
if len_v % 2 != 1:
if v2 > v1:
v2 -= 1
else:
v1 -= 1
return v1, v2

if not self.allow_shifted:
x1, x2 = make_odd_val(x1, x2)
y1, y2 = make_odd_val(y1, y2)

xc = (x1 + x2) / 2
yc = (y1 + y2) / 2

center = ksize / 2 - 0.5
dx = xc - center
dy = yc - center
x1, x2 = [int(i - dx) for i in [x1, x2]]
y1, y2 = [int(i - dy) for i in [y1, y2]]

cv2.line(kernel, (x1, y1), (x2, y2), 1, thickness=1)

# Normalize kernel
kernel = kernel.astype(np.float32) / np.sum(kernel)
Expand Down
28 changes: 28 additions & 0 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1223,3 +1223,31 @@ def test_bbox_clipping_perspective():
bboxes = np.array([[0, 0, 100, 100, 1]])
res = transform(image=image, bboxes=bboxes)["bboxes"]
assert len(res) == 0


@pytest.mark.parametrize("seed", [i for i in range(10)])
def test_motion_blur_allow_shifted(seed):
random.seed(seed)

transform = A.MotionBlur(allow_shifted=False)
kernel = transform.get_params()["kernel"]

center = kernel.shape[0] / 2 - 0.5

def check_center(vector):
start = None
end = None

for i, v in enumerate(vector):
if start is None and v != 0:
start = i
elif start is not None and v == 0:
end = i
break
if end is None:
end = len(vector)

assert (end + start - 1) / 2 == center

check_center(kernel.sum(axis=0))
check_center(kernel.sum(axis=1))

0 comments on commit c820592

Please sign in to comment.