In [None]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Conv2D, Conv2DTranspose, Dense, Flatten, BatchNormalization, LeakyReLU, ReLU, Reshape
from tensorflow.keras.optimizers import RMSprop
from tensorflow.keras import backend
# from tensorflow.keras.optimizers import SGD
# from tensorflow.keras.utils import to_categorical

In [None]:
tf.enable_eager_execution()
tf.executing_eagerly()

In [None]:
import tensorflow as tf
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        # Currently, memory growth needs to be the same across GPUs
        for gpu in gpus:
          tf.config.experimental.set_memory_growth(gpu, True)
        logical_gpus = tf.config.experimental.list_logical_devices('GPU')
        print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
    except RuntimeError as e:
        # Memory growth must be set before GPUs have been initialized
        print(e)

In [None]:
saved_model = tf.keras.models.load_model("gan_evaluation/gan_evaluation_model.h5")

In [None]:
latent_shape = 100
batch_size = 32

In [None]:
(x_train_full, y_train_full), (x_test_full, y_test_full) = tf.keras.datasets.mnist.load_data()

# Find indices of labels 5 to 9
train_index = np.squeeze(np.argwhere(y_train_full>=5))
# test_index = np.squeeze(np.argwhere(y_test_full>=5))

(x_train, y_train) = (x_train_full[train_index], y_train_full[train_index])
# (x_test, y_test) = (x_test_full[test_index], y_test_full[test_index])

x_train_n = np.expand_dims(x_train/255, axis=-1).astype('float32')
# x_test_n = np.expand_dims(x_test/255, axis=-1).astype('float32')

num_classes = 5

# convert class vectors to binary class matrices
# y_train_b = to_categorical(y_train-5, num_classes)
# y_test_b = to_categorical(y_test-5, num_classes)

dataset = tf.data.Dataset.from_tensor_slices(x_train_n)
dataset = dataset.batch(batch_size, drop_remainder=True).prefetch(1)

In [None]:
critic = tf.keras.Sequential(name="critic")
critic.add(Conv2D(64, (4, 4), strides=(2, 2), padding="same", input_shape=(28, 28, 1), name="c_conv_1"))
# critic.add(BatchNormalization(name="c_batch_norm_1"))
critic.add(LeakyReLU(alpha=0.2, name="c_leaky_1"))
critic.add(Conv2D(32, (4, 4), strides=(2, 2), padding="same", name="c_conv_2"))
critic.add(BatchNormalization(name="c_batch_norm_2"))
critic.add(LeakyReLU(alpha=0.2, name="c_leaky_2"))
critic.add(Conv2D(16, (4, 4), strides=(2, 2), padding="same", name="c_conv_3"))
critic.add(BatchNormalization(name="c_batch_norm_3"))
critic.add(LeakyReLU(alpha=0.2, name="c_leaky_3"))
critic.add(Flatten(name="c_flatten"))
critic.add(Dense(128, name="c_dense_1"))
critic.add(Dense(1, name="c_dense_out"))
critic.add(LeakyReLU(alpha=0.2, name="c_leaky_out"))

In [None]:
critic.summary()

In [None]:
generator = tf.keras.Sequential(name="generator")
generator.add(Dense(100, input_shape=[latent_shape], name="g_dense_in"))
generator.add(ReLU(name="g_relu_1"))
generator.add(Dense(128*7*7, input_shape=[latent_shape], name="g_dense_2"))
generator.add(ReLU(name="g_relu_2"))
generator.add(Reshape((7, 7, 128)))
generator.add(Conv2DTranspose(64, (4, 4),strides=(2, 2), padding="same", name="g_convT_1"))
generator.add(BatchNormalization(name="g_batch_norm_1"))
generator.add(ReLU(name="g_relu_3"))
generator.add(Conv2DTranspose(1, (4, 4), strides=(2, 2), padding="same", activation="tanh", name="g_convT_2"))
generator.add(BatchNormalization(name="g_batch_norm_2"))
# generator.add(ReLU(name="g_relu_4"))
# generator.add(Conv2D(1, (7,7), activation="tanh", padding='same', name="g_conv_out"))

In [None]:
generator.summary()

In [None]:
gan = tf.keras.Sequential([generator, critic], name="gan")
gan.summary()

In [None]:
def wasserstein_loss(y_true, y_pred):
    return backend.mean(y_true * y_pred)

In [None]:
critic.compile(loss=wasserstein_loss, optimizer=RMSprop(lr=0.00005))
# critic.trainable = False
gan.compile(loss=wasserstein_loss, optimizer=RMSprop(lr=0.00005))

In [None]:
def train_gan(gan, dataset, batch_size, latent_shape, n_epochs=50, n_critic=5):
    c_hist, g_hist = list(), list()
    generator, critic = gan.layers
    for epoch in range(n_epochs):
        for X_batch in dataset:
            # phase 1 - training the critic
            c_tmp = list()
            for _ in range(n_critic):
                noise = tf.random.normal(shape=[batch_size, latent_shape])
                generated_images = generator(noise)
                X_fake_and_real = tf.concat([generated_images, X_batch], axis=0)
                y1 = tf.constant([[1.]] * batch_size + [[-1.]] * batch_size)
                critic.trainable = True
                c_loss = critic.train_on_batch(X_fake_and_real, y1)
                c_tmp.append(c_loss)
                
                # Clip critic weights
                clip_value = 0.01
                for l in critic.layers:
                    weights = l.get_weights()
                    weights = [np.clip(w, -clip_value, clip_value) for w in weights]
                    l.set_weights(weights)
                    
            c_hist.append(np.mean(c_tmp))
            # phase 2 - training the generator
            noise = tf.random.normal(shape=[batch_size, latent_shape])
            y2 = tf.constant([[1.]] * batch_size)
            critic.trainable = False
            g_loss = gan.train_on_batch(noise, y2)
            g_hist.append(g_loss)
        print('>%d, c=%.3f, g=%.3f' % (n_epochs, c_hist[-1], g_loss))

In [None]:
train_gan(gan, dataset, batch_size, latent_shape)