In [11]:
import math
from operator import mul


%env JAX_PLATFORM_NAME=cuda

import warnings
from functools import partial, reduce

import jax
import jax.numpy as jnp
import jax.random as jr
import jax.tree_util as jtu
import numpy as np
import numpyro
import ot
import scipy
from jax import Array
from numpyro import diagnostics, distributions as dist
from numpyro.infer import MCMC, NUTS

from mcmc import run_lmc_numpyro


warnings.simplefilter("ignore", FutureWarning)

jnp.set_printoptions(precision=3, suppress=True)
jax.config.update("jax_enable_x64", True)
print(jax.devices("cuda"))


def get_model_and_data(data, name):
    dset = data[name][0, 0]
    x = dset["x"]
    labels = jnp.squeeze(dset["t"])
    # labels are -1 and 1, convert to 0 and 1
    labels = (labels + 1) / 2
    n, data_dim = x.shape
    print(f"Data shape: {x.shape}")

    # randomly shuffle the data
    perm = jax.random.permutation(jr.PRNGKey(0), n)
    x = x[perm]
    labels = labels[perm]

    n_train = min(int(n * 0.8), 1000)
    x_train = x[:n_train]
    labels_train = labels[:n_train]
    x_test = x[n_train:]
    labels_test = labels[n_train:]

    def model(x, labels):
        alpha = numpyro.sample("alpha", dist.Exponential(0.01))
        W = numpyro.sample("W", dist.Normal(jnp.zeros(data_dim), 1.0 / alpha))
        b = numpyro.sample("b", dist.Normal(jnp.zeros((1,)), 1.0 / alpha))
        logits = jnp.sum(W * x + b, axis=-1)
        return numpyro.sample("obs", dist.Bernoulli(logits=logits), obs=labels)

    return model, (x_train, labels_train, x_test, labels_test)


def compute_w2(x1, x2, num_iters):
    source_samples = np.array(x1)
    target_samples = np.array(x2)
    source_weights = np.ones(source_samples.shape[0]) / source_samples.shape[0]
    target_weights = np.ones(target_samples.shape[0]) / target_samples.shape[0]
    mm = ot.dist(source_samples, target_samples)
    return ot.emd2(source_weights, target_weights, mm, numItermax=num_iters)


@partial(jax.jit, static_argnames=("max_len",))
def energy_distance(x: Array, y: Array, max_len: int = 2**16):
    assert y.ndim == x.ndim
    assert x.shape[1:] == y.shape[1:]
    prod = reduce(mul, x.shape[1:], 1)
    if prod >= 4:
        max_len = int(max_len / math.sqrt(prod))

    if x.shape[0] > max_len:
        x = x[:max_len]
    if y.shape[0] > max_len:
        y = y[:max_len]

    @partial(jax.vmap, in_axes=(None, 0))
    def _dist_single(_x, _y_single):
        assert _x.ndim == _y_single.ndim + 1, f"{_x.ndim} != {_y_single.ndim + 1}"
        diff = _x - _y_single
        if x.ndim > 1:
            # take the norm over all axes except the first one
            diff = jnp.sqrt(jnp.sum(diff**2, axis=tuple(range(1, diff.ndim))))
        return jnp.mean(jnp.abs(diff))

    def dist(_x, _y):
        assert _x.ndim == _y.ndim
        return jnp.mean(_dist_single(_x, _y))

    return 2 * dist(x, y) - dist(x, x) - dist(y, y)


def dict_to_array(dct: dict):
    alpha = dct["alpha"]
    alpha = jnp.expand_dims(alpha, alpha.ndim)
    b = dct["b"]
    return jnp.concatenate([alpha, b, dct["W"]], axis=-1)


vec_dict_to_array = jax.jit(jax.vmap(dict_to_array, in_axes=0, out_axes=0))


def flatten_samples(samples):
    if isinstance(samples, dict):
        samples = vec_dict_to_array(samples)
    # remove alpha
    samples = samples[..., 1:]
    return jnp.reshape(samples, (-1, samples.shape[-1]))


