In [1]:
import jax
import jax.numpy as jnp
import jax.random as jr
import numpy as np
import numpyro
import ot
import scipy
from numpyro import diagnostics, distributions as dist
from numpyro.infer import MCMC, NUTS
from numpyro.infer.util import initialize_model

from mcmc import run_lmc


jnp.set_printoptions(precision=4, suppress=True)
jax.config.update("jax_enable_x64", True)


def get_model_and_data(data, name):
    dset = data[name][0, 0]
    x = dset["x"]
    labels = dset["t"]
    n, data_dim = x.shape
    n_train = min(int(n / 2), 500)
    x_train = x[:n_train]
    labels_train = labels[:n_train]
    x_test = x[n_train:]
    labels_test = labels[n_train:]

    def model():
        alpha = numpyro.sample("alpha", dist.Exponential(0.01))
        W = numpyro.sample("W", dist.Normal(jnp.zeros(data_dim), 1.0 / alpha))
        logits = jnp.sum(W * x_train, axis=-1)
        return numpyro.sample("obs", dist.Bernoulli(logits=logits), obs=labels_train)

    return model, (x_train, labels_train, x_test, labels_test)


data = scipy.io.loadmat("mcmc_data/benchmarks.mat")

model_logreg, data_split = get_model_and_data(data, "banana")

logreg_info = initialize_model(jr.PRNGKey(0), model_logreg)


def dict_to_array(dct: dict):
    alpha = dct["alpha"]

    return jnp.concatenate([jnp.expand_dims(alpha, alpha.ndim), dct["W"]], axis=-1)


vec_dict_to_array = jax.jit(jax.vmap(dict_to_array, in_axes=0, out_axes=0))


@jax.jit
def array_to_dict(arr: jnp.ndarray):
    return {"alpha": arr[0], "W": arr[1:]}


@jax.jit
def potential_fn(arr: jnp.ndarray):
    dct = array_to_dict(arr)
    return logreg_info.potential_fn(dct)


arr0 = dict_to_array(logreg_info.param_info.z)

print(jax.devices("cuda"))


def compute_w2(x1, x2):
    source_samples = np.array(x1)
    target_samples = np.array(x2)
    source_weights = np.ones(source_samples.shape[0]) / source_samples.shape[0]
    target_weights = np.ones(target_samples.shape[0]) / target_samples.shape[0]
    mm = ot.dist(source_samples, target_samples)
    return ot.emd2(source_weights, target_weights, mm, numItermax=1e5)


def eval_logreg(samples, evals_per_sample, ground_truth=None):
    if isinstance(samples, dict):
        samples = vec_dict_to_array(samples)
    ess = diagnostics.effective_sample_size(samples)
    print(f"Effective sample size: {ess}")
    print(
        f"Gradient evals per effective sample:"
        f" {evals_per_sample * samples.shape[0]/ess}"
    )
    if ground_truth is None:
        return
    sample_dim = samples.shape[-1]
    reshaped = jnp.reshape(samples, (-1, sample_dim))

    w2 = compute_w2(reshaped, ground_truth)
    print(f"W2 distance: {w2}")

[cuda(id=0)]


In [6]:
# gt_nuts = MCMC(NUTS(model_logreg, step_size=0.125),
#                num_warmup=2**8, num_samples=2**14)
# gt_nuts.run(jr.PRNGKey(0))
# gt_logreg = vec_dict_to_array(gt_nuts.get_samples())
# np.save("mcmc_data/banana_ground_truth.npy", gt_logreg)
gt_logreg = np.load("mcmc_data/banana_ground_truth.npy")

In [6]:
num_chains = 2**7
num_samples_per_chain = 2**8
out_logreg_lmc, steps_logreg_lmc = run_lmc(
    jr.PRNGKey(0),
    potential_fn,
    arr0,
    num_chains,
    num_samples_per_chain,
    0.1,
    5.0,
    warmup_mult=32.0,
    warmup_tol_mult=8.0,
)
print(out_logreg_lmc.shape)

Steps warmup: 32.5078125, steps mcmc: 883.0078125, gradient evaluations per output: 7.1524658203125
(128, 256, 10)


In [7]:
eval_logreg(out_logreg_lmc, steps_logreg_lmc, gt_logreg)

