In [None]:
import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf

In [None]:
mnist = tf.keras.datasets.mnist
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = np.concatenate((X_train, X_test))
X_train = (X_train - 127.5) / 127.5
y_train = np.concatenate((y_train, y_test))

In [None]:
class Generator:
    
    def __init__(self):
        fc = tf.keras.layers.Dense(7*7*128, activation=tf.nn.relu, name="fc_generator")
        
        reshape = tf.keras.layers.Reshape((7, 7, 128))

        conv_t1 = tf.keras.layers.Conv2DTranspose(filters=64,
                                                  kernel_size=(5,5),
                                                  strides=(1,1),
                                                  padding='same',
                                                  name="conv_t_generator1")
        bn1 = tf.keras.layers.BatchNormalization()
        act1 = tf.keras.layers.ReLU()

        conv_t2 = tf.keras.layers.Conv2DTranspose(filters=32,
                                                  kernel_size=(5,5),
                                                  strides=(2,2),
                                                  padding='same',
                                                  name="conv_t_generator2")
        bn2 = tf.keras.layers.BatchNormalization()
        act2 = tf.keras.layers.ReLU()

        conv_t3 = tf.keras.layers.Conv2DTranspose(filters=1,
                                                  kernel_size=(5,5),
                                                  strides=(2,2),
                                                  activation=tf.nn.tanh,
                                                  padding='same',
                                                  name="conv_t_generator3")

        self.layers = [fc, reshape, conv_t1, bn1, act1, conv_t2, bn2, act2, conv_t3]
        
    def generate(self, rand_noise, class_vec):
        x = tf.concat([rand_noise, class_vec], axis=1)
        for layer in self.layers:
            x = layer(x)
        return x
    
class GeneratorMLP:
    
    def __init__(self):
        fc1 = tf.keras.layers.Dense(128, activation=tf.nn.relu, name="fc_generator1")
        fc2 = tf.keras.layers.Dense(512, activation=tf.nn.relu, name="fc_generator2")
        fc3 = tf.keras.layers.Dense(28*28, activation=tf.nn.tanh, name="fc_generator3")

        reshape = tf.keras.layers.Reshape((28, 28, 1))

        self.layers = [fc1, fc2, fc3, reshape]
        
    def generate(self, rand_noise, class_vec):
        x = tf.concat([rand_noise, class_vec], axis=1)
        for layer in self.layers:
            x = layer(x)
        return x

In [None]:
class Discriminator:
        
    def __init__(self):
        dropout1 = tf.keras.layers.Dropout(rate=0.4)
        conv1 = tf.keras.layers.Conv2D(filters=32,
                                       kernel_size=(5,5),
                                       strides=(3,3),
                                       activation=tf.nn.leaky_relu,
                                       name="conv_discriminator1")

        dropout2 = tf.keras.layers.Dropout(rate=0.2)
        conv2 = tf.keras.layers.Conv2D(filters=64,
                                       kernel_size=(5,5),
                                       strides=(3,3),
                                       activation=tf.nn.leaky_relu,
                                       name="conv_discriminator2")

        flatten = tf.keras.layers.Flatten()
        
        # Discriminator
        fc_discr = tf.keras.layers.Dense(1, name="fc_discriminator")
        # Classifier
        fc_class = tf.keras.layers.Dense(nb_classes, name="fc_classifier")
        
        self.layers_common = [dropout1, conv1, dropout2, conv2, flatten]
        self.layers_discr = [fc_discr]
        self.layers_class = [fc_class] 
    
    def discriminate(self, image):
        x = image
        for layer in self.layers_common:
            x = layer(x)
            
        x_discr = x
        x_class = x
        for layer in self.layers_discr:
            x_discr = layer(x_discr)
        for layer in self.layers_class:
            x_class = layer(x_class)
            
        return x_discr, x_class
    
class DiscriminatorMLP:
        
    def __init__(self):
        flatten = tf.keras.layers.Flatten()
        
        dropout1 = tf.keras.layers.Dropout(rate=0.4)
        fc1 = tf.keras.layers.Dense(128, activation=tf.nn.leaky_relu, name="fc_discriminator1")
                
        dropout2 = tf.keras.layers.Dropout(rate=0.2)
        #fc2 = tf.keras.layers.Dense(64, activation=tf.nn.leaky_relu, name="fc_discriminator2")

        # Discriminator
        fc_discr = tf.keras.layers.Dense(1, name="fc_discriminator2")
        # Classifier
        fc_class = tf.keras.layers.Dense(nb_classes, name="fc_classifier2")
        
        self.layers_common = [flatten, dropout1, fc1, dropout2]
        self.layers_discr = [fc_discr]
        self.layers_class = [fc_class]

    def discriminate(self, image):
        x = image
        for layer in self.layers_common:
            x = layer(x)
            
        x_discr = x
        x_class = x
        for layer in self.layers_discr:
            x_discr = layer(x_discr)
        for layer in self.layers_class:
            x_class = layer(x_class)
            
        return x_discr, x_class

