In [22]:
from numpy import expand_dims, mean, ones
import numpy as np
from numpy.random import randn, randint
from keras.datasets import mnist
from keras import backend
from keras.optimizers import RMSprop
from keras.models import Sequential
from keras.layers import Dense, Reshape, Flatten, Conv2D, Conv2DTranspose, LeakyReLU, BatchNormalization 
from keras.initializers import RandomNormal
from keras.constraints import Constraint
import matplotlib.pyplot as plt
from tqdm import tqdm

In [None]:
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = expand_dims(X_train, axis=-1)
X_train = (X_train.astype(np.float32) - 127.5)/127.5

class ClipConstraint(Constraint):
    def __init__(self, clip_value):
        self.clip_value = clip_value
 
    def __call__(self, weights):
        return backend.clip(weights, -self.clip_value, self.clip_value)
 
    def get_config(self):
        return {'clip_value': self.clip_value}
 
def wasserstein_loss(y_true, y_pred):
    return backend.mean(y_true * y_pred)
 
def define_critic(in_shape=(28,28,1)):
    init = RandomNormal(stddev=0.02)
    const = ClipConstraint(0.01)
    model = Sequential()
    model.add(Conv2D(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init, kernel_constraint=const, input_shape=in_shape))
    model.add(BatchNormalization())
    model.add(LeakyReLU(alpha=0.2))
    model.add(Conv2D(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init, kernel_constraint=const))
    model.add(BatchNormalization())
    model.add(LeakyReLU(alpha=0.2))
    model.add(Flatten())
    model.add(Dense(1))
    opt = RMSprop(lr=0.00005)
    model.compile(loss=wasserstein_loss, optimizer=opt)
    return model
 
def define_generator(latent_dim):
    init = RandomNormal(stddev=0.02)
    model = Sequential()
    n_nodes = 128 * 7 * 7
    model.add(Dense(n_nodes, kernel_initializer=init, input_dim=latent_dim))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Reshape((7, 7, 128)))
    model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same', kernel_initializer=init))
    model.add(BatchNormalization())
    model.add(LeakyReLU(alpha=0.2))
    model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same', kernel_initializer=init))
    model.add(BatchNormalization())
    model.add(LeakyReLU(alpha=0.2))
    model.add(Conv2D(1, (7,7), activation='tanh', padding='same', kernel_initializer=init))
    return model
 
def define_gan(generator, critic):
    for layer in critic.layers:
        if not isinstance(layer, BatchNormalization):
            layer.trainable = False
    model = Sequential()
    model.add(generator)
    model.add(critic)
    opt = RMSprop(lr=0.00005)
    model.compile(loss=wasserstein_loss, optimizer=opt)
    return model
 
def train(g_model, c_model, gan_model, dataset, latent_dim, n_epochs=10, n_batch=128, n_critic=5):
    bat_per_epo = int(dataset.shape[0] / n_batch)
    n_steps = bat_per_epo * n_epochs
    half_batch = int(n_batch / 2)
    for i in tqdm(range(n_steps)):
        for _ in range(n_critic):
            X_real = dataset[randint(0, dataset.shape[0], half_batch)]
            y_real = ones((half_batch, 1))
            c_loss1 = c_model.train_on_batch(X_real, y_real)
            x_input = randn(latent_dim * half_batch)
            x_input = x_input.reshape(half_batch, latent_dim)
            X_fake = generator.predict(x_input)
            y_fake = -ones((half_batch, 1))
            c_loss2 = c_model.train_on_batch(X_fake, y_fake)
        x_input = randn(latent_dim * n_batch)
        X_gan = x_input.reshape(n_batch, latent_dim)
        y_gan = ones((n_batch, 1))
        g_loss = gan_model.train_on_batch(X_gan, y_gan)

latent_dim = 100
critic = define_critic()
generator = define_generator(latent_dim)
gan_model = define_gan(generator, critic)
train(generator, critic, gan_model, X_train, latent_dim)