In this notebook, we will implement a Variational Autoencoder (VAE) using Keras and TensorFlow. The VAE will be trained on the MNIST dataset, which consists of handwritten digits. We will visualize the latent space and generate new samples from the learned distribution.

Standard autoencoders map each input to a single point in latent space, which limits their generative capabilities.

## Variational Part: Probabilistic Latent Space
Instead of mapping each input to a fixed point, a VAE:
	•	Maps inputs to a distribution (commonly Gaussian) in latent space. For each input, the encoder predicts a mean ($\mu$) and standard deviation ($\sigma$) for this distribution.
	•	Samples a random point from this distribution to pass to the decoder.

This approach makes the latent space continuous and smooth, enabling the generation of new—yet realistic—samples by sampling from anywhere within it.

## Loss Function
The training objective combines two goals:
	•	Reconstruction Loss: Ensures the decoder can faithfully reconstruct the original input from samples in the latent space.
	•	KL Divergence Loss: Encourages the learned latent distributions to be close to a standard normal distribution, ensuring a well-behaved latent space and enabling generative capabilities.
Together, these drive the VAE to both represent the data well and allow for creative sampling.

## What Problem Does VAE Solve?

Traditional autoencoders can struggle to generate new, realistic samples because their latent space may develop “holes” or regions with no valid representations. VAEs, by enforcing a continuous probabilistic latent space, enable:
	•	Smooth, meaningful latent representations: Any point in the latent space decodes to a plausible, realistic sample.
	•	Data Generation: New data can be generated by sampling randomly from the latent space.
	•	Interpolation: It’s possible to smoothly interpolate between data points via the latent space (e.g., morphing one image into another).

This makes VAEs very useful for:
	•	Generating new images, sounds, or text similar to the training data.
	•	Anomaly detection (outliers fall outside the learned data distribution).
	•	Data compression and denoising.

In [None]:
import numpy as np
import tensorflow as tf
from keras.models import Model
from keras.datasets import fashion_mnist
from keras.layers import Input, Dense, Flatten, Reshape, Conv2D, Conv2DTranspose, Normalization

import matplotlib.pyplot as plt

## Load the MNIST dataset