def predict(x, samples):
    sum = jnp.sum(samples[:, 2:] * x + samples[:, 1:2], axis=-1)
    # apply sigmoid
    return 1.0 / (1.0 + jnp.exp(-sum))


def test_accuracy(x_test, labels_test, samples):
    if isinstance(samples, dict):
        samples = vec_dict_to_array(samples)
    sample_dim = samples.shape[-1]
    samples = jnp.reshape(samples, (-1, sample_dim))
    if samples.shape[0] > 2**10:
        samples = samples[: 2**10]

    func = jax.jit(jax.vmap(lambda x: predict(x, samples), in_axes=0, out_axes=0))
    predictions = func(x_test)
    assert predictions.shape == (
        labels_test.shape[0],
        samples.shape[0],
    ), f"{predictions.shape} != {(labels_test.shape[0], samples.shape[0])}"

    labels_test = jnp.reshape(labels_test, (labels_test.shape[0], 1))
    is_correct = jnp.abs(predictions - labels_test) < 0.5
    accuracy_per_sample = jnp.mean(is_correct, axis=0)

    avg_accuracy = jnp.mean(accuracy_per_sample)

    len10 = int(0.1 * accuracy_per_sample.shape[0])
    best_sorted = jnp.sort(accuracy_per_sample)[len10:]
    accuracy_best90 = jnp.mean(best_sorted)
    return avg_accuracy, accuracy_best90


def eval_logreg(
    samples,
    evals_per_sample=None,
    ground_truth=None,
    num_iters_w2=0,
    x_test=None,
    labels_test=None,
):
    if isinstance(samples, dict):
        samples = vec_dict_to_array(samples)

    sample_dim = samples.shape[-1]
    reshaped_with_alpha = jnp.reshape(samples, (-1, sample_dim))
    vars = jnp.var(reshaped_with_alpha, axis=0)
    means = jnp.mean(reshaped_with_alpha, axis=0)
    result_str = f"means: {means},\nvars:  {vars}"

    samples_with_alpha = samples
    samples = samples[..., 1:]
    reshaped = jnp.reshape(samples, (-1, sample_dim - 1))

    ess = diagnostics.effective_sample_size(samples)
    avg_ess = 1 / jnp.mean(1 / jnp.stack(jtu.tree_leaves(ess)))
    ess_per_sample = avg_ess / reshaped.shape[0]
    result_str += (
        f"\nEffective sample size: {avg_ess:.4},"
        f" ess per sample: {ess_per_sample:.4}"
    )
    if evals_per_sample is not None:
        result_str += f", grad evals per sample: {evals_per_sample:.4}"

    half_len = reshaped.shape[0] // 2
    energy_self = energy_distance(reshaped[:half_len], reshaped[half_len:])
    result_str += f"\nEnergy dist v self: {energy_self:.4}"

    if ground_truth is not None:
        ground_truth = ground_truth[..., 1:]
        energy_gt = energy_distance(reshaped, ground_truth)
        result_str += f", energy dist vs ground truth: {energy_gt:.4}"
    if num_iters_w2 > 0 and ground_truth is not None:
        w2 = compute_w2(reshaped, ground_truth, num_iters_w2)
        result_str += f", Wasserstein-2: {w2:.4}"

    if x_test is not None and labels_test is not None:
        acc_error, acc_best90 = test_accuracy(x_test, labels_test, samples_with_alpha)
        result_str += (
            f"\nTest_accuracy: {acc_error:.4}, top 90% accuracy: {acc_best90:.4}"
        )
    else:
        acc_error, acc_best90 = None, None

    print(result_str)

    result_dict = {
        "ess": avg_ess,
        "ess_per_sample": ess_per_sample,
        "energy_v_self": energy_self,
        "grad_evals_per_sample": evals_per_sample,
        "test_accuracy": acc_error,
        "top90_accuracy": acc_best90,
    }

    return result_str, result_dict


dataset = scipy.io.loadmat("mcmc_data/benchmarks.mat")
data_name = "flare_solar"
model_logreg, data_split = get_model_and_data(dataset, data_name)
x_train, labels_train, x_test, labels_test = data_split

