* [tutorial#1](https://blog.tensorflow.org/2019/03/variational-autoencoders-with.html)
    * An implementation of the VAEs by using Tensorflow Probability Layers
* [tutorial#2](https://www.tensorflow.org/tutorials/generative/cvae)
    * An implementation of the VAEs on the official site

In [None]:
import tensorflow as tf
import tensorflow_probability as tfp
import tensorflow.keras as tfk

In [None]:
debug_mode = False

# Method definitions:

In [None]:
def createPriorDistribution(encoded_size, isBound = False):
    tfd = tfp.distributions

    if isBound:
        p_x = tfd.Normal(loc=tf.zeros(encoded_size), scale=1)
        p_y = tfd.TransformedDistribution(distribution=p_x, bijector = tfp.bijectors.Tanh())
        prior = tfd.Independent(p_y, reinterpreted_batch_ndims=1)
        return prior
    else:
        p_x = tfd.Normal(loc=tf.zeros(encoded_size), scale=1)
        prior = tfd.Independent(p_x, reinterpreted_batch_ndims=1)
        return prior

In [None]:
if debug_mode:
    prior = createPriorDistribution(encoded_size=16, isBound=True);
    prior = createPriorDistribution(encoded_size=16, isBound=False);

In [None]:
def createEncoder(input_shape, prior):
    tfpl = tfp.layers
    tfkl = tfk.layers
    
    encoded_size = prior.event_shape[0]

    encoder = tfk.Sequential([
        tfkl.InputLayer(input_shape=input_shape),
        tfkl.Dense(tfpl.MultivariateNormalTriL.params_size(encoded_size), 
                   activation=None),
        tfpl.MultivariateNormalTriL(
            encoded_size,
            activity_regularizer=tfpl.KLDivergenceRegularizer(prior, weight=1.0)),
        ])
    return encoder

In [None]:
if debug_mode:
    input_shape = (3,)
    encoder = createEncoder(input_shape, prior)

In [None]:
def createDecoder(input_shape, prior):
    tfpl = tfp.layers
    tfkl = tfk.layers

    encoded_size = prior.event_shape[0]
    n = input_shape[0]

    decoder = tfk.Sequential([
        tfkl.InputLayer(input_shape=[encoded_size]) #  (encode_size)
        , tfkl.Dense(tfpl.MultivariateNormalTriL.params_size(n), activation = None) # (encode_size) => input_shape[0] + input_shape[0]**2/2
        , tfpl.MultivariateNormalTriL(n) # => mu: (input_shape[0],), Sigma: (input_shape[0], input_shape[0])
        ])
    return decoder

In [None]:
if debug_mode:
    decoder = createDecoder(input_shape, prior)

In [None]:
def createTrainDataset(input_shape):
    assert input_shape == (2,)
    sample_size = 2**7
    x1 = tf.random.normal(shape = (sample_size, input_shape[0])) + [-1, -1]
    x2 = tf.random.normal(shape = (sample_size, input_shape[0])) + [1, 1]
    x = tf.concat((x1, x2), axis=0)

    batch_size = 2**5
    train_dataset = tf.data.Dataset.from_tensor_slices((x, x)).shuffle(sample_size).batch(batch_size) # generates (batch_size, *input_shape)    
    return train_dataset

In [None]:
def createEvalDataset(input_shape):
    assert input_shape == (2,)
    sample_size = 2**7
    x1 = tf.random.normal(shape = (sample_size, input_shape[0])) - tf.ones(shape=(sample_size, input_shape[0]))
    x2 = tf.random.normal(shape = (sample_size, input_shape[0])) + tf.ones(shape=(sample_size, input_shape[0]))
    return x1, x2

In [None]:
def createTrainDatasetForDebug(input_shape):
    sample_size = 2**7
    batch_size = 2**5
    xRaw = tf.random.normal(shape=(sample_size, input_shape[0])) # (sample_size, *input_shape)
    train_dataset = tf.data.Dataset.from_tensor_slices((xRaw, xRaw)).shuffle(sample_size).batch(batch_size) # generates (batch_size, *input_shape)
    return train_dataset

In [None]:
if debug_mode:
    createTrainDataset((2,))
    createEvalDataset((2,))
    train_dataset = createTrainDatasetForDebug(input_shape)

In [None]:
def trainVAEs(encoder, decoder, train_dataset, epochs):

    vae = tfk.Model(inputs=encoder.inputs,
                    outputs=decoder(encoder.outputs[0]))

    negative_log_likelihood = lambda x, rv_x: -rv_x.log_prob(x)

    vae.compile(optimizer=tf.optimizers.Adam(learning_rate=1e-3),
                loss=negative_log_likelihood)

    vae.fit(train_dataset, epochs = epochs)

In [None]:
if debug_mode: 
    trainVAEs(encoder, decoder, train_dataset, epochs = 8)

# Case study

In [None]:
import matplotlib.pylab as plt

In [None]:
input_shape = (2,)

train_dataset = createTrainDataset(input_shape)
x1, x2 = createEvalDataset(input_shape)

prior = createPriorDistribution(encoded_size=2, isBound=False);
decoder = createDecoder(input_shape, prior)
encoder = createEncoder(input_shape, prior)

trainVAEs(encoder, decoder, train_dataset, epochs = 2**8)

In [None]:
z1 = encoder(x1).sample(sample_shape=()).numpy() # (128, 2)
z2 = encoder(x2).sample(sample_shape=()).numpy() # (128, 2)

x1hat = decoder(z1).sample().numpy()
x2hat = decoder(z2).sample().numpy()

plt.figure()
plt.plot(z1[:,0], z1[:,1], 'ro')
plt.plot(z2[:,0], z2[:,1], 'bo')

plt.figure()
plt.plot(x1[:,0], x1[:,1], 'k+')
plt.plot(x2[:,0], x2[:,1], 'kx')
plt.plot(x1hat[:,0], x1hat[:,1], 'ro', markerfacecolor="none")
plt.plot(x2hat[:,0], x2hat[:,1], 'bo', markerfacecolor="none")