# Augmentation

### Let's first see how to show the transformations applied!

In [2]:
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from skimage import io

In [3]:
'''
Usually these transformations are applied on the fly --> not need for saving them!
Here I save the results to show what happen!
'''

## Defining the transformations!
datagen = ImageDataGenerator(
        rotation_range=45,     #Random rotation between 0 and 45
        width_shift_range=0.2,   #% shift
        height_shift_range=0.2,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True,
        fill_mode='constant', cval=125)    #Also try nearest, constant, reflect, wrap

In [5]:
# loading an image!
x = io.imread('./images/einstein_mona_lisa/einstein_original.jpg')  #Array with shape (256, 256, 3)

In [6]:
x = x.reshape((1, ) + x.shape)  #Array with shape (1, 256, 256, 3)

In [8]:
'''
Saving the results in the directory and after 20 batch of generated images, I break the loop!
.flow() is used since here the transformations are applied to one only image! (can be used also for an array)
 '''
i = 0
for batch in datagen.flow(x, batch_size=16,  
                          save_to_dir='../image-processing/images/augmented', 
                          save_prefix='aug', 
                          save_format='png'):
    i += 1
    if i > 20:
        break  # otherwise the generator would loop indefinitely 

In [11]:
'''
To read from a directory, I need to change the method!
This allows for multiclass augmentation!
'''
i = 0
for batch in datagen.flow_from_directory(directory='../image-processing/images/einstein_mona_lisa/', 
                                         batch_size=16,  
                                         target_size=(256, 256),
                                         color_mode="rgb",
                                         save_to_dir='../image-processing/images/augmented', 
                                         save_prefix='aug', 
                                         save_format='png'):
    i += 1
    if i > 31:
        break 

Found 22 images belonging to 2 classes.


### Let's see how to use augmentation for training and validation

In [None]:
'''
Step 1: Defining the transformation I want to apply!
'''

# this is the augmentation configuration we will use for training
train_datagen = ImageDataGenerator(
        rescale=1./255,
        rotation_range=45,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True)

# this is the augmentation configuration we will use for validation:
# only rescaling. But you can try other operations
validation_datagen = ImageDataGenerator(rescale=1./255)

In [None]:
'''
Step 2: Defining from where the images have to be loaded and other needed infos!
'''
train_generator = train_datagen.flow_from_directory(
        'cell_images',  # this is the input directory
        target_size=(150, 150),  # all images will be resized to 64x64
        batch_size=batch_size,
        class_mode='binary')  # since we use binary_crossentropy loss, we need binary labels

# this is a similar generator, for validation data
validation_generator = validation_datagen.flow_from_directory(
        'cell_validation',
        target_size=(150, 150),
        batch_size=batch_size,
        class_mode='binary')

In [None]:
'''
After having defined the model, i can fit and validate using the fit_generator() method!
'''
#Add checkpoints 
from keras.callbacks import ModelCheckpoint
#filepath='saved_models/models.h5'
filepath="saved_models/weights-improvement-{epoch:02d}-{val_acc:.2f}.hdf5" #File name includes epoch and validation accuracy.
checkpoint = ModelCheckpoint(filepath, monitor='val_acc', verbose=1, save_best_only=True, mode='max')
callbacks_list = [checkpoint]

#We can now use these generators to train our model. 
model.fit_generator(
        train_generator,
        steps_per_epoch=2000 // batch_size,    #The 2 slashes division return rounded integer
        epochs=5,
        validation_data=validation_generator,
        validation_steps=800 // batch_size,
        callbacks=callbacks_list)