Permalink
Browse files

Fix bg aug errors; add bg aug tests; fix image not installed; change …

…some docu
  • Loading branch information...
aleju committed Aug 26, 2017
1 parent ae08785 commit b72df66dbda7ce0cd534d20fd7d0f284f33b9eb8
Showing with 289 additions and 36 deletions.
  1. +1 −0 MANIFEST.in
  2. +54 −23 imgaug/augmenters.py
  3. +3 −4 imgaug/imgaug.py
  4. BIN { → imgaug}/quokka.jpg
  5. +1 −0 setup.py
  6. +230 −9 tests/test.py
View
@@ -0,0 +1 @@
recursive-include imgaug *.py *.jpg *.gif *.ttf
View
@@ -139,10 +139,11 @@ def augment_batches(self, batches, hooks=None, background=False):
If true, hooks can currently not be used as that would require
pickling functions.
Returns
Yields
-------
augmented_batch : list
Corresponding list of batches of augmented images.
augmented_batch : ia.Batch or list of ia.KeypointsOnImage or list of (H,W,C) ndarray or list of (H,W) ndarray or (N,H,W,C) ndarray or (N,H,W) ndarray
Augmented images/keypoints.
Datatype usually matches the input datatypes per list element.
"""
assert isinstance(batches, list)
@@ -154,6 +155,7 @@ def augment_batches(self, batches, hooks=None, background=False):
batches_original_dts = []
for i, batch in enumerate(batches):
if isinstance(batch, ia.Batch):
batch.data = (i, batch.data)
batches_normalized.append(batch)
batches_original_dts.append("imgaug.Batch")
elif ia.is_np_array(batch):
@@ -169,37 +171,49 @@ def augment_batches(self, batches, hooks=None, background=False):
batches_original_dts.append("list_of_numpy_arrays")
elif isinstance(batch[0], ia.KeypointsOnImage):
batches_normalized.append(ia.Batch(keypoints=batch, data=i))
batches_original_dts.append("list_of_imgaug.KeypointsOfImage")
batches_original_dts.append("list_of_imgaug.KeypointsOnImage")
else:
raise Exception("Unknown datatype in batch[0]. Expected numpy array or imgaug.KeypointsOnImage, got %s." % (type(batch[0]),))
else:
raise Exception("Unknown datatype in of batch. Expected imgaug.Batch or numpy array or list of numpy arrays/imgaug.KeypointsOnImage. Got %s." % (type(batch),))
def unnormalize_batch(batch_aug):
if batch_aug.data is None:
return batch_aug
#if batch_aug.data is None:
# return batch_aug
#else:
i = batch_aug.data
# if input was ia.Batch, then .data has content (i, .data)
if isinstance(i, tuple):
i = i[0]
dt_orig = batches_original_dts[i]
if dt_orig == "imgaug.Batch":
batch_unnormalized = batch_aug
# change (i, .data) back to just .data
batch_unnormalized.data = batch_unnormalized.data[1]
elif dt_orig == "numpy_array":
batch_unnormalized = batch_aug.images_aug
elif dt_orig == "empty_list":
batch_unnormalized = []
elif dt_orig == "list_of_numpy_arrays":
batch_unnormalized = batch_aug.images_aug
elif dt_orig == "list_of_imgaug.KeypointsOnImage":
batch_unnormalized = batch_aug.keypoints_aug
else:
i = batch_aug.data
dt_orig = batches_original_dts[i]
if dt_orig == "imgaug.Batch":
batch_unnormalized = batch_aug
elif dt_orig == "numpy_array":
batch_unnormalized = batch_aug.images_aug
elif dt_orig == "empty_list":
batch_unnormalized = []
elif dt_orig == "list_of_numpy_arrays":
batch_unnormalized = batch_aug.images_aug
elif dt_orig == "list_of_imgaug.KeypointsOnImage":
batch_unnormalized = batch_aug.keypoints_aug
else:
raise Exception("Internal error. Unexpected value in dt_orig '%s'. This should never happen." % (dt_orig,))
return batch_unnormalized
raise Exception("Internal error. Unexpected value in dt_orig '%s'. This should never happen." % (dt_orig,))
return batch_unnormalized
if not background:
for batch_normalized in batches_normalized:
if batch_normalized.images is not None:
batch_augment_images = batch_normalized.images is not None
batch_augment_keypoints = batch_normalized.keypoints is not None
if batch_augment_images and batch_augment_keypoints:
augseq_det = self.to_deterministic() if not self.deterministic else self
batch_normalized.images_aug = augseq_det.augment_images(batch_normalized.images, hooks=hooks)
batch_normalized.keypoints_aug = augseq_det.augment_keypoints(batch_normalized.keypoints, hooks=hooks)
elif batch_augment_images:
batch_normalized.images_aug = self.augment_images(batch_normalized.images, hooks=hooks)
if batch_normalized.keypoints is not None:
elif batch_augment_keypoints:
batch_normalized.keypoints_aug = self.augment_keypoints(batch_normalized.keypoints, hooks=hooks)
batch_unnormalized = unnormalize_batch(batch_normalized)
yield batch_unnormalized
@@ -1905,6 +1919,23 @@ class Lambda(Augmenter):
random_state : int or np.random.RandomState or None, optional(default=None)
See `Augmenter.__init__()`
Examples
--------
>>> def func_images(images, random_state, parents, hooks):
>>> images[:, ::2, :, :] = 0
>>> return images
>>>
>>> def func_keypoints(keypoints_on_images, random_state, parents, hooks):
>>> return keypoints_on_images
>>>
>>> aug = iaa.Lambda(
>>> func_images=func_images,
>>> func_keypoints=func_keypoints
>>> )
Replaces every second row in images with black pixels and leaves keypoints
unchanged.
"""
def __init__(self, func_images, func_keypoints, name=None, deterministic=False, random_state=None):
View
@@ -27,7 +27,6 @@
# filepath to the quokka image
QUOKKA_FP = os.path.join(
os.path.dirname(os.path.abspath(__file__)),
"..",
"quokka.jpg"
)
@@ -823,12 +822,12 @@ def _augment_images_worker(self, augseq, queue_source, queue_result, source_fini
batch_augment_keypoints = batch.keypoints is not None and self.augment_keypoints
if batch_augment_images and batch_augment_keypoints:
augseq_det = augseq.to_deterministic()
augseq_det = augseq.to_deterministic() if not augseq.deterministic else augseq
batch.images_aug = augseq_det.augment_images(batch.images)
batch.keypoints_aug = augseq_det.augment_keypoints(batch.keypoints)
elif batch_augment_images is not None:
elif batch_augment_images:
batch.images_aug = augseq.augment_images(batch.images)
elif batch_augment_keypoints is not None:
elif batch_augment_keypoints:
batch.keypoints_aug = augseq.augment_keypoints(batch.keypoints)
# send augmented batch to output queue
File renamed without changes.
View
@@ -23,6 +23,7 @@
download_url="https://github.com/aleju/imgaug/archive/0.2.4.tar.gz",
install_requires=["scipy", "scikit-image>=0.11.0", "numpy>=1.7.0", "six"],
packages=find_packages(),
include_package_data=True,
license="MIT",
description="Image augmentation library for machine learning",
long_description=long_description,
Oops, something went wrong.

0 comments on commit b72df66

Please sign in to comment.