In [2]:
from numpy import hstack
from numpy import zeros
from numpy import ones
import numpy as np
from numpy.random import rand
from numpy.random import randn
from keras.datasets import mnist
from keras.optimizers import Adam
from keras.losses import mse, binary_crossentropy
from keras.models import Sequential
from keras.models import Model
from keras.layers import *
from numpy.random import randint
from matplotlib import pyplot as plt
from keras import backend as K
import os

In [3]:
def sampling(args):
    z_mean, z_log_var = args
    batch = K.shape(z_mean)[0]
    dim = K.int_shape(z_mean)[1] 
    epsilon = K.random_normal(shape=(batch, dim))
    return z_mean + K.exp(0.5 * z_log_var) * epsilon

def define_encoder(latent_dim):
    inputs = Input(shape=(784,))
    x = Dense(256)(inputs)
    x = LeakyReLU(alpha=0.2)(x)
    x = Dense(128)(x)
    x = LeakyReLU(alpha=0.2)(x)
    x = Dense(64, activation='relu')(x)
    
    z_mean = Dense(latent_dim)(x)
    z_log_var = Dense(latent_dim)(x)
    z = Lambda(sampling, output_shape=(latent_dim,))([z_mean, z_log_var])
    model = Model(inputs, [z_mean, z_log_var, z])
    return model

def define_generator(latent_dim):
    inputs = Input(shape=(latent_dim,))
    x = Dense(64)(inputs)
    x = LeakyReLU(alpha=0.2)(x)
    x = BatchNormalization(momentum=0.8)(x)
    x = Dense(128)(x)
    x = LeakyReLU(alpha=0.2)(x)
    x = BatchNormalization(momentum=0.8)(x)
    x = Dense(256)(x)
    x = LeakyReLU(alpha=0.2)(x)
    x = BatchNormalization(momentum=0.8)(x)
    outputs = Dense(784, activation='sigmoid')(x)
    model = Model(inputs, outputs)
    return model

def define_discriminator():
    inputs = Input(shape=(784,))
    x = Dense(256)(inputs)
    x = LeakyReLU(alpha=0.2)(x)
    x = Dense(128)(inputs)
    x = LeakyReLU(alpha=0.2)(x)
    x = Dense(64)(inputs)
    x = LeakyReLU(alpha=0.2)(x)
    outputs = Dense(1, activation='sigmoid')(x)
    model = Model(inputs, outputs)
    opt = Adam(lr=0.0002, beta_1=0.5)
    model.compile(loss='binary_crossentropy', optimizer=opt)
    return model

def define_gan(generator, discriminator):
    discriminator.trainable = False
    gen_noise = generator.input
    gen_output = generator.output
    gan_output = discriminator(gen_output)
    model = Model(gen_noise, gan_output)
    opt = Adam(lr=0.0008, beta_1=0.5)
    model.compile(loss='binary_crossentropy', optimizer=opt)
    return model

def load_real_samples():
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    x_train = x_train.astype('float32') / 255
    x_test = x_test.astype('float32') / 255
    x_train = x_train.reshape((len(x_train), np.prod(x_train.shape[1:])))
    x_test = x_test.reshape((len(x_test), np.prod(x_test.shape[1:])))
    return x_train, y_train, x_test, y_test

def generate_real_samples(dataset, n):
    images = dataset
    ix = randint(0, images.shape[0], n)
    x = images[ix]
    y = ones((n, 1))
    return x, y

def generate_latent_points(latent_dim, n):
    x_input = randn(latent_dim * n)
    x_input = x_input.reshape(n, latent_dim)
    return x_input

def generate_fake_samples(generator, latent_dim, n):
    x_input = generate_latent_points(latent_dim, n)
    X = generator.predict(x_input)
    y = zeros((n, 1))
    return X, y



In [4]:
def plot_results(*args,
                 batch_size=128,
                 model_name="vae_mnist"):

    encoder, decoder, x_test, y_test = args
    os.makedirs(model_name, exist_ok=True)

    filename = os.path.join(model_name, "vae_mean.png")
    z_mean, _, _ = encoder.predict(x_test,
                                   batch_size=batch_size)
    plt.figure(figsize=(12, 10))
    plt.scatter(z_mean[:, 0], z_mean[:, 1], c=y_test)
    plt.colorbar()
    plt.xlabel("Dimension 1")
    plt.ylabel("Dimension 2")
    plt.savefig(filename)

    filename = os.path.join(model_name, "digits_over_latent.png")
    n = 30
    digit_size = 28
    figure = np.zeros((digit_size * n, digit_size * n))
    grid_x = np.linspace(-4, 4, n)
    grid_y = np.linspace(-4, 4, n)[::-1]

    for i, yi in enumerate(grid_y):
        for j, xi in enumerate(grid_x):
            z_sample = np.array([[xi, yi]])
            x_decoded = decoder.predict(z_sample)
            digit = x_decoded[0].reshape(digit_size, digit_size)
            figure[i * digit_size: (i + 1) * digit_size,
                   j * digit_size: (j + 1) * digit_size] = digit

    plt.figure(figsize=(10, 10))
    start_range = digit_size // 2
    end_range = n * digit_size + start_range + 1
    pixel_range = np.arange(start_range, end_range, digit_size)
    sample_range_x = np.round(grid_x, 1)
    sample_range_y = np.round(grid_y, 1)
    plt.xticks(pixel_range, sample_range_x)
    plt.yticks(pixel_range, sample_range_y)
    plt.xlabel("z[0]")
    plt.ylabel("z[1]")
    plt.imshow(figure, cmap='Greys_r')
    plt.savefig(filename)

