In [None]:
from keras.datasets import mnist
from keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt
import os

In [None]:
#load data
(X_train, y_train) , (X_test, y_test) = mnist.load_data()

#reshape to be [samples][channels][width][height]
X_train = X_train.reshape(X_train.shape[0], 28, 28,1).astype('float32')
X_test = X_test.reshape(X_test.shape[0], 28, 28,1).astype('float32')

In [None]:
#1. define data preparation by feature standardization
datagen = ImageDataGenerator(featurewise_center=True, featurewise_std_normalization=True)

#fit parameters from data
datagen.fit(X_train)

In [None]:
#2. ZCA whitening
datagen = ImageDataGenerator(zca_whitening=True)

#fit parameters from data
datagen.fit(X_train)

In [None]:
#3. random rotation
datagen = ImageDataGenerator(rotation_range=90)

#fit parameters from data
datagen.fit(X_train)

In [None]:
#4. random shifts
datagen = ImageDataGenerator(width_shift_range=0.2, height_shift_range=0.2)

#fit parameters from data
datagen.fit(X_train)

In [None]:
#5. random flips
datagen = ImageDataGenerator(horizontal_flip=True, vertical_flip=True)

#fit parameters from data
datagen.fit(X_train)

In [None]:
#configure batch size and retrieve one batch of images
for X_batch, y_batch in datagen.flow(X_train, y_train, batch_size=9):
    #create a grid of 3*3 images
    for i in range(0, 9):
        plt.subplot(330+1+i)
        plt.imshow(X_batch[i].reshape(28,28), cmap=plt.get_cmap('gray'))
    plt.show()
    break

In [None]:
#save to a folder
os.makedirs('images')
#configure batch size and retrieve one batch of images
for X_batch, y_batch in datagen.flow(X_train, y_train, batch_size=9, save_to_dir='images', save_prefix='aug', save_format='png'):
    #create a grid of 3*3 images
    for i in range(0, 9):
        plt.subplot(330+1+i)
        plt.imshow(X_batch[i].reshape(28,28), cmap=plt.get_cmap('gray'))
    plt.show()
    break