# Vector Quantization

### Sources:

Neural Discrete Representation Learning: de Oord et al. (2018)

Understanding Vector Quantization in VQ-VAE: https://huggingface.co/blog/ariG23498/understand-vq



The key idea of vector quantization is to quantize a continuous latent space into a discrete latent space. 

A latent embedding space is defined: $e \in \mathcal{R}^{K \times D}$ where $K$ is the size of the discrete latent space. Effectively, each vector is a $K$-way categorical, and $D$ is the dimensionality of each latent embedding vector, $e_i$. There are $K$ embedding vectors $e_i \in \mathcal{R}^D, i \in 1, 2,..., K$. 

The model takes an input $x$ and the encoder produces a continuous output $z_e(x)$. The discrete latent variables $z$ are calculated by a nearest neighbour look-up using the shared embedding space $e$ using the following:

\begin{equation}
q(z = k|x) = \begin{cases}
1 & \text{if} \ k = \text{argmin}_j || z_e(x) - e_i ||_2 \\
0 & \text{otherwise}
\end{cases}
\end{equation}

where $z_e$ is the output of the encoder. This is one-hot encoding of the closest embedding vector and is known as the straight-through estimator. In full, the difference between the continuous and discrete latent representation for one element is:

\begin{equation}
(z_i - z_{qi})^2 = z_i^2 + z_{qi}^2 - 2z_i z_{qi}
\end{equation}

Across the entire batch,

\begin{equation}
\text{MSE} = \frac{1}{N} \sum_{i=1}^{N} z_i^2 + z_{qi}^2 - 2z_i z_{qi}
\end{equation} 

\begin{equation}
 = \frac{1}{N} \sum_{i=1}^{N} z_i^2 + \sum_{i=1}^{N} z_{qi}^2 - 2 \sum_{i=1}^{N} z_i z_{qi}
\end{equation} 

Using vector notation,

\begin{equation}
MSE = \frac{1}{N} (||\textbf{z}||^2 + ||\textbf{z}_q||^2 - 2\textbf{z} \cdot \textbf{z}_q)
\end{equation} 



The input to the decoder in this one-hot encoded latent vector, $e_k$ so that:

\begin{equation}
z_q(x) = e_k
\end{equation}

where $k = \text{argmin}_j ||z_e(x) - e_j||_2$.

It is clear from the above equation that there is no real gradient defined for $z_q$ which means the gradient cannot flow from the encoder to the decoder through the embedding space during training. How van den Oord et al. fixed this was by using the straight-through estimator: simply, copy the gradients from the decoder input $z_q$ to the encoder output $z_e$. 

During the forward pass, the nearest embedding $z_q(x)$ is passed to the decoder, and during the backward pass, the gradient of the loss, $\nabla_z L$ is passed to the encoder to allow updates to the rest of the network to lower the reconstruction loss. This can be done as the output representation of the encoder is the same dimensionality as the input to the decoder (they are sharing the same $D$ dimensional space). 

The right panel on the figure below from van den Oord et al., shows how the gradient can push the encoder's output to be discretized differently in the next forward pass as the assignment of the embedding vector will be different.

The authors of the original VQ-VAE paper introduced more terms to the overall loss function which is written as:

\begin{equation}
L = \log p(x|z_q(x)) + ||sg[z_e(x)] - e||_2^2 + \beta ||z_e(x) - sg[e]||_2^2
\end{equation}

The first term is the reconstruction loss which optimizes the encoder and decoder through the straight-through estimator. 

The second term is a dictionary learning loss term. It penalises the difference between the encoder output and the embedding vectors $e_i$ to move the embedding vectors towards the encoder output. Here sg stands for the 'stop gradient' operator which blocks the flow of gradients and effectively keeps the encoder outputs fixed (only updated by the reconstruction loss, not the partial derivative of this loss term). 

The volume of the embedding space is dimensionless and therefore, it can grow arbitarily if the embeddings do not train as fast as the encoder parameters. The third term, the commitment loss,  which constrains the growth of the embedding space ensures the encoder commits to an embedding. $\beta$ here is a hyperparameter for the commitment loss (robust to $0.1 \leq \beta \leq 2.0$) 



# ![Vector Quantization](./figures/VQ-VAE_architecture.png)
*From van den Oord et al. (2018)*

## Implementaiton of VQEmbedding

In [65]:
import torch
import torch.nn as nn
import torch.nn.functional as F 

class VQEmbedding(nn.Module):
    def __init__(self, num_embeddings, embedding_dim):
        super().__init__()
        self.embedding_dim = embedding_dim 
        self.num_embeddings = num_embeddings 
        
        self.embedding = nn.Embedding(num_embeddings, embedding_dim)
        self.embedding.weight.data.uniform_(-1.0 / self.num_embeddings, 1.0 / self.num_embeddings)
    
    def forward(self, z, commitment_cost):
        b, c, h, w = z.shape # Get the shape (batch_size, embedding_dimension, height, width)

        z_channel_last = z.permute(0, 2, 3, 1) # Permute the input to (batch_size, height, width, embedding_dimension)
        z_flattened = z_channel_last.reshape(b*h*w, self.embedding_dim) # Flatten the input to (batch_size * height * width, embedding_dimension)
        
        # Compute the distances between the input and the embeddings

        distances = (
            torch.sum(z_flattened**2, dim=-1, keepdim=True)
            + torch.sum(self.embedding.weight.t()**2, dim=0, keepdim=True)
            - 2 * torch.matmul(z_flattened, self.embedding.weight.t())
        )

        encoding_indices = torch.argmin(distances, dim=-1)

        # Get the quantized latent vectors 
        z_q = self.embedding(encoding_indices) 
        z_q = z_q.reshape(b, h, w, self.embedding_dim)
        z_q = z_q.permute(0, 3, 1, 2)
        
        loss = F.mse_loss(z_q, z.detach()) + commitment_cost * F.mse_loss(z_q.detach(), z)

        # Straight-through estimator
        z_q = z + (z_q - z).detach() 

        return z_q, loss, encoding_indices

In [56]:
batch_size = 5
num_embeddings = 3
h, w = 3, 3

z_e = torch.randn(batch_size, num_embeddings, h, w) # typically encoded inputs from an image (batch_size, embedding_dimension, height, width)
commitment_cost = 0.25 

In [67]:
vqe = VQEmbedding(num_embeddings=num_embeddings, embedding_dim=3)

z_q, loss, encoding_indices = vqe(
    z_e,
    commitment_cost
)

In [69]:
print(encoding_indices.shape, z_q.shape, loss)

torch.Size([45]) torch.Size([5, 3, 3, 3]) tensor(1.1047, grad_fn=<AddBackward0>)
