In [1]:
from functools import partial

import jax
import jax.numpy as jnp
import jax.random as jr
import jax.tree_util as jtu
import numpy as np
import numpyro
import ot
import scipy
from jax import Array
from numpyro import diagnostics, distributions as dist
from numpyro.infer import MCMC, NUTS
from numpyro.infer.util import initialize_model

from mcmc import run_lmc_numpyro


jnp.set_printoptions(precision=3, suppress=True)
jax.config.update("jax_enable_x64", True)


def get_model_and_data(data, name):
    dset = data[name][0, 0]
    x = dset["x"]
    labels = dset["t"]
    n, data_dim = x.shape
    n_train = min(int(n / 2), 500)
    x_train = x[:n_train]
    labels_train = labels[:n_train]
    x_test = x[n_train:]
    labels_test = labels[n_train:]

    def model():
        alpha = numpyro.sample("alpha", dist.Exponential(0.01))
        W = numpyro.sample("W", dist.Normal(jnp.zeros(data_dim), 1.0 / alpha))
        logits = jnp.sum(W * x_train, axis=-1)
        return numpyro.sample("obs", dist.Bernoulli(logits=logits), obs=labels_train)

    return model, (x_train, labels_train, x_test, labels_test)


data = scipy.io.loadmat("mcmc_data/benchmarks.mat")

data_name = "diabetis"
model_logreg, data_split = get_model_and_data(data, data_name)

model = initialize_model(jr.PRNGKey(0), model_logreg)

print(jax.devices("cuda"))


def compute_w2(x1, x2, num_iters):
    source_samples = np.array(x1)
    target_samples = np.array(x2)
    source_weights = np.ones(source_samples.shape[0]) / source_samples.shape[0]
    target_weights = np.ones(target_samples.shape[0]) / target_samples.shape[0]
    mm = ot.dist(source_samples, target_samples)
    return ot.emd2(source_weights, target_weights, mm, numItermax=num_iters)


@partial(jax.jit, static_argnames=("max_len",))
def energy_distance(x: Array, y: Array, max_len: int = 2**15):
    assert y.ndim == x.ndim
    assert x.shape[1:] == y.shape[1:]
    if x.shape[0] > max_len:
        x = x[:max_len]
    if y.shape[0] > max_len:
        y = y[:max_len]

    @partial(jax.vmap, in_axes=(None, 0))
    def _dist_single(_x, _y_single):
        assert _x.ndim == _y_single.ndim + 1, f"{_x.ndim} != {_y_single.ndim + 1}"
        diff = _x - _y_single
        if x.ndim > 1:
            # take the norm over all axes except the first one
            diff = jnp.sqrt(jnp.sum(diff**2, axis=tuple(range(1, diff.ndim))))
        return jnp.mean(jnp.abs(diff))

    def dist(_x, _y):
        assert _x.ndim == _y.ndim
        return jnp.mean(_dist_single(_x, _y))

    return 2 * dist(x, y) - dist(x, x) - dist(y, y)


def dict_to_array(dct: dict):
    alpha = dct["alpha"]
    return jnp.concatenate([jnp.expand_dims(alpha, alpha.ndim), dct["W"]], axis=-1)


vec_dict_to_array = jax.jit(jax.vmap(dict_to_array, in_axes=0, out_axes=0))


def eval_logreg(samples, evals_per_sample, ground_truth=None, num_iters_w2=0):
    if isinstance(samples, dict):
        samples = vec_dict_to_array(samples)
    ess = diagnostics.effective_sample_size(samples)
    print(f"Effective sample size: {ess}")
    # print(
    #     f"Gradient evals per effective sample:"
    #     f" {evals_per_sample * samples.shape[0]/ess}"
    # )
    if ground_truth is None:
        return
    sample_dim = samples.shape[-1]
    reshaped = jnp.reshape(samples, (-1, sample_dim))
    energy_gt = energy_distance(reshaped, ground_truth)
    half_len = reshaped.shape[0] // 2
    energy_self = energy_distance(reshaped[:half_len], reshaped[half_len:])
    result_str = (
        f"Energy dist vs ground truth: {energy_gt:.4}, vs self: {energy_self:.4}"
    )
    if num_iters_w2 > 0:
        w2 = compute_w2(reshaped, ground_truth, num_iters_w2)
        result_str += f", Wasserstein-2: {w2:.4}"
    print(result_str)

[cuda(id=0)]