def sample_images(g_model, latent_dim):
        r, c = 5, 5
        noise = np.random.normal(0, 1, (r * c, latent_dim))
        gen_imgs = g_model.predict(noise)

        # Rescale images 0 - 1
        gen_imgs = 0.5 * gen_imgs + 0.5

        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs.reshape(-1, 28, 28)[cnt, :,:], cmap='gray')
                axs[i,j].axis('off')
                cnt += 1
        plt.show()
    


In [None]:
def train(e_model, g_model, d_model, gan_model, latent_dim, dataset, n_epochs=200, n_batch=128):
    half_batch = int(n_batch / 3)
    three_batch = int(n_batch / 2)

    # VAE
    inputs = Input(shape=(784,))

    E_mean, E_log_var, Z = e_model(inputs)

    outputs = g_model(Z)

    vae = Model(inputs, outputs)
    reconstruction_loss = binary_crossentropy(inputs, outputs)
    reconstruction_loss *= 784
    kl_loss = 1 + E_log_var - K.square(E_mean) - K.exp(E_log_var)
    kl_loss = K.sum(kl_loss, axis=-1)
    kl_loss *= -0.5
    vae_loss = K.mean(reconstruction_loss + kl_loss)
    vae.add_loss(vae_loss)
    opt = Adam(0.001, 0.5)
    vae.compile(optimizer=opt)

    x_train, _, x_test, y_test = load_real_samples()
    for i in range(n_epochs):
        for j in range(6):
            x_train, _ = generate_real_samples(dataset, n_batch)
            vae_loss = vae.train_on_batch(x_train, None)
        for j in range(4):
            x_real, y_real = generate_real_samples(dataset, three_batch)
            x_fake, y_fake = generate_fake_samples(g_model, latent_dim, three_batch)
            x_recon = vae.predict(x_real)
            d_loss1 = d_model.train_on_batch(x_real, y_real)
            d_loss2 = d_model.train_on_batch(x_fake, y_fake)
            d_loss3 = d_model.train_on_batch(x_recon, y_fake)
        for j in range(1):
            x_gan = generate_latent_points(latent_dim, n_batch*2)
            y_gan = ones((n_batch*2, 1))
            g_loss1 = gan_model.train_on_batch(x_gan, y_gan)
            x_real, y_real = generate_real_samples(dataset, n_batch*2)
            x_code = e_model(x_real)
            g_loss2 = gan_model.train_on_batch(x_code, y_gan)
        print('>%d, d1=%.3f, d2=%.3f, d3=%.3f, vae=%.3f, g1=%.3f, g2=%.3f' %
            (i, d_loss1, d_loss2, d_loss3, vae_loss, g_loss1, g_loss2))
        if i % 10 == 0:
            sample_images(g_model, latent_dim)
            decoded_imgs = vae.predict(x_test)
            n = 10  # How many digits we will display
            plt.figure(figsize=(20, 4))
            for i in range(n):
                # Display original
                ax = plt.subplot(2, n, i + 1)
                plt.imshow(x_test[i].reshape(28, 28))
                plt.gray()
                ax.get_xaxis().set_visible(False)
                ax.get_yaxis().set_visible(False)

                # Display reconstruction
                ax = plt.subplot(2, n, i + 1 + n)
                plt.imshow(decoded_imgs[i].reshape(28, 28))
                plt.gray()
                ax.get_xaxis().set_visible(False)
                ax.get_yaxis().set_visible(False)
            plt.show()
    g_model.save('cgan_generator.h5')

    plot_results(e_model, g_model, x_test, y_test)
    return g_model
            
latent_dim = 2
encoder = define_encoder(latent_dim)
discriminator = define_discriminator()
generator = define_generator(latent_dim)
gan_model = define_gan(generator, discriminator)
dataset, _, x_test, y_test = load_real_samples()
train(encoder, generator, discriminator, gan_model, latent_dim, dataset)

  'be expecting any data to be passed to {0}.'.format(name))
