In [1]:
import flax.linen as nn
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import optax
from datasets import load_dataset
import jax.random as random

from aevb.core import AEVB


# Data Processing Functions ----------------------------------
def one_hot_encode(x, k):
    "Create a one-hot encoding of x of size k."
    return jnp.array(x[:, None] == jnp.arange(k), dtype=jnp.float32)


@jax.jit
def prepare_data(X):
    num_examples = X.shape[0]
    num_pixels = 28 * 28
    X = X.reshape(num_examples, num_pixels)
    X = X / 255.0

    return X, num_examples


def data_stream(seed, data, batch_size, data_size):
    """Return an iterator over batches of data."""
    rng = np.random.RandomState(seed)
    num_batches = int(jnp.ceil(data_size / batch_size))
    while True:
        perm = rng.permutation(data_size)
        for i in range(num_batches):
            batch_idx = perm[i * batch_size : (i + 1) * batch_size]
            yield data[batch_idx]



# Generative Model and Recognition Feature Extractor --------------------
def recognition_init(key, latent_dim, data_dim):
    w1key, b1key, w2key, b2key, w3key, b3key = random.split(key, 6)

    shared_W = random.normal(w1key, (data_dim, 100)) * 0.1
    shared_b = random.normal(b1key, (100,)) * 0.1

    mu_W = random.normal(w2key, (100, latent_dim)) * 0.1
    mu_b = random.normal(b2key, (latent_dim,)) * 0.1

    logvar_W = random.normal(w3key, (100, latent_dim)) * 0.1
    logvar_b = random.normal(b3key, (latent_dim,)) *0.1 

    return {"shared": {"W": shared_W, "b": shared_b}, "mu": {"W": mu_W, "b": mu_b}, "logvar": {"W": logvar_W, "b": logvar_b}}

def recognition_apply(params, state, input, train):
    x = jnp.dot(input, params["shared"]["W"]) + params["shared"]["b"]
    x = jax.nn.relu(x)
    mu = jnp.dot(x, params["mu"]["W"]) + params["mu"]["b"]
    logvar = jnp.dot(x, params["logvar"]["W"]) + params["logvar"]["b"]
    sigma = jnp.exp(logvar * 0.5)
    return (mu, sigma), {}

#recognition_apply = jax.vmap(recognition_apply, in_axes=(None, None, 0, None))

def generative_init(key, latent_dim, data_dim):
    wkey, bkey = random.split(key)
    W = random.normal(wkey, (latent_dim, data_dim)) * 0.1
    b = random.normal(bkey, (data_dim,)) * 0.1
    return W, b


def generative_apply(params, state, input, train: bool):
    W,b = params
    pre = jnp.dot(input, W) + b
    out = jax.nn.relu(pre)
    return out, {}

#generative_apply = jax.vmap(generative_apply, in_axes=(None, None, 0, None))


  from .autonotebook import tqdm as notebook_tqdm


In [2]:

# Prepare Data
mnist_data = load_dataset("mnist")
data_train = mnist_data["train"]

X_train = np.stack([np.array(example["image"]) for example in data_train])
X_train, N_train = prepare_data(X_train)

seed = 1
n = N_train.item()
batch_size = 500
batches = data_stream(seed, X_train, batch_size, n)


2024-03-29 12:39:00.303176: W external/xla/xla/service/gpu/nvptx_compiler.cc:742] The NVIDIA driver's CUDA version is 12.0 which is older than the ptxas CUDA version (12.4.99). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [8]:
data_dim = 784
latent_dim = 4
optimizer = optax.adam(1e-3)

engine = AEVB(
    latent_dim=4,
    generative_model=generative_apply,
    recognition_model=recognition_apply,
    optimizer=optimizer,
    n_samples=15,
)
state = {}

rec_params = recognition_init(random.key(0), latent_dim, data_dim)
gen_params = generative_init(random.key(1), latent_dim, data_dim)


# Run AEVB
key = random.key(1242)
num_steps = 5000
eval_every = 100

key, init_key = random.split(key)
aevb_state = engine.init(rec_params, state, gen_params, state)

In [9]:
key, *training_keys = random.split(key, num_steps + 1)
for i, rng_key in enumerate(training_keys):
    batch = next(batches)
    aevb_state, info = engine.step(rng_key, aevb_state, batch)
    if i % eval_every == 0:
        print(f"Step {i} | loss: {info.loss} | nll: {info.nll} | kl: {info.kl}")

# Random Data Samples of Learned Generative Model
key, data_samples_key = random.split(key)
x_samples = engine.util.sample_data(data_samples_key, aevb_state, n_samples=5)

# Encode/Decode samples using Learned Recognition and Generative Models
key, encode_key = random.split(key)
z_samples = engine.util.encode(encode_key, aevb_state, x_samples, n_samples=30)
z_means = z_samples.mean(axis=0)
x_recon = engine.util.decode(aevb_state, z_means)

Step 0 | loss: 42776.68359375 | nll: 42218.33203125 | kl: 558.3516235351562
Step 100 | loss: 28623.5625 | nll: 27586.064453125 | kl: 1037.497802734375
Step 200 | loss: 25083.638671875 | nll: 24162.140625 | kl: 921.49755859375
Step 300 | loss: 23333.6484375 | nll: 22307.55859375 | kl: 1026.09033203125
Step 400 | loss: 21774.603515625 | nll: 20699.970703125 | kl: 1074.6329345703125
Step 500 | loss: 21348.267578125 | nll: 20323.4453125 | kl: 1024.8223876953125
Step 600 | loss: 20740.91796875 | nll: 19722.33984375 | kl: 1018.5777587890625
Step 700 | loss: 20558.134765625 | nll: 19605.55859375 | kl: 952.5765380859375
Step 800 | loss: 20578.52734375 | nll: 19617.986328125 | kl: 960.54150390625
Step 900 | loss: 19497.7109375 | nll: 18625.97265625 | kl: 871.7384033203125
Step 1000 | loss: 20018.634765625 | nll: 19155.015625 | kl: 863.61865234375
Step 1100 | loss: 19849.416015625 | nll: 19000.89453125 | kl: 848.5208740234375
Step 1200 | loss: 20180.15625 | nll: 19333.263671875 | kl: 846.8924560