In [1]:
import numpy as np
import matplotlib.pyplot as plt

import tensorflow

In [2]:
import tensorflow as tf
mnist = tf.keras.datasets.mnist
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = np.concatenate((X_train, X_test)) / 255.

In [3]:
class Encoder:
    
    def __init__(self, dim_latent):
        self.dim_latent = dim_latent
        
    def encode_conv(self, input_image):
        with tf.variable_scope("encoder"):
            x = tf.keras.layers.Conv2D(filters=32,
                                       kernel_size=(4,4),
                                       strides=(2,2),
                                       padding="valid",
                                       activation=tf.nn.relu,
                                       name="conv1")(input_image)
            x = tf.keras.layers.Conv2D(filters=64,
                                       kernel_size=(2,2),
                                       strides=(2,2),
                                       padding="valid",
                                       activation=tf.nn.relu,
                                       name="conv2")(x)
            x = tf.keras.layers.Conv2D(filters=64,
                                       kernel_size=(3,3),
                                       strides=(1,1),
                                       padding="valid",
                                       activation=tf.nn.relu,
                                       name="conv3")(x)

            flat = tf.keras.layers.Flatten()(x)

            mean = tf.keras.layers.Dense(units=self.dim_latent,
                                         name="mean")(flat)
            std = tf.keras.layers.Dense(units=self.dim_latent,
                                        name="std")(flat)
            
            sample_normal = tf.random_normal(tf.shape(std))
            
            sample_latent = mean + std * sample_normal
            
            return sample_latent, mean, std
        
    def encode_mlp(self, input_image):
        with tf.variable_scope("encoder"):
            flat = tf.keras.layers.Flatten()(input_image)

            
            x = tf.keras.layers.Dense(units=256,
                                      activation=tf.nn.relu,
                                      name="fc_encoder1")(flat)
            
            mean = tf.keras.layers.Dense(units=self.dim_latent,
                                         name="mean")(x)
            log_var = tf.keras.layers.Dense(units=self.dim_latent,
                                        name="std")(x)
            
            sample_normal = tf.random_normal(tf.shape(log_var))
            
            sample_latent = mean + tf.exp(log_var / 2) * sample_normal
            
            return sample_latent, mean, log_var

In [4]:
class Decoder:
    
    def __init__(self, dim_latent):
        self.dim_latent = dim_latent
        
    def decode_conv(self, latent_vector):
        with tf.variable_scope("decoder"):
            x = tf.keras.layers.Dense(units=16,
                                      name="fc_decoder")(latent_vector)
            
            x = tf.reshape(x, (-1, 4, 4, 1))
            
            x = tf.keras.layers.Conv2DTranspose(filters=64,
                                                kernel_size=(3,3),
                                                strides=(1,1),
                                                padding="valid",
                                                activation=tf.nn.relu,
                                                name="deconv1")(x)
            x = tf.keras.layers.Conv2DTranspose(filters=64,
                                                kernel_size=(2,2),
                                                strides=(2,2),
                                                padding="valid",
                                                activation=tf.nn.relu,
                                                name="deconv2")(x)
            x = tf.keras.layers.Conv2DTranspose(filters=32,
                                                kernel_size=(4,4),
                                                strides=(2,2),
                                                padding="valid",
                                                activation=tf.nn.relu,
                                                name="deconv3")(x)

            flat = tf.keras.layers.Flatten()(x)
            x = tf.keras.layers.Dense(units=28*28,
                                      activation=tf.nn.sigmoid,
                                      name="fc_decoder")(flat)
            reconstruction = tf.reshape(x, (-1, 28, 28, 1))
            
            return reconstruction
        
    def decode_mlp(self, latent_vector):
        with tf.variable_scope("decoder"):
            x = tf.keras.layers.Dense(units=256,
                                      activation=tf.nn.relu,
                                      name="fc_decoder1")(latent_vector)
            
            x = tf.keras.layers.Dense(units=28*28,
                                      activation=tf.nn.sigmoid,
                                      name="fc_decoder3")(x)
            reconstruction = tf.reshape(x, (-1, 28, 28, 1))
            
            return reconstruction

