<a href="https://colab.research.google.com/github/Ryan0v0/nninn/blob/master/vq_vae.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Step1: Splitting up neural net params into chunks

In [2]:
import haiku as hk
import jax
import jax.numpy as jnp

# Define neural network architecture
class NeuralNetwork(hk.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        
    def __call__(self, x):
        x = hk.Linear(100)(x)
        x = jax.nn.relu(x)
        
        x = hk.Linear(100)(x)
        x = jax.nn.relu(x)
        
        x = hk.Linear(1)(x)
        return x

def net_fn(x):
    return NeuralNetwork()(x)

# Create an instance of the neural network
net = hk.transform(net_fn)

# Initialize neural network parameters
params_rng = jax.random.PRNGKey(42)
dummy_input = jnp.ones([1, 10])
params = net.init(params_rng, dummy_input)

# Initialize weights as non-negative
params = jax.tree_map(lambda x: jax.nn.relu(x), params)

# Split neural network parameters into chunks
chunk_size = 1000
param_chunks = []

for param in jax.tree_leaves(params):
    flattened_param = param.ravel()
    n_chunks = len(flattened_param) // chunk_size + (len(flattened_param) % chunk_size > 0)
    chunks = jnp.array_split(flattened_param, n_chunks)
    param_chunks.extend(chunks)

# Print the number of parameter chunks
print("Number of parameter chunks:", len(param_chunks))
print("Parameter chunks:", param_chunks)


  for param in jax.tree_leaves(params):


Number of parameter chunks: 15
Parameter chunks: [Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],      dtype=float32), Array([0.06073479, 0.        , 0.        , 0.04616429, 0.        ,
       0.        , 0.        , 0.4006773 , 0.11116457, 0.        ,
       0.        , 0.08436938, 0.1989393 , 0.02855882, 0.05032298,
       0.        , 0.        , 0.        , 0.        , 0.        ,
       0.        , 0.3577101 , 0.3701997 , 0.26633295, 0.268086  ,
       0.        , 0.        , 0.2998543 , 0.        , 0.08337022,
       0.        , 0.        , 0.2681047 , 0.        , 0.        ,
       0.        

In [3]:
import jax
import numpy as np

# Convert param_chunks to a numpy array
param_chunks_np = np.concatenate([jax.device_get(chunk) for chunk in param_chunks])

print(type(param_chunks_np))
print("size:", param_chunks_np.shape)

<class 'numpy.ndarray'>
size: (11301,)


# Step2: learning a mapping from each chunk to an integer via VQ-VAE

## Load Data

using the weights of the above neural network as input.

In [4]:
data_variance = np.var(param_chunks_np / 255.0)

print(data_variance)

8.001368e-08


## Vector Quantizer Layer

This layer takes a tensor to be quantized. The channel dimension will be used as the space in which to quantize. All other dimensions will be flattened and will be seen as different examples to quantize.

The output tensor will have the same shape as the input.

As an example for a `BCHW` tensor of shape `[16, 64, 32, 32]`, we will first convert it to an `BHWC` tensor of shape `[16, 32, 32, 64]` and then reshape it into `[16384, 64]` and all `16384` vectors of size `64`  will be quantized independently. In otherwords, the channels are used as the space in which to quantize. All other dimensions will be flattened and be seen as different examples to quantize, `16384` in this case.

In [5]:
"""Haiku implementation of VQ-VAE https://arxiv.org/abs/1711.00937."""

from typing import Any, Optional

import haiku as hk

from haiku._src import base
from haiku._src import initializers
from haiku._src import module
from haiku._src import moving_averages

import jax
import jax.numpy as jnp


# If you are forking replace this with `import haiku as hk`.
# pylint: disable=invalid-name
class hk:
  get_parameter = base.get_parameter
  get_state = base.get_state
  set_state = base.set_state
  initializers = initializers
  ExponentialMovingAverage = moving_averages.ExponentialMovingAverage
  Module = module.Module
# pylint: enable=invalid-name
del base, initializers, module, moving_averages


class VectorQuantizer(hk.Module):
  """Haiku module representing the VQ-VAE layer.

  Implements the algorithm presented in
  "Neural Discrete Representation Learning" by van den Oord et al.
  https://arxiv.org/abs/1711.00937

  Input any tensor to be quantized. Last dimension will be used as space in
  which to quantize. All other dimensions will be flattened and will be seen
  as different examples to quantize.

  The output tensor will have the same shape as the input.

  For example a tensor with shape ``[16, 32, 32, 64]`` will be reshaped into
  ``[16384, 64]`` and all ``16384`` vectors (each of ``64`` dimensions)  will be
  quantized independently.

  Attributes:
    embedding_dim: integer representing the dimensionality of the tensors in the
      quantized space. Inputs to the modules must be in this format as well.
    num_embeddings: integer, the number of vectors in the quantized space.
    commitment_cost: scalar which controls the weighting of the loss terms (see
      equation 4 in the paper - this variable is Beta).
  """

  def __init__(
      self,
      embedding_dim: int,
      num_embeddings: int,
      commitment_cost: float,
      dtype: Any = jnp.float32,
      name: Optional[str] = None,
      cross_replica_axis: Optional[str] = None,
  ):
    """Initializes a VQ-VAE module.

    Args:
      embedding_dim: dimensionality of the tensors in the quantized space.
        Inputs to the modules must be in this format as well.
      num_embeddings: number of vectors in the quantized space.
      commitment_cost: scalar which controls the weighting of the loss terms
        (see equation 4 in the paper - this variable is Beta).
      dtype: dtype for the embeddings variable, defaults to ``float32``.
      name: name of the module.
      cross_replica_axis: If not ``None``, it should be a string representing
        the axis name over which this module is being run within a
        :func:`jax.pmap`. Supplying this argument means that perplexity is
        calculated across all replicas on that axis.
    """
    super().__init__(name=name)
    self.embedding_dim = embedding_dim
    self.num_embeddings = num_embeddings
    self.commitment_cost = commitment_cost
    self.cross_replica_axis = cross_replica_axis

    self._embedding_shape = [embedding_dim, num_embeddings]
    self._embedding_dtype = dtype

  @property
  def embeddings(self):
    initializer = hk.initializers.VarianceScaling(distribution="uniform")
    return hk.get_parameter(
        "embeddings",
        self._embedding_shape,
        self._embedding_dtype,
        init=initializer)

  def __call__(self, inputs, is_training):
    """Connects the module to some inputs.

    Args:
      inputs: Tensor, final dimension must be equal to ``embedding_dim``. All
        other leading dimensions will be flattened and treated as a large batch.
      is_training: boolean, whether this connection is to training data.

    Returns:
      dict: Dictionary containing the following keys and values:
        * ``quantize``: Tensor containing the quantized version of the input.
        * ``loss``: Tensor containing the loss to optimize.
        * ``perplexity``: Tensor containing the perplexity of the encodings.
        * ``encodings``: Tensor containing the discrete encodings, ie which
          element of the quantized space each input element was mapped to.
        * ``encoding_indices``: Tensor containing the discrete encoding indices,
          ie which element of the quantized space each input element was mapped
          to.
    """
    flat_inputs = jnp.reshape(inputs, [-1, self.embedding_dim])

    distances = (
        jnp.sum(jnp.square(flat_inputs), 1, keepdims=True) -
        2 * jnp.matmul(flat_inputs, self.embeddings) +
        jnp.sum(jnp.square(self.embeddings), 0, keepdims=True))

    encoding_indices = jnp.argmax(-distances, 1)
    encodings = jax.nn.one_hot(encoding_indices,
                               self.num_embeddings,
                               dtype=distances.dtype)

    # NB: if your code crashes with a reshape error on the line below about a
    # Tensor containing the wrong number of values, then the most likely cause
    # is that the input passed in does not have a final dimension equal to
    # self.embedding_dim. Ideally we would catch this with an Assert but that
    # creates various other problems related to device placement / TPUs.
    encoding_indices = jnp.reshape(encoding_indices, inputs.shape[:-1])
    quantized = self.quantize(encoding_indices)

    e_latent_loss = jnp.mean(
        jnp.square(jax.lax.stop_gradient(quantized) - inputs))
    q_latent_loss = jnp.mean(
        jnp.square(quantized - jax.lax.stop_gradient(inputs)))
    loss = q_latent_loss + self.commitment_cost * e_latent_loss

    # Straight Through Estimator
    quantized = inputs + jax.lax.stop_gradient(quantized - inputs)
    avg_probs = jnp.mean(encodings, 0)
    if self.cross_replica_axis:
      avg_probs = jax.lax.pmean(avg_probs, axis_name=self.cross_replica_axis)
    perplexity = jnp.exp(-jnp.sum(avg_probs * jnp.log(avg_probs + 1e-10)))

    return {
        "quantize": quantized,
        "loss": loss,
        "perplexity": perplexity,
        "encodings": encodings,
        "encoding_indices": encoding_indices,
        "distances": distances,
    }

  def quantize(self, encoding_indices):
    """Returns embedding tensor for a batch of indices."""
    w = self.embeddings.swapaxes(1, 0)
    w = jax.device_put(w)  # Required when embeddings is a NumPy array.
    return w[(encoding_indices,)]


class VectorQuantizerEMA(hk.Module):
  r"""Haiku module representing the VQ-VAE layer.

  Implements a slightly modified version of the algorithm presented in
  "Neural Discrete Representation Learning" by van den Oord et al.
  https://arxiv.org/abs/1711.00937

  The difference between :class:`VectorQuantizerEMA` and
  :class:`VectorQuantizer` is that this module uses
  :class:`~haiku.ExponentialMovingAverage`\ s to update the embedding vectors
  instead of an auxiliary loss. This has the advantage that the embedding
  updates are independent of the choice of optimizer (SGD, RMSProp, Adam, K-Fac,
  ...) used for the encoder, decoder and other parts of the architecture. For
  most experiments the EMA version trains faster than the non-EMA version.

  Input any tensor to be quantized. Last dimension will be used as space in
  which to quantize. All other dimensions will be flattened and will be seen
  as different examples to quantize.

  The output tensor will have the same shape as the input.

  For example a tensor with shape ``[16, 32, 32, 64]`` will be reshaped into
  ``[16384, 64]`` and all ``16384`` vectors (each of 64 dimensions)  will be
  quantized independently.

  Attributes:
    embedding_dim: integer representing the dimensionality of the tensors in
      the quantized space. Inputs to the modules must be in this format as well.
    num_embeddings: integer, the number of vectors in the quantized space.
    commitment_cost: scalar which controls the weighting of the loss terms
      (see equation 4 in the paper).
    decay: float, decay for the moving averages.
    epsilon: small float constant to avoid numerical instability.
  """

  def __init__(
      self,
      embedding_dim,
      num_embeddings,
      commitment_cost,
      decay,
      epsilon: float = 1e-5,
      dtype: Any = jnp.float32,
      cross_replica_axis: Optional[str] = None,
      name: Optional[str] = None,
  ):
    """Initializes a VQ-VAE EMA module.

    Args:
      embedding_dim: integer representing the dimensionality of the tensors in
        the quantized space. Inputs to the modules must be in this format as
        well.
      num_embeddings: integer, the number of vectors in the quantized space.
      commitment_cost: scalar which controls the weighting of the loss terms
        (see equation 4 in the paper - this variable is Beta).
      decay: float between 0 and 1, controls the speed of the Exponential Moving
        Averages.
      epsilon: small constant to aid numerical stability, default ``1e-5``.
      dtype: dtype for the embeddings variable, defaults to ``float32``.
      cross_replica_axis: If not ``None``, it should be a string representing
        the axis name over which this module is being run within a
        :func:`jax.pmap`. Supplying this argument means that cluster statistics
        and the perplexity are calculated across all replicas on that axis.
      name: name of the module.
    """
    super().__init__(name=name)
    if not 0 <= decay <= 1:
      raise ValueError("decay must be in range [0, 1]")

    self.embedding_dim = embedding_dim
    self.num_embeddings = num_embeddings
    self.decay = decay
    self.commitment_cost = commitment_cost
    self.epsilon = epsilon
    self.cross_replica_axis = cross_replica_axis

    self._embedding_shape = [embedding_dim, num_embeddings]
    self._dtype = dtype

    self._ema_cluster_size = hk.ExponentialMovingAverage(
        decay=self.decay, name="ema_cluster_size")
    self._ema_dw = hk.ExponentialMovingAverage(decay=self.decay, name="ema_dw")

  @property
  def embeddings(self):
    initializer = hk.initializers.VarianceScaling(distribution="uniform")
    return hk.get_state(
        "embeddings", self._embedding_shape, self._dtype, init=initializer)

  @property
  def ema_cluster_size(self):
    self._ema_cluster_size.initialize([self.num_embeddings], self._dtype)
    return self._ema_cluster_size

  @property
  def ema_dw(self):
    self._ema_dw.initialize(self._embedding_shape, self._dtype)
    return self._ema_dw

  def __call__(self, inputs, is_training):
    """Connects the module to some inputs.

    Args:
      inputs: Tensor, final dimension must be equal to ``embedding_dim``. All
        other leading dimensions will be flattened and treated as a large batch.
      is_training: boolean, whether this connection is to training data. When
        this is set to ``False``, the internal moving average statistics will
        not be updated.

    Returns:
      dict: Dictionary containing the following keys and values:
        * ``quantize``: Tensor containing the quantized version of the input.
        * ``loss``: Tensor containing the loss to optimize.
        * ``perplexity``: Tensor containing the perplexity of the encodings.
        * ``encodings``: Tensor containing the discrete encodings, ie which
          element of the quantized space each input element was mapped to.
        * ``encoding_indices``: Tensor containing the discrete encoding indices,
          ie which element of the quantized space each input element was mapped
          to.
    """
    flat_inputs = jnp.reshape(inputs, [-1, self.embedding_dim])
    embeddings = self.embeddings

    distances = (
        jnp.sum(jnp.square(flat_inputs), 1, keepdims=True) -
        2 * jnp.matmul(flat_inputs, embeddings) +
        jnp.sum(jnp.square(embeddings), 0, keepdims=True))

    encoding_indices = jnp.argmax(-distances, 1)
    encodings = jax.nn.one_hot(encoding_indices,
                               self.num_embeddings,
                               dtype=distances.dtype)

    # NB: if your code crashes with a reshape error on the line below about a
    # Tensor containing the wrong number of values, then the most likely cause
    # is that the input passed in does not have a final dimension equal to
    # self.embedding_dim. Ideally we would catch this with an Assert but that
    # creates various other problems related to device placement / TPUs.
    encoding_indices = jnp.reshape(encoding_indices, inputs.shape[:-1])
    quantized = self.quantize(encoding_indices)
    e_latent_loss = jnp.mean(
        jnp.square(jax.lax.stop_gradient(quantized) - inputs))

    if is_training:
      cluster_size = jnp.sum(encodings, axis=0)
      if self.cross_replica_axis:
        cluster_size = jax.lax.psum(
            cluster_size, axis_name=self.cross_replica_axis)
      updated_ema_cluster_size = self.ema_cluster_size(cluster_size)

      dw = jnp.matmul(flat_inputs.T, encodings)
      if self.cross_replica_axis:
        dw = jax.lax.psum(dw, axis_name=self.cross_replica_axis)
      updated_ema_dw = self.ema_dw(dw)

      n = jnp.sum(updated_ema_cluster_size)
      updated_ema_cluster_size = ((updated_ema_cluster_size + self.epsilon) /
                                  (n + self.num_embeddings * self.epsilon) * n)

      normalised_updated_ema_w = (
          updated_ema_dw / jnp.reshape(updated_ema_cluster_size, [1, -1]))

      hk.set_state("embeddings", normalised_updated_ema_w)
      loss = self.commitment_cost * e_latent_loss

    else:
      loss = self.commitment_cost * e_latent_loss

    # Straight Through Estimator
    quantized = inputs + jax.lax.stop_gradient(quantized - inputs)
    avg_probs = jnp.mean(encodings, 0)
    if self.cross_replica_axis:
      avg_probs = jax.lax.pmean(avg_probs, axis_name=self.cross_replica_axis)
    perplexity = jnp.exp(-jnp.sum(avg_probs * jnp.log(avg_probs + 1e-10)))

    return {
        "quantize": quantized,
        "loss": loss,
        "perplexity": perplexity,
        "encodings": encodings,
        "encoding_indices": encoding_indices,
        "distances": distances,
    }

  def quantize(self, encoding_indices):
    """Returns embedding tensor for a batch of indices."""
    w = self.embeddings.swapaxes(1, 0)
    w = jax.device_put(w)  # Required when embeddings is a NumPy array.
    return w[(encoding_indices,)]

We will also implement a slightly modified version  which will use exponential moving averages to update the embedding vectors instead of an auxillary loss. This has the advantage that the embedding updates are independent of the choice of optimizer for the encoder, decoder and other parts of the architecture. For most experiments the EMA version trains faster than the non-EMA version.

## Encoder & Decoder Architecture

The encoder and decoder architecture is based on a ResNet and is implemented below:

In [6]:
import haiku as hk
import jax
import jax.numpy as jnp

class Residual(hk.Module):
    def __init__(self, num_residual_hiddens):
        super(Residual, self).__init__()
        self.num_residual_hiddens = num_residual_hiddens

    def __call__(self, x):
        x = jnp.expand_dims(x, 2)  # Add an extra dimension
        residual = hk.Conv1D(self.num_residual_hiddens, 1)(x)
        residual = jax.nn.relu(residual)
        output = x + residual
        return jnp.squeeze(output, 2)  # Remove the extra dimension

class ResidualStack(hk.Module):
    def __init__(self, num_hiddens, num_residual_layers, num_residual_hiddens):
        super(ResidualStack, self).__init__()
        self.num_residual_layers = num_residual_layers
        self.num_residual_hiddens = num_residual_hiddens

    def __call__(self, x):
        for _ in range(self.num_residual_layers):
            x = Residual(self.num_residual_hiddens)(x)
        return jax.nn.relu(x)

class Encoder(hk.Module):
    def __init__(self, num_hiddens, num_residual_layers, num_residual_hiddens):
        super(Encoder, self).__init__()
        self.num_hiddens = num_hiddens
        self.num_residual_layers = num_residual_layers
        self.num_residual_hiddens = num_residual_hiddens

    def __call__(self, inputs):
        x = hk.Linear(self.num_hiddens // 2)(inputs.reshape((inputs.shape[0], -1)))
        x = jax.nn.relu(x)
        x = hk.Linear(self.num_hiddens)(x)
        x = jax.nn.relu(x)
        return ResidualStack(self.num_hiddens, self.num_residual_layers, self.num_residual_hiddens)(x)

class Decoder(hk.Module):
    def __init__(self, num_hiddens, num_residual_layers, num_residual_hiddens):
        super(Decoder, self).__init__()
        self.num_hiddens = num_hiddens
        self.num_residual_layers = num_residual_layers
        self.num_residual_hiddens = num_residual_hiddens

    def __call__(self, inputs):
        x = hk.Linear(self.num_hiddens)(inputs.reshape((inputs.shape[0], -1)))
        x = ResidualStack(self.num_hiddens, self.num_residual_layers, self.num_residual_hiddens)(x)
        x = hk.Linear(self.num_hiddens // 2)(x)
        x = jax.nn.relu(x)
        return x

## Train

We use the hyperparameters from the author's code:

In [7]:
batch_size = 256
num_training_updates = 15000

num_hiddens = 128
num_residual_hiddens = 32
num_residual_layers = 2

embedding_dim = 64
num_embeddings = 512

commitment_cost = 0.25

decay = 0.99

learning_rate = 1e-3

In [8]:
def data_loader(data, batch_size=64, shuffle=True):
    n = len(data)
    indices = np.arange(n)
    
    # Shuffle the indices if shuffle is True
    if shuffle:
        np.random.shuffle(indices)
    
    # Generate batches
    for start_idx in range(0, n, batch_size):
        end_idx = min(start_idx + batch_size, n)
        batch_idx = indices[start_idx:end_idx]
        yield data[batch_idx]

# Use the custom data loader
batch_size = 64
my_training_loader = data_loader(param_chunks_np, batch_size=batch_size)

# Use the generator function
for batch_idx, data_batch in enumerate(my_training_loader):
    print("Batch Index:", batch_idx)
    print("Data:", data_batch)
    print()


Batch Index: 0
Data: [0.         0.0646355  0.1383781  0.         0.         0.03250742
 0.06139227 0.         0.07957268 0.1536034  0.         0.08873744
 0.44462025 0.         0.         0.03760948 0.         0.
 0.         0.01354564 0.05850566 0.         0.         0.
 0.         0.01923074 0.1142441  0.1020169  0.         0.
 0.04887822 0.         0.         0.         0.0636283  0.18371016
 0.02441926 0.48644787 0.         0.         0.         0.03620996
 0.         0.08870517 0.         0.         0.03035104 0.
 0.         0.08281113 0.         0.         0.15550774 0.
 0.         0.         0.02632775 0.31960234 0.02705066 0.02366519
 0.         0.         0.10951056 0.05181864]

Batch Index: 1
Data: [0.0645869  0.06940471 0.04132371 0.         0.         0.
 0.         0.03042846 0.07531554 0.         0.         0.
 0.11913878 0.         0.03965377 0.41436425 0.         0.07565757
 0.         0.10509455 0.49042365 0.04238994 0.         0.0136773
 0.35306725 0.         0.47267

In [9]:
class Model(hk.Module):
    def __init__(self, num_hiddens, num_residual_layers, num_residual_hiddens,
                 num_embeddings, embedding_dim, commitment_cost, decay=0):
        super(Model, self).__init__()
        
        # Initialize the encoder
        self._encoder = Encoder(1, num_hiddens, num_residual_layers, num_residual_hiddens)
        print("Encoder:", self._encoder)
        
        # Initialize the pre VQ (Vector Quantization) linear layer
        self._pre_vq_linear = hk.Linear(embedding_dim)
        print("Pre VQ Linear Layer:", self._pre_vq_linear)
        
        # Initialize the VQ-VAE layer with or without exponential moving average (EMA) decay
        if decay > 0.0:
            self._vq_vae = VectorQuantizerEMA(num_embeddings, embedding_dim, commitment_cost, decay)
        else:
            self._vq_vae = VectorQuantizer(num_embeddings, embedding_dim, commitment_cost)
        print("VQ VAE Layer:", self._vq_vae)
            
        # Initialize the decoder
        self._decoder = Decoder(embedding_dim, num_hiddens, num_residual_layers, num_residual_hiddens)
        print("Decoder:", self._decoder)
        
    def __call__(self, x):
        # Forward pass through encoder
        z = self._encoder(x)
        z = jnp.reshape(z, (z.shape[0], -1))
        z = self._pre_vq_linear(z)
        
        # Forward pass through VQ-VAE layer
        loss, quantized, perplexity = self._vq_vae(z)
        
        # Forward pass through decoder
        x_recon = self._decoder(quantized)
        
        return loss, x_recon, perplexity

# Wrap the model as a Haiku function
def model_fn(x):
    model = Model(num_hiddens, num_residual_layers, num_residual_hiddens,
                  num_embeddings, embedding_dim,
                  commitment_cost, decay)
    return model(x)

# Create a JAX-transformed Haiku function
model = hk.transform(model_fn)

print(model)

Transformed(init=<function without_state.<locals>.init_fn at 0x7f74d4502ee0>, apply=<function without_state.<locals>.apply_fn at 0x7f74d4502940>)


In [10]:
from jax import grad, jit, value_and_grad
import optax
import numpy as np

# First, define the loss function
def loss_fn(params, data):
    vq_loss, data_recon, perplexity = model.apply(params, None, data)
    data_variance = np.var(data)
    recon_error = ((data_recon - data)**2).mean() / data_variance
    loss = recon_error + vq_loss
    return loss, (recon_error, perplexity)

# Use JIT (Just-In-Time compilation) for acceleration
@jit
def update(params, opt_state, data):
    (loss, aux), grads = value_and_grad(loss_fn, has_aux=True)(params, data)
    updates, opt_state = optimizer.update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
    return new_params, opt_state, aux

# Set up the optimizer
optimizer = optax.adam(1e-3)
opt_state = optimizer.init(params)

train_res_recon_error = []
train_res_perplexity = []

# Assume training_loader is an iterator that can return batches of data
for i in range(num_training_updates):
    data = next(iter(training_loader))
    
    # Update model parameters
    params, opt_state, (recon_error, perplexity) = update(params, opt_state, data)
    
    train_res_recon_error.append(recon_error)
    train_res_perplexity.append(perplexity)
    
    if (i+1) % 100 == 0:
        print(f'{i+1} iterations')
        print(f'recon_error: {np.mean(train_res_recon_error[-100:]):.3f}')
        print(f'perplexity: {np.mean(train_res_perplexity[-100:]):.3f}')
        print()


NameError: name 'training_loader' is not defined

In [None]:
# model = Model(num_hiddens, num_residual_layers, num_residual_hiddens,
#              num_embeddings, embedding_dim,
#              commitment_cost, decay).to(device)

Encoder(
  (_linear_1): Linear(in_features=64, out_features=64, bias=True)
  (_linear_2): Linear(in_features=64, out_features=128, bias=True)
  (_residual_stack): ResidualStack(
    (_layers): ModuleList(
      (0): Residual(
        (_block): Sequential(
          (0): ReLU(inplace=True)
          (1): Conv1d(128, 1, kernel_size=(1,), stride=(1,), bias=False)
        )
      )
    )
  )
)
Linear(in_features=128, out_features=64, bias=True)
VectorQuantizerEMA(
  (_embedding): Embedding(512, 64)
)
Decoder(
  (_linear_1): Linear(in_features=64, out_features=128, bias=True)
  (_residual_stack): ResidualStack(
    (_layers): ModuleList(
      (0): Residual(
        (_block): Sequential(
          (0): ReLU(inplace=True)
          (1): Conv1d(128, 1, kernel_size=(1,), stride=(1,), bias=False)
        )
      )
    )
  )
  (_linear_2): Linear(in_features=128, out_features=64, bias=True)
)


# Test JAX

In [1]:
import jax
import jaxlib

print("jax version:", jax.__version__)
print("jaxlib version:", jaxlib.__version__)

jax version: 0.4.11
jaxlib version: 0.4.7


In [2]:
import jax.numpy as jnp

x = jnp.array([1.0, 2.0, 3.0])
y = jnp.sin(x)

print("x:", x)
print("y:", y)

x: [1. 2. 3.]
y: [0.841471   0.90929747 0.14112   ]


In [3]:
from jax import grad

def f(x):
    return jnp.sin(x)

g = grad(f)

print("Gradient of f at x = 1.0:", g(1.0))

Gradient of f at x = 1.0: 0.5403023


In [12]:
!pip install -U dm-haiku

Collecting dm-haiku
  Obtaining dependency information for dm-haiku from https://files.pythonhosted.org/packages/df/ff/235c5bdf5d83f9013771a37dd926080400ecc1f0586f18583600dcf1540d/dm_haiku-0.0.10-py3-none-any.whl.metadata
  Downloading dm_haiku-0.0.10-py3-none-any.whl.metadata (18 kB)
Downloading dm_haiku-0.0.10-py3-none-any.whl (360 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m360.3/360.3 kB[0m [31m17.4 MB/s[0m eta [36m0:00:00[0m
[?25h[33mDEPRECATION: pyodbc 4.0.0-unsupported has a non-standard version number. pip 23.3 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of pyodbc or contact the author to suggest that they release a version with a conforming version number. Discussion can be found at https://github.com/pypa/pip/issues/12063[0m[33m
[0mInstalling collected packages: dm-haiku
  Attempting uninstall: dm-haiku
    Found existing installation: dm-haiku 0.0.9
    Uninstalling dm-haiku-0.0.9:
      Successf