# Vector Quantized Variational Autoencoders (VQ-VAEs)

In [97]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [98]:
class VQVAE(nn.Module):
    def __init__(
        self,
        input_dim: int,
        n_channels: int,
        latent_dim: int,
        codebook_size: int,
    ):
        super().__init__()
        
        self.latent_dim = latent_dim
        
        self.encoder = nn.Sequential(
            nn.Conv2d(n_channels, 96, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(96),
            nn.Conv2d(96, 192, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(192),
            nn.Conv2d(192, 384, kernel_size=3, stride=2, padding=1), 
            nn.ReLU(),
            nn.BatchNorm2d(384),
            nn.Conv2d(384, latent_dim, kernel_size=3),
            nn.ReLU(),
            nn.BatchNorm2d(latent_dim)
        )

        self.codebook = nn.Parameter(torch.zeros(codebook_size, latent_dim))
        
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(latent_dim, 384, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(384),
            nn.ConvTranspose2d(384, 192, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(192),
            nn.ConvTranspose2d(192, 96, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(96),
            nn.ConvTranspose2d(96, n_channels, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid(),
        )
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass through the VAE.

        Args:
            x: input data

        Returns:
            x_hat: reconstructed data
        """
        z_encoder = self.encoder(x)
        z_q = self._quantize_encoder_output(z_encoder)
        return self.decoder(z_q)
    
    def _quantize_encoder_output(self, z_e: torch.Tensor) -> torch.Tensor:
        """
        Quantizing the encoded tensor by snapping its elements to the closest codebook 
        entry.

        Args:
            z_e: encoded representation

        Returns:
            z_q: quantized representation
        """
        batch_size, latent_dim, h, w = z_e.shape
        encoded = z_e.permute(0, 2, 3, 1).reshape(batch_size*h*w, latent_dim)
        quantized = self.codebook[torch.argmin(torch.cdist(encoded, self.codebook), dim=1)]
        z_q = quantized.reshape(batch_size, h, w, latent_dim).permute(0, 3, 1, 2)
        return z_q
    


In [99]:
# Create dummy tensors and model for testing
batch_size = 2
n_channels = 3
height = 64
width = 64
latent_dim = 128
codebook_size = 512

# Create dummy input tensor
dummy_input = torch.randn(batch_size, n_channels, height, width)

# Create dummy model
dummy_model = VQVAE(
    input_dim=(n_channels, height, width),
    latent_dim=latent_dim,
    n_channels=n_channels,
    codebook_size=codebook_size
)

# Run forward pass to see debug statements
print("Input shape:", dummy_input.shape)
print("Running forward pass...")
output = dummy_model(dummy_input)
print("Output shape:", output.shape)


Input shape: torch.Size([2, 3, 64, 64])
Running forward pass...


RuntimeError: mat1 and mat2 shapes cannot be multiplied (1536x6 and 128x24576)