In [1]:
import jax
import jax.numpy as jnp

import mfvi

import jax.scipy.stats as stats
import optax
import seaborn as sns

In [2]:
def loglikelihood_fn(params, batch):
    logpdf = stats.norm.logpdf(batch, params[0], 1)
    return jnp.sum(logpdf)

def prior_fn(params):
    return stats.norm.logpdf(params[0], 10, 2)

@jax.jit
def logjoint_fn(params, batch):
    return prior_fn(params) + loglikelihood_fn(params, batch)


def batch_data(rng_key, data, batch_size, data_size):
    """Return an iterator over batches of data."""
    while True:
        _, rng_key = jax.random.split(rng_key)
        idxs = jax.random.choice(
            key=rng_key, a=jnp.arange(data_size), shape=(batch_size,)
        )
        minibatch = jnp.array(tuple(data[idx] for idx in idxs))
        yield minibatch

In [3]:
mfvi_ = mfvi.meanfield_vi(
    logjoint_fn, optax.sgd(1e-3)
)

In [None]:
key = jax.random.PRNGKey(123)
optimizer = optax.sgd(1e-3)
pos = jnp.array([1.])
mfvi_state = mfvi_.init(pos)

key, subkey = jax.random.split(key)
data = jax.random.normal(jax.random.PRNGKey(1), shape=(500,1)) + 10 * 2
batches = batch_data(subkey, data, 100, 500)

(MFVIState(mu=Array([1.9803663], dtype=float32), rho=Array([-1.8186588], dtype=float32), opt_state=(EmptyState(), EmptyState())),
 MFVIInfo(elbo=Array(19717.18, dtype=float32)),
 Array([3933081256,  374430633], dtype=uint32))

In [None]:
batch = next(batches)

mfvi_.step(key, mfvi_state, batch, 1)

In [None]:
key = jax.random.PRNGKey(123)
optimizer = optax.sgd(1e-3)
pos = jnp.array([1.])
mfvi_state = mfvi.init(pos, optimizer)


num_steps = 100
n_samples = 5
key, subkey = jax.random.split(key)
data = jax.random.normal(jax.random.PRNGKey(1), shape=(500,1)) + 10 * 2
batches = batch_data(subkey, data, 100, 500)


for _ in range(num_steps):

    start = time.time()
    batch = next(batches)
    end = time.time()
    print(f"Getting next batch time: {end-start}")

    start = time.time()
    mfvi_state, mfvi_info, key = mfvi.step(key, mfvi_state, logjoint_fn, optimizer, batch, n_samples)
    end = time.time()
    print(f"Doing jax stuff time: {end-start}")
    #print(mfvi_info.elbo)



meanfield_params = mfvi_state.mu, mfvi_state.rho
posterior_samples, _ = mfvi.meanfield_sample(meanfield_params, key, 50)
sns.kdeplot(posterior_samples)