# Variational Autoencoder

## Installation

You will need Python 3.10 or later, and a working JAX installation. For example, you can install JAX with:

In [None]:
%pip install -U "jax[cuda12]"

Then, install CAX:

In [None]:
%pip install "cax[examples] @ git+https://github.com/879f4cf7/cax.git"

## Import

In [1]:
import jax
import jax.numpy as jnp
import mediapy
import optax
from cax.nn.vae import VAE, vae_loss
from datasets import load_dataset
from flax import nnx
from tqdm.auto import tqdm

## Configuration

In [None]:
seed = 42

spatial_dims = (28, 28)
features = (1, 32, 32)
latent_size = 8

batch_size = 32
learning_rate = 1e-2

key = jax.random.key(seed)
rngs = nnx.Rngs(seed)

## Dataset

In [3]:
ds = load_dataset("ylecun/mnist")

image_train = jnp.expand_dims(jnp.array(ds["train"]["image"], dtype=jnp.float32) / 255, axis=-1)
image_test = jnp.expand_dims(jnp.array(ds["test"]["image"], dtype=jnp.float32) / 255, axis=-1)

mediapy.show_images(ds["train"]["image"][:8], width=128, height=128)

## Model

In [4]:
vae = VAE(spatial_dims, features, latent_size, rngs)

In [5]:
params = nnx.state(vae, nnx.Param)
print("Number of params:", jax.tree.reduce(lambda x, y: x + y.size, params, 0))

Number of params: 2518513


## Train

### Optimizer

In [6]:
lr_sched = optax.linear_schedule(init_value=learning_rate, end_value=0.01 * learning_rate, transition_steps=8_192)

optimizer = optax.chain(
	optax.clip_by_global_norm(1.0),
	optax.adam(learning_rate=lr_sched),
)
optimizer = nnx.Optimizer(vae, optimizer)

### Loss

In [7]:
@nnx.jit
def loss_fn(vae, image):
	image_recon, mean, logvar = vae(image)
	return vae_loss(image_recon, image, mean, logvar)

### Train step

In [8]:
@nnx.jit
def train_step(vae, optimizer, key):
	image_index = jax.random.choice(key, image_train.shape[0], shape=(batch_size,))
	image = image_train[image_index]

	loss, grad = nnx.value_and_grad(loss_fn)(vae, image)
	optimizer.update(grad)

	return loss

### Main loop

In [None]:
num_train_steps = 8_192
print_interval = 128

pbar = tqdm(range(num_train_steps), desc="Training", unit="train_step")
losses = []

for i in pbar:
	key, subkey = jax.random.split(key)
	loss = train_step(vae, optimizer, subkey)
	losses.append(loss)

	if i % print_interval == 0 or i == num_train_steps - 1:
		avg_loss = sum(losses[-print_interval:]) / len(losses[-print_interval:])
		pbar.set_postfix({"Average Loss": f"{avg_loss:.6f}"})

## Visualize

In [25]:
key, subkey = jax.random.split(key)
z = jax.random.normal(subkey, shape=(8, latent_size))
image = vae.generate(z)

mediapy.show_images(image, width=128, height=128)

In [30]:
key, subkey = jax.random.split(key)
image_index = jax.random.choice(subkey, image_test.shape[0], shape=(8,))
image = image_test[image_index]

state_axes = nnx.StateAxes({nnx.RngState: 0, ...: None})
image_recon, _, _ = nnx.split_rngs(splits=8)(
	nnx.vmap(
		lambda vae, image: vae(image),
		in_axes=(state_axes, 0),
	)
)(vae, image)

mediapy.show_images(image, width=128, height=128)
mediapy.show_images(jax.nn.sigmoid(image_recon), width=128, height=128)