Skip to content

Commit

Permalink
Merge pull request #443 from aleju/fix_resize
Browse files Browse the repository at this point in the history
Fix resize breaking input dtypes #442
  • Loading branch information
aleju committed Sep 26, 2019
2 parents 847d00c + 9b37b66 commit 75a14ee
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 1 deletion.
3 changes: 3 additions & 0 deletions changelogs/master/fixes/20190926_fixed_resize_dtype.md
@@ -0,0 +1,3 @@
* Fixed `Resize` always returning an `uint8` array during image augmentation
if the input was a single numpy array and all augmented images had the
same shape. #442 #443
2 changes: 1 addition & 1 deletion imgaug/augmenters/size.py
Expand Up @@ -591,7 +591,7 @@ def _augment_images(self, images, random_state, parents, hooks):
if not isinstance(images, list):
all_same_size = (len(set([image.shape for image in result])) == 1)
if all_same_size:
result = np.array(result, dtype=np.uint8)
result = np.array(result, dtype=images.dtype)

return result

Expand Down
41 changes: 41 additions & 0 deletions test/augmenters/test_size.py
Expand Up @@ -707,6 +707,47 @@ def test_get_parameters(self):
assert params[0].value == 1
assert params[1].value == "nearest"

def test_dtypes_roughly(self):
# most of the dtype testing is done for imresize_many_images()
# so we focus here on a rough test that merely checks if the dtype
# does not change

# these dtypes should be kept in sync with imresize_many_images()
dtypes = [
"uint8",
"uint16",
"int8",
"int16",
"float16",
"float32",
"float64",
"bool"
]

for dt in dtypes:
for ip in ["nearest", "cubic"]:
aug = iaa.Resize({"height": 10, "width": 20}, interpolation=ip)
for is_list in [False, True]:
with self.subTest(dtype=dt, interpolation=ip,
is_list=is_list):
image = np.full((9, 19, 3), 1, dtype=dt)
images = [image, image]
if not is_list:
images = np.array(images, dtype=dt)

images_aug = aug(images=images)

if is_list:
assert isinstance(images_aug, list)
else:
assert ia.is_np_array(images_aug)

assert len(images_aug) == 2
for image_aug in images_aug:
assert image_aug.dtype.name == dt
assert image_aug.shape == (10, 20, 3)
assert np.all(image_aug >= 1 - 1e-4)


class TestPad(unittest.TestCase):
def setUp(self):
Expand Down

0 comments on commit 75a14ee

Please sign in to comment.