In [1]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import tensorflow as tf

In [3]:
from tensorflow.keras.preprocessing.image import ImageDataGenerator

In [72]:
def generate_dataset_from_directory(path, count, target_size = (244, 244), 
    generator = ImageDataGenerator()):
    """
    Randomly sample images with replacement from the provided directory, using 
    the containing folders of each image as labels, and randomly augment them
    using a keras ImageDataGenerator.
    
    Args:
        path: The directory in which the images you would like to sample from
        are contained.
        count: The amount of images you would like in the final dataset.
        target_size: The size all of the images should be resized to.
            default: (244, 244)
        generator: ImageDataGenerator to use for sampling and augmentation of
        images.
            default: ImageDataGenerator()
    
    Returns:
        images: List of length "count" containing an augmented set of randomly
        sampled images from the specfiied directory.
        labels: One-hot labels corresponding to each image in "images"
    """
    # Set up flow from provided directory.
    flow = gen.flow_from_directory(path, target_size = target_size, 
        batch_size = 5)
    
    # Set up loop variables.
    images = []
    labels = []
    
    while len(images) < count:
        next_images, next_labels = flow.next()
        images.extend(next_images)
        labels.extend(next_labels)
    # Cut off any extra unwanted images.
    images = images[:count]
    labels = labels[:count]
    
    return images, labels

In [73]:
# Define custom parameters for data augmentation.
data_generator = ImageDataGenerator(horizontal_flip = True, 
    height_shift_range = 0.05, width_shift_range = 0.05, zoom_range = 0.2, 
    brightness_range = [0.9, 1.5])

X_train, Y_train = generate_dataset_from_directory('Train', 2000, 
    target_size = (244, 244), generator = data_generator)
X_test, Y_test = generate_dataset_from_directory('Test', 500, 
    target_size = (244, 244), generator = data_generator)

Found 3444 images belonging to 150 classes.
Found 3376 images belonging to 150 classes.


In [74]:
len(X_train) == len(Y_train),\
len(X_test) == len(Y_test),\
len(X_train),\
len(X_test)

(True, True, 2000, 500)