# VQ-VAE
(vector quantised variational auto-encoder)

Based on DeepMind's [Neural Discrete Representation Learning
](https://arxiv.org/abs/1711.00937)

In [1]:
import tensorflow as tf
import os

# import MNIST data
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
Instructions for updating:
Please write your own downloading logic.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST_data/train-images-idx3-ubyte.gz
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Instructions for updating:
Please use tf.one_hot on tensors.
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.


In [2]:
INPUT_SIZE = 784
IMG_DIM = 28
BATCH_SIZE = 250

CODE_SIZE = 32  # size of a single vector in e (embedding space)
EMBEDDING_COUNT = 128  # no embedding vectors
BETA = .1  # weigh how much the encoder tries to match the embeddings

In [3]:
graph = tf.Graph()

def tf_summary_image(name, flat_tensor):
    tf.summary.image(name, tf.expand_dims(tf.reshape(flat_tensor, [-1, IMG_DIM, IMG_DIM]), -1))

def get_params(x):
    """ Extracts kernel and bias params from a tf.layers.dense tensor. """
    name = (x.name).split('/')[1:-1]
    return [
        tf.get_variable('/'.join(name) + '/kernel'),
        tf.get_variable('/'.join(name) + '/bias')]

with graph.as_default():
    images = tf.placeholder(tf.float32, shape=[BATCH_SIZE, INPUT_SIZE], name='images')
    
    tf_summary_image('input', images)
    
    with tf.variable_scope('encoder', reuse=tf.AUTO_REUSE):
        encoder_dense1 = tf.layers.dense(images, 128, activation=tf.nn.relu, name='dense1')
        encoder_out = tf.layers.dense(encoder_dense1, CODE_SIZE, activation=tf.tanh, name='out')
        
        encoder_vars = get_params(encoder_dense1) + get_params(encoder_out)
    
    with tf.variable_scope('embedding'):
        embedding_space = tf.get_variable(name='space', shape=[EMBEDDING_COUNT, CODE_SIZE], initializer=tf.random_normal_initializer)
        embedding_space_batch = tf.reshape(tf.tile(embedding_space, [BATCH_SIZE, 1]), [BATCH_SIZE, EMBEDDING_COUNT, CODE_SIZE])
        encoder_tiled = tf.reshape(tf.tile(encoder_out, [1, EMBEDDING_COUNT]), [BATCH_SIZE, EMBEDDING_COUNT, CODE_SIZE])
        differences = tf.subtract(embedding_space_batch, encoder_tiled)
        l2_distances = tf.reduce_sum(tf.square(differences), axis=2, name='l2_distances')
        e_index = tf.argmin(l2_distances, axis=1, name='e_index')
        code = tf.gather(embedding_space, e_index, axis=0, name='lookup_result')
        code_stop_grad = tf.stop_gradient(code - encoder_out) + encoder_out

    with tf.variable_scope('decoder', reuse=tf.AUTO_REUSE):
        decoder_dense1 = tf.layers.dense(code_stop_grad, 128, activation=tf.nn.relu, name='dense1')
        reconstruction = tf.layers.dense(decoder_dense1, INPUT_SIZE, activation=tf.tanh, name='reconstruction')
        tf_summary_image('reconstruction', reconstruction)
        
        decoder_vars = get_params(decoder_dense1) + get_params(reconstruction)
        
    # losses
    loss_enc_dec = tf.reduce_mean(tf.square(tf.subtract(images, reconstruction)))
    loss_enc_vq = tf.reduce_mean(tf.square(tf.subtract(encoder_out, code)))
    
    tf.summary.scalar('reconstruction_loss', loss_enc_dec)
    tf.summary.scalar('embedding_loss', loss_enc_vq)
    
    # optimizer
    opt_encoder = tf.train.GradientDescentOptimizer(1e-2).minimize(loss_enc_dec + BETA * loss_enc_vq, var_list=encoder_vars)
    opt_decoder = tf.train.GradientDescentOptimizer(1e-2).minimize(loss_enc_dec, var_list=decoder_vars)
    opt_embeddings = tf.train.GradientDescentOptimizer(5e-4).minimize(loss_enc_vq, var_list=[embedding_space])
    
    # misc
    init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
    summary_merged = tf.summary.merge_all()

In [4]:
encoder_tiled.shape

TensorShape([Dimension(250), Dimension(128), Dimension(32)])

In [5]:
MODEL_NAME = 'vqvae_007c'
NO_STEPS = 10000

sess = tf.Session(graph=graph)
sess.run(init)
log_writer = tf.summary.FileWriter('tf_logs/%s' % MODEL_NAME, sess.graph)

for step in range(NO_STEPS):
    batch_images, batch_labels = mnist.train.next_batch(BATCH_SIZE)
    summary, _, _, _ = sess.run(
        [summary_merged, opt_encoder, opt_decoder, opt_embeddings],
        feed_dict={
            images: batch_images
        }
    )
    log_writer.add_summary(summary, global_step=step)