# Stable Diffusion VAE - Variational Auto-Encoder

This is a type of Auto-Encoder and Neural Network that trains using an unsupervised technique. They're widely used in image generation models, mainly on latent-diffusion and GAN based image generation models.


## AutoEncoder


### Overview

To understand the VAE, we must first understand the Auto-Encoder.

An Auto-Encoder is an unsupervised model primarily used for dimensionality reduction.

We cannot feed 4K images to train the models, as they would be too big and unpractical to deal with. To solve this, we'll need a model that reduces the dimensionality of these images to a fixed small vector with high rich features from the originals.

Auto-Encoders are designed to compress images into a latent space and back to the original image, so it lacks the ability to generate new and varied outputs.

As expected, this is where the VAE comes from. Rather than memorizing the patterns, it generates variations of images. This means VAE doesn't directly generate the latent space, it first generates the probability distribution, only after producing the latent.

By learning and predicting the probability distribution of the latent space, the VAE encodes each data point to a range defined by mean and variance, rather than a fixed vector. It generates "pools" of the data points in the latent space, rather than a single vector in that space. This allows variation within the pools, rather than fixing it to a single point.

### Loss function

To train the VAE, we use Reconstruction loss as MSE, combined with a scaled KL Divergence.

#### Reconstruction Loss

The Reconstruction loss is the MSE (Mean Squared Error) between original and reconstructed images, defined bellow:

$$\operatorname{Reconstruction\ Loss} = \frac{1}{N} \ \sum^N_{i=1}(x_i - \hat{x_i})^2$$

#### KL Divergence

The KL Divergence (Kullback-Leibler divergence) is a measure of how different the learned distribution is from a standard normal distribution, defined as:

$$\operatorname{KL}(q(z|x) || p(z)) = -\frac{1}{2} \sum^J_{j=1} (1 + \log \sigma^2_j - \mu^2_j - \sigma^2_j)$$

Where:

- $\operatorname{KL}(q(z|x) || p(z))$ is the KL divergence from p of z to q of z given x. In other words, how different is the encoder's learned distribution from a standard normal distribution.
- $q(z|x)$ (q of z given x) is the approximate/learned distribution.
- $p(z)$ is the target/reference distribution.
- $\mu$ is the mean.
- $\sigma^2$ is the variance.
- $J$ is the number of latent dimensions.

This loss prevents the encoder from encoding each input to a very different, arbitrary location in latent space, and encourages the learned distribution to be close to $\mathcal{N}(0,1)$. $1 + \log \sigma^2_j$ encourages variance to be close to 1, $\mu^2_j$ penalizes mean far from 0, and $\sigma^2_j$ penalizes far from 1.

#### Final Loss

The final loss is defined as the combination of the Reconstruction loss, with a weighted KL Divergence loss. It can be described as the sum of how well can we rebuild the images and how is the latent space structured. Were we to only use Reconstruction loss, the encoder would be able to map each image to arbitrary, distant points in latent space, but would lack structure and interpolation, and encourage overfitting. Whereas if we only used KL Divergence, the latent codes would all be $\mathcal{N}(0, 1)$ (standard Normal/Gaussian distribution) but contain no information about the images, in other words a perfect distribution, but useless for reconstruction.

The combination of both allows the encoder to follow a regular distribution pattern, and compress enough information to reconstruct the images.

$$\operatorname{loss} = \operatorname{Reconstruction\ Loss} + \beta \ \operatorname{KL}(q(z|x) || p(z))$$

Where $\beta$ is a parameter to control how much to emphasize each loss.


## Importing packages


In [None]:
from typing import Tuple

import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn import functional as F

import torchvision
from torchvision import transforms

import os
import math
import shutil
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split

## Functions