env: JAX_PLATFORM_NAME=cuda
[cuda(id=0)]
Data shape: (144, 9)


In [2]:
file_name = f"mcmc_data/{data_name}_ground_truth.npy"

# gt_nuts = MCMC(NUTS(model_logreg, step_size=1.0), num_warmup=2**13, num_samples=2**16)
# gt_nuts.run(jr.PRNGKey(0), x_train, labels_train)
# gt_logreg = vec_dict_to_array(gt_nuts.get_samples())
# # thin the samples to 2**15
# np.save(file_name, gt_logreg)

gt_logreg = np.load(file_name)
gt_logreg = gt_logreg[::4]
size_gt_half = int(gt_logreg.shape[0] // 2)
energy_bias = energy_distance(gt_logreg[:size_gt_half], gt_logreg[size_gt_half:])
print(f"Energy bias: {energy_bias}")
print(f"Ground truth shape: {gt_logreg.shape}")
print(f"test accuracy: {test_accuracy(x_test, labels_test, gt_logreg)}")
flattened_gt = jnp.reshape(gt_logreg, (-1, 4))
print(flattened_gt.shape)
print(jnp.var(flattened_gt, axis=0))
print(jnp.mean(flattened_gt, axis=0))

FileNotFoundError: [Errno 2] No such file or directory: 'mcmc_data/banana_ground_truth.npy'

In [23]:
num_chains = 2**5
num_samples_per_chain = 2**11
warmup_len = 2**13

In [24]:
out_logreg_lmc, steps_logreg_lmc = run_lmc_numpyro(
    jr.PRNGKey(2),
    model_logreg,
    (x_train, labels_train),
    num_chains,
    num_samples_per_chain,
    chain_sep=0.2,
    tol=0.01,
    warmup_mult=warmup_len,
    warmup_tol_mult=32,
    use_adaptive=False,
)
out_logreg_lmc["alpha"] = jnp.exp(out_logreg_lmc["alpha"])
print(jtu.tree_map(lambda x: x.shape, out_logreg_lmc))

100.00%|██████████| [00:27<00:00,  3.69%/s]
100.00%|██████████| [04:13<00:00,  2.53s/%]

{'W': (32, 2048, 9), 'alpha': (32, 2048), 'b': (32, 2048, 1)}





In [33]:
flat_lmc = flatten_samples(out_logreg_lmc)
print(flat_lmc.shape)
outlier_positions = jnp.abs(flat_lmc) > 400
outlier_positions = jnp.any(outlier_positions, axis=1)
outliers = flat_lmc[outlier_positions]
print(outliers)
print(outliers.shape)

(65536, 10)
[]
(0, 10)


In [26]:
_ = eval_logreg(
    out_logreg_lmc, steps_logreg_lmc, x_test=x_test, labels_test=labels_test
)

means: [   0.013   44.606    0.313    0.27    -0.25     1.263 -208.002   12.058
    0.021    0.382  129.781],
vars:  [   0.      51.064    0.793    0.483    0.502    6.237  949.076  495.334
    0.033    0.608 1168.32 ]
Effective sample size: 61.44, ess per sample: 0.0009375, grad evals per sample: 45.06
Energy dist v self: 2.66
Test_accuracy: 0.6211, top 90% accuracy: 0.6296


In [27]:
nuts = MCMC(
    NUTS(model_logreg),
    num_warmup=warmup_len,
    num_samples=num_samples_per_chain,
    num_chains=num_chains,
    chain_method="vectorized",
)
nuts.warmup(
    jr.PRNGKey(2),
    x_train,
    labels_train,
    extra_fields=("num_steps",),
    collect_warmup=True,
)
warmup_steps = sum(nuts.get_extra_fields()["num_steps"])
nuts.run(jr.PRNGKey(2), x_train, labels_train, extra_fields=("num_steps",))
out_logreg_nuts = nuts.get_samples(group_by_chain=True)
num_steps_nuts = sum(nuts.get_extra_fields()["num_steps"]) + warmup_steps
geps_nuts = num_steps_nuts / (num_chains * num_samples_per_chain)
print(geps_nuts)
print(jtu.tree_map(lambda x: x.shape, out_logreg_nuts))

warmup: 100%|██████████| 8192/8192 [00:56<00:00, 146.28it/s]
sample: 100%|██████████| 2048/2048 [00:29<00:00, 69.19it/s] 


63.75968933105469
{'W': (32, 2048, 9), 'alpha': (32, 2048), 'b': (32, 2048, 1)}


In [28]:
_ = eval_logreg(out_logreg_nuts, geps_nuts, x_test=x_test, labels_test=labels_test)

means: [22.333  0.05   0.018  0.025 -0.     0.045 -0.018  0.04   0.002  0.023
  0.012],
vars:  [237.374   0.002   0.008   0.007   0.007   0.012   0.018   0.015   0.003
   0.007   0.01 ]
Effective sample size: 1.826e+04, ess per sample: 0.2786, grad evals per sample: 63.76
Energy dist v self: 3.95e-05
Test_accuracy: 0.5516, top 90% accuracy: 0.5518


In [29]:
flat_nuts = flatten_samples(out_logreg_nuts)
enenrgy_dist = energy_distance(flat_nuts, flat_lmc)
print(enenrgy_dist)

446.56067682614116


In [3]:
import pickle


def run_logreg_dataset(name, results_filename=None, results_dict_filename=None):
    model_logreg, data_split = get_model_and_data(dataset, name)
    x_train, labels_train, x_test, labels_test = data_split

    num_chains = 2**5
    num_samples_per_chain = 2**11
    warmup_len = 2**13

    nuts = MCMC(
        NUTS(model_logreg),
        num_warmup=warmup_len,
        num_samples=num_samples_per_chain,
        num_chains=num_chains,
        chain_method="vectorized",
    )
    nuts.warmup(
        jr.PRNGKey(2),
        x_train,
        labels_train,
        extra_fields=("num_steps",),
        collect_warmup=True,
    )
    warmup_steps = sum(nuts.get_extra_fields()["num_steps"])
    nuts.run(jr.PRNGKey(2), x_train, labels_train, extra_fields=("num_steps",))
    out_logreg_nuts = nuts.get_samples(group_by_chain=True)
    num_steps_nuts = sum(nuts.get_extra_fields()["num_steps"]) + warmup_steps
    geps_nuts = num_steps_nuts / (num_chains * num_samples_per_chain)
    print("NUTS:")
    eval_nuts_str, eval_nuts_dict = eval_logreg(
        out_logreg_nuts,
        geps_nuts,
        ground_truth=None,
        x_test=x_test,
        labels_test=labels_test,
    )

    lmc_tol = 0.01
    chain_sep = (0.4 * num_steps_nuts / num_chains) * (
        lmc_tol / (num_samples_per_chain + 4 + warmup_len / 32)
    )
    print(f"Target chain separation: {chain_sep}")
    if chain_sep < 0.1:
        chain_sep = 0.1

    out_logreg_lmc, geps_lmc = run_lmc_numpyro(
        jr.PRNGKey(3),
        model_logreg,
        (x_train, labels_train),
        num_chains,
        num_samples_per_chain,
        chain_sep=chain_sep,
        tol=lmc_tol,
        warmup_mult=warmup_len,
        warmup_tol_mult=32,
        use_adaptive=False,
    )
    out_logreg_lmc["alpha"] = jnp.exp(out_logreg_lmc["alpha"])

    eval_lmc_str, eval_lmc_dict = eval_logreg(
        out_logreg_lmc,
        geps_lmc,
        ground_truth=None,
        x_test=x_test,
        labels_test=labels_test,
    )

    # Compute energy distance between the two methods
    lmc_flat = flatten_samples(out_logreg_lmc)
    nuts_flat = flatten_samples(out_logreg_nuts)

    energy_dist = energy_distance(lmc_flat, nuts_flat)
    print(f"Energy distance between LMC and NUTS: {energy_dist:.5}")

    if results_filename is not None:
        with open(results_filename, "a") as f:
            f.write(f"{name}\n")
            f.write(f"LMC: {eval_lmc_str}\n\n")
            f.write(f"NUTS: {eval_nuts_str}\n\n")
            f.write(f"Energy distance: {energy_dist:.5}\n\n\n")

    results_dict = {
        "dataset_name": name,
        "LMC": eval_lmc_dict,
        "NUTS": eval_nuts_dict,
        "Energy distance": energy_dist,
    }

    if results_dict_filename is not None:
        with open(results_dict_filename, "wb") as f:
            pickle.dump(results_dict, f)

In [4]:
names = [
    "banana",
    "breast_cancer",
    "diabetis",
    "flare_solar",
    "german",
    "heart",
    "image",
    "ringnorm",
    "splice",
    "thyroid",
    "titanic",
    "twonorm",
    "waveform",
]

In [None]:
# make a file for the results, which has date and time in the name
import datetime


time = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
results_filename = f"mcmc_data/results_{time}.txt"
results_dict_filename = f"mcmc_data/results_dict_{time}.pkl"

# create the results file
with open(results_filename, "w") as f:
    f.write("Results\n\n")

# with warnings.catch_warnings():
#     warnings.simplefilter("ignore")
for name in names:
    print(f"==================== {name} ====================")
    run_logreg_dataset(name, results_filename, results_dict_filename)
    print()

Data shape: (5300, 2)


warmup: 100%|██████████| 8192/8192 [00:26<00:00, 306.29it/s]
sample: 100%|██████████| 2048/2048 [00:08<00:00, 234.85it/s]


NUTS:
means: [16.761 -0.101 -0.03  -0.047],
vars:  [162.443   0.001   0.002   0.003]
Effective sample size: 2.806e+04, ess per sample: 0.4282, grad evals per sample: 32.16
Energy dist v self: 9.073e-06
Test_accuracy: 0.5509, top 90% accuracy: 0.5549
Target chain separation: 0.07134521907495668


100.00%|██████████| [00:13<00:00,  7.40%/s]
100.00%|██████████| [02:06<00:00,  1.27s/%]


means: [16.46  -0.111 -0.036 -0.056],
vars:  [322.838   0.283   0.189   0.261]
Effective sample size: 2.761e+04, ess per sample: 0.4214, grad evals per sample: 22.53
Energy dist v self: 0.0001146
Test_accuracy: 0.546, top 90% accuracy: 0.5544
Energy distance between LMC and NUTS: 0.001113

Data shape: (263, 9)


warmup: 100%|██████████| 8192/8192 [00:24<00:00, 329.15it/s]
sample: 100%|██████████| 2048/2048 [00:08<00:00, 230.53it/s]


NUTS:
means: [ 4.596 -0.123 -0.097  0.107  0.051  0.295 -0.211  0.425 -0.11  -0.002
 -0.133],
vars:  [2.486 0.    0.024 0.027 0.02  0.026 0.022 0.031 0.019 0.02  0.017]
Effective sample size: 6.341e+04, ess per sample: 0.9675, grad evals per sample: 37.85
Energy dist v self: 4.937e-05
Test_accuracy: 0.6586, top 90% accuracy: 0.6632
Target chain separation: 0.0839616957322357


100.00%|██████████| [00:13<00:00,  7.27%/s]
100.00%|██████████| [02:08<00:00,  1.28s/%]


means: [ 4.471 -0.161 -0.118  0.232  0.049  0.408 -0.268  0.76  -0.186 -0.009
 -0.175],
vars:  [2.798 0.219 0.24  1.154 0.162 1.159 0.64  9.702 0.671 0.041 0.305]
Effective sample size: 3.951e+03, ess per sample: 0.06029, grad evals per sample: 22.53
Energy dist v self: 0.0001789
Test_accuracy: 0.6587, top 90% accuracy: 0.664
Energy distance between LMC and NUTS: 0.009713

Data shape: (768, 8)


warmup: 100%|██████████| 8192/8192 [00:22<00:00, 368.61it/s]
sample: 100%|██████████| 2048/2048 [00:07<00:00, 282.18it/s]


NUTS:
means: [ 2.083 -0.105  0.371  1.136 -0.298 -0.015 -0.142  0.714  0.229  0.159],
vars:  [0.264 0.    0.013 0.017 0.012 0.014 0.012 0.017 0.011 0.014]
Effective sample size: 8.386e+04, ess per sample: 1.28, grad evals per sample: 34.5
Energy dist v self: 4.048e-05
Test_accuracy: 0.778, top 90% accuracy: 0.7805
Target chain separation: 0.07654422118717505


100.00%|██████████| [00:14<00:00,  7.11%/s]
100.00%|██████████| [02:11<00:00,  1.32s/%]


means: [ 2.08  -0.105  0.372  1.137 -0.299 -0.016 -0.142  0.715  0.229  0.159],
vars:  [0.271 0.    0.013 0.02  0.013 0.014 0.014 0.021 0.012 0.014]
Effective sample size: 2.932e+04, ess per sample: 0.4473, grad evals per sample: 22.53
Energy dist v self: 5.438e-05
Test_accuracy: 0.7776, top 90% accuracy: 0.7803
Energy distance between LMC and NUTS: 7.637e-05

Data shape: (144, 9)


warmup: 100%|██████████| 8192/8192 [00:54<00:00, 149.49it/s]
sample: 100%|██████████| 2048/2048 [00:29<00:00, 69.25it/s] 


NUTS:
means: [22.333  0.05   0.018  0.025 -0.     0.045 -0.018  0.04   0.002  0.023
  0.012],
vars:  [237.374   0.002   0.008   0.007   0.007   0.012   0.018   0.015   0.003
   0.007   0.01 ]
Effective sample size: 1.826e+04, ess per sample: 0.2786, grad evals per sample: 63.76
Energy dist v self: 3.95e-05
Test_accuracy: 0.5516, top 90% accuracy: 0.5518
Target chain separation: 0.14144263837738302


100.00%|██████████| [00:20<00:00,  4.89%/s]
100.00%|██████████| [03:12<00:00,  1.93s/%]


means: [   0.014   41.193    0.337    0.302   -0.272    1.368 -194.248   15.019
    0.021    0.406  120.766],
vars:  [   0.      27.697    1.053    0.785    0.682   10.246  468.104  580.589
    0.031    0.828 1216.831]
Effective sample size: 71.62, ess per sample: 0.001093, grad evals per sample: 33.58
Energy dist v self: 1.474
Test_accuracy: 0.6226, top 90% accuracy: 0.6313
Energy distance between LMC and NUTS: 412.16

Data shape: (1000, 20)


warmup: 100%|██████████| 8192/8192 [00:30<00:00, 266.03it/s]
sample: 100%|██████████| 2048/2048 [00:10<00:00, 197.31it/s]


NUTS:
means: [ 4.255 -0.053 -0.653  0.239 -0.341 -0.087  0.234 -0.328 -0.147  0.18
 -0.217 -0.13   0.064  0.244 -0.128 -0.151 -0.097  0.08   0.002  0.052
 -0.117 -0.156],
vars:  [0.577 0.    0.009 0.009 0.009 0.006 0.011 0.008 0.008 0.008 0.007 0.008
 0.008 0.009 0.009 0.006 0.008 0.008 0.008 0.007 0.008 0.011]
Effective sample size: 7.301e+04, ess per sample: 1.114, grad evals per sample: 41.93
Energy dist v self: 8.904e-05
Test_accuracy: 0.7818, top 90% accuracy: 0.7855
Target chain separation: 0.09300696625866552


100.00%|██████████| [00:14<00:00,  6.97%/s]
100.00%|██████████| [02:13<00:00,  1.34s/%]


means: [ 4.261 -0.053 -0.652  0.238 -0.341 -0.087  0.234 -0.327 -0.146  0.181
 -0.217 -0.129  0.063  0.245 -0.128 -0.151 -0.097  0.08   0.002  0.051
 -0.117 -0.156],
vars:  [0.608 0.    0.009 0.01  0.009 0.007 0.011 0.009 0.008 0.008 0.007 0.008
 0.008 0.009 0.009 0.007 0.008 0.009 0.008 0.007 0.008 0.011]
Effective sample size: 3.373e+04, ess per sample: 0.5147, grad evals per sample: 22.53
Energy dist v self: 6.986e-05
Test_accuracy: 0.7824, top 90% accuracy: 0.786
Energy distance between LMC and NUTS: 0.00010642

Data shape: (270, 13)


warmup: 100%|██████████| 8192/8192 [00:23<00:00, 342.36it/s]
sample: 100%|██████████| 2048/2048 [00:08<00:00, 232.54it/s]


NUTS:
means: [ 2.073 -0.022 -0.111  0.635  0.719  0.279  0.213 -0.043  0.271 -0.145
  0.415  0.315  0.317  1.018  0.645],
vars:  [0.245 0.    0.046 0.06  0.045 0.038 0.04  0.04  0.039 0.057 0.042 0.054
 0.052 0.07  0.043]
Effective sample size: 6.987e+04, ess per sample: 1.066, grad evals per sample: 38.43
Energy dist v self: 0.0001023
Test_accuracy: 0.8035, top 90% accuracy: 0.8111
Target chain separation: 0.08524808410961872


100.00%|██████████| [00:13<00:00,  7.15%/s]
100.00%|██████████| [02:06<00:00,  1.27s/%]


means: [ 1.992 -0.03  -0.464  1.706  1.564  0.835  0.69  -0.099  0.676 -0.052
  0.921  0.412  0.781  2.723  1.201],
vars:  [ 0.415  0.099  3.768 34.265 25.743  9.176  6.517  0.268  5.936  1.838
  8.584  3.085  6.67  85.814 14.693]
Effective sample size: 968.9, ess per sample: 0.01478, grad evals per sample: 22.53
Energy dist v self: 0.01377
Test_accuracy: 0.802, top 90% accuracy: 0.81
Energy distance between LMC and NUTS: 0.094604

Data shape: (2086, 18)


warmup: 100%|██████████| 8192/8192 [03:32<00:00, 38.46it/s]
sample: 100%|██████████| 2048/2048 [01:00<00:00, 33.73it/s]


NUTS:
means: [ 1.244  0.014 -0.435  0.503  0.05   0.01  -0.373  0.143  0.681  0.699
  0.684  0.886  0.583  0.61   1.246 -0.145 -1.016  0.106  1.631 -0.578],
vars:  [0.065 0.    0.008 0.014 0.007 0.007 0.018 0.026 0.026 0.177 0.594 0.562
 0.626 0.579 0.186 0.453 0.187 0.638 0.032 0.046]
Effective sample size: 4.92e+04, ess per sample: 0.7507, grad evals per sample: 346.4
Energy dist v self: 0.0006878
Test_accuracy: 0.8241, top 90% accuracy: 0.8254
Target chain separation: 0.7685093289644714


100.00%|██████████| [02:01<00:00,  1.22s/%]
100.00%|██████████| [17:57<00:00, 10.78s/%]


means: [ 1.25   0.014 -0.435  0.503  0.05   0.01  -0.372  0.144  0.679  0.696
  0.684  0.885  0.575  0.61   1.247 -0.14  -1.01   0.111  1.627 -0.578],
vars:  [0.067 0.    0.008 0.014 0.007 0.007 0.019 0.026 0.026 0.175 0.589 0.552
 0.611 0.558 0.183 0.447 0.184 0.619 0.032 0.046]
Effective sample size: 4.039e+04, ess per sample: 0.6164, grad evals per sample: 173.4
Energy dist v self: 0.001093
Test_accuracy: 0.8237, top 90% accuracy: 0.8249
Energy distance between LMC and NUTS: 0.0008031

Data shape: (7400, 20)


warmup: 100%|██████████| 8192/8192 [00:33<00:00, 245.87it/s]
sample:  66%|██████▌   | 1355/2048 [00:05<00:01, 596.20it/s]