In [1]:
from numpy import zeros, ones, expand_dims, hstack
from numpy.random import randn, randint
from keras.datasets.mnist import load_data
from tensorflow.keras.optimizers import Adam
from keras.initializers import RandomNormal
from tensorflow.keras.utils import to_categorical
from keras.models import Model
from keras.layers import Input, Dense, Reshape, Flatten, Conv2D, Conv2DTranspose,\
LeakyReLU, BatchNormalization, Activation
import matplotlib.pyplot as plt

In [2]:
def define_discriminator(n_cat, in_shape=(28,28,1)):
    init = RandomNormal(stddev=0.02)
    in_image = Input(shape=in_shape)
    d = Conv2D(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(in_image)
    d = LeakyReLU(alpha=0.1)(d)
    d = Conv2D(128, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(d)
    d = LeakyReLU(alpha=0.1)(d)
    d = BatchNormalization()(d)
    d = Conv2D(256, (4,4), padding='same', kernel_initializer=init)(d)
    d = LeakyReLU(alpha=0.1)(d)
    d = BatchNormalization()(d)
    d = Flatten()(d)
    out_classifier = Dense(1, activation='sigmoid')(d)
    d_model = Model(in_image, out_classifier)
    d_model.compile(loss='binary_crossentropy', optimizer=Adam(learning_rate=0.0002, beta_1=0.5))
    q = Dense(128)(d)
    q = LeakyReLU(alpha=0.1)(q)
    q = BatchNormalization()(q)
    out_codes = Dense(n_cat, activation='softmax')(q)
    q_model = Model(in_image, out_codes)
    return d_model, q_model

In [3]:
def define_generator(gen_input_size):
    init = RandomNormal(stddev=0.02)
    in_lat = Input(shape=(gen_input_size, ))
    gen = Dense(512*7*7, kernel_initializer=init)(in_lat)
    gen = Activation('relu')(gen)
    gen = BatchNormalization()(gen)
    gen = Reshape((7,7,512))(gen)
    gen = Conv2D(128, (4,4), padding='same', kernel_initializer=init)(gen)
    gen = Activation('relu')(gen)
    gen = BatchNormalization()(gen)
    gen = Conv2DTranspose(64, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(gen)
    gen = Activation('relu')(gen)
    gen = BatchNormalization()(gen)
    gen = Conv2DTranspose(1, (4,4), strides=(2,2), padding='same', kernel_initializer=init)(gen)
    out_layer = Activation('tanh')(gen)
    model = Model(in_lat, out_layer)
    return model

In [4]:
def define_gan(g_model, d_model, q_model):
    d_model.trainable = False
    d_output = d_model(g_model.output)
    q_output = q_model(g_model.output)
    model = Model(g_model.input, [d_output, q_output])
    opt = Adam(learning_rate=0.0002, beta_1=0.5)
    model.compile(loss=['binary_crossentropy', 'categorical_crossentropy'], optimizer=opt)
    return model

In [5]:
def load_real_samples():
    (trainX, _), (_, _) = load_data()
    X = expand_dims(trainX, axis=-1)
    X = X.astype('float32')
    X = (X - 127.5) / 127.5
    return X

In [6]:
def generate_real_samples(dataset, n_samples):
    idx = randint(0, dataset.shape[0], n_samples)
    X = dataset[idx]
    y = ones((n_samples, 1))
    return X, y

In [19]:
def generate_latent_points(latent_dim, n_cat, n_samples):
    z_latent = randn(latent_dim * n_samples)
    z_latent = z_latent.reshape((n_samples, latent_dim))
    cat_codes = randint(0, n_cat, n_samples)
    cat_codes = to_categorical(cat_codes, num_classes=n_cat)
    z_input = hstack((z_latent, cat_codes))
    #print(z_latent.shape, cat_codes.shape)
    return z_input, cat_codes

In [14]:
def generate_fake_samples(generator, latent_dim, n_cat, n_samples):
    z_input, _ = generate_latent_points(latent_dim, n_cat, n_samples)
    images = generator.predict(z_input)
    y = zeros((n_samples, 1))
    return images, y

In [15]:
def summarize_performance(step, g_model, gan_model, latent_dim, n_cat, n_samples=100):
    X, _ = generate_fake_samples(g_model, latent_dim, n_cat, n_samples)
    X = (X + 1) / 2.0
    for i in range(100):
        plt.subplot(10, 10, 1+i)
        plt.axis('off')
        plt.imshow(X[i, :, :, 0], cmap='gray_r')
    filename1 = 'generated_plot_%04d.png' %(step+1)
    plt.save(filename1)
    plt.close()
    filename2 = 'model_%04d.h5' %(step+1)
    g_model.save(filename2)
    filename3 = 'gan_model_%04d.h5' %(step+1)
    gan_model.save(filename3)
    print('>Saved: %s, %s, and %s' % (filename1, filename2, filename3))

In [16]:
def train(g_model, d_model, q_model, dataset, latent_dim, n_cat, n_epochs=100, n_batch=64):
    batch_per_epoch = int(dataset.shape[0] / n_batch)
    n_steps = batch_per_epoch * n_epochs
    half_batch = int(n_batch / 2)
    for i in range(n_steps):
        X_real, y_real = generate_real_samples(dataset, half_batch)
        d_loss1 = d_model.train_on_batch(X_real, y_real)
        X_fake, y_fake = generate_fake_samples(g_model, latent_dim, n_cat, half_batch)
        d_loss2 = d_model.train_on_batch(X_fake, y_fake)
        z_input, cat_codes = generate_latent_points(latent_dim, n_cat, n_batch)
        y_gan = ones((n_batch, 1))
        _, g_1, g_2 = gan_model.train_on_batch(z_input, [y_gan, cat_codes])
        print('>%d, d[%.3f, %.3f], g[%.3f], q[%.3f]' % (i+1, d_loss1, d_loss2, g_1, g_2))
        if (i + 1) % (batch_per_epoch * 10) == 0:
            summarize_performance(i, g_model, gan_model, dataset, latent_dim, n_cat)

In [17]:
n_cat = 10
latent_dim = 62
d_model, q_model = define_discriminator(n_cat)
gen_input_size = latent_dim + n_cat
g_model = define_generator(gen_input_size)
gan_model = define_gan(g_model, d_model, q_model)
dataset = load_real_samples()

In [20]:
train(g_model, d_model, gan_model, dataset, latent_dim, n_cat)

In [None]:
#Loading and predicting 
from math import sqrt
from keras.models import load_model
import numpy as np

In [None]:
def create_plot(examples, n_examples):
    for i in range(n_examples):
        plt.subplot(sqrt(n_examples), sqrt(n_examples), 1+i)
        plt.axis('off')
        plt.imshow(examples[i, :, :, 0], cmap='gray_r')
    plt.show()

In [None]:
model = load_model('model_93700.h5')
n_cat = 10
latent_dim = 62
n_samples = 100
z_input, _ = generate_latent_points(latent_dim, n_cat, n_samples)
X = model.predict(z_input)
X = X + 1 / 2.0
create_plot(X, n_samples)

In [None]:
#using the control variables
def generate_latent_points(latent_dim, n_cat, n_samples, digit):
    z_latent = randn(latent_dim * n_samples)
    z_latent = z_latent.reshape(n_samples, latent_dim)
    cat_codes = np.asarray([digit for _ in range(n_samples)])
    cat_codes = to_categorical(cat_codes, num_classes=n_cat)
    z_input = hstack((z_latent, cat_codes))
    return [z_input, cat_codes]

In [None]:
model = load_model('model_93700.h5')
n_cat = 10
latent_dim = 62
n_samples = 100
digit = 2
z_input, _ = generate_latent_points(latent_dim, n_cat, n_samples, digit)
X = model.predict(z_input)
X = X + 1 / 2.0
create_plot(X, n_samples)