In [3]:
from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers import LeakyReLU
from keras.layers import UpSampling2D, Conv2D
from keras.models import Sequential, Model
# from keras.optimizers import RMSprop
import tensorflow as tf
import keras.backend as K
import matplotlib.pyplot as plt
import sys
import numpy as np

In [12]:
class WGAN():

  def __init__(self):
    self.img_rows = 28
    self.img_cols = 28
    self.channels = 1
    self.img_shape = (self.img_rows, self.img_cols, self.channels)
    self.latent_dim = 100
    # Following parameter and optimizer set as recommended in paper
    self.n_critic = 5
    self.clip_value = 0.01
    optimizer = tf.keras.optimizers.legacy.SGD(learning_rate=0.00005)
    # Build and compile the critic
    self.critic = self.build_critic()
    self.critic.compile(loss=self.wasserstein_loss,
                        optimizer=optimizer,
                        metrics=['accuracy'])
    # Build the generator
    self.generator = self.build_generator()
    # The generator takes noise as input and generated imgs
    z = Input(shape=(self.latent_dim,))
    img = self.generator(z)
    # For the combined model we will only train the generator
    self.critic.trainable = False
    # The critic takes generated images as input and determines
    # validity
    valid = self.critic(img)
    # The combined model  (stacked generator and critic)
    self.combined = Model(z, valid)
    self.combined.compile(loss=self.wasserstein_loss,
                          optimizer=optimizer,
                          metrics=['accuracy'])

  def wasserstein_loss(self, y_true, y_pred):
    return K.mean(y_true * y_pred)

  def build_critic(self):
    model = Sequential()
    model.add(Conv2D(16, kernel_size=3, strides=2,
                     input_shape=self.img_shape, padding="same"))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.25))
    model.add(Conv2D(32, kernel_size=3, strides=2, padding="same"))
    model.add(ZeroPadding2D(padding=((0,1),(0,1))))
    model.add(BatchNormalization(momentum=0.8))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.25))
    model.add(Conv2D(64, kernel_size=3, strides=2, padding="same"))
    model.add(BatchNormalization(momentum=0.8))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.25))
    model.add(Conv2D(128, kernel_size=3, strides=1, padding="same"))
    model.add(BatchNormalization(momentum=0.8))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Dropout(0.25))
    model.add(Flatten())
    model.add(Dense(1))
    # model.summary()
    img = Input(shape=self.img_shape)
    validity = model(img)
    return Model(img, validity)

  def build_generator(self):
    model = Sequential()
    model.add(Dense(128 * 7 * 7, activation="relu",
                    input_dim=self.latent_dim))
    model.add(Reshape((7, 7, 128)))
    model.add(UpSampling2D())
    model.add(Conv2D(128, kernel_size=4, padding="same"))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Activation("relu"))
    model.add(UpSampling2D())
    model.add(Conv2D(64, kernel_size=4, padding="same"))
    model.add(BatchNormalization(momentum=0.8))
    model.add(Activation("relu"))
    model.add(Conv2D(self.channels, kernel_size=4, padding="same"))
    model.add(Activation("tanh"))
    # model.summary()
    noise = Input(shape=(self.latent_dim,))
    img = model(noise)
    return Model(noise, img)

  def generate_plot_image(self, test_noise, epoch):
    pre_images = self.generator(test_noise,training=False)
    fig = plt.figure(figsize=(4,8))
    for i in range(pre_images.shape[0]):
        plt.subplot(4,8,i+1)
        plt.imshow((pre_images[i,:,:,0] + 1 )/2,cmap='gray')
        plt.axis('off')
    plt.savefig('.\WGAN\image_at_epoch_{:04d}.png'.format(epoch))
    plt.close()

  def train(self, epochs, batch_size=128, sample_interval=50):
    # Load the dataset
    (X_train, _), (_, _) = mnist.load_data()
    # Rescale -1 to 1
    X_train = (X_train.astype(np.float32) - 127.5) / 127.5
    X_train = np.expand_dims(X_train, axis=3)
    # Adversarial ground truths
    valid = -np.ones((batch_size, 1))
    fake = np.ones((batch_size, 1))
    for epoch in range(epochs+1):
        for _ in range(self.n_critic):
            #  Train Discriminator
            # Select a random batch of images
            idx = np.random.randint(0, X_train.shape[0], batch_size)
            imgs = X_train[idx]

            # Sample noise as generator input
            noise = np.random.normal(0, 1, (batch_size,
                                            self.latent_dim))
            # Generate a batch of new images
            gen_imgs = self.generator.predict(noise)
            # Train the critic
            d_loss_real = self.critic.train_on_batch(imgs, valid)
            d_loss_fake = self.critic.train_on_batch(gen_imgs, fake)
            d_loss = 0.5 * np.add(d_loss_fake, d_loss_real)
            # Clip critic weights
            for l in self.critic.layers:
               weights = l.get_weights()
               weights = [np.clip(w, -self.clip_value,
                          self.clip_value) for w in weights]
               l.set_weights(weights)

        #  Train Generator
        g_loss = self.combined.train_on_batch(noise, valid)
        # If at save interval => save generated image samples
        if epoch % sample_interval == 0:
           # Plot the progress
           print ("%d [D loss: %f] [G loss: %f]" % (epoch, 1 -
                   d_loss[0], 1 - g_loss[0]))
           self.generate_plot_image(noise, epoch)


In [13]:
wgan = WGAN()
wgan.train(epochs=8000, batch_size=32, sample_interval=100)

0 [D loss: 0.999974] [G loss: 1.000072]
100 [D loss: 0.999975] [G loss: 1.000001]
200 [D loss: 0.999975] [G loss: 1.000001]
300 [D loss: 0.999975] [G loss: 1.000001]
400 [D loss: 0.999975] [G loss: 1.000001]
500 [D loss: 0.999975] [G loss: 1.000000]
600 [D loss: 0.999975] [G loss: 1.000000]
700 [D loss: 0.999975] [G loss: 1.000001]
800 [D loss: 0.999975] [G loss: 1.000000]
900 [D loss: 0.999975] [G loss: 1.000000]
1000 [D loss: 0.999975] [G loss: 1.000001]
1100 [D loss: 0.999975] [G loss: 1.000001]
1200 [D loss: 0.999974] [G loss: 1.000000]
1300 [D loss: 0.999975] [G loss: 1.000001]
1400 [D loss: 0.999975] [G loss: 1.000000]
1500 [D loss: 0.999975] [G loss: 1.000001]
1600 [D loss: 0.999975] [G loss: 1.000002]
1700 [D loss: 0.999975] [G loss: 1.000002]
1800 [D loss: 0.999974] [G loss: 1.000001]
1900 [D loss: 0.999975] [G loss: 1.000001]
2000 [D loss: 0.999975] [G loss: 1.000001]
2100 [D loss: 0.999975] [G loss: 1.000000]
2200 [D loss: 0.999975] [G loss: 1.000002]
2300 [D loss: 0.999975]