# Data Augmentation Using Generative Adversarial Networks (GANs)

Notebook to perform data augmentation using Generative Adversarial Networks (GANs). This notebook must be run after running the notebook `prepare_dataset_and_create_project_structure.ipynb` that should be available in the same directory.

---
## Prerequisites

1. Ensure that the Cityscapes dataset is downloaded and placed in the directory named **dataset**. More information about the dataset can be found in the notebook `prepare_dataset_and_create_project_structure.ipynb` that should be available in the same directory. 
2. Run all the cells in the notebook `prepare_dataset_and_create_project_structure.ipynb` to understand the dataset, process the dataset, create the project structure required to ensure correct results from the project.

### 1. Check Python version

It is crucial to ensure that the notebook runs on the correct version of Python to guarantee proper functionality. 

In [None]:
import platform
assert (platform.python_version_tuple()[:2] >= ('3','7')), "[ERROR] The notebooks are tested on Python 3.7 and higher. Please updated your Python to evaluate the code"

### 2. Check Notebook server has access to all required resources

In [None]:
from pathlib import Path

dataset_folder = Path("dataset")
dataset_folder = Path.joinpath(Path.cwd(), dataset_folder)

if not dataset_folder.exists():
    raise FileNotFoundError("[ERROR] Add `{}` folder in the current directory (`{}`)".format(dataset_folder.name, Path.cwd()))

In [None]:
dataset_preparation_notebook = Path("prepare_dataset_and_create_project_structure.ipynb")
dataset_preparation_notebook = Path.joinpath(Path.cwd(), dataset_preparation_notebook)

if not dataset_preparation_notebook.exists():
    raise FileNotFoundError("[ERROR] The notebook `{}` is unavailable in the current directory (`{}`). Please download and run the notebook `{}` before this notebook to ensure proper results.".format(dataset_preparation_notebook.name, Path.cwd(), dataset_preparation_notebook.name))

In [None]:
test_dataset = Path.joinpath(dataset_folder, "test_dataset")
test_dataset_A = Path.joinpath(test_dataset, "A")
test_dataset_A = Path.joinpath(test_dataset, "B")

test_dataset_overall = [test_dataset, test_dataset_A, test_dataset_B]

for dataset in test_dataset_overall:
    if not dataset.exists():
        raise FileNotFoundError("[ERROR] The folder `{}` is unavailable. Please run the notebook `prepare_dataset_and_create_project_structure.ipynb` available in the current directory (`{}`) before running this notebook.".format(dataset.name, Path.cwd()))

In [None]:
training_dataset = Path.joinpath(dataset_folder, "training_dataset")
training_dataset_A = Path.joinpath(training_dataset, "A")
training_dataset_B = Path.joinpath(training_dataset, "B")

training_dataset_overall = [training_dataset, training_dataset_A, training_dataset_B]

for dataset in training_dataset_overall:
    if not dataset.exists():
        raise FileNotFoundError("[ERROR] The folder `{}` is unavailable. Please run the notebook `prepare_dataset_and_create_project_structure.ipynb` available in the current directory (`{}`) before running this notebook.".format(dataset.name, Path.cwd()))

In [None]:
validatation_dataset = Path.joinpath(dataset_folder, "validatation_dataset")
validatation_dataset_A = Path.joinpath(validatation_dataset, "A")
validatation_dataset_B = Path.joinpath(validatation_dataset, "B")

validatation_dataset_overall = [validatation_dataset, validatation_dataset_A, validatation_dataset_B]

for dataset in validatation_dataset_overall:
    if not dataset.exists():
        raise FileNotFoundError("[ERROR] The folder `{}` is unavailable. Please run the notebook `prepare_dataset_and_create_project_structure.ipynb` available in the current directory (`{}`) before running this notebook.".format(dataset.name, Path.cwd()))

---
## Introduction

One of the biggest bottlenecks in creating generalized deep learning models is a scarcity of high-quality data. The collection of high-quality data and its conversion is expensive. Most of the data collection methods are labor-intensive and error-prone, requiring considerable editing afterward to clean the data. Since large amounts of data are needed to achieve generalized deep learning models, standard data augmentation methods are routinely used to increase the dataset's generalizability. Data augmentation methods are also used when the datasets are imbalanced, improving the model's overall performance.

