In [1]:
from functools import partial

import jax
import jax.numpy as jnp
import jax.random as jr
import numpy as np
import numpyro
import ot
import scipy
from jax import Array
from numpyro import diagnostics, distributions as dist
from numpyro.infer import MCMC, NUTS
from numpyro.infer.util import initialize_model

from mcmc import run_lmc


jnp.set_printoptions(precision=3, suppress=True)
jax.config.update("jax_enable_x64", True)


def get_model_and_data(data, name):
    dset = data[name][0, 0]
    x = dset["x"]
    labels = dset["t"]
    n, data_dim = x.shape
    n_train = min(int(n / 2), 500)
    x_train = x[:n_train]
    labels_train = labels[:n_train]
    x_test = x[n_train:]
    labels_test = labels[n_train:]

    def model():
        alpha = numpyro.sample("alpha", dist.Exponential(0.01))
        W = numpyro.sample("W", dist.Normal(jnp.zeros(data_dim), 1.0 / alpha))
        logits = jnp.sum(W * x_train, axis=-1)
        return numpyro.sample("obs", dist.Bernoulli(logits=logits), obs=labels_train)

    return model, (x_train, labels_train, x_test, labels_test)


data = scipy.io.loadmat("mcmc_data/benchmarks.mat")

data_name = "diabetis"
model_logreg, data_split = get_model_and_data(data, data_name)

logreg_info = initialize_model(jr.PRNGKey(0), model_logreg)


def dict_to_array(dct: dict):
    alpha = dct["alpha"]

    return jnp.concatenate([jnp.expand_dims(alpha, alpha.ndim), dct["W"]], axis=-1)


vec_dict_to_array = jax.jit(jax.vmap(dict_to_array, in_axes=0, out_axes=0))


@jax.jit
def array_to_dict(arr: jnp.ndarray):
    return {"alpha": arr[0], "W": arr[1:]}


@jax.jit
def potential_fn(arr: jnp.ndarray):
    dct = array_to_dict(arr)
    return logreg_info.potential_fn(dct)


arr0 = dict_to_array(logreg_info.param_info.z)

print(jax.devices("cuda"))


def compute_w2(x1, x2, num_iters):
    source_samples = np.array(x1)
    target_samples = np.array(x2)
    source_weights = np.ones(source_samples.shape[0]) / source_samples.shape[0]
    target_weights = np.ones(target_samples.shape[0]) / target_samples.shape[0]
    mm = ot.dist(source_samples, target_samples)
    return ot.emd2(source_weights, target_weights, mm, numItermax=num_iters)


@partial(jax.jit, static_argnames=("max_len",))
def energy_distance(x: Array, y: Array, max_len: int = 2**15):
    assert y.ndim == x.ndim
    assert x.shape[1:] == y.shape[1:]
    if x.shape[0] > max_len:
        x = x[:max_len]
    if y.shape[0] > max_len:
        y = y[:max_len]

    @partial(jax.vmap, in_axes=(None, 0))
    def _dist_single(_x, _y_single):
        assert _x.ndim == _y_single.ndim + 1, f"{_x.ndim} != {_y_single.ndim + 1}"
        diff = _x - _y_single
        if x.ndim > 1:
            # take the norm over all axes except the first one
            diff = jnp.sqrt(jnp.sum(diff**2, axis=tuple(range(1, diff.ndim))))
        return jnp.mean(jnp.abs(diff))

    def dist(_x, _y):
        assert _x.ndim == _y.ndim
        return jnp.mean(_dist_single(_x, _y))

    return 2 * dist(x, y) - dist(x, x) - dist(y, y)


