Skip to content

Commit

Permalink
Add Defocus and ZoomBlur transforms (#551)
Browse files Browse the repository at this point in the history
* Add Defocus and ZoomBlur transforms from 'Benchmarking Neural Network Robustness to Common Corruptions and Perturbations' (https://arxiv.org/abs/1903.12261)

* Add missing links to Defocus and ZoomBlur documentation

* Changes according to code review

* Remove redundant get_params_dependent_on_targets

* Add serialization tests + small fixes

* typing

* Fix __all__

Co-authored-by: Eugene Khvedchenya <ekhvedchenya@gmail.com>
Co-authored-by: Dipet <dipetm@gmail.com>
  • Loading branch information
3 people committed Aug 23, 2022
1 parent a0aebc2 commit a28dbb8
Show file tree
Hide file tree
Showing 5 changed files with 157 additions and 16 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ Pixel-level transforms will change just an input image and will leave any additi
- [ChannelDropout](https://albumentations.ai/docs/api_reference/augmentations/dropout/channel_dropout/#albumentations.augmentations.dropout.channel_dropout.ChannelDropout)
- [ChannelShuffle](https://albumentations.ai/docs/api_reference/augmentations/transforms/#albumentations.augmentations.transforms.ChannelShuffle)
- [ColorJitter](https://albumentations.ai/docs/api_reference/augmentations/transforms/#albumentations.augmentations.transforms.ColorJitter)
- [Defocus](https://albumentations.ai/docs/api_reference/augmentations/transforms/#albumentations.augmentations.transforms.Defocus)
- [Downscale](https://albumentations.ai/docs/api_reference/augmentations/transforms/#albumentations.augmentations.transforms.Downscale)
- [Emboss](https://albumentations.ai/docs/api_reference/augmentations/transforms/#albumentations.augmentations.transforms.Emboss)
- [Equalize](https://albumentations.ai/docs/api_reference/augmentations/transforms/#albumentations.augmentations.transforms.Equalize)
Expand Down Expand Up @@ -167,6 +168,7 @@ Pixel-level transforms will change just an input image and will leave any additi
- [ToGray](https://albumentations.ai/docs/api_reference/augmentations/transforms/#albumentations.augmentations.transforms.ToGray)
- [ToSepia](https://albumentations.ai/docs/api_reference/augmentations/transforms/#albumentations.augmentations.transforms.ToSepia)
- [UnsharpMask](https://albumentations.ai/docs/api_reference/augmentations/transforms/#albumentations.augmentations.transforms.UnsharpMask)
- [ZoomBlur](https://albumentations.ai/docs/api_reference/augmentations/transforms/#albumentations.augmentations.transforms.ZoomBlur)

### Spatial-level transforms
Spatial-level transforms will simultaneously change both an input image as well as additional targets such as masks, bounding boxes, and keypoints. The following table shows which additional targets are supported by each transform.
Expand Down
52 changes: 52 additions & 0 deletions albumentations/augmentations/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from functools import wraps
from itertools import product
from math import ceil
from typing import Callable, Optional, Sequence, Union
from warnings import warn

Expand Down Expand Up @@ -68,6 +69,8 @@
"to_float",
"to_gray",
"unsharp_mask",
"scale",
"resize",
]

MAX_VALUES_BY_DTYPE = {
Expand Down Expand Up @@ -1565,3 +1568,52 @@ def spatter(
raise ValueError("Unsupported spatter mode: " + str(mode))

return img * 255


def defocus(img: np.ndarray, radius: int, alias_blur: float) -> np.ndarray:
L = np.arange(-max(8, radius), max(8, radius) + 1)
ksize = 3 if radius <= 8 else 5

X, Y = np.meshgrid(L, L)
aliased_disk = np.array((X**2 + Y**2) <= radius**2, dtype=np.float32)
aliased_disk /= np.sum(aliased_disk)

kernel = gaussian_blur(aliased_disk, ksize, sigma=alias_blur)
return convolve(img, kernel=kernel)


@preserve_channel_dim
def resize(img: np.ndarray, height: int, width: int, interpolation: int = cv2.INTER_LINEAR) -> np.ndarray:
img_height, img_width = img.shape[:2]
if height == img_height and width == img_width:
return img
resize_fn = _maybe_process_in_chunks(cv2.resize, dsize=(width, height), interpolation=interpolation)
return resize_fn(img)


@preserve_channel_dim
def scale(img: np.ndarray, scale: float, interpolation: int = cv2.INTER_LINEAR) -> np.ndarray:
height, width = img.shape[:2]
new_height, new_width = int(height * scale), int(width * scale)
return resize(img, new_height, new_width, interpolation)


def central_zoom(img: np.ndarray, zoom_factor: int) -> np.ndarray:
h, w = img.shape[:2]
h_ch, w_ch = ceil(h / zoom_factor), ceil(w / zoom_factor)
h_top, w_top = (h - h_ch) // 2, (w - w_ch) // 2

img = scale(img[h_top : h_top + h_ch, w_top : w_top + w_ch], zoom_factor, cv2.INTER_LINEAR)
h_trim_top, w_trim_top = (img.shape[0] - h) // 2, (img.shape[1] - w) // 2
return img[h_trim_top : h_trim_top + h, w_trim_top : w_trim_top + w]


@clipped
def zoom_blur(img: np.ndarray, zoom_factors: Sequence[int]) -> np.ndarray:
out = np.zeros_like(img, dtype=np.float32)
for zoom_factor in zoom_factors:
out += central_zoom(img, zoom_factor)

img = ((img + out) / (len(zoom_factors) + 1)).astype(img.dtype)

return img
18 changes: 2 additions & 16 deletions albumentations/augmentations/geometric/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
clipped,
preserve_channel_dim,
preserve_shape,
resize,
scale,
)

__all__ = [
Expand Down Expand Up @@ -377,22 +379,6 @@ def elastic_transform(
return remap_fn(img)


@preserve_channel_dim
def resize(img, height, width, interpolation=cv2.INTER_LINEAR):
img_height, img_width = img.shape[:2]
if height == img_height and width == img_width:
return img
resize_fn = _maybe_process_in_chunks(cv2.resize, dsize=(width, height), interpolation=interpolation)
return resize_fn(img)


@preserve_channel_dim
def scale(img: np.ndarray, scale: float, interpolation: int = cv2.INTER_LINEAR) -> np.ndarray:
height, width = img.shape[:2]
new_height, new_width = int(height * scale), int(width * scale)
return resize(img, new_height, new_width, interpolation)


def keypoint_scale(keypoint: KeypointType, scale_x: float, scale_y: float) -> KeypointType:
"""Scales a keypoint by scale_x and scale_y.
Expand Down
99 changes: 99 additions & 0 deletions albumentations/augmentations/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
ImageOnlyTransform,
NoOp,
ScaleFloatType,
ScaleIntType,
to_tuple,
)
from ..core.utils import format_args
Expand Down Expand Up @@ -73,6 +74,8 @@
"AdvancedBlur",
"PixelDropout",
"Spatter",
"Defocus",
"ZoomBlur",
]


Expand Down Expand Up @@ -2855,3 +2858,99 @@ def get_params_dependent_on_targets(self, params: Dict[str, Any]) -> Dict[str, A

def get_transform_init_args_names(self) -> Tuple[str, str, str, str, str, str]:
return "mean", "std", "gauss_sigma", "intensity", "cutout_threshold", "mode"


class Defocus(ImageOnlyTransform):
"""
Apply defocus transform. See https://arxiv.org/abs/1903.12261.
Args:
radius ((int, int) or int): range for radius of defocusing.
If limit is a single int, the range will be [1, limit]. Default: (3, 10).
alias_blur ((float, float) or float): range for alias_blur of defocusing (sigma of gaussian blur).
If limit is a single float, the range will be (0, limit). Default: (0.1, 0.5).
p (float): probability of applying the transform. Default: 0.5.
Targets:
image
Image types:
Any
"""

def __init__(
self,
radius: ScaleIntType = (3, 10),
alias_blur: ScaleFloatType = (0.1, 0.5),
always_apply: bool = False,
p: float = 0.5,
):
super().__init__(always_apply, p)
self.radius = to_tuple(radius, low=1)
self.alias_blur = to_tuple(alias_blur, low=0)

if self.radius[0] <= 0:
raise ValueError("Parameter radius must be positive")

if self.alias_blur[0] < 0:
raise ValueError("Parameter alias_blur must be non-negative")

def apply(self, img: np.ndarray, radius: int = 3, alias_blur: float = 0.5, **params) -> np.ndarray:
return F.defocus(img, radius, alias_blur)

def get_params(self) -> Dict[str, Any]:
return {
"radius": random_utils.randint(self.radius[0], self.radius[1] + 1),
"alias_blur": random_utils.uniform(self.alias_blur[0], self.alias_blur[1]),
}

def get_transform_init_args_names(self) -> Tuple[str, str]:
return ("radius", "alias_blur")


class ZoomBlur(ImageOnlyTransform):
"""
Apply zoom blur transform. See https://arxiv.org/abs/1903.12261.
Args:
max_factor ((float, float) or float): range for max factor for blurring.
If max_factor is a single float, the range will be (1, limit). Default: (1, 1.31).
All max_factor values should be larger than 1.
step_factor ((float, float) or float): If single float will be used as step parameter for np.arange.
If tuple of float step_factor will be in range `[step_factor[0], step_factor[1])`. Default: (0.01, 0.03).
All step_factor values should be positive.
p (float): probability of applying the transform. Default: 0.5.
Targets:
image
Image types:
Any
"""

def __init__(
self,
max_factor: ScaleFloatType = 1.31,
step_factor: ScaleFloatType = (0.01, 0.03),
always_apply: bool = False,
p: float = 0.5,
):
super().__init__(always_apply, p)
self.max_factor = to_tuple(max_factor, low=1.0)
self.step_factor = to_tuple(step_factor, step_factor)

if self.max_factor[0] < 1:
raise ValueError("Max factor must be larger or equal 1")
if self.step_factor[0] <= 0:
raise ValueError("Step factor must be positive")

def apply(self, img: np.ndarray, zoom_factors: np.ndarray = None, **params) -> np.ndarray:
return F.zoom_blur(img, zoom_factors)

def get_params(self) -> Dict[str, Any]:
max_factor = random.uniform(self.max_factor[0], self.max_factor[1])
step_factor = random.uniform(self.step_factor[0], self.step_factor[1])
return {"zoom_factors": np.arange(1.0, max_factor, step_factor)}

def get_transform_init_args_names(self) -> Tuple[str, str]:
return ("max_factor", "step_factor")
2 changes: 2 additions & 0 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,8 @@ def test_augmentations_serialization(augmentation_cls, params, p, seed, image, m
[A.PixelDropout, {"dropout_prob": 0.1, "per_channel": False, "drop_value": None, "mask_drop_value": 15}],
[A.RandomCropFromBorders, dict(crop_left=0.2, crop_right=0.3, crop_top=0.05, crop_bottom=0.5)],
[A.Spatter, dict(mean=0.2, std=0.1, gauss_sigma=3, cutout_threshold=0.4, intensity=0.7, mode="mud")],
[A.Defocus, {"radius": (5, 7), "alias_blur": (0.2, 0.6)}],
[A.ZoomBlur, {"max_factor": (1.56, 1.7), "step_factor": (0.02, 0.04)}],
]

AUGMENTATION_CLS_EXCEPT = {
Expand Down

0 comments on commit a28dbb8

Please sign in to comment.