# Install
```sh
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install dm-haiku optax
python3 -m pip install tensorflow[and-cuda]
pip install tensorflow_datasets
pip install matplotlib
pip install jupyter
```

In [16]:
import haiku as hk
import jax
import optax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import tensorflow.compat.v2 as tf
import tensorflow_datasets as tfds
import pickle
import os
import imageio

tf.enable_v2_behavior()

print("JAX version {}".format(jax.__version__))
print("Haiku version {}".format(hk.__version__))
print("TF version {}".format(tf.__version__))

JAX version 0.4.25
Haiku version 0.0.12
TF version 2.16.1


# Download Cifar10 data

In [17]:
# cifar10 = tfds.as_numpy(tfds.load("cifar10", split="train+test", batch_size=-1))
# del cifar10["id"], cifar10["label"]
# jax.tree_util.tree_map(lambda x: f'{x.dtype.name}{list(x.shape)}', cifar10)

In [18]:
# print(cifar10.keys())

# Download BouncingBall data

In [19]:
def load_images_from_folder(folder):
    images = []
    for filename in os.listdir(folder):
        img = imageio.imread(os.path.join(folder, filename))
        if img is not None:
            images.append(img)
    return np.array(images)

# Load your dataset
image_folder = './images'
custom_images = load_images_from_folder(image_folder)

  img = imageio.imread(os.path.join(folder, filename))


# Load the data into Numpy

In [20]:
# train_data_dict = jax.tree_util.tree_map(lambda x: x[:40000], cifar10)
# valid_data_dict = jax.tree_util.tree_map(lambda x: x[40000:50000], cifar10)
# test_data_dict = jax.tree_util.tree_map(lambda x: x[50000:], cifar10)

In [21]:
# BouncingBall
total_images = custom_images.shape[0]

# For example, let's say you want a 70-15-15 split
train_end = int(total_images * 0.7)
valid_end = train_end + int(total_images * 0.15)

train_images = custom_images[:train_end]
valid_images = custom_images[train_end:valid_end]
test_images = custom_images[valid_end:]

In [22]:
print(total_images, train_images.shape[0])

2000 1400


## Create dictionaries


In [23]:
train_data_dict = {'image': train_images}
valid_data_dict = {'image': valid_images}
test_data_dict = {'image': test_images}
train_data_variance = np.var(train_data_dict['image'] / 255.0)

In [24]:
# # CIFAR10
# def cast_and_normalise_images(data_dict):
#   """Convert images to floating point with the range [-0.5, 0.5]"""
#   data_dict['image'] = (tf.cast(data_dict['image'], tf.float32) / 255.0) - 0.5
#   return data_dict

# train_data_variance = np.var(train_data_dict['image'] / 255.0)
# print('train data variance: %s' % train_data_variance)

In [25]:
def cast_and_normalise_images(features):
    # Assuming your images are uint8 [0, 255]
    features['image'] = tf.cast(features['image'], tf.float32) / 255.0
    return features


# Encoder & Decoder Architecture

In [26]:
class ResidualStack(hk.Module):
  def __init__(self, num_hiddens, num_residual_layers, num_residual_hiddens,
               name=None):
    super(ResidualStack, self).__init__(name=name)
    self._num_hiddens = num_hiddens
    self._num_residual_layers = num_residual_layers
    self._num_residual_hiddens = num_residual_hiddens

    self._layers = []
    for i in range(num_residual_layers):
      conv3 = hk.Conv2D(
          output_channels=num_residual_hiddens,
          kernel_shape=(3, 3),
          stride=(1, 1),
          name="res3x3_%d" % i)
      conv1 = hk.Conv2D(
          output_channels=num_hiddens,
          kernel_shape=(1, 1),
          stride=(1, 1),
          name="res1x1_%d" % i)
      self._layers.append((conv3, conv1))

  def __call__(self, inputs):
    h = inputs
    for conv3, conv1 in self._layers:
      conv3_out = conv3(jax.nn.relu(h))
      conv1_out = conv1(jax.nn.relu(conv3_out))
      h += conv1_out
    return jax.nn.relu(h)  # Resnet V1 style


