# Data Augmentation

In [54]:
from numpy import expand_dims
from keras.preprocessing.image import load_img
from keras.preprocessing.image import save_img
from keras.preprocessing.image import img_to_array
from keras.preprocessing.image import ImageDataGenerator
from matplotlib import pyplot

from os import listdir
from os.path import isfile, join

from tqdm import tqdm

print("Libraries Imported")

Libraries Imported


### Functions

In [55]:
"""
Data augmentation of images
    > horizontally
    > vertically
    > ...

Input:
    > path of starting images
    > path of directory where the augmented images will be saved
Output: augmented images in specified folder 
"""
def augmentImages(images_path, save_dir_path):
    imgs_names = [f for f in listdir(images_path) if isfile(join(images_path, f))]

    for img in tqdm(imgs_names):
        img_path = images_path + "/" + img
        horizontalAugmentation(img_path, save_dir_path)

"""
Horizontal shift image augmentation

Input:
    > path of starting image
    > path of directory where the augmented images will be saved
"""
def horizontalAugmentation(img_path, save_dir_path):

    # load the image
    img = load_img(img_path)

    # convert to numpy array
    data = img_to_array(img)

    # expand dimension to one sample
    samples = expand_dims(data, 0)

    # create image data augmentation generator
    datagen = ImageDataGenerator(width_shift_range=[-50,50])

    # prepare iterator
    it = datagen.flow(samples, batch_size=1)

    # generate samples and plot
    for i in range(9):

        # generate batch of images
        batch = it.next()

        # convert to unsigned integers for viewing
        image = batch[0].astype('uint8')

        # save image in specified dir
        temp_name = img_path[img_path.rfind("/")+1:-4] + "_horiz_augm_" + str(i) + ".png"
        temp_dir_path = save_dir_path + "/" + temp_name
        save_img(temp_dir_path, image)

### Tests

In [56]:
# horizontal shift image augmentation
#img_1 = "starting_images/cyberpunk_girl.png"
#horizontalAugmentation(img_1, "augmented_images")

In [57]:
#img_2 = "starting_images/universe.png"
#horizontalAugmentation(img_2, "temp")

In [58]:
augmentImages("starting_images", "augmented_images")

100%|██████████| 5/5 [00:02<00:00,  1.98it/s]
