<a href="https://colab.research.google.com/github/AlessandroFornasier/diffusion-models/blob/main/autoencoders/vae.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from dataclasses import dataclass
from typing import List, Optional

In [6]:
"""
Autoencoder class.

Args:
  dims (List[int]): Dimensionality of the layers.
  binary (bool): Flag to indicate binary data [0, 1]
"""
class Autoencoder(nn.Module):
  def __init__(self, dims : List[int], binary : bool = True) -> None:
    super(Autoencoder, self).__init__()

    self.binary = binary

    """
    Encoder

    Activation function:
      SiLU

    Note:
      Latent space has no activation function
    """
    layers = []
    for idim, odim in zip(dims[:-2], dims[1:-1]):
      layers.append(nn.Linear(idim, odim))
      layers.append(nn.SiLU())
    layers.append(nn.Linear(dims[-2], dims[-1]))
    self.encoder = nn.Sequential(*layers)


    """
    Decoder

    Activation function for hidden layers:
      SiLU

    Note:
      If the autoencoder is binary the activation function for output layer is a sigmoid, which normalizes to (0, 1)
    """
    layers = []
    for idim, odim in zip(dims[-2:0:-1], dims[-3::-1]):
      layers.append(nn.Linear(idim, odim))
      layers.append(nn.SiLU())
    layers.append(nn.Linear(dims[1], dims[0]))
    if self.binary:
      layers.append(nn.Sigmoid())
    self.decoder = nn.Sequential(*layers)


    def encode(self, x) -> torch.Tensor:
      """
      Encodes the input data into the latent space.

      Args:
        x (torch.Tensor): Input data.

      Returns:
        z (torch.Tensor): Encoded data, latent space.
      """
      return self.encoder(x)

    def decode(self, z) -> torch.Tensor:
      """
      Decodes the latent space data.

      Args:
        z (torch.Tensor): Latent space data.

      Returns:
        x_hat (torch.Tensor): Decoded data, output space.
      """
      return self.decoder(z)

    """
    Forward pass of the autoencoder
    """
    def forward(self, x) -> torch.Tensor:
      return self.decode(self.encode(x))

In [9]:
@dataclass
class VAEState:
  """
  Variational autoencoder state.

  Attributes:
   - x (Optional[torch.Tensor]): Input data
   - dist (Optional[torch.distributions.Distribution]): Encoder Gaussian distribution
   - z (Optional[torch.Tensor]): Latent space sample
   - x_hat (Optional[torch.Tensor]): Reconstructed data
  """
  x : Optional[torch.Tensor] = None                         # Input
  dist : Optional[torch.distributions.Distribution] = None  # Latent space distribution
  z : Optional[torch.Tensor] = None                         # Latent space sample
  x_hat : Optional[torch.Tensor] = None                     # Reconstructed input


"""
Variational autoencoder class.

Args:
  dims (List[int]): Dimensionality of the layers.
  binary (bool): Flag to indicate binary data [0, 1]

Note:
  A VAE is trained by maximizing ELBO:
  - Reconstruction loss (MSE ~ cross entropy)
  - KL divergence

Refernce:
  - https://hunterheidenreich.com/posts/modern-variational-autoencoder-in-pytorch/
  - https://github.com/pytorch/examples/blob/main/vae/main.py
"""
class VAE(Autoencoder):
  def __init__(self, dims):
    dims[-1] *= 2 # Mean and variance
    super().__init__(dims)

    self.softplus = nn.softplus

  def encode(self, x, eps: float = 1e-6) -> torch.distributions.Distribution:
    """
    Encodes the input data into the latent space.

    Args:
      x (torch.Tensor): Input data.
      eps (float): Small value to avoid numerical instability.

    Returns:
      dist (torch.distributions.MultivariateNormal): Normal distribution of the encoded data.

    Note:
      Learning logvar improves numerical stability since var is smaller than zero and tipically smaller than once. Hence logvar is within (-inf, log(1)).
      Softplus + epsilon (softplus(x) = \log(1 + \exp(x))) is used to get sigma instead of directly exponentiating while ensuring numerical stability
    """
    x = self.encoder(x)
    mu, logvar = torch.tensor_split(x, 2, dim=-1)
    var = self.softplus(logvar) + eps
    return torch.distributions.MultivariateNormal(mu, scale_tril=torch.diag_embed(var)) # Use scale_tril as it is more efficient


  def reparametrize(self, dist) -> torch.Tensor:
    """
    Perform sampling via the reparametrization trick

    Args:
      dist (torch.distributions.MultivariateNormal): Normal distribution of the encoded data.

    Returns:
      z (torch.Tensor): Sampled data from the latent space z = mu + sigma * epsilon. With epsilon ~ N(0,I)
    """
    return dist.rsample()

  def decode(self, z) -> torch.Tensor:
    """
    Decodes the data from the latent space to the original input space.

    Args:
      z (torch.Tensor): Data in the latent space.

    Returns:
      x_hat (torch.Tensor): Reconstructed data in the original input space.
    """
    return self.decoder(z)

  def forward(self, x) -> VAEState:
    """
    Performs a forward pass of the VAE.

    Args:
      x (torch.Tensor): Input data.

    Returns:
      state (VAEState): state of the VAE.
    """
    state = VAEState(x)
    state.dist = self.encode(state.x)
    state.z = self.reparametrize(state.dist)
    state.x_hat = self.decoder(state.z)
    return state