In [2]:
from numpy import hstack
from numpy import zeros
from numpy import ones
import numpy as np
from numpy.random import rand
from numpy.random import randn
from keras.datasets import fashion_mnist, mnist
from keras.optimizers import Adam
from keras.losses import mse, binary_crossentropy
from keras.models import Sequential
from keras.models import Model
from keras.layers import *
from numpy.random import randint
from matplotlib import pyplot as plt
import os

def define_generator(latent_dim):
    inputs = Input(shape=(latent_dim,))
    x = Dense(256)(inputs)
    x = LeakyReLU(alpha=0.2)(x)
    x = BatchNormalization(momentum=0.8)(x)
    x = Dense(512)(x)
    x = LeakyReLU(alpha=0.2)(x)
    x = BatchNormalization(momentum=0.8)(x)
    x = Dense(1024)(x)
    x = LeakyReLU(alpha=0.2)(x)
    x = BatchNormalization(momentum=0.8)(x)
    outputs = Dense(784, activation='tanh')(x)
    model = Model(inputs, outputs)
    return model

def define_discriminator():
    inputs = Input(shape=(784,))
    x = Dense(1024)(inputs)
    x = LeakyReLU(alpha=0.2)(x)
    x = Dropout(0.3)(x)
    x = Dense(512)(inputs)
    x = LeakyReLU(alpha=0.2)(x)
    x = Dropout(0.3)(x)
    x = Dense(256)(inputs)
    x = LeakyReLU(alpha=0.2)(x)
    x = Dropout(0.3)(x)
    outputs = Dense(1, activation='sigmoid')(x)
    model = Model(inputs, outputs)
    opt = Adam(0.001, 0.5)
    model.compile(loss='binary_crossentropy', optimizer=opt)
    return model

def define_gan(generator, discriminator):
    discriminator.trainable = False
    gen_noise = generator.input
    gen_output = generator.output
    gan_output = discriminator(gen_output)
    model = Model(gen_noise, gan_output)
    opt = Adam(0.001, 0.5)
    model.compile(loss='binary_crossentropy', optimizer=opt)
    return model

def load_real_samples():
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    x_train = (x_train.astype(np.float32) - 127.5)/127.5
    x_test = (x_train.astype(np.float32) - 127.5)/127.5
    x_train = x_train.reshape(60000, 784)
    x_test = x_train.reshape(60000, 784)
    return x_train, y_train, x_test, y_test

def generate_real_samples(dataset, n):
    images = dataset
    ix = randint(0, images.shape[0], n)
    x = images[ix]
    y = ones((n, 1))
    return x, y

def generate_latent_points(latent_dim, n):
    x_input = randn(latent_dim * n)
    x_input = x_input.reshape(n, latent_dim)
    return x_input

def generate_fake_samples(generator, latent_dim, n):
    x_input = generate_latent_points(latent_dim, n)
    X = generator.predict(x_input)
    y = zeros((n, 1))
    return X, y

def plot_results(*args,
                 batch_size=128,
                 model_name="vae_mnist"):

    encoder, decoder, x_test, y_test = args
    os.makedirs(model_name, exist_ok=True)

    filename = os.path.join(model_name, "vae_mean.png")
    z_mean, _, _ = encoder.predict(x_test,
                                   batch_size=batch_size)
    plt.figure(figsize=(12, 10))
    plt.scatter(z_mean[:, 0], z_mean[:, 1], c=y_test)
    plt.colorbar()
    plt.xlabel("Dimension 1")
    plt.ylabel("Dimension 2")
    plt.savefig(filename)

    filename = os.path.join(model_name, "digits_over_latent.png")
    n = 30
    digit_size = 28
    figure = np.zeros((digit_size * n, digit_size * n))
    grid_x = np.linspace(-4, 4, n)
    grid_y = np.linspace(-4, 4, n)[::-1]

    for i, yi in enumerate(grid_y):
        for j, xi in enumerate(grid_x):
            z_sample = np.array([[xi, yi]])
            x_decoded = decoder.predict(z_sample)
            digit = x_decoded[0].reshape(digit_size, digit_size)
            figure[i * digit_size: (i + 1) * digit_size,
                   j * digit_size: (j + 1) * digit_size] = digit

    plt.figure(figsize=(10, 10))
    start_range = digit_size // 2
    end_range = n * digit_size + start_range + 1
    pixel_range = np.arange(start_range, end_range, digit_size)
    sample_range_x = np.round(grid_x, 1)
    sample_range_y = np.round(grid_y, 1)
    plt.xticks(pixel_range, sample_range_x)
    plt.yticks(pixel_range, sample_range_y)
    plt.xlabel("z[0]")
    plt.ylabel("z[1]")
    plt.imshow(figure, cmap='Greys_r')
    plt.savefig(filename)

def sample_images(g_model, latent_dim):
        r, c = 5, 5
        noise = np.random.normal(0, 1, (r * c, latent_dim))
        gen_imgs = g_model.predict(noise)

        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5

        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs.reshape(-1, 28, 28)[cnt, :,:], cmap='gray')
                axs[i,j].axis('off')
                cnt += 1
        plt.show()
    
def train(g_model, d_model, gan_model, latent_dim, dataset, n_epochs=10000, n_batch=128):
  half_batch = int(n_batch / 2)
  three_batch = int(n_batch / 3)
  
  x_train, _, x_test, y_test = load_real_samples()
  for i in range(n_epochs):
    x_real, y_real = generate_real_samples(dataset, half_batch)
    x_fake, y_fake = generate_fake_samples(g_model, latent_dim, half_batch)
    d_loss1 = d_model.train_on_batch(x_real, y_real)
    d_loss2 = d_model.train_on_batch(x_fake, y_fake)
    x_gan = generate_latent_points(latent_dim, n_batch)
    y_gan = ones((n_batch, 1))
    g_loss1 = gan_model.train_on_batch(x_gan, y_gan)
    print('>%d, d1=%.3f, d2=%.3f, g1=%.3f' %
        (i, d_loss1, d_loss2, g_loss1))
    if i % 200 == 0:
      sample_images(g_model, latent_dim)
      
  g_model.save('cgan_generator.h5')

  return g_model
            
latent_dim = 100
discriminator = define_discriminator()
generator = define_generator(latent_dim)
gan_model = define_gan(generator, discriminator)
dataset, _, x_test, y_test = load_real_samples()
train(generator, discriminator, gan_model, latent_dim, dataset)

Output hidden; open in https://colab.research.google.com to view.