# Hierarchical Variational Autoencoders (VAEs): A Complete Brief Explanation

## 1. Introduction to VAEs:
Variational Autoencoders (VAEs) are a class of generative models that learn to represent complex data distributions in a simpler, lower-dimensional latent space. They are composed of two main components:
- **Encoder**: Maps data to a probabilistic latent space, typically Gaussian.
- **Decoder**: Reconstructs the data from the latent variables.

The VAE optimization consists of:
- **Reconstruction Loss**: Measures how closely the generated data resembles the input.
- **KL Divergence Loss**: Regularizes the latent space by encouraging the distribution of latent variables to match a prior distribution (usually Gaussian).

## 2. What Are Hierarchical Priors?:
In a standard VAE, we assume that the latent variables are independent and come from a simple prior (usually a Gaussian). However, real-world data often has dependencies at multiple scales. For instance, in images, high-level features (like object shapes) influence lower-level features (like textures or pixel details).

A **hierarchical prior** is a prior distribution defined over other priors, creating multiple levels of latent variables. This allows the model to capture multi-scale dependencies, where each level's latent variables are conditioned on the ones above it.

## 3. How Hierarchical VAEs Work:
In a hierarchical VAE, the latent space is organized into multiple layers, with each layer capturing different levels of abstraction in the data:
- The **top layer** captures high-level features (e.g., objects, shapes).
- The **bottom layer** captures low-level features (e.g., textures, pixel-level details).

For example, in a 2-layer hierarchical VAE:
- The first layer might capture global object shapes (high-level features).
- The second layer would capture local details (like textures or edges), conditioned on the first layer's output.

Each layer has its own prior, and the latent variables are sampled using the reparameterization trick. The layers are connected probabilistically, allowing the model to generate more structured and realistic data.

## 4. Training Hierarchical VAEs:
The training process involves the following steps:
1. **Encoding**: The encoder computes the parameters (mean and variance) for the latent variables at each layer and samples from these distributions.
2. **Decoding**: The decoder reconstructs the input data using all levels of latent variables.
3. **Loss Calculation**:
   - **Reconstruction Loss**: Measures the difference between the input data and the reconstructed output.
   - **KL Divergence Loss**: Regularizes the latent space, ensuring each layer’s latent distribution matches the prior, with dependencies across layers.
4. **Backpropagation**: Gradients are computed for both losses, and the model parameters are updated using an optimization method (e.g., Adam optimizer).
5. **Repeat**: The process continues for several epochs, with the model learning better representations of data at multiple scales.

## 5. Benefits of Hierarchical Priors:
- **Improved Representations**: Hierarchical priors enable the model to learn richer, more structured latent representations that capture multi-level dependencies in data.
- **Better Generalization**: By modeling high-level and low-level features separately, hierarchical VAEs can generalize better on complex datasets.
- **More Expressive**: They can handle more complex and diverse data distributions (e.g., images, text, videos) by allowing each level of the latent space to represent different types of features.

## 6. Applications:
- **Image Generation**: Hierarchical VAEs can generate realistic images by capturing both the overall structure (e.g., object shapes) and detailed features (e.g., textures).
- **Text Generation**: In natural language, hierarchical VAEs can learn dependencies between high-level topics and low-level words, producing more coherent and contextually rich text.
- **Anomaly Detection**: By learning a multi-scale representation, hierarchical VAEs can better detect anomalies since outliers are less likely to fit the complex multi-level structure of normal data.
- **Semi-Supervised Learning**: Hierarchical VAEs help learn from both labeled and unlabeled data by leveraging the structured latent space, improving performance when labeled data is scarce.

## 7. Challenges:
- **Model Complexity**: Hierarchical VAEs are more complex than standard VAEs, requiring careful design and tuning of the latent layers and priors.
- **Optimization**: Training hierarchical VAEs can be more difficult due to the increased number of parameters and dependencies between latent variables.
- **Computational Cost**: With multiple layers and priors, these models may require more memory and longer training times.

## 8. Key Differences from Standard VAEs:
- **Prior Structure**: Standard VAEs use a simple Gaussian prior for each latent variable, while hierarchical VAEs define priors across multiple layers of latent variables, capturing complex dependencies.
- **Modeling Power**: Hierarchical VAEs can model more complex relationships in the data, whereas standard VAEs assume that the latent variables are independent.