In [18]:
class VAE:
    def __init__(self, input_im_shape, dim_latent):
        # Remove self where not needed
        self.input_im_shape = input_im_shape
        self.dim_latent = dim_latent
        
        self.encoder = Encoder(dim_latent)
        self.decoder = Decoder(dim_latent)

        self.original_image = tf.placeholder(tf.float32, (None, *(self.input_im_shape)), name="original_image")

        self.batch_size = tf.placeholder(tf.int64, None, name="batch_size")
        self.dataset = tf.data.Dataset.from_tensor_slices(self.original_image).shuffle(10000).batch(self.batch_size).repeat()
        self.iterator = self.dataset.make_initializable_iterator()

        self.original_image_exp = tf.expand_dims(self.iterator.get_next(), -1)

        self.latent_vec, mean, log_var = self.encoder.encode_mlp(self.original_image_exp)

        self.reconstruction = self.decoder.decode_mlp(self.latent_vec)

        # Losses
#         self.reconstruction_loss = tf.reduce_mean(tf.math.squared_difference(self.reconstruction,
#                                                                              self.original_image_exp))
        self.reconstruction_loss = tf.reduce_mean(tf.keras.backend.binary_crossentropy(self.original_image_exp,
                                                                                       self.reconstruction))
            
        self.coeff_latent_loss = tf.placeholder(tf.float32, None, name="coeff_latent_loss")
        self.latent_loss = 0.5 * tf.reduce_mean(mean ** 2 + tf.exp(log_var) - log_var - 1)
            
        self.loss = self.reconstruction_loss + self.coeff_latent_loss * self.latent_loss

        # Optimization
        self.learning_rate = tf.placeholder(tf.float32, None, name="learning_rate")
        self.optimizer = tf.train.AdamOptimizer(self.learning_rate)
        self.train_op = self.optimizer.minimize(self.loss)

        # Summaries   
        tf.summary.scalar("reconstruction_loss", self.reconstruction_loss)
        tf.summary.scalar("latent_loss", self.latent_loss)
        tf.summary.scalar("loss", self.loss)
        tf.summary.scalar("mean", mean)
        tf.summary.scalar("log_var", log_var)
        tf.summary.image("train_images", self.original_image_exp, 16)
        tf.summary.image("reconstructions", self.reconstruction, 16)
        self.merged_summaries = tf.summary.merge_all()


    def train(self, X_train, batch_size, nb_steps, learning_rate, sess):
        saver = tf.train.Saver()
        summary_writer = tf.summary.FileWriter("./tensorboard/", sess.graph)

        sess.run(self.iterator.initializer, feed_dict={self.original_image: X_train,
                                                       self.batch_size: batch_size})

        for step in range(1, nb_steps + 1):
            coeff_latent_loss = min(0.15 * step / 100000, 0.15)
#             coeff_latent_loss = 0.15
            
            _, summaries = sess.run([self.train_op, self.merged_summaries],
                                    feed_dict={self.learning_rate: learning_rate,
                                               self.coeff_latent_loss: coeff_latent_loss})

            if step % 5000 == 0:
                print("Save and write summaries")
                saver.save(sess, "./model/model.ckpt")
                summary_writer.add_summary(summaries, step)
                
            if step % 5000 == 0:
                latent_samples = np.random.randn((nb_samples, dim_latent))

                generated_images = sess.run(vae.reconstruction,
                                            feed_dict={vae.latent_vec: latent_samples})
                eval_sum = tf.summary.image("rec_from_sample_latent", generated_images, 16)
                summary_writer.add_summary(eval_sum, step)
        
    def restore(self, ckpt_file, sess):
        saver = tf.train.Saver()
        saver.restore(sess, ckpt_file)

In [6]:
im_shape = (28, 28)
dim_latent = 16
batch_size = 256
learning_rate = 4e-4

