Skip to content

Commit

Permalink
Merge pull request #553 from aleju/randaugment
Browse files Browse the repository at this point in the history
Add RandAugment
  • Loading branch information
aleju committed Jan 11, 2020
2 parents 8d39175 + af5d247 commit 00789f1
Show file tree
Hide file tree
Showing 8 changed files with 455 additions and 4 deletions.
4 changes: 4 additions & 0 deletions changelogs/master/added/20200105_discretize_round.md
@@ -0,0 +1,4 @@
# Added `round` Parameter to `Discretize` #553

Added the parameter `round` to `imgaug.parameters.Discretize`. The parameter
defaults to `True`, i.e. the default behaviour of `Discretize` did not change.
8 changes: 8 additions & 0 deletions changelogs/master/added/20200106_randaugment.md
@@ -0,0 +1,8 @@
# Add RandAugment #553

Added a RandAugment augmenter, similar to the one described in the paper
"RandAugment: Practical automated data augmentation with a reduced
search space".

* Added module `imgaug.augmenters.collections`
* Added augmenter `imgaug.augmenters.collections.RandAugment`.
34 changes: 34 additions & 0 deletions checks/check_randaugment.py
@@ -0,0 +1,34 @@
from __future__ import print_function, division, absolute_import

import numpy as np

import imgaug as ia
import imgaug.augmenters as iaa


def main():
image = ia.quokka(0.25)

for N in [1, 2]:
print("N=%d" % (N,))

images_aug = []
for M in [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20]:
images_aug.extend(
iaa.RandAugment(n=N, m=M, random_state=1)(images=[image] * 10)
)
ia.imshow(ia.draw_grid(images_aug, cols=10))

for M in [0, 1, 2, 4, 8, 10]:
print("M=%d" % (M,))
aug = iaa.RandAugment(m=M, random_state=1)

images_aug = []
for _ in np.arange(6):
images_aug.extend(aug(images=[image] * 16))

ia.imshow(ia.draw_grid(images_aug, cols=16, rows=6))


if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions imgaug/augmenters/__init__.py
Expand Up @@ -5,6 +5,7 @@
from imgaug.augmenters.artistic import *
from imgaug.augmenters.blend import *
from imgaug.augmenters.blur import *
from imgaug.augmenters.collections import *
from imgaug.augmenters.color import *
from imgaug.augmenters.contrast import *
from imgaug.augmenters.convolutional import *
Expand Down
288 changes: 288 additions & 0 deletions imgaug/augmenters/collections.py
@@ -0,0 +1,288 @@
"""Augmenters that are collections of other augmenters.
List of augmenters:
* :class:`RandAugment`
"""
from __future__ import print_function, division, absolute_import

import numpy as np

from .. import parameters as iap
from .. import random as iarandom
from . import meta
from . import arithmetic
from . import flip
from . import pillike
from . import size as sizelib


