# Vector Quantized Variational Autoencoders (VQ-VAEs)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class VQVAE(nn.Module):
    def __init__(
        self,
        input_dim: int,
        n_channels: int,
        latent_dim: int,
        codebook_size: int,
    ):
        super().__init__()
        
        self.latent_dim = latent_dim
        
        self.encoder = nn.Sequential(
            nn.Conv2d(n_channels, 96, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(96),
            nn.Conv2d(96, 192, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(192),
            nn.Conv2d(192, 384, kernel_size=3, stride=2, padding=1), 
            nn.ReLU(),
            nn.BatchNorm2d(384),
            nn.Flatten(),
            nn.Linear(384 * input_dim[1] // 8 * input_dim[2] // 8, 2 * latent_dim),
        )
        
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 384 * input_dim[1] // 8 * input_dim[2] // 8),
            nn.Unflatten(1, (384, input_dim[1] // 8, input_dim[2] // 8)),
            nn.ConvTranspose2d(384, 192, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(192),
            nn.ConvTranspose2d(192, 96, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(96),
            nn.ConvTranspose2d(96, n_channels, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid(),
        )
        
    def sample(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
        """
        Sample from the latent space using the reparameterization trick.

        Args:
            mu: mean of the latent space
            logvar: log variance of the latent space

        Returns:
            z: sampled latent space
        """
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass through the VAE.

        Args:
            x: input data

        Returns:
            x_hat: reconstructed data
        """
        encoded = self.encoder(x)
        mu, logvar = encoded.split(self.latent_dim, dim=1)
        z = self.sample(mu, logvar)
        return self.decoder(z), (mu, logvar)