# VAE Fundamentals: Mathematical Foundation and Implementation

## Learning Objectives
- Understand the mathematical derivation of VAEs
- Learn the reparameterization trick
- Implement a basic VAE from scratch
- Analyze the Evidence Lower Bound (ELBO)
- Compare VAE with standard autoencoders

## What Makes VAEs Special?

Unlike standard autoencoders that learn deterministic mappings, VAEs learn probabilistic encodings. This enables:
1. **Generation**: Sample new data from the learned distribution
2. **Interpolation**: Smooth transitions in latent space
3. **Uncertainty**: Quantify confidence in representations
4. **Regularization**: Structured latent space through KL divergence

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm

# Set style and random seeds
plt.style.use('seaborn-v0_8')
torch.manual_seed(42)
np.random.seed(42)

# Check for GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 1. Mathematical Foundation

### The Generative Process
VAEs model data generation as:
1. Sample latent code: $z \sim p(z)$ (prior)
2. Generate data: $x \sim p(x|z)$ (likelihood)

### The Inference Problem
We want to maximize: $\log p(x) = \log \int p(x|z)p(z)dz$

Since this integral is intractable, we use variational inference with approximate posterior $q(z|x)$.

In [None]:
# Visualize the VAE framework
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# Prior distribution p(z)
z_samples = np.random.normal(0, 1, 1000)
axes[0].hist(z_samples, bins=30, alpha=0.7, density=True, color='blue')
axes[0].set_title('Prior p(z) ~ N(0,1)')
axes[0].set_xlabel('z')
axes[0].set_ylabel('Density')

# Approximate posterior q(z|x)
mu, sigma = 0.5, 0.8
z_posterior = np.random.normal(mu, sigma, 1000)
axes[1].hist(z_posterior, bins=30, alpha=0.7, density=True, color='red')
axes[1].set_title(f'Posterior q(z|x) ~ N({mu},{sigma}¬≤)')
axes[1].set_xlabel('z')
axes[1].set_ylabel('Density')

# KL divergence visualization
z_range = np.linspace(-3, 4, 100)
prior_pdf = (1/np.sqrt(2*np.pi)) * np.exp(-0.5 * z_range**2)
posterior_pdf = (1/(sigma*np.sqrt(2*np.pi))) * np.exp(-0.5 * ((z_range - mu)/sigma)**2)

axes[2].plot(z_range, prior_pdf, label='p(z)', color='blue')
axes[2].plot(z_range, posterior_pdf, label='q(z|x)', color='red')
axes[2].fill_between(z_range, 0, np.minimum(prior_pdf, posterior_pdf), alpha=0.3, color='green')
axes[2].set_title('KL Divergence KL(q||p)')
axes[2].set_xlabel('z')
axes[2].set_ylabel('Density')
axes[2].legend()

plt.tight_layout()
plt.show()

# Calculate KL divergence analytically for Gaussians
kl_div = 0.5 * (sigma**2 + mu**2 - 1 - 2*np.log(sigma))
print(f"KL divergence between N({mu},{sigma}¬≤) and N(0,1): {kl_div:.4f}")

## 2. Evidence Lower Bound (ELBO) Derivation

Starting from the log-likelihood:

$$\log p(x) = \mathbb{E}_{q(z|x)}[\log p(x)] = \mathbb{E}_{q(z|x)}\left[\log \frac{p(x,z)}{p(z|x)}\right]$$

$$= \mathbb{E}_{q(z|x)}\left[\log \frac{p(x,z)q(z|x)}{p(z|x)q(z|x)}\right]$$

$$= \mathbb{E}_{q(z|x)}\left[\log \frac{p(x,z)}{q(z|x)}\right] + \mathbb{E}_{q(z|x)}\left[\log \frac{q(z|x)}{p(z|x)}\right]$$

$$= \text{ELBO} + \text{KL}(q(z|x) || p(z|x))$$

Since KL ‚â• 0, we have: $\log p(x) \geq \text{ELBO}$

In [None]:
def compute_elbo_components(x, mu, logvar, recon_x):
    """
    Compute ELBO components for analysis
    """
    # Reconstruction term: E[log p(x|z)]
    recon_loss = F.binary_cross_entropy(recon_x, x, reduction='sum')
    
    # KL divergence: KL(q(z|x) || p(z))
    # For q(z|x) = N(Œº, œÉ¬≤I) and p(z) = N(0, I):
    # KL = 0.5 * sum(œÉ¬≤ + Œº¬≤ - 1 - log(œÉ¬≤))
    kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    
    # ELBO = E[log p(x|z)] - KL(q(z|x) || p(z))
    # We minimize the negative ELBO
    elbo = -recon_loss - kl_div
    
    return {
        'elbo': elbo.item(),
        'reconstruction_loss': recon_loss.item(),
        'kl_divergence': kl_div.item(),
        'total_loss': (-elbo).item()  # What we actually minimize
    }

