In [44]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class ConvolutionalVAE(nn.Module):
    def __init__(
        self,
        input_dim: int,
        n_channels: int,
        conv_dim: int,
        latent_dim: int,
    ):
        super().__init__()
        
        self.latent_dim = latent_dim
        
        self.encoder = nn.Sequential(
            nn.Conv2d(n_channels, conv_dim, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(conv_dim),
            nn.Conv2d(conv_dim, 2 * conv_dim, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(2 * conv_dim),
            nn.Conv2d(2 * conv_dim, 4 * conv_dim, kernel_size=3, stride=2, padding=1), 
            nn.ReLU(),
            nn.BatchNorm2d(4 * conv_dim),
            nn.Flatten(),
            nn.Linear(conv_dim * 4 * input_dim[1] // 8 * input_dim[2] // 8, 2 * latent_dim),
        )
        
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 4 * conv_dim * input_dim[1] // 8 * input_dim[2] // 8),
            nn.Unflatten(1, (4 * conv_dim, input_dim[1] // 8, input_dim[2] // 8)),
            nn.ConvTranspose2d(4 * conv_dim, 2 * conv_dim, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(2 * conv_dim),
            nn.ConvTranspose2d(2 * conv_dim, conv_dim, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(conv_dim),
            nn.ConvTranspose2d(conv_dim, n_channels, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid(),
        )
        
    def sample(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
        """
        Sample from the latent space using the reparameterization trick.

        Args:
            mu: mean of the latent space
            logvar: log variance of the latent space

        Returns:
            z: sampled latent space
        """
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass through the VAE.

        Args:
            x: input data

        Returns:
            x_hat: reconstructed data
        """
        encoded = self.encoder(x)
        mu, logvar = encoded.split(self.latent_dim, dim=1)
        z = self.sample(mu, logvar)
        return self.decoder(z)

In [45]:
# Checking that dimenionality is correct
vae = ConvolutionalVAE(
    input_dim=(3, 32, 32),
    n_channels=3,
    conv_dim=96,
    latent_dim=128,
    batch_size=4,
)

random_data = torch.randn(4, 3, 32, 32)
x_hat = vae(random_data)

print("x_hat shape", x_hat.shape)

x_hat shape torch.Size([4, 3, 32, 32])


So we've adjusted the architecture and added in the sampling/reparameterization trick to allow the flow of gradients. What's left?

the other difference between the VAE and the AE is the loss function. our loss term consists of two parts now, reconstruction and KL
divergence of the latent distribution from a standard normal distribution