In [1]:
import time

import jax
import jax.numpy as jnp
import jax.random as jr
import jax.tree_util as jtu
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import numpyro
import numpyro.distributions as dist
from numpyro import handlers
from numpyro.infer import MCMC, NUTS

from mcmc import run_lmc_numpyro


matplotlib.use("Agg")  # noqa: E402


# a two-layer bayesian neural network with computational flow
# given by D_X => D_H => D_H => D_Y where D_H is the number of
# hidden units. (note we indicate tensor dimensions in the comments)
def model(X, Y, D_H, D_Y=1):
    N, D_X = X.shape

    # sample first layer (we put unit normal priors on all weights)
    w1 = numpyro.sample("w1", dist.Normal(jnp.zeros((D_X, D_H)), jnp.ones((D_X, D_H))))
    assert w1.shape == (D_X, D_H), f"Expected shape {(D_X, D_H)}, got {w1.shape}"
    z1 = jnp.tanh(jnp.matmul(X, w1))  # <= first layer of activations
    assert z1.shape == (N, D_H)

    # sample second layer
    w2 = numpyro.sample("w2", dist.Normal(jnp.zeros((D_H, D_H)), jnp.ones((D_H, D_H))))
    assert w2.shape == (D_H, D_H)
    z2 = jnp.tanh(jnp.matmul(z1, w2))  # <= second layer of activations
    assert z2.shape == (N, D_H)

    # sample final layer of weights and neural network output
    w3 = numpyro.sample("w3", dist.Normal(jnp.zeros((D_H, D_Y)), jnp.ones((D_H, D_Y))))
    assert w3.shape == (D_H, D_Y)
    z3 = jnp.matmul(z2, w3)  # <= output of the neural network
    assert z3.shape == (N, D_Y)

    if Y is not None:
        assert z3.shape == Y.shape

    # we put a prior on the observation noise
    prec_obs = numpyro.sample("prec_obs", dist.Gamma(3.0, 1.0))
    sigma_obs = 1.0 / jnp.sqrt(prec_obs)

    # observe data
    with numpyro.plate("data", N):
        # note we use to_event(1) because each observation has shape (1,)
        numpyro.sample("Y", dist.Normal(z3, sigma_obs).to_event(1), obs=Y)


# helper function for HMC inference
def run_inference(model, args, rng_key, X, Y, D_H):
    start = time.time()
    kernel = NUTS(model)
    mcmc = MCMC(
        kernel,
        num_warmup=args.num_warmup,
        num_samples=args.num_samples,
        num_chains=args.num_chains,
        thinning=args.thinning,
        chain_method="vectorized",
        progress_bar=True,
    )
    mcmc.run(rng_key, X, Y, D_H, extra_fields=("num_steps",))
    elapsed = time.time() - start
    num_steps = mcmc.get_extra_fields()["num_steps"]
    geps = sum(num_steps) / (args.num_samples * args.num_chains)
    mcmc.print_summary()
    print(f"\nNUTS elapsed time: {elapsed:.4}, grad evals per sample: {geps:.4}")
    return mcmc.get_samples(group_by_chain=True)


def run_inference_lmc(model, args, rng_key, X, Y, D_H):
    n, chains, warmup = args.num_samples, args.num_chains, args.num_warmup
    start = time.time()
    samples, geps = run_lmc_numpyro(
        rng_key,
        model,
        (X, Y, D_H),
        num_particles=chains,
        chain_len=n,
        warmup_mult=warmup,
        tol=args.tol,
        chain_sep=0.25 * args.thinning,
        warmup_tol_mult=64.0,
    )
    elapsed = time.time() - start
    print(f"LMC elapsed time: {elapsed:.4}, grad evals per sample: {geps:.4}")
    return samples


# helper function for prediction
def predict(model, rng_key, samples, X, D_H):
    model = handlers.substitute(handlers.seed(model, rng_key), samples)
    # note that Y will be sampled in the model because we pass Y=None here
    model_trace = handlers.trace(model).get_trace(X=X, Y=None, D_H=D_H)
    return model_trace["Y"]["value"]


# create artificial regression dataset
def get_data(N=50, D_X=3, sigma_obs=0.05, N_test=500):
    D_Y = 1  # create 1d outputs
    np.random.seed(0)
    X = jnp.linspace(-1, 1, N)
    X = jnp.power(X[:, np.newaxis], jnp.arange(D_X))
    W = 0.5 * np.random.randn(D_X)
    Y = jnp.dot(X, W) + 0.5 * jnp.power(0.5 + X[:, 1], 2.0) * jnp.sin(4.0 * X[:, 1])
    Y += sigma_obs * np.random.randn(N)
    Y = Y[:, np.newaxis]
    Y -= jnp.mean(Y)
    Y /= jnp.std(Y)

    assert X.shape == (N, D_X)
    assert Y.shape == (N, D_Y)

    X_test = jnp.linspace(-1.3, 1.3, N_test)
    X_test = jnp.power(X_test[:, np.newaxis], jnp.arange(D_X))

    return X, Y, X_test