class Encoder(hk.Module):
  def __init__(self, num_hiddens, num_residual_layers, num_residual_hiddens,
               name=None):
    super(Encoder, self).__init__(name=name)
    self._num_hiddens = num_hiddens
    self._num_residual_layers = num_residual_layers
    self._num_residual_hiddens = num_residual_hiddens

    self._enc_1 = hk.Conv2D(
        output_channels=self._num_hiddens // 2,
        kernel_shape=(4, 4),
        stride=(2, 2),
        name="enc_1")
    self._enc_2 = hk.Conv2D(
        output_channels=self._num_hiddens,
        kernel_shape=(4, 4),
        stride=(2, 2),
        name="enc_2")
    self._enc_3 = hk.Conv2D(
        output_channels=self._num_hiddens,
        kernel_shape=(3, 3),
        stride=(1, 1),
        name="enc_3")
    self._residual_stack = ResidualStack(
        self._num_hiddens,
        self._num_residual_layers,
        self._num_residual_hiddens)

  def __call__(self, x):
    h = jax.nn.relu(self._enc_1(x))
    h = jax.nn.relu(self._enc_2(h))
    h = jax.nn.relu(self._enc_3(h))
    return self._residual_stack(h)


class Decoder(hk.Module):
  def __init__(self, num_hiddens, num_residual_layers, num_residual_hiddens,
               name=None):
    super(Decoder, self).__init__(name=name)
    self._num_hiddens = num_hiddens
    self._num_residual_layers = num_residual_layers
    self._num_residual_hiddens = num_residual_hiddens

    self._dec_1 = hk.Conv2D(
        output_channels=self._num_hiddens,
        kernel_shape=(3, 3),
        stride=(1, 1),
        name="dec_1")
    self._residual_stack = ResidualStack(
        self._num_hiddens,
        self._num_residual_layers,
        self._num_residual_hiddens)
    self._dec_2 = hk.Conv2DTranspose(
        output_channels=self._num_hiddens // 2,
        # output_shape=None,
        kernel_shape=(4, 4),
        stride=(2, 2),
        name="dec_2")
    self._dec_3 = hk.Conv2DTranspose(
        output_channels=3,
        # output_shape=None,
        kernel_shape=(4, 4),
        stride=(2, 2),
        name="dec_3")
    
  def __call__(self, x):
    h = self._dec_1(x)
    h = self._residual_stack(h)
    h = jax.nn.relu(self._dec_2(h))
    x_recon = self._dec_3(h)
    return x_recon
    

class VQVAEModel(hk.Module):
  def __init__(self, encoder, decoder, vqvae, pre_vq_conv1, 
               data_variance, name=None):
    super(VQVAEModel, self).__init__(name=name)
    self._encoder = encoder
    self._decoder = decoder
    self._vqvae = vqvae
    self._pre_vq_conv1 = pre_vq_conv1
    self._data_variance = data_variance

  def __call__(self, inputs, is_training):
    
    print(f"x shape: {inputs.shape}")
    
    z = self._pre_vq_conv1(self._encoder(inputs))
    print(f"z_e shape: {z.shape}")
    
    vq_output = self._vqvae(z, is_training=is_training)
    
    print(f"z_q shape: {vq_output['quantize'].shape}")
    
    x_recon = self._decoder(vq_output['quantize'])
    # recon_error = jnp.mean((x_recon - inputs) ** 2) / self._data_variance
    recon_error = jnp.mean((x_recon - inputs) ** 2)
    
    loss = recon_error + vq_output['loss']
    return {
        'z': z,
        'x_recon': x_recon,
        'loss': loss,
        'recon_error': recon_error,
        'vq_output': vq_output,
    }

# Build Model and train

In [27]:
# Set hyper-parameters.
batch_size = 32
image_size = 64 # BouncingBall

# 100k steps should take < 30 minutes on a modern (>= 2017) GPU.
num_training_updates = 30000

