# Image Augmentation

<font color='steelblue'>

<font size = 5>
    <strong>Image Augmentation Example</strong><br><br>
    Examples of number of things that can be done on an image<br><br>
</font>
</font>

<font color = 'grey'>
<font size = 4>
    <b>Following examples are included in the processing:</b><br>
    <ol>
        <li>Read images from a folder and set class names</li>
        <li>Apply image augmentation as these images are read</li>
        <li>Single Image processing</li>
        <ul>
            <li>Read image using matplotlib library</li>
            <li>Rotate image</li>
            <li>Width and Height Shifting</li>
            <li>Brightness, Shear and Zoom</li>
            <li>Vertical Flip</li>
            <li>Convert to Gray Scale</li>
        </ul>
    </ol>
</font>

</font>

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

#plt.style.use('seaborn-whitegrid')    # grids in the plots
import warnings
warnings.filterwarnings('ignore')

import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from matplotlib.pyplot import imread, imshow, subplots, show
import matplotlib.image as mpimg

## Import Images from Folder<br>

<font size = 3>

Often as part of image processing, the data might be defined in different sub-folders.<br>

**In the following example, will look at:**
- Explore the folders and sub-folders
- Extract the class names (image categories)
    
</font>

In [None]:
# explore the folders & sub-folders
import os

for dirpath, dirnames, filenames in os.walk("../datasets/images/"):
    print(f'There are {len(dirnames)} directories and {len(filenames)} images in {dirpath}')

In [None]:
# define path
imagesPath = "../datasets/images/types/"

In [None]:
import pathlib
dataDir = pathlib.Path(imagesPath)

classNames = np.array(sorted([item.name for item in dataDir.glob('*')]))
print(classNames)

In [None]:
# function to plot images
def plotImages(imgFiles, classnames):
    """
    plotImages: Plot images based on file names provided
    imgFiles: list of filenames (full path)
    classnames: class names associated with each file
    NOTE: this function assumes that for each class there are 5 files
    """
    cidx = -1
    for i in range(len(imgFiles)):
        if i % 5 == 0:
            cidx = cidx + 1
        filename = imgFiles[i]
        #print(imgFiles[i])
        ax = plt.subplot(6, 5,i + 1)
        data = mpimg.imread(filename)
        dshape = data.shape
        plt.imshow(data)
        # print the class name and the shape of image
        plt.title(f"{classnames[cidx]}\n{dshape}")
        plt.xticks([])
        plt.yticks([])
        plt.xlabel(data.shape)
        plt.grid(False)
        plt.axis('off');

In [None]:
# function to plot images
def plotPerClass(imgFiles, cname):
    """
    plotPerClass: Plot images per class name
    imgFiles: list of filenames (full path)
    cname: class name
    NOTE: this function assumes that for each class there are 5 files
    """
    startindex = -1
    for i, name in enumerate(classNames):
        if name == cname:
            startindex = i * 5
    endindex = startindex + 5
    #print(f"{startindex}, {endindex}")
    filenames = []
    for i in range(startindex, endindex):
        filenames.append(imgFiles[i])

    for i, afile in enumerate(filenames):
        ax = plt.subplot(1, 5, i + 1)
        data = mpimg.imread(afile)
        plt.imshow(data)
        plt.title(f"{cname}\n{data.shape}")
        plt.xticks([])
        plt.yticks([])
        plt.grid(False)
        plt.axis('off');

In [None]:
# get all the file names to be plotted

imgFiles = []
for name in classNames:
    folder = imagesPath + name
    files = os.listdir(folder)
    for afile in files:
        fname = folder + "/" + afile
        imgFiles.append(fname)

In [None]:
plt.figure(figsize = (18, 18))
plotImages(imgFiles, classNames)

## Image Augmentation<br>

<font size = "3">

**Data or in this case image augmentation is a process of altering the training data, that will help with:**
- adding more diversity to the data so that the model can learn more generalization
- might show data in a "real world" like scanrios

**Following is included in the example below:**

- **rotation_range:** rotate image between 0 and x degrees
- **shear_range:** shear the image
- **zoom_range:** zoom into the image
- **width_shift_range:** shift the image width ways
- **height_shift_range:** shift the image height ways
- **horizontal_flip:** flip image on the horizontal axis

<font size = "3">
    
