# 🌌 Variational Autoencoder (VAE)

Welcome to **Variational Autoencoders**! In this notebook, we'll explore probabilistic generative modeling, learn meaningful latent representations, and generate new data from learned distributions.

## What you'll learn:
- Probabilistic encoder-decoder architecture
- Reparameterization trick for backpropagation
- KL divergence and ELBO optimization
- Latent space exploration and interpolation

Let's explore the latent space! 🚀

In [None]:
# Import libraries
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.manifold import TSNE

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models
from tensorflow.keras.losses import binary_crossentropy

plt.style.use('seaborn-v0_8')
np.random.seed(42)
tf.random.set_seed(42)

print(f"TensorFlow version: {tf.__version__}")
print(f"GPU Available: {len(tf.config.list_physical_devices('GPU')) > 0}")

In [None]:
# Load and preprocess MNIST dataset
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()

# Normalize to [0, 1]
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0

# Reshape to flatten
x_train = x_train.reshape(x_train.shape[0], -1)
x_test = x_test.reshape(x_test.shape[0], -1)

print(f"Training data shape: {x_train.shape}")
print(f"Test data shape: {x_test.shape}")
print(f"Data range: [{x_train.min():.2f}, {x_train.max():.2f}]")

# Visualize sample digits
fig, axes = plt.subplots(2, 5, figsize=(12, 6))
for i, ax in enumerate(axes.flat):
    img = x_train[i].reshape(28, 28)
    ax.imshow(img, cmap='gray')
    ax.set_title(f'Digit: {y_train[i]}')
    ax.axis('off')

plt.suptitle('🔢 MNIST Dataset Samples', fontsize=16)
plt.tight_layout()
plt.show()

In [None]:
# VAE Architecture
class VAE(keras.Model):
    def __init__(self, latent_dim=20, intermediate_dim=512, **kwargs):
        super(VAE, self).__init__(**kwargs)
        self.latent_dim = latent_dim
        self.intermediate_dim = intermediate_dim
        
        # Encoder
        self.encoder = keras.Sequential([
            layers.InputLayer(input_shape=(784,)),
            layers.Dense(intermediate_dim, activation='relu'),
            layers.Dense(intermediate_dim, activation='relu'),
        ])
        
        # Latent space parameters
        self.z_mean = layers.Dense(latent_dim)
        self.z_log_var = layers.Dense(latent_dim)
        
        # Decoder
        self.decoder = keras.Sequential([
            layers.InputLayer(input_shape=(latent_dim,)),
            layers.Dense(intermediate_dim, activation='relu'),
            layers.Dense(intermediate_dim, activation='relu'),
            layers.Dense(784, activation='sigmoid'),
        ])
    
    def encode(self, x):
        """Encode input to latent parameters"""
        h = self.encoder(x)
        z_mean = self.z_mean(h)
        z_log_var = self.z_log_var(h)
        return z_mean, z_log_var
    
    def reparameterize(self, z_mean, z_log_var):
        """Reparameterization trick"""
        batch_size = tf.shape(z_mean)[0]
        epsilon = tf.random.normal(shape=(batch_size, self.latent_dim))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon
    
    def decode(self, z):
        """Decode latent code to reconstruction"""
        return self.decoder(z)
    
    def call(self, x):
        """Forward pass"""
        z_mean, z_log_var = self.encode(x)
        z = self.reparameterize(z_mean, z_log_var)
        reconstruction = self.decode(z)
        return reconstruction, z_mean, z_log_var

# Create VAE model
LATENT_DIM = 20
vae = VAE(latent_dim=LATENT_DIM)

print(f"✅ VAE model created with latent dimension: {LATENT_DIM}")

In [None]:
# VAE Loss Function
def vae_loss(x, reconstruction, z_mean, z_log_var):
    """VAE loss = Reconstruction loss + KL divergence"""
    # Reconstruction loss (binary crossentropy)
    reconstruction_loss = binary_crossentropy(x, reconstruction)
    reconstruction_loss = tf.reduce_sum(reconstruction_loss, axis=1)
    
    # KL divergence loss
    kl_loss = -0.5 * tf.reduce_sum(
        1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var), axis=1
    )
    
    return tf.reduce_mean(reconstruction_loss + kl_loss), tf.reduce_mean(reconstruction_loss), tf.reduce_mean(kl_loss)