In [None]:
def split_dataset(
    source_dir: str,
    train_dir: str,
    test_dir: str,
    test_size: float = 0.2,
    random_state: int = 42,
) -> None:
    image_files = [
        f
        for f in os.listdir(source_dir)
        if f.lower().endswith((".png", ".jpg", ".jpeg"))
    ]

    train_files, test_files = train_test_split(
        image_files, test_size=test_size, random_state=random_state
    )

    os.makedirs(train_dir, exist_ok=True)
    os.makedirs(test_dir, exist_ok=True)

    for file in train_files:
        shutil.copy(os.path.join(source_dir, file), os.path.join(train_dir, file))

    for file in test_files:
        shutil.copy(os.path.join(source_dir, file), os.path.join(test_dir, file))

    print(
        f"Dataset split complete. {len(train_files)} training images, {len(test_files)} test images."
    )

## Classes


### Self-Attention Block


This block implements the self-attention algorithm from [Attention is all you need](https://arxiv.org/abs/1706.03762).

$$
\text{Attention}(Q, K, V)= \operatorname{softmax}(\frac{Q\cdot K^T}{\sqrt{d_k}})\cdot V
$$

For this implementation there are a few quirks, so they'll be described bellow:

- `in_proj` and `out_proj`: These layers are the first and final transformations to the data. The input projection layer, projects the input embeddings into the query, key and value spaces for attention computation. Whereas the output projection layer, projects its inputs back to the original embedding dimension, after the computation of the attention weighted values.

- _Interim shape_: The interim shape is a reorganization format that explicitly separates the embedding dimension into multiple attention heads. This is essential for computing attention, where each head learns different representations subspaces independently. The interim shape in the code is `[batch_size, seq_len, self.n_heads, self.d_heads].

- `.chunk` splits a tensor into equal-sized pieces along a specified dimension. In the context of the code, it takes the output of the input projection and splits into three equal tensors along the last dimension. `[batch_size, seq_len, 3 * embed_dim] -> 3 * [batch_size, seq_len, embed_dim]`

- `q, v, k` are transposed in order to make parallel computation possible across all attention heads. This is done for conformity with PyTorch's broadcasting method for matmul. Matmul follows these rules:

  1. The multiplication only happens in the last two dimensions.
  2. All leading dimensions are treated as batch dimensions.
  3. The operation broadcasts across these batch dimension.

  - The actual computation process is as follows:
  - ```text
    q @ k.transpose(-1, -2)

    For each combination of (batch_idx, head_idx):
    Take q[batch_idx, head_idx, :, :] -> shape (seq_len, d_head)
    Take k[batch_idx, head_idx, :, :] -> shape (seq_len, d_head)
    Compute: q[batch_idx, head_idx] @ k[batch_idx, head_idx].T
    Result: (seq_len, seq_len) attention scores
    ```

#### Forward pass step-by-step

1. Input projection and chunking
   - Projects input into query, key, and value spaces.
2. Reshape to interim shape
   - Splits the embedding dimension in to `n_heads` separate subspaces, enabling multi-head attention where each head learns different patterns.
3. Transpose for parallel computation
   - Moves heads to the second dimension so all heads can be computed in parallel.
4. Compute attention scores
   - Computes Similarity scores between every pair of positions in the sequence.
5. Apply causal mask
   - For autoregressive models, prevents positions from attending to future positions by setting their scores to negative infinity, that outputs to 0 after softmax.
6. Scale and normalize
   - Scales the dot products by $\sqrt{{d\_head}}$ to prevent them from growing too large.
7. Compute weighted values
   - Aggregates values from all positions, weighted by attention probabilities.
8. Transpose back
   - Reverses the earlier transpose to prepare for concatenating all heads back together.
9. Reshape to original format
   - Concatenates all attention heads back into a single embedding dimension, merging the learned learned representations from all heads.
10. Output Projection
    - Final linear transformation that allows the model to learn how to combine information from different heads optimally.


In [None]:
class SelfAttention(nn.Module):
    def __init__(
        self,
        n_heads,
        embed_dim,
        in_proj_bias=True,
        out_proj_bias=True,
    ) -> None:
        super().__init__()

        self.n_heads = n_heads

        self.in_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=in_proj_bias)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=out_proj_bias)

        self.d_heads = embed_dim // n_heads

    def forward(self, x: torch.Tensor, causal_mask=False) -> torch.Tensor:
        batch_size, seq_len, channels = x.shape

        interim_shape = (batch_size, seq_len, self.n_heads, self.d_heads)

        # Input projection and chunking
        q, k, v = self.in_proj(x).chunk(3, dim=-1)

        # Reshaping tensors to interim shape
        q = q.view(interim_shape)
        k = k.view(interim_shape)
        v = v.view(interim_shape)

        # Transposing
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        # Computing first part of attention
        weight = q @ k.transpose(-1, -2)  # Matmul

        # Applying mask
        if causal_mask:
            # Mask where the upper triangle (above the principal diagonal) is 1
            mask = torch.ones_like(weight, dtype=torch.bool).triu()

            # Fill mask values with -inf
            weight.masked_fill(mask, -torch.inf)

        # Dividing by square root of d_heads as described in paper
        weight /= math.sqrt(self.d_heads)

        weight = F.softmax(weight, dim=-1)

        # Final computing of attention
        output = weight @ v

        # Returning to original shape
        output = output.transpose(1, 2)

        # Changing the shape to the shape of out_proj
        output = output.reshape((batch_size, seq_len, channels))

        output = self.out_proj(output)

        return output

### Attention Block


This block is still mostly from the original transformer architecture's paper, but it has a more focused approach to vision-based attention. The first evidence of this, is the fact the block uses `GroupNorm` instead of `LayerNorm`.

The two layers from PyTorch implement the same operation, with the key difference between the two being that `GroupNorm` applies Normalization into groups of channels (32 in the code bellow), whereas `LayerNorm` applies it into the entire channels dimension. Both layers implement the following formula:

$$y = \gamma \cdot \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta$$

Where:

- $x$ is the input.
- $\mu$ is the input's mean.
- $\sigma$ is the standard deviation.
- $\gamma$ is a learnable scale parameter.
- $\beta$ is a leanable shift parameter.
- $\epsilon$ is a small constant for numerical stability.

And the difference is that each of these variables is measured from within the groups.

#### Forward pass step-by-step

1. Save the residual connection
   - Saves the original data to add as residual connection post self-attention layer.
2. Group Normalization
   - Group normalization for activation stability, so that attention scores don't explode or vanish.
3. Reshape to sequence format
   - Flattens the inputs so the images are sequential.
4. Transpose to Sequence format
   - Makes channels last dimension to conform with self-attention layer's expected format.
5. Apply Self-Attention
   - Allows spatial location to "see" the entire image, capturing global structure.
6. Transpose back to spatial format
   - Reverse the transpose.
7. Reshape back to 2D spatial
   - Reverse the image flattening.
8. Sum with residual connection


In [None]:
class AttentionBlock(nn.Module):
    def __init__(self, channels) -> None:
        super().__init__()
        self.group_norm = nn.GroupNorm(32, channels)
        self.attention = SelfAttention(1, channels)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        residual = x.clone()

        x = self.group_norm(x)

        n_samples, channels, height, width = x.shape

        x = x.view((n_samples, channels, height * width))

        x = x.transpose(-1, -2)

        x = self.attention(x)

        x = x.transpose(-1, -2)

        print(x.shape)

        x = x.view((n_samples, channels, height, width))

        x += residual

        return x

### Residual Block


Residual Blocks on the other hand, don't come from the Attention is all you Need paper, instead it originates form the ResNet architecture, first proposed in [Deep Residual Learning for Image Recognition](https://arxiv.org/pdf/1512.03385). This architecture is composed of two sequential weight layers, whose outputs are combined with the inputs to its layers.

Although the idea is simple enough, the implementation has some steps worth going over:

- `in_channels == out_channels`: This if block is supposed to handle if the number of channels is supposed to change. We wouldn't be able to add the original input to the output of the second weighted layer if the number of channels has increased, to handle this we pass the inputs through a 1D convolution that only increases the number of channels, preserving the spatial dimensions of the data. It ensures this by passing the input through a `torch.nn.Identity` layer, that doesn't perform any operations ($\operatorname{Identity}(x) = x$).

#### Forward pass step-by-step

1. Save the residual path
2. Group Normalization 1
3. Activation function
4. First convolution
5. Group Normalization 2
6. Second Convolution
7. Residual Connection


In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels) -> None:
        super().__init__()
        self.group_norm_1 = nn.GroupNorm(32, in_channels)
        self.conv_1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)

        self.group_norm_2 = nn.GroupNorm(32, out_channels)
        self.conv_2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)

        if in_channels == out_channels:
            self.residual_layer = nn.Identity()
        else:
            self.residual_layer = nn.Conv2d(
                in_channels, out_channels, kernel_size=1, padding=0
            )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        residue = x.clone()

        x = self.group_norm_1(x)
        x = F.selu(x)
        x = self.conv_1(x)
        x = self.group_norm_2(x)
        x = self.conv_2(x)

        return x + self.residual_layer(residue)

### Encoder


This is one of the two main components of the VAE, responsible for compressing high-dimensional input data into a low-dimensional latent representation. Rather than mapping the data to a single point, it maps them to a probability distribution in latent space.

The entire idea of this block is that the sequential layer of blocks will output a tensor with 8 channels, that are then divided, the first 4 will represent the mean, and the last 4 the log variance of the latent space. Once the channels have been split in half, we fit the data through a sampling equation:

$$z = \mu + \epsilon \cdot \sigma$$

Where $\mu$ is the mean, $\epsilon$ random noise, and $\sigma$ is the standard deviation. At first neither $\mu$ or $log \ \sigma^2$ actually represent the true mean and variance, but during training the model will learn these values for a given input.

As most of the code blocks so far, this one has its peculiarities:

- Inherits from `torch.nn.Sequential` instead of `torch.nn.Module`: this is basically the same as using `nn.Sequential` for chaining layers, it automatically registers parameters, and supports layer indexing (biggest reason for using this architecture).

#### Forward pass step-by-step

1. Loop through sequential layers
   - The for loop iterates through all layers in the class. Inside the loop, there is special handling for 2D convolutions, where we add a pixel on the right and bottom of the input, in order to compensate for the asymmetry of stride-2 convolution with `padding=0`.
2. Layer by layer transformation
   - Just applies the layers to the data.
3. Splitting channels
   - The outputs of the convolutional layer are chunked into two tensors, each with 4 channels. The first one is the mean, and the second one the log variance.
4. Clamp Log-Variance
   - Using `torch.clamp`, we clamp the values to a range $[-30, 20]$. This is done to prevent $log(\sigma^2)$ from becoming too extreme, ensuring numerical stability.
5. Sample from distribution
   - Reparametrization trick for sampling from the learned distribution, allowing backpropagation through the sampling.
6. Scale the latent space


In [None]:
class Encoder(nn.Sequential):
    def __init__(self):
        super().__init__(
            nn.Conv2d(3, 128, kernel_size=3, padding=1),
            ResidualBlock(128, 128),
            nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=0),
            ResidualBlock(128, 256),
            ResidualBlock(256, 256),
            nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=0),
            ResidualBlock(256, 512),
            ResidualBlock(512, 512),
            nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=0),
            ResidualBlock(512, 512),
            ResidualBlock(512, 512),
            ResidualBlock(512, 512),
            AttentionBlock(512),
            ResidualBlock(512, 512),
            nn.GroupNorm(32, 512),
            nn.SiLU(),
            nn.Conv2d(512, 8, kernel_size=3, padding=1),
            nn.Conv2d(8, 8, kernel_size=1, padding=0),
        )

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        x = input

        for module in self:
            if isinstance(module, nn.Conv2d) and module.stride == (2, 2):
                x = F.pad(x, (0, 1, 0, 1))
            x = module(x)

        mean, log_variance = torch.chunk(x, 2, dim=1)

        log_variance = torch.clamp(log_variance, -30, 20)

        std = torch.std(0.5 * log_variance)
        eps = torch.rand_like(std)
        x = mean + eps * std

        x *= 0.18215

        return x

### Decoder


The Decoder is the mirror of the Encoder. While the Encoder compresses, the decoder reconstructs by reversing the Encoder's operations.

#### Forward Pass step-by-step

1. Receive latent code
   - The compressed representation to expand back to an image.
2. Undo scaling
   - Reverses the unit variance scaling from the Encoder.
3. Progressive upsampling
4. Spatial Upsampling
5. Channel reduction
6. Output


In [None]:
class Decoder(nn.Sequential):
    def __init__(self):
        super().__init__(
            nn.Conv2d(4, 512, kernel_size=3, padding=1),
            ResidualBlock(512, 512),
            AttentionBlock(512),
            ResidualBlock(512, 512),
            ResidualBlock(512, 512),
            ResidualBlock(512, 512),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            ResidualBlock(512, 512),
            ResidualBlock(512, 512),
            ResidualBlock(512, 512),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            ResidualBlock(512, 256),
            ResidualBlock(256, 256),
            ResidualBlock(256, 256),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            ResidualBlock(256, 128),
            ResidualBlock(128, 128),
            ResidualBlock(128, 128),
            nn.GroupNorm(32, 128),
            nn.SiLU(),
            nn.Conv2d(128, 3, kernel_size=3, padding=1),
        )

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        x = input

        x /= 0.18215

        for module in self:
            x = module(x)

        return x

### VAE


Finally, the VAE architecture, composed of the encoder and the decoder sequentially.


In [None]:
class VAE(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)

        return decoded, encoded

## Downloading and preprocessing data


In [None]:
!gdown 1KXRTB_q4uub_XOHecpsQjE4Kmv76sZbV

!unzip -q all-dogs.zip

!mkdir models

In [None]:
source_dir = "./all-dogs"
train_dir = "./data/train/dogs"
test_dir = "./data/test/dogs"

split_dataset(source_dir, train_dir, test_dir)

## Training model


In [None]:
# Overall configurations
device = torch.device("mps" if torch.mps.is_available() else "cuda")

# Hyperparameters
num_epochs = 100
learning_rate = 1e-4
beta = 0.00025
batch_size = 4
accumulation_steps = 1
effective_batch_size = batch_size * accumulation_steps

In [None]:
# Data loading
transform = transforms.Compose(
    [
        transforms.Resize((56, 56)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ]
)
dataset = torchvision.datasets.ImageFolder(root="./data/train", transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

Dividing the loss by `accumulation_steps` simulates bigger batches, accumulating gradients over only after $n$ batches.


In [None]:
# Training loop
model = VAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

train_losses = []

for epoch in range(num_epochs):
    model.train()
    train_loss = 0

    for i, (images, _) in enumerate(dataloader):
        images = images.to(device)

        reconstructed, encoded = model(images)
        recon_loss = nn.MSELoss()(reconstructed, images)

        mean, log_variance = torch.chunk(encoded, 2, dim=1)

        kl_div = -0.5 * torch.sum(1 + log_variance - mean.pow(2) - log_variance.exp())
        loss = recon_loss + beta * kl_div

        loss = loss / accumulation_steps

        loss.backward()

        if (i + 1) % accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

        train_loss += loss.item() * accumulation_steps

        print(
            f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(dataloader)}], '
            f'Loss: {loss.item()*accumulation_steps:.4f}, Recon Loss: {recon_loss.item():.4f}, KL Div: {kl_div.item():.4f}'
        )

        with torch.no_grad():
            sample_image = images[0].unsqueeze(0)
            sample_reconstructed = model(sample_image)[0]

            sample_image = (sample_image * 0.5) + 0.5

            torchvision.utils.save_image(sample_reconstructed, "reconstructed.png")

    train_losses.append(train_loss / len(dataloader))

    torch.save(model.state_dict(), f=f"models/vae_model_epoch_{epoch + 1}.pth")

print("Training finished!")

In [None]:
# Plotting loss curve
plt.figure(figsize=(10, 15))
plt.plot(train_losses, label="Training Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("VAE Loss over Time")
plt.legend()
plt.show()