In [11]:
from functools import partial
import jax
import jax.numpy as jnp
import optax

LOG_2PI = jnp.log(2.0 * jnp.pi)

def kl_diag_standard_normal(mu, log_sigma):
    sigma2 = jnp.exp(2.0 * log_sigma)
    return 0.5 * jnp.sum(sigma2 + mu**2 - 2.0 * log_sigma - 1.0)

def mc_neg_elbo(key, mu, log_sigma, num_mc: int):
    d = mu.shape[0]
    eps = jax.random.normal(key, shape=(num_mc, d))
    sigma = jnp.exp(log_sigma)
    z = mu[None, :] + sigma[None, :] * eps

    logp = -0.5 * jnp.sum(z**2, axis=-1) - 0.5 * d * LOG_2PI
    logq = -0.5 * jnp.sum(eps**2, axis=-1) - jnp.sum(log_sigma) - 0.5 * d * LOG_2PI
    return -jnp.mean(logp - logq)

@partial(jax.jit, static_argnames=("optimizer",))
def step_exact(opt_state, params, optimizer):
    mu, log_sigma = params
    loss, grads = jax.value_and_grad(kl_diag_standard_normal, argnums=(0, 1))(mu, log_sigma)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return opt_state, params, loss

@partial(jax.jit, static_argnames=("optimizer", "num_mc"))
def step_mc(key, opt_state, params, optimizer, num_mc: int):
    mu, log_sigma = params
    key, subkey = jax.random.split(key)

    def loss_fn(mu_, log_sigma_):
        return mc_neg_elbo(subkey, mu_, log_sigma_, num_mc)

    loss, grads = jax.value_and_grad(loss_fn, argnums=(0, 1))(mu, log_sigma)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return key, opt_state, params, loss


def run(mode: str = "mc", seed: int = 0, d: int = 100, steps: int = 2000,
        lr: float = 1e-2, num_mc: int = 8, init_log_sigma: float = -2.0, log_every: int = 200):
    """
    mode: "exact" or "mc"
    """
    key = jax.random.key(seed)

    mu = jnp.zeros((d,))
    log_sigma = jnp.ones((d,)) * init_log_sigma
    params = (mu, log_sigma)

    optimizer = optax.adam(lr)
    opt_state = optimizer.init(params)

    means = []
    stds = []
    ts = []
    for t in range(1, steps + 1):
        if mode == "exact":
            opt_state, params, loss = step_exact(opt_state, params, optimizer)
        elif mode == "mc":
            key, opt_state, params, loss = step_mc(key, opt_state, params, optimizer, num_mc)
        else:
            raise ValueError("mode must be 'exact' or 'mc'")

        if (t % log_every) == 0 or t == 1:
            mu, log_sigma = params
            sigma = jnp.exp(log_sigma)
            kl = kl_diag_standard_normal(mu, log_sigma)
            print(
                f"t={t:5d}  loss={float(loss): .6f}  KL={float(kl): .6f}  "
                f"|mu|={float(jnp.linalg.norm(mu)):.4f}  "
                f"mean(sigma)={float(jnp.mean(sigma)):.4f}  std(sigma)={float(jnp.std(sigma)):.4f}  "
                f"min(sigma)={float(jnp.min(sigma)):.4e}"
            )
            means.append(jnp.mean(sigma))
            stds.append(jnp.std(sigma))
            ts.append(t)

    return params, means, stds, ts

In [8]:
# Sanity: exact optimization should nail sigma ~ 1 quickly from almost any init.
run(mode="exact", seed=0, d=100, steps=2000, lr=5e-2, init_log_sigma=-3.0, log_every=200)