## 9. Future Directions:
Research on hierarchical VAEs is still evolving. Potential improvements include:
- **Handling Temporal Data**: Extending hierarchical VAEs for sequential data (e.g., videos, time series).
- **Multimodal Learning**: Combining text, images, and audio for richer generative models.
- **Better Optimization**: Developing more efficient training methods for hierarchical structures.

### Conclusion:
Hierarchical VAEs offer a powerful extension to standard VAEs by introducing multiple levels of latent variables, each capturing different aspects of data. They provide a more flexible and expressive framework for generative modeling, particularly for complex datasets like images and text. By learning multi-scale representations, hierarchical VAEs can generate more realistic data, improve generalization, and handle a wide variety of tasks, from image synthesis to anomaly detection.


In [None]:
import tensorflow as tf
from tensorflow.keras import layers, models
import numpy as np
import matplotlib.pyplot as plt

# Encoder
class Encoder(tf.keras.Model):
    def __init__(self, latent_dim=2):
        super(Encoder, self).__init__()
        self.dense1 = layers.Dense(400, activation='relu')
        self.mean = layers.Dense(latent_dim)
        self.logvar = layers.Dense(latent_dim)
        
    def call(self, inputs):
        x = self.dense1(inputs)
        return self.mean(x), self.logvar(x)


# Decoder
class Decoder(tf.keras.Model):
    def __init__(self, latent_dim=2):
        super(Decoder, self).__init__()
        self.dense1 = layers.Dense(400, activation='relu')
        self.output_layer = layers.Dense(784, activation='sigmoid')
        
    def call(self, inputs):
        x = self.dense1(inputs)
        return self.output_layer(x)


# Hierarchical VAE Model with Latent Layers
class HierarchicalVAE(tf.keras.Model):
    def __init__(self, latent_dim=2, num_latent_layers=2):
        super(HierarchicalVAE, self).__init__()
        self.encoder = Encoder(latent_dim)
        self.decoder = Decoder(latent_dim)
        self.num_latent_layers = num_latent_layers
        self.latent_layers = [layers.Dense(latent_dim) for _ in range(num_latent_layers)]

    def reparameterize(self, mu, logvar):
        epsilon = tf.random.normal(shape=tf.shape(mu))
        return mu + tf.exp(0.5 * logvar) * epsilon

    def call(self, inputs):
        mu, logvar = self.encoder(inputs)
        z = self.reparameterize(mu, logvar)
        
        for latent_layer in self.latent_layers:
            z = latent_layer(z)  # Apply hierarchical layers
            
        return self.decoder(z), mu, logvar


# Loss function
def compute_loss(recon_x, x, mu, logvar):
    cross_entropy = tf.reduce_sum(
        tf.keras.losses.binary_crossentropy(x, recon_x), axis=[1, 2]
    )
    cross_entropy = tf.reduce_mean(cross_entropy)
    
    kl_divergence = -0.5 * tf.reduce_mean(
        tf.reduce_sum(1 + logvar - tf.square(mu) - tf.exp(logvar), axis=1)
    )
    
    return cross_entropy + kl_divergence


# Loading MNIST dataset
(x_train, _), (x_test, _) = tf.keras.datasets.mnist.load_data()
x_train = x_train / 255.0
x_test = x_test / 255.0
x_train = x_train.reshape(-1, 784).astype(np.float32)
x_test = x_test.reshape(-1, 784).astype(np.float32)

# Create a TensorFlow Dataset
train_dataset = tf.data.Dataset.from_tensor_slices(x_train).shuffle(60000).batch(64)
test_dataset = tf.data.Dataset.from_tensor_slices(x_test).batch(64)

# Initialize model and optimizer
model = HierarchicalVAE(latent_dim=2, num_latent_layers=3)
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)

# Training loop
epochs = 10
for epoch in range(epochs):
    train_loss = 0
    for data in train_dataset:
        with tf.GradientTape() as tape:
            recon_batch, mu, logvar = model(data)
            loss = compute_loss(recon_batch, data, mu, logvar)
        gradients = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))
        train_loss += loss

    print(f'Epoch {epoch+1}, Loss: {train_loss / len(train_dataset)}')


# Visualization of reconstructed images
def visualize_reconstruction(model, dataset):
    for data in dataset.take(1):
        recon_batch, _, _ = model(data)
        recon_batch = tf.reshape(recon_batch, (-1, 28, 28))
        plt.imshow(recon_batch[0].numpy(), cmap='gray')
        plt.show()

visualize_reconstruction(model, test_dataset)