# Demonstrate ELBO computation with dummy data
batch_size, latent_dim = 32, 10
x_dummy = torch.rand(batch_size, 784)  # Flattened 28x28 images
mu_dummy = torch.randn(batch_size, latent_dim)
logvar_dummy = torch.randn(batch_size, latent_dim)
recon_dummy = torch.sigmoid(torch.randn(batch_size, 784))

elbo_components = compute_elbo_components(x_dummy, mu_dummy, logvar_dummy, recon_dummy)

print("ELBO Components (dummy data):")
for key, value in elbo_components.items():
    print(f"{key}: {value:.4f}")

## 3. The Reparameterization Trick

**Problem**: We can't backpropagate through stochastic sampling $z \sim q(z|x)$

**Solution**: Reparameterize as $z = \mu + \sigma \odot \epsilon$ where $\epsilon \sim N(0,I)$

This makes the stochasticity independent of the parameters we're optimizing.

In [None]:
def reparameterize(mu, logvar):
    """
    Reparameterization trick: z = Œº + œÉ * Œµ, where Œµ ~ N(0,1)
    """
    std = torch.exp(0.5 * logvar)  # œÉ = exp(0.5 * log(œÉ¬≤))
    eps = torch.randn_like(std)    # Œµ ~ N(0,1)
    return mu + eps * std

def reparameterize_no_trick(mu, logvar):
    """
    Direct sampling (doesn't allow gradients)
    """
    std = torch.exp(0.5 * logvar)
    return torch.normal(mu, std)

# Demonstrate gradient flow
mu = torch.tensor([0.0, 1.0], requires_grad=True)
logvar = torch.tensor([0.0, 0.5], requires_grad=True)

print("Testing gradient flow:")

# With reparameterization trick
z_reparam = reparameterize(mu, logvar)
loss_reparam = z_reparam.sum()
loss_reparam.backward()
print(f"With reparameterization - mu.grad: {mu.grad}")

# Reset gradients
mu.grad = None
logvar.grad = None

# Without reparameterization (this would fail in practice)
try:
    z_direct = reparameterize_no_trick(mu, logvar)
    loss_direct = z_direct.sum()
    loss_direct.backward()
    print(f"Without reparameterization - mu.grad: {mu.grad}")
except RuntimeError as e:
    print(f"Error without reparameterization: {e}")

# Visualize the effect of reparameterization
mu_test = torch.tensor([0.0])
logvar_test = torch.tensor([1.0])

samples_reparam = [reparameterize(mu_test, logvar_test).item() for _ in range(1000)]
samples_direct = [torch.normal(mu_test, torch.exp(0.5 * logvar_test)).item() for _ in range(1000)]

fig, axes = plt.subplots(1, 2, figsize=(12, 5))

axes[0].hist(samples_reparam, bins=30, alpha=0.7, density=True, label='Reparameterized')
axes[0].set_title('Reparameterized Sampling')
axes[0].set_xlabel('z')
axes[0].set_ylabel('Density')

axes[1].hist(samples_direct, bins=30, alpha=0.7, density=True, label='Direct', color='orange')
axes[1].set_title('Direct Sampling')
axes[1].set_xlabel('z')
axes[1].set_ylabel('Density')

plt.tight_layout()
plt.show()

print(f"Both methods produce the same distribution:")
print(f"Reparameterized - Mean: {np.mean(samples_reparam):.3f}, Std: {np.std(samples_reparam):.3f}")
print(f"Direct - Mean: {np.mean(samples_direct):.3f}, Std: {np.std(samples_direct):.3f}")

## 4. Basic VAE Implementation

In [None]:
class BasicVAE(nn.Module):
    def __init__(self, input_dim=784, hidden_dim=400, latent_dim=20):
        super(BasicVAE, self).__init__()
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )
        
        # Latent space parameters
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim),
            nn.Sigmoid()  # Output in [0,1] for MNIST
        )
        
        self.latent_dim = latent_dim
    
    def encode(self, x):
        """Encode input to latent parameters"""
        h = self.encoder(x)
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar
    
    def reparameterize(self, mu, logvar):
        """Reparameterization trick"""
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def decode(self, z):
        """Decode latent code to reconstruction"""
        return self.decoder(z)
    
    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        recon_x = self.decode(z)
        return recon_x, mu, logvar
    
    def sample(self, num_samples, device):
        """Generate new samples from prior"""
        z = torch.randn(num_samples, self.latent_dim).to(device)
        return self.decode(z)