# Optimizer
optimizer = keras.optimizers.Adam(learning_rate=1e-3)

# Training step
@tf.function
def train_step(x):
    with tf.GradientTape() as tape:
        reconstruction, z_mean, z_log_var = vae(x)
        total_loss, recon_loss, kl_loss = vae_loss(x, reconstruction, z_mean, z_log_var)
    
    gradients = tape.gradient(total_loss, vae.trainable_variables)
    optimizer.apply_gradients(zip(gradients, vae.trainable_variables))
    
    return total_loss, recon_loss, kl_loss

print("✅ Loss function and training step defined!")

In [None]:
# Training loop
EPOCHS = 50
BATCH_SIZE = 128

# Create dataset
train_dataset = tf.data.Dataset.from_tensor_slices(x_train)
train_dataset = train_dataset.shuffle(1000).batch(BATCH_SIZE)

# Track losses
train_losses = {'total': [], 'reconstruction': [], 'kl': []}

print(f"🚀 Starting VAE training...")
print(f"Epochs: {EPOCHS}, Batch Size: {BATCH_SIZE}")

for epoch in range(EPOCHS):
    epoch_losses = {'total': [], 'reconstruction': [], 'kl': []}
    
    for batch in train_dataset:
        total_loss, recon_loss, kl_loss = train_step(batch)
        epoch_losses['total'].append(total_loss)
        epoch_losses['reconstruction'].append(recon_loss)
        epoch_losses['kl'].append(kl_loss)
    
    # Average losses for epoch
    avg_total = tf.reduce_mean(epoch_losses['total'])
    avg_recon = tf.reduce_mean(epoch_losses['reconstruction'])
    avg_kl = tf.reduce_mean(epoch_losses['kl'])
    
    train_losses['total'].append(avg_total)
    train_losses['reconstruction'].append(avg_recon)
    train_losses['kl'].append(avg_kl)
    
    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1}: Total Loss = {avg_total:.2f}, "
              f"Recon Loss = {avg_recon:.2f}, KL Loss = {avg_kl:.2f}")

print("\n🎉 Training completed!")

In [None]:
# Visualize training progress
fig, axes = plt.subplots(1, 2, figsize=(15, 6))

# Loss curves
axes[0].plot(train_losses['total'], label='Total Loss', alpha=0.8)
axes[0].plot(train_losses['reconstruction'], label='Reconstruction Loss', alpha=0.8)
axes[0].plot(train_losses['kl'], label='KL Divergence', alpha=0.8)
axes[0].set_title('📉 VAE Training Losses')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Test reconstructions
test_sample = x_test[:10]
reconstructions, _, _ = vae(test_sample)

# Show original vs reconstructed
axes[1].axis('off')
axes[1].set_title('🔄 Original vs Reconstructed')

plt.tight_layout()
plt.show()

# Detailed reconstruction comparison
fig, axes = plt.subplots(2, 10, figsize=(20, 6))

for i in range(10):
    # Original
    axes[0, i].imshow(test_sample[i].numpy().reshape(28, 28), cmap='gray')
    axes[0, i].set_title('Original')
    axes[0, i].axis('off')
    
    # Reconstructed
    axes[1, i].imshow(reconstructions[i].numpy().reshape(28, 28), cmap='gray')
    axes[1, i].set_title('Reconstructed')
    axes[1, i].axis('off')

plt.suptitle('🔄 VAE Reconstructions', fontsize=16)
plt.tight_layout()
plt.show()

print(f"\n📊 Final Training Results:")
print(f"Total Loss: {train_losses['total'][-1]:.2f}")
print(f"Reconstruction Loss: {train_losses['reconstruction'][-1]:.2f}")
print(f"KL Divergence: {train_losses['kl'][-1]:.2f}")

In [None]:
# Generate new samples
def generate_samples(vae, num_samples=16):
    """Generate new samples from random latent codes"""
    random_latent = tf.random.normal(shape=(num_samples, vae.latent_dim))
    generated = vae.decode(random_latent)
    return generated

# Generate and visualize samples
generated_samples = generate_samples(vae, 16)

