In [1]:
import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf

In [2]:
mnist = tf.keras.datasets.mnist
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = np.concatenate((X_train, X_test)) / 255.

In [3]:
class Encoder:
    
    def __init__(self, dim_latent):
        self.dim_latent = dim_latent
        
    def encode_conv(self, input_image):
        with tf.variable_scope("encoder"):
            x = tf.keras.layers.Conv2D(filters=32,
                                       kernel_size=(4,4),
                                       strides=(2,2),
                                       padding="valid",
                                       activation=tf.nn.relu,
                                       name="conv1")(input_image)
            x = tf.keras.layers.Conv2D(filters=64,
                                       kernel_size=(2,2),
                                       strides=(2,2),
                                       padding="valid",
                                       activation=tf.nn.relu,
                                       name="conv2")(x)
            x = tf.keras.layers.Conv2D(filters=64,
                                       kernel_size=(3,3),
                                       strides=(1,1),
                                       padding="valid",
                                       activation=tf.nn.relu,
                                       name="conv3")(x)

            flat = tf.keras.layers.Flatten()(x)

            mean = tf.keras.layers.Dense(units=self.dim_latent,
                                         name="mean")(flat)
            std = tf.keras.layers.Dense(units=self.dim_latent,
                                        name="std")(flat)
            
            sample_normal = tf.random_normal(tf.shape(std))
            
            sample_latent = mean + std * sample_normal
            
            return sample_latent, mean, std
        
    def encode_mlp(self, input_image):
        with tf.variable_scope("encoder"):
            flat = tf.keras.layers.Flatten()(input_image)

            
            x = tf.keras.layers.Dense(units=256,
                                      activation=tf.nn.relu,
                                      name="fc_encoder1")(flat)
            
            mean = tf.keras.layers.Dense(units=self.dim_latent,
                                         name="mean")(x)
            log_var = tf.keras.layers.Dense(units=self.dim_latent,
                                        name="std")(x)
            
            sample_normal = tf.random_normal(tf.shape(log_var))
            
            sample_latent = mean + tf.exp(log_var / 2) * sample_normal
            
            return sample_latent, mean, log_var

In [4]:
class Decoder:
    
    #def __init__(self, dim_latent):
    #    self.dim_latent = dim_latent
        
    def decode_conv(self, latent_vector):
        with tf.variable_scope("decoder"):
            x = tf.keras.layers.Dense(units=16,
                                      name="fc_decoder")(latent_vector)
            
            x = tf.reshape(x, (-1, 4, 4, 1))
            
            x = tf.keras.layers.Conv2DTranspose(filters=64,
                                                kernel_size=(3,3),
                                                strides=(1,1),
                                                padding="valid",
                                                activation=tf.nn.relu,
                                                name="deconv1")(x)
            x = tf.keras.layers.Conv2DTranspose(filters=64,
                                                kernel_size=(2,2),
                                                strides=(2,2),
                                                padding="valid",
                                                activation=tf.nn.relu,
                                                name="deconv2")(x)
            x = tf.keras.layers.Conv2DTranspose(filters=32,
                                                kernel_size=(4,4),
                                                strides=(2,2),
                                                padding="valid",
                                                activation=tf.nn.relu,
                                                name="deconv3")(x)

            flat = tf.keras.layers.Flatten()(x)
            x = tf.keras.layers.Dense(units=28*28,
                                      activation=tf.nn.sigmoid,
                                      name="fc_decoder")(flat)
            reconstruction = tf.reshape(x, (-1, 28, 28, 1))
            
            return reconstruction
        
    def decode_mlp(self, latent_vector):
        with tf.variable_scope("decoder"):
            x = tf.keras.layers.Dense(units=256,
                                      activation=tf.nn.relu,
                                      name="fc_decoder1")(latent_vector)
            
            x = tf.keras.layers.Dense(units=28*28,
                                      activation=tf.nn.sigmoid,
                                      name="fc_decoder3")(x)
            reconstruction = tf.reshape(x, (-1, 28, 28, 1))
            
            return reconstruction

