In [1]:
import numpy as np
import matplotlib.pyplot as plt

import tensorflow

In [2]:
import tensorflow as tf
mnist = tf.keras.datasets.mnist
(X_train, y_train), (X_test, y_test) = mnist.load_data()

In [3]:
class Encoder:
    
    def __init__(self, dim_latent):
        self.dim_latent = dim_latent
        
    def encode_conv(self, input_image):
        with tf.variable_scope("encoder"):
            x = tf.keras.layers.Conv2D(filters=32,
                                       kernel_size=(4,4),
                                       strides=(2,2),
                                       padding="valid",
                                       activation=tf.nn.relu,
                                       name="conv1")(input_image)
            x = tf.keras.layers.Conv2D(filters=64,
                                       kernel_size=(2,2),
                                       strides=(2,2),
                                       padding="valid",
                                       activation=tf.nn.relu,
                                       name="conv2")(x)
            x = tf.keras.layers.Conv2D(filters=64,
                                       kernel_size=(3,3),
                                       strides=(1,1),
                                       padding="valid",
                                       activation=tf.nn.relu,
                                       name="conv3")(x)

            flat = tf.keras.layers.Flatten()(x)

            mean = tf.keras.layers.Dense(units=self.dim_latent,
                                         name="mean")(flat)
            std = tf.keras.layers.Dense(units=self.dim_latent,
                                        name="std")(flat)
            
            sample_normal = tf.random_normal(tf.shape(std))
            
            sample_latent = mean + std * sample_normal
            
            return sample_latent, mean, std
        
    def encode_mlp(self, input_image):
        with tf.variable_scope("encoder"):
            flat = tf.keras.layers.Flatten()(input_image)

            x = tf.keras.layers.Dense(units=512,
                                      activation=tf.nn.relu,
                                      name="fc_encoder1")(flat)
                
            x = tf.keras.layers.Dense(units=128,
                                      activation=tf.nn.relu,
                                      name="fc_encoder2")(x)

            mean = tf.keras.layers.Dense(units=self.dim_latent,
                                         name="mean")(x)
            std = tf.keras.layers.Dense(units=self.dim_latent,
                                        name="std")(x)
            
            sample_normal = tf.random_normal(tf.shape(std))
            
            sample_latent = mean + std * sample_normal
            
            return sample_latent, mean, std

In [4]:
class Decoder:
    
    def __init__(self, dim_latent):
        self.dim_latent = dim_latent
        
    def decode_conv(self, latent_vector):
        with tf.variable_scope("decoder"):
            x = tf.keras.layers.Dense(units=16,
                                      name="fc_decoder")(latent_vector)
            
            x = tf.reshape(x, (-1, 4, 4, 1))
            
            x = tf.keras.layers.Conv2DTranspose(filters=64,
                                                kernel_size=(3,3),
                                                strides=(1,1),
                                                padding="valid",
                                                activation=tf.nn.relu,
                                                name="deconv1")(x)
            x = tf.keras.layers.Conv2DTranspose(filters=64,
                                                kernel_size=(2,2),
                                                strides=(2,2),
                                                padding="valid",
                                                activation=tf.nn.relu,
                                                name="deconv2")(x)
            x = tf.keras.layers.Conv2DTranspose(filters=32,
                                                kernel_size=(4,4),
                                                strides=(2,2),
                                                padding="valid",
                                                activation=tf.nn.relu,
                                                name="deconv3")(x)

            flat = tf.keras.layers.Flatten()(x)
            x = tf.keras.layers.Dense(units=28*28,
                                      activation=tf.nn.sigmoid,
                                      name="fc_decoder")(flat)
            reconstruction = tf.reshape(x, (-1, 28, 28, 1))
            
            return reconstruction
        
    def decode_mlp(self, latent_vector):
        with tf.variable_scope("decoder"):
            x = tf.keras.layers.Dense(units=128,
                                      activation=tf.nn.relu,
                                      name="fc_decoder1")(latent_vector)
            
            x = tf.keras.layers.Dense(units=512,
                                      activation=tf.nn.relu,
                                      name="fc_decoder2")(x)
        
            x = tf.keras.layers.Dense(units=28*28,
                                      activation=tf.nn.sigmoid,
                                      name="fc_decoder3")(x)
            reconstruction = tf.reshape(x, (-1, 28, 28, 1))
            
            return reconstruction

In [None]:
class VAE:
    def __init__(self, input_im_shape, dim_latent):
        # Remove self where not needed
        self.input_im_shape = input_im_shape
        self.dim_latent = dim_latent
                
        self.encoder = Encoder(dim_latent)
        self.decoder = Decoder(dim_latent)
        
        self.original_image = tf.placeholder(tf.float32, (None, *(self.input_im_shape)), name="original_image")
        self.original_image_exp = tf.expand_dims(self.original_image, -1)
        
        self.latent_vec, mean, std = self.encoder.encode_mlp(self.original_image_exp)
        
        self.reconstruction = self.decoder.decode_mlp(self.latent_vec)

        # Losses
        self.coeff_latent_loss = tf.placeholder(tf.float32, (), name="coeff_latent_loss")
        self.reconstruction_loss = tf.reduce_mean(tf.math.squared_difference(self.reconstruction,
                                                                             self.original_image_exp))
        self.latent_loss = 0.5 * tf.reduce_mean(mean ** 2 + std ** 2 - tf.log(std ** 2) - 1)
        self.loss = self.reconstruction_loss + self.coeff_latent_loss * self.latent_loss

        # Optimization
        self.learning_rate = tf.placeholder(tf.float32, (), name="learning_rate")
        self.optimizer = tf.train.AdamOptimizer(self.learning_rate)
        self.train_op = self.optimizer.minimize(self.loss)
        
        # Summaries    
        tf.summary.scalar("reconstruction_loss", self.reconstruction_loss)
        tf.summary.scalar("latent_loss", self.latent_loss)
        tf.summary.scalar("loss", self.loss)
        tf.summary.image("reconstructions", self.reconstruction, 16)
        self.merged_summaries = tf.summary.merge_all()
        
    def gen_input(self, X_train, nb_steps, batch_size):  # TODO clean
        bin_batch = 0
        for _ in range(nb_steps):
            batch = X_train[bin_batch * batch_size : (bin_batch + 1) * batch_size] / 255
            bin_batch = bin_batch + 1 if (bin_batch + 2) * batch_size < len(X_train) else 0
            yield batch

    def train(self, X_train, batch_size, nb_steps, learning_rate, sess):
        saver = tf.train.Saver()
        summary_writer = tf.summary.FileWriter("./tensorboard/", sess.graph)
        
        step = 1
        coeff_latent_loss = 0.
        for batch in self.gen_input(X_train, nb_steps, batch_size):
            _, summaries = sess.run([self.train_op, self.merged_summaries],
                                    feed_dict={self.original_image: batch,
                                               self.learning_rate: learning_rate,
                                               self.coeff_latent_loss: coeff_latent_loss})
            
            if step > 2000:  # TODO clean
                coeff_latent_loss += 1 / (5 * (nb_steps - 2000))

            if step % 1000 == 0:
                "Save and write summaries"
                saver.save(sess, "./model/model.ckpt")
                summary_writer.add_summary(summaries, step)
            step += 1

In [None]:
im_shape = (28, 28)
dim_latent = 32
batch_size = 128
learning_rate = 1e-4

In [None]:
vae = VAE(im_shape, dim_latent)

Instructions for updating:
Colocations handled automatically by placer.


In [None]:
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    vae.train(X_train, batch_size, 20000, learning_rate, sess)