In [19]:
vae = VAE(im_shape, dim_latent)

In [8]:
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    vae.train(X_train, batch_size, 300000, learning_rate, sess)

Save and write summaries
Save and write summaries
Save and write summaries
Save and write summaries
Save and write summaries
Save and write summaries
Save and write summaries
Save and write summaries
Save and write summaries
Save and write summaries
Save and write summaries
Save and write summaries
Save and write summaries
Save and write summaries
Save and write summaries
Save and write summaries
Save and write summaries
Save and write summaries
Save and write summaries
Save and write summaries
Save and write summaries
Save and write summaries
Save and write summaries
Save and write summaries
Save and write summaries
Save and write summaries
Save and write summaries
Save and write summaries
Save and write summaries
Save and write summaries
Save and write summaries
Save and write summaries
Save and write summaries
Save and write summaries
Save and write summaries
Save and write summaries
Save and write summaries
Save and write summaries
Save and write summaries
Save and write summaries


In [21]:
# Random spampling from latent space
nb_samples = 64
with tf.Session() as sess:
    vae.restore("./model/model.ckpt", sess)
    latent_samples = np.random.randn(nb_samples, dim_latent)

    generated_images = sess.run(vae.reconstruction,
                                feed_dict={vae.latent_vec: latent_samples})

    print(np.squeeze(generated_images[0]))

    for im in generated_images:
        plt.figure()
        plt.imshow(np.squeeze(im), vmin=0, vmax=1)
        plt.show()

INFO:tensorflow:Restoring parameters from ./model/model.ckpt


NotFoundError: Restoring from checkpoint failed. This is most likely due to a Variable name or other graph key that is missing from the checkpoint. Please ensure that you have not altered the graph expected based on the checkpoint. Original error:

Key beta1_power_1 not found in checkpoint
	 [[node save_6/RestoreV2 (defined at <ipython-input-18-009371bbf2b9>:78) ]]
	 [[node save_6/RestoreV2 (defined at <ipython-input-18-009371bbf2b9>:78) ]]