[Documentation - Data Augmentation](https://www.tensorflow.org/tutorials/images/data_augmentation)
</font>

In [None]:
imgAugmented = ImageDataGenerator(rotation_range =  20,
                                 shear_range = 0.3,
                                 zoom_range = 0.5,
                                 width_shift_range = 0.4,
                                 height_shift_range = 0.3, 
                                 horizontal_flip = True)

### Apply the image data generator while reading the files from directory

In [None]:
images = imgAugmented.flow_from_directory(imagesPath,
                                          target_size = (224, 224),
                                          shuffle = False, 
                                          batch_size = 5)

In [None]:
print(type(images))

In [None]:
# function to plot augmented images
def plotAugImages(imgs, classindex):
    for i in range(5):
        ax = plt.subplot(5, 5,i + 1)
        plt.imshow(imgs[i].astype('uint8'))
        plt.title(f"{classNames[classindex]}\n{imgs[i].shape}")
        plt.axis('off');

### Plot the augmented images

In [None]:
plt.figure(figsize = (18, 18))
augImages, augLabels = images.next()
idx = tf.argmax(augLabels[0]).numpy()
plotAugImages(augImages, idx)

### Compare with original images

In [None]:
plt.figure(figsize = (18, 18))
cname = classNames[idx]
plotPerClass(imgFiles, cname)

In [None]:
plt.figure(figsize = (18, 18))
augImages, augLabels = images.next()
idx = tf.argmax(augLabels[0]).numpy()
plotAugImages(augImages, idx)

In [None]:
plt.figure(figsize = (18, 18))
cname = classNames[idx]
plotPerClass(imgFiles, cname)

## Single image processing

### Read the image

In [None]:
image = imread('../datasets/home-4.jpg')
# image = imread('../datasets/dog.jpg')

In [None]:
# Shape of our RGB image
image.shape[0], image.shape[1], image.shape[2]

In [None]:
print(type(image))

In [None]:
# Creating a dataset which contains just one image
images = image.reshape((1, image.shape[0], image.shape[1], image.shape[2]))

In [None]:
print(type(images))

In [None]:
images.shape

In [None]:
imshow(images[0])
show()

In [None]:
# Function to fit and plot images defined by the data_generator parameter
def plot(data_generator, nimages = 4):
    """
    Plots 4 images generated by an object of the ImageDataGenerator class.
    """
    data_generator.fit(images)
    image_iterator = data_generator.flow(images)
    
    # Plot the images given by the iterator
    fig, rows = subplots(nrows=1, ncols = nimages, figsize=(18,18))
    for row in rows:
        row.imshow(image_iterator.next()[0].astype('int'))
        row.axis('off')
    show()

## Image Rotation<br>
<font color='gray'>
<span style="font-family:Arial; font-size:14pt; font-style:bold">
    rotation_range - randomly rotates the image between positive and negative range in degrees
</span>
</font>

In [None]:
data_generator = ImageDataGenerator(rotation_range = 45)

In [None]:
plot(data_generator)

## Width Shifting<br>
<font color='gray'>
<span style="font-family:Arial; font-size:14pt; font-style:bold">
    width_shift - randomly shifts the image to the left or the right by a fraction of the total width provided by the upper bound of a number between 0.0 and 1.0
</span>
</font>

In [None]:
data_generator = ImageDataGenerator(width_shift_range=0.3)

In [None]:
plot(data_generator)

## Height Shifting<br>
<font color='gray'>
<span style="font-family:Arial; font-size:14pt; font-style:bold">
    height_shift - randomly shifts the image to the virtically by a fraction of the total width provided by the upper bound of a number between 0.0 and 1.0
</span>
</font>

In [None]:
data_generator = ImageDataGenerator(height_shift_range=0.4)

In [None]:
plot(data_generator)

## Brightness<br>
<font color='gray'>
<span style="font-family:Arial; font-size:14pt; font-style:bold">
    brightness_range - The brightness_range specifies the range for randomly picking a brightness shift value from. A brightness of 0.0 corresponds to absolutely no brightness, and 1.0 corresponds to maximum brightness.
</span>
</font>

In [None]:
data_generator = ImageDataGenerator(brightness_range=(0.2, 0.6))

In [None]:
plot(data_generator)

## Shear<br>
<font color='gray'>
<span style="font-family:Arial; font-size:14pt; font-style:bold">
    shear_range - Unlike rotation where the angle of the image changes, here one axis is fixed and the image is stretched at a shear angle. Shear_range specifies the angle of the slant in degrees
</span>
</font>

In [None]:
data_generator = ImageDataGenerator(shear_range=30.0)

In [None]:
plot(data_generator)

## Zoom<br>
<font color='gray'>
<span style="font-family:Arial; font-size:14pt; font-style:bold">
    zoom_range - A zoom less than 1.0 magnifies the image, where as zoom greater than 1.0 zooms out the image, zoom_range provides that random zoom paramter
</span>
</font>

In [None]:
data_generator = ImageDataGenerator(zoom_range=[0.5, 1.5])

In [None]:
plot(data_generator)

## Vertical Flip<br>
<font color='gray'>
<span style="font-family:Arial; font-size:14pt; font-style:bold">
    vertical_flip - Unlike rotation where the angle of the image changes, here one axis is fixed and the image is stretched at a shear angle. Shear_range specifies the agnel of the slant in degrees
</span>
</font>

In [None]:
data_generator = ImageDataGenerator(vertical_flip=True)

In [None]:
plot(data_generator, 2)

## Convert Image to Gray Scale<br>
<font color='gray'>
<span style="font-family:Arial; font-size:14pt; font-style:bold">
    Following code shows how to convert image to gray scale
</span>
</font>

In [None]:
# parameters in the array are:
# 0.2989 * Red + 0.5870 * Green + 0.1140 * Blue
def rgb2gray(rgb):
    return np.dot(rgb[..., :3], [0.2989, 0.5870, 0.1140])

In [None]:
gray = rgb2gray(image)

In [None]:
imshow(gray, cmap=plt.get_cmap('gray'))
show()