In [1]:
import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf

In [2]:
mnist = tf.keras.datasets.mnist
(X_train, y_train), (X_test, y_test) = mnist.load_data()

X_train = (X_train - 127.5) / 127.5
X_test = (X_test - 127.5) / 127.5

n_labeled = 100
X_labeled, X_unlabeled = X_train[:n_labeled], X_train[n_labeled:]
y_labeled = y_train[:n_labeled]

In [3]:
class Generator:
    
    def __init__(self):
        fc = tf.keras.layers.Dense(7*7*128, activation=tf.nn.relu, name="fc_generator")
        
        reshape = tf.keras.layers.Reshape((7, 7, 128))

        conv_t1 = tf.keras.layers.Conv2DTranspose(filters=64,
                                                  kernel_size=(5,5),
                                                  strides=(1,1),
                                                  padding='same',
                                                  name="conv_t_generator1")
        bn1 = tf.keras.layers.BatchNormalization()
        act1 = tf.keras.layers.ReLU()

        conv_t2 = tf.keras.layers.Conv2DTranspose(filters=32,
                                                  kernel_size=(5,5),
                                                  strides=(2,2),
                                                  padding='same',
                                                  name="conv_t_generator2")
        bn2 = tf.keras.layers.BatchNormalization()
        act2 = tf.keras.layers.ReLU()

        conv_t3 = tf.keras.layers.Conv2DTranspose(filters=1,
                                                  kernel_size=(5,5),
                                                  strides=(2,2),
                                                  activation=tf.nn.tanh,
                                                  padding='same',
                                                  name="conv_t_generator3")

        self.layers = [fc, reshape, conv_t1, bn1, act1, conv_t2, bn2, act2, conv_t3]
        
    def generate(self, rand_noise):
        x = rand_noise
        for layer in self.layers:
            x = layer(x)
        return x
    
class GeneratorMLP:
    
    def __init__(self):
        fc1 = tf.keras.layers.Dense(128, activation=tf.nn.relu, name="fc_generator1")
        fc2 = tf.keras.layers.Dense(512, activation=tf.nn.relu, name="fc_generator2")
        fc3 = tf.keras.layers.Dense(28*28, activation=tf.nn.tanh, name="fc_generator3")

        reshape = tf.keras.layers.Reshape((28, 28, 1))

        self.layers = [fc1, fc2, fc3, reshape]
        
    def generate(self, rand_noise):
        x = rand_noise
        for layer in self.layers:
            x = layer(x)
        return x

In [4]:
class Classifier:
        
    def __init__(self):
        dropout1 = tf.keras.layers.Dropout(rate=0.3)
        conv1 = tf.keras.layers.Conv2D(filters=32,
                                       kernel_size=(5,5),
                                       strides=(3,3),
                                       activation=tf.nn.leaky_relu,
                                       name="conv_discriminator1")

        dropout2 = tf.keras.layers.Dropout(rate=0.3)
        conv2 = tf.keras.layers.Conv2D(filters=64,
                                       kernel_size=(4,4),
                                       strides=(2,2),
                                       activation=tf.nn.leaky_relu,
                                       name="conv_discriminator2")
        
        dropout3 = tf.keras.layers.Dropout(rate=0.3)
        conv3 = tf.keras.layers.Conv2D(filters=128,
                                       kernel_size=(4,4),
                                       strides=(2,2),
                                       activation=tf.nn.leaky_relu,
                                       name="conv_discriminator3")
        
        flatten = tf.keras.layers.Flatten()
        
        fc = tf.keras.layers.Dense(units=1, name="fc_discriminator")

        self.layers = [dropout1, conv1, dropout2, conv2, dropout3, conv3, flatten, fc]
            
    def classify(self, image):
        x = image
        for layer in self.layers:
            x = layer(x)
        return x
    
