In [1]:
from functools import partial

import jax
import jax.numpy as jnp
import jax.random as jr
import numpy as np
import numpyro
import ot
import scipy
from jax import Array
from numpyro import diagnostics, distributions as dist
from numpyro.infer import MCMC, NUTS
from numpyro.infer.util import initialize_model

from mcmc import run_lmc


jnp.set_printoptions(precision=3, suppress=True)
jax.config.update("jax_enable_x64", True)


def get_model_and_data(data, name):
    dset = data[name][0, 0]
    x = dset["x"]
    labels = dset["t"]
    n, data_dim = x.shape
    n_train = min(int(n / 2), 500)
    x_train = x[:n_train]
    labels_train = labels[:n_train]
    x_test = x[n_train:]
    labels_test = labels[n_train:]

    def model():
        alpha = numpyro.sample("alpha", dist.Exponential(0.01))
        W = numpyro.sample("W", dist.Normal(jnp.zeros(data_dim), 1.0 / alpha))
        logits = jnp.sum(W * x_train, axis=-1)
        return numpyro.sample("obs", dist.Bernoulli(logits=logits), obs=labels_train)

    return model, (x_train, labels_train, x_test, labels_test)


data = scipy.io.loadmat("mcmc_data/benchmarks.mat")

data_name = "diabetis"
model_logreg, data_split = get_model_and_data(data, data_name)

logreg_info = initialize_model(jr.PRNGKey(0), model_logreg)


def dict_to_array(dct: dict):
    alpha = dct["alpha"]

    return jnp.concatenate([jnp.expand_dims(alpha, alpha.ndim), dct["W"]], axis=-1)


vec_dict_to_array = jax.jit(jax.vmap(dict_to_array, in_axes=0, out_axes=0))


@jax.jit
def array_to_dict(arr: jnp.ndarray):
    return {"alpha": arr[0], "W": arr[1:]}


@jax.jit
def potential_fn(arr: jnp.ndarray):
    dct = array_to_dict(arr)
    return logreg_info.potential_fn(dct)


arr0 = dict_to_array(logreg_info.param_info.z)

print(jax.devices("cuda"))


def compute_w2(x1, x2, num_iters):
    source_samples = np.array(x1)
    target_samples = np.array(x2)
    source_weights = np.ones(source_samples.shape[0]) / source_samples.shape[0]
    target_weights = np.ones(target_samples.shape[0]) / target_samples.shape[0]
    mm = ot.dist(source_samples, target_samples)
    return ot.emd2(source_weights, target_weights, mm, numItermax=num_iters)


@partial(jax.jit, static_argnames=("max_len",))
def energy_distance(x: Array, y: Array, max_len: int = 2**15):
    assert y.ndim == x.ndim
    assert x.shape[1:] == y.shape[1:]
    if x.shape[0] > max_len:
        x = x[:max_len]
    if y.shape[0] > max_len:
        y = y[:max_len]

    @partial(jax.vmap, in_axes=(None, 0))
    def _dist_single(_x, _y_single):
        assert _x.ndim == _y_single.ndim + 1, f"{_x.ndim} != {_y_single.ndim + 1}"
        diff = _x - _y_single
        if x.ndim > 1:
            # take the norm over all axes except the first one
            diff = jnp.sqrt(jnp.sum(diff**2, axis=tuple(range(1, diff.ndim))))
        return jnp.mean(jnp.abs(diff))

    def dist(_x, _y):
        assert _x.ndim == _y.ndim
        return jnp.mean(_dist_single(_x, _y))

    return 2 * dist(x, y) - dist(x, x) - dist(y, y)


def eval_logreg(samples, evals_per_sample, ground_truth=None, num_iters_w2=0):
    if isinstance(samples, dict):
        samples = vec_dict_to_array(samples)
    ess = diagnostics.effective_sample_size(samples)
    print(f"Effective sample size: {ess}")
    # print(
    #     f"Gradient evals per effective sample:"
    #     f" {evals_per_sample * samples.shape[0]/ess}"
    # )
    if ground_truth is None:
        return
    sample_dim = samples.shape[-1]
    reshaped = jnp.reshape(samples, (-1, sample_dim))
    energy_gt = energy_distance(reshaped, ground_truth)
    half_len = reshaped.shape[0] // 2
    energy_self = energy_distance(reshaped[:half_len], reshaped[half_len:])
    result_str = (
        f"Energy dist vs ground truth: {energy_gt:.4}, vs self: {energy_self:.4}"
    )
    if num_iters_w2 > 0:
        w2 = compute_w2(reshaped, ground_truth, num_iters_w2)
        result_str += f", Wasserstein-2: {w2:.4}"
    print(result_str)

