<a href="https://colab.research.google.com/github/AlessandroFornasier/diffusion-models/blob/main/autoencoders/vae.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [51]:
import torch
import torch.nn as nn

from dataclasses import dataclass
from typing import List, Optional


@dataclass
class AutoencoderState:
  """
  Autoencoder state.

  Attributes:
   - x (Optional[torch.Tensor]): Input data
   - z (Optional[torch.Tensor]): Latent space sample
   - x_hat (Optional[torch.Tensor]): Reconstructed data
  """
  x : Optional[torch.Tensor] = None                         # Input
  z : Optional[torch.Tensor] = None                         # Latent space sample
  x_hat : Optional[torch.Tensor] = None                     # Reconstructed input


class Autoencoder(nn.Module):
  """
  Autoencoder class.

  Args:
    dims (List[int]): Dimensions of the layers.
    binary (bool): Flag to indicate binary data [0, 1]
  """

  def __init__(self, dims : List[int], binary : bool = True) -> None:
    super(Autoencoder, self).__init__()

    self.binary = binary

    """
    Encoder

    Activation function:
      SiLU

    Note:
      Latent space has no activation function
    """
    layers = []
    for idim, odim in zip(dims[:-2], dims[1:-1]):
      layers.append(nn.Linear(idim, odim))
      layers.append(nn.SiLU())
    layers.append(nn.Linear(dims[-2], dims[-1]))
    self.encoder = nn.Sequential(*layers)


    """
    Decoder

    Activation function for hidden layers:
      SiLU

    Note:
      If the autoencoder is binary the activation function for output layer is a sigmoid, which normalizes to (0, 1)
    """
    layers = []
    for idim, odim in zip(dims[-1:1:-1], dims[-2:0:-1]):
      layers.append(nn.Linear(idim, odim))
      layers.append(nn.SiLU())
    layers.append(nn.Linear(dims[1], dims[0]))
    if self.binary:
      layers.append(nn.Sigmoid())
    self.decoder = nn.Sequential(*layers)


    def encode(self, x) -> torch.Tensor:
      """
      Encodes the input data into the latent space.

      Args:
        x (torch.Tensor): Input data.

      Returns:
        z (torch.Tensor): Encoded data, latent space.
      """
      return self.encoder(x)

    def decode(self, z) -> torch.Tensor:
      """
      Decodes the latent space data.

      Args:
        z (torch.Tensor): Latent space data.

      Returns:
        x_hat (torch.Tensor): Decoded data, output space.
      """
      return self.decoder(z)

    def forward(self, x) -> torch.Tensor:
      """
      Forward pass of the autoencoder
      """
      state = AutoencoderState(x)
      state.z = self.encode(x)
      state.x_hat = self.decode(state.z)
      return state


@dataclass
class VAEState(AutoencoderState):
  """
  Variational autoencoder state.

  Attributes:
   - x (Optional[torch.Tensor]): Input data
   - z (Optional[torch.Tensor]): Latent space sample
   - x_hat (Optional[torch.Tensor]): Reconstructed data
   - dist (Optional[torch.distributions.Distribution]): Encoder Gaussian distribution
  """
  dist : Optional[torch.distributions.Distribution] = None  # Latent space distribution