class ClassifierMLP:
        
    def __init__(self, nb_classes):
        flatten = tf.keras.layers.Flatten()
        
        dropout1 = tf.keras.layers.Dropout(rate=0.3)
        fc1 = tf.keras.layers.Dense(128, activation=tf.nn.leaky_relu, name="fc_discriminator1")
          
        dropout2 = tf.keras.layers.Dropout(rate=0.3)
        fc2 = tf.keras.layers.Dense(64, activation=tf.nn.leaky_relu, name="fc_discriminator2")
        
        dropout3 = tf.keras.layers.Dropout(rate=0.3)
        fc3 = tf.keras.layers.Dense(nb_classes, name="fc_discriminator3")

        self.layers_features = [flatten, dropout1, fc1, dropout2, fc2]
        self.layers_classif = [dropout3, fc3]
            
    def classify(self, image, is_training=True):
        x = image
        # Extract features
        for layer in self.layers_features:
            x = layer(x)#, training=is_training)
        features = x
        # Classification
        for layer in self.layers_classif:
            x = layer(x)
        return features, x

In [5]:
class GAN:
    
    def __init__(self, im_shape, dim_noise, nb_classes):
        self.im_shape = im_shape
        self.dim_noise = dim_noise
        self.nb_classes = nb_classes

        with tf.variable_scope("GAN"):
            self.generator = GeneratorMLP()
            self.classifier = ClassifierMLP(self.nb_classes)

            self.batch_size = tf.placeholder(tf.int64, None, name="batch_size")

            # Labeled data from mnist
            self.labeled_image = tf.placeholder(tf.float32, (None, *(self.im_shape)), name="original_image")
            self.image_label = tf.placeholder(tf.int32, (None,), name="image_label")
            dataset_labeled = tf.data.Dataset.from_tensor_slices((self.labeled_image, self.image_label)).shuffle(10000).batch(self.batch_size).repeat()
            self.iterator_labeled = dataset_labeled.make_initializable_iterator()
            
            batch_images_labeled, batch_labels = self.iterator_labeled.get_next()
            batch_images_labeled = tf.expand_dims(batch_images_labeled, -1)
            batch_labels_oh = tf.one_hot(batch_labels, depth=self.nb_classes)
            
            # Unlabeled data from mnist
            self.unlabeled_image = tf.placeholder(tf.float32, (None, *(self.im_shape)), name="unlabeled_image")
            dataset_unlabled = tf.data.Dataset.from_tensor_slices(self.unlabeled_image).shuffle(10000).batch(self.batch_size).repeat()
            self.iterator_unlabeled = dataset_unlabled.make_initializable_iterator()
            
            batch_images_unlabeled = self.iterator_unlabeled.get_next()
            batch_images_unlabeled = tf.expand_dims(batch_images_unlabeled, -1)
            
            # Sample and generate fake images
            with tf.variable_scope("generator"):
                bs = tf.shape(batch_images_unlabeled)[0]
                rand_noise = tf.random_normal((bs, self.dim_noise), name="rand_noise")
                generated_images = self.generator.generate(rand_noise)

            # Use classifier
            with tf.variable_scope("classifier"):
                features_labeled, class_logits_labeled = self.classifier.classify(batch_images_labeled)
                features_unlabeled, _ = self.classifier.classify(batch_images_unlabeled)
                features_fake, class_logits_fake = self.classifier.classify(generated_images)

                Z_real = tf.reduce_logsumexp(class_logits_labeled, 1)
                discrim_real = tf.reduce_logsumexp(class_logits_labeled, axis=1) - tf.reduce_logsumexp(tf.pad(class_logits_labeled, [[0,0], [0,1]]), axis=1)
                discrim_fake = tf.reduce_logsumexp(class_logits_fake, axis=1) - tf.reduce_logsumexp(tf.pad(class_logits_fake, [[0,0], [0,1]]), axis=1)
                
                mean_features_unlabeled = tf.reduce_mean(features_unlabeled, axis=0)
                mean_features_fake = tf.reduce_mean(features_fake, axis=0)

            # Compute losses
            loss_classif_supervised = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=class_logits_labeled,
                                                                                                labels=batch_labels_oh))
            loss_classif_unsupervised = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=discrim_real,
                                                                                               labels=.9 * tf.ones_like(discrim_real))) + \
                                        tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=discrim_fake,
                                                                                               labels=tf.zeros_like(discrim_fake)))
            loss_classifier = loss_classif_supervised + loss_classif_unsupervised
            
            loss_generator = tf.reduce_mean((mean_features_unlabeled - mean_features_fake) ** 2)
            
            accuracy_classif_supervised = tf.reduce_mean(tf.cast(tf.equal(batch_labels,
                                                                          tf.argmax(class_logits_labeled, axis=1, output_type=tf.int32)),
                                                                 tf.float32))
            
            # Separate trainable variables
            generator_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="GAN/generator")
            classifier_variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope="GAN/classifier")

            # Optimization
            self.learning_rate = tf.placeholder(tf.float32, None, name="learning_rate")
            optimizer_generator = tf.train.AdamOptimizer(self.learning_rate)
            optimizer_classifier = tf.train.AdamOptimizer(self.learning_rate)

            self.classifier_train_op = optimizer_classifier.minimize(loss_classifier, var_list=classifier_variables)
            self.generator_train_op = optimizer_generator.minimize(loss_generator, var_list=generator_variables)
            
            # Validate classification on test set
            with tf.variable_scope("validation"):
                self.val_image = tf.placeholder(tf.float32, (None, *(self.im_shape)), name="val_image")
                self.val_label = tf.placeholder(tf.int32, (None,), name="val_label")
                self.val_batch_size = tf.placeholder(tf.int64, None, name="val_batch_size")
                val_dataset = tf.data.Dataset.from_tensor_slices((self.val_image, self.val_label)).shuffle(10000).batch(self.val_batch_size).repeat()
                self.iterator_val = val_dataset.make_initializable_iterator()

                val_batch_images, val_batch_labels = self.iterator_val.get_next()
                val_batch_images = tf.expand_dims(val_batch_images, -1)
            
                _, val_classif_logits = self.classifier.classify(val_batch_images, is_training=False)
                
                self.val_accuracy_classif = tf.reduce_mean(tf.cast(tf.equal(val_batch_labels,
                                                                            tf.argmax(val_classif_logits, axis=1, output_type=tf.int32)),
                                                                   tf.float32))
            
            # Summaries
            tf.summary.scalar("loss_classif_supervised", loss_classif_supervised)
            tf.summary.scalar("loss_classif_unsupervised", loss_classif_unsupervised)
            tf.summary.scalar("loss_classifier", loss_classifier)
            tf.summary.scalar("loss_generator", loss_generator)
            tf.summary.scalar("accuracy_classif_supervised", accuracy_classif_supervised)            
            tf.summary.scalar("val_accuracy_classif", self.val_accuracy_classif)            
            tf.summary.image("generated_images", (generated_images + 1) / 2, 16)
            self.merged_summaries = tf.summary.merge_all()
            
            self.saver = tf.train.Saver()
            
    def train(self, X_labeled, y_labeled, X_unlabeled, X_test, y_test,
              batch_size, val_batch_size, nb_steps, learning_rate, classifier_steps, log_every, save_every, sess):
        summary_writer = tf.summary.FileWriter("./tensorboard/", sess.graph)

        sess.run(self.iterator_labeled.initializer, feed_dict={self.labeled_image: X_labeled,
                                                               self.image_label: y_labeled,
                                                               self.batch_size: batch_size})
        sess.run(self.iterator_unlabeled.initializer, feed_dict={self.unlabeled_image: X_unlabeled,
                                                                 self.batch_size: batch_size})
        sess.run(self.iterator_val.initializer, feed_dict={self.val_image: X_test,
                                                           self.val_label: y_test,
                                                           self.val_batch_size: val_batch_size})
        
        for step in range(1, nb_steps + 1):
            # Train discriminator
            for k in range(classifier_steps):
                sess.run(self.classifier_train_op,
                         feed_dict={self.learning_rate: learning_rate,
                                    self.batch_size: batch_size})

            # Train generator
            sess.run(self.generator_train_op,
                     feed_dict={self.learning_rate: learning_rate,
                                self.batch_size: batch_size})
            
            if step % log_every == 0:
                # Eval and summaries
                _, summaries = sess.run([self.val_accuracy_classif, self.merged_summaries],
                                        feed_dict={self.learning_rate: learning_rate,
                                                   self.batch_size: batch_size})
                print("Write summaries")
                summary_writer.add_summary(summaries, step)
        
            if step % save_every == 0:
                print("Save model")
                self.saver.save(sess, "./model/model.ckpt")
                
    def restore(self, sess, ckpt_file):
        self.saver.restore(sess, ckpt_file)

