# VQ-VAE
(vector quantised variational auto-encoder)

Based on DeepMind's [Neural Discrete Representation Learning
](https://arxiv.org/abs/1711.00937)

In [3]:
import tensorflow as tf
import os

In [4]:
INPUT_SIZE = 3
BATCH_SIZE = 2

CODE_SIZE = 3  # size of a single vector in e (embedding space)
EMBEDDING_COUNT = 2  # no embedding vectors

In [16]:
graph = tf.Graph()

with graph.as_default():
    images = tf.placeholder(tf.float32, shape=[BATCH_SIZE, INPUT_SIZE], name='images')
    
    with tf.variable_scope('encoder', reuse=tf.AUTO_REUSE):
        encoder_out = images
    
    with tf.variable_scope('embedding'):
        embedding_space = tf.placeholder(tf.float32, shape=[EMBEDDING_COUNT, CODE_SIZE], name='embedding_space')
        embedding_space_batch = tf.reshape(tf.tile(embedding_space, [BATCH_SIZE, 1]), [BATCH_SIZE, EMBEDDING_COUNT, CODE_SIZE])
        encoder_tiled = tf.reshape(tf.tile(encoder_out, [1, EMBEDDING_COUNT]), [BATCH_SIZE, EMBEDDING_COUNT, CODE_SIZE])
        differences = tf.subtract(embedding_space_batch, encoder_tiled)
        l2_distances = tf.reduce_sum(tf.square(differences), axis=2, name='l2_distances')
        e_index = tf.argmin(l2_distances, axis=1, name='e_index')
        code = tf.gather(embedding_space, e_index, axis=0, name='lookup_result')
    
    init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
    summary_merged = tf.summary.merge_all()

In [13]:
encoder_tiled.shape

TensorShape([Dimension(2), Dimension(2), Dimension(3)])

In [18]:
sess = tf.Session(graph=graph)
sess.run(init)

input_val = [[1,2,3],[4,5,6]]
embedding_space_val = [[0,2,4],[1,3,5]]
    

def get_val(tensor):
    return sess.run(tensor, feed_dict={
        images: input_val,
        embedding_space: embedding_space_val
    })

with sess.as_default():
    print("input")
    print(input_val)
    print("embeddings")
    print(embedding_space_val)
    
    print("----------------")
    print("1")
    print("encoder tiled")
    print(get_val(encoder_tiled))
    
    print("----------------")
    print("2")
    print("embedding space batch")
    print(get_val(embedding_space_batch))
    
    print("----------------")
    print("3")
    print("distances")
    print(get_val(l2_distances))
    print("argmin indices")
    print(get_val(e_index))
    
    print("----------------")
    print("4")
    print("code; gather bs")
    print(get_val(code))



input
[[1, 2, 3], [4, 5, 6]]
embeddings
[[0, 2, 4], [1, 3, 5]]
----------------
1
encoder tiled
[[[1. 2. 3.]
  [1. 2. 3.]]

 [[4. 5. 6.]
  [4. 5. 6.]]]
----------------
2
embedding space batch
[[[0. 2. 4.]
  [1. 3. 5.]]

 [[0. 2. 4.]
  [1. 3. 5.]]]
----------------
3
distances
[[ 2.  5.]
 [29. 14.]]
argmin indices
[0 1]
----------------
4
code; gather bs
[[0. 2. 4.]
 [1. 3. 5.]]
