In [23]:
import matplotlib.pyplot as plt
import sys
import numpy as np
import tqdm
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Dropout
from tensorflow.keras.layers import BatchNormalization, Activation, ZeroPadding2D
from tensorflow.keras.layers import LeakyReLU
from tensorflow.keras.layers import UpSampling2D, Conv2D
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers.legacy import Adam
from tensorflow.keras import initializers

In [2]:
randomDim = 10
(X_train, _), (_, _) = mnist.load_data()
X_train = (X_train.astype(np.float32) - 127.5)/127.5
X_train = X_train.reshape(60000, 784)

In [27]:
generator = Sequential()
generator.add(Dense(256, input_dim=randomDim))
generator.add(LeakyReLU(0.2))
generator.add(Dense(512))
generator.add(LeakyReLU(0.2))
generator.add(Dense(1024))
generator.add(LeakyReLU(0.2))
generator.add(Dense(784, activation='tanh'))
generator.compile(loss='binary_crossentropy', optimizer=Adam(learning_rate=0.01, beta_1=2e-4))

In [28]:
discriminator = Sequential()
discriminator.add(Dense(1024, input_dim=784,
kernel_initializer=initializers.RandomNormal(stddev=0.02)))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dropout(0.3))
discriminator.add(Dense(512))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dropout(0.3))
discriminator.add(Dense(256))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dropout(0.3))
discriminator.add(Dense(1, activation='sigmoid'))
discriminator.compile(loss='binary_crossentropy', optimizer=Adam(learning_rate=0.01, beta_1=2e-4))

In [29]:
# Combined network
discriminator.trainable = False
ganInput = Input(shape=(randomDim,))
x = generator(ganInput)
ganOutput = discriminator(x)
gan = Model(inputs=ganInput, outputs=ganOutput)
gan.compile(loss='binary_crossentropy', optimizer=Adam(learning_rate=0.01, beta_1=2e-4))


In [35]:
dLosses = []
gLosses = []
def train(epochs=1, batchSize=128):
  batchCount = int(X_train.shape[0] / batchSize)
  print ('Epochs:', epochs)
  print ('Batch size:', batchSize)
  print ('Batches per epoch:', batchCount)
  for e in range(1, epochs+1):
    print ('-'*15, 'Epoch %d' % e, '-'*15)
    for _ in range(batchCount):
      noise = np.random.normal(0, 1, size=[batchSize, randomDim])
      imageBatch = X_train[np.random.randint(0, X_train.shape[0], size=batchSize)]
      generatedImages = generator.predict(noise)
      X = np.concatenate([imageBatch, generatedImages])
      yDis = np.zeros(2*batchSize)
      yDis[:batchSize] = 0.9
      discriminator.trainable = True
      dloss = discriminator.train_on_batch(X, yDis)
      noise = np.random.normal(0, 1, size=[batchSize, randomDim])
      yGen = np.ones(batchSize)
      discriminator.trainable = False
      gloss = gan.train_on_batch(noise, yGen)
      dLosses.append(dloss)
      gLosses.append(gloss)
      if e == 1 or e % 20 == 0:
        saveGeneratedImages(e)

In [33]:
# Plot the loss from each batch
def plotLoss(epoch):
 plt.figure(figsize=(10, 8))
 plt.plot(dLosses, label='Discriminitive loss')
 plt.plot(gLosses, label='Generative loss')
 plt.xlabel('Epoch')
 plt.ylabel('Loss')
 plt.legend()
 plt.savefig('gan_loss_epoch_%d.png' % epoch)

In [34]:
# Create a wall of generated MNIST images
def saveGeneratedImages(epoch, examples=100, dim=(10, 10), figsize=(10, 10)):
 noise = np.random.normal(0, 1, size=[examples, randomDim])
 generatedImages = generator.predict(noise)
 generatedImages = generatedImages.reshape(examples, 28, 28)
 plt.figure(figsize=figsize)
 for i in range(generatedImages.shape[0]):
  plt.subplot(dim[0], dim[1], i+1)
  plt.imshow(generatedImages[i], interpolation='nearest',
  cmap='gray_r')
  plt.axis('off')
  plt.tight_layout()
  plt.savefig('gan_generated_image_epoch_%d.png' % epoch)


In [None]:
train(10)
Epochs: 10
Batchsize :  128
Batchesperepoch: 468

Epochs: 10
Batch size: 128
Batches per epoch: 468
--------------- Epoch 1 ---------------


  plt.figure(figsize=figsize)