class RandAugment(meta.Sequential):
"""Apply RandAugment to inputs as described in the corresponding paper.
See paper::
Cubuk et al.
RandAugment: Practical automated data augmentation with a reduced
search space
.. note::
The paper contains essentially no hyperparameters for the individual
augmentation techniques. The hyperparameters used here come mostly
from the official code repository, which however seems to only contain
code for CIFAR10 and SVHN, not for ImageNet. So some guesswork was
involved and a few of the hyperparameters were also taken from
https://github.com/ildoonet/pytorch-randaugment/blob/master/RandAugment/augmentations.py .
This implementation deviates from the code repository for all PIL
enhance operations. In the repository these use a factor of
``0.1 + M*1.8/M_max``, which would lead to a factor of ``0.1`` for the
weakest ``M`` of ``M=0``. For e.g. ``Brightness`` that would result in
a basically black image. This definition is fine for AutoAugment (from
where the code and hyperparameters are copied), which optimizes
each transformation's ``M`` individually, but not for RandAugment,
which uses a single fixed ``M``. We hence redefine these
hyperparameters to ``1.0 + S * M * 0.9/M_max``, where ``S`` is
randomly either ``1`` or ``-1``.
We also note that it is not entirely clear which transformations
were used in the ImageNet experiments. The paper lists some
transformations in Figure 2, but names others in the text too (e.g.
crops, flips, cutout). While Figure 2 lists the Identity function,
this transformation seems to not appear in the repository (and in fact,
the function ``randaugment(N, M)`` doesn't seem to exist in the
repository either). So we also make a best guess here about what
transformations might have been used.
.. warning::
This augmenter only works with image data, not e.g. bounding boxes.
The used PIL-based affine transformations are not yet able to
process non-image data. (This augmenter uses PIL-based affine
transformations to ensure that outputs are as similar as possible
to the paper's implementation.)
Parameters
----------
n : int or tuple of int or list of int or imgaug.parameters.StochasticParameter or None, optional
Parameter ``N`` in the paper, i.e. number of transformations to apply.
The paper suggests ``N=2`` for ImageNet.
See also parameter ``n`` in :class:`imgaug.augmenters.meta.SomeOf`
for more details.
Note that horizontal flips (p=50%) and crops are always applied. This
parameter only determines how many of the other transformations
are applied per image.
m : int or tuple of int or list of int or imgaug.parameters.StochasticParameter or None, optional
Parameter ``M`` in the paper, i.e. magnitude/severity/strength of the
applied transformations in interval ``[0 .. 30]`` with ``M=0`` being
the weakest. The paper suggests for ImageNet ``M=9`` in case of
ResNet-50 and ``M=28`` in case of EfficientNet-B7.
This implementation uses a default value of ``(6, 12)``, i.e. the
value is uniformly sampled per image from the interval ``[6 .. 12]``.
This ensures greater diversity of transformations than using a single
fixed value.
* If ``int``: That value will always be used.
* If ``tuple`` ``(a, b)``: A random value will be uniformly sampled per
image from the discrete interval ``[a .. b]``.
* If ``list``: A random value will be picked from the list per image.
* If ``StochasticParameter``: For ``B`` images in a batch, ``B`` values
will be sampled per augmenter (provided the augmenter is dependent
on the magnitude).
cval : number or tuple of number or list of number or imgaug.ALL or imgaug.parameters.StochasticParameter, optional
The constant value to use when filling in newly created pixels.
See parameter `fillcolor` in
:class:`imgaug.augmenters.pillike.Affine` for details.
The paper's repository uses an RGB value of ``125, 122, 113``.
This implementation uses a single intensity value of ``128``, which
should work better for cases where input images don't have exactly
``3`` channels or come from a different dataset than used by the
paper.
Examples
--------
>>> import imgaug.augmenters as iaa
>>> aug = iaa.RandAugment(n=2, m=9)
Create a RandAugment augmenter similar to the suggested hyperparameters
in the paper.
>>> aug = iaa.RandAugment(m=30)
Create a RandAugment augmenter with maximum magnitude/strength.
>>> aug = iaa.RandAugment(m=(0, 9))
Create a RandAugment augmenter that applies its transformations with a
random magnitude between ``0`` (very weak) and ``9`` (recommended for
ImageNet and ResNet-50). ``m`` is sampled per transformation.
>>> aug = iaa.RandAugment(n=(0, 3))
Create a RandAugment augmenter that applies ``0`` to ``3`` of its
child transformations to images. Horizontal flips (p=50%) and crops are
always applied.
"""

_M_MAX = 30

# according to paper:
# N=2, M=9 is optimal for ImageNet with ResNet-50
# N=2, M=28 is optimal for ImageNet with EfficientNet-B7
# for cval they use [125, 122, 113]
def __init__(self, n=2, m=(6, 12), cval=128,
name=None, deterministic=False, random_state=None):
# pylint: disable=invalid-name
random_state = iarandom.RNG(random_state)

# we don't limit the value range to 10 here, because the paper
# gives several examples of using more than 10 for M
m = iap.handle_discrete_param(
m, "m", value_range=(0, None),
tuple_to_uniform=True, list_to_choice=True,
allow_floats=False)
self._m = m
self._cval = cval

# The paper says in Appendix A.2.3 "ImageNet", that they actually
# always execute Horizontal Flips and Crops first and only then a
# random selection of the other transformations.
# Hence, we split here into two groups.
# It's not really clear what crop parameters they use, so we
# choose [0..M] here.
initial_augs = self._create_initial_augmenters_list(m)
main_augs = self._create_main_augmenters_list(m, cval)

# assign random state to all child augmenters
for lst in [initial_augs, main_augs]:
for augmenter in lst:
augmenter.random_state = random_state

super(RandAugment, self).__init__(
[
meta.Sequential(initial_augs, random_state=random_state),
meta.SomeOf(n, main_augs, random_order=True,
random_state=random_state)
],
name=name, deterministic=deterministic, random_state=random_state
)

@classmethod
def _create_initial_augmenters_list(cls, m):
# pylint: disable=invalid-name
return [
flip.Fliplr(0.5),
sizelib.KeepSizeByResize(
# assuming that the paper implementation crops M pixels from
# 224px ImageNet images, we crop here a fraction of
# M*(M_max/224)
sizelib.Crop(
percent=iap.Divide(
iap.Uniform(0, m),
224,
elementwise=True),
sample_independently=True,
keep_size=False),
interpolation="linear"
)
]

@classmethod
def _create_main_augmenters_list(cls, m, cval):
# pylint: disable=invalid-name
m_max = cls._M_MAX

def _float_parameter(level, maxval):
maxval_norm = maxval / m_max
return iap.Multiply(level, maxval_norm, elementwise=True)

def _int_parameter(level, maxval):
# paper applies just int(), so we don't round here
return iap.Discretize(_float_parameter(level, maxval),
round=False)

# In the paper's code they use the definition from AutoAugment,
# which is 0.1 + M*1.8/10. But that results in 0.1 for M=0, i.e. for
# Brightness an almost black image, while M=5 would result in an
# unaltered image. For AutoAugment that may be fine, as M is optimized
# for each operation individually, but here we have only one fixed M
# for all operations. Hence, we rather set this to 1.0 +/- M*0.9/10,
# so that M=10 would result in 0.1 or 1.9.
def _enhance_parameter(level):
fparam = _float_parameter(level, 0.9)
return iap.Clip(
iap.Add(1.0, iap.RandomSign(fparam), elementwise=True),
0.1, 1.9
)

def _subtract(a, b):
return iap.Subtract(a, b, elementwise=True)

def _affine(*args, **kwargs):
kwargs["fillcolor"] = cval
if "center" not in kwargs:
kwargs["center"] = (0.0, 0.0)
return pillike.Affine(*args, **kwargs)

_rnd_s = iap.RandomSign
shear_max = np.rad2deg(0.3)

# we don't add vertical flips here, paper is not really clear about
# whether they used them or not
return [
meta.Identity(),
pillike.Autocontrast(cutoff=0),
pillike.Equalize(),
arithmetic.Invert(p=1.0),
# they use Image.rotate() for the rotation, which uses
# the image center as the rotation center
_affine(rotate=_rnd_s(_float_parameter(m, 30)),
center=(0.5, 0.5)),
# paper uses 4 - int_parameter(M, 4)
pillike.Posterize(
nb_bits=_subtract(
8,
iap.Clip(_int_parameter(m, 6), 0, 6)
)
),
# paper uses 256 - int_parameter(M, 256)
pillike.Solarize(
p=1.0,
threshold=iap.Clip(
_subtract(256, _int_parameter(m, 256)),
0, 256
)
),
pillike.EnhanceColor(_enhance_parameter(m)),
pillike.EnhanceContrast(_enhance_parameter(m)),
pillike.EnhanceBrightness(_enhance_parameter(m)),
pillike.EnhanceSharpness(_enhance_parameter(m)),
_affine(shear={"x": _rnd_s(_float_parameter(m, shear_max))}),
_affine(shear={"y": _rnd_s(_float_parameter(m, shear_max))}),
_affine(translate_percent={"x": _rnd_s(_float_parameter(m, 0.33))}),
_affine(translate_percent={"y": _rnd_s(_float_parameter(m, 0.33))}),
# paper code uses 20px on CIFAR (i.e. size 20/32), no information
# on ImageNet values so we just use the same values
arithmetic.Cutout(1,
size=iap.Clip(
_float_parameter(m, 20 / 32), 0, 20 / 32),
squared=True,
fill_mode="constant",
cval=cval),
pillike.FilterBlur(),
pillike.FilterSmooth()
]

def get_parameters(self):
"""See :func:`imgaug.augmenters.meta.Augmenter.get_parameters`."""
someof = self[1]
return [someof.n, self._m, self._cval]
13 changes: 10 additions & 3 deletions imgaug/parameters.py
Expand Up @@ -1737,6 +1737,9 @@ class Discretize(StochasticParameter):
other_param : imgaug.parameters.StochasticParameter
The other parameter, which's values are to be discretized.
round : bool, optional
Whether to round before converting to integer dtype.
Examples
--------
>>> import imgaug.parameters as iap
Expand All @@ -1745,10 +1748,12 @@ class Discretize(StochasticParameter):
Create a discrete standard gaussian distribution.
"""
def __init__(self, other_param):
def __init__(self, other_param, round=True):
# pylint: disable=redefined-builtin
super(Discretize, self).__init__()
_assert_arg_is_stoch_param("other_param", other_param)
self.other_param = other_param
self.round = round

def _draw_samples(self, size, random_state):
samples = self.other_param.draw_samples(size, random_state=random_state)
Expand All @@ -1768,14 +1773,16 @@ def _draw_samples(self, size, random_state):
# lower bound here -- shouldn't happen though
bitsize = max(bitsize, 8)
dtype = np.dtype("int%d" % (bitsize,))
return np.round(samples).astype(dtype)
if self.round:
samples = np.round(samples)
return samples.astype(dtype)

def __repr__(self):
return self.__str__()

def __str__(self):
opstr = str(self.other_param)
return "Discretize(%s)" % (opstr,)
return "Discretize(%s, round=%s)" % (opstr, str(self.round))


class Multiply(StochasticParameter):
Expand Down

0 comments on commit 00789f1

Please sign in to comment.