In [9]:
from numpy import zeros, ones, hstack
from numpy.random import rand, randn
from keras.models import Sequential
from keras.layers import Dense
from matplotlib import pyplot

In [2]:
# define generator model

def define_generator(latent_dim, n_outputs=2):
    generator_model = Sequential(name = 'Generator_Model')
    generator_model.add(Dense(15, 
                              activation='relu',
                              kernel_initializer = 'he_uniform',
                              input_dim = latent_dim))
    
    generator_model.add(Dense(n_outputs, activation='linear'))
    
    return generator_model

In [17]:
# define standalone discriminator model

def define_discriminator_model(n_inputs=2):
    model = Sequential()
    model.add(Dense(25,
                    activation='relu',
                    kernel_initializer = 'he_uniform',
                    input_dim = n_inputs))
    
    model.add(Dense(1, activation = 'sigmoid'))
    
    model.compile(loss='binary_crossentropy',
                  optimizer = 'adam',
                  metrics = ['accuracy'])
    
    return model

In [4]:
# define the combined generator and discriminator model, for updating the generator

def define_gan(generator, discriminator):
    # make weights in the discriminator not trainable
    discriminator.trainable = False
    
    # connect them
    gan = Sequential()
    
    # add generator
    gan.add(generator)
    
    # add discriminator
    gan.add(discriminator)
    
    # compile model
    gan.compile(loss='binary_crossentropy',
                  optimizer='adam')
    
    return gan

In [5]:
# generate n real samples with class labels

def generate_real_samples(n):
    # generate inputs in [-0.5, 0.5]
    X1 = rand(n) - 0.5
    # generate outputs X^2
    X2 = X1 * X1
    
    X1 = X1.reshape(n, 1)
    X2 = X2.reshape(n, 1)
    
    # stack array
    X = hstack((X1, X2))
    
    # generate class labels
    y = ones((n, 1))
    
    return X, y

In [6]:
# generate points in latent space as input for the generator

def generate_latent_points(latent_dim, n):
    x_input = randn(latent_dim * n)
    x_input = x_input.reshape((n, latent_dim))
    
    return x_input

In [26]:
# use the generator to generate n fake examples and plot the result

def generate_fake_samples(gen, latent_dim, n):
    
    x_input = generate_latent_points(latent_dim, n)
    
    # predict output
    X = gen.predict(x_input)
    # create class labels
    y = zeros((n, 1))
    
    return X, y

In [27]:
# evaluate the discriminator and plot real and fake points

def summarize_performance(epoch, gen, dis, latent_dim, n=100):
    # prepare real samples
    x_real, y_real = generate_real_samples(n)
    
    # evaluate discriminator on real examples
    _, acc_real = dis.evaluate(x_real, y_real, verbose=0)
    
    # prepare fake examples
    x_fake, y_fake = generate_fake_samples(gen, latent_dim, n)
    
    # evaluate discriminator on fake examples
    _, acc_fake = dis.evaluate(x_fake, y_fake, verbose=0)
    
    print(epoch, acc_real, acc_fake)

In [28]:
# train the generator and discriminator

def train_gan(gen, dis, gan, latent_dim, n_epochs=10000, n_batch=128, n_eval=2000):
    
    half_batch = int(n_batch/2)
    
    for i in range(n_epochs):
        # prepare real samples
        x_real, y_real = generate_real_samples(half_batch)
        # prepare fake examples
        x_fake, y_fake = generate_fake_samples(gen, latent_dim, half_batch)
        
        # update discriminator
        dis.train_on_batch(x_real, y_real)
        dis.train_on_batch(x_fake, y_fake)
        
        # prepare points in latent space as input for the generator
        x_gan = generate_latent_points(latent_dim, n_batch)
        y_gan = ones((n_batch, 1))
        
        # update the generator via the discriminator's error
        gan.train_on_batch(x_gan, y_gan)
        
        # evaluate the model every n_eval epochs
        if (i+1) % n_eval ==0:
            summarize_performance(i, gen, dis, latent_dim)

In [None]:
latent_dim = 5

discriminator = define_discriminator_model()
generator = define_generator(latent_dim)

gan_model = define_gan(generator, discriminator)

# train gan_model
train_gan(generator, discriminator, gan_model, latent_dim)

