In [26]:
import tensorflow as tf
import tensorflow.keras.layers as layers
from tensorflow.keras.preprocessing.image import ImageDataGenerator

## We can perform standard image augmentations within Keras ImageDataGenerator module. For this we use the CIFAR10 dataset.

In [41]:
mnist = tf.keras.datasets.mnist
cifar10 = tf.keras.datasets.cifar10
(x_train, y_train),(x_test,y_test) = cifar10.load_data()
print(x_train.shape)

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
(50000, 32, 32, 3)


## We first set up an ImageDataGenerator which will perform the augmentations we have specified. In this case we perfrom random flipping in both directions, along with random rotations up to a specified range.

In [42]:
datagen = ImageDataGenerator(vertical_flip= True, #random vertical flipping of training images
                            horizontal_flip = True, #random horizontal flipping of training images
                            rotation_range= 40)

## Setup the model

In [45]:
model = tf.keras.models.Sequential()

model.add(layers.Flatten(input_shape = (32,32,3)))
model.add(layers.Dense(64, activation = 'relu'))
model.add(layers.Dropout(0.2))
model.add(layers.Dense(10,activation='softmax'))

model.compile(optimizer ='adam',
             loss = 'sparse_categorical_crossentropy',
             metrics = ['accuracy'])

## We use keras' flow method to pass the image data through our ImageDataGenerator which will augment the data before feeding it to our model.

In [46]:
x_train,x_test = x_train/255.0, x_test/255.0
model.fit(datagen.flow(x_train,y_train), epochs= 5)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<tensorflow.python.keras.callbacks.History at 0x1a66480d940>

In [47]:
model.evaluate(x_test,y_test)



[1.8057101964950562, 0.08910000324249268]

## The results are horrendous, but the point was to show the use of image augementation. Of course we could manipulate the data a bit , a good first step would be to perform one-hot encoding of the labels. Along with tuning the model/adding other layers.