In [None]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.initializers import RandomNormal
from tensorflow.keras.layers import Conv2D, Conv2DTranspose, Dense, Flatten, BatchNormalization, LeakyReLU, ReLU, Reshape, Dropout
from tensorflow.keras.optimizers import RMSprop
from tensorflow.keras import backend
from tensorflow.keras.constraints import Constraint
from tensorflow.keras.utils import to_categorical
import matplotlib.pyplot as plt
import os

In [None]:
tf.enable_eager_execution()
tf.executing_eagerly()

In [None]:
import tensorflow as tf
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        # Currently, memory growth needs to be the same across GPUs
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        logical_gpus = tf.config.experimental.list_logical_devices('GPU')
        print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
    except RuntimeError as e:
        # Memory growth must be set before GPUs have been initialized
        print(e)

In [None]:
saved_model = tf.keras.models.load_model("gan_evaluation/gan_evaluation_model.h5")

In [None]:
latent_shape = 50
batch_size = 32
n_epochs=50

In [None]:
(x_train_full, y_train_full), (x_test_full, y_test_full) = tf.keras.datasets.mnist.load_data()

# Find indices of labels 5 to 9
train_index = np.squeeze(np.argwhere(y_train_full>=5))

(x_train, y_train) = (x_train_full[train_index], y_train_full[train_index])

x_train_n = np.expand_dims(x_train/255, axis=-1).astype('float32')

num_classes = 5

# convert class vectors to binary class matrices
y_train_b = to_categorical(y_train-5, num_classes)

dataset = tf.data.Dataset.from_tensor_slices(x_train_n)
dataset = dataset.shuffle(buffer_size=100).batch(batch_size, drop_remainder=True).prefetch(1)

In [None]:
init = RandomNormal(stddev=0.02)
clip_value = 0.01

class ClipConstraint(Constraint):
    """Constrains the weights to vary between + - clip_value.
    """
    
    def __init__(self, clip_value):
        self.clip_value = clip_value
        
    def __call__(self, w):
        return backend.clip(w, -self.clip_value, self.clip_value)

In [None]:
clip = ClipConstraint(0.01)

In [None]:
critic = tf.keras.Sequential(name="critic")

critic.add(Conv2D(64, (4, 4), strides=(2, 2), padding="same", kernel_initializer=init, kernel_constraint=clip, input_shape=(28, 28, 1)))
critic.add(BatchNormalization())
critic.add(LeakyReLU(alpha=0.2))

critic.add(Conv2D(64, (4, 4), strides=(2, 2), padding="same", kernel_initializer=init, kernel_constraint=clip))
critic.add(BatchNormalization())
critic.add(LeakyReLU(alpha=0.2))

critic.add(Flatten())
critic.add(Dense(1))#, activation="sigmoid"))

In [None]:
critic.summary()

In [None]:
generator = tf.keras.Sequential(name="generator")

generator.add(Dense(128*7*7, input_shape=[latent_shape]))
generator.add(LeakyReLU(alpha=0.2))
generator.add(Reshape((7, 7, 128)))

generator.add(Conv2DTranspose(128, (4, 4),strides=(2, 2), padding="same", kernel_initializer=init))
generator.add(BatchNormalization())
generator.add(LeakyReLU(alpha=0.2))

generator.add(Conv2DTranspose(128, (4, 4),strides=(2, 2), padding="same", kernel_initializer=init))
generator.add(BatchNormalization())
generator.add(LeakyReLU(alpha=0.2))

generator.add(Conv2DTranspose(1, (7, 7), padding="same", kernel_initializer=init, activation="tanh"))


In [None]:
generator.summary()

In [None]:
gan = tf.keras.Sequential([generator, critic], name="gan")
gan.summary()

In [None]:
def wasserstein_loss(y_true, y_pred):
    return backend.mean(y_true * y_pred)

In [None]:
critic.compile(loss=wasserstein_loss, optimizer=RMSprop(lr=0.0005, momentum=0.9))
critic.trainable = False
gan.compile(loss=wasserstein_loss, optimizer=RMSprop(lr=0.0005, momentum=0.9))

In [None]:
def train_gan(gan, dataset, batch_size, latent_shape, n_epochs=n_epochs, n_critic=5):
    c_hist, g_hist = list(), list()
    generator, critic = gan.layers
    for epoch in range(n_epochs):
        for X_batch in dataset:
            
            # phase 1 - training the critic
            c_tmp = list()
            for _ in range(n_critic):
                noise = tf.random.normal(shape=[batch_size, latent_shape])
                generated_images = generator(noise)
                
                X_fake_and_real = tf.concat([generated_images, X_batch], axis=0)
                y1 = tf.constant([[1.]] * batch_size + [[-1.]] * batch_size)
                
                critic.trainable = True
                c_loss = critic.train_on_batch(X_fake_and_real, y1)
                c_tmp.append(c_loss)
                                    
            c_hist.append(np.mean(c_tmp))
            
            # phase 2 - training the generator
            noise = tf.random.normal(shape=[batch_size, latent_shape])
            y2 = tf.constant([[-1.]] * batch_size)
            critic.trainable = False
            g_loss = gan.train_on_batch(noise, y2)
            g_hist.append(g_loss)
        
        print('>%d, c=%.3f, g=%.3f' % (epoch, c_hist[-1], g_loss))
        
        # predictions
        # scale from [-1,1] to [0,1]
        generated_images = (generated_images + 1) / 2.0
        y_prob = saved_model.predict(generated_images)
        for i in range(batch_size):
            label = np.argmax(y_prob[i,:])+5
            prob = np.max(y_prob[i,:])
            # PLOT
            plt.subplot(4, 8, 1 + i)
            plt.axis('off')
            plt.imshow(generated_images[i, :, :, 0], cmap='gray_r', shape=(500, 500))
            plt.title("{}-{:.2f}".format(label, prob), size=5, pad=1)
            # save plot to file
        filename1 = 'eval_gan_plot_%04d.png' % (epoch+1)
        plt.savefig(os.path.join("gan_evaluation", filename1), dpi=200)
        plt.close()
        # save the generator model
        filename2 = os.path.join("gan_evaluation",'attack_gan_model_%04d.h5' % (epoch+1))
        gan.save(filename2)
        # save the critic model
        filename3 = os.path.join("gan_evaluation",'attack_critic_model_weights_%04d.h5' % (epoch+1))
        critic.save_weights(filename3)
        

In [None]:
train_gan(gan, dataset, batch_size, latent_shape)