[cuda(id=0)]


In [2]:
gt_nuts = MCMC(NUTS(model_logreg, step_size=2**-3), num_warmup=2**13, num_samples=2**15)
gt_nuts.run(jr.PRNGKey(0))
gt_logreg = vec_dict_to_array(gt_nuts.get_samples())
np.save(f"mcmc_data/{data_name}_ground_truth.npy", gt_logreg)
gt_logreg = np.load(f"mcmc_data/{data_name}_ground_truth.npy")
size_gt_half = int(gt_logreg.shape[0] // 2)
energy_bias = energy_distance(gt_logreg[:size_gt_half], gt_logreg[size_gt_half:])
print(f"Energy bias: {energy_bias:.4}")

sample:  17%|█▋        | 5783/33792 [27:53<2:15:07,  3.45it/s, 1023 steps of size 1.92e-15. acc. prob=0.92]


KeyboardInterrupt: 

In [18]:
num_chains = 2**7
num_samples_per_chain = 2**8
out_logreg_lmc, steps_logreg_lmc = run_lmc(
    jr.PRNGKey(0),
    potential_fn,
    arr0,
    num_chains,
    num_samples_per_chain,
    chain_sep=0.25,
    tol=0.1,
    warmup_mult=128.0,
    warmup_tol_mult=4.0,
)
print(out_logreg_lmc.shape)

LMC: Steps warmup: 2300.2734375, steps mcmc: 24664.6640625, gradient evaluations per output: 210.66357421875
(128, 256, 9)


In [23]:
eval_logreg(out_logreg_lmc, steps_logreg_lmc, gt_logreg)

Effective sample size: [427.2759 323.3381 323.3381 323.3381 323.3381 323.3381 323.3381 323.3382
 323.3381]:.4
Gradient evals per effective sample: [63.109  83.3955 83.3955 83.3955 83.3955 83.3955 83.3955 83.3955 83.3955]
Energy distance vs ground truth: 1.216e+08, energy distance vs self: 0.007298


In [8]:
nuts = MCMC(
    NUTS(model_logreg),
    num_warmup=2**3,
    num_samples=num_samples_per_chain,
    num_chains=num_chains,
    chain_method="vectorized",
)
nuts.run(jr.PRNGKey(0), extra_fields=("num_steps",))
out_logreg_nuts = nuts.get_samples(group_by_chain=True)
num_steps_nuts = nuts.get_extra_fields()["num_steps"]
geps_nuts = sum(num_steps_nuts) / len(num_steps_nuts)
print(geps_nuts)

sample: 100%|██████████| 264/264 [00:59<00:00,  4.44it/s]


10.700408935546875


In [9]:
eval_logreg(out_logreg_nuts, geps_nuts, gt_logreg)

Effective sample size: [    nan 65.1355     nan]
Gradient evals per effective sample: [    nan 21.0277     nan]
W2 distance: 0.03975801461724085


  check_result(result_code)


In [2]:
import os


def run_logreg_dataset(name):
    model_logreg, data_split = get_model_and_data(data, name)
    logreg_info = initialize_model(jr.PRNGKey(0), model_logreg)
    arr0 = dict_to_array(logreg_info.param_info.z)

    @jax.jit
    def _potential_fn(arr: jnp.ndarray):
        dct = array_to_dict(arr)
        return logreg_info.potential_fn(dct)

    # Compute ground truth
    ground_truth_filename = f"mcmc_data/{name}_ground_truth.npy"
    if os.path.isfile(ground_truth_filename):
        ground_truth = np.load(ground_truth_filename)
    else:
        gt_nuts = MCMC(
            NUTS(model_logreg, step_size=0.125), num_warmup=2**13, num_samples=2**15
        )
        gt_nuts.run(jr.PRNGKey(0))
        ground_truth = vec_dict_to_array(gt_nuts.get_samples())
        np.save(f"mcmc_data/{name}_ground_truth.npy", ground_truth)
    size_gt_half = int(ground_truth.shape[0] // 2)
    gt_energy_bias = energy_distance(
        ground_truth[:size_gt_half], ground_truth[size_gt_half:]
    )
    print(f"Ground truth energy bias: {gt_energy_bias:.4}")
    num_chains = 2**7
    num_samples_per_chain = 2**8

    out_logreg_lmc, steps_logreg_lmc = run_lmc(
        jr.PRNGKey(0),
        _potential_fn,
        arr0,
        num_chains,
        num_samples_per_chain,
        chain_sep=0.25,
        tol=0.1,
        warmup_mult=128.0,
        warmup_tol_mult=8.0,
    )
    eval_logreg(out_logreg_lmc, steps_logreg_lmc, ground_truth)

    nuts = MCMC(
        NUTS(model_logreg),
        num_warmup=2**8,
        num_samples=num_samples_per_chain,
        num_chains=num_chains,
        chain_method="vectorized",
    )
    nuts.run(jr.PRNGKey(0), extra_fields=("num_steps",))
    out_logreg_nuts = nuts.get_samples(group_by_chain=True)
    num_steps_nuts = nuts.get_extra_fields()["num_steps"]
    geps_nuts = sum(num_steps_nuts) / len(num_steps_nuts)
    print(f"NUTS: Gradient evals per output: {geps_nuts:.4}")
    eval_logreg(out_logreg_nuts, geps_nuts, ground_truth)

In [3]:
names = [
    "banana",
    "breast_cancer",
    "diabetis",
    "flare_solar",
    "german",
    "heart",
    "image",
    "ringnorm",
    "splice",
    "thyroid",
    "titanic",
    "twonorm",
    "waveform",
]

In [4]:
import warnings


with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    for name in names:
        print(f"==================== {name} ====================")
        run_logreg_dataset(name)
        print()



sample: 100%|██████████| 40960/40960 [02:09<00:00, 316.86it/s, 7 steps of size 6.03e-01. acc. prob=0.92]


Ground truth energy bias: 0.0002429
LMC: gradient evaluations per output: 337.6
Effective sample size: [ 2413.727  7492.875 20134.305]
Energy dist vs ground truth: 3.438, vs self: 0.0001769


sample: 100%|██████████| 512/512 [01:20<00:00,  6.32it/s]


NUTS: Gradient evals per output: 7.066
Effective sample size: [11042.872 33694.531 22180.266]
Energy dist vs ground truth: 0.0001309, vs self: 0.0004369



sample: 100%|██████████| 40960/40960 [01:55<00:00, 354.57it/s, 1 steps of size 3.90e-08. acc. prob=0.86]  


Ground truth energy bias: 2.925e+08
LMC: gradient evaluations per output: 29.86
Effective sample size: [568.704 323.338 323.339 323.338 323.343 323.341 323.338 323.344 323.299
 323.342]
Energy dist vs ground truth: 9.922e+10, vs self: 0.005908


sample: 100%|██████████| 512/512 [01:56<00:00,  4.38it/s]


NUTS: Gradient evals per output: 42.95
Effective sample size: [64.421    nan 64.125 64.125    nan 64.125 64.125 64.125 64.125    nan]
Energy dist vs ground truth: 2.803e+22, vs self: 1.101e+23



sample: 100%|██████████| 40960/40960 [3:24:13<00:00,  3.34it/s, 1023 steps of size 3.78e-16. acc. prob=0.82]  


Ground truth energy bias: 0.0
LMC: gradient evaluations per output: 71.52
Effective sample size: [572.847 323.338 323.338 323.338 323.338 323.338 323.338 323.338 323.338]
Energy dist vs ground truth: 1.601e+11, vs self: 0.004434


sample: 100%|██████████| 512/512 [04:05<00:00,  2.09it/s]


NUTS: Gradient evals per output: 176.6
Effective sample size: [   nan 64.125 64.125 64.125 64.125    nan 64.125 64.125    nan]
Energy dist vs ground truth: 4.116e+18, vs self: 1.645e+19



sample: 100%|██████████| 40960/40960 [09:19<00:00, 73.15it/s, 63 steps of size 7.83e-02. acc. prob=0.96]


Ground truth energy bias: 0.0001228


jax.pure_callback failed
Traceback (most recent call last):
  File "/home/andy/PycharmProjects/diffrax_STLA/venv/lib/python3.11/site-packages/jax/_src/callback.py", line 77, in pure_callback_impl
    return callback(*args)
           ^^^^^^^^^^^^^^^
  File "/home/andy/PycharmProjects/diffrax_STLA/venv/lib/python3.11/site-packages/jax/_src/callback.py", line 65, in __call__
    return tree_util.tree_leaves(self.callback_func(*args, **kwargs))
                                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/andy/PycharmProjects/diffrax_STLA/venv/lib/python3.11/site-packages/equinox/_errors.py", line 70, in raises
    raise EqxRuntimeError(msgs[_index.item()])
equinox._errors.EqxRuntimeError: The maximum number of solver steps was reached. Try increasing `max_steps`.
E0426 23:06:31.385251   12784 pjrt_stream_executor_client.cc:2809] Execution of replica 0 failed: INTERNAL: CustomCall failed: CpuCallback error: Traceback (most recent call last):
  File "<frozen runpy>", l

ValueError: not enough values to unpack (expected 2, got 1)