Skip to content

Commit

Permalink
Switch KeepSizeByResize to augment_batch() interface
Browse files Browse the repository at this point in the history
  • Loading branch information
aleju committed Oct 14, 2019
1 parent 6117146 commit cac59ea
Showing 1 changed file with 111 additions and 142 deletions.
253 changes: 111 additions & 142 deletions imgaug/augmenters/size.py
Original file line number Diff line number Diff line change
Expand Up @@ -2433,6 +2433,117 @@ def _validate_param(val, allow_same_as_images):
self.interpolation_segmaps = _validate_param(interpolation_segmaps,
True)

def _augment_batch(self, batch, random_state, parents, hooks):
with batch.propagation_hooks_ctx(self, hooks, parents):
images_were_array = None
if batch.images is not None:
images_were_array = ia.is_np_array(batch.images)
shapes_orig = self.get_shapes(batch)

samples = self._draw_samples(batch.nb_items, random_state)

batch = self.children.augment_batch(
batch, parents=parents + [self], hooks=hooks)

if batch.images is not None:
batch.images = self._keep_size_images(
batch.images, shapes_orig["images"], images_were_array,
samples)

if batch.heatmaps is not None:
# dont use shapes_orig["images"] because they might be None
batch.heatmaps = self._keep_size_maps(
batch.heatmaps, shapes_orig["heatmaps"],
shapes_orig["heatmaps_arr"], samples[1])

if batch.segmentation_maps is not None:
# dont use shapes_orig["images"] because they might be None
batch.segmentation_maps = self._keep_size_maps(
batch.segmentation_maps, shapes_orig["segmentation_maps"],
shapes_orig["segmentation_maps_arr"], samples[2])

for augm_name in ["keypoints", "bounding_boxes", "polygons",
"line_strings"]:
augm_value = getattr(batch, augm_name)
if augm_value is not None:
func = functools.partial(
self._keep_size_keypoints,
shapes_orig=shapes_orig[augm_name],
interpolations=samples[0])
cbaois = self._apply_to_cbaois_as_keypoints(augm_value,
func)
setattr(batch, augm_name, cbaois)
return batch

@classmethod
def _keep_size_images(cls, images, shapes_orig, images_were_array,
samples):
interpolations, _, _ = samples

gen = zip(images, interpolations, shapes_orig)
result = []
for image, interpolation, input_shape in gen:
if interpolation == KeepSizeByResize.NO_RESIZE:
result.append(image)
else:
result.append(
ia.imresize_single_image(image, input_shape[0:2],
interpolation))

if images_were_array:
# note here that NO_RESIZE can have led to different shapes
nb_shapes = len(set([image.shape for image in result]))
if nb_shapes == 1:
result = np.array(result, dtype=images.dtype)

return result

@classmethod
def _keep_size_maps(cls, augmentables, shapes_orig_images,
shapes_orig_arrs, interpolations):
result = []
gen = zip(augmentables, interpolations,
shapes_orig_arrs, shapes_orig_images)
for augmentable, interpolation, arr_shape_orig, img_shape_orig in gen:
if interpolation == "NO_RESIZE":
result.append(augmentable)
else:
augmentable = augmentable.resize(
arr_shape_orig[0:2], interpolation=interpolation)
augmentable.shape = img_shape_orig
result.append(augmentable)

return result

@classmethod
def _keep_size_keypoints(cls, kps_aug, shapes_orig, interpolations):
result = []
gen = zip(kps_aug, interpolations, shapes_orig)
for kps_aug, interpolation, input_shape in gen:
if interpolation == KeepSizeByResize.NO_RESIZE:
result.append(kps_aug)
else:
result.append(kps_aug.on(input_shape))

return result

@classmethod
def get_shapes(cls, batch):
result = dict()
augms = batch.get_augmentables()
for augm_name, augm_value, augm_attr_name in augms:
result[augm_name] = [cell.shape for cell in augm_value]

if batch.heatmaps is not None:
result["heatmaps_arr"] = [
cell.arr_0to1.shape for cell in batch.heatmaps]

if batch.segmentation_maps is not None:
result["segmentation_maps_arr"] = [
cell.arr.shape for cell in batch.segmentation_maps]

return result