Effective sample size: [  749.9436  5597.1612 14985.1067]
Gradient evals per effective sample: [12.8984  1.7282  0.6455]
W2 distance: 0.02424249648306309


  check_result(result_code)


In [8]:
nuts = MCMC(
    NUTS(model_logreg),
    num_warmup=2**3,
    num_samples=num_samples_per_chain,
    num_chains=num_chains,
    chain_method="vectorized",
)
nuts.run(jr.PRNGKey(0), extra_fields=("num_steps",))
out_logreg_nuts = nuts.get_samples(group_by_chain=True)
num_steps_nuts = nuts.get_extra_fields()["num_steps"]
geps_nuts = sum(num_steps_nuts) / len(num_steps_nuts)
print(geps_nuts)

sample: 100%|██████████| 264/264 [00:59<00:00,  4.44it/s]


10.700408935546875


In [9]:
eval_logreg(out_logreg_nuts, geps_nuts, gt_logreg)

Effective sample size: [    nan 65.1355     nan]
Gradient evals per effective sample: [    nan 21.0277     nan]
W2 distance: 0.03975801461724085


  check_result(result_code)


In [2]:
import os


def run_logreg_dataset(name):
    model_logreg, data_split = get_model_and_data(data, name)
    logreg_info = initialize_model(jr.PRNGKey(0), model_logreg)
    arr0 = dict_to_array(logreg_info.param_info.z)

    @jax.jit
    def _potential_fn(arr: jnp.ndarray):
        dct = array_to_dict(arr)
        return logreg_info.potential_fn(dct)

    # Compute ground truth
    ground_truth_filename = f"mcmc_data/{name}_ground_truth.npy"
    if os.path.isfile(ground_truth_filename):
        gt_logreg = np.load(ground_truth_filename)
    else:
        gt_nuts = MCMC(
            NUTS(model_logreg, step_size=0.125), num_warmup=2**8, num_samples=2**14
        )
        gt_nuts.run(jr.PRNGKey(0))
        gt_logreg = vec_dict_to_array(gt_nuts.get_samples())
        np.save(f"mcmc_data/{name}_ground_truth.npy", gt_logreg)

    num_chains = 2**7
    num_samples_per_chain = 2**8

    print("LMC")
    out_logreg_lmc, steps_logreg_lmc = run_lmc(
        jr.PRNGKey(0),
        _potential_fn,
        arr0,
        num_chains,
        num_samples_per_chain,
        0.1,
        5.0,
        warmup_mult=32.0,
        warmup_tol_mult=16.0,
    )
    eval_logreg(out_logreg_lmc, steps_logreg_lmc, gt_logreg)

    print("NUTS")
    nuts = MCMC(
        NUTS(model_logreg),
        num_warmup=2**6,
        num_samples=num_samples_per_chain,
        num_chains=num_chains,
        chain_method="vectorized",
    )
    nuts.run(jr.PRNGKey(0), extra_fields=("num_steps",))
    out_logreg_nuts = nuts.get_samples(group_by_chain=True)
    num_steps_nuts = nuts.get_extra_fields()["num_steps"]
    geps_nuts = sum(num_steps_nuts) / len(num_steps_nuts)
    print(f"Gradient evals per effective sample (NUTS): {geps_nuts}")
    eval_logreg(out_logreg_nuts, geps_nuts, gt_logreg)

    print("LMC")

In [3]:
names = [
    "banana",
    "breast_cancer",
    "diabetis",
    "flare_solar",
    "german",
    "heart",
    "image",
    "ringnorm",
    "splice",
    "thyroid",
    "titanic",
    "twonorm",
    "waveform",
]

In [None]:
import warnings


with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    for name in names:
        print(f"==================== {name} ====================")
        run_logreg_dataset(name)

Steps warmup: 25.6640625, steps mcmc: 910.875, gradient evaluations per output: 7.31671142578125


sample: 100%|██████████| 320/320 [00:32<00:00,  9.75it/s]
sample: 100%|██████████| 16640/16640 [00:56<00:00, 295.60it/s, 3 steps of size 8.17e-06. acc. prob=0.83]   


NUTS
Effective sample size: [    nan     nan 65.2196     nan     nan     nan     nan     nan     nan
     nan]
Gradient evals per effective sample: [   nan    nan 17.012    nan    nan    nan    nan    nan    nan    nan]