fig, axes = plt.subplots(4, 4, figsize=(12, 12))
for i, ax in enumerate(axes.flat):
    img = generated_samples[i].numpy().reshape(28, 28)
    ax.imshow(img, cmap='gray')
    ax.axis('off')

plt.suptitle('🎲 Generated Samples from Random Latent Codes', fontsize=16)
plt.tight_layout()
plt.show()

In [None]:
# Latent space visualization (2D projection)
if LATENT_DIM > 2:
    # Encode test samples
    test_sample = x_test[:1000]
    test_labels = y_test[:1000]
    z_mean, _ = vae.encode(test_sample)
    
    # Use t-SNE for 2D projection
    tsne = TSNE(n_components=2, random_state=42)
    z_2d = tsne.fit_transform(z_mean.numpy())
    
    # Plot latent space
    plt.figure(figsize=(12, 10))
    scatter = plt.scatter(z_2d[:, 0], z_2d[:, 1], c=test_labels, cmap='tab10', alpha=0.7)
    plt.colorbar(scatter)
    plt.title('🌌 Latent Space Visualization (t-SNE projection)')
    plt.xlabel('Latent Dimension 1')
    plt.ylabel('Latent Dimension 2')
    plt.grid(True, alpha=0.3)
    plt.show()
else:
    # Direct 2D visualization
    test_sample = x_test[:1000]
    test_labels = y_test[:1000]
    z_mean, _ = vae.encode(test_sample)
    
    plt.figure(figsize=(12, 10))
    scatter = plt.scatter(z_mean[:, 0], z_mean[:, 1], c=test_labels, cmap='tab10', alpha=0.7)
    plt.colorbar(scatter)
    plt.title('🌌 2D Latent Space')
    plt.xlabel('Latent Dimension 1')
    plt.ylabel('Latent Dimension 2')
    plt.grid(True, alpha=0.3)
    plt.show()

In [None]:
# Latent space interpolation
def interpolate_latent(vae, start_img, end_img, steps=10):
    """Interpolate between two images in latent space"""
    # Encode images to latent space
    start_z, _ = vae.encode(start_img.reshape(1, -1))
    end_z, _ = vae.encode(end_img.reshape(1, -1))
    
    # Interpolate
    interpolated_images = []
    for i in range(steps):
        alpha = i / (steps - 1)
        interpolated_z = (1 - alpha) * start_z + alpha * end_z
        decoded = vae.decode(interpolated_z)
        interpolated_images.append(decoded[0].numpy())
    
    return interpolated_images

# Select two different digits for interpolation
start_idx = np.where(y_test == 0)[0][0]
end_idx = np.where(y_test == 9)[0][0]

start_img = x_test[start_idx]
end_img = x_test[end_idx]

interpolated = interpolate_latent(vae, start_img, end_img, 10)

# Visualize interpolation
fig, axes = plt.subplots(1, 10, figsize=(20, 4))
for i, (ax, img) in enumerate(zip(axes, interpolated)):
    ax.imshow(img.reshape(28, 28), cmap='gray')
    ax.set_title(f'Step {i+1}')
    ax.axis('off')

plt.suptitle('🌈 Latent Space Interpolation (0 → 9)', fontsize=16)
plt.tight_layout()
plt.show()

print(f"\n📊 VAE Summary:")
print(f"Latent Dimension: {LATENT_DIM}")
print(f"Training Epochs: {EPOCHS}")
print(f"Final Loss: {train_losses['total'][-1]:.2f}")
print(f"Model Parameters: {vae.count_params():,}")

## 🎉 Congratulations!

You've successfully implemented and trained a Variational Autoencoder! Here's what you've accomplished:

✅ **VAE Architecture**: Built probabilistic encoder-decoder  
✅ **Reparameterization Trick**: Enabled backpropagation through sampling  
✅ **ELBO Optimization**: Balanced reconstruction and regularization  
✅ **Latent Space**: Explored meaningful representations  
✅ **Generation**: Created new samples from learned distribution  
✅ **Interpolation**: Smooth transitions in latent space  

### 🚀 Next Steps:
1. Try β-VAE for better disentanglement
2. Implement Conditional VAE for controlled generation
3. Experiment with different latent dimensions
4. Move on to **Project 09: Transformer for Language Modeling**

Ready for the attention revolution? Let's build Transformers! 🤖