In [25]:
import matplotlib.pyplot as plt
import numpy as np
import os
import random
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from PIL import Image


In [26]:
def plotImages(images_arr):
    fig, axes = plt.subplots(1, 10, figsize=(20,20))
    axes = axes.flatten()
    for img, ax in zip( images_arr , axes ):
        ax.imshow(img)
        ax.axis("off")
    plt.tight_layout()
    plt.show()

In [27]:
def saveNumpyImage(numpy_image, dest_path):
    image = Image.fromarray(numpy_image)
    image.save(dest_path)

In [28]:

def augmentData(src_dir, dest_dir, image_generator, augmented_images_count=5):
    # iterate over all images in the passed directory
    for filename in os.listdir(src_dir):
        if os.path.isfile(os.path.join(dest_dir, filename)):
            continue
        filename_without_extension = filename.split('.')[0]
        image_path = os.path.join(src_dir, filename)

        # extract image
        image = np.expand_dims(plt.imread(image_path), 0)

        # augment image
        aug_iter = image_generator.flow(image)
        aug_images = [next(aug_iter)[0].astype(np.uint8) for i in range(augmented_images_count)]

        # save original image in destination directory as well
        saveNumpyImage(image[0], dest_dir + "/" + filename)

        # iterate over all augmented images and save them in the destination directory which an additional suffix
        suffix = 0
        for aug_image in aug_images:
            saveNumpyImage(aug_image, dest_dir + "/" + filename_without_extension + "_" + str(suffix) + ".jpg")
            suffix += 1
        


In [29]:
# NOTE : train_samples_count + valid_samples_count + test_samples_count must be smaller than the number of files in src_dir
# NOTE : it also assumes the class directories of the dest_dir sub directories (train, valid, test) to be empty
def prepareModelData(src_dir, dest_dir, c, train_samples_count, valid_samples_count, test_samples_count):
    print(len([name for name in os.listdir(src_dir)]))
    if train_samples_count + valid_samples_count + test_samples_count > len([name for name in os.listdir(src_dir)]):
        print("samples counts too large")
        return

    counter = 0
    for filename in os.listdir(src_dir):
        image_path = os.path.join(src_dir, filename)
        # extract image
        image = np.expand_dims(plt.imread(image_path), 0)

        if counter < train_samples_count:
            saveNumpyImage(image[0], dest_dir + "/train/" + c + "/" + filename)
        elif counter < train_samples_count + valid_samples_count:
            saveNumpyImage(image[0], dest_dir + "/valid/" + c + "/" + filename)
        elif counter < train_samples_count + valid_samples_count + test_samples_count:
            saveNumpyImage(image[0], dest_dir + "/test/" + c + "/" + filename)
        else:
            saveNumpyImage(image[0], dest_dir + "/train/" + c + "/" + filename)
        
        counter += 1


In [30]:
classes = ["andromeda","merlin","morgana","perseus"]


In [31]:
gen = ImageDataGenerator(
    rotation_range=20, 
    width_shift_range=0.1, 
    height_shift_range=0.1, 
    shear_range=0.1, 
    zoom_range=0.2, 
    channel_shift_range=20. , 
    vertical_flip=False,
    horizontal_flip=True
)

In [32]:

# Create augmented versions of each image
for c in classes:
    augmentData(
        src_dir="data/all/" + c, 
        dest_dir="data/all/augmented/" + c,
        image_generator = gen, 
        augmented_images_count=10
    )

In [33]:
# Prepare data used by the model
for c in classes:
    prepareModelData(
        src_dir="data/all/augmented/" + c,
        dest_dir="data",
        c=c,
        train_samples_count=900,
        valid_samples_count=100,
        test_samples_count=30
    )

1089
1034
1320
1034