# Initialize model
model = BasicVAE(input_dim=784, hidden_dim=400, latent_dim=20).to(device)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

# Print model architecture
print("\nModel Architecture:")
print(model)

## 5. Loss Function Implementation

In [None]:
def vae_loss(recon_x, x, mu, logvar, beta=1.0):
    """
    VAE loss function: -ELBO = Reconstruction Loss + Œ≤ * KL Divergence
    
    Args:
        recon_x: Reconstructed input
        x: Original input
        mu: Mean of latent distribution
        logvar: Log variance of latent distribution
        beta: Weight for KL divergence (Œ≤-VAE)
    """
    # Reconstruction loss (negative log-likelihood)
    # For binary data, use BCE; for continuous data, use MSE
    recon_loss = F.binary_cross_entropy(recon_x, x, reduction='sum')
    
    # KL divergence between q(z|x) and p(z) = N(0,I)
    # KL(N(Œº,œÉ¬≤)||N(0,1)) = 0.5 * (œÉ¬≤ + Œº¬≤ - 1 - log(œÉ¬≤))
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    
    # Total loss
    total_loss = recon_loss + beta * kl_loss
    
    return total_loss, recon_loss, kl_loss

def vae_loss_mse(recon_x, x, mu, logvar, beta=1.0):
    """VAE loss with MSE reconstruction (for continuous data)"""
    recon_loss = F.mse_loss(recon_x, x, reduction='sum')
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss + beta * kl_loss, recon_loss, kl_loss

# Test loss function with dummy data
x_test = torch.rand(10, 784).to(device)
recon_test, mu_test, logvar_test = model(x_test)

loss, recon_loss, kl_loss = vae_loss(recon_test, x_test, mu_test, logvar_test)

print(f"Loss components (test):")
print(f"Total loss: {loss.item():.4f}")
print(f"Reconstruction loss: {recon_loss.item():.4f}")
print(f"KL divergence: {kl_loss.item():.4f}")
print(f"KL per dimension: {kl_loss.item() / (10 * 20):.4f}")

## 6. Data Loading and Preprocessing

In [None]:
# Load MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.view(-1))  # Flatten to 784 dimensions
])

train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

print(f"Training samples: {len(train_dataset)}")
print(f"Test samples: {len(test_dataset)}")

# Visualize some samples
sample_batch, sample_labels = next(iter(train_loader))
print(f"Batch shape: {sample_batch.shape}")
print(f"Data range: [{sample_batch.min():.3f}, {sample_batch.max():.3f}]")

# Plot sample images
fig, axes = plt.subplots(2, 8, figsize=(12, 4))
for i in range(8):
    # Original
    axes[0, i].imshow(sample_batch[i].view(28, 28), cmap='gray')
    axes[0, i].set_title(f'Label: {sample_labels[i]}')
    axes[0, i].axis('off')
    
    # Histogram of pixel values
    axes[1, i].hist(sample_batch[i].numpy(), bins=20, alpha=0.7)
    axes[1, i].set_title('Pixel Distribution')
    axes[1, i].set_xlabel('Pixel Value')

plt.tight_layout()
plt.show()

## 7. Training the VAE

In [None]:
def train_vae(model, train_loader, optimizer, epoch, beta=1.0):
    model.train()
    train_loss = 0
    train_recon_loss = 0
    train_kl_loss = 0
    
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        
        # Forward pass
        recon_batch, mu, logvar = model(data)
        
        # Compute loss
        loss, recon_loss, kl_loss = vae_loss(recon_batch, data, mu, logvar, beta)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        train_recon_loss += recon_loss.item()
        train_kl_loss += kl_loss.item()
        
        if batch_idx % 100 == 0:
            print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} '
                  f'({100. * batch_idx / len(train_loader):.0f}%)]\t'
                  f'Loss: {loss.item() / len(data):.6f}')
    
    avg_loss = train_loss / len(train_loader.dataset)
    avg_recon = train_recon_loss / len(train_loader.dataset)
    avg_kl = train_kl_loss / len(train_loader.dataset)
    
    print(f'====> Epoch: {epoch} Average loss: {avg_loss:.4f} '
          f'Recon: {avg_recon:.4f} KL: {avg_kl:.4f}')
    
    return avg_loss, avg_recon, avg_kl