num_hiddens = 128
num_residual_hiddens = 32
num_residual_layers = 2
# These hyper-parameters define the size of the model (number of parameters and layers).
# The hyper-parameters in the paper were (For ImageNet):
# batch_size = 128
# image_size = 128
# num_hiddens = 128
# num_residual_hiddens = 32
# num_residual_layers = 2

# This value is not that important, usually 64 works.
# This will not change the capacity in the information-bottleneck.
embedding_dim = 64

# The higher this value, the higher the capacity in the information bottleneck.
num_embeddings = 512

# commitment_cost should be set appropriately. It's often useful to try a couple
# of values. It mostly depends on the scale of the reconstruction cost
# (log p(x|z)). So if the reconstruction cost is 100x higher, the
# commitment_cost should also be multiplied with the same amount.
commitment_cost = 0.25

# Use EMA updates for the codebook (instead of the Adam optimizer).
# This typically converges faster, and makes the model less dependent on choice
# of the optimizer. In the VQ-VAE paper EMA updates were not used (but was
# developed afterwards). See Appendix of the paper for more details.
vq_use_ema = True

# This is only used for EMA updates.
decay = 0.99

learning_rate = 1e-3


# # Data Loading.
train_dataset = tfds.as_numpy(
    tf.data.Dataset.from_tensor_slices(train_data_dict)
    .map(cast_and_normalise_images)
    .shuffle(10000)
    .repeat(-1)  # repeat indefinitely
    .batch(batch_size, drop_remainder=True)
    .prefetch(-1))
valid_dataset = tfds.as_numpy(
    tf.data.Dataset.from_tensor_slices(valid_data_dict)
    .map(cast_and_normalise_images)
    .repeat(1)  # 1 epoch
    .batch(batch_size)
    .prefetch(-1))

# Build modules.
def forward(data, is_training):
  encoder = Encoder(num_hiddens, num_residual_layers, num_residual_hiddens)
  decoder = Decoder(num_hiddens, num_residual_layers, num_residual_hiddens)
  pre_vq_conv1 = hk.Conv2D(
      output_channels=embedding_dim,
      kernel_shape=(1, 1),
      stride=(1, 1),
      name="to_vq")

  if vq_use_ema:
    vq_vae = hk.nets.VectorQuantizerEMA(
        embedding_dim=embedding_dim,
        num_embeddings=num_embeddings,
        commitment_cost=commitment_cost,
        decay=decay)
  else:
    vq_vae = hk.nets.VectorQuantizer(
        embedding_dim=embedding_dim,
        num_embeddings=num_embeddings,
        commitment_cost=commitment_cost)
    
  model = VQVAEModel(encoder, decoder, vq_vae, pre_vq_conv1,
                     data_variance=train_data_variance)

  return model(data['image'], is_training)

forward = hk.transform_with_state(forward)
optimizer = optax.adam(learning_rate)

@jax.jit
def train_step(params, state, opt_state, data):
  def adapt_forward(params, state, data):
    # Pack model output and state together.
    model_output, state = forward.apply(params, state, None, data, is_training=True)
    loss = model_output['loss']
    return loss, (model_output, state)

  grads, (model_output, state) = (
      jax.grad(adapt_forward, has_aux=True)(params, state, data))

  updates, opt_state = optimizer.update(grads, opt_state)
  params = optax.apply_updates(params, updates)

  return params, state, opt_state, model_output

In [28]:
%%time

train_losses = []
train_recon_errors = []
train_perplexities = []
train_vqvae_loss = []

rng = jax.random.PRNGKey(42)
train_dataset_iter = iter(train_dataset)
params, state = forward.init(rng, next(train_dataset_iter), is_training=True)
opt_state = optimizer.init(params)

