<a href="https://colab.research.google.com/github/KirkDCO/HandsOnML_Exercises/blob/main/Ch17_Q11.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!cp -r drive/MyDrive/Colab\ Notebooks/HandsOnML/Ch17_Q10/flowers sample_data/.

In [None]:
!mkdir sample_data/flowers_small
!cp -r sample_data/flowers/sunflower sample_data/flowers_small/.

In [None]:
!mkdir sample_data/flowers_small
!cp -r drive/MyDrive/Colab\ Notebooks/HandsOnML/Ch17_Q10/flowers/sunflower sample_data/flowers_small/.

In [None]:
# imports and globals

import tensorflow as tf
from tensorflow import keras
K = keras.backend

import numpy as np
import matplotlib.pyplot as plt
import random

BATCH_SIZE = 32 

IMG_WIDTH  = 48 
IMG_HEIGHT = 48 

ENCODING_SIZE = 512 

In [None]:
# create a flower with labels generator
#flower_generator = tf.keras.preprocessing.image_dataset_from_directory(
#  "sample_data/flowers",
#  seed = 84,
#  image_size = (IMG_HEIGHT, IMG_WIDTH),
#  batch_size = BATCH_SIZE, 
#  labels = 'inferred'
#)

# keep the flower names for future reference
#flower_names = flower_generator.class_names

# create a training data generator 
training_generator = tf.keras.preprocessing.image_dataset_from_directory(
  "sample_data/flowers_small",
  seed = 42,
  image_size = (IMG_HEIGHT, IMG_WIDTH),
  batch_size = BATCH_SIZE, 
  labels = None
)

training_generator = training_generator.prefetch(1)

Found 1027 files belonging to 1 classes.


In [None]:
generator = keras.Sequential([
  keras.layers.Dense(128 * 6 * 6, activation = "selu", input_shape = [ENCODING_SIZE]),
  keras.layers.Reshape([6, 6, 128]),
  keras.layers.BatchNormalization(),
  keras.layers.Conv2DTranspose(filters = 128, kernel_size = 3, strides = 2,
                               padding = "same", activation = "selu",
                               kernel_initializer='lecun_normal'),
  keras.layers.BatchNormalization(),
  keras.layers.Conv2DTranspose(filters = 64, kernel_size = 3, strides = 2,
                               padding = "same", activation = "selu",
                               kernel_initializer='lecun_normal'),
  keras.layers.BatchNormalization(),
  keras.layers.Conv2DTranspose(filters = 32, kernel_size = 3, strides = 2,
                               padding = "same", activation = "selu",
                               kernel_initializer='lecun_normal'),
  keras.layers.Conv2DTranspose(filters = 3, kernel_size = 3, strides = 1,
                               padding = 'same', activation = 'sigmoid')
])
generator.summary()   

discriminator = keras.Sequential([
  keras.layers.Conv2D(32, input_shape = [IMG_WIDTH, IMG_HEIGHT, 3], kernel_size = 3, 
                      strides = 1, padding = 'same', activation = keras.layers.LeakyReLU(0.2)),
  keras.layers.Dropout(0.25),                      
  keras.layers.Conv2D(64, kernel_size = 3, strides = 2,
                      padding = 'same', activation = keras.layers.LeakyReLU(0.2)),
  keras.layers.Dropout(0.25),                      
  keras.layers.Conv2D(128, kernel_size = 3, strides = 2,
                      padding = 'same', activation = keras.layers.LeakyReLU(0.2)),
  keras.layers.Dropout(0.25),                      
  keras.layers.Flatten(),
  keras.layers.Dense(1, activation = 'sigmoid')                                  
])
discriminator.summary()

gan = keras.models.Sequential([generator, discriminator])
gan.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 dense (Dense)               (None, 4608)              2363904   
                                                                 
 reshape (Reshape)           (None, 6, 6, 128)         0         
                                                                 
 batch_normalization (BatchN  (None, 6, 6, 128)        512       
 ormalization)                                                   
                                                                 
 conv2d_transpose (Conv2DTra  (None, 12, 12, 128)      147584    
 nspose)                                                         
                                                                 
 batch_normalization_1 (Batc  (None, 12, 12, 128)      512       
 hNormalization)                                                 
                                                        

In [None]:
def plot_multiple_images(images, n_cols=None):
    n_cols = n_cols or len(images)
    n_rows = (len(images) - 1) // n_cols + 1
    if images.shape[-1] == 1:
        images = np.squeeze(images, axis=-1)
    plt.figure(figsize=(n_cols * 3, n_rows * 3))
    for index, image in enumerate(images):
        plt.subplot(n_rows, n_cols, index + 1)
        plt.imshow(image, cmap="binary")
        plt.axis("off")

def exponential_decay_fn(epoch):
  return 0.001 * 0.1 ** (epoch / 20)

def train_gan(gan, dataset, BATCH_SIZE, ENCODING_SIZE, n_epochs=50):
    generator, discriminator = gan.layers
    for epoch in range(n_epochs):
        print("Epoch {}/{}".format(epoch + 1, n_epochs))
        for X_batch in dataset:
          print("=", end = '')
          X_batch /= 255
          # phase 1 - training the discriminator
          noise = tf.random.normal(shape=[len(X_batch), ENCODING_SIZE])
          generated_images = generator(noise)
          X_fake_and_real = tf.concat([generated_images, X_batch], axis=0)
          y1 = tf.constant([[0.]] * len(X_batch) + [[1.]] * len(X_batch))
          discriminator.trainable = True
          K.set_value(discriminator.optimizer.learning_rate, 
                      exponential_decay_fn(epoch))
          discriminator.train_on_batch(X_fake_and_real, y1)

          # phase 2 - training the generator
          noise = tf.random.normal(shape=[len(X_batch), ENCODING_SIZE])
          y2 = tf.constant([[1.]] * len(X_batch))
          discriminator.trainable = False
          K.set_value(gan.optimizer.learning_rate, 
                      exponential_decay_fn(epoch))
          gan.train_on_batch(noise, y2)
        
        plot_multiple_images(X_fake_and_real, 3)
        plt.show()

In [None]:
# clear the session for a clean run
keras.backend.clear_session()
tf.random.set_seed(42)

discriminator.compile(loss = 'binary_crossentropy', optimizer = keras.optimizers.Nadam())
discriminator.trainable = False
gan.compile(loss = 'binary_crossentropy', optimizer = keras.optimizers.Nadam())

train_gan(gan, training_generator, BATCH_SIZE, ENCODING_SIZE, n_epochs = 500)