<a href="https://colab.research.google.com/github/AlessandroFornasier/diffusion-models/blob/main/autoencoders/vae.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn

from dataclasses import dataclass
from typing import List, Optional

In [2]:
@dataclass
class AutoencoderState:
  """
  Autoencoder state.

  Attributes:
   - x (Optional[torch.Tensor]): Input data
   - z (Optional[torch.Tensor]): Latent space sample
   - x_hat (Optional[torch.Tensor]): Reconstructed data
  """
  x : Optional[torch.Tensor] = None                         # Input
  z : Optional[torch.Tensor] = None                         # Latent space sample
  x_hat : Optional[torch.Tensor] = None                     # Reconstructed input


class Autoencoder(nn.Module):
  """
  Autoencoder class.

  Args:
    dims (List[int]): Dimensionality of the layers.
    binary (bool): Flag to indicate binary data [0, 1]
  """

  def __init__(self, dims : List[int], binary : bool = True) -> None:
    super(Autoencoder, self).__init__()

    self.binary = binary

    """
    Encoder

    Activation function:
      SiLU

    Note:
      Latent space has no activation function
    """
    layers = []
    for idim, odim in zip(dims[:-2], dims[1:-1]):
      layers.append(nn.Linear(idim, odim))
      layers.append(nn.SiLU())
    layers.append(nn.Linear(dims[-2], dims[-1]))
    self.encoder = nn.Sequential(*layers)


    """
    Decoder

    Activation function for hidden layers:
      SiLU

    Note:
      If the autoencoder is binary the activation function for output layer is a sigmoid, which normalizes to (0, 1)
    """
    layers = []
    for idim, odim in zip(dims[-2:0:-1], dims[-3::-1]):
      layers.append(nn.Linear(idim, odim))
      layers.append(nn.SiLU())
    layers.append(nn.Linear(dims[1], dims[0]))
    if self.binary:
      layers.append(nn.Sigmoid())
    self.decoder = nn.Sequential(*layers)


    def encode(self, x) -> torch.Tensor:
      """
      Encodes the input data into the latent space.

      Args:
        x (torch.Tensor): Input data.

      Returns:
        z (torch.Tensor): Encoded data, latent space.
      """
      return self.encoder(x)

    def decode(self, z) -> torch.Tensor:
      """
      Decodes the latent space data.

      Args:
        z (torch.Tensor): Latent space data.

      Returns:
        x_hat (torch.Tensor): Decoded data, output space.
      """
      return self.decoder(z)

    """
    Forward pass of the autoencoder
    """
    def forward(self, x) -> torch.Tensor:
      state = AutoencoderState(x)
      state.z = self.encode(x)
      state.x_hat = self.decode(state.z)
      return state

In [3]:
@dataclass
class VAEState(AutoencoderState):
  """
  Variational autoencoder state.

  Attributes:
   - x (Optional[torch.Tensor]): Input data
   - z (Optional[torch.Tensor]): Latent space sample
   - x_hat (Optional[torch.Tensor]): Reconstructed data
   - dist (Optional[torch.distributions.Distribution]): Encoder Gaussian distribution
  """
  dist : Optional[torch.distributions.Distribution] = None  # Latent space distribution


class VAE(Autoencoder):
  """
  Variational autoencoder class.

  Args:
    dims (List[int]): Dimensionality of the layers.
    binary (bool): Flag to indicate binary data [0, 1]

  Note:
    A VAE is trained by maximizing ELBO:
    - Reconstruction loss (MSE ~ cross entropy)
    - KL divergence

  Refernce:
    - https://hunterheidenreich.com/posts/modern-variational-autoencoder-in-pytorch/
    - https://github.com/pytorch/examples/blob/main/vae/main.py
  """

  def __init__(self, dims : List[int], binary : bool = True) -> None:
    dims[-1] *= 2 # Mean and variance
    super().__init__(dims, binary)

    self.softplus = nn.softplus

  def encode(self, x, eps: float = 1e-6) -> torch.distributions.Distribution:
    """
    Encodes the input data into the latent space.

    Args:
      x (torch.Tensor): Input data.
      eps (float): Small value to avoid numerical instability.

    Returns:
      dist (torch.distributions.MultivariateNormal): Normal distribution of the encoded data.

    Note:
      Learning logvar improves numerical stability since var is smaller than zero and tipically smaller than once. Hence logvar is within (-inf, log(1)).
      Softplus + epsilon (softplus(x) = \log(1 + \exp(x))) is used to get sigma instead of directly exponentiating while ensuring numerical stability
    """
    x = self.encoder(x)
    mu, logvar = torch.tensor_split(x, 2, dim=-1)
    var = self.softplus(logvar) + eps
    return torch.distributions.MultivariateNormal(mu, scale_tril=torch.diag_embed(var)) # Use scale_tril as it is more efficient


  def reparametrize(self, dist) -> torch.Tensor:
    """
    Perform sampling via the reparametrization trick

    Args:
      dist (torch.distributions.MultivariateNormal): Normal distribution of the encoded data.

    Returns:
      z (torch.Tensor): Sampled data from the latent space z = mu + sigma * epsilon. With epsilon ~ N(0,I)
    """
    return dist.rsample()

  def decode(self, z) -> torch.Tensor:
    """
    Decodes the data from the latent space to the original input space.

    Args:
      z (torch.Tensor): Data in the latent space.

    Returns:
      x_hat (torch.Tensor): Reconstructed data in the original input space.
    """
    return self.decoder(z)

  def forward(self, x) -> VAEState:
    """
    Performs a forward pass of the VAE.

    Args:
      x (torch.Tensor): Input data.

    Returns:
      state (VAEState): state of the VAE.
    """
    state = VAEState(x)
    state.dist = self.encode(state.x)
    state.z = self.reparametrize(state.dist)
    state.x_hat = self.decoder(state.z)
    return state

