[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/aldomunaretto/immune_deep_learning/blob/main/notebooks/02_CNN/09_CNN_data_augmentation.ipynb)

# Convolutional Neural Networks

## Data Augmentation


---

<a id="index"></a>
## Index

* [0. Context](#section0)
* [1. Data Augmentation](#section1)
* [2. Native MNIST Data](#section2)
* [3. Feature Standardization](#section3)
* [4. ZCA Whitening](#section4)
* [5. Random Rotations](#section5)
* [6. Random Shifts](#section6)
* [7. Random Flips](#section7)
* [8. Save Augmented Images](#section8)
* [9. Final Tips](#section9)

---
<a id="section0"></a>
## Context

In this lesson, you will learn how to use image data preparation and augmentation:
* About the Keras API for image augmentation.
* How to perform feature standardization.
* How to apply ZCA whitening to your images.
* How to augment data with random rotations, shifts, and flips of images.
* How to save augmented image data to disk.

In [None]:
import tensorflow as tf
# Remove warning
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

---
<a id="section1"></a>
## Data Augmentation

Keras provides the `ImageDataGenerator` class for image data preparation and augmentation. This includes capabilities such as:
* Feature standardization.
* ZCA whitening.
* Random rotation, shifts, and flips.
* Dimension reordering.
* Saving augmented images to disk.


![Data Augmentation](https://raw.githubusercontent.com/aldomunaretto/immune_deep_learning/main/image/notebooks/data-augmentation.jpeg)

The augmented image generator object will be:
```python
    datagen = ImageDataGenerator()
```

Once the `ImageDataGenerator` object is created, you need to fit it to your data in order to perform the transformations later. To do this, you will use the `fit()` function, passing it the dataset.
```python
    datagen.fit(train)
```

We can configure the batch size and prepare the data generator by calling the `flow()` function.
```python
    X_batch, y_batch = datagen.flow(train, train, batch_size=32)
```

Finally, we can use the data generator by calling the `fit_generator()` function, passing the data generator, the desired duration of an epoch, and the total number of epochs for training.

```python
    fit_generator(datagen, samples_per_epoch=len(train), epochs=100)
```

You can learn more about the Keras image data generator API in the [official documentation.](http://keras.io/preprocessing/image/)


Here is an example that uses the `ImageDataGenerator` class from Keras to load and preprocess images.

1. Import the necessary libraries

In [None]:
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt


2. Prepare the Dataset: We will use the TensorFlow flowers dataset for this example. You can replace it with your own dataset if you prefer.

In [None]:
import tensorflow_datasets as tfds

# Download and prepare the dataset
(dataset, ds_info) = tfds.load('tf_flowers', split='train', with_info=True, as_supervised=True)


3. Configure the Image Data Generator: Set up the `ImageDataGenerator` instance to augment the data.

In [None]:
datagen = ImageDataGenerator(
    rotation_range=40,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest'
)

4. Create a function to display augmented images

In [None]:
def plot_images(images_arr):
    fig, axes = plt.subplots(1, 5, figsize=(20,20))
    axes = axes.flatten()
    for img, ax in zip(images_arr, axes):
        ax.imshow(img)
        ax.axis('off')
    plt.tight_layout()
    plt.show()

5. Apply data augmentation to a sample of images

In [None]:
# Retrieve a sample of images from the dataset
for image, label in dataset.take(1):
    image = tf.image.resize(image, (150, 150))
    image = tf.expand_dims(image, 0)  # Add an extra dimension for the batch.

# Generate a batch of augmented images.
augmented_images = [next(datagen.flow(image, batch_size=1))[0].astype('uint8') for _ in range(5)]

# Display the augmented images.
plot_images(augmented_images)


---
<a id="section2"></a>
## Native MNIST Data

Veamos un ejemplo por lo que vamos a echar un vistazo a las primeras 9 imágenes.

In [None]:
# Plot of images as baseline for comparison
from keras.datasets import mnist
import matplotlib.pyplot as plt
# load data
(X_train, y_train), (X_test, y_test) = mnist.load_data()
# create a grid of 3x3 images
for i in range(0, 9):
    plt.subplot(330 + 1 + i)
    plt.imshow(X_train[i], cmap=plt.get_cmap('gray'))
# show the plot
plt.show()

---
<div style="text-align: right"> <font size=5> <a href="#indice"><i class="fa fa-arrow-circle-up" aria-hidden="true" style="color:#004D7F"></i></a></font></div>

---

<a id="section3"></a>
# <font color="#004D7F" size=6>3. Estandarización de características</font>

También es posible estandarizar los valores de los píxeles en todo el conjunto de datos.

Los valores de los píxeles están estandarizados en todas las muestras (todas las imágenes del conjunto de datos). En este caso, cada imagen se considera una característica.

Podemos realizar la estandarización estableciendo los argumentos de `featurewise_center` y `featurewise_std_normalization` en la clase `ImageDataGenerator`.

In [None]:
# Standardize images across the dataset, mean=0, stdev=1
from keras.datasets import mnist
from keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt

# load data
(X_train, y_train), (X_test, y_test) = mnist.load_data()

# reshape to be [samples][width][height][channels]
X_train = X_train.reshape((X_train.shape[0], 28, 28, 1))
X_test = X_test.reshape((X_test.shape[0], 28, 28, 1))

# convert from int to float
X_train = X_train.astype('float32')
X_test = X_test.astype('float32')

# Normalize values in range from 0 to 1
X_train /= 255
X_test /= 255

# define data preparation
datagen = ImageDataGenerator(featurewise_center=True, featurewise_std_normalization=True)

# fit parameters from data
datagen.fit(X_train)

# configure batch size and retrieve one batch of images
for X_batch, y_batch in datagen.flow(X_train, y_train, batch_size=9):
  for i in range(0,9):
    plt.subplot(330+1+i)
    plt.imshow(X_batch[i].reshape(28,28), cmap=plt.get_cmap('gray'))
  plt.show()
  break


---
<div style="text-align: right"> <font size=5> <a href="#indice"><i class="fa fa-arrow-circle-up" aria-hidden="true" style="color:#004D7F"></i></a></font></div>

---

<a id="section4"></a>
# <font color="#004D7F" size=6>4. Blanqueamiento ZCA </font>

Una transformación de blanqueamiento de una imagen es una operación de álgebra lineal que reduce la redundancia en la matriz de imágenes de píxeles.

ZCA muestra mejores resultados y resultados en imágenes transformadas que mantienen todas las dimensiones originales y, a diferencia de PCA, las imágenes transformadas resultantes todavía se ven como sus originales.

Para realizar ZCA configuramos el argumento `zca_whitening` en `True`.

<div class="alert alert-block alert-info">
    
<i class="fa fa-info-circle" aria-hidden="true"></i>
Puede obtener más información sobre [ZCA](http://www.cs.toronto.edu/~kriz/learning-features-2009-TR.pdf)


In [None]:
# load data
(X_train, y_train), (X_test, y_test) = mnist.load_data()

# reshape to be [samples][width][height][channels]
X_train = X_train.reshape((X_train.shape[0], 28, 28, 1))
X_test = X_test.reshape((X_test.shape[0], 28, 28, 1))

# convert from int to float
X_train = X_train.astype('float32')
X_test = X_test.astype('float32')

# Normalize values in range from 0 to 1
X_train /= 255
X_test /= 255

# define data preparation
datagen = ImageDataGenerator(zca_whitening=True)

# fit parameters from data
datagen.fit(X_train)

# configure batch size and retrieve one batch of images
for X_batch, y_batch in datagen.flow(X_train, y_train, batch_size=9):
  for i in range(0,9):
    plt.subplot(330+1+i)
    plt.imshow(X_batch[i].reshape(28,28), cmap=plt.get_cmap('gray'))
  plt.show()
  break

---
<div style="text-align: right"> <font size=5> <a href="#indice"><i class="fa fa-arrow-circle-up" aria-hidden="true" style="color:#004D7F"></i></a></font></div>

---

<a id="section5"></a>
# <font color="#004D7F" size=6>5. Rotaciones aleatorias </font>

A veces, las imágenes de los datos de muestra pueden tener rotaciones diferentes y variables en la escena.

Podemos crear rotaciones aleatorias de los dígitos MNIST hasta 90 grados estableciendo el argumento `rotation_range`.

In [None]:
# load data
(X_train, y_train), (X_test, y_test) = mnist.load_data()

# reshape to be [samples][width][height][channels]
X_train = X_train.reshape((X_train.shape[0], 28, 28, 1))
X_test = X_test.reshape((X_test.shape[0], 28, 28, 1))

# convert from int to float
X_train = X_train.astype('float32')
X_test = X_test.astype('float32')

# Normalize values in range from 0 to 1
X_train /= 255
X_test /= 255

# define data preparation
datagen = ImageDataGenerator(rotation_range=90)

# fit parameters from data
datagen.fit(X_train)

# configure batch size and retrieve one batch of images
for X_batch, y_batch in datagen.flow(X_train, y_train, batch_size=9):
  for i in range(0,9):
    plt.subplot(330+1+i)
    plt.imshow(X_batch[i].reshape(28,28), cmap=plt.get_cmap('gray'))
  plt.show()
  break

---
<div style="text-align: right"> <font size=5> <a href="#indice"><i class="fa fa-arrow-circle-up" aria-hidden="true" style="color:#004D7F"></i></a></font></div>

---

<a id="section6"></a>
# <font color="#004D7F" size=6>6. Desplazamientos aleatorios </font>

Keras admite cambios aleatorios horizontales y verticales mediante los argumentos `width_shift_range` y `height_shift_range`.

In [None]:
# load data
(X_train, y_train), (X_test, y_test) = mnist.load_data()

# reshape to be [samples][width][height][channels]
X_train = X_train.reshape((X_train.shape[0], 28, 28, 1))
X_test = X_test.reshape((X_test.shape[0], 28, 28, 1))

# convert from int to float
X_train = X_train.astype('float32')
X_test = X_test.astype('float32')

# Normalize values in range from 0 to 1
X_train /= 255
X_test /= 255

# define data preparation
datagen = ImageDataGenerator(width_shift_range=0.2, height_shift_range=0.2)

# fit parameters from data
datagen.fit(X_train)

# configure batch size and retrieve one batch of images
for X_batch, y_batch in datagen.flow(X_train, y_train, batch_size=9):
  for i in range(0,9):
    plt.subplot(330+1+i)
    plt.imshow(X_batch[i].reshape(28,28), cmap=plt.get_cmap('gray'))
  plt.show()
  break

---
<div style="text-align: right"> <font size=5> <a href="#indice"><i class="fa fa-arrow-circle-up" aria-hidden="true" style="color:#004D7F"></i></a></font></div>

---

<a id="section7"></a>
# <font color="#004D7F" size=6>7. Giros aleatorios </font>

Keras admite la inversión aleatoria a lo largo de los ejes vertical y horizontal utilizando los argumentos `vertical_flip` y `horizontal_flip`.

In [None]:
# load data
(X_train, y_train), (X_test, y_test) = mnist.load_data()

# reshape to be [samples][width][height][channels]
X_train = X_train.reshape((X_train.shape[0], 28, 28, 1))
X_test = X_test.reshape((X_test.shape[0], 28, 28, 1))

# convert from int to float
X_train = X_train.astype('float32')
X_test = X_test.astype('float32')

# Normalize values in range from 0 to 1
X_train /= 255
X_test /= 255

# define data preparation
datagen = ImageDataGenerator(vertical_flip=True, horizontal_flip=True)

# fit parameters from data
datagen.fit(X_train)

# configure batch size and retrieve one batch of images
for X_batch, y_batch in datagen.flow(X_train, y_train, batch_size=9):
  for i in range(0,9):
    plt.subplot(330+1+i)
    plt.imshow(X_batch[i].reshape(28,28), cmap=plt.get_cmap('gray'))
  plt.show()
  break

---
<div style="text-align: right"> <font size=5> <a href="#indice"><i class="fa fa-arrow-circle-up" aria-hidden="true" style="color:#004D7F"></i></a></font></div>

---

<a id="section8"></a>
# <font color="#004D7F" size=6>8. Guardar imágenes aumentadas </font>

El directorio, el prefijo del nombre de archivo y el tipo de archivo de imagen se pueden especificar en la función `flow()` antes del entrenamiento.

El siguiente ejemplo demuestra esto y escribe 9 imágenes en un subdirectorio `Img` con el prefijo `aug` y el tipo de archivo PNG.

In [None]:
# Save augmented images to file
import os

# load data
(X_train, y_train), (X_test, y_test) = mnist.load_data()

# reshape to be [samples][width][height][channels]
X_train = X_train.reshape((X_train.shape[0], 28, 28, 1))
X_test = X_test.reshape((X_test.shape[0], 28, 28, 1))

# convert from int to float
X_train = X_train.astype('float32')
X_test = X_test.astype('float32')

# Normalize values in range from 0 to 1
X_train /= 255
X_test /= 255

# define data preparation
datagen = ImageDataGenerator(
    rotation_range=40,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest'
)

# fit parameters from data
datagen.fit(X_train)

# configure batch size and retrieve one batch of images
os.makedirs('MNIST')

for X_batch, y_batch in datagen.flow(X_train, y_train, batch_size=9, save_to_dir='MNIST', save_prefix='augmented', save_format='png'):
  for i in range(0,9):
    plt.subplot(330+1+i)
    plt.imshow(X_batch[i].reshape(28,28), cmap=plt.get_cmap('gray'))
  plt.show()
  break

---
<div style="text-align: right"> <font size=5> <a href="#indice"><i class="fa fa-arrow-circle-up" aria-hidden="true" style="color:#004D7F"></i></a></font></div>

---

<a id="section9"></a>
# <font color="#004D7F" size=6>9. Consejos finales </font>

A continuación, se incluyen algunos consejos para aprovechar al máximo esta técnica:
* **Revisar el conjunto de datos**. Tómese un tiempo para revisar su conjunto de datos con gran detalle.
* **Revisar ampliaciones**. Revise las imágenes de muestra después de que se haya realizado el aumento.
* **Evaluar un conjunto de transformaciones**. Pruebe más de un esquema de preparación.

<div style="text-align: right"> <font size=5> <a href="#indice"><i class="fa fa-arrow-circle-up" aria-hidden="true" style="color:#004D7F"></i></a></font></div>

---

<div style="text-align: right"> <font size=6><i class="fa fa-coffee" aria-hidden="true" style="color:#004D7F"></i> </font></div>