In [6]:
im_shape = (28, 28)
dim_noise = 100
nb_classes = 10

batch_size = 128
learning_rate = 1e-5
val_batch_size = 512

classifier_steps = 1

nb_steps = 500000
log_every = 1000
save_every = 20000

In [None]:
gan = GAN(im_shape, dim_noise, nb_classes)

In [None]:
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())
    gan.restore(sess, "./model/model.ckpt")
    gan.train(X_labeled, y_labeled, X_unlabeled, X_test, y_test,
              batch_size, val_batch_size, nb_steps, learning_rate, classifier_steps,
              log_every, save_every, sess)

In [7]:
# Test only the classifier
with tf.Session() as sess:
    # Labeled data from mnist
    labeled_image = tf.placeholder(tf.float32, (None, 28, 28), name="image")
    image_label = tf.placeholder(tf.int32, (None,), name="image_label")
    dataset_labeled = tf.data.Dataset.from_tensor_slices((labeled_image, image_label)).shuffle(10000).batch(128).repeat()
    iterator_labeled = dataset_labeled.make_initializable_iterator()

    batch_images_labeled, batch_labels = iterator_labeled.get_next()
    batch_images_labeled = tf.expand_dims(batch_images_labeled, -1)
    batch_labels_oh = tf.one_hot(batch_labels, depth=nb_classes)
    
    plain_classifier = ClassifierMLP(nb_classes)
    _, class_logits_labeled = plain_classifier.classify(batch_images_labeled)
    
    loss_classif_supervised = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=class_logits_labeled,
                                                                                        labels=batch_labels_oh))
    accuracy_classif_supervised = tf.reduce_mean(tf.cast(tf.equal(batch_labels,
                                                                  tf.argmax(class_logits_labeled, axis=1, output_type=tf.int32)),
                                                         tf.float32))
            
    with tf.variable_scope("validation"):
        val_image = tf.placeholder(tf.float32, (None, *(im_shape)), name="val_image")
        val_label = tf.placeholder(tf.int32, (None,), name="val_label")
        val_dataset = tf.data.Dataset.from_tensor_slices((val_image, val_label)).shuffle(10000).batch(512).repeat()
        iterator_val = val_dataset.make_initializable_iterator()

        val_batch_images, val_batch_labels = iterator_val.get_next()
        val_batch_images = tf.expand_dims(val_batch_images, -1)

        _, val_classif_logits = plain_classifier.classify(val_batch_images, is_training=False)

        val_accuracy_classif = tf.reduce_mean(tf.cast(tf.equal(val_batch_labels,
                                                               tf.argmax(val_classif_logits, axis=1, output_type=tf.int32)),
                                                      tf.float32))
        
    train_op = tf.train.AdamOptimizer(1e-3).minimize(loss_classif_supervised)

    tf.summary.scalar("accuracy_classif_supervised", accuracy_classif_supervised)            
    tf.summary.scalar("val_accuracy_classif", val_accuracy_classif)            
    merged_summaries = tf.summary.merge_all()
    
    sess.run(iterator_labeled.initializer, feed_dict={labeled_image: X_labeled,
                                                      image_label: y_labeled})
    sess.run(iterator_val.initializer, feed_dict={val_image: X_test,
                                                  val_label: y_test})
    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())
    
    summary_writer = tf.summary.FileWriter("./tensorboard_plain_classif/", sess.graph)

    for step in range(10000):
        sess.run(train_op)
        
        if step % log_every == 0:
            # Eval and summaries
            val_acc, summaries = sess.run([val_accuracy_classif, merged_summaries])
            print(val_acc)
            summary_writer.add_summary(summaries, step)

Instructions for updating:
Colocations handled automatically by placer.
Instructions for updating:
Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.
0.107421875
0.6660156
0.671875
0.6972656
0.68359375
0.70703125
0.65625
0.6953125
0.6796875
0.6777344
