In [1]:
import numpy as np
from keras.datasets.mnist import load_data
from tensorflow.keras.optimizers import Adam
from keras.models import Sequential
from keras.layers import Dense, Reshape, Flatten, Conv2D, Conv2DTranspose
from keras.layers import Activation, LeakyReLU, BatchNormalization
from keras.initializers import RandomNormal
import matplotlib.pyplot as plt

In [2]:
def def_discriminator(im_shape=(28,28,1)):
  init = RandomNormal(stddev=0.02)
  model = Sequential()
  model.add(Conv2D(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init,
                   input_shape=im_shape))
  model.add(BatchNormalization())
  model.add(LeakyReLU(alpha=0.2))

  model.add(Conv2D(128, (4,4), strides=(2,2), padding='same', kernel_initializer=init))
  model.add(BatchNormalization())
  model.add(LeakyReLU(alpha=0.2))

  model.add(Flatten())
  model.add(Dense(1, activation='linear', kernel_initializer=init))

  model.compile(loss='mse', optimizer=Adam(learning_rate=0.0002, beta_1=0.5))
  return model

In [3]:
def def_generator(latent_dim):
  init = RandomNormal(stddev=0.02)
  model = Sequential()
  n_nodes = 256*7*7
  model.add(Dense(n_nodes, kernel_initializer=init, input_dim=latent_dim))
  model.add(BatchNormalization())
  model.add(Activation('relu'))
  model.add(Reshape((7,7,256)))

  model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same', kernel_initializer=init))
  model.add(BatchNormalization())
  model.add(Activation('relu'))

  model.add(Conv2DTranspose(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init))
  model.add(BatchNormalization())
  model.add(Activation('relu'))

  model.add(Conv2D(1, (7,7), padding='same', kernel_initializer=init))
  model.add(Activation('tanh'))

  return model

In [4]:
def def_gan(generator, discriminator):
  discriminator.trainable = False

  model = Sequential()
  model.add(generator)
  model.add(discriminator)

  model.compile(loss='mse', optimizer=Adam(learning_rate=0.0002, beta_1=0.5))
  return model

In [5]:
def load_real_samples():
  (trainX, _), (_, _) = load_data()
  X = np.expand_dims(trainX, axis=-1)
  X = X.astype('float32')
  X = (X - 127.5) / 127.5
  return X

In [6]:
def generate_real_samples(dataset, n_samples):
  idx = np.random.randint(0, dataset.shape[0], n_samples)
  X = dataset[idx]
  y = np.ones((n_samples, 1))
  return X, y

In [7]:
def generate_latent_points(latent_dim, n_samples):
  x_input = np.random.rand(latent_dim * n_samples)
  x_input = x_input.reshape(n_samples, latent_dim)
  return x_input

In [8]:
def generate_fake_samples(generator, latent_dim, n_samples):
  x_input = generate_latent_points(latent_dim, n_samples)
  X = generator.predict(x_input)
  y = np.zeros((n_samples, 1))
  return X, y

In [9]:
def summarize_performance(epoch, g_model, latent_dim, n_samples=100):
  X, _ = generate_fake_samples(g_model, latent_dim, n_samples)
  X = X + 1 / 2
  for i in range(n_samples):
    plt.subplot(10, 10, 1+i)
    plt.axis('off')
    plt.imshow(X[i, :, :, 0], cmap='gray_r')
  filename1 = 'generated_plot_%06d.png' %(epoch+1)
  plt.savefig(filename1)
  plt.close()

  filename2 = 'model_%06d.h5' % (epoch+1)
  g_model.save(filename2)
  print('saved %s and %s' % (filename1, filename2))

In [10]:
def plot_history(d1_hist, d2_hist, g_hist):
  plt.plot(d1_hist, label='d1loss')
  plt.plot(d2_hist, label='d2loss')
  plt.plot(g_hist, label='gan loss')
  plt.legend()
  filename = 'plot_loss.png'
  plt.savefig(filename)
  plt.close()
  print('saved %s' % (filename))

In [11]:
def train(g_model, d_model, gan_model, dataset, latent_dim, n_epochs=20, n_batch=64):
  bat_per_epoch = int(dataset.shape[0] / n_batch)
  n_steps = bat_per_epoch * n_epochs
  half_batch = int(n_batch/2)
  d1_hist, d2_hist, g_hist = list(), list(), list()

  for i in range(n_steps):
    X_real, y_real = generate_real_samples(dataset, half_batch)
    X_fake, y_fake = generate_fake_samples(g_model, latent_dim, half_batch)

    d_loss1 = d_model.train_on_batch(X_real, y_real)
    d_loss2 = d_model.train_on_batch(X_fake, y_fake)

    g_input = generate_latent_points(latent_dim, n_batch)
    y_real_2 = np.ones((n_batch, 1))
    g_loss = gan_model.train_on_batch(g_input, y_real_2)

    d1_hist.append(d_loss1)
    d2_hist.append(d_loss2)
    g_hist.append(g_loss)

    if (i+1) % (bat_per_epoch * 1) == 0:
      summarize_performance(i, g_model, latent_dim)

    plot_history(d1_hist, d2_hist, g_hist)

In [None]:
import time
latent_dim = 100
discriminator = def_discriminator()
generator = def_generator(latent_dim)
gan_model = def_gan(generator, discriminator)
dataset = load_real_samples()
print(dataset.shape)
start_time = time.time()
train(generator, discriminator, gan_model, dataset, latent_dim)
end_time = time.time()
print('Total Time', (end_time - start_time))

In [None]:
from keras.models import load_model

In [None]:
def plot_generated_images(examples, n):
    for i in range(n*n):
        plt.subplot(n, n, i+1)
        plt.axis('off')
        plt.imshow(examples[i, :, :, 0], cmap='gray_r')
    plt.show()

In [None]:
model = load_model('model_018740.h5')
latent_points = generate_latent_points(100, 100)
X = model.predict(latent_points)
plot_generated_images(X, 10)