In [1]:
import equinox as eqx
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import optax
from datasets import load_dataset
import jax.random as random
from jax.random import PRNGKey, split

from aevb.core import AEVB


# Data Processing Functions ----------------------------------
def one_hot_encode(x, k):
    "Create a one-hot encoding of x of size k."
    return jnp.array(x[:, None] == jnp.arange(k), dtype=jnp.float32)


@jax.jit
def prepare_data(X):
    num_examples = X.shape[0]
    num_pixels = 28 * 28
    X = X.reshape(num_examples, num_pixels)
    X = X / 255.0

    return X, num_examples


def data_stream(seed, data, batch_size, data_size):
    """Return an iterator over batches of data."""
    rng = np.random.RandomState(seed)
    num_batches = int(jnp.ceil(data_size / batch_size))
    while True:
        perm = rng.permutation(data_size)
        for i in range(num_batches):
            batch_idx = perm[i * batch_size : (i + 1) * batch_size]
            yield data[batch_idx]


class MlpEncoder(eqx.Module):
    norm1: eqx.nn.BatchNorm
    linear1: eqx.nn.Linear
    linear2: eqx.nn.Linear

    def __init__(self, key):
        key1, key2 = random.split(key)
        self.norm1 = eqx.nn.BatchNorm(input_size=32, axis_name="batch")
        self.linear1 = eqx.nn.Linear(in_features=10, out_features=32, key=key1)
        self.linear2 = eqx.nn.Linear(in_features=32, out_features=3, key=key2)

    def __call__(self, x, state):
        x = self.linear1(x)
        x, state = self.norm1(x, state)
        x = jax.nn.relu(x)
        x = self.linear2(x)
        return x, state


class MlpDecoder(eqx.Module):
    norm1: eqx.nn.BatchNorm
    linear1: eqx.nn.Linear
    linear2: eqx.nn.Linear

    def __init__(self, key):
        key1, key2 = random.split(key)
        self.norm1 = eqx.nn.BatchNorm(input_size=8, axis_name="batch")
        self.linear1 = eqx.nn.Linear(in_features=3, out_features=8, key=key1)
        self.linear2 = eqx.nn.Linear(in_features=8, out_features=10, key=key2)

    def __call__(self, x, state):
        x = self.linear1(x)
        x, state = self.norm1(x, state)
        x = jax.nn.relu(x)
        x = self.linear2(x)
        return x, state


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from aevb._src.eqx import convert_eqx_model
init, apply = convert_eqx_model(MlpEncoder, random.key(0))

rec_params, rec_state = init()
model = apply(params=rec_params, state=rec_state, input=jnp.ones((100, 10)), train=True)


2024-03-25 01:43:59.074849: W external/xla/xla/service/gpu/nvptx_compiler.cc:742] The NVIDIA driver's CUDA version is 12.0 which is older than the ptxas CUDA version (12.4.99). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [4]:
latent_dim = 3
optimizer = optax.adam(1e-3)
init, step, sample_data = AEVB(
    latent_dim=latent_dim,
    generative_model=(MlpDecoder, {"rng_key": random.key(0)}),
    recognition_model=(MlpEncoder, {"rng_key": random.key(1)}),
    optimizer=optimizer,
    n_samples=15,
    nn_lib="equinox",
)       

In [5]:
aevb_state = init()

step(random.key(2), aevb_state, jnp.ones((100, 10)))

ValueError: too many values to unpack (expected 2)

In [4]:
# Main Function --------------------------------
def main(save_samples_pth: str):

    # Prepare Data
    mnist_data = load_dataset("mnist")
    data_train = mnist_data["train"]

    X_train = np.stack([np.array(example["image"]) for example in data_train])
    X_train, N_train = prepare_data(X_train)

    seed = 1
    n = N_train.item()
    batch_size = 100
    batches = data_stream(seed, X_train, batch_size, n)

    # Create AEVB inference engine
    latent_dim = 3
    optimizer = optax.adam(1e-3)

    init, step, sample_data = AEVB(
        latent_dim=latent_dim,
        generative_model=(MlpEncoder, {}),
        recognition_model=(MlpDecoder, {}),
        optimizer=optimizer,
        n_samples=15,
        nn_lib="equinox",
    )

    # Run AEVB
    key = PRNGKey(1242)
    num_steps = 10000
    eval_every = 100

    key, init_key = split(key)
    state = init(init_key)

    key, *training_keys = split(key, num_steps + 1)
    for i, rng_key in enumerate(training_keys):
        batch = next(batches)
        state, info = step(rng_key, state, batch)
        if i % eval_every == 0:
            print(f"Step {i} | loss: {info.loss} | nll: {info.nll} | kl: {info.kl}")

    # Random Data Samples of Learned Generative Model
    key, data_samples_key = split(key)
    samples, _ = sample_data(data_samples_key, state.gen_params, state.gen_state, 5)
    fig, axs = plt.subplots(5, 1)
    for i, s in enumerate(samples):
        axs[i].imshow(s.reshape(28, 28))
    plt.savefig(save_samples_pth, format="png")


if __name__ == "__main__":
    from time import localtime, strftime

    now = strftime("%Y-%m-%d %H:%M:%S", localtime())
    main(f"./samples-{now}.png")


TypeError: dot_general requires contracting dimensions to have the same shape, got (3,) and (100,).