In [None]:
from tensorflow.keras.datasets import mnist

In [None]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train = x_train / 255
x_test = x_test / 255

In [None]:
image_width = x_train.shape[1]
image_height = x_train.shape[2]
num_channels = 1

x_train = x_train.reshape(x_train.shape[0], image_height, image_width, num_channels)
x_test = x_test.reshape(x_test.shape[0], image_height, image_width, num_channels)

input_shape = (image_height, image_width, num_channels)

In [None]:
import tensorflow as tf

In [None]:
latent_dim = 8

input_img = tf.keras.layers.Input(shape=input_shape, name="encoder_input")
x = tf.keras.layers.Conv2D(filters=32, kernel_size=3, padding="same", activation="relu")(input_img)
x = tf.keras.layers.Conv2D(filters=64, kernel_size=3, padding="same", activation="relu", strides=(2,2))(x)
x = tf.keras.layers.Conv2D(filters=64, kernel_size=3, padding="same", activation="relu")(x)
x = tf.keras.layers.Conv2D(filters=64, kernel_size=3, padding="same", activation="relu")(x)

conv_shape = tf.keras.backend.int_shape(x)
x = tf.keras.layers.Flatten()(x)
x = tf.keras.layers.Dense(32, activation="relu", kernel_initializer="he_normal")(x)

In [None]:
z_mu = tf.keras.layers.Dense(latent_dim, name="z_mu")(x)
z_log_sigma = tf.keras.layers.Dense(latent_dim, name="z_log_sigma")(x)

In [None]:
def sampling(args):
    z_mu, z_log_sigma = args
    epsilon = tf.keras.backend.random_normal(shape=(tf.keras.backend.shape(z_mu)[0], latent_dim), mean=0., stddev=1.)
    return z_mu + (tf.keras.backend.exp(z_log_sigma) ** 0.5) * epsilon

In [None]:
z = tf.keras.layers.Lambda(sampling, output_shape=(latent_dim,), name="z")([z_mu, z_log_sigma])

In [None]:
encoder = tf.keras.models.Model(inputs=input_img, outputs=z, name="encoder")
encoder.summary()

In [None]:
decoder_input = tf.keras.layers.Input(shape=(latent_dim,), name="decoder_input")
x = tf.keras.layers.Dense(conv_shape[1] * conv_shape[2] * conv_shape[3], activation="relu", kernel_initializer="he_normal")(decoder_input)
x = tf.keras.layers.Reshape((conv_shape[1], conv_shape[2], conv_shape[3]))(x)
x = tf.keras.layers.Conv2DTranspose(filters=64, kernel_size=3, padding="same", activation="relu", strides=(2,2))(x)
x = tf.keras.layers.Conv2DTranspose(filters=32, kernel_size=3, padding="same", activation="relu")(x)
x = tf.keras.layers.Conv2D(filters=1, kernel_size=3, padding='same', activation='sigmoid', name='decoder_output')(x)
decoder = tf.keras.models.Model(decoder_input, x, name='decoder')
decoder.summary()

In [None]:
decoder_output = decoder(z)

In [None]:
class VAELossLayer(tf.keras.layers.Layer):
    def vae_loss(self, x, decoder_output, z_mu, z_log_sigma):
        x = tf.keras.backend.flatten(x)
        decoder_output = tf.keras.backend.flatten(decoder_output)

        reconstruction_loss = tf.keras.metrics.binary_crossentropy(x, decoder_output)

        kl_loss = -0.5 * tf.keras.backend.sum(1 + z_log_sigma - tf.keras.backend.square(z_mu) - tf.keras.backend.exp(z_log_sigma), axis=-1)

        return reconstruction_loss + kl_loss


    def call(self, inputs):
        x, decoder_output, z_mu, z_log_sigma = inputs
        loss = self.vae_loss(x, decoder_output, z_mu, z_log_sigma)
        self.add_loss(loss)
        return x

y = VAELossLayer()([input_img, decoder_output, z_mu, z_log_sigma])

vae = tf.keras.models.Model(inputs=[input_img], outputs=[y], name='vae')

vae.compile(optimizer='adam', loss=None)
vae.summary()

In [None]:
vae.fit(x=x_train, y=None, epochs = 10, batch_size = 32)

In [None]:
import numpy as np
import matplotlib.pyplot as plt

img_width, img_height = 28, 28
sample_vector = np.array([[1, 1, 1, 1, 1, 1, 1, 1]])

decoded_example = decoder.predict(sample_vector)

decoded_example_reshaped = decoded_example.reshape(img_width, img_height)

fig = plt.figure(figsize=(img_width/100, img_height/100), dpi=100)
ax = fig.add_subplot(111)
ax.imshow(decoded_example_reshaped, cmap='gray')
ax.axis('off')
plt.show()