# FlowerGan



This script takes flower images from this dataset:
http://www.robots.ox.ac.uk/~vgg/data/flowers/102/index.html

and feeds it into a GAN


*Note, portions of this notebook are based on a [notebook from Jeff Heaton](https://github.com/jeffheaton/t81_558_deep_learning/blob/master/t81_558_class_07_2_Keras_gan.ipynb). Part of his [course](https://github.com/jeffheaton/t81_558_deep_learning) on Deep Learning*

## Google Drive

This code should be run on a GPU, it will be very slow on a CPU alone.  The following code mounts your Google drive for use with Google CoLab.  If you are not using CoLab, the following code will not work.

In [None]:

try:
    from google.colab import drive
    drive.mount('/content/drive', force_remount=True)
    COLAB = True
    print("Note: using Google CoLab")
    %tensorflow_version 2.x
except:
    print("Note: not using Google CoLab")
    COLAB = False
    
%cd drive/My Drive/research/deep_learning/GDL_code

## imports

The following packages will be used to implement a basic GAN system in Python/Keras.

In [None]:
import numpy as np
import os
import time
import matplotlib.pyplot as plt
from models.GAN import GAN
from utils.loaders import load_flowers, hms_string

## configuration

These are the constants that define how the GANs will be created for this example.  The higher the resolution, the more memory that will be needed.  Higher resolution will also result in longer run times.  For Google CoLab (with GPU) 128x128 resolution is as high as can be used (due to memory).  Note that the resolution is specified as a multiple of 32.  So **GENERATE_RES** of 1 is 32, 2 is 64, etc.

To run this you will need training data.  The training data can be any collection of images.  I suggest using training data from the following two locations.  Simply unzip and combine to a common directory.  This directory should be uploaded to Google Drive (if you are using CoLab). The constant **DATA_PATH** defines where these images are stored.

In [None]:
# Generation resolution - Must be square 
# Training data is also scaled to this.
# Note GENERATE_RES 4 or higher  will blow Google CoLab's memory and have not
# been tested extensivly.
#GENERATE_RES = 3 # Generation resolution factor (1=32, 2=64, 3=96, 4=128, etc.)
GENERATE_SQUARE = 128 #32 * GENERATE_RES # rows/cols (should be square)
BLOCK_SQUARE = 450;
IMAGE_CHANNELS = 3

# Configuration
DATA_PATH = '/content/drive/My Drive/research/deep_learning/GDL_code/data/flower_gen'
EPOCHS = 200
BATCH_SIZE = 64
#VIRTUAL_BATCH_SIZE = 16 # using virtual batch normalization
#BUFFER_SIZE = 60000
PRINT_EVERY_N_BATCHES = 5

print(f"Will generate {GENERATE_SQUARE}px square images.")

# run params
SECTION = 'gan'
RUN_ID = '0003'
DATA_NAME = 'flower_gen'
RUN_FOLDER = 'run/{}/'.format(SECTION)
RUN_FOLDER += '_'.join([RUN_ID, DATA_NAME])

if not os.path.exists(RUN_FOLDER):
    os.mkdir(RUN_FOLDER)
    os.mkdir(os.path.join(RUN_FOLDER, 'viz'))
    os.mkdir(os.path.join(RUN_FOLDER, 'images'))
    os.mkdir(os.path.join(RUN_FOLDER, 'weights'))

mode =  'build' #'load' #

## load & preprocess images

Next we will load and preprocess the images.  This can take awhile.  Because of this we store the processed file as a binary.  This way we can simply reload the processed training data and quickly use it.  It is most efficient to only perform this operation once.  The dimensions of the image are encoded into the filename of the binary file because we need to regenerate it if these change.

In [None]:
# Image set may have over 10000 images. Can take over an hour for initial preprocessing.
# Because of this time needed, save a Numpy preprocessed file.
# Note, that file is large enough to cause problems for some verisons of Pickle,
# so Numpy binary files are used.
training_data = load_flowers(
    DATA_PATH, GENERATE_SQUARE, GENERATE_SQUARE, BLOCK_SQUARE, BLOCK_SQUARE, IMAGE_CHANNELS)

## architecture

The code below creates the generator and discriminator.

In [None]:
gan = GAN(input_dim = (GENERATE_SQUARE,GENERATE_SQUARE,IMAGE_CHANNELS)
        , discriminator_conv_filters = [64,64,128,128,256]
        , discriminator_conv_kernel_size = [5,5,5,5,5]
        , discriminator_conv_strides = [1,2,2,2,2]
        , discriminator_batch_norm_momentum = 0.8
        , discriminator_activation = 'leaky_relu'
        , discriminator_dropout_rate = 0.4
        , discriminator_learning_rate = 0.0002
        , generator_initial_dense_layer_size = (8, 8, 256)
        , generator_upsample = [2,2,2,2,1]
        , generator_conv_filters = [256,128,128,64,IMAGE_CHANNELS]
        , generator_conv_kernel_size = [5,5,5,5,5]
        , generator_conv_strides = [1,1,1,1,1]
        , generator_batch_norm_momentum = 0.8
        , generator_activation = 'leaky_relu'
        , generator_dropout_rate = 0.25
        , generator_learning_rate = 0.00015
        , optimiser = 'adam'
        , z_dim = 100 # Size vector to generate images from (latent space)
        , virtual_batch_size = None
        , label_smoothing = 0.1
        , preview_rows = 5 # Preview image
        , preview_cols = 5 # Preview image
        )

if mode == 'build':
    gan.save(RUN_FOLDER)
else:
    gan.load_weights(os.path.join(RUN_FOLDER, 'weights/weights.h5'))

In [None]:
gan.discriminator.summary()

In [None]:
gan.generator.summary()

## training

In [None]:
start = time.time()

gan.train(
    training_data
    , batch_size=BATCH_SIZE
    , epochs=EPOCHS
    , run_folder=RUN_FOLDER
    , print_every_n_batches=PRINT_EVERY_N_BATCHES
)

elapsed = time.time()-start
print (f'Training time: {hms_string(elapsed)}')

In [None]:
fig = plt.figure()
plt.plot([x[0] for x in gan.d_losses], color='black', linewidth=0.25)

plt.plot([x[1] for x in gan.d_losses], color='green', linewidth=0.25)
plt.plot([x[2] for x in gan.d_losses], color='red', linewidth=0.25)
plt.plot([x[0] for x in gan.g_losses], color='orange', linewidth=0.25)

plt.xlabel('batch', fontsize=18)
plt.ylabel('loss', fontsize=16)

plt.xlim(0, 200)
plt.ylim(0, 5)

plt.show()

In [None]:
fig = plt.figure()
plt.plot([x[3] for x in gan.d_losses], color='black', linewidth=0.25)
plt.plot([x[4] for x in gan.d_losses], color='green', linewidth=0.25)
plt.plot([x[5] for x in gan.d_losses], color='red', linewidth=0.25)
plt.plot([x[1] for x in gan.g_losses], color='orange', linewidth=0.25)

plt.xlabel('batch', fontsize=18)
plt.ylabel('accuracy', fontsize=16)

plt.xlim(0, 200)

plt.show()