In [1]:
from numpy import expand_dims, mean, ones
from numpy.random import randn, randint
from keras.datasets.mnist import load_data
from keras import backend as K
from tensorflow.keras.optimizers import RMSprop
from keras.models import Sequential
from keras.layers import Dense, Reshape, Flatten, Conv2D
from tensorflow.keras.layers import Conv2DTranspose, LeakyReLU, BatchNormalization
from keras.initializers import RandomNormal
from keras.constraints import Constraint
import matplotlib.pyplot as plt

In [2]:
class ClipConstraint(Constraint):
    def __init__(self, clip_value):
        self.clip_value = clip_value
    def __call__(self, weights):
        return K.clip(weights, -self.clip_value, self.clip_value)
    def get_config(self):
        return {'clip_value': self.clip_value}

In [3]:
def wasserstein_loss(y_true, y_pred):
    return K.mean(y_true * y_pred)

In [4]:
def define_critic(in_shape=(28,28,1)):
    init = RandomNormal(stddev=0.02)
    const = ClipConstraint(0.01)
    model = Sequential()
    model.add(Conv2D(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init, 
                    kernel_constraint=const, input_shape=in_shape))
    model.add(BatchNormalization())
    model.add(LeakyReLU(alpha=0.2))
    model.add(Conv2D(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init, 
                    kernel_constraint=const))
    model.add(BatchNormalization())
    model.add(LeakyReLU(alpha=0.2))
    model.add(Flatten())
    model.add(Dense(1))
    opt = RMSprop(learning_rate=0.00005)
    model.compile(loss=wasserstein_loss, optimizer=opt)
    return model

In [5]:
def define_generator(latent_dim):
    init = RandomNormal(stddev=0.02)
    model = Sequential()
    model.add(Dense(128*7*7, kernel_initializer=init, input_dim=latent_dim))
    model.add(LeakyReLU(alpha=0.2))
    model.add(Reshape((7,7,128)))
    model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same', kernel_initializer=init))
    model.add(BatchNormalization())
    model.add(LeakyReLU(alpha=0.2))
    model.add(Conv2DTranspose(128, (4,4), strides=(2,2), padding='same', kernel_initializer=init))
    model.add(BatchNormalization())
    model.add(LeakyReLU(alpha=0.2))
    model.add(Conv2D(1, (7,7), activation='tanh', padding='same', kernel_initializer=init))
    return model

In [6]:
def define_gan(generator, critic):
    critic.trainable = False
    model = Sequential()
    model.add(generator)
    model.add(critic)
    opt = RMSprop(learning_rate=0.00005)
    model.compile(loss=wasserstein_loss, optimizer=opt)
    return model

In [7]:
def load_real_samples():
    (trainX, trainy), (_, _) = load_data()
    idx = trainy == 7
    X = trainX[idx]
    X = expand_dims(X, axis=-1)
    X = X.astype('float32')
    X = (X - 127.5) / 127.5
    return X

In [8]:
def generate_real_samples(dataset, n_samples):
    idx = randint(0, dataset.shape[0], n_samples)
    X = dataset[idx]
    y = -ones((n_samples, 1))
    return X, y

In [9]:
def generate_latent_points(latent_dim, n_samples):
    x_input = randn(latent_dim*n_samples)
    x_input = x_input.reshape(n_samples, latent_dim)
    return x_input

In [10]:
def generate_fake_samples(generator, latent_dim, n_samples):
    x_input = generate_latent_points(latent_dim, n_samples)
    X = generator.predict(x_input)
    y = ones((n_samples, 1))
    return X, y

In [11]:
def summarize_performance(step, g_model, latent_dim, n_samples=100):
    X, _ = generate_fake_samples(g_model, latent_dim, n_samples)
    X = (X + 1) / 2.0
    
    for i in range(100):
        plt.subplot(10,10,1+i)
        plt.axis('off')
        plt.imshow(X[i, :, :, 0], cmap='gray_r')
       
    filename1 = 'generated_plot_%04d.png' %(step+1)
    plt.savefig(filename1)
    plt.close()
    filename2 = 'model_%04d.h5' %(step+1)
    g_model.save(filename2)
    print('Saved: %s and %s' %(filename1, filename2))

In [12]:
def plot_history(d1_hist, d2_hist, g_hist):
    plt.plot(d1_hist, label='critic_real')
    plt.plot(d2_hist, label='critic_fake')
    plt.plot(g_hist, label='gen')
    plt.legend()
    plt.savefig('plot_loss.png')
    plt.close()

In [13]:
def train(g_model, c_model, gan_model, dataset, latent_dim, n_epochs=50, n_batch=64, n_critic=5):
    batch_per_epoch = int(dataset.shape[0] / n_batch)
    n_steps = batch_per_epoch * n_epochs
    half_batch = int(n_batch / 2)
    c1_hist, c2_hist, g_hist = list(), list(), list()
    for i in range(n_steps):
        c1_temp, c2_temp = list(), list()
        for _ in range(n_critic):
            X_real, y_real = generate_real_samples(dataset, half_batch)
            c_loss1 = c_model.train_on_batch(X_real, y_real)
            c1_temp.append(c_loss1)
            X_fake, y_fake = generate_fake_samples(g_model, latent_dim, half_batch)
            c_loss2 = c_model.train_on_batch(X_fake, y_fake)
            c2_temp.append(c_loss2)
        c1_hist.append(mean(c1_temp))
        c2_hist.append(mean(c2_temp))
        X_gan_input = generate_latent_points(latent_dim, n_batch)
        y_gan = -ones((n_batch, 1))
        g_loss = gan_model.train_on_batch(X_gan_input, y_gan)
        g_hist.append(g_loss)
        print('%d, c1=%.3f, c2=%.3f, g=%.3f' %(i+1, c1_hist[-1], c2_hist[-1], g_loss))
        if (i+1) % batch_per_epoch == 0:
            summarize_performance(i, g_model, latent_dim)
        plot_history(c1_hist, c2_hist, g_hist)

In [14]:
latent_dim = 50
critic = define_critic()
generator = define_generator(latent_dim)
gan_model = define_gan(generator, critic)
dataset = load_real_samples()
print(dataset.shape)

In [None]:
train(generator, critic, gan_model, dataset, latent_dim)

In [None]:
def plot_generated(examples, n):
    for i in range(n*n):
        plt.subplot(n, n, i+1)
        plt.axis('off')
        plt.imshow(examples[i, :, :, 0], cmap='gray_r')
    plt.show()

In [None]:
from keras.models import load_model
model = load_model('model_14550.h5')
latent_points = generate_latentpoints(50,25)
X = model.predict(latent_points)
plot_generated(X, 5)