## Image Datagenerator
[More info](https://www.tensorflow.org/api_docs/python/tf/keras/preprocessing/image/ImageDataGenerator)

We will take a look at how we can load images for our image classification problem.

1. loading images from a folder
2. loading images from a csv file and imaage path

##### Loading images from a folder

In [1]:
import os

import pandas as pd
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

In [None]:
train_datagen = tf.keras.preprocessing.image.ImageDataGenerator(
    rescale=1./255,
    rotation_range=40,
    width_shift_range=0.2,
    height_shift_range=0.2,
    horizontal_flip=True,
    vertical_flip=True,
    zoom_range=0.2
)

train_generator = train_datagen.flow_from_directory(
    os.path.join(
        'dataset',
        'train'
    ),
    target_size=(256, 256),
    color_mode='rgb',
    class_mode='categorical'
)

In [None]:
test_datagen = tf.keras.preprocessing.image.ImageDataGenerator(
    rescale=1./255
)

valid_generator = test_datagen.flow_from_directory(
    os.path.join(
        'dataset',
        'valid'
    ),
    target_size=(256, 256),
    color_mode='rgb',
    class_mode='categorical'
)

test_generator = test_datagen.flow_from_directory(
    os.path.join(
        'dataset',
        'test'
    ),
    target_size=(256, 256),
    color_mode='rgb',
    class_mode=None
)

##### Loading from a CSV file

In [None]:
def create_dataframe(dir_path: str) -> pd.DataFrame:
    ''' Generate a CSV file to use with ImageDataLoader.

        Parameters:
        - dir_path: str
            path to the directory containing sub folders of each class

        Returns:
        - pd.DataFrame
            first column is the path of the image, and the second column is lable
    '''
    sub_folders = os.listdir(dir_path)

    images_path = []
    labels = []

    for i, sub_folder in enumerate(sub_folders):
        images_name = os.listdir(
            os.path.join(
                dir_path,
                sub_folder
            )
        )
        for image_name in images_name:
            labels.append(str(i))
            images_path.append(
                os.path.join(
                    dir_path,
                    sub_folder,
                    image_name
                )
            )
    
    df = pd.DataFrame(
        {
            'image_path': images_path,
            'label': labels
        }
    )

    return df

In [None]:
train_df = create_dataframe(
    dir_path=os.path.join(
        'dataset',
        'train'
    )
)

train_df.head()

In [None]:
valid_df = create_dataframe(
    dir_path=os.path.join(
        'dataset',
        'valid'
    )
)

valid_df.head()

In [None]:
test_df = create_dataframe(
    dir_path=os.path.join(
        'dataset',
        'test'
    )
)

test_df.head()

In [None]:
train_datagen = tf.keras.preprocessing.image.ImageDataGenerator(
    rescale=1./255,
    rotation_range=40,
    width_shift_range=0.2,
    height_shift_range=0.2,
    horizontal_flip=True,
    vertical_flip=True,
    zoom_range=0.2
)

train_generator = train_datagen.flow_from_dataframe(
    dataframe=train_df,
    x_col='image_path',
    y_col='label',
    target_size=(256, 256),
    color_mode='rgb',
    class_mode='categorical'
)

In [None]:
test_datagen = tf.keras.preprocessing.image.ImageDataGenerator(
    rescale=1./255
)

valid_generator = test_datagen.flow_from_dataframe(
    dataframe=valid_df,
    x_col='image_path',
    y_col='label',
    target_size=(256, 256),
    color_mode='rgb',
    class_mode='categorical'
)

test_generator = test_datagen.flow_from_dataframe(
    dataframe=test_df,
    x_col='image_path',
    target_size=(256, 256),
    color_mode='rgb',
    class_mode=None
)

##### Let's see the image generators

##### Visualize images

In [None]:
plt.figure(figsize=(20, 9))

for i in range(6):
    plt.subplot(2, 3, i+1)
    plt.imshow(valid_generator[0][0][i])
    plt.title(label=f'Label - {np.argmax(valid_generator[0][1][i])}')
    plt.grid(visible=False)
plt.tight_layout()
plt.show()

In [None]:
fig, axes = plt.subplots(2, 3, figsize=(20, 9))
axes = axes.flatten()

for i, ax in enumerate(axes):
    ax.imshow(valid_generator[0][0][i])
    ax.set_title(label=f'Label - {np.argmax(valid_generator[0][1][i])}')

plt.tight_layout()
plt.show()