In [2]:
# gt_nuts = MCMC(NUTS(model_logreg, step_size=2**-3),
#                num_warmup=2**13, num_samples=2**15)
# gt_nuts.run(jr.PRNGKey(0))
# gt_logreg = vec_dict_to_array(gt_nuts.get_samples())
# np.save(f"mcmc_data/{data_name}_ground_truth.npy", gt_logreg)
gt_logreg = np.load(f"mcmc_data/{data_name}_ground_truth.npy")
size_gt_half = int(gt_logreg.shape[0] // 2)
energy_bias = energy_distance(gt_logreg[:size_gt_half], gt_logreg[size_gt_half:])
print(f"Energy bias: {energy_bias}")
print(f"Ground truth shape: {gt_logreg.shape}")

Energy bias: 0.0
Ground truth shape: (32768, 9)


In [3]:
num_chains = 2**6
num_samples_per_chain = 2**7
out_logreg_lmc, steps_logreg_lmc = run_lmc_numpyro(
    jr.PRNGKey(0),
    model,
    num_chains,
    num_samples_per_chain,
    chain_sep=0.25,
    tol=0.1,
    warmup_mult=128.0,
    warmup_tol_mult=4.0,
)
print(jtu.tree_map(lambda x: x.shape, out_logreg_lmc))

LMC: gradient evaluations per output: 92.6
{'W': (64, 128, 8), 'alpha': (64, 128)}


In [4]:
eval_logreg(out_logreg_lmc, steps_logreg_lmc, gt_logreg)

Effective sample size: [668.592 162.907 162.907 162.907 162.907 162.907 162.907 162.907 162.907]
Energy dist vs ground truth: 1.601e+11, vs self: 0.01475


In [5]:
nuts = MCMC(
    NUTS(model_logreg),
    num_warmup=2**3,
    num_samples=num_samples_per_chain,
    num_chains=num_chains,
    chain_method="vectorized",
)
nuts.run(jr.PRNGKey(0), extra_fields=("num_steps",))
out_logreg_nuts = nuts.get_samples(group_by_chain=True)
num_steps_nuts = nuts.get_extra_fields()["num_steps"]
geps_nuts = sum(num_steps_nuts) / len(num_steps_nuts)
print(geps_nuts)

sample: 100%|██████████| 136/136 [00:15<00:00,  8.79it/s]


11.0029296875


In [6]:
eval_logreg(out_logreg_nuts, geps_nuts, gt_logreg)

Effective sample size: [567.858  62.241  62.236  62.244  62.235  62.234  62.245  62.256  62.24 ]
Energy dist vs ground truth: 1.601e+11, vs self: 2.425e+03


In [2]:
import os


def run_logreg_dataset(name):
    model_logreg, data_split = get_model_and_data(data, name)
    model = initialize_model(jr.PRNGKey(0), model_logreg)

    # Compute ground truth
    ground_truth_filename = f"mcmc_data/{name}_ground_truth.npy"
    if os.path.isfile(ground_truth_filename):
        ground_truth = np.load(ground_truth_filename)
    else:
        gt_nuts = MCMC(
            NUTS(model_logreg, step_size=0.125), num_warmup=2**13, num_samples=2**15
        )
        gt_nuts.run(jr.PRNGKey(0))
        ground_truth = vec_dict_to_array(gt_nuts.get_samples())
        np.save(f"mcmc_data/{name}_ground_truth.npy", ground_truth)
    size_gt_half = int(ground_truth.shape[0] // 2)
    gt_energy_bias = energy_distance(
        ground_truth[:size_gt_half], ground_truth[size_gt_half:]
    )
    print(f"Ground truth energy bias: {gt_energy_bias:.4}")
    num_chains = 2**7
    num_samples_per_chain = 2**8

    out_logreg_lmc, steps_logreg_lmc = run_lmc_numpyro(
        jr.PRNGKey(0),
        model,
        num_chains,
        num_samples_per_chain,
        chain_sep=0.25,
        tol=0.1,
        warmup_mult=128.0,
        warmup_tol_mult=8.0,
    )
    eval_logreg(out_logreg_lmc, steps_logreg_lmc, ground_truth)

    nuts = MCMC(
        NUTS(model_logreg),
        num_warmup=2**8,
        num_samples=num_samples_per_chain,
        num_chains=num_chains,
        chain_method="vectorized",
    )
    nuts.run(jr.PRNGKey(0), extra_fields=("num_steps",))
    out_logreg_nuts = nuts.get_samples(group_by_chain=True)
    num_steps_nuts = nuts.get_extra_fields()["num_steps"]
    geps_nuts = sum(num_steps_nuts) / len(num_steps_nuts)
    print(f"NUTS: Gradient evals per output: {geps_nuts:.4}")
    eval_logreg(out_logreg_nuts, geps_nuts, ground_truth)

In [3]:
names = [
    "banana",
    "breast_cancer",
    "diabetis",
    "flare_solar",
    "german",
    "heart",
    "image",
    "ringnorm",
    "splice",
    "thyroid",
    "titanic",
    "twonorm",
    "waveform",
]

In [4]:
import warnings


with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    for name in names:
        print(f"==================== {name} ====================")
        run_logreg_dataset(name)
        print()