class VAE(Autoencoder):
  """
  Variational autoencoder class.

  Args:
    dims (List[int]): Dimensions of the layers.
    binary (bool): Flag to indicate binary data [0, 1]

  Note:
    A VAE is trained by maximizing ELBO:
    - Reconstruction loss (MSE ~ cross entropy)
    - KL divergence

  Refernce:
    - https://hunterheidenreich.com/posts/modern-variational-autoencoder-in-pytorch/
    - https://github.com/pytorch/examples/blob/main/vae/main.py
  """

  def __init__(self, dims : List[int], binary : bool = True) -> None:
    dims[-1] *= 2 # Mean and variance
    super().__init__(dims, binary)

    self.softplus = nn.Softplus()

  def encode(self, x, eps: float = 1e-6) -> torch.distributions.Distribution:
    """
    Encodes the input data into the latent space.

    Args:
      x (torch.Tensor): Input data.
      eps (float): Small value to avoid numerical instability.

    Returns:
      dist (torch.distributions.MultivariateNormal): Normal distribution of the encoded data.

    Note:
      Learning logvar improves numerical stability since var is smaller than zero and tipically smaller than once. Hence logvar is within (-inf, log(1)).
      Softplus + epsilon (softplus(x) = \log(1 + \exp(x))) is used to get sigma instead of directly exponentiating while ensuring numerical stability
    """
    x = self.encoder(x)
    mu, logvar = torch.tensor_split(x, 2, dim=-1)
    var = self.softplus(logvar) + eps
    return torch.distributions.MultivariateNormal(mu, scale_tril=torch.diag_embed(var)) # Use scale_tril as it is more efficient


  def reparametrize(self, dist) -> torch.Tensor:
    """
    Perform sampling via the reparametrization trick

    Args:
      dist (torch.distributions.MultivariateNormal): Normal distribution of the encoded data.

    Returns:
      z (torch.Tensor): Sampled data from the latent space z = mu + sigma * epsilon. With epsilon ~ N(0,I)
    """
    return dist.rsample()

  def decode(self, z) -> torch.Tensor:
    """
    Decodes the data from the latent space to the original input space.

    Args:
      z (torch.Tensor): Data in the latent space.

    Returns:
      x_hat (torch.Tensor): Reconstructed data in the original input space.
    """
    return self.decoder(z)

  def forward(self, x) -> VAEState:
    """
    Performs a forward pass of the VAE.

    Args:
      x (torch.Tensor): Input data.

    Returns:
      state (VAEState): state of the VAE.
    """
    state = VAEState(x)
    state.dist = self.encode(state.x)
    state.z = self.reparametrize(state.dist)
    state.x_hat = self.decode(state.z)
    return state

In [19]:
from tqdm import tqdm
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from typing import Callable, Dict

class AutoencoderTrainer:
  """
  Trainer class for training an autoencoder model.

  Args:
    device (torch.device): The device (CPU or GPU) to run the model on.
    model (nn.Module): The autoencoder model to be trained.
    loss (Callable[[AutoencoderState], Dict[torch.Tensor]]): A callable loss function that takes the model's state and returns a scalar loss value.
    optimizer (Optimizer): The optimizer for training the model.
    epochs (int): Number of epochs
    writer (Optional[SummaryWriter]): Optional TensorBoard writer for logging training metrics.
  """
  def __init__(
    self,
    device: torch.device,
    model: nn.Module,
    loss: Callable[[AutoencoderState], Dict[str, torch.Tensor]],
    optimizer: Optimizer,
    epochs: int,
    writer: Optional[SummaryWriter] = None,
  ) -> None:
    self.device = device
    self.model = model.to(self.device)
    self.loss = loss
    self.optimizer = optimizer
    self.epochs = epochs
    self.writer = writer

  def train(self, dataloader: DataLoader) -> None:
    """
    Trains the autoencoder model on the given dataset.

    Args:
      dataloader (DataLoader): The DataLoader for loading training data.
    """
    self.model.train()

    step = 0
    for epoch in range(self.epochs):
      epoch_loss = 0.0
      progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}")

      for data, _ in progress_bar:
        data = data.to(self.device)

        state = self.model(data)
        losses = self.loss(state, self.model.binary)

        loss = sum(losses.values())

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        batch_loss = loss.item()
        epoch_loss += batch_loss
        average_loss = epoch_loss / len(dataloader)
        progress_bar.set_postfix(loss=batch_loss)

        if self.writer:
          self.writer.add_scalar("Train/Loss/Batch", batch_loss, step)
          self.writer.add_scalar("Train/Loss/Epoch", average_loss, step)
          for name, loss in losses.items():
            self.writer.add_scalar(f"Train/Loss/{name}", loss, step)

        step += 1
        print(f"Epoch [{epoch+1}/{epochs}] \| Batch loss: {batch_loss:.4f} \| Epoch Loss: {epoch_loss:.4f} \| Avg Loss: {average_loss:.4f}")

  def test(self, dataloader: DataLoader) -> None:
    """
    Test the autoencoder model on the given dataset.

    Args:
      dataloader (DataLoader): The DataLoader for loading training data.
    """
    self.model.eval()

    average_loss = 0.0
    with torch.no_grad():
      for data, _ in tqdm(dataloader, desc="Testing"):
        data = data.to(self.device)

        state = self.model(data)
        losses = self.loss(state, self.model.binary)
        loss = sum(losses.values())

        average_loss += loss.item()

    average_loss /= len(dataloader)
    if self.writer:
      self.writer.add_scalar("Test/Loss/Average", average_loss)
      for name, loss in losses.items():
        self.writer.add_scalar(f"Test/Loss/Average/{name}", loss / len(dataloader))

    print(f"Average test loss: {average_loss:.4f}")