In [None]:
# Load Fashion MNIST data
(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()
x_train = np.pad(x_train, ((0, 0), (2, 2), (2, 2)), mode='constant')
x_test = np.pad(x_test, ((0, 0), (2, 2), (2, 2)), mode='constant')
x_train = x_train.astype("float32") / 255.
x_test = x_test.astype("float32") / 255.

In [None]:
labels = [
    'T-shirt/top',
    'Trouser',
    'Pullover',
    'Dress',
    'Coat',
    'Sandal',
    'Shirt',
    'Sneaker',
    'Bag',
    'Ankle boot'
]

In [None]:
# VAE hyperparameters
input_shape = (32, 32, 1)  # Fashion MNIST padded to 32x32
latent_dim = 2  # 2D latent space for easy visualization
intermediate_dim = 512
batch_size = 128
epochs = 30

In [None]:
# Reparameterization trick
def sampling(args):
    """Reparameterization trick by sampling from an isotropic unit Gaussian.
    
    Arguments:
        args (tensor): mean and log of variance of Q(z|X)
        
    Returns:
        z (tensor): sampled latent vector
    """
    z_mean, z_log_var = args
    batch = tf.shape(z_mean)[0]
    dim = tf.shape(z_mean)[1]
    # Sample epsilon from standard normal distribution
    epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
    # Return sampled latent vector: z = mean + std * epsilon
    return z_mean + tf.exp(0.5 * z_log_var) * epsilon

In [None]:
# Build the encoder
inputs = Input(shape=input_shape, name='encoder_input')

# Convolutional layers for feature extraction
x = Conv2D(32, 3, activation='relu', strides=2, padding='same')(inputs)
x = Conv2D(64, 3, activation='relu', strides=2, padding='same')(x)
x = Conv2D(128, 3, activation='relu', strides=2, padding='same')(x)
x = Flatten()(x)
x = Dense(intermediate_dim, activation='relu')(x)

# Generate latent distribution parameters
z_mean = Dense(latent_dim, name='z_mean')(x)
z_log_var = Dense(latent_dim, name='z_log_var')(x)

# Use reparameterization trick to sample from latent distribution
z = tf.keras.layers.Lambda(sampling, output_shape=(latent_dim,), name='z')([z_mean, z_log_var])

# Instantiate encoder model
encoder = Model(inputs, [z_mean, z_log_var, z], name='encoder')
encoder.summary()

In [None]:
# Build the decoder
latent_inputs = Input(shape=(latent_dim,), name='z_sampling')

# Dense layers to expand from latent space
x = Dense(intermediate_dim, activation='relu')(latent_inputs)
x = Dense(4 * 4 * 128, activation='relu')(x)
x = Reshape((4, 4, 128))(x)

# Transpose convolutional layers for upsampling
x = Conv2DTranspose(128, 3, activation='relu', strides=2, padding='same')(x)
x = Conv2DTranspose(64, 3, activation='relu', strides=2, padding='same')(x)
x = Conv2DTranspose(32, 3, activation='relu', strides=2, padding='same')(x)

# Output layer
outputs = Conv2DTranspose(1, 3, activation='sigmoid', padding='same', name='decoder_output')(x)

# Instantiate decoder model
decoder = Model(latent_inputs, outputs, name='decoder')
decoder.summary()

In [None]:
# Custom layer to handle KL divergence loss
class KLDivergenceLayer(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super(KLDivergenceLayer, self).__init__(**kwargs)

    def call(self, inputs):
        z_mean, z_log_var = inputs
        kl_loss = 1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)
        kl_loss = tf.reduce_mean(kl_loss) * -0.5
        self.add_loss(kl_loss)
        return z_mean, z_log_var

# Rebuild the VAE with proper loss handling
# Add KL loss layer
z_mean_kl, z_log_var_kl = KLDivergenceLayer()([z_mean, z_log_var])

# Use the outputs with KL loss added
z_with_kl = tf.keras.layers.Lambda(sampling, output_shape=(latent_dim,), name='z_with_kl')([z_mean_kl, z_log_var_kl])

# Rebuild encoder with KL loss
encoder = Model(inputs, [z_mean_kl, z_log_var_kl, z_with_kl], name='encoder')

# Instantiate VAE model with KL loss properly integrated
outputs = decoder(encoder(inputs)[2])  # Use the sampled z from encoder
vae = Model(inputs, outputs, name='vae_mlp')

print("VAE Model Architecture:")
vae.summary()

In [None]:
# Define VAE loss function
def vae_loss(inputs, outputs):
    """VAE loss = reconstruction loss + KL divergence loss"""
    # Reconstruction loss (binary crossentropy)
    reconstruction_loss = tf.keras.losses.binary_crossentropy(inputs, outputs)
    reconstruction_loss *= input_shape[0] * input_shape[1]  # Scale by image dimensions
    
    return tf.reduce_mean(reconstruction_loss)

# We'll add the KL loss as a separate loss to the model
# This is a better approach for VAEs in Keras

In [None]:
# Compile the VAE model
# KL loss is automatically added by the KLDivergenceLayer
vae.compile(optimizer='adam', loss='binary_crossentropy', metrics=['mae'])

print("VAE model compiled successfully!")

## Train the VAE

Now we'll train the Variational Autoencoder on the Fashion MNIST dataset. The training process will optimize both the reconstruction quality and the regularization of the latent space.

In [None]:
# Prepare data for training (add channel dimension)
x_train = x_train.reshape(x_train.shape[0], 32, 32, 1)
x_test = x_test.reshape(x_test.shape[0], 32, 32, 1)

print(f"Training data shape: {x_train.shape}")
print(f"Test data shape: {x_test.shape}")
print(f"Training data range: [{x_train.min():.3f}, {x_train.max():.3f}]")

In [None]:
# Define callbacks for better training
early_stopping = tf.keras.callbacks.EarlyStopping(
    monitor='val_loss', 
    patience=5, 
    restore_best_weights=True
)

reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(
    monitor='val_loss', 
    factor=0.5, 
    patience=3, 
    min_lr=1e-7
)

# Train the VAE
print("Starting VAE training...")
history = vae.fit(
    x_train, x_train,
    epochs=epochs,
    batch_size=batch_size,
    validation_data=(x_test, x_test),
    callbacks=[early_stopping, reduce_lr],
    verbose=1
)

In [None]:
# Plot training history
plt.figure(figsize=(12, 4))

plt.subplot(1, 2, 1)
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('VAE Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)

plt.subplot(1, 2, 2)
plt.plot(history.history['mae'], label='Training MAE')
plt.plot(history.history['val_mae'], label='Validation MAE')
plt.title('VAE Training MAE')
plt.xlabel('Epoch')
plt.ylabel('Mean Absolute Error')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()

## Visualize the Latent Space

Now let's explore the 2D latent space learned by our VAE. We'll encode the test images and visualize how different fashion items are distributed in the latent space.