In [83]:
import tensorflow as tf

from tensorflow.keras import Sequential
from tensorflow.keras.layers import Reshape
from tensorflow.keras.layers import Conv2DTranspose
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import Flatten
from tensorflow.keras.layers import Activation
from tensorflow.keras.layers import UpSampling2D
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import Input
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense

from tensorflow import random
from tensorflow import GradientTape


import glob
import imageio
import matplotlib.pyplot as plt
import numpy as np
import os
import PIL
from tensorflow.keras import layers
from tensorflow.keras.preprocessing import image_dataset_from_directory
import time

from IPython import display

In [84]:
epochs = 20
number_of_examples = 16
batch_size = 16
latent_dim = 512
image_size = (64, 64) # h x w
num_examples_to_generate = 16

data_dir = r'../input/animal-faces/afhq/train/'

dataset = image_dataset_from_directory(
    data_dir, label_mode=None, image_size=image_size, batch_size=batch_size
)

In [85]:
dataset = dataset.map(lambda x: (x - 127.5) / 127.5) # Normalizing to -1,1

In [86]:
def make_generator_model(latent_dim):
    visible = Input(shape=[latent_dim])
    hidden = Reshape((1, 1, latent_dim))(visible)

    hidden = Conv2DTranspose(filters=512, kernel_size=4, strides=(1, 1),
                             padding='valid', activation="leaky_relu")(hidden)

    hidden = Conv2DTranspose(filters=256, kernel_size=3, strides=(2, 2),
                             padding='same', activation="leaky_relu")(hidden)
    # hidden = BatchNormalization()(hidden)
    hidden = Conv2DTranspose(filters=128, kernel_size=3, strides=(2, 2),
                             padding='same', activation="leaky_relu")(hidden)
    # hidden = BatchNormalization()(hidden)
    hidden = Conv2DTranspose(filters=64, kernel_size=3, strides=(2, 2),
                             padding='same', activation="leaky_relu")(hidden)
    # hidden = BatchNormalization()(hidden)
    hidden = Conv2DTranspose(filters=32, kernel_size=3, strides=(2, 2),
                             padding='same', activation="leaky_relu")(hidden)
    # hidden = BatchNormalization()(hidden)
    
    out4x4 = hidden
    out4x4 = Conv2D(filters=3, kernel_size=4, strides=(1, 1), padding='same')(out4x4) # ToRGB
    out4x4 = Activation("tanh")(out4x4)

    model = Model(inputs=visible, outputs=[out4x4])
    
    return model

In [87]:
generator = make_generator_model(latent_dim)

noise = tf.random.normal([1, latent_dim])
generated_image = generator(noise, training=False)

plt.imshow((generated_image[0].numpy()*127.5+127.5).astype("uint32"))

In [88]:
def make_discriminator_model():
    input5 = Input(shape=(64, 64, 3))
    hidden = Conv2D(filters=32, kernel_size=4, strides=1, padding="same",
                    activation="leaky_relu")(input5) # From RGB
    
    hidden = Conv2D(filters=64, kernel_size=3, strides=2, padding="same",
                    activation="leaky_relu")(hidden)
    # hidden = BatchNormalization()(hidden)
    
    hidden = Conv2D(filters=128, kernel_size=3, strides=2, padding="same",
                    activation="leaky_relu")(hidden)
    # hidden = BatchNormalization()(hidden)
    
    hidden = Conv2D(filters=256, kernel_size=3, strides=2, padding="same",
                    activation="leaky_relu")(hidden)
    # hidden = BatchNormalization()(hidden)
    
    hidden = Conv2D(filters=512, kernel_size=3, strides=2, padding="same",
                    activation="leaky_relu")(hidden)
    # hidden = BatchNormalization()(hidden)
    
    hidden = Conv2D(filters=512, kernel_size=4, strides=1, padding="valid", activation="leaky_relu")(hidden)
    hidden = Flatten()(hidden)
    hidden = Dense(1)(hidden)
    out = Activation("sigmoid")(hidden)

    model = Model(inputs=[input5], outputs=[out])
    
    return model

In [89]:
discriminator = make_discriminator_model()
decision = discriminator(generated_image)
print(decision)

In [90]:
# This method returns a helper function to compute cross entropy loss
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    
    total_loss = real_loss + fake_loss
    return total_loss

# The generator is performing well, if the discriminator classifies fakes as real(1)
def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)

In [91]:
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

In [92]:
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)

In [93]:
seed = tf.random.normal([num_examples_to_generate, latent_dim])

In [94]:
@tf.function
def train_step(images):
    noise = random.normal([batch_size, latent_dim])

    with GradientTape() as gen_tape, GradientTape() as disc_tape:
        generated_images = generator(noise, training=True)

        real_output = discriminator(images, training=True)
        fake_output = discriminator(generated_images, training=True)

        gen_loss = generator_loss(fake_output)
        disc_loss = discriminator_loss(real_output, fake_output)

    gradients_of_generator = gen_tape.gradient(
        gen_loss,
        generator.trainable_variables
    )

    gradients_of_discriminator = disc_tape.gradient(
        disc_loss,
        discriminator.trainable_variables
    )

    generator_optimizer.apply_gradients(
        zip(gradients_of_generator,
            generator.trainable_variables)
        )
    discriminator_optimizer.apply_gradients(
        zip(gradients_of_discriminator,
            discriminator.trainable_variables)
        )

    return (gen_loss, disc_loss)

In [95]:
if not os.path.isdir("epochs"):
    os.mkdir("epochs")

def plot_grid_of_images(images, epoch):
    plt.figure(figsize=(8, 8))

    for i in range(images.shape[0]):
        plt.subplot(4, 4, i+1)
        plt.imshow((images[i].numpy() * 127.5 + 127.5).astype("uint32"))
        plt.axis('off')

    plt.savefig('epochs/image_at_epoch_{:04d}.png'.format(epoch))
    plt.show()

In [96]:
def train(dataset, epochs):
    generator_losses = np.empty((0, 0), dtype=float)
    discriminator_losses = np.empty((0, 0), dtype=float)
    for epoch in range(epochs):
        start = time.time()
        
        batch_generator_losses = np.empty((0, 0), dtype=float)
        batch_discriminator_losses = np.empty((0, 0), dtype=float)
        for (batch, image_batch) in enumerate(dataset):
            gen_loss, disc_loss = train_step(image_batch)
            
            if batch % 100 == 0:
                average_batch_loss =\
                   gen_loss.numpy()/int(image_batch.shape[1])
                print(f"""Epoch {epoch+1}
                        Batch {batch} Loss {average_batch_loss:.4f}""")

            batch_generator_losses = np.append(batch_generator_losses, gen_loss)
            batch_discriminator_losses = np.append(batch_discriminator_losses, disc_loss)
            
        if generator_losses.shape == (0, 0):
            generator_losses = batch_generator_losses
            discriminator_losses = batch_discriminator_losses
        else:
            generator_losses = np.vstack(
                [generator_losses, batch_generator_losses]
            )
            discriminator_losses = np.vstack(
                [discriminator_losses, batch_discriminator_losses]
            )
            
        # Saving the model every 15 epochs
        if (epoch + 1) % 15 == 0:
            checkpoint.save(file_prefix = checkpoint_prefix)
            
        print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))

        # Producing images for the GIF
        #display.clear_output(wait=True)
        example_images = generator(seed, training=False)
        plot_grid_of_images(example_images, epoch)
        

    # Generating after the final epoch
    example_images = generator(seed, training=False)
    plot_grid_of_images(example_images, epoch)

In [None]:
# Training the model
train(dataset, epochs)