def eval_logreg(samples, evals_per_sample, ground_truth=None, num_iters_w2=0):
    if isinstance(samples, dict):
        samples = vec_dict_to_array(samples)
    ess = diagnostics.effective_sample_size(samples)
    print(f"Effective sample size: {ess}")
    # print(
    #     f"Gradient evals per effective sample:"
    #     f" {evals_per_sample * samples.shape[0]/ess}"
    # )
    if ground_truth is None:
        return
    sample_dim = samples.shape[-1]
    reshaped = jnp.reshape(samples, (-1, sample_dim))
    energy_gt = energy_distance(reshaped, ground_truth)
    half_len = reshaped.shape[0] // 2
    energy_self = energy_distance(reshaped[:half_len], reshaped[half_len:])
    result_str = (
        f"Energy dist vs ground truth: {energy_gt:.4}, vs self: {energy_self:.4}"
    )
    if num_iters_w2 > 0:
        w2 = compute_w2(reshaped, ground_truth, num_iters_w2)
        result_str += f", Wasserstein-2: {w2:.4}"
    print(result_str)

[cuda(id=0)]


In [2]:
gt_nuts = MCMC(NUTS(model_logreg, step_size=2**-4), num_warmup=2**10, num_samples=2**15)
gt_nuts.run(jr.PRNGKey(0))
gt_logreg = vec_dict_to_array(gt_nuts.get_samples())
np.save(f"mcmc_data/{data_name}_ground_truth.npy", gt_logreg)
gt_logreg = np.load(f"mcmc_data/{data_name}_ground_truth.npy")
size_gt_half = int(gt_logreg.shape[0] // 2)
energy_bias = energy_distance(gt_logreg[:size_gt_half], gt_logreg[size_gt_half:])
print(f"Energy bias: {energy_bias:.4}")

sample:  17%|█▋        | 5783/33792 [27:53<2:15:07,  3.45it/s, 1023 steps of size 1.92e-15. acc. prob=0.92]


KeyboardInterrupt: 

In [18]:
num_chains = 2**7
num_samples_per_chain = 2**8
out_logreg_lmc, steps_logreg_lmc = run_lmc(
    jr.PRNGKey(0),
    potential_fn,
    arr0,
    num_chains,
    num_samples_per_chain,
    chain_sep=0.25,
    tol=0.1,
    warmup_mult=128.0,
    warmup_tol_mult=4.0,
)
print(out_logreg_lmc.shape)

LMC: Steps warmup: 2300.2734375, steps mcmc: 24664.6640625, gradient evaluations per output: 210.66357421875
(128, 256, 9)


In [23]:
eval_logreg(out_logreg_lmc, steps_logreg_lmc, gt_logreg)

Effective sample size: [427.2759 323.3381 323.3381 323.3381 323.3381 323.3381 323.3381 323.3382
 323.3381]:.4
Gradient evals per effective sample: [63.109  83.3955 83.3955 83.3955 83.3955 83.3955 83.3955 83.3955 83.3955]
Energy distance vs ground truth: 1.216e+08, energy distance vs self: 0.007298


In [8]:
nuts = MCMC(
    NUTS(model_logreg),
    num_warmup=2**3,
    num_samples=num_samples_per_chain,
    num_chains=num_chains,
    chain_method="vectorized",
)
nuts.run(jr.PRNGKey(0), extra_fields=("num_steps",))
out_logreg_nuts = nuts.get_samples(group_by_chain=True)
num_steps_nuts = nuts.get_extra_fields()["num_steps"]
geps_nuts = sum(num_steps_nuts) / len(num_steps_nuts)
print(geps_nuts)

sample: 100%|██████████| 264/264 [00:59<00:00,  4.44it/s]


10.700408935546875


In [9]:
eval_logreg(out_logreg_nuts, geps_nuts, gt_logreg)

Effective sample size: [    nan 65.1355     nan]
Gradient evals per effective sample: [    nan 21.0277     nan]
W2 distance: 0.03975801461724085


  check_result(result_code)


In [2]:
import os


def run_logreg_dataset(name):
    model_logreg, data_split = get_model_and_data(data, name)
    logreg_info = initialize_model(jr.PRNGKey(0), model_logreg)
    arr0 = dict_to_array(logreg_info.param_info.z)

    @jax.jit
    def _potential_fn(arr: jnp.ndarray):
        dct = array_to_dict(arr)
        return logreg_info.potential_fn(dct)

    # Compute ground truth
    ground_truth_filename = f"mcmc_data/{name}_ground_truth.npy"
    if os.path.isfile(ground_truth_filename):
        ground_truth = np.load(ground_truth_filename)
    else:
        gt_nuts = MCMC(
            NUTS(model_logreg, step_size=0.125), num_warmup=2**13, num_samples=2**15
        )
        gt_nuts.run(jr.PRNGKey(0))
        ground_truth = vec_dict_to_array(gt_nuts.get_samples())
        np.save(f"mcmc_data/{name}_ground_truth.npy", ground_truth)
    size_gt_half = int(ground_truth.shape[0] // 2)
    gt_energy_bias = energy_distance(
        ground_truth[:size_gt_half], ground_truth[size_gt_half:]
    )
    print(f"Ground truth energy bias: {gt_energy_bias:.4}")
    num_chains = 2**7
    num_samples_per_chain = 2**8

    out_logreg_lmc, steps_logreg_lmc = run_lmc(
        jr.PRNGKey(0),
        _potential_fn,
        arr0,
        num_chains,
        num_samples_per_chain,
        chain_sep=0.25,
        tol=0.1,
        warmup_mult=128.0,
        warmup_tol_mult=8.0,
    )
    eval_logreg(out_logreg_lmc, steps_logreg_lmc, ground_truth)

    nuts = MCMC(
        NUTS(model_logreg),
        num_warmup=2**8,
        num_samples=num_samples_per_chain,
        num_chains=num_chains,
        chain_method="vectorized",
    )
    nuts.run(jr.PRNGKey(0), extra_fields=("num_steps",))
    out_logreg_nuts = nuts.get_samples(group_by_chain=True)
    num_steps_nuts = nuts.get_extra_fields()["num_steps"]
    geps_nuts = sum(num_steps_nuts) / len(num_steps_nuts)
    print(f"NUTS: Gradient evals per output: {geps_nuts}")
    eval_logreg(out_logreg_nuts, geps_nuts, ground_truth)

In [3]:
names = [
    "banana",
    "breast_cancer",
    "diabetis",
    "flare_solar",
    "german",
    "heart",
    "image",
    "ringnorm",
    "splice",
    "thyroid",
    "titanic",
    "twonorm",
    "waveform",
]

In [4]:
import warnings


with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    for name in names:
        print(f"==================== {name} ====================")
        run_logreg_dataset(name)

LMC
LMC: Steps warmup: 600.5859375, steps mcmc: 9082.40625, gradient evaluations per output: 75.64837646484375
Effective sample size: [  528.9647  3286.149  13097.2888]
Gradient evals per effective sample: [18.3056  2.9466  0.7393]
W2 distance: 0.023220977744452573
NUTS


sample: 100%|██████████| 320/320 [00:57<00:00,  5.56it/s]


Gradient evals per effective sample (NUTS): 13.71722412109375
Effective sample size: [ 6194.9014 27028.3926 24080.6835]
Gradient evals per effective sample: [0.2834 0.065  0.0729]
W2 distance: 0.00028803654533254326
LMC
LMC
LMC: Steps warmup: 25.6640625, steps mcmc: 910.875, gradient evaluations per output: 7.31671142578125
Effective sample size: [4788.6521  323.302   323.3166  323.2347  323.3078  323.2946  323.2509
  323.4151  319.7926  323.3168]
Gradient evals per effective sample: [0.1956 2.8968 2.8967 2.8974 2.8967 2.8969 2.8973 2.8958 2.9286 2.8967]
W2 distance: 540404996204079.0
NUTS


sample: 100%|██████████| 320/320 [00:32<00:00,  9.76it/s]


Gradient evals per effective sample (NUTS): 8.668121337890625
Effective sample size: [    nan     nan 65.2196     nan     nan     nan     nan     nan     nan
     nan]
Gradient evals per effective sample: [   nan    nan 17.012    nan    nan    nan    nan    nan    nan    nan]
W2 distance: 314604860600664.06
LMC


sample: 100%|██████████| 16640/16640 [00:36<00:00, 450.75it/s, 1 steps of size 1.29e-06. acc. prob=0.67]  


LMC
LMC: Steps warmup: 62.484375, steps mcmc: 1994.8359375, gradient evaluations per output: 16.07281494140625
Effective sample size: [573.7187 323.3002 323.2995 323.3002 323.3014 323.3013 323.3005 323.3007
 323.3003]
Gradient evals per effective sample: [3.5859 6.3635 6.3635 6.3635 6.3635 6.3635 6.3635 6.3635 6.3635]
W2 distance: 1200989517942678.5
NUTS


sample: 100%|██████████| 320/320 [01:00<00:00,  5.26it/s]


Gradient evals per effective sample (NUTS): 7.62481689453125
Effective sample size: [     nan 196.1794 196.1808 196.1828 196.1723 196.1621 196.1643 196.1705
 196.1765]
Gradient evals per effective sample: [   nan 4.9749 4.9749 4.9748 4.9751 4.9754 4.9753 4.9751 4.975 ]
W2 distance: 4459274847206856.5
LMC


sample: 100%|██████████| 16640/16640 [03:49<00:00, 72.54it/s, 63 steps of size 7.53e-02. acc. prob=0.94]


LMC
LMC: Steps warmup: 155.0390625, steps mcmc: 3010.21875, gradient evaluations per output: 24.72857666015625
Effective sample size: [ 698.3806 1616.1114 1647.3904 2554.9489 1671.9595 1785.5674 2131.7512
 2338.3801 2191.8755  827.9338]
Gradient evals per effective sample: [4.5323 1.9586 1.9214 1.2389 1.8931 1.7727 1.4848 1.3536 1.4441 3.8231]
W2 distance: 0.36185241246931377
NUTS


sample: 100%|██████████| 320/320 [00:16<00:00, 19.75it/s]


Gradient evals per effective sample (NUTS): 53.47589111328125
Effective sample size: [13450.8355 10240.8361 18660.7307 23661.4046 22750.6907 20257.2226
 22369.4588 22118.4988 28380.4042 26857.4759]
Gradient evals per effective sample: [0.5089 0.6684 0.3668 0.2893 0.3009 0.3379 0.306  0.3095 0.2412 0.2549]
W2 distance: 0.0263435363146163
LMC


sample: 100%|██████████| 16640/16640 [00:51<00:00, 322.76it/s, 1 steps of size 1.18e-06. acc. prob=0.33]   


LMC
LMC: Steps warmup: 85.265625, steps mcmc: 2081.6015625, gradient evaluations per output: 16.92864990234375
Effective sample size: [5710.5647  323.3007  323.3041  323.2999  323.3076  323.2988  323.3017
  323.3027  323.2986  323.052   323.2998  323.305   323.3088  323.3003
  323.3063  323.3049  323.3002  323.3024  323.309   323.2983  323.2973]
Gradient evals per effective sample: [0.3794 6.7023 6.7023 6.7023 6.7022 6.7024 6.7023 6.7023 6.7024 6.7075
 6.7023 6.7022 6.7022 6.7023 6.7022 6.7022 6.7023 6.7023 6.7022 6.7024
 6.7024]
W2 distance: 3.393477821155428e+24
NUTS


sample: 100%|██████████| 320/320 [01:29<00:00,  3.59it/s]


Gradient evals per effective sample (NUTS): 9.33795166015625
Effective sample size: [ 65.464  206.1236      nan 206.101  206.0576      nan 206.1223      nan
 206.0072      nan 206.0749 205.7906 197.6787 206.0995 206.2341      nan
      nan 206.0888      nan      nan 190.1711]
Gradient evals per effective sample: [18.2582  5.7987     nan  5.7994  5.8006     nan  5.7988     nan  5.802
     nan  5.8001  5.8081  6.0465  5.7994  5.7956     nan     nan  5.7997
     nan     nan  6.2852]
W2 distance: 3.3186743708373977e+24
LMC


sample: 100%|██████████| 16640/16640 [00:34<00:00, 477.32it/s, 7 steps of size 5.25e-14. acc. prob=0.66] 


LMC
LMC: Steps warmup: 42.1640625, steps mcmc: 1483.6875, gradient evaluations per output: 11.92071533203125
Effective sample size: [1414.7265  323.3028  323.2343  323.4078  323.3107  323.3177  323.2816
  323.1749  323.2894  323.1315  323.3381  323.2948  323.3067  323.2774]
Gradient evals per effective sample: [1.0785 4.7196 4.7206 4.718  4.7195 4.7194 4.7199 4.7214 4.7198 4.7221
 4.7191 4.7197 4.7195 4.7199]
W2 distance: 8.482035947069806e+28
NUTS


sample: 100%|██████████| 320/320 [00:32<00:00,  9.81it/s]


Gradient evals per effective sample (NUTS): 8.098663330078125
Effective sample size: [     nan      nan 657.0863 856.8401      nan      nan 901.309  861.8158
      nan 910.4092      nan      nan      nan      nan]
Gradient evals per effective sample: [   nan    nan 1.5776 1.2098    nan    nan 1.1501 1.2028    nan 1.1386
    nan    nan    nan    nan]
W2 distance: 8.701650165076928e+28
LMC


sample: 100%|██████████| 16640/16640 [22:21<00:00, 12.41it/s, 255 steps of size 1.23e-02. acc. prob=0.93]


LMC
LMC: Steps warmup: 615.046875, steps mcmc: 11311.875, gradient evaluations per output: 93.1790771484375
Effective sample size: [  723.6322  3410.5624 33827.7544 33948.3572 13908.3542 17067.3401
 32070.6369 23602.6354 39665.0891   680.5408   741.2097   849.3059
   794.5553  1544.7576   792.7175  1413.0643  3209.898  19503.4566
 18116.6526]
Gradient evals per effective sample: [16.482   3.4971  0.3526  0.3513  0.8575  0.6988  0.3719  0.5053  0.3007
 17.5257 16.0912 14.0431 15.0108  7.7209 15.0456  8.4405  3.7157  0.6115
  0.6583]
W2 distance: 1.3025857361000306
NUTS


sample: 100%|██████████| 320/320 [03:44<00:00,  1.43it/s]


Gradient evals per effective sample (NUTS): 402.3988037109375
Effective sample size: [17560.1831 38723.1214 38069.9341 37662.2673 37388.7044 36407.0894
 38416.1394 36735.6245 37455.7208 17507.6692 17810.295  18170.3777
 16267.176  11776.9843 11799.6105 11661.0834 19650.3976 38315.9265
 30417.2787]
Gradient evals per effective sample: [2.9332 1.3301 1.353  1.3676 1.3776 1.4148 1.3408 1.4021 1.3751 2.942
 2.892  2.8347 3.1663 4.3735 4.3651 4.417  2.6212 1.3443 1.6933]
W2 distance: 0.028828142973814908
LMC


sample: 100%|██████████| 16640/16640 [00:38<00:00, 433.82it/s, 7 steps of size 4.33e-01. acc. prob=0.88]


LMC
LMC: Steps warmup: 405.2578125, steps mcmc: 4325.484375, gradient evaluations per output: 36.95892333984375
Effective sample size: [  3607.6973   2599.8099  39796.7409  30625.6768  99429.4684  50584.0341
  28101.4976  55753.1025  44419.7559  51062.2544  44552.4141 108917.2809
  38147.6377  21997.2671  27875.6335  45483.1593  23317.8221  41734.3787
  40939.5226  42408.7979  57361.5279]
Gradient evals per effective sample: [1.3113 1.8196 0.1189 0.1545 0.0476 0.0935 0.1683 0.0849 0.1065 0.0926
 0.1062 0.0434 0.124  0.2151 0.1697 0.104  0.2029 0.1134 0.1156 0.1116
 0.0825]
W2 distance: 0.08947221389612854
NUTS


sample: 100%|██████████| 320/320 [01:02<00:00,  5.10it/s]


Gradient evals per effective sample (NUTS): 44.67138671875
Effective sample size: [ 3720.4145 34619.5969 32705.1965 36623.73   34188.8458 34810.0174
 35329.5534 36547.4325 33986.4087 34763.0171 35004.6241 31553.9146
 34387.0663 35957.3309 36035.1444 32403.422  36520.8932 35625.0128
 33753.406  35298.171  36067.9402]
Gradient evals per effective sample: [1.5369 0.1652 0.1748 0.1561 0.1672 0.1643 0.1618 0.1565 0.1682 0.1645
 0.1633 0.1812 0.1663 0.159  0.1587 0.1765 0.1566 0.1605 0.1694 0.162
 0.1585]
W2 distance: 0.001283718426240755
LMC


sample: 100%|██████████| 16640/16640 [00:57<00:00, 287.42it/s, 1 steps of size 1.67e-04. acc. prob=0.99]   


LMC
LMC: Steps warmup: 30.0, steps mcmc: 798.84375, gradient evaluations per output: 6.475341796875
Effective sample size: [1782.3168  323.3003  323.3003  323.3003  323.3003  323.3003  323.3003
  323.3003  323.3003  323.3003  323.3003  323.3003  323.3003  323.3003
  323.3003  323.3003  323.3003  323.3003  323.3003  323.3003  323.3003
  323.3003  323.3003  323.3003  323.3003  323.3003  323.3003  323.3003
  323.3003  323.3003  323.3003  323.3003  323.3003  323.3003  323.3003
  323.3003  323.3003  323.3003  323.3003  323.3003  323.3003  323.3003
  323.3003  323.3003  323.3003  323.3003  323.3003  323.3003  323.3003
  323.3003  323.3003  323.3003  323.3002  323.3003  323.3003  323.3003
  323.3003  323.3003  323.3003  323.3003  323.3003]
Gradient evals per effective sample: [0.465  2.5637 2.5637 2.5637 2.5637 2.5637 2.5637 2.5637 2.5637 2.5637
 2.5637 2.5637 2.5637 2.5637 2.5637 2.5637 2.5637 2.5637 2.5637 2.5637
 2.5637 2.5637 2.5637 2.5637 2.5637 2.5637 2.5637 2.5637 2.5637 2.5637
 2.5637

sample: 100%|██████████| 320/320 [01:18<00:00,  4.09it/s]


Gradient evals per effective sample (NUTS): 5.389129638671875
Effective sample size: [65.6532     nan     nan 76.0958     nan     nan     nan     nan     nan
     nan 76.0962     nan 76.0963     nan 76.0963     nan     nan     nan
     nan 76.0975 76.0958     nan     nan     nan     nan 76.0959     nan
 76.096  76.0956     nan 76.0951     nan     nan 76.0963     nan     nan
 76.0955     nan     nan     nan     nan 76.0954     nan     nan 76.0965
     nan 76.0959 76.0965 76.0958 76.0957 76.0966     nan     nan     nan
     nan 76.096      nan     nan     nan 76.0961     nan]
Gradient evals per effective sample: [10.5069     nan     nan  9.065      nan     nan     nan     nan     nan
     nan  9.065      nan  9.0649     nan  9.0649     nan     nan     nan
     nan  9.0648  9.065      nan     nan     nan     nan  9.065      nan
  9.065   9.065      nan  9.0651     nan     nan  9.0649     nan     nan
  9.065      nan     nan     nan     nan  9.065      nan     nan  9.0649
     nan  9.065  

sample: 100%|██████████| 16640/16640 [00:43<00:00, 378.96it/s, 3 steps of size 4.57e-07. acc. prob=0.89]  


LMC
LMC: Steps warmup: 21.0, steps mcmc: 777.0, gradient evaluations per output: 6.234375
Effective sample size: [1779.974   323.302   323.3012  323.2986  323.3025  323.3011]
Gradient evals per effective sample: [0.4483 2.4683 2.4683 2.4683 2.4683 2.4683]
W2 distance: 3.0464607324680264e+18
NUTS


sample: 100%|██████████| 320/320 [00:15<00:00, 20.57it/s]


Gradient evals per effective sample (NUTS): 3.94195556640625
Effective sample size: [     nan      nan      nan      nan 182.5514      nan]
Gradient evals per effective sample: [  nan   nan   nan   nan 2.764   nan]
W2 distance: 1.6840457900237903e+18
LMC


sample: 100%|██████████| 16640/16640 [00:40<00:00, 407.78it/s, 1 steps of size 2.69e-01. acc. prob=0.63] 


LMC
LMC: Steps warmup: 12.0, steps mcmc: 2089.0078125, gradient evaluations per output: 16.41412353515625
Effective sample size: [  381.9657  5799.2398 13774.3345  9335.5848]
Gradient evals per effective sample: [5.5005 0.3623 0.1525 0.2251]
W2 distance: 0.06714026281601781
NUTS


sample: 100%|██████████| 320/320 [00:10<00:00, 30.59it/s]


Gradient evals per effective sample (NUTS): 11.40740966796875
Effective sample size: [  326.841   3389.2428 14219.9902 14664.0662]
Gradient evals per effective sample: [4.4675 0.4308 0.1027 0.0996]
W2 distance: 20.766105157542803
LMC


sample: 100%|██████████| 16640/16640 [1:23:20<00:00,  3.33it/s, 1023 steps of size 4.49e-16. acc. prob=0.47]


LMC
LMC: Steps warmup: 40.59375, steps mcmc: 1553.203125, gradient evaluations per output: 12.4515380859375
Effective sample size: [760.1273 323.3    323.3001 323.3    323.3    323.3    323.3    323.3
 323.3    323.3001 323.3    323.2999 323.2999 323.2999 323.3001 323.3
 323.3    323.3    323.3    323.3    323.3   ]
Gradient evals per effective sample: [2.0967 4.9298 4.9298 4.9298 4.9298 4.9298 4.9298 4.9298 4.9298 4.9298
 4.9298 4.9298 4.9298 4.9298 4.9298 4.9298 4.9298 4.9298 4.9298 4.9298
 4.9298]
W2 distance: 1.4889128894639722e+24
NUTS


sample: 100%|██████████| 320/320 [00:43<00:00,  7.28it/s]


Gradient evals per effective sample (NUTS): 3.792510986328125
Effective sample size: [     nan      nan      nan 129.4707      nan      nan      nan      nan
      nan 129.4675      nan 129.4731      nan      nan 129.4689      nan
      nan 129.4727      nan      nan 129.4715]
Gradient evals per effective sample: [   nan    nan    nan 3.7494    nan    nan    nan    nan    nan 3.7495
    nan 3.7494    nan    nan 3.7495    nan    nan 3.7494    nan    nan
 3.7494]
W2 distance: 1.4843556079383466e+24
LMC


sample: 100%|██████████| 16640/16640 [1:23:37<00:00,  3.32it/s, 1023 steps of size 2.15e-15. acc. prob=0.19]


LMC
LMC: Steps warmup: 30.0, steps mcmc: 781.7578125, gradient evaluations per output: 6.34185791015625
Effective sample size: [1392.787   323.3003  323.3005  323.2975  323.3502  323.3044  323.303
  323.2576  323.2774  323.302   323.3014  323.2743  323.3023  323.3026
  323.302   323.3028  323.3014  323.3064  323.3054  323.2991  323.2992
  323.3093]
Gradient evals per effective sample: [0.5828 2.5108 2.5108 2.5109 2.5105 2.5108 2.5108 2.5112 2.511  2.5108
 2.5108 2.511  2.5108 2.5108 2.5108 2.5108 2.5108 2.5108 2.5108 2.5109
 2.5109 2.5108]
W2 distance: 9.646766406379034e+21
NUTS


sample: 100%|██████████| 320/320 [00:47<00:00,  6.78it/s]


Gradient evals per effective sample (NUTS): 3.255615234375
Effective sample size: [    nan     nan     nan     nan     nan     nan     nan 64.1405     nan
 73.1303     nan     nan     nan 73.0969     nan     nan 73.1399     nan
 72.8051     nan     nan     nan]
Gradient evals per effective sample: [   nan    nan    nan    nan    nan    nan    nan 6.497     nan 5.6983
    nan    nan    nan 5.7009    nan    nan 5.6976    nan 5.7238    nan
    nan    nan]
W2 distance: 1.028961149034422e+22
LMC