def plot(predictions, X, Y, X_test, use_lmc=False):
    # compute mean prediction and confidence interval around median
    mean_prediction = jnp.mean(predictions, axis=0)
    percentiles = np.percentile(predictions, [5.0, 95.0], axis=0)

    # make plots
    fig, ax = plt.subplots(figsize=(8, 6), constrained_layout=True)

    # plot training data
    ax.plot(X[:, 1], Y[:, 0], "kx")
    # plot 90% confidence level of predictions
    ax.fill_between(
        X_test[:, 1], percentiles[0, :], percentiles[1, :], color="lightblue"
    )
    # plot mean prediction
    ax.plot(X_test[:, 1], mean_prediction, "blue", ls="solid", lw=2.0)
    ax.set(xlabel="X", ylabel="Y", title="Mean predictions with 90% CI")

    filename = "bnn_plot_lmc.pdf" if use_lmc else "bnn_plot.pdf"
    plt.savefig(filename)


def main(args, use_lmc=False):
    N, D_X, D_H = args.num_data, 3, args.num_hidden
    X, Y, X_test = get_data(N=N, D_X=D_X)
    num_samples, num_chains = args.num_samples, args.num_chains

    # do inference
    rng_key, rng_key_predict = jr.split(jr.PRNGKey(0))

    inference_fun = run_inference_lmc if use_lmc else run_inference
    samples = inference_fun(model, args, rng_key, X, Y, D_H)

    assert jtu.tree_all(
        jtu.tree_map(lambda x: x.shape[:2] == (num_chains, num_samples), samples)
    ), f"Expected shape (num_chains, num_samples, ...) for all samples, got {jtu.tree_map(lambda x: x.shape, samples)}"
    flat_samples = jtu.tree_map(
        lambda x: jnp.reshape(x, (num_samples * num_chains,) + x.shape[2:]), samples
    )

    # predict Y_test at inputs X_test
    vmap_args = (
        flat_samples,
        jr.split(rng_key_predict, num_samples * num_chains),
    )
    predictions = jax.jit(
        jax.vmap(
            lambda _samples, _rng_key: predict(model, _rng_key, _samples, X_test, D_H)
        )
    )(*vmap_args)
    prediction = predictions[..., 0]

    plot(prediction, X, Y, X_test, use_lmc=use_lmc)

    return samples, predictions


class Args:
    def __init__(
        self,
        num_data=50,
        num_hidden=5,
        num_samples=2**9,
        num_warmup=128,
        num_chains=2**6,
        tol=0.5,
        thinning=1,
    ):
        self.num_data = num_data
        self.num_hidden = num_hidden
        self.num_samples = num_samples
        self.num_warmup = num_warmup
        self.num_chains = num_chains
        self.tol = tol
        self.thinning = thinning


def save_samples(samples, filename):
    with open(filename, "wb") as f:
        jnp.savez(f, **samples)


def load_samples(filename):
    with open(filename, "rb") as f:
        npz = jnp.load(f)
        samples = {k: npz[k] for k in npz.keys()}

    return samples

In [26]:
args = Args(
    num_data=50,
    num_hidden=5,
    num_samples=2**9,
    num_warmup=2048,
    num_chains=2**6,
    tol=0.1,
)
samples_lmc, predictions_lmc = main(args, use_lmc=True)
save_samples(samples_lmc, "samples_lmc.npz")

100.00%|██████████| [04:01<00:00,  2.42s/%]
100.00%|██████████| [01:52<00:00,  1.12s/%]


LMC: gradient evaluations per output: 189.8
LMC elapsed time: 383.2, grad evals per sample: 189.8


In [27]:
args = Args(
    num_data=50, num_hidden=5, num_samples=2**9, num_warmup=2048, num_chains=2**6
)
samples_nuts, predictions_nuts = main(args)
save_samples(samples_nuts, "samples_nuts.npz")

