Common imports

In [None]:
import os
import numpy as np

# Visualization
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns
%matplotlib inline

In [None]:
# Tensorflow imports
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.preprocessing.image import ImageDataGenerator

In [None]:
image_folder = os.path.join('datasets','face_dataset_train_images')
img_height, img_width = 250,250
num_classes = 2 #me notme

Look at the data

In [None]:
dataset = keras.preprocessing.image_dataset_from_directory(image_folder,seed=42,image_size=(img_height,img_width),label_mode='categorical',shuffle=True)

In [None]:
class_names = dataset.class_names
class_names

In [None]:
# Helper function to get classname of the image
def get_classname(class_name,mask):
  '''
    Returns an element of the array 'class_names' with the index
    where the maximum value from the 'mask' array is located.
    Used to get classname with categorical labels.

    Parameters:
        class_names (array-like): Target array
        mask (array-like): Mask array, elements must be numbers
    Returns:
        One of the element from 'class_names'

    >>> get_classname(['first', 'second'], [0, 1])
    'second'
    >>> get_classname(['first', 'second', third], [1, 0, 0])
    'first'
    '''
  assert len(class_names) == len(mask),"the arrays must of the same length"
  return class_name[np.array(mask).argmax(axis=0)]

In [None]:
sqrt_img = 2 #images per row/col
# the square root of the total number of images shown

plt.figure(figsize=(8,8))
for images,labels in dataset.take(3):
  for index in range(sqrt_img**2):
    # grid 'sqrt_img' x 'sqrt_img'
    plt.subplot(sqrt_img,sqrt_img,index+1)
    plt.imshow(images[index]/255)
    class_name = get_classname(class_names,labels[index])
    plt.axis("off")

Data Augmentation

In [None]:
batch_size=16

In [None]:
train_datagen = ImageDataGenerator(
    rotation_range = 20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    brightness_range=(0.7,1),
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    vertical_flip=False,
    fill_mode='nearest')

train_generator = train_datagen.flow_from_directory(
    image_folder,
    target_size=(img_height,img_width),
    batch_size=batch_size,
    class_mode='categorical'
)

In [None]:
# To see next augmented image
image,label = train_generator.next()

plt.figure(figsize=(6,6))
plt.imshow(image[0]/255)
plt.title("Augmented image from ImageDataGenerator")
plt.axis("off")

Generate n samples for each image

In [None]:
n=10
aug_img_folder = os.path.join('datasets','face_dataset_train_aug_image')
if not os.path.exists(aug_img_folder):
  os.makedirs(aur_img_folder)

In [None]:
# classes: 'me' and 'not_me'
image_folder_to_generate = os.path.join(image_folder,'me')
image_folder_to_save = os.path.join(aug_img_folder,'me')
if not os.path.exists(image_folder_to_save):
  os.makedirs(image_folder_to_save)

i=0
total = len(os.listdir(image_folder_to_generate)) #number of foles in folder
for filename in os.listdir(image_folder_to_generate):
  print("Step {} of {}".format(i+1,total))
  # for each image in folder: read it
  image_path = os.path.join(image_folder_to_generate,filename)
  image = keras.preprocessing.image.load_img(image_path,target_size=(img_height,img_width,3))
  image = keras.preprocessing.image.img_to_array(image)
  # shape from (250,250,3) to (1,250,250,3)
  image = np.extend_dims(image,axis=0)

  # create ImageDataGenerator object for it
  current_image_gen = train_datagen.flow(image,
                                         batch_size=1,
                                         save_to_dir=image_folder_to_save,
                                         save_prefix=filename,
                                         save_format='jpg')
  # generate n samples
  count=0
  for image in current_image_gen: #accessing the object saves the image to disk
    count+=1
    if count==n:
      break
  print('\tGenerate {} samples for file {}'.format(n,filename))
  i+=1
print("\nTotal number images generated = {}".format(n*total))
