In [None]:
import numpy as np
from keras import Sequential
from keras.datasets.mnist import load_data
from keras.layers import Conv2D, LeakyReLU, Dropout, Flatten, Dense, Reshape, Conv2DTranspose
from keras.optimizers import Adam
from matplotlib import pyplot
from numpy import zeros, ones, vstack, expand_dims
from numpy.random import randn, randint
import os
import tensorflow as tf

In [None]:
def load_real_samples():
    (trainX, _), (_, _) = load_data()
    X = expand_dims(trainX, axis=-1)
    X = X.astype(('float32'))
    X = X / 255.0
    return X
trainX = load_real_samples()

In [None]:
os.getcwd()
if not os.path.isdir("saved/"):
    os.makedirs("saved/")

In [None]:
physical_devices = tf.config.experimental.list_physical_devices('GPU')
print("Num GPUs Available: ", len(physical_devices))
tf.config.experimental.set_memory_growth(physical_devices[0], True)

In [None]:
for i in range(10):
    pyplot.subplot(1, 10, 1 + i)
    pyplot.axis('off')
    pyplot.imshow(trainX[i, :, :, 0], cmap='gray_r')
pyplot.savefig("saved/digits_example.png", bbox_inches='tight')
pyplot.close()

In [None]:
def define_disc(in_shape=(28,28,1)):
    model = Sequential()
    model.add(Conv2D(64, (3,3), strides=(2,2), padding='same', input_shape=in_shape))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.4))
    model.add(Conv2D(64, (3,3), strides=(2,2), padding='same'))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.4))
    model.add(Flatten())
    model.add(Dense(1, activation='sigmoid'))
    model.compile(loss='binary_crossentropy', optimizer=Adam(learning_rate=0.0002, beta_1=0.5), metrics=['accuracy'])
    return model

In [None]:
model = define_disc()
model.summary()

In [None]:
def define_gen(latent_dim):
    model = Sequential()
    model.add(Dense(128*7*7, input_dim=latent_dim))
    model.add(LeakyReLU(0.2))
    model.add(Reshape((7,7,128)))
    model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same'))
    model.add(LeakyReLU(0.2))
    model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same'))
    model.add(LeakyReLU(0.2))
    model.add(Conv2D(1, (7,7), activation='sigmoid', padding='same'))
    return model

In [None]:
gen_model = define_gen(100)
gen_model.summary()

In [None]:
def generate_latent_points(latent_dim, n):
    x_input = np.random.randn(latent_dim * n)
    x_input = x_input.reshape(n, latent_dim)
    return x_input

In [None]:
def generate_fake_samples(g_model, latent_dim, n_samples):
    x_input = generate_latent_points(latent_dim, n_samples)
    X = g_model.predict(x_input)
    y = np.zeros((n_samples, 1))
    return X, y

In [None]:
def generate_real_samples(n_samples):
   ix = randint(0, trainX.shape[0], n_samples)
   X = trainX[ix]
   y = ones((n_samples, 1))
   return X, y


In [None]:
generatedX, _ = generate_fake_samples(gen_model, 100, 25)
for i in range(10):
    pyplot.subplot(1, 10, 1 + i)
    pyplot.axis('off')
    pyplot.imshow(generatedX[i, :, :, 0], cmap='gray_r')

In [None]:
def define_gan(generator, discriminator):
    discriminator.trainable = False
    model = Sequential()
    model.add(generator)
    model.add(discriminator)
    model.compile(loss='binary_crossentropy', optimizer=Adam(learning_rate=0.0002, beta_1=0.5))
    return model

In [None]:
def train_gan(g_model, d_model, gan_model, latent_dim, epochs=100, batch_size=256):
    batch_per_epo = int(trainX.shape[0] / batch_size)
    half_batch = int(batch_per_epo / 2)
    acc_real_arr = []
    acc_fake_arr = []
    d_loss_arr = []
    g_loss_arr = []

    for i in range(epochs):
        acc_real, acc_fake = summarize_performance(i,g_model,d_model,latent_dim)
        acc_real_arr.append(acc_real)
        acc_fake_arr.append(acc_fake)

        for j in range(batch_per_epo):
            realX, realy = generate_real_samples(half_batch)
            fakeX, fakey = generate_fake_samples(g_model, 100, half_batch)
            X, y = vstack((realX, fakeX)), vstack((realy, fakey))
            d_loss, _ = d_model.train_on_batch(X, y)
            X_gan = generate_latent_points(latent_dim, batch_size)
            y_gan = ones((batch_size, 1))
            g_loss = gan_model.train_on_batch(X_gan, y_gan)
            if j == 0:
              d_loss_arr.append(d_loss)
              g_loss_arr.append(g_loss)
              print('>%d, %d/%d, d=%.3f, g=%.3f' % (i+1, j+1, batch_per_epo, d_loss, g_loss))

    return acc_real_arr, acc_fake_arr, d_loss_arr, g_loss_arr

In [None]:
def summarize_performance(epoch, g_model, d_model, latent_dim, n_samples=100):
    X_real, y_real = generate_real_samples(n_samples)
    _, acc_real = d_model.evaluate(X_real, y_real, verbose=0)
    X_fake, y_fake = generate_fake_samples(g_model, latent_dim, n_samples)
    _, acc_fake = d_model.evaluate(X_fake, y_fake, verbose=0)
    print('>Accuracy real: %.0f%%, fake: %.0f%%' % (acc_real*100, acc_fake*100))
    
    if epoch % 20 == 0:
      filename = 'saved/generator_model_%03d.h5' % (epoch)
      save_plot(X_fake, epoch)
      g_model.save(filename)

    return acc_real, acc_fake

In [None]:
def save_plot(examples, epoch):
    for i in range(10):
        pyplot.subplot(1, 10, 1 + i)
        pyplot.axis('off')
        pyplot.imshow(examples[i, :, :, 0], cmap='gray_r')
    filename = 'saved/generated_plot_e%03d.png' % (epoch)
    pyplot.savefig(filename, bbox_inches='tight')
    pyplot.close()

In [None]:
latent_dim = 100
d_model = define_disc()
g_model = define_gen(latent_dim)
gan_model = define_gan(g_model, d_model)
metrics = train_gan(g_model, d_model, gan_model, latent_dim, epochs=200)

In [None]:
pyplot.plot(metrics[0], label='accuracy real')
pyplot.plot(metrics[1], label='accuracy fake')
pyplot.plot([0.5 for x in range(len(metrics[0]))], label='ideal accuracy')
pyplot.xlabel("Epoch")
pyplot.ylabel("Accuracy")
pyplot.legend()
pyplot.savefig('saved/acc_plot')
pyplot.close()

In [None]:
pyplot.plot(metrics[2], label='discriminator')
pyplot.plot(metrics[3], label='generator')
pyplot.xlabel("Epoch")
pyplot.ylabel("Loss")
pyplot.legend()
pyplot.savefig('saved/loss_plot')
pyplot.close()

In [None]:
g_model.save('saved/generator_model_200.h5')

In [None]:
generatedX, _ = generate_fake_samples(g_model, 100, 10)
for i in range(10):
    pyplot.subplot(1, 10, 1 + i)
    pyplot.axis('off')
    pyplot.imshow(generatedX[i, :, :, 0], cmap='gray_r')
pyplot.savefig('saved/generated_plot_e200.png', bbox_inches='tight')
pyplot.close()

In [None]:
d_model.save('saved/discriminator_model_200.h5')