Caused by op 'save_6/RestoreV2', defined at:
  File "/usr/lib/python3.5/runpy.py", line 184, in _run_module_as_main
    "__main__", mod_spec)
  File "/usr/lib/python3.5/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/home/jason/.local/lib/python3.5/site-packages/ipykernel_launcher.py", line 16, in <module>
    app.launch_new_instance()
  File "/home/jason/.local/lib/python3.5/site-packages/traitlets/config/application.py", line 658, in launch_instance
    app.start()
  File "/home/jason/.local/lib/python3.5/site-packages/ipykernel/kernelapp.py", line 505, in start
    self.io_loop.start()
  File "/home/jason/.local/lib/python3.5/site-packages/tornado/platform/asyncio.py", line 132, in start
    self.asyncio_loop.run_forever()
  File "/usr/lib/python3.5/asyncio/base_events.py", line 345, in run_forever
    self._run_once()
  File "/usr/lib/python3.5/asyncio/base_events.py", line 1312, in _run_once
    handle._run()
  File "/usr/lib/python3.5/asyncio/events.py", line 125, in _run
    self._callback(*self._args)
  File "/home/jason/.local/lib/python3.5/site-packages/tornado/ioloop.py", line 758, in _run_callback
    ret = callback()
  File "/home/jason/.local/lib/python3.5/site-packages/tornado/stack_context.py", line 300, in null_wrapper
    return fn(*args, **kwargs)
  File "/home/jason/.local/lib/python3.5/site-packages/tornado/gen.py", line 1233, in inner
    self.run()
  File "/home/jason/.local/lib/python3.5/site-packages/tornado/gen.py", line 1147, in run
    yielded = self.gen.send(value)
  File "/home/jason/.local/lib/python3.5/site-packages/ipykernel/kernelbase.py", line 357, in process_one
    yield gen.maybe_future(dispatch(*args))
  File "/home/jason/.local/lib/python3.5/site-packages/tornado/gen.py", line 326, in wrapper
    yielded = next(result)
  File "/home/jason/.local/lib/python3.5/site-packages/ipykernel/kernelbase.py", line 267, in dispatch_shell
    yield gen.maybe_future(handler(stream, idents, msg))
  File "/home/jason/.local/lib/python3.5/site-packages/tornado/gen.py", line 326, in wrapper
    yielded = next(result)
  File "/home/jason/.local/lib/python3.5/site-packages/ipykernel/kernelbase.py", line 534, in execute_request
    user_expressions, allow_stdin,
  File "/home/jason/.local/lib/python3.5/site-packages/tornado/gen.py", line 326, in wrapper
    yielded = next(result)
  File "/home/jason/.local/lib/python3.5/site-packages/ipykernel/ipkernel.py", line 294, in do_execute
    res = shell.run_cell(code, store_history=store_history, silent=silent)
  File "/home/jason/.local/lib/python3.5/site-packages/ipykernel/zmqshell.py", line 536, in run_cell
    return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
  File "/home/jason/.local/lib/python3.5/site-packages/IPython/core/interactiveshell.py", line 2843, in run_cell
    raw_cell, store_history, silent, shell_futures)
  File "/home/jason/.local/lib/python3.5/site-packages/IPython/core/interactiveshell.py", line 2869, in _run_cell
    return runner(coro)
  File "/home/jason/.local/lib/python3.5/site-packages/IPython/core/async_helpers.py", line 67, in _pseudo_sync_runner
    coro.send(None)
  File "/home/jason/.local/lib/python3.5/site-packages/IPython/core/interactiveshell.py", line 3044, in run_cell_async
    interactivity=interactivity, compiler=compiler, result=result)
  File "/home/jason/.local/lib/python3.5/site-packages/IPython/core/interactiveshell.py", line 3209, in run_ast_nodes
    if (yield from self.run_code(code, result)):
  File "/home/jason/.local/lib/python3.5/site-packages/IPython/core/interactiveshell.py", line 3291, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-21-a0f366644799>", line 4, in <module>
    vae.restore("./model/model.ckpt", sess)
  File "<ipython-input-18-009371bbf2b9>", line 78, in restore
    saver = tf.train.Saver()
  File "/home/jason/.local/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 832, in __init__
    self.build()
  File "/home/jason/.local/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 844, in build
    self._build(self._filename, build_save=True, build_restore=True)
  File "/home/jason/.local/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 881, in _build
    build_save=build_save, build_restore=build_restore)
  File "/home/jason/.local/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 513, in _build_internal
    restore_sequentially, reshape)
  File "/home/jason/.local/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 332, in _AddRestoreOps
    restore_sequentially)
  File "/home/jason/.local/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 580, in bulk_restore
    return io_ops.restore_v2(filename_tensor, names, slices, dtypes)
  File "/home/jason/.local/lib/python3.5/site-packages/tensorflow/python/ops/gen_io_ops.py", line 1572, in restore_v2
    name=name)
  File "/home/jason/.local/lib/python3.5/site-packages/tensorflow/python/framework/op_def_library.py", line 788, in _apply_op_helper
    op_def=op_def)
  File "/home/jason/.local/lib/python3.5/site-packages/tensorflow/python/util/deprecation.py", line 507, in new_func
    return func(*args, **kwargs)
  File "/home/jason/.local/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 3300, in create_op
    op_def=op_def)
  File "/home/jason/.local/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 1801, in __init__
    self._traceback = tf_stack.extract_stack()

NotFoundError (see above for traceback): Restoring from checkpoint failed. This is most likely due to a Variable name or other graph key that is missing from the checkpoint. Please ensure that you have not altered the graph expected based on the checkpoint. Original error:

Key beta1_power_1 not found in checkpoint
	 [[node save_6/RestoreV2 (defined at <ipython-input-18-009371bbf2b9>:78) ]]
	 [[node save_6/RestoreV2 (defined at <ipython-input-18-009371bbf2b9>:78) ]]