In [None]:
class ACGAN:
    
    def __init__(self, original_im_shape, dim_noise, nb_classes):
        self.original_im_shape = original_im_shape
        self.dim_noise = dim_noise
        self.nb_classes = nb_classes

        with tf.variable_scope("AC-GAN"):
            self.generator = Generator()
            self.discriminator = Discriminator()

            # Data from mnist
            self.original_image = tf.placeholder(tf.float32, (None, *(self.original_im_shape)), name="original_image")
            self.image_label = tf.placeholder(tf.int32, (None,), name="image_label")
            self.batch_size = tf.placeholder(tf.int64, None, name="batch_size")
            self.dataset = tf.data.Dataset.from_tensor_slices((self.original_image, self.image_label)).shuffle(10000).batch(self.batch_size).repeat()
            self.iterator = self.dataset.make_initializable_iterator()

            self.batch_images, self.batch_labels = self.iterator.get_next()
            self.original_image_exp = tf.expand_dims(self.batch_images, -1)
            self.batch_labels_oh = tf.one_hot(self.batch_labels, depth=self.nb_classes)
            
            # Sample and generate fake images
            with tf.variable_scope("generator"):
                bs = tf.shape(self.original_image_exp)[0]
                self.rand_noise = tf.random_normal((bs, self.dim_noise), name="rand_noise")
                self.rand_class = tf.random.uniform((bs,), maxval=self.nb_classes, dtype=tf.int32, name="rand_class")
                self.rand_class_oh = tf.one_hot(self.rand_class, depth=self.nb_classes, dtype=tf.float32, name="rand_class_oh")
                self.generated_images = self.generator.generate(self.rand_noise, self.rand_class_oh)

            # Use discriminator
            with tf.variable_scope("discriminator"):
                self.prob_true_real, self.class_real = self.discriminator.discriminate(self.original_image_exp)
                self.prob_true_fake, self.class_fake = self.discriminator.discriminate(self.generated_images)

            # Compute losses
            self.loss_classif = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.class_real,
                                                                                       labels=self.batch_labels_oh) +
                                               tf.nn.sigmoid_cross_entropy_with_logits(logits=self.class_fake,
                                                                                       labels=self.rand_class_oh))
            self.loss_generator = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.prob_true_fake,
                                                                                         labels=tf.ones_like(self.prob_true_fake)))
            self.loss_discriminator = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.prob_true_real,
                                                                                             labels=.9 * tf.ones_like(self.prob_true_real)) +
                                                     tf.nn.sigmoid_cross_entropy_with_logits(logits=self.prob_true_fake,
                                                                                             labels=tf.zeros_like(self.prob_true_fake)))
            self.loss_discriminator += self.loss_classif
            self.loss_generator += self.loss_classif
            
            # Separate trainable variables
            self.generator_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="AC-GAN/generator")
            self.discriminator_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="AC-GAN/discriminator")

            # Optimization
            self.learning_rate = tf.placeholder(tf.float32, None, name="learning_rate")
            self.optimizer_generator = tf.train.AdamOptimizer(self.learning_rate)
            self.optimizer_discriminator = tf.train.AdamOptimizer(self.learning_rate)

            self.generator_train_op = self.optimizer_generator.minimize(self.loss_generator, var_list=self.generator_variables)
            self.discriminator_train_op = self.optimizer_discriminator.minimize(self.loss_discriminator, var_list=self.discriminator_variables)
            
            # Summaries   
            tf.summary.scalar("loss_generator", self.loss_generator)
            tf.summary.scalar("loss_discriminator", self.loss_discriminator)
            tf.summary.scalar("loss_classif", self.loss_classif)
            tf.summary.image("generated_images", (self.generated_images + 1) / 2, 16)
            self.merged_summaries = tf.summary.merge_all()
            
            self.saver = tf.train.Saver()
            
    def train(self, X_train, y_train, batch_size, nb_steps, learning_rate, discriminator_steps, log_every, save_every, sess):
        summary_writer = tf.summary.FileWriter("./tensorboard/", sess.graph)

        sess.run(self.iterator.initializer, feed_dict={self.original_image: X_train,
                                                       self.image_label: y_train,
                                                       self.batch_size: batch_size})
        
        for step in range(1, nb_steps + 1):
            # Train discriminator
            for k in range(discriminator_steps):
                _ = sess.run(self.discriminator_train_op,
                             feed_dict={self.learning_rate: learning_rate,
                                        self.batch_size: batch_size})
            # Train generator
            _, summaries = sess.run([self.generator_train_op, self.merged_summaries],
                                     feed_dict={self.learning_rate: learning_rate,
                                                self.batch_size: batch_size})
            
            if step % log_every == 0:
                print("Write summaries")
                summary_writer.add_summary(summaries, step)
        
            if step % save_every == 0:
                print("Save model")
                self.saver.save(sess, "./model/model.ckpt")
                
    def restore(self, sess, ckpt_file):
        self.saver.restore(sess, ckpt_file)

In [None]:
original_im_shape = (28, 28)
dim_noise = 100
nb_classes = 10

batch_size = 128
learning_rate = 4e-4

discriminator_steps = 1

nb_steps = 50000
log_every = 200
save_every = 1000

In [None]:
acgan = ACGAN(original_im_shape, dim_noise, nb_classes)

In [None]:
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    #acgan.restore(sess, "./model/model.ckpt")
    acgan.train(X_train, y_train, batch_size, nb_steps, learning_rate, discriminator_steps, log_every, save_every, sess)