In [5]:
class VAE:
    def __init__(self, input_im_shape, dim_latent):
        # Remove self where not needed
        self.input_im_shape = input_im_shape
        self.dim_latent = dim_latent
        
        self.encoder = Encoder(dim_latent)
        self.decoder = Decoder(dim_latent)

        # Data from mnist
        self.original_image = tf.placeholder(tf.float32, (None, *(self.input_im_shape)), name="original_image")
        self.batch_size = tf.placeholder(tf.int64, None, name="batch_size")
        self.dataset = tf.data.Dataset.from_tensor_slices(self.original_image).shuffle(10000).batch(self.batch_size).repeat()
        self.iterator = self.dataset.make_initializable_iterator()

        self.original_image_exp = tf.expand_dims(self.iterator.get_next(), -1)

        self.latent_vec, mean, log_var = self.encoder.encode_mlp(self.original_image_exp)

        self.reconstruction = self.decoder.decode_mlp(self.latent_vec)

        # Losses
#         self.reconstruction_loss = tf.reduce_mean(tf.math.squared_difference(self.reconstruction,
#                                                                              self.original_image_exp))
        self.reconstruction_loss = tf.reduce_mean(tf.keras.backend.binary_crossentropy(self.original_image_exp,
                                                                                       self.reconstruction))
            
        self.coeff_latent_loss = tf.placeholder(tf.float32, None, name="coeff_latent_loss")
        self.latent_loss = 0.5 * tf.reduce_mean(mean ** 2 + tf.exp(log_var) - log_var - 1)
            
        self.loss = self.reconstruction_loss + self.coeff_latent_loss * self.latent_loss

        # Optimization
        self.learning_rate = tf.placeholder(tf.float32, None, name="learning_rate")
        self.optimizer = tf.train.AdamOptimizer(self.learning_rate)
        self.train_op = self.optimizer.minimize(self.loss)
        
        # Generation from random samples to evaluate
        self.latent_samples = tf.random_normal((16, self.dim_latent))
        self.reconstructions_from_samples = self.decoder.decode_mlp(self.latent_samples)

        # Summaries   
        tf.summary.scalar("reconstruction_loss", self.reconstruction_loss)
        tf.summary.scalar("latent_loss", self.latent_loss)
        tf.summary.scalar("loss", self.loss)
        tf.summary.scalar("mean", tf.reduce_mean(mean))
        tf.summary.scalar("log_var", tf.reduce_mean(log_var))
        tf.summary.image("train_images", self.original_image_exp, 16)
        tf.summary.image("reconstructions", self.reconstruction, 16)
        tf.summary.image("reconstructions_from_samples", self.reconstructions_from_samples, 16)
        self.merged_summaries = tf.summary.merge_all()


    def train(self, X_train, batch_size, nb_steps, learning_rate, save_every, sess):
        saver = tf.train.Saver()
        summary_writer = tf.summary.FileWriter("./tensorboard/", sess.graph)

        sess.run(self.iterator.initializer, feed_dict={self.original_image: X_train,
                                                       self.batch_size: batch_size})

        for step in range(1, nb_steps + 1):
            coeff_latent_loss = min(0.5 * step / 100000, 0.5)
#             coeff_latent_loss = 0.15
            
            _, summaries = sess.run([self.train_op, self.merged_summaries],
                                    feed_dict={self.learning_rate: learning_rate,
                                               self.coeff_latent_loss: coeff_latent_loss})

            if step % save_every == 0:
                print("Save and write summaries")
                saver.save(sess, "./model/model.ckpt")
                summary_writer.add_summary(summaries, step)
        
    def restore(self, ckpt_file, sess):
        saver = tf.train.Saver()
        saver.restore(sess, ckpt_file)

In [6]:
im_shape = (28, 28)
dim_latent = 16
batch_size = 256
learning_rate = 4e-4

save_every = 2500

In [7]:
vae = VAE(im_shape, dim_latent)

Instructions for updating:
Colocations handled automatically by placer.


In [8]:
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    vae.train(X_train, batch_size, 400000, learning_rate, save_every, sess)

Save and write summaries
Save and write summaries
Save and write summaries
Save and write summaries
Save and write summaries
Save and write summaries
Save and write summaries
Save and write summaries
Save and write summaries
Save and write summaries
Save and write summaries
Save and write summaries
Save and write summaries
Save and write summaries
Save and write summaries
Save and write summaries
Save and write summaries
Save and write summaries


KeyboardInterrupt: 

In [None]:
# Random spampling from latent space
nb_samples = 64
with tf.Session() as sess:
    vae.restore("./model/model.ckpt", sess)
    latent_samples = np.random.randn(nb_samples, dim_latent)

    generated_images = sess.run(vae.reconstruction,
                                feed_dict={vae.latent_vec: latent_samples})

    print(np.squeeze(generated_images[0]))

    for im in generated_images:
        plt.figure()
        plt.imshow(np.squeeze(im), vmin=0, vmax=1)
        plt.show()