Generative Adversarial Networks, popularly known as GANs, are a novel method for data augmentation. The generation of artificial training data can not only be instrumental in situations such as imbalanced data sets, but it can also be useful when the original dataset contains sensitive information. In such cases, it is then desirable to avoid using the original data as much as possible (For example, Medical data).

This project proposes a GAN architecture based on this [paper](https://arxiv.org/abs/1611.07004) to perform data augmentation using the popular image-to-image translation method. Generative Adversarial Networks trained on these methods learn the mapping from an input image to an output image and learn a loss function to train this mapping. Therefore, this approach makes it possible to apply the same generic approach to problems that traditionally would require very different loss formulations. In this particular project, we demonstrate that this approach can effectively synthesize photos from label maps. To evaluate the performance of the proposed GAN architecture, we utilize a standard dataset named Cityscapes. The Cityscapes Dataset focuses on semantic understanding of urban street scenes. The dataset contains 5000 images with detailed annotations and 20000 images with coarse annotations apart from the original images. Some sample images from the dataset are presented below:

![Sample Image from Cityscape Dataset](https://www.cityscapes-dataset.com/wordpress/wp-content/uploads/2015/07/stuttgart02-2040x500.png)
![Sample Image from Cityscape Dataset](https://www.cityscapes-dataset.com/wordpress/wp-content/uploads/2015/07/stuttgart00-2040x500.png)
![Sample Image from Cityscape Dataset](https://www.cityscapes-dataset.com/wordpress/wp-content/uploads/2015/07/stuttgart04-2040x500.png)
![Sample Image from Cityscape Dataset](https://www.cityscapes-dataset.com/wordpress/wp-content/uploads/2015/07/stuttgart01-2040x500.png)

<center>Image Courtesy: Cityscapes Datatset (Link: https://www.cityscapes-dataset.com/)</center>

---
## Background Theory

With the advancements in deep learning, the most striking successes have involved discriminative models, usually those that map a high-dimensional, rich sensory input to a class label. These striking successes have primarily been based on the backpropagation and dropout algorithms, using piecewise linear units. However, deep generative models have had less impact due to the challenges of approximating many probabilistic computations that occur due to the usage of piecewise linear units in the generative context. 

Generative Adversarial Networks, popularly known as GANs, is a machine learning framework class that sidesteps these difficulties by pitting the generative model against an adversary. In other words, a GAN is a machine learning framework where two neural networks compete against each other in a zero-sum game (i.e., one network's gain is the other network's loss). The two networks in a GAN can be considered as a generator and a discriminator. The generator learns to create images that look real, while the discriminator learns to tell real images apart from fakes. Competition in this game drives both networks to improve their models until the counterfeits are indistinguishable from the genuine images. An overview of the Generative Adversarial Network is represented below:

![Overview of Generative Adversarial Network](https://developers.google.com/machine-learning/gan/images/gan_diagram.svg)

<center>Image Courtesy: Google Developers (Link: https://developers.google.com/machine-learning/gan/gan_structure)</center>

To be added

---
## Proposed Solution

To be added

### 0. Imports

In [None]:
import os
import glob
import datetime
import numpy as np 
import scipy as sp
from imageio import imread
from skimage import transform
import matplotlib.pyplot as plt
from keras.optimizers import Adam
from keras.models import Sequential, Model
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, Concatenate

In [None]:
# Function to load data

def load_data(mode):
    
    if(mode == "training"):
        dataset_expr = str(training_dataset) + "\\**\\*.jpg"
    else:
        dataset_expr = str(validatation_dataset) + "\\**\\*.jpg"
        
    dataset_paths = glob.glob(dataset_expr, recursive=True)
    dataset_paths = sorted(dataset_paths)
    
    return dataset_paths
    
print("******************")
test = load_data("training")
print(test)

print("\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n")

print("******************")
test = load_data("validation")
print(test)

In [None]:
def load_data(dataset_name,batch_size=1, is_val=False):
        data_type = "train" if not is_val else "val"
        path = glob('../input/%s/%s/%s/*' % (dataset_name,dataset_name, data_type))

        batch_images = np.random.choice(path, size=batch_size)
        img_res=(128,128)
        imgs_A = []
        imgs_B = []
        for img_path in batch_images:
            img = imread(img_path)

            h, w, _ = img.shape
            _w = int(w/2)
            # because in the edges2shoes and maps dataset the input image comes before the ground truth.
            if (dataset_name=="edges2shoes" or dataset_name=="maps"):
                img_A, img_B = img[:, _w:, :],img[:, :_w, :] 
            else:  
                img_A, img_B = img[:, :_w, :], img[:, _w:, :]
            # decreasing the resolution 
            img_A = transform.resize(img_A, img_res)  #Ground Truth image
            img_B = transform.resize(img_B, img_res)  #Input image

            # If training => do random flip , this is a trick to avoid overfitting 
            if not is_val and np.random.random() < 0.5:
                img_A = np.fliplr(img_A)
                img_B = np.fliplr(img_B)

            imgs_A.append(img_A)
            imgs_B.append(img_B)
            
        
        imgs_A = np.array(imgs_A)/127.5 - 1.  #normalizing the images
        imgs_B = np.array(imgs_B)/127.5 - 1.

        return imgs_A, imgs_B

In [None]:
def load_batch( dataset_name,batch_size=1, is_val=False):
        data_type = "train" if not is_val else "val"
        path = glob('../input/%s/%s/%s/*' % (dataset_name,dataset_name, data_type))

        
        n_batches=batch_size
        img_res=(128,128)
        for i in range(n_batches-1):
            batch = path[i*batch_size:(i+1)*batch_size]
            imgs_A, imgs_B = [], []
            for img in batch:
                img = imread(img)
                h, w, _ = img.shape
                half_w = int(w/2)
                # because in the edges2shoes and maps dataset the input image comes before the ground truth.
                if (dataset_name=="edges2shoes"or dataset_name=="maps"):
                      img_A, img_B = img[:, half_w:, :],img[:, :half_w, :] 
                else:  
                      img_A, img_B = img[:, :half_w, :], img[:, half_w:, :]
                img_A = transform.resize(img_A, img_res)#Ground truth image
                img_B = transform.resize(img_B, img_res)# input image
                
 # when training => do random flip , this is a trick to avoid overfitting 
                if not is_val and np.random.random() > 0.5:
                        img_A = np.fliplr(img_A)
                        img_B = np.fliplr(img_B)

                imgs_A.append(img_A)
                imgs_B.append(img_B)
            # normalizing the images 
            imgs_A = np.array(imgs_A)/127.5 - 1.
            imgs_B = np.array(imgs_B)/127.5 - 1.

            yield imgs_A, imgs_B
def imread(path):
        return scipy.misc.imread(path, mode='RGB').astype(np.float)

In [None]:
def build_generator():
        """U-Net Generator"""

        def conv2d(layer_input, filters, f_size=4, bn=True):
            """Layers used during downsampling"""
            d = Conv2D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input)
            d = LeakyReLU(alpha=0.2)(d)
            if bn:
                d = BatchNormalization(momentum=0.8)(d)
            return d

        def deconv2d(layer_input, skip_input, filters, f_size=4, dropout_rate=0):
            """Layers used during upsampling"""
            u = UpSampling2D(size=2)(layer_input)
            u = Conv2D(filters, kernel_size=f_size, strides=1, padding='same', activation='relu')(u)
            if dropout_rate:
                u = Dropout(dropout_rate)(u)
            u = BatchNormalization(momentum=0.8)(u)
            u = Concatenate()([u, skip_input]) #skip connection
            return u

        
        d0 = Input(shape=img_shape)

        # Downsampling
        d1 = conv2d(d0, gf, bn=False)
        d2 = conv2d(d1, gf*2)
        d3 = conv2d(d2, gf*4)
        d4 = conv2d(d3, gf*8)
        d5 = conv2d(d4, gf*8)
        d6 = conv2d(d5, gf*8)
        d7 = conv2d(d6, gf*8)

        # Upsampling
        u1 = deconv2d(d7, d6, gf*8)
        u2 = deconv2d(u1, d5, gf*8)
        u3 = deconv2d(u2, d4, gf*8)
        u4 = deconv2d(u3, d3, gf*4)
        u5 = deconv2d(u4, d2, gf*2)
        u6 = deconv2d(u5, d1, gf)

        u7 = UpSampling2D(size=2)(u6)
        output_img = Conv2D(channels, kernel_size=4, strides=1, padding='same', activation='tanh')(u7)

        return Model(d0, output_img)

In [None]:
def build_discriminator():
        # a small function to make one layer of the discriminator
        def d_layer(layer_input, filters, f_size=4, bn=True):
            """Discriminator layer"""
            d = Conv2D(filters, kernel_size=f_size, strides=2, padding='same')(layer_input)
            d = LeakyReLU(alpha=0.2)(d)
            if bn:
                d = BatchNormalization(momentum=0.8)(d)
            return d

        img_A = Input(shape=img_shape)
        img_B = Input(shape=img_shape)

        # Concatenate image and conditioning image by channels to produce input
        combined_imgs = Concatenate(axis=-1)([img_A, img_B])

        d1 = d_layer(combined_imgs, df, bn=False)
        d2 = d_layer(d1, df*2)
        d3 = d_layer(d2, df*4)
        d4 = d_layer(d3, df*8)

        validity = Conv2D(1, kernel_size=4, strides=1, padding='same')(d4)

        return Model([img_A, img_B], validity)

In [None]:
# Input shape
img_rows = 128
img_cols = 128
channels = 3
img_shape = (img_rows, img_cols, channels)


# Calculate output shape of D (PatchGAN)
patch = int(img_rows / 2**4)
disc_patch = (patch, patch, 1)

# Number of filters in the first layer of G and D
gf = 64
df = 64

optimizer = Adam(0.0002, 0.5)

# Build and compile the discriminator
discriminator = build_discriminator()
discriminator.compile(loss='mse',
            optimizer=optimizer,
            metrics=['accuracy'])

# Build the generator
generator = build_generator()

# Input images and their conditioning images
img_A = Input(shape=img_shape)
img_B = Input(shape=img_shape)

# By conditioning on B generate a fake version of A
fake_A = generator(img_B)

# For the combined model we will only train the generator
discriminator.trainable = False

# Discriminators determines validity of translated images / condition pairs
valid = discriminator([fake_A, img_B])

combined = Model(inputs=[img_A, img_B], outputs=[valid, fake_A])
combined.compile(loss=['mse', 'mae'],
                              loss_weights=[1, 100],
                              optimizer=optimizer)


In [None]:
def show_images( dataset_name,epoch, batch_i):
        
        r, c = 3, 3

        imgs_A, imgs_B = load_data(dataset_name,batch_size=3, is_val=True)
        fake_A = generator.predict(imgs_B)

        gen_imgs = np.concatenate([imgs_B, fake_A, imgs_A])

        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5

        titles = ['Input', 'Output', 'Ground Truth']
        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt])
                axs[i, j].set_title(titles[i])
                axs[i,j].axis('off')
                cnt += 1
        plt.show()
        plt.close()

In [None]:
def train( dataset_name,epochs, batch_size=1, show_interval=10):

        start_time = datetime.datetime.now()

        # Adversarial loss ground truths
        valid = np.ones((batch_size,) + disc_patch)
        fake = np.zeros((batch_size,) + disc_patch)

        for epoch in range(epochs):
            for batch_i, (imgs_A, imgs_B) in enumerate(load_batch(dataset_name,batch_size)):

                
                #  Train Discriminator
                

                # Condition on B and generate a translated version
                fake_A = generator.predict(imgs_B)

                # Train the discriminators (original images = real / generated = Fake)
                d_loss_real = discriminator.train_on_batch([imgs_A, imgs_B], valid)
                d_loss_fake = discriminator.train_on_batch([fake_A, imgs_B], fake)
                d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

               
                #  Train Generator
                g_loss = combined.train_on_batch([imgs_A, imgs_B], [valid, imgs_A])

                elapsed_time = datetime.datetime.now() - start_time
                
            # Plot the progress
            if epoch%10==0:
                  print ("[Epoch %d/%d]  [D loss: %f, acc: %3d%%] [G loss: %f] time: %s" % (epoch, epochs,
                                                                        
                                                                        d_loss[0], 100*d_loss[1],
                                                                        g_loss[0],
                                                                        elapsed_time))
            # If at show interval => show generated image samples
            if epoch % show_interval == 0:
                    show_images(dataset_name,epoch, batch_i)