t=    1  loss= 250.123932  KL= 245.136993  |mu|=0.0000  mean(sigma)=0.0523  std(sigma)=0.0000  min(sigma)=5.2340e-02
t=  200  loss= 0.000006  KL= 0.000006  |mu|=0.0000  mean(sigma)=0.9998  std(sigma)=0.0000  min(sigma)=9.9983e-01
t=  400  loss= 0.000000  KL= 0.000000  |mu|=0.0000  mean(sigma)=1.0000  std(sigma)=0.0000  min(sigma)=1.0000e+00
t=  600  loss=-0.000003  KL=-0.000003  |mu|=0.0000  mean(sigma)=1.0000  std(sigma)=0.0000  min(sigma)=1.0000e+00
t=  800  loss=-0.000003  KL=-0.000003  |mu|=0.0000  mean(sigma)=1.0000  std(sigma)=0.0000  min(sigma)=1.0000e+00
t= 1000  loss=-0.000003  KL=-0.000003  |mu|=0.0000  mean(sigma)=1.0000  std(sigma)=0.0000  min(sigma)=1.0000e+00
t= 1200  loss=-0.000003  KL=-0.000003  |mu|=0.0000  mean(sigma)=1.0000  std(sigma)=0.0000  min(sigma)=1.0000e+00
t= 1400  loss=-0.000003  KL=-0.000003  |mu|=0.0000  mean(sigma)=1.0000  std(sigma)=0.0000  min(sigma)=1.0000e+00
t= 1600  loss=-0.000003  KL=-0.000003  |mu|=0.0000  mean(sigma)=1.0000  std(sigma)=0.0000  m

(Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],      dtype=float32),
 Array([1.6452576e-08, 1.6452576e-08, 1.6452576e-08, 1.6452576e-08,
        1.6452576e-08, 1.6452576e-08, 1.6452576e-08, 1.6452576e-08,
        1.6452576e-08, 1.6452576e-08, 1.6452576e-08, 1.6452576e-08,
        1.6452576e-08, 1.6452576e-08, 1.6452576e-08, 1.6452576e-08,
        1.6452576e-08, 1.6452576e-08, 1.6452576e-08, 1.6452576e-08,
        1.6452576e-08, 1.6452576e-08, 1.6452576e-08, 1.6452576e-08,
        1.6452576e-08, 1.6452576e-08, 1.6452576e-08, 1.6452576e-08,
        1.6452576e-08, 1.6452576e-08, 1.6452576e-08, 1

In [None]:
# MC: should work with reasonable settings.

run(mode="mc", seed=0, d=100, steps=10000, lr=5e-2, num_mc=32, init_log_sigma=-3.0, log_every=1000)

t=    1  loss= 247.486664  KL= 245.261993  |mu|=0.5000  mean(sigma)=0.0523  std(sigma)=0.0000  min(sigma)=5.2340e-02
t= 1000  loss= 0.192264  KL= 0.468204  |mu|=0.5972  mean(sigma)=0.9942  std(sigma)=0.0534  min(sigma)=8.6873e-01
t= 2000  loss= 0.876277  KL= 0.541946  |mu|=0.7078  mean(sigma)=1.0031  std(sigma)=0.0542  min(sigma)=8.9238e-01
t= 3000  loss= 0.530626  KL= 0.421085  |mu|=0.5891  mean(sigma)=0.9875  std(sigma)=0.0478  min(sigma)=8.7415e-01
t= 4000  loss= 0.397608  KL= 0.446191  |mu|=0.6177  mean(sigma)=1.0044  std(sigma)=0.0503  min(sigma)=8.5669e-01
t= 5000  loss= 0.617996  KL= 0.552384  |mu|=0.6356  mean(sigma)=1.0020  std(sigma)=0.0593  min(sigma)=8.6286e-01
t= 6000  loss= 0.483220  KL= 0.488588  |mu|=0.6183  mean(sigma)=0.9901  std(sigma)=0.0534  min(sigma)=8.5783e-01
t= 7000  loss= 0.482189  KL= 0.475805  |mu|=0.6490  mean(sigma)=0.9951  std(sigma)=0.0512  min(sigma)=8.8723e-01
t= 8000  loss= 0.461131  KL= 0.519379  |mu|=0.6794  mean(sigma)=1.0065  std(sigma)=0.0536  m

(Array([ 0.05222048, -0.09547757, -0.08973138, -0.00486247, -0.06041062,
        -0.05572452,  0.00952461, -0.00474449,  0.01430413, -0.01745423,
         0.00565786,  0.08210675,  0.02284317,  0.05148947, -0.055523  ,
         0.07317773,  0.00569338,  0.01393594,  0.01632994,  0.10512311,
         0.0303864 ,  0.08212869, -0.1063923 ,  0.08140229, -0.05069764,
        -0.01297907, -0.07561885, -0.08708009,  0.08924918,  0.05945631,
         0.0725781 , -0.01374277,  0.01198772,  0.01489643, -0.01569547,
         0.08592975,  0.00278923, -0.05594056, -0.04632309,  0.12979361,
         0.01708827,  0.13152191, -0.02218023, -0.03593712, -0.02776361,
        -0.02619023,  0.02860598, -0.11970337, -0.04563605,  0.10448401,
         0.00321919, -0.04123145, -0.02309494,  0.02316009, -0.05259848,
         0.03408407,  0.00239959, -0.03576782,  0.01027049, -0.07080129,
         0.03262068, -0.05093395,  0.03949349,  0.03585717, -0.12591767,
        -0.06085879, -0.09696774, -0.03278483, -0.0

In [6]:
# MC stress test: tiny num_mc + big lr can produce bad dynamics (often looks like “collapse”).
run(mode="mc", seed=0, d=100, steps=4000, lr=2e-1, num_mc=1, init_log_sigma=-3.0, log_every=200)


t=    1  loss= 246.082306  KL= 232.184998  |mu|=2.0000  mean(sigma)=0.0608  std(sigma)=0.0000  min(sigma)=6.0810e-02
t=  200  loss= 8.199280  KL= 11.450817  |mu|=3.0370  mean(sigma)=0.9815  std(sigma)=0.2500  min(sigma)=3.7768e-01
t=  400  loss= 8.570023  KL= 12.263913  |mu|=3.1032  mean(sigma)=0.9925  std(sigma)=0.2690  min(sigma)=3.8712e-01
t=  600  loss= 4.533768  KL= 11.666977  |mu|=3.2454  mean(sigma)=0.9606  std(sigma)=0.2394  min(sigma)=3.3763e-01
t=  800  loss= 8.456177  KL= 12.849049  |mu|=3.1989  mean(sigma)=1.0054  std(sigma)=0.2809  min(sigma)=5.6365e-01
t= 1000  loss= 18.115662  KL= 16.669365  |mu|=3.3841  mean(sigma)=1.0090  std(sigma)=0.3383  min(sigma)=3.6990e-01
t= 1200  loss= 24.594284  KL= 15.183155  |mu|=2.8522  mean(sigma)=0.9744  std(sigma)=0.3180  min(sigma)=2.1452e-01
t= 1400  loss= 14.525162  KL= 12.611357  |mu|=2.9606  mean(sigma)=0.9385  std(sigma)=0.2665  min(sigma)=2.1321e-01
t= 1600  loss= 17.679977  KL= 13.534353  |mu|=2.8683  mean(sigma)=0.9475  std(sigm

(Array([ 0.26047188, -0.37172276,  0.00252672, -0.19129899, -0.20676397,
        -0.7970484 , -0.29160658, -0.06943002, -0.44504952, -0.08637094,
         0.18900146, -0.20127475, -0.38819084,  0.07795054, -0.4408967 ,
        -0.14589953, -0.04657221, -0.06250014,  0.3746062 , -0.34293354,
        -0.21741675, -0.17951053,  0.22295861, -0.4065317 ,  0.00509515,
        -0.00165432, -0.5605479 , -0.64103186, -0.29320174,  0.28177407,
        -0.44035345, -0.06821295,  0.6495677 , -0.2204929 , -0.3080792 ,
        -0.00634776,  0.09931718, -0.05755906, -0.20617934,  0.3373213 ,
         0.0447189 ,  0.22640753, -0.13294555, -0.7889059 ,  0.01509628,
        -0.23003182,  0.16644815,  0.1881523 , -0.17196876, -0.23686506,
        -0.05457807,  0.22352776, -0.13154888, -0.09709449, -0.16866939,
         0.1696515 ,  0.29132098, -0.7004078 ,  0.04433437, -0.40666303,
         0.18765667,  0.4787695 ,  0.16759159,  0.11209597, -0.436815  ,
         0.29356977, -0.24123292, -0.40653935,  0.0