In [28]:
import time
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
import numpy as np

In [29]:
epochs = 50
number_of_examples = 16
batch_size = 32
latent_dim = 512
image_size = (32, 32) # h x w

image_seed = tf.random.normal([number_of_examples, latent_dim])
label_seed = tf.constant([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5])

number_of_classes = 10

In [30]:
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()

x_train = (x_train - 127.5) / 127.5 #Normalizing

In [31]:
dataset = tf.data.Dataset.from_tensor_slices(
    (x_train, y_train)
).batch(batch_size, drop_remainder=True)

In [32]:
def make_generator_model(latent_dim):
    noise = keras.layers.Input(shape=[latent_dim])
    label = keras.layers.Input(shape=(1, ))
    
    label_embedding = keras.layers.Flatten()(keras.layers.Embedding(number_of_classes, latent_dim)(label))
    
    gen_input = keras.layers.Multiply()([noise, label_embedding])
    
    hidden = keras.layers.Reshape((1, 1, latent_dim))(gen_input)
    
    hidden = keras.layers.Conv2DTranspose(512, 4, 1, 'valid')(hidden)
    hidden = keras.layers.BatchNormalization(momentum=0.9)(hidden)
    hidden = keras.layers.ReLU()(hidden)
    
    hidden = keras.layers.Conv2DTranspose(256, 3, 2, 'same')(hidden)
    hidden = keras.layers.BatchNormalization(momentum=0.9)(hidden)
    hidden = keras.layers.ReLU()(hidden)
    
    hidden = keras.layers.Conv2DTranspose(128, 3, 2, 'same')(hidden)
    hidden = keras.layers.BatchNormalization(momentum=0.9)(hidden)
    hidden = keras.layers.ReLU()(hidden)
    
    hidden = keras.layers.Conv2DTranspose(3, 3, 2, 'same')(hidden)
    out = keras.layers.Activation("tanh")(hidden)
    
    return keras.Model(inputs=[noise, label], outputs=out)

In [33]:
generator = make_generator_model(latent_dim)

In [36]:
def make_discriminator_model():
    image = keras.layers.Input(shape=(32, 32, 3))
    
    label = keras.layers.Input(shape=(1, ))
    
    hidden = keras.layers.GaussianNoise(0.1)(image)
    
    hidden = keras.layers.Conv2D(128, 3, 2, 'same')(hidden)
    hidden = keras.layers.BatchNormalization(momentum=0.9)(hidden)
    hidden = keras.layers.ReLU()(hidden)
    
    hidden = keras.layers.Conv2D(256, 3, 2, 'same')(hidden)
    hidden = keras.layers.BatchNormalization(momentum=0.9)(hidden)
    hidden = keras.layers.ReLU()(hidden)
    
    hidden = keras.layers.Conv2D(512, 3, 2, 'same')(hidden)
    hidden = keras.layers.BatchNormalization(momentum=0.9)(hidden)
    hidden = keras.layers.ReLU()(hidden)
    
    hidden = keras.layers.Conv2D(latent_dim, 4, 1, 'valid')(hidden)
    hidden = keras.layers.ReLU()(hidden)
    features = keras.layers.Flatten()(hidden)
    
    label_embedding = keras.layers.Flatten()(keras.layers.Embedding(number_of_classes, latent_dim)(label))
    
    embedded_space = keras.layers.Multiply()([features, label_embedding]) 
    
    hidden = keras.layers.Dropout(0.3)(embedded_space)
    
    out = keras.layers.Dense(1)(hidden)
    
    return keras.Model(inputs=[image, label], outputs=out)

In [37]:
discriminator = make_discriminator_model()

In [38]:
cross_entropy = keras.losses.BinaryCrossentropy(from_logits=True)

def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.constant(np.full(real_output.shape, 0.9)), real_output)
    fake_loss = cross_entropy(tf.constant(np.full(fake_output.shape, 0)), fake_output)
    
    total_loss = real_loss + fake_loss
    return total_loss

def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)

generator_optimizer = keras.optimizers.Adam(1e-4, beta_1=0.5)
discriminator_optimizer = keras.optimizers.Adam(4e-4, beta_1=0.5)

In [39]:
@tf.function
def train_step(images):
    train_images, train_labels = images
    
    noise = tf.random.normal([batch_size, latent_dim])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        generated_images = generator((noise, train_labels), training=True)

        real_output = discriminator((train_images, train_labels), training=True)
        fake_output = discriminator((generated_images, train_labels), training=True)

        gen_loss = generator_loss(fake_output)
        disc_loss = discriminator_loss(real_output, fake_output)

    gradients_of_generator = gen_tape.gradient(
        gen_loss,
        generator.trainable_variables
    )

    gradients_of_discriminator = disc_tape.gradient(
        disc_loss,
        discriminator.trainable_variables
    )

    generator_optimizer.apply_gradients(
        zip(gradients_of_generator,
            generator.trainable_variables)
        )
    discriminator_optimizer.apply_gradients(
        zip(gradients_of_discriminator,
            discriminator.trainable_variables)
        )

    return (gen_loss, disc_loss)

In [44]:
def plot_grid_of_images(images, epoch):
    plt.figure(figsize=(8, 8))

    for i in range(images.shape[0]):
        plt.subplot(8, 8, i+1)
        plt.imshow((images[i].numpy() * 127.5 + 127.5).astype("uint32"), interpolation="none")
        plt.axis('off')

    plt.show()

In [41]:
def train(dataset, epochs):
    generator_losses = np.empty((0, 0), dtype=float)
    discriminator_losses = np.empty((0, 0), dtype=float)
    for epoch in range(epochs):
        start = time.time()
        
        batch_generator_losses = np.empty((0, 0), dtype=float)
        batch_discriminator_losses = np.empty((0, 0), dtype=float)
        for (batch, image_batch) in enumerate(dataset):
            gen_loss, disc_loss = train_step(image_batch)
            
            if batch % 500 == 0:
                average_batch_loss =\
                   gen_loss.numpy()/int(image_batch[0].shape[1])
                print(f"""Epoch {epoch+1}
                        Batch {batch} Loss {average_batch_loss:.4f}""")

            batch_generator_losses = np.append(batch_generator_losses, gen_loss)
            batch_discriminator_losses = np.append(batch_discriminator_losses, disc_loss)
        if generator_losses.shape == (0, 0):
            generator_losses = batch_generator_losses
            discriminator_losses = batch_discriminator_losses
        else:
            generator_losses = np.vstack(
                [generator_losses, batch_generator_losses]
            )
            discriminator_losses = np.vstack(
                [discriminator_losses, batch_discriminator_losses]
            )
        print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))

        example_images = generator((image_seed, label_seed), training=False)
        plot_grid_of_images(example_images, epoch)
    
    return (generator_losses, discriminator_losses)

In [42]:
# Training the model
(generator_losses, discriminator_losses) = train(dataset, 20)

In [50]:
class_names = ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]
noises = tf.random.normal([48, latent_dim])
for i in range(10):
    print(class_names[i])
    
    example_images = generator((noises, tf.constant(np.full(48, i))), training=False)
    plot_grid_of_images(example_images, 0)

In [None]:
generator.save("generatorcifar10.h5")
discriminator.save("discriminatorcifar10.h5")