# VAE Architecture

In this notebook, we will delve into the architecture of Variational Auto-Encoders (VAEs). We will explain the components of a VAE, including the encoder and decoder, and provide code examples for building a VAE architecture using PyTorch.

## Encoder

The encoder is a neural network that takes the input data and maps it to a latent space. The output of the encoder is the parameters of a probability distribution in the latent space, typically the mean and log variance.

In [1]:
import torch
import torch.nn as nn

class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(Encoder, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2_mean = nn.Linear(hidden_dim, latent_dim)
        self.fc2_log_var = nn.Linear(hidden_dim, latent_dim)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        h = self.relu(self.fc1(x))
        mean = self.fc2_mean(h)
        log_var = self.fc2_log_var(h)
        return mean, log_var

## Decoder

The decoder is a neural network that takes samples from the latent distribution and maps them back to the original data space. The output of the decoder is the reconstructed data.

In [2]:
class Decoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, output_dim):
        super(Decoder, self).__init__()
        self.fc1 = nn.Linear(latent_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, z):
        h = self.relu(self.fc1(z))
        x_reconstructed = self.sigmoid(self.fc2(h))
        return x_reconstructed

## VAE Model

The VAE model combines the encoder and decoder, and includes a sampling layer to sample from the latent distribution.

In [3]:
class VAE(nn.Module):
    def __init__(self, encoder, decoder):
        super(VAE, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
    
    def forward(self, x):
        mean, log_var = self.encoder(x)
        std = torch.exp(0.5 * log_var)
        epsilon = torch.randn_like(std)
        z = mean + std * epsilon
        x_reconstructed = self.decoder(z)
        return x_reconstructed, mean, log_var

## Loss Function

The loss function for a VAE consists of two terms: the reconstruction loss and the KL divergence. The reconstruction loss measures how well the decoder can reconstruct the input data from the latent space, while the KL divergence measures how close the learned distribution is to a prior distribution.

In [4]:
def vae_loss(x, x_reconstructed, mean, log_var):
    reconstruction_loss = nn.functional.binary_cross_entropy(x_reconstructed, x, reduction='sum')
    kl_divergence = -0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp())
    return reconstruction_loss + kl_divergence

## Training the VAE

Let's train the VAE on a simple dataset, such as the MNIST dataset.

In [5]:
from torchvision import datasets, transforms
import torch.optim as optim

# Load the MNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: x.view(-1))])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)

# Define the VAE model
input_dim = 28 * 28
hidden_dim = 256
latent_dim = 2
encoder = Encoder(input_dim, hidden_dim, latent_dim)
decoder = Decoder(latent_dim, hidden_dim, input_dim)
vae = VAE(encoder, decoder)

# Define the optimizer
optimizer = optim.Adam(vae.parameters(), lr=1e-3)

# Train the model
num_epochs = 10
for epoch in range(num_epochs):
    vae.train()
    train_loss = 0
    for x, _ in train_loader:
        optimizer.zero_grad()
        x_reconstructed, mean, log_var = vae(x)
        loss = vae_loss(x, x_reconstructed, mean, log_var)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
    print(f'Epoch {epoch + 1}, Loss: {train_loss / len(train_loader.dataset)}')

Epoch 1, Loss: 182.64177789713543
Epoch 2, Loss: 164.48488276367186
Epoch 3, Loss: 161.2271118815104
Epoch 4, Loss: 159.09910853678386
Epoch 5, Loss: 157.53306800944011
Epoch 6, Loss: 156.28730290527344
Epoch 7, Loss: 155.26126284179688
Epoch 8, Loss: 154.44241954752604
Epoch 9, Loss: 153.73177485351562
Epoch 10, Loss: 153.18850033365885