def _draw_samples(self, nb_images, random_state):
rngs = random_state.duplicate(3)
interpolations = self.interpolation.draw_samples((nb_images,),
Expand Down Expand Up @@ -2479,148 +2590,6 @@ def _draw_samples(self, nb_images, random_state):

return interpolations, interpolations_heatmaps, interpolations_segmaps

def _is_propagating(self, augmentables, parents, hooks):
return (
hooks is None
or hooks.is_propagating(
augmentables, augmenter=self, parents=parents, default=True)
)

def _augment_images(self, images, random_state, parents, hooks):
input_was_array = ia.is_np_array(images)
if self._is_propagating(images, parents, hooks):
interpolations, _, _ = self._draw_samples(len(images),
random_state)
input_shapes = [image.shape[0:2] for image in images]

images_aug = self.children.augment_images(
images=images,
parents=parents + [self],
hooks=hooks
)

gen = zip(images_aug, interpolations, input_shapes)
result = []
for image_aug, interpolation, input_shape in gen:
if interpolation == KeepSizeByResize.NO_RESIZE:
result.append(image_aug)
else:
result.append(
ia.imresize_single_image(image_aug, input_shape[0:2],
interpolation))

if input_was_array:
# note here that NO_RESIZE can have led to different shapes
nb_shapes = len(set([image.shape for image in result]))
if nb_shapes == 1:
result = np.array(result, dtype=images.dtype)
else:
result = images
return result

def _augment_heatmaps(self, heatmaps, random_state, parents, hooks):
if self._is_propagating(heatmaps, parents, hooks):
nb_heatmaps = len(heatmaps)
_, interpolations_heatmaps, _ = self._draw_samples(
nb_heatmaps, random_state)
input_arr_shapes = [heatmaps_i.arr_0to1.shape
for heatmaps_i in heatmaps]

# augment according to if and else list
heatmaps_aug = self.children.augment_heatmaps(
heatmaps,
parents=parents + [self],
hooks=hooks
)

result = []
gen = zip(heatmaps, heatmaps_aug, interpolations_heatmaps,
input_arr_shapes)
for heatmap, heatmap_aug, interpolation, input_arr_shape in gen:
if interpolation == "NO_RESIZE":
result.append(heatmap_aug)
else:
heatmap_aug = heatmap_aug.resize(
input_arr_shape[0:2], interpolation=interpolation)
heatmap_aug.shape = heatmap.shape
result.append(heatmap_aug)
else:
result = heatmaps

return result

def _augment_segmentation_maps(self, segmaps, random_state, parents, hooks):
if self._is_propagating(segmaps, parents, hooks):
nb_segmaps = len(segmaps)
_, _, interpolations_segmaps = self._draw_samples(nb_segmaps,
random_state)
input_arr_shapes = [segmaps_i.arr.shape for segmaps_i in segmaps]

# augment according to if and else list
segmaps_aug = self.children.augment_segmentation_maps(
segmaps,
parents=parents + [self],
hooks=hooks
)

result = []
gen = zip(segmaps, segmaps_aug, interpolations_segmaps,
input_arr_shapes)
for segmaps, segmaps_aug, interpolation, input_arr_shape in gen:
if interpolation == "NO_RESIZE":
result.append(segmaps_aug)
else:
segmaps_aug = segmaps_aug.resize(
input_arr_shape[0:2], interpolation=interpolation)
segmaps_aug.shape = segmaps.shape
result.append(segmaps_aug)
else:
result = segmaps

return result

def _augment_keypoints(self, keypoints_on_images, random_state, parents,
hooks):
if self._is_propagating(keypoints_on_images, parents, hooks):
interpolations, _, _ = self._draw_samples(
len(keypoints_on_images), random_state)
input_shapes = [kpsoi_i.shape for kpsoi_i in keypoints_on_images]

# augment according to if and else list
kps_aug = self.children.augment_keypoints(
keypoints_on_images=keypoints_on_images,
parents=parents + [self],
hooks=hooks
)

result = []
gen = zip(keypoints_on_images, kps_aug, interpolations,
input_shapes)
for kps, kps_aug, interpolation, input_shape in gen:
if interpolation == KeepSizeByResize.NO_RESIZE:
result.append(kps_aug)
else:
result.append(kps_aug.on(input_shape))
else:
result = keypoints_on_images

return result

def _augment_polygons(self, polygons_on_images, random_state, parents,
hooks):
return self._augment_polygons_as_keypoints(
polygons_on_images, random_state, parents, hooks)

def _augment_line_strings(self, line_strings_on_images, random_state,
parents, hooks):
return self._augment_line_strings_as_keypoints(
line_strings_on_images, random_state, parents, hooks)

def _augment_bounding_boxes(self, bounding_boxes_on_images, random_state,
parents, hooks):
return self._augment_bounding_boxes_as_keypoints(
bounding_boxes_on_images, random_state, parents, hooks)

def _to_deterministic(self):
aug = self.copy()
aug.children = aug.children.to_deterministic()
Expand Down

0 comments on commit cac59ea

Please sign in to comment.