for step in range(1, num_training_updates + 1):
  data = next(train_dataset_iter)
  params, state, opt_state, train_results = (
      train_step(params, state, opt_state, data))

  train_results = jax.device_get(train_results)
  train_losses.append(train_results['loss'])
  train_recon_errors.append(train_results['recon_error'])
  train_perplexities.append(train_results['vq_output']['perplexity'])
  train_vqvae_loss.append(train_results['vq_output']['loss'])

  if step % 100 == 0:
    print(f'[Step {step}/{num_training_updates}] ' + 
          ('train loss: %f ' % np.mean(train_losses[-100:])) +
          ('recon_error: %.9f ' % np.mean(train_recon_errors[-100:])) +
          ('perplexity: %.9f ' % np.mean(train_perplexities[-100:])) +
          ('vqvae loss: %.9f' % np.mean(train_vqvae_loss[-100:])))

z_e shape: (32, 16, 16, 64)
z_q shape: (32, 16, 16, 64)
z_e shape: (32, 16, 16, 64)
z_q shape: (32, 16, 16, 64)


KeyboardInterrupt: 

In [None]:
checkpoint_path = './checkpoints/VQ-VAE-DeepMind—BouncingBall.pkl'

# Save the model


In [None]:
# Assuming params and state are your model's parameters and state to save
model_dict = {
    'params': params,
    'state': state,
}

# Choose a file name


# Use pickle to save the model_dict to a file
# with open(checkpoint_path, 'wb') as file:
#     pickle.dump(model_dict, file)

# print("Model saved successfully.")

# Load the model (optional)

In [None]:
# Use pickle to load the model_dict from the file
with open(checkpoint_path, 'rb') as file:
    loaded_model_dict = pickle.load(file)

# Extract the params and state
loaded_params = loaded_model_dict['params']
loaded_state = loaded_model_dict['state']

params = loaded_params
state = loaded_state
print("Model loaded successfully.")


# Plot loss

For BouncingBall, minimal loss can be 0.00001.

In [None]:
f = plt.figure(figsize=(16,8))
ax = f.add_subplot(1,2,1)
ax.plot(train_recon_errors)
ax.set_yscale('log')
ax.set_title('NMSE.')

ax = f.add_subplot(1,2,2)
ax.plot(train_perplexities)
ax.set_title('Average codebook usage (perplexity).')

# View reconstructions

In [None]:
# Reconstructions
train_batch = next(iter(train_dataset))
valid_batch = next(iter(valid_dataset))

# Put data through the model with is_training=False, so that in the case of 
# using EMA the codebook is not updated.
train_reconstructions = forward.apply(params, state, rng, train_batch, is_training=False)[0]['x_recon']
valid_reconstructions = forward.apply(params, state, rng, valid_batch, is_training=False)[0]['x_recon']



def convert_batch_to_image_grid(image_batch, rows=4, cols=8):
    # Assuming image_batch is of shape (B, H, W, C)
    B, H, W, C = image_batch.shape
    assert B >= rows * cols, "Not enough images to fill the grid"
    
    reshaped = image_batch[:rows * cols].reshape(rows, cols, H, W, C)
    reshaped = reshaped.transpose(0, 2, 1, 3, 4)  # Transpose to (rows, H, cols, W, C)
    grid = reshaped.reshape(rows * H, cols * W, C)
    
    return grid

# Assuming 'train_batch', 'train_reconstructions', 'valid_batch', and 'valid_reconstructions' are available
f = plt.figure(figsize=(16, 16))

# Training Data Originals
ax = f.add_subplot(2, 2, 1)
ax.imshow(convert_batch_to_image_grid(train_batch['image']), interpolation='nearest')
ax.set_title('Training Data Originals')
plt.axis('off')

# Training Data Reconstructions
ax = f.add_subplot(2, 2, 2)
ax.imshow(convert_batch_to_image_grid(train_reconstructions), interpolation='nearest')
ax.set_title('Training Data Reconstructions')
plt.axis('off')

# Validation Data Originals
ax = f.add_subplot(2, 2, 3)
ax.imshow(convert_batch_to_image_grid(valid_batch['image']), interpolation='nearest')
ax.set_title('Validation Data Originals')
plt.axis('off')

# Validation Data Reconstructions
ax = f.add_subplot(2, 2, 4)
ax.imshow(convert_batch_to_image_grid(valid_reconstructions), interpolation='nearest')
ax.set_title('Validation Data Reconstructions')
plt.axis('off')

plt.show()
