# Data Augmentation with Masking

Data Augmentation is a regularisation technique that's used to avoid overfitting when training Computer Vision models. Adjustments are made to the original images in the training dataset before being used in training. Some example adjustments include; translating, croping, scaling, rotating, changing brightness and contrast. We do this to reduce the dependence of the model on spurious characteristics; e.g. training data may only contain faces that fill 1/4 of the image, so the model trainied without data augmentation might unhelpfully learn that faces can only be of this size.

Certain Computer Vision tasks (like Image Segmentation) require the use of 'masks', and we have to take extra care when using these in conjunction with data augmentation techniques. An image mask is typically a binary or greyscale image, that overlays a base image and can be used to highlight different regions of that base image. In Image Segmentation, the training data contains masks to represent each object class of interest (e.g. floor, table, person, etc).

When we adjust the base image as part of data augmentation, we also need to apply exactly the same operation to the associated masks. An example would be after applying a horizontal flip to the base image, we'd need to also flip the mask, to preserve the corresponsence between the base image and mask.

In [1]:
import mxnet as mx
import numpy as np

## 1) Generate example image and mask layers

In [2]:
img_height = 20
img_width = 15

In [3]:
# Generate 5 random mask layers
n_masks = 5
masks = mx.nd.array(np.random.randint(low=0, high=2, size=(img_height, img_width, n_masks)))
masks.shape

(20, 15, 5)

In [4]:
# Generate a random 3 channel image
n_channels = 3
image = mx.nd.array(np.random.randint(low=0, high=256, size=(img_height, img_width, n_channels)))
image.shape

(20, 15, 3)

## 2) Concatinate image and masks

In [5]:
# Concatinate on channels dim, to obtain an 8 channel image
# (3 channels for the original image, plus 5 layer masks)
image_w_masks = mx.nd.concat(image, masks, dim=2)
image_w_masks.shape

(20, 15, 8)

## 3) Perform Positional Augmentations

In [6]:
# Random crop
crop_height = 5
crop_width = 4
# Watch out: weight before height in size param!
aug_image_w_masks, crop_box = mx.image.random_crop(image_w_masks, size=(crop_width, crop_height))

In [7]:
# Deterministic resize
resize_size = 10
aug_image_w_masks = mx.image.resize_short(aug_image_w_masks, size=resize_size)
# Add more translation/scale/rotation augmentations here...

In [8]:
aug_image_w_masks.shape

(12, 10, 8)

## 4) Split image and masks 

In [9]:
aug_image = aug_image_w_masks[:, :, :n_channels]
aug_masks = aug_image_w_masks[:, :, n_channels:]

print("aug_image shape: " + str(aug_image.shape))
print("aug_masks shape: " + str(aug_masks.shape))

aug_image shape: (12, 10, 3)
aug_masks shape: (12, 10, 5)


## 5) Perform color augmentation (optional)

In [10]:
# only want to apply this step to the image, and not the mask layers.
# creating callable augmenter class, instead of direct function, as implemented directly in `__call__` method of class.
aug = mx.image.BrightnessJitterAug(brightness=1)
aug_image = aug(aug_image)

print("aug_image shape: " + str(aug_image.shape))

aug_image shape: (12, 10, 3)
