Skip to content

Commit

Permalink
downscale transform (#375)
Browse files Browse the repository at this point in the history
* downscale transform

* python2 fixes
  • Loading branch information
arsenyinfo authored and ternaus committed Sep 23, 2019
1 parent f38e153 commit df831d6
Show file tree
Hide file tree
Showing 7 changed files with 88 additions and 6 deletions.
1 change: 1 addition & 0 deletions README.md
Expand Up @@ -115,6 +115,7 @@ Pixel-level transforms will change just an input image and will leave any additi
- [ChannelShuffle](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.ChannelShuffle)
- [CoarseDropout](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.CoarseDropout)
- [Cutout](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.Cutout)
- [Downscale](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.Downscale)
- [Equalize](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.Equalize)
- [FromFloat](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.FromFloat)
- [GaussNoise](https://albumentations.readthedocs.io/en/latest/api/augmentations.html#albumentations.augmentations.transforms.GaussNoise)
Expand Down
14 changes: 14 additions & 0 deletions albumentations/augmentations/functional.py
Expand Up @@ -1315,6 +1315,20 @@ def to_gray(img):
return cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB)


@preserve_shape
def downscale(img, scale, interpolation=cv2.INTER_NEAREST):
h, w = img.shape[:2]

need_cast = interpolation != cv2.INTER_NEAREST and img.dtype == np.uint8
if need_cast:
img = to_float(img)
downscaled = cv2.resize(img, None, fx=scale, fy=scale, interpolation=interpolation)
upscaled = cv2.resize(downscaled, (w, h), interpolation=interpolation)
if need_cast:
upscaled = from_float(np.clip(upscaled, 0, 1), dtype=np.dtype("uint8"))
return upscaled


def to_float(img, max_value=None):
if max_value is None:
try:
Expand Down
42 changes: 39 additions & 3 deletions albumentations/augmentations/transforms.py
@@ -1,17 +1,17 @@
from __future__ import absolute_import, division

from types import LambdaType
import math
import random
import warnings
from enum import Enum
from types import LambdaType

import cv2
import numpy as np

from . import functional as F
from .bbox_utils import union_of_bboxes, denormalize_bbox, normalize_bbox
from ..core.transforms_interface import to_tuple, DualTransform, ImageOnlyTransform, NoOp
from .bbox_utils import denormalize_bbox, normalize_bbox, union_of_bboxes
from ..core.transforms_interface import DualTransform, ImageOnlyTransform, NoOp, to_tuple
from ..core.utils import format_args

__all__ = [
Expand Down Expand Up @@ -72,6 +72,7 @@
"Solarize",
"Equalize",
"Posterize",
"Downscale",
]


Expand Down Expand Up @@ -2753,6 +2754,41 @@ def get_transform_init_args(self):
return {"dtype": self.dtype.name, "max_value": self.max_value}


class Downscale(ImageOnlyTransform):
"""Decreases image quality by downscaling and upscaling back.
Args:
scale_min (float): lower bound on the image scale. Should be < 1.
scale_max (float): lower bound on the image scale. Should be .
interpolation: cv2 interpolation method. cv2.INTER_NEAREST by default
Targets:
image
Image types:
uint8, float32
"""

def __init__(self, scale_min=0.25, scale_max=0.25, interpolation=cv2.INTER_NEAREST, always_apply=False, p=0.5):
super(Downscale, self).__init__(always_apply, p)
assert scale_min <= scale_max, "Expected scale_min be less or equal scale_max, got {} {}".format(
scale_min, scale_max
)
assert scale_max < 1, "Expected scale_max to be less than 1, got {}".format(scale_max)
self.scale_min = scale_min
self.scale_max = scale_max
self.interpolation = interpolation

def apply(self, image, scale, interpolation, **params):
return F.downscale(image, scale=scale, interpolation=interpolation)

def get_params(self):
return {"scale": np.random.uniform(self.scale_min, self.scale_max), "interpolation": self.interpolation}

def get_transform_init_args_names(self):
return "scale_min", "scale_max", "interpolation"


class Lambda(NoOp):
"""A flexible transformation class for using user-defined transformation functions per targets.
Function signature must include **kwargs to accept optinal arguments like interpolation method, image size, etc:
Expand Down
2 changes: 2 additions & 0 deletions tests/test_augmentations.py
Expand Up @@ -59,6 +59,7 @@
Equalize,
CropNonEmptyMaskIfExists,
LongestMaxSize,
Downscale,
)


Expand Down Expand Up @@ -92,6 +93,7 @@
[Solarize, {}],
[Posterize, {}],
[Equalize, {}],
[Downscale, {}],
],
)
def test_image_only_augmentations(augmentation_cls, params, image, mask):
Expand Down
19 changes: 17 additions & 2 deletions tests/test_functional.py
Expand Up @@ -2,11 +2,11 @@

import cv2
import numpy as np
from numpy.testing import assert_array_almost_equal_nulp
import pytest
from numpy.testing import assert_array_almost_equal_nulp

from albumentations.augmentations.bbox_utils import filter_bboxes
import albumentations.augmentations.functional as F
from albumentations.augmentations.bbox_utils import filter_bboxes
from .utils import convert_2d_to_target_format


Expand Down Expand Up @@ -879,6 +879,21 @@ def test_equalize_rgb_mask():
assert np.all(img_b == result_img[20:30, 20:30, 2])


@pytest.mark.parametrize("dtype", ["float32", "uint8"])
def test_downscale_ones(dtype):
img = np.ones((100, 100, 3), dtype=dtype)
downscaled = F.downscale(img, scale=0.5)
assert np.all(downscaled == img)


def test_downscale_random():
img = np.random.rand(100, 100, 3)
downscaled = F.downscale(img, scale=0.5)
assert downscaled.shape == img.shape
downscaled = F.downscale(img, scale=1)
assert np.all(img == downscaled)


def test_maybe_process_in_chunks():
image = np.random.randint(0, 256, (100, 100, 6), np.uint8)

Expand Down
1 change: 1 addition & 0 deletions tests/test_serialization.py
Expand Up @@ -59,6 +59,7 @@
[A.Solarize, {}],
[A.Posterize, {}],
[A.Equalize, {}],
[A.Downscale, {}],
],
)
@pytest.mark.parametrize("p", [0.5, 1])
Expand Down
15 changes: 14 additions & 1 deletion tests/test_transforms.py
@@ -1,8 +1,8 @@
from functools import partial
from multiprocessing.pool import Pool

import numpy as np
import cv2
import numpy as np
import pytest

import albumentations as A
Expand Down Expand Up @@ -398,3 +398,16 @@ def _test_crop(mask, crop, aug, n=1):
_test_crop(mask_4, crop_4, aug_4, n=5)
_test_crop(mask_5, crop_5, aug_5, n=1)
_test_crop(mask_6, crop_6, aug_6, n=10)


@pytest.mark.parametrize("interpolation", [cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC])
def test_downscale(interpolation):
img_float = np.random.rand(100, 100, 3)
img_uint = (img_float * 255).astype("uint8")

aug = A.Downscale(scale_min=0.5, scale_max=0.5, interpolation=interpolation, always_apply=True)

for img in (img_float, img_uint):
transformed = aug(image=img)["image"]
func_applied = F.downscale(img, scale=0.5, interpolation=interpolation)
np.testing.assert_almost_equal(transformed, func_applied)

0 comments on commit df831d6

Please sign in to comment.