<a href="https://colab.research.google.com/github/AlessandroFornasier/diffusion-models/blob/main/autoencoders/vae.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [None]:
"""
Autoencoder class.

Args:
  dims (List[int]): Dimensionality of the layers.
  binary (bool): Flag to indicate binary data [0, 1]
"""
class Autoencoder(nn.Module):
  def __init__(self, dims : List[int], binary : bool = True):
    super(Autoencoder, self).__init__()

    self.binary = binary

    """
    Encoder

    Activation function:
      SiLU

    Note:
      Latent space has no activation function
    """
    layers = []
    for idim, odim in zip(dims[:-2], dims[1:-1]):
      layers.append(nn.Linear(idim, odim))
      layers.append(nn.SiLU())
    layers.append(nn.Linear(dims[-2], dims[-1]))
    self.encoder = nn.Sequential(*layers)


    """
    Decoder

    Activation function for hidden layers:
      SiLU

    Activation function for output layer:
      Sigmoid (normalizes (0, 1))
    """
    layers = []
    for idim, odim in zip(dims[-2:0:-1], dims[-3::-1]):
      layers.append(nn.Linear(idim, odim))
      layers.append(nn.SiLU())
    layers.append(nn.Linear(dims[1], dims[0]))
    if self.binary:
      layers.append(nn.Sigmoid())
    self.decoder = nn.Sequential(*layers)


    def encode(self, x)
      """
      Encodes the input data into the latent space.

      Args:
        x (torch.Tensor): Input data.

      Returns:
        torch.Tensor: Encoded data, latent space.
      """
      return self.encoder(x)

    def decode(self, z)
      """
      Decodes the latent space data.

      Args:
        z (torch.Tensor): Latent space data.

      Returns:
        torch.Tensor: Decoded data, output space.
      """
      return self.decoder(z)

    """
    Forward pass of the autoencoder
    """
    def forward(self, x):
      return self.decode(self.encode(x))

In [None]:
"""
Autoencoder class.

Args:
  dims (List[int]): Dimensionality of the layers.
  binary (bool): Flag to indicate binary data [0, 1]
"""
class VAE(Autoencoder):
  def __init__(self, dims):
    super().__init__(dims)

  def encode(self, x, eps: float = 1e-9):
    pass

  def reparametrize(self, dist):
    pass

  def decode(self, z):
    pass

  def forward(self, x, compute_loss: bool = True):
    pass