sample: 100%|██████████| 2560/2560 [09:54<00:00,  4.30it/s]



                mean       std    median      5.0%     95.0%     n_eff     r_hat
  prec_obs     11.55      2.39     11.38      7.65     15.39  29352.87      1.00
   w1[0,0]     -0.00      1.15      0.01     -1.88      1.88   4047.20      1.01
   w1[0,1]     -0.01      1.15     -0.01     -1.94      1.87   3961.83      1.01
   w1[0,2]     -0.02      1.15     -0.03     -1.89      1.89   3310.26      1.02
   w1[0,3]      0.01      1.17      0.00     -1.87      1.95   4012.14      1.01
   w1[0,4]      0.03      1.17      0.06     -1.88      1.92   3626.19      1.02
   w1[1,0]     -0.00      1.14     -0.00     -1.84      1.73   3976.57      1.01
   w1[1,1]      0.02      1.15      0.05     -1.79      1.83   3482.02      1.02
   w1[1,2]      0.01      1.15      0.01     -1.80      1.78   3428.00      1.02
   w1[1,3]     -0.01      1.15     -0.02     -1.79      1.82   3682.64      1.02
   w1[1,4]     -0.00      1.14      0.01     -1.79      1.79   3704.13      1.01
   w1[2,0]     -0.00      1

In [2]:
args = Args(
    num_data=50,
    num_hidden=5,
    num_samples=2**16,
    num_warmup=2048,
    num_chains=2**6,
    thinning=8,
)
samples_nuts_precise, predictions_nuts_precise = main(args)
save_samples(samples_nuts_precise, "samples_nuts.npz")

sample: 100%|██████████| 67584/67584 [4:18:55<00:00,  4.35it/s]  



                mean       std    median      5.0%     95.0%     n_eff     r_hat
  prec_obs     11.54      2.39     11.38      7.61     15.35 477243.52      1.00
   w1[0,0]      0.00      1.16      0.01     -1.87      1.93 342172.61      1.00
   w1[0,1]      0.00      1.16      0.00     -1.88      1.92 342279.81      1.00
   w1[0,2]     -0.00      1.16     -0.00     -1.89      1.90 332054.52      1.00
   w1[0,3]     -0.00      1.16     -0.00     -1.90      1.89 345202.48      1.00
   w1[0,4]     -0.00      1.16     -0.00     -1.90      1.90 338795.20      1.00
   w1[1,0]     -0.00      1.14     -0.01     -1.81      1.78 316801.26      1.00
   w1[1,1]     -0.00      1.14     -0.01     -1.80      1.79 321233.90      1.00
   w1[1,2]      0.00      1.15      0.00     -1.79      1.80 312331.88      1.00
   w1[1,3]     -0.00      1.14     -0.00     -1.79      1.80 320548.45      1.00
   w1[1,4]      0.00      1.15     -0.00     -1.80      1.79 319182.07      1.00
   w1[2,0]     -0.01      1

AssertionError: Expected shape (num_chains, num_samples, ...) for all samples, got {'prec_obs': (64, 8192), 'w1': (64, 8192, 3, 5), 'w2': (64, 8192, 5, 5), 'w3': (64, 8192, 5, 1)}

In [None]:
plot(predictions_nuts_precise[..., 0], X, Y, X_test)

In [29]:
from numpyro import diagnostics


samples_lmc = load_samples("samples_lmc.npz")

ess_lmc = jtu.tree_map(
    lambda x: diagnostics.effective_sample_size(x) / x.shape[1], samples_lmc
)
print(ess_lmc)

samples_nuts = load_samples("samples_nuts.npz")
ess_nuts = jtu.tree_map(
    lambda x: diagnostics.effective_sample_size(x) / x.shape[1], samples_nuts
)
print(ess_nuts)

{'prec_obs': 35.47085303911096, 'w1': array([[0.96547994, 0.94590134, 0.88050268, 0.7010292 , 1.08598659],
       [1.09277597, 0.88502356, 1.02323194, 0.75362109, 0.87642086],
       [1.1048426 , 0.75331447, 0.90245502, 0.69098364, 0.75179445]]), 'w2': array([[3.81481028, 3.35927273, 3.2468056 , 3.58500049, 3.63945121],
       [3.33703389, 3.24539524, 3.79352066, 3.47304621, 3.61056228],
       [3.53048479, 3.88395913, 3.55241667, 3.50269356, 3.30924641],
       [2.79644306, 3.1397997 , 3.29180351, 2.99984435, 3.58892501],
       [3.54133879, 3.34176407, 3.65324352, 3.29120513, 3.28289685]]), 'w3': array([[1.02754679],
       [1.01293591],
       [1.13488737],
       [1.2329998 ],
       [1.09866073]])}
{'prec_obs': 57.329830620977035, 'w1': array([[7.90468294, 7.73795397, 6.46535417, 7.83621051, 7.08239748],
       [7.76674254, 6.80082134, 6.69530391, 7.19266347, 7.23463372],
       [7.17681314, 7.09333939, 6.24567394, 7.43929542, 7.02708214]]), 'w2': array([[32.23834754, 29.00150401,