# Practical 2: Vector Quantization for Learning Representations
---

**Tutorial overview:**
In this tutorial you will learn about and implement a Vector Quantization layer for learning discrete representations of images. This tutorial is adapted from the VQ-VAE example in the [Haiku repository](https://github.com/deepmind/dm-haiku).


**Tutorial outline:**
- [What is vector quantization?](#vq)
- [Setup](#setup)
  - Install and import packages
  - Get MNIST dataset
- [Implementing the model](#implementing-the-model)
  - Implementing the VQ layer
  - Implementing the encoder and decoder
- [Training](#training)
- [Analysis](#analysis)


## What is vector quantization? <a class="anchor" id="what-is-vq"></a>
---

[Vector quantization](https://en.wikipedia.org/wiki/Vector_quantization)  (VQ) is a technique used to map continuous vectors to a finite set of discrete vectors, called a "codebook". This technique and variants of it are common data compression techniques.

In representation learning, we'll often train a model specifically for the purpose of extracting a representation that it implicitly learns (e.g., in one of its hidden layer).

VQ has been used as a layer in deep learning in order to turn models that would otherwise learn continuous representations into ones that learn discrete ones, including for generative models like VAEs (variational autoencoders) and GANs (generative adversarial networks), which you'll see later this week. The 2017 paper [Neural Discrete Representation Learning](https://arxiv.org/abs/1711.00937) introduces VQ-VAE, which makes the encoder output discrete codes. In this practical, you will implement the VQ layer introduced in the paper.


A lot of state-of-the-art image generation work relies on first learning discrete representations, such as [DALL-E](https://arxiv.org/abs/2102.12092) and [Phenaki](https://sites.research.google/phenaki/).  Discrete representations can then be modeled using powerful autoregressive language models. Something to think about: why else might a discrete representation be advantageous over a continuous one?

In the following section we'll see how the VQ operation works and how to implement it.


#### Quick note on autoencoders
Below, you'll train a VQ-VAE, but you won't need to know anything about VAEs for this Colab (you will learn about them later this week). For now, just ensure you understand the plain autoencoder, where an **encoder** layer and a **decoder** are composed and trained to reconstruct its input $x$, as follows:
$$
\begin{align*}
x &= \text{input (e.g. image)} \\
\text{bottleneck } b &= E(x)
      &&\text{for some appropriate encoder }E\\
\text{prediction (reconstruction) is } \hat{x} &= D(b)
      &&\text{for some appropriate decoder }D \\
loss &= \mathcal L(x, \hat x)  &&\text{with some reconstruction loss} \mathcal L
\end{align*}
$$
The loss $\mathcal L$ encourages $\hat x$ to be similar to $x$; for images this may simply be the squared error.

After training, the value at the bottleneck, $b = E(x)$ can then be used for its learned representation in a downstream task. In this setup, without further changes, $b$ will be a continuous representation of $x$.

## Setup <a class="anchor" id="setup"></a>

### Install and import packages

In [None]:
!pip install -q dm-haiku optax

In [None]:
import haiku as hk
import jax
import matplotlib.pyplot as plt
import optax
from jax import jit
from jax import numpy as jnp
from jax import random
import numpy as np
# TensorFlow used only for datasets:
from tensorflow.keras import datasets
import tensorflow as tf

### Get MNIST dataset

In [None]:
def load_mnist():
    (x_train, _), (x_test, _) = datasets.mnist.load_data()

    # Rescale images to [-1, 1]
    x_train = (x_train.astype(np.float32) / 255.0) * 2.0 - 1.0
    x_train = np.expand_dims(x_train, axis=-1)
    x_test = (x_test.astype(np.float32) / 255.0) * 2.0 - 1.0
    x_test = np.expand_dims(x_test, axis=-1)

    # For convenience later on, we'll pad MNIST from 28x28 to 32x32.
    # Black pixels are -1.
    pad_width = ((2, 2), (2, 2))
    x_train = np.pad(x_train, ((0, 0),) + pad_width + ((0, 0),), mode='constant', constant_values=-1)
    x_test = np.pad(x_test, ((0, 0),) + pad_width + ((0, 0),), mode='constant', constant_values=-1)

    return x_train, x_test

mnist_train, mnist_test = load_mnist()
print(mnist_train.shape, mnist_test.shape)

In [None]:
# Display the first few examples in the training set.
fig, ax = plt.subplots()
ax.imshow(np.hstack(mnist_train[:10].squeeze()), cmap='Greys_r')
fig.show()

## Implementing the model <a class="anchor" id="implementing-the-model"></a>

Haiku, like other neural network libraries, divides large computational graphs into _modules_ or _layers_ that together can be composed into an arbitrarily complex model. We'll start by implementing a module for vector quantization.

### Vector quantization layer

The vector quantization layer keeps track of a set of embeddings for each discrete code in its codebook.

On the forward pass, the layer finds the nearest embedding for each item in the input and replaces it with the corresponding discrete code.

Since this operation is not differentiable, on the backward pass, the gradients are passed through the original continuous representations (bypassing the discrete embeddings), using the straight-through estimator.

In [None]:
class VectorQuantizer(hk.Module):
    def __init__(self, embedding_dim, num_embeddings, commitment_cost):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.num_embeddings = num_embeddings
        self.commitment_cost = commitment_cost

    def __call__(self, inputs, is_training):
        flat_inputs = jnp.reshape(inputs, [-1, self.embedding_dim])
        embeddings = hk.get_parameter(
            "embeddings",
            [self.embedding_dim, self.num_embeddings],
            init=hk.initializers.RandomUniform())

        # Quantization operation: compute (squared) distances, then find the
        # indices of the nearest neighbors. The original vectors are then
        # replaced with the corresponding entries in the codebook.

        # START OF SECTION TO FILL IN
        # Squared distance avoids the need to expand_dims on flat_inputs:
        distances = (jnp.sum(jnp.square(flat_inputs), axis=1, keepdims=True) -
                     2 * jnp.matmul(flat_inputs, embeddings) +
                     jnp.sum(jnp.square(embeddings), axis=0, keepdims=True))

        # Using one-hot lets us compute avg_probs below more easily
        encoding_indices = jnp.argmin(distances, axis=1)
        encodings = jax.nn.one_hot(encoding_indices, self.num_embeddings)
        quantized = jnp.take(embeddings.T, encoding_indices, axis=0)
        # alternative: quantized = embeddings[:, encoding_indices]

        # Identity in the forward pass (== quantized), but forces gradient of
        # quantized = gradient flat_inputs, i.e. straight-through estimator
        quantized = flat_inputs + jax.lax.stop_gradient(quantized - flat_inputs)
        # END OF SECTION TO FILL IN

        # Losses: besides the VAE loss (already implemented below, as VAEs have
        # not yet been covered - will see this later in the week), we need
        # two more losses for the VQ operation:
        # FILL IN THE BLANKS:
        # - e_latent_loss: encourages the encoder's output (i.e., the continuous
        #   representations) to be close to the quantized embeddings
        e_latent_loss = jnp.mean(jnp.square(jax.lax.stop_gradient(quantized) - flat_inputs))
        # - q_latent_loss: updates the learned embeddings in the codebook to
        #   better represent the continuous representations.
        q_latent_loss = jnp.mean(jnp.square(quantized - jax.lax.stop_gradient(flat_inputs)))

        loss = q_latent_loss + self.commitment_cost * e_latent_loss

        return {
            "quantize": jnp.reshape(quantized, inputs.shape),
            "loss": loss,
            "encodings": encodings,
            "encoding_indices": jnp.reshape(encoding_indices, inputs.shape[:-1])
        }

### Convolutional encoder and decoder

The architecture here closely follows that in the [VQ-VAE paper](https://arxiv.org/abs/1711.00937).

The encoder is several convolutional layers, followed by some ResNet-style blocks. Similar for the decoder.



In [None]:
class ResidualStack(hk.Module):
  def __init__(self, num_hiddens, num_residual_layers, num_residual_hiddens,
               name=None):
    super(ResidualStack, self).__init__(name=name)
    self._num_hiddens = num_hiddens
    self._num_residual_layers = num_residual_layers
    self._num_residual_hiddens = num_residual_hiddens

    self._layers = []
    for i in range(num_residual_layers):
      conv3 = hk.Conv2D(
          output_channels=num_residual_hiddens,
          kernel_shape=(3, 3),
          stride=(1, 1),
          name="res3x3_%d" % i)
      conv1 = hk.Conv2D(
          output_channels=num_hiddens,
          kernel_shape=(1, 1),
          stride=(1, 1),
          name="res1x1_%d" % i)
      self._layers.append((conv3, conv1))

  def __call__(self, inputs):
    h = inputs
    for conv3, conv1 in self._layers:
      conv3_out = conv3(jax.nn.relu(h))
      conv1_out = conv1(jax.nn.relu(conv3_out))
      h += conv1_out
    return jax.nn.relu(h)  # Resnet V1 style


class Encoder(hk.Module):
  def __init__(self, num_hiddens, num_residual_layers, num_residual_hiddens,
               name=None):
    super(Encoder, self).__init__(name=name)
    self._num_hiddens = num_hiddens
    self._num_residual_layers = num_residual_layers
    self._num_residual_hiddens = num_residual_hiddens

    self._enc_1 = hk.Conv2D(
        output_channels=self._num_hiddens // 2,
        kernel_shape=(4, 4),
        stride=(2, 2),
        name="enc_1")
    self._enc_2 = hk.Conv2D(
        output_channels=self._num_hiddens,
        kernel_shape=(4, 4),
        stride=(2, 2),
        name="enc_2")
    self._enc_3 = hk.Conv2D(
        output_channels=self._num_hiddens,
        kernel_shape=(4, 4),
        stride=(2, 2),  # was 1,1
        name="enc_3")
    self._residual_stack = ResidualStack(
        self._num_hiddens,
        self._num_residual_layers,
        self._num_residual_hiddens)

  def __call__(self, x):
    h = jax.nn.relu(self._enc_1(x))
    h = jax.nn.relu(self._enc_2(h))
    h = jax.nn.relu(self._enc_3(h))
    return self._residual_stack(h)


class Decoder(hk.Module):
  def __init__(self, num_hiddens, num_residual_layers, num_residual_hiddens,
               name=None):
    super(Decoder, self).__init__(name=name)
    self._num_hiddens = num_hiddens
    self._num_residual_layers = num_residual_layers
    self._num_residual_hiddens = num_residual_hiddens

    self._dec_1 = hk.Conv2DTranspose(
        output_channels=self._num_hiddens,
        kernel_shape=(4, 4),
        stride=(2, 2),
        name="dec_1")
    self._residual_stack = ResidualStack(
        self._num_hiddens,
        self._num_residual_layers,
        self._num_residual_hiddens)
    self._dec_2 = hk.Conv2DTranspose(
        output_channels=self._num_hiddens // 2,
        kernel_shape=(4, 4),
        stride=(2, 2),
        name="dec_2")
    self._dec_3 = hk.Conv2DTranspose(
        output_channels=1,
        kernel_shape=(4, 4),
        stride=(2, 2),
        name="dec_3")

  def __call__(self, x):
    h = self._dec_1(x)
    h = self._residual_stack(h)
    h = jax.nn.relu(self._dec_2(h))
    x_recon = self._dec_3(h)
    return x_recon


class VQVAEModel(hk.Module):
  def __init__(self, encoder, decoder, vqvae, pre_vq_conv1,
               data_variance, name=None):
    super(VQVAEModel, self).__init__(name=name)
    self._encoder = encoder
    self._decoder = decoder
    self._vqvae = vqvae
    self._pre_vq_conv1 = pre_vq_conv1
    self._data_variance = data_variance

  def __call__(self, inputs, is_training):
    z = self._pre_vq_conv1(self._encoder(inputs))
    vq_output = self._vqvae(z, is_training=is_training)
    x_recon = self._decoder(vq_output['quantize'])
    recon_error = jnp.mean((x_recon - inputs) ** 2) / self._data_variance
    loss = recon_error + vq_output['loss']
    return {
        'z': z,
        'x_recon': x_recon,
        'loss': loss,
        'recon_error': recon_error,
        'vq_output': vq_output,
    }

In [None]:
batch_size = 64
num_hiddens = 64
num_residual_hiddens = 32
num_residual_layers = 2
embedding_dim = 32
num_embeddings = 16
commitment_cost = 0.25
learning_rate = 1e-4

# For scaling reconstruction error
train_data_variance = np.var(mnist_train)

# Build modules.
def forward(data, is_training):
    encoder = Encoder(num_hiddens, num_residual_layers, num_residual_hiddens)
    decoder = Decoder(num_hiddens, num_residual_layers, num_residual_hiddens)
    pre_vq_conv1 = hk.Conv2D(
        output_channels=embedding_dim,
        kernel_shape=(1, 1),
        stride=(1, 1),
        name="to_vq")

    vq_vae = hk.nets.VectorQuantizer(
        embedding_dim=embedding_dim,
        num_embeddings=num_embeddings,
        commitment_cost=commitment_cost)

    model = VQVAEModel(encoder, decoder, vq_vae, pre_vq_conv1, data_variance=train_data_variance)

    return model(data, is_training)

forward = hk.transform_with_state(forward)
optimizer = optax.adam(learning_rate)

@jax.jit
def train_step(params, state, opt_state, data):
    def adapt_forward(params, state, data):
        model_output, state = forward.apply(params, state, None, data, is_training=True)
        loss = model_output['loss']
        return loss, (model_output, state)

    grads, (model_output, state) = jax.grad(adapt_forward, has_aux=True)(params, state, data)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return params, state, opt_state, model_output

## Training <a class="anchor" id="training"></a>

In [None]:
# Make minibatch iterators over the MNIST data using TensorFlow dataset API.
train_dataset = tf.data.Dataset.from_tensor_slices(mnist_train).shuffle(10000).repeat().batch(batch_size)
valid_dataset = tf.data.Dataset.from_tensor_slices(mnist_test).batch(batch_size)

In [None]:
# Initialize model
dummy_input = jnp.zeros((1, 32, 32, 1))
params = forward.init(jax.random.PRNGKey(42), dummy_input, is_training=True)
opt_state = optimizer.init(params)

Training for 20000 updates takes approximately 2 minutes with a T4 GPU.

In [None]:
%%time
num_training_updates = 20000

# Lists to keep track of metrics
train_losses = []
train_recon_errors = []
train_vqvae_loss = []

# Initialization
rng = jax.random.PRNGKey(42)
train_dataset_iter = iter(train_dataset)

# Initialize model parameters and optimizer state
dummy_data = next(train_dataset_iter).numpy()
params, state = forward.init(rng, dummy_data, is_training=True)
opt_state = optimizer.init(params)

# Training loop
for step in range(1, num_training_updates + 1):
    data = next(train_dataset_iter).numpy()
    params, state, opt_state, train_results = train_step(params, state, opt_state, data)

    train_results = jax.device_get(train_results)
    train_losses.append(train_results['loss'])
    train_recon_errors.append(train_results['recon_error'])
    train_vqvae_loss.append(train_results['vq_output']['loss'])

    if step % 1000 == 0:
        print(
            f'[Step {step}/{num_training_updates}] ' +
            f'train loss: {np.mean(train_losses[-100:]):.3f} ' +
            f'recon_error: {np.mean(train_recon_errors[-100:]):.3f} ' +
            f'vqvae loss: {np.mean(train_vqvae_loss[-100:]):.3f}'
        )


In [None]:
def plot_reconstructions(originals, recons, n=8):
    '''Plots original and reconstructed images side by side'''

    plt.figure(figsize=(20, 5))
    for i in range(n):
        # Display original
        ax = plt.subplot(2, n, i + 1)
        plt.imshow(np.clip(originals[i].squeeze(), 0, 1), cmap='gray')
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        ax.set_title('Original')

        # Display reconstruction
        ax = plt.subplot(2, n, i + 1 + n)
        plt.imshow(np.clip(recons[i].squeeze(), 0, 1), cmap='gray')
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        ax.set_title('Reconstruction')

    plt.tight_layout()
    plt.show()

# Obtain a batch from the test set
test_batch = next(iter(valid_dataset)).numpy()

# Get the reconstructions using the trained VQ-VAE
rng = jax.random.PRNGKey(42)
result, _ = forward.apply(params, state, rng, test_batch, is_training=False)
reconstructions = result['x_recon']

# Plot some inputs alongside reconstructions:
plot_reconstructions(test_batch, reconstructions)

In [None]:
# Obtain the discrete representation of the first digit
discrete_representation = result['vq_output']['encoding_indices'][0]  # Taking the first element

# Display the discrete representation
plt.figure(figsize=(10, 10))
plt.imshow(discrete_representation, cmap='tab20', aspect='auto')
plt.colorbar()
plt.title('Discrete Representation')
plt.show()

## Analysis

1. Fill in the blank above for the `VectorQuantization` class and make sure the code runs correctly. What is the shape of the learned discrete representation? (See visualization and/or try printing it from the `VectorQuantization` class.)
2. The codebook consists of a fixed set of embeddings. How would the model behave if the size of this codebook (number of embeddings) is increased or decreased significantly?
3. Discuss potential real-world applications where learned discrete representations may be advantageous.