class VAETrainer(AutoencoderTrainer):
  """
  Trainer class for training a VAE model.

  Args:
    device (torch.device): The device (CPU or GPU) to run the model on.
    model (nn.Module): The autoencoder model to be trained.
    loss (Callable[[VAEState], Dict[torch.Tensor]]): A callable loss function that takes the model's state and returns a scalar loss value.
    optimizer (Optimizer): The optimizer for training the model.
    epochs (int): Number of epochs
    writer (Optional[SummaryWriter]): Optional TensorBoard writer for logging training metrics.
  """
  def __init__(
    self,
    device: torch.device,
    model: nn.Module,
    loss: Callable[[VAEState], Dict[str, torch.Tensor]],
    optimizer: Optimizer,
    epochs: int,
    writer: Optional[SummaryWriter] = None,
  ) -> None:
    super().__init__(
      device,
      model,
      loss,
      optimizer,
      epochs,
      writer
    )

In [22]:
from torchvision import datasets
from torchvision.transforms import v2

class MNISTLoader:
  """
  A utility class for loading and preprocessing the MNIST dataset using a custom transformation pipeline.

  The transformation pipeline includes:
    - Conversion to a PyTorch image tensor.
    - Scaling pixel values from [0, 255] to [0.0, 1.0].
    - Flattening the image into a 1D tensor and shifting the values to [-0.5, 0.5].

  Args:
    batch_size (int): Number of samples per batch in the DataLoader.
  """
  def __init__(self, batch_size: int) -> None:
    self.batch_size = batch_size
    self.transform = v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True), v2.Lambda(lambda x: x.view(-1) - 0.5)])

  def get_dataloader(self, train: bool) -> DataLoader:
    data = datasets.MNIST('./data/MNIST', download=True, train=train, transform=self.transform)
    return DataLoader(data, batch_size=self.batch_size, shuffle=True)

In [52]:
import torch.nn.functional as Func

from datetime import datetime

class VAELoss:
  """
  VAE Loss callable class. The VAE loss is given by the ELBO,
  which is the the sum of the reconstruction loss and the KL divergence loss
  """
  def __binary_vae_loss(self, state: VAEState) -> Dict[str, torch.Tensor]:
    rl = Func.binary_cross_entropy(state.x_hat, state.x, reduction='none').sum(-1).mean() # Reconstruction loss
    target_dist = torch.distributions.MultivariateNormal(
      torch.zeros_like(state.z, device=state.z.device),
      scale_tril=torch.eye(state.z.shape[-1], device=state.z.device).unsqueeze(0).expand(state.z.shape[0], -1, -1),
    )
    kll = torch.distributions.kl.kl_divergence(state.dist, target_dist).mean() # KL loss
    return rl + kll

  def __call__(self, state: VAEState, binary: bool) -> Dict[str, torch.Tensor]:
    if binary:
      return self.__binary_vae_loss(self, state)
    else:
      raise NotImplementedError


dims = [28*28, 512, 128, 64, 24, 12, 4]
learning_rate = 1e-3
weight_decay = 1e-2
epochs = 50

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = VAE(dims=dims, binary=True)
loss = VAELoss
dataloader = MNISTLoader(batch_size=128)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
writer = SummaryWriter(f'./runs/MNIST/VAE_{datetime.now().strftime("%Y%m%d%H%M%S")}')
trainer = VAETrainer(device=device, model=model, loss=loss, optimizer=optimizer, epochs=epochs, writer=writer)

trainer.train(dataloader.get_dataloader(train=True))

decoder layers: [Linear(in_features=8, out_features=12, bias=True), SiLU(), Linear(in_features=12, out_features=24, bias=True), SiLU(), Linear(in_features=24, out_features=64, bias=True), SiLU(), Linear(in_features=64, out_features=128, bias=True), SiLU(), Linear(in_features=128, out_features=512, bias=True), SiLU(), Linear(in_features=512, out_features=784, bias=True), Sigmoid()]


Epoch 1/50:   0%|          | 0/469 [00:00<?, ?it/s]


RuntimeError: mat1 and mat2 shapes cannot be multiplied (128x4 and 8x12)