In [None]:
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras import backend as K
from tensorflow.keras.datasets import mnist
import matplotlib.pyplot as plt

In [None]:
# Load and prepare the MNIST dataset

In [None]:
(x_train, _), (x_test, _) = mnist.load_data()
x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.
x_train = np.reshape(x_train, (x_train.shape[0], 28, 28, 1))
x_test = np.reshape(x_test, (x_test.shape[0], 28, 28, 1))

In [None]:
# Encoder Model

In [None]:
def build_encoder(latent_dim=2):
    x = layers.Input(shape=(28, 28, 1))
    h = layers.Flatten()(x)
    h = layers.Dense(256, activation='relu')(h)
    h = layers.Dense(128, activation='relu')(h)
    z_mean = layers.Dense(latent_dim)(h)
    z_log_var = layers.Dense(latent_dim)(h)
    encoder = models.Model(x, [z_mean, z_log_var])
    return encoder

In [None]:
# Reparameterization Trick

In [None]:
def sampling(args):
    z_mean, z_log_var = args
    batch = K.shape(z_mean)[0]
    dim = K.int_shape(z_mean)[1]
    epsilon = K.random_normal(shape=(batch, dim))
    return z_mean + K.exp(0.5 * z_log_var) * epsilon

In [None]:
# Decoder Model
def build_decoder(latent_dim=2):
    z = layers.Input(shape=(latent_dim,))
    h = layers.Dense(128, activation='relu')(z)
    h = layers.Dense(256, activation='relu')(h)
    x_decoded = layers.Dense(28 * 28, activation='sigmoid')(h)
    x_decoded = layers.Reshape((28, 28, 1))(x_decoded)
    decoder = models.Model(z, x_decoded)
    return decoder

In [None]:
# VAE Loss Layer
class VAELossLayer(layers.Layer):
    def __init__(self, **kwargs):
        super(VAELossLayer, self).__init__(**kwargs)

    def call(self, inputs):
        x, x_decoded, z_mean, z_log_var = inputs

        # Reconstruction loss
        xent_loss = tf.reduce_sum(tf.keras.losses.binary_crossentropy(x, x_decoded), axis=(1, 2))

        # KL Divergence loss
        kl_loss = - 0.5 * tf.reduce_sum(1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var), axis=-1)

        # Total loss
        vae_loss = tf.reduce_mean(xent_loss + kl_loss)

        self.add_loss(vae_loss)

        return x_decoded  # The output is still the decoded image, as it is part of the model

In [None]:
# VAE Model
def build_vae(encoder, decoder, latent_dim=2):
    x = layers.Input(shape=(28, 28, 1))
    z_mean, z_log_var = encoder(x)
    z = layers.Lambda(sampling, output_shape=(latent_dim,))([z_mean, z_log_var])
    x_decoded = decoder(z)
    
    # Adding the custom VAE loss layer
    vae_loss_layer = VAELossLayer()([x, x_decoded, z_mean, z_log_var])

    vae = models.Model(x, vae_loss_layer)

    return vae

In [None]:
# Build the Encoder, Decoder, and VAE models
latent_dim = 2
encoder = build_encoder(latent_dim)
decoder = build_decoder(latent_dim)
vae = build_vae(encoder, decoder, latent_dim)

In [None]:
# Compile the VAE model
vae.compile(optimizer='adam')

In [None]:
# Train the VAE
vae.fit(x_train, epochs=50, batch_size=128, validation_data=(x_test, None))

In [None]:
# Generate New Data
def generate_new_data(decoder, latent_dim=2):
    z_new = np.random.normal(size=(10, latent_dim))
    generated_images = decoder.predict(z_new)
    return generated_images