Skip to content

Commit

Permalink
Changed downscale interpolation to avoid aliasing (#584)
Browse files Browse the repository at this point in the history
* Changed downscale interpolation to avoid aliasing

* add 2 interpolation arguments

* Add interpolation params as tuple

* Save interpolation into single param

* Go back to use random instead of np.random for sampling

* Fix downscale tests

* Interpolation class for Downscale

* Interpolation class for Downscale

* Remove abc from interface

* Install types for mypy

* Install types-dataclasses

* Remove dataclasses

Co-authored-by: Nathan Hubens <nathan@Nounou.local>
Co-authored-by: Dipet <dipetm@gmail.com>
  • Loading branch information
3 people committed Aug 18, 2022
1 parent a271a09 commit e96530e
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 21 deletions.
10 changes: 6 additions & 4 deletions albumentations/augmentations/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1083,14 +1083,16 @@ def to_gray(img):


@preserve_shape
def downscale(img, scale, interpolation=cv2.INTER_NEAREST):
def downscale(img, scale, down_interpolation=cv2.INTER_AREA, up_interpolation=cv2.INTER_LINEAR):
h, w = img.shape[:2]

need_cast = interpolation != cv2.INTER_NEAREST and img.dtype == np.uint8
need_cast = (
up_interpolation != cv2.INTER_NEAREST or down_interpolation != cv2.INTER_NEAREST
) and img.dtype == np.uint8
if need_cast:
img = to_float(img)
downscaled = cv2.resize(img, None, fx=scale, fy=scale, interpolation=interpolation)
upscaled = cv2.resize(downscaled, (w, h), interpolation=interpolation)
downscaled = cv2.resize(img, None, fx=scale, fy=scale, interpolation=down_interpolation)
upscaled = cv2.resize(downscaled, (w, h), interpolation=up_interpolation)
if need_cast:
upscaled = from_float(np.clip(upscaled, 0, 1), dtype=np.dtype("uint8"))
return upscaled
Expand Down
67 changes: 51 additions & 16 deletions albumentations/augmentations/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
DualTransform,
ImageOnlyTransform,
NoOp,
ScaleFloatType,
to_tuple,
)
from ..core.utils import format_args
Expand Down Expand Up @@ -1701,7 +1702,11 @@ class Downscale(ImageOnlyTransform):
Args:
scale_min (float): lower bound on the image scale. Should be < 1.
scale_max (float): lower bound on the image scale. Should be .
interpolation: cv2 interpolation method. cv2.INTER_NEAREST by default
interpolation: cv2 interpolation method. Could be:
- single cv2 interpolation flag - selected method will be used for downscale and upscale.
- dict(downscale=flag, upscale=flag)
- Downscale.Interpolation(downscale=flag, upscale=flag) -
Default: Interpolation(downscale=cv2.INTER_NEAREST, upscale=cv2.INTER_NEAREST)
Targets:
image
Expand All @@ -1710,34 +1715,64 @@ class Downscale(ImageOnlyTransform):
uint8, float32
"""

class Interpolation:
def __init__(self, *, downscale: int = cv2.INTER_NEAREST, upscale: int = cv2.INTER_NEAREST):
self.downscale = downscale
self.upscale = upscale

def __init__(
self,
scale_min=0.25,
scale_max=0.25,
interpolation=cv2.INTER_NEAREST,
always_apply=False,
p=0.5,
scale_min: float = 0.25,
scale_max: float = 0.25,
interpolation: Optional[Union[int, Interpolation, Dict[str, int]]] = None,
always_apply: bool = False,
p: float = 0.5,
):
super(Downscale, self).__init__(always_apply, p)
if interpolation is None:
self.interpolation = self.Interpolation(downscale=cv2.INTER_NEAREST, upscale=cv2.INTER_NEAREST)
warnings.warn(
"Using default interpolation INTER_NEAREST, which is sub-optimal."
"Please specify interpolation mode for downscale and upscale explicitly."
"For additional information see this PR https://github.com/albumentations-team/albumentations/pull/584"
)
elif isinstance(interpolation, int):
self.interpolation = self.Interpolation(downscale=interpolation, upscale=interpolation)
elif isinstance(interpolation, self.Interpolation):
self.interpolation = interpolation
elif isinstance(interpolation, dict):
self.interpolation = self.Interpolation(**interpolation)
else:
raise ValueError(
"Wrong interpolation data type. Supported types: `Optional[Union[int, Interpolation, Dict[str, int]]]`."
f" Got: {type(interpolation)}"
)

if scale_min > scale_max:
raise ValueError("Expected scale_min be less or equal scale_max, got {} {}".format(scale_min, scale_max))
if scale_max >= 1:
raise ValueError("Expected scale_max to be less than 1, got {}".format(scale_max))
self.scale_min = scale_min
self.scale_max = scale_max
self.interpolation = interpolation

def apply(self, image, scale, interpolation, **params):
return F.downscale(image, scale=scale, interpolation=interpolation)
def apply(self, img: np.ndarray, scale: Optional[float] = None, **params) -> np.ndarray:
return F.downscale(
img,
scale=scale,
down_interpolation=self.interpolation.downscale,
up_interpolation=self.interpolation.upscale,
)

def get_params(self):
return {
"scale": random.uniform(self.scale_min, self.scale_max),
"interpolation": self.interpolation,
}
def get_params(self) -> Dict[str, Any]:
return {"scale": random.uniform(self.scale_min, self.scale_max)}

def get_transform_init_args_names(self):
return "scale_min", "scale_max", "interpolation"
def get_transform_init_args_names(self) -> Tuple[str, str]:
return "scale_min", "scale_max"

def _to_dict(self) -> Dict[str, Any]:
result = super()._to_dict()
result["interpolation"] = {"upscale": self.interpolation.upscale, "downscale": self.interpolation.downscale}
return result


class Lambda(NoOp):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ def test_downscale(interpolation):

for img in (img_float, img_uint):
transformed = aug(image=img)["image"]
func_applied = F.downscale(img, scale=0.5, interpolation=interpolation)
func_applied = F.downscale(img, scale=0.5, down_interpolation=interpolation, up_interpolation=interpolation)
np.testing.assert_almost_equal(transformed, func_applied)


Expand Down

0 comments on commit e96530e

Please sign in to comment.