We are going to implement the VQVAE mentionned in the paper here: https://arxiv.org/pdf/1711.00937

    As a first experiment we compare VQ-VAE with normal VAEs (with continuous variables), as well as    VIMCO [ 28 ] with independent Gaussian or categorical priors. We train these models using the same
    standard VAE architecture on CIFAR10, while varying the latent capacity (number of continuous or
    discrete latent variables, as well as the dimensionality of the discrete space K).
    
    
    The encoder consists of:
    
    - 2 strided convolutional layers with stride 2 and window size 4 × 4
    - two residual 3 × 3 blocks (implemented as ReLU, 3x3 conv, ReLU, 1x1 conv), all having 256 hidden units.

    The decoder similarly has:
    
    - two residual 3 × 3 blocks, followed by
    - two transposed convolutions with stride 2 and window size 4 × 4. 
    
    We use the ADAM optimiser [21 ] with learning rate 2e-4 and evaluates the performance after 250,000 steps with batch-size 128. For VIMCO we use 50 samples in the multi-sample training objective

Unfortunately it doesn't seem like they specify the order of the residual block placement, i.e. when the identity function is added. We'll suppose it's at the end.

Basing myself off https://github.com/google-deepmind/sonnet/blob/v2/sonnet/src/nets/vqvae.py for the code which is the official implementation


In [75]:
import jax
import equinox as eqx
import equinox.nn as nn

class ResBlock(eqx.Module):
    layers: list
    norm1: nn.BatchNorm
    norm2: nn.BatchNorm

    def __init__(self, dim, key):
        key1, key2 = jax.random.split(key, 2)

        self.layers = [
            nn.Conv2d(dim, dim, (5, 5), padding="SAME", key=key1),
            nn.Conv2d(dim, dim, (5, 5), padding="SAME", key=key2)
        ]
        self.norm1 = nn.BatchNorm(input_size=dim, axis_name="batch2", momentum=0.9, dtype=jax.numpy.float32)
        self.norm2 = nn.BatchNorm(input_size=dim, axis_name="batch2", momentum=0.9, dtype=jax.numpy.float32)

    def __call__(self, x, state):
        y = x

        y = self.layers[0](y)
        y, state = self.norm1(y, state)

        y = jax.nn.relu(y)

        y = self.layers[1](y)
        y, state = self.norm2(y, state)

        y = y + x
        y = jax.nn.relu(y)

        return y, state


class Encoder(eqx.Module):
    layers: list

    def __init__(self, dim, key):
        key1, key2, key3, key4 = jax.random.split(key, 4)

        self.layers = [
            nn.Conv2d(1, 2, (4, 4), 2, padding="SAME", key=key1),
            nn.Conv2d(2, 4, (4, 4), 2, padding="SAME", key=key3),
            ResBlock(dim, key=key2),
            ResBlock(dim, key=key4)
        ]

    def __call__(self, x, state):
        y = x
        for layer in self.layers:
            if isinstance(layer, ResBlock):
                y, state = layer(y, state)
            else:
                y = layer(y) 
    
        return y, state

In [76]:
import optax

@eqx.filter_jit
def forward(model, x, state):
    return jax.vmap(model, axis_name="batch2", in_axes=(0, None), out_axes=(0, None))(x, state)

@eqx.filter_grad(has_aux=True)
@eqx.filter_jit
def loss(model, x, y, state):
    result, state = forward(model, x, state)
    loss = jax.numpy.mean(jax.numpy.abs(result - y))
    return loss, state


key1, key2, key3 = jax.random.split(jax.random.PRNGKey(69), 3)

model, state = eqx.nn.make_with_state(Encoder)(dim=4, key=key1)

optimizer = optax.adam(1e-5)
opt_state = optimizer.init(eqx.filter(model, eqx.is_array))

x = jax.random.normal(key2, (10, 1, 100, 100))
y = jax.random.normal(key3, (10, 1, 25, 25))

grads = loss(model, x, y, state)

Currently the convolutional layers only divide by 4x4=16 the embedded picture in terms of dimensions. 

In [32]:
class Decoder(eqx.Module):
    layers: list

    def __init__(self, dim, key):
        key1, key2, key3, key4 = jax.random.split(key, 4)

        self.layers = [
            ResBlock(dim, key=key2),
            ResBlock(dim, key=key4),
            nn.ConvTranspose2d(4, 2, (4, 4), 2, padding="SAME", key=key1),
            nn.ConvTranspose2d(2, 1, (4, 4), 2, padding="SAME", key=key3),
        ]

    def __call__(self, x, state):
        y = x
        for layer in self.layers:
            if isinstance(layer, ResBlock):
                y, state = layer(y, state)
            else:
                y = layer(y)

        y = jax.nn.sigmoid(y)
    
        return y, state

In [None]:
@eqx.filter_jit
def forward(model, x, state):
    return jax.vmap(model, axis_name="batch", in_axes=(0, None), out_axes=(0, None))(x, state)

@eqx.filter_grad(has_aux=True)
@eqx.filter_jit
def loss(model, x, y, state):
    result, state = forward(model, x, state)
    loss = jax.numpy.mean(jax.numpy.abs(result - y))
    return loss, state

key1, key2, key3 = jax.random.split(jax.random.PRNGKey(69), 3)

model, state = eqx.nn.make_with_state(Decoder)(key1)

x = jax.random.normal(key2, (10, 256, 100, 100))
y = jax.random.normal(key3, (10, 256, 25, 25))

grads = loss(model, y, x, state)

We now have a variational autoencoder - It takes in inputs, autoencodes them into a latent space but instead of poping out simply a vector it should also pop out two values to describe the distribution of that input: mean and variance.

In [57]:
import jax
import equinox as eqx
import jax.numpy as np

class Quantizer(eqx.Module):
    K: int
    D: int
    codebook: np.array

    def __init__(self, num_vecs, num_dims, key):
        self.K = num_vecs
        self.D = num_dims

        # Init a matrix of vectors that will move with time

        self.codebook = jax.nn.initializers.variance_scaling(scale=1.0, mode="fan_in", distribution="uniform")(key, (self.K, self.D))

    def __call__(self, x):
        # X comes in as a N x D Matrix.
        flattened_x = jax.numpy.reshape(x, (-1, self.D))
        # Calculate dist
        # Nx1
        a_squared = np.sum(flattened_x**2, axis=-1, keepdims=True)
        # 1xK
        b_squared = np.transpose(np.sum(self.codebook**2, axis=-1, keepdims=True))
        distance = a_squared + b_squared - 2*np.matmul(flattened_x, np.transpose(self.codebook))


        encoding_indices = np.reshape(
            np.argmin(distance, axis=-1), x.shape[0]
        )


        z_q = self.codebook[encoding_indices]


        z_q = flattened_x + jax.lax.stop_gradient(z_q - flattened_x) # For the straight through estimation.
        z_q = jax.numpy.reshape(z_q, (4, 7, 7))

        return z_q



What the above does is create K vectors in D dimensional space. Incoming vectors find their nearest match and the loss calculated is the L2 distance.

In [None]:
import jax
import jax.numpy as np

# @jax.jit
def forward(model, x):
    return jax.vmap(model)(x)

# @jax.jit
def loss(model, x):
    y = forward(model, x)
    
    los = np.mean((x - y) ** 2)
    # print(los)
    return los

key1, key2 = jax.random.split(jax.random.PRNGKey(69), 2)

model = Quantizer(10, 7*7, key=key1)
x = jax.random.normal(key2, (64, 4, 7, 7))

loss = jax.grad(loss, allow_int=True)

grads = loss(model, x)
print(grads)