def test_vae(model, test_loader, beta=1.0):
    model.eval()
    test_loss = 0
    test_recon_loss = 0
    test_kl_loss = 0
    
    with torch.no_grad():
        for data, _ in test_loader:
            data = data.to(device)
            recon_batch, mu, logvar = model(data)
            
            loss, recon_loss, kl_loss = vae_loss(recon_batch, data, mu, logvar, beta)
            test_loss += loss.item()
            test_recon_loss += recon_loss.item()
            test_kl_loss += kl_loss.item()
    
    test_loss /= len(test_loader.dataset)
    test_recon_loss /= len(test_loader.dataset)
    test_kl_loss /= len(test_loader.dataset)
    
    print(f'====> Test set loss: {test_loss:.4f} '
          f'Recon: {test_recon_loss:.4f} KL: {test_kl_loss:.4f}')
    
    return test_loss, test_recon_loss, test_kl_loss

# Initialize optimizer
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# Training loop
num_epochs = 10
train_losses = []
test_losses = []
recon_losses = []
kl_losses = []

print("Starting VAE training...")
for epoch in range(1, num_epochs + 1):
    # Beta annealing (gradually increase KL weight)
    beta = min(1.0, epoch / 5.0)
    
    train_loss, train_recon, train_kl = train_vae(model, train_loader, optimizer, epoch, beta)
    test_loss, test_recon, test_kl = test_vae(model, test_loader, beta)
    
    train_losses.append(train_loss)
    test_losses.append(test_loss)
    recon_losses.append(train_recon)
    kl_losses.append(train_kl)
    
    print(f"Beta: {beta:.3f}\n")

print("Training completed!")

## 8. Analyzing Training Progress

In [None]:
# Plot training curves
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# Total loss
axes[0, 0].plot(train_losses, label='Train', marker='o')
axes[0, 0].plot(test_losses, label='Test', marker='s')
axes[0, 0].set_title('Total Loss')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# Loss components
axes[0, 1].plot(recon_losses, label='Reconstruction', marker='o', color='blue')
axes[0, 1].plot(kl_losses, label='KL Divergence', marker='s', color='red')
axes[0, 1].set_title('Loss Components')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Loss')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# KL/Reconstruction ratio
kl_recon_ratio = np.array(kl_losses) / np.array(recon_losses)
axes[1, 0].plot(kl_recon_ratio, marker='o', color='green')
axes[1, 0].set_title('KL/Reconstruction Ratio')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Ratio')
axes[1, 0].grid(True, alpha=0.3)

# KL per latent dimension
kl_per_dim = np.array(kl_losses) / model.latent_dim
axes[1, 1].plot(kl_per_dim, marker='o', color='purple')
axes[1, 1].axhline(y=1.0, color='red', linestyle='--', alpha=0.7, label='KL=1 per dim')
axes[1, 1].set_title('KL Divergence per Latent Dimension')
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('KL per Dimension')
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"Final training metrics:")
print(f"Total loss: {train_losses[-1]:.4f}")
print(f"Reconstruction loss: {recon_losses[-1]:.4f}")
print(f"KL divergence: {kl_losses[-1]:.4f}")
print(f"KL per dimension: {kl_per_dim[-1]:.4f}")

## 9. Evaluating Reconstructions

In [None]:
def visualize_reconstructions(model, test_loader, num_samples=8):
    model.eval()
    
    with torch.no_grad():
        for data, labels in test_loader:
            data = data.to(device)
            recon, mu, logvar = model(data)
            
            # Plot original vs reconstruction
            fig, axes = plt.subplots(3, num_samples, figsize=(15, 6))
            
            for i in range(num_samples):
                # Original
                axes[0, i].imshow(data[i].cpu().view(28, 28), cmap='gray')
                axes[0, i].set_title(f'Original (Label: {labels[i]})')
                axes[0, i].axis('off')
                
                # Reconstruction
                axes[1, i].imshow(recon[i].cpu().view(28, 28), cmap='gray')
                axes[1, i].set_title('Reconstruction')
                axes[1, i].axis('off')
                
                # Difference
                diff = torch.abs(data[i] - recon[i]).cpu().view(28, 28)
                im = axes[2, i].imshow(diff, cmap='hot')
                axes[2, i].set_title(f'Diff (MSE: {F.mse_loss(recon[i], data[i]):.3f})')
                axes[2, i].axis('off')
            
            plt.tight_layout()
            plt.show()
            break

visualize_reconstructions(model, test_loader)