In [5]:
import torch.optim as Optimizer
import torch.nn.functional as Func

from tqdm import tqdm
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from typing import Callable, Dict

class AutoencoderTrainer:
  """
  Trainer class for training an autoencoder model.

  Args:
    device (torch.device): The device (CPU or GPU) to run the model on.
    model (nn.Module): The autoencoder model to be trained.
    loss (Callable[[AutoencoderState], Dict[torch.Tensor]]): A callable loss function that takes the model's state and returns a scalar loss value.
    dataloader (DataLoader): The DataLoader for loading training data.
    optimizer (Optimizer.Optimizer): The optimizer for training the model.
    epochs (int): Number of epochs
    writer (Optional[SummaryWriter]): Optional TensorBoard writer for logging training metrics.
  """
  def __init__(
    self,
    device: torch.device,
    model: nn.Module,
    loss: Callable[[AutoencoderState], Dict[str, torch.Tensor]],
    dataloader: DataLoader,
    optimizer: Optimizer.Optimizer,
    epochs: int,
    writer: Optional[SummaryWriter] = None,
  ) -> None:
    self.device = device
    self.model = model.to(self.device)
    self.loss = loss
    self.dataloader = dataloader
    self.optimizer = optimizer
    self.epochs = epochs
    self.writer = writer

  def train(self) -> None:
    self.model.train()

    for epoch in range(self.epochs):
      epoch_loss = 0.0
      progress_bar = tqdm(self.dataloader, desc=f"Epoch {epoch+1}/{epochs}")

      for batch_idx, (data, _) in enumerate(progress_bar):
        data = data.to(self.device)

        state = self.model(data)
        losses = self.loss(state)

        loss = sum(losses.values())

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        batch_loss = loss.item()
        epoch_loss += batch_loss
        average_loss = epoch_loss / len(self.dataloader)
        progress_bar.set_postfix(loss=batch_loss)

        if self.writer:
          self.writer.add_scalar("Train/Loss/Batch", batch_loss, global_step)
          self.writer.add_scalar("Train/Loss/Epoch", average_loss, epoch)
          for name, loss in losses.items():
            self.writer.add_scalar(f"Train/Loss/{name}", loss, epoch)

        print(f"Epoch [{epoch+1}/{epochs}] \| Batch loss: {batch_loss:.4f} \| Epoch Loss: {epoch_loss:.4f} \| Avg Loss: {average_loss:.4f}")

  def test(self) -> None:
    self.model.eval()



class VAETrainer(AutoencoderTrainer):
  """
  Trainer class for training a VAE model.

  Args:
    device (torch.device): The device (CPU or GPU) to run the model on.
    model (nn.Module): The autoencoder model to be trained.
    loss (Callable[[AutoencoderState], Dict[torch.Tensor]]): A callable loss function that takes the model's state and returns a scalar loss value.
    dataloader (DataLoader): The DataLoader for loading training data.
    optimizer (Optimizer.Optimizer): The optimizer for training the model.
    writer (Optional[SummaryWriter]): Optional TensorBoard writer for logging training metrics.
  """
  def __init__(
    self,
    device: torch.device,
    model: nn.Module,
    loss: Callable[[AutoencoderState], torch.Tensor],
    dataloader: DataLoader,
    optimizer: Optimizer.Optimizer,
    writer: Optional[SummaryWriter] = None,
  ) -> None:
    super().__init__(
      device,
      model,
      loss,
      dataloader,
      optimizer,
      writer
    )