In [1]:
import jax
import jax.numpy as jnp
import jax.random as random
import numpy as np
import optax
from datasets import load_dataset

from aevb.aevb import AevbEngine
from aevb._src.dist import normal


# Data Processing Functions ----------------------------------
def one_hot_encode(x, k):
    "Create a one-hot encoding of x of size k."
    return jnp.array(x[:, None] == jnp.arange(k), dtype=jnp.float32)


@jax.jit
def prepare_data(X):
    num_examples = X.shape[0]
    num_pixels = 28 * 28
    X = X.reshape(num_examples, num_pixels)
    X = X / 255.0

    return X, num_examples


def data_stream(seed, data, batch_size, data_size):
    """Return an iterator over batches of data."""
    rng = np.random.RandomState(seed)
    num_batches = int(jnp.ceil(data_size / batch_size))
    while True:
        perm = rng.permutation(data_size)
        for i in range(num_batches):
            batch_idx = perm[i * batch_size : (i + 1) * batch_size]
            yield data[batch_idx]


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Generative Model and Recognition Feature Extractor --------------------
def recognition_init(key, latent_dim, data_dim):
    w1key, b1key, w2key, b2key, w3key, b3key = random.split(key, 6)

    shared_W = random.normal(w1key, (data_dim, 100)) * 0.1
    shared_b = random.normal(b1key, (100,)) * 0.1

    mu_W = random.normal(w2key, (100, latent_dim)) * 0.1
    mu_b = random.normal(b2key, (latent_dim,)) * 0.1

    logvar_W = random.normal(w3key, (100, latent_dim)) * 0.1
    logvar_b = random.normal(b3key, (latent_dim,)) * 0.1

    return ({
        "shared": {"W": shared_W, "b": shared_b},
        "mu": {"W": mu_W, "b": mu_b},
        "logvar": {"W": logvar_W, "b": logvar_b},
    }), {}


def recognition_apply(params, state, input, train):
    x = jnp.dot(input, params["shared"]["W"]) + params["shared"]["b"]
    x = jax.nn.relu(x)
    mu = jnp.dot(x, params["mu"]["W"]) + params["mu"]["b"]
    logvar = jnp.dot(x, params["logvar"]["W"]) + params["logvar"]["b"]
    sigma = jnp.exp(logvar * 0.5)
    return {"loc": mu, "scale": sigma}, {}


# recognition_apply = jax.vmap(recognition_apply, in_axes=(None, None, 0, None))


def generative_init(key, latent_dim, data_dim):
    wkey, bkey = random.split(key)
    W = random.normal(wkey, (latent_dim, data_dim)) * 0.1
    b = random.normal(bkey, (data_dim,)) * 0.1
    return {"w": W, "b": b}, {}


def generative_apply(params, state, input, train: bool):
    W, b = params["w"], params["b"]
    pre = jnp.dot(input, W) + b
    out = jax.nn.relu(pre)
    return {"loc": out, "scale": 1}, {}

In [3]:
# Prepare Data
mnist_data = load_dataset("mnist")
data_train = mnist_data["train"]

X_train = np.stack([np.array(example["image"]) for example in data_train])
X_train, N_train = prepare_data(X_train)

seed = 1
n = N_train.item()
batch_size = 500
batches = data_stream(seed, X_train, batch_size, n)

data_dim = 784
latent_dim = 4
optimizer = optax.adam(1e-3)

In [4]:
rec_params, rec_state = recognition_init(random.key(0), latent_dim, data_dim)
gen_params, gen_state = generative_init(random.key(1), latent_dim, data_dim)

aevb_engine = AevbEngine.from_applys(
    latent_dim=latent_dim,
    data_dim=data_dim,
    gen_prior="unit_normal",
    gen_loglik="normal",
    gen_apply=generative_apply,
    rec_dist="normal",
    rec_apply=recognition_apply,
    optimizer=optimizer,
    n_samples=1
)



In [5]:
aevb_state = aevb_engine.init(rec_params, rec_state, gen_params, gen_state)

In [6]:
aevb_engine.step(random.key(0), aevb_state, X_train[0:2])

(AevbState(rec_params={'logvar': {'W': Array([[-2.20809914e-02,  2.92238984e-02, -6.05521090e-02,
         -2.32900679e-02],
        [-1.13535397e-01,  1.71265136e-02,  2.21996710e-01,
         -1.77103207e-01],
        [ 7.19098598e-02,  2.39095222e-02,  1.87002733e-01,
         -9.83726010e-02],
        [-1.42196730e-01, -5.79608157e-02, -1.72340588e-05,
          4.40230891e-02],
        [-2.93076336e-01, -1.83322560e-02,  4.90383729e-02,
          2.22229123e-01],
        [-9.72286090e-02, -1.82023451e-01,  7.36403093e-02,
          1.64915785e-01],
        [-6.05688170e-02,  1.21684976e-01, -1.95530534e-01,
         -8.89911801e-02],
        [-6.95139989e-02, -5.11346124e-02,  8.24787281e-03,
         -1.56331398e-02],
        [ 2.70374902e-02,  1.16931893e-01,  7.29086921e-02,
         -1.04550473e-01],
        [ 6.46673962e-02, -1.96751691e-02,  4.38791960e-02,
         -1.98730212e-02],
        [ 2.24984419e-02, -6.18143938e-02,  7.84693807e-02,
         -1.96547821e-01],
     