# Compute reconstruction statistics
def compute_reconstruction_stats(model, test_loader):
    model.eval()
    mse_errors = []
    
    with torch.no_grad():
        for data, _ in test_loader:
            data = data.to(device)
            recon, _, _ = model(data)
            
            # Compute MSE for each sample
            mse = F.mse_loss(recon, data, reduction='none').mean(dim=1)
            mse_errors.extend(mse.cpu().numpy())
    
    return np.array(mse_errors)

mse_errors = compute_reconstruction_stats(model, test_loader)

plt.figure(figsize=(10, 6))
plt.hist(mse_errors, bins=50, alpha=0.7, edgecolor='black')
plt.axvline(np.mean(mse_errors), color='red', linestyle='--', 
           label=f'Mean: {np.mean(mse_errors):.4f}')
plt.axvline(np.median(mse_errors), color='green', linestyle='--', 
           label=f'Median: {np.median(mse_errors):.4f}')
plt.title('Distribution of Reconstruction Errors')
plt.xlabel('MSE')
plt.ylabel('Frequency')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

print(f"Reconstruction statistics:")
print(f"Mean MSE: {np.mean(mse_errors):.6f}")
print(f"Std MSE: {np.std(mse_errors):.6f}")
print(f"Min MSE: {np.min(mse_errors):.6f}")
print(f"Max MSE: {np.max(mse_errors):.6f}")

## 10. Generating New Samples

In [None]:
def generate_samples(model, num_samples=16):
    model.eval()
    
    with torch.no_grad():
        # Sample from prior N(0,I)
        z = torch.randn(num_samples, model.latent_dim).to(device)
        samples = model.decode(z)
        
        # Plot generated samples
        fig, axes = plt.subplots(4, 4, figsize=(8, 8))
        axes = axes.flatten()
        
        for i in range(num_samples):
            axes[i].imshow(samples[i].cpu().view(28, 28), cmap='gray')
            axes[i].set_title(f'Sample {i+1}')
            axes[i].axis('off')
        
        plt.suptitle('Generated Samples from VAE')
        plt.tight_layout()
        plt.show()
        
        return samples

generated_samples = generate_samples(model, 16)

# Compare with real samples
real_samples, _ = next(iter(test_loader))

fig, axes = plt.subplots(2, 8, figsize=(12, 4))

for i in range(8):
    # Real samples
    axes[0, i].imshow(real_samples[i].view(28, 28), cmap='gray')
    axes[0, i].set_title('Real')
    axes[0, i].axis('off')
    
    # Generated samples
    axes[1, i].imshow(generated_samples[i].cpu().view(28, 28), cmap='gray')
    axes[1, i].set_title('Generated')
    axes[1, i].axis('off')

plt.suptitle('Real vs Generated Samples')
plt.tight_layout()
plt.show()

## üéØ Key Takeaways

### Mathematical Insights:
1. **ELBO Maximization**: VAEs maximize a lower bound on the log-likelihood
2. **Reparameterization Trick**: Enables gradient flow through stochastic nodes
3. **KL Regularization**: Forces latent codes to match the prior distribution
4. **Trade-off**: Balance between reconstruction quality and latent regularity

### Implementation Insights:
1. **Beta Annealing**: Gradually increase KL weight for stable training
2. **Loss Components**: Monitor reconstruction and KL losses separately
3. **Latent Dimensions**: More dimensions = more capacity but harder optimization
4. **Architecture**: Encoder/decoder symmetry often works well

### Practical Insights:
1. **Generation Quality**: VAEs produce smoother but blurrier samples than GANs
2. **Latent Space**: Continuous and interpolatable latent representations
3. **Training Stability**: More stable than GANs, less prone to mode collapse
4. **Applications**: Great for representation learning and anomaly detection

## üìù Exercises

### Beginner:
1. Modify the latent dimension and observe the effect on generation quality
2. Try different beta values and analyze the reconstruction-regularization trade-off
3. Implement VAE with MSE loss instead of BCE

### Intermediate:
1. Add batch normalization to the encoder and decoder
2. Implement learning rate scheduling
3. Create a convolutional VAE for CIFAR-10

### Advanced:
1. Implement Œ≤-VAE with automatic beta tuning
2. Add skip connections to the decoder
3. Implement hierarchical VAE with multiple latent levels

## üîó Next Steps

Now that you understand VAE fundamentals, you're ready to explore:
1. **Advanced VAE Variants**: Œ≤-VAE, WAE, VQ-VAE
2. **Latent Space Analysis**: Disentanglement and interpolation
3. **Applications**: Anomaly detection, data augmentation
4. **Comparisons**: VAE vs GAN trade-offs

**Next Notebook**: [VAE Architecture Design](./02_vae_architectures.ipynb) ‚Üí