Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
ColorNormalizeAug now accepts mxnet NDArrays as well as stated (wrong…
Browse files Browse the repository at this point in the history
…ly) in the documentation. (#11606)

* ColorNormalizeAug now accepts mxnet NDArrays as well.

According to the doc of mx.image.ColorNormalizeAug mean and std must be mxnet NDArrays though in reality it accepts only numpy arrays. This commit allows ColorNormalizeAug to accept both numpy and mxnet ndarrays.

Extended the image unit tests to cover this functionality.

* Fixed missing brackets

* Simplified assert call according to pylint

* Fixed linting issue

* Fixed namespace issue

* Fixed std and mean assignment

* Fixed isinstance python3 mode

* Added with_seed annotation
  • Loading branch information
ifeherva authored and zhreshold committed Jul 9, 2018
1 parent 0d5ebe1 commit 459a891
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 8 deletions.
1 change: 1 addition & 0 deletions CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -173,3 +173,4 @@ List of Contributors
* [Jesse Brizzi](https://github.com/jessebrizzi)
* [Hang Zhang](http://hangzh.com)
* [Kou Ding](https://github.com/chinakook)
* [Istvan Fehervari](https://github.com/ifeherva)
12 changes: 6 additions & 6 deletions python/mxnet/image/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -839,8 +839,8 @@ class ColorNormalizeAug(Augmenter):
"""
def __init__(self, mean, std):
super(ColorNormalizeAug, self).__init__(mean=mean, std=std)
self.mean = nd.array(mean) if mean is not None else None
self.std = nd.array(std) if std is not None else None
self.mean = mean if mean is None or isinstance(mean, nd.NDArray) else nd.array(mean)
self.std = std if std is None or isinstance(std, nd.NDArray) else nd.array(std)

def __call__(self, src):
"""Augmenter body"""
Expand Down Expand Up @@ -999,14 +999,14 @@ def CreateAugmenter(data_shape, resize=0, rand_crop=False, rand_resize=False, ra
auglist.append(RandomGrayAug(rand_gray))

if mean is True:
mean = np.array([123.68, 116.28, 103.53])
mean = nd.array([123.68, 116.28, 103.53])
elif mean is not None:
assert isinstance(mean, np.ndarray) and mean.shape[0] in [1, 3]
assert isinstance(mean, (np.ndarray, nd.NDArray)) and mean.shape[0] in [1, 3]

if std is True:
std = np.array([58.395, 57.12, 57.375])
std = nd.array([58.395, 57.12, 57.375])
elif std is not None:
assert isinstance(std, np.ndarray) and std.shape[0] in [1, 3]
assert isinstance(std, (np.ndarray, nd.NDArray)) and std.shape[0] in [1, 3]

if mean is not None or std is not None:
auglist.append(ColorNormalizeAug(mean, std))
Expand Down
15 changes: 13 additions & 2 deletions tests/python/unittest/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import mxnet as mx
import numpy as np
from mxnet.test_utils import *
from common import assertRaises
from common import assertRaises, with_seed
import shutil
import tempfile
import unittest
Expand Down Expand Up @@ -153,8 +153,19 @@ def test_imageiter(self):
for batch in test_iter:
pass


@with_seed()
def test_augmenters(self):
# ColorNormalizeAug
mean = np.random.rand(3) * 255
std = np.random.rand(3) + 1
width = np.random.randint(100, 500)
height = np.random.randint(100, 500)
src = np.random.rand(height, width, 3) * 255.
# We test numpy and mxnet NDArray inputs
color_norm_aug = mx.image.ColorNormalizeAug(mean=mx.nd.array(mean), std=std)
out_image = color_norm_aug(mx.nd.array(src))
assert_almost_equal(out_image.asnumpy(), (src - mean) / std, atol=1e-3)

# only test if all augmenters will work
# TODO(Joshua Zhang): verify the augmenter outputs
im_list = [[0, x] for x in TestImage.IMAGES]
Expand Down

0 comments on commit 459a891

Please sign in to comment.