In [1]:
%env JAX_PLATFORM_NAME=cuda

import warnings
from functools import partial

import jax
import jax.numpy as jnp
import jax.random as jr
import jax.tree_util as jtu
import numpy as np
import numpyro
import ot
import scipy
from jax import Array
from numpyro import diagnostics, distributions as dist
from numpyro.infer import MCMC, NUTS

from mcmc import run_lmc_numpyro


warnings.simplefilter("ignore", FutureWarning)

jnp.set_printoptions(precision=3, suppress=True)
jax.config.update("jax_enable_x64", True)
print(jax.devices("cuda"))


def get_model_and_data(data, name):
    dset = data[name][0, 0]
    x = dset["x"]
    labels = jnp.squeeze(dset["t"])
    # labels are -1 and 1, convert to 0 and 1
    labels = (labels + 1) / 2
    n, data_dim = x.shape
    print(f"Data shape: {x.shape}")

    # randomly shuffle the data
    perm = jax.random.permutation(jr.PRNGKey(0), n)
    x = x[perm]
    labels = labels[perm]

    n_train = min(int(n * 0.8), 500)
    x_train = x[:n_train]
    labels_train = labels[:n_train]
    x_test = x[n_train:]
    labels_test = labels[n_train:]

    def model(x, labels):
        alpha = numpyro.sample("alpha", dist.Exponential(0.01))
        W = numpyro.sample("W", dist.Normal(jnp.zeros(data_dim), 1.0 / alpha))
        b = numpyro.sample("b", dist.Normal(jnp.zeros((1,)), 1.0 / alpha))
        logits = jnp.sum(W * x + b, axis=-1)
        return numpyro.sample("obs", dist.Bernoulli(logits=logits), obs=labels)

    return model, (x_train, labels_train, x_test, labels_test)


def compute_w2(x1, x2, num_iters):
    source_samples = np.array(x1)
    target_samples = np.array(x2)
    source_weights = np.ones(source_samples.shape[0]) / source_samples.shape[0]
    target_weights = np.ones(target_samples.shape[0]) / target_samples.shape[0]
    mm = ot.dist(source_samples, target_samples)
    return ot.emd2(source_weights, target_weights, mm, numItermax=num_iters)


@partial(jax.jit, static_argnames=("max_len",))
def energy_distance(x: Array, y: Array, max_len: int = 2**15):
    assert y.ndim == x.ndim
    assert x.shape[1:] == y.shape[1:]
    if x.shape[0] > max_len:
        x = x[:max_len]
    if y.shape[0] > max_len:
        y = y[:max_len]

    @partial(jax.vmap, in_axes=(None, 0))
    def _dist_single(_x, _y_single):
        assert _x.ndim == _y_single.ndim + 1, f"{_x.ndim} != {_y_single.ndim + 1}"
        diff = _x - _y_single
        if x.ndim > 1:
            # take the norm over all axes except the first one
            diff = jnp.sqrt(jnp.sum(diff**2, axis=tuple(range(1, diff.ndim))))
        return jnp.mean(jnp.abs(diff))

    def dist(_x, _y):
        assert _x.ndim == _y.ndim
        return jnp.mean(_dist_single(_x, _y))

    return 2 * dist(x, y) - dist(x, x) - dist(y, y)


def dict_to_array(dct: dict):
    alpha = dct["alpha"]
    alpha = jnp.expand_dims(alpha, alpha.ndim)
    b = dct["b"]
    return jnp.concatenate([alpha, b, dct["W"]], axis=-1)


vec_dict_to_array = jax.jit(jax.vmap(dict_to_array, in_axes=0, out_axes=0))


def predict(x, samples):
    sum = jnp.sum(samples[:, 2:] * x + samples[:, 1:2], axis=-1)
    # apply sigmoid
    return 1.0 / (1.0 + jnp.exp(-sum))


def test_accuracy(x_test, labels_test, samples):
    if isinstance(samples, dict):
        samples = vec_dict_to_array(samples)
    sample_dim = samples.shape[-1]
    samples = jnp.reshape(samples, (-1, sample_dim))
    if samples.shape[0] > 2**10:
        samples = samples[: 2**10]

    func = jax.jit(jax.vmap(lambda x: predict(x, samples), in_axes=0, out_axes=0))
    predictions = func(x_test)
    assert predictions.shape == (
        labels_test.shape[0],
        samples.shape[0],
    ), f"{predictions.shape} != {(labels_test.shape[0], samples.shape[0])}"

    labels_test = jnp.reshape(labels_test, (labels_test.shape[0], 1))
    is_correct = jnp.abs(predictions - labels_test) < 0.5
    accuracy_per_sample = jnp.mean(is_correct, axis=0)

    avg_accuracy = jnp.mean(accuracy_per_sample)

    len10 = int(0.1 * accuracy_per_sample.shape[0])
    best_sorted = jnp.sort(accuracy_per_sample)[len10:]
    accuracy_best90 = jnp.mean(best_sorted)
    return avg_accuracy, accuracy_best90


def eval_logreg(
    samples,
    evals_per_sample=None,
    ground_truth=None,
    num_iters_w2=0,
    x_test=None,
    labels_test=None,
):
    if isinstance(samples, dict):
        samples = vec_dict_to_array(samples)

    sample_dim = samples.shape[-1]
    reshaped_with_alpha = jnp.reshape(samples, (-1, sample_dim))
    vars = jnp.var(reshaped_with_alpha, axis=0)
    means = jnp.mean(reshaped_with_alpha, axis=0)
    result_str = f"means: {means},\nvars:  {vars}"

    samples_with_alpha = samples
    samples = samples[..., 1:]
    ess = diagnostics.effective_sample_size(samples)
    avg_ess = 1 / jnp.mean(1 / jnp.stack(jtu.tree_leaves(ess)))
    result_str += f"\nEffective sample size: {avg_ess}"
    if evals_per_sample is not None:
        result_str += f", grad evals per sample: {evals_per_sample:.4}"

    reshaped = jnp.reshape(samples, (-1, sample_dim - 1))

    half_len = reshaped.shape[0] // 2
    energy_self = energy_distance(reshaped[:half_len], reshaped[half_len:])
    result_str += f"\nEnergy dist v self: {energy_self:.4}"

    if ground_truth is not None:
        ground_truth = ground_truth[..., 1:]
        energy_gt = energy_distance(reshaped, ground_truth)
        result_str += f", energy dist vs ground truth: {energy_gt:.4}"
    if num_iters_w2 > 0 and ground_truth is not None:
        w2 = compute_w2(reshaped, ground_truth, num_iters_w2)
        result_str += f", Wasserstein-2: {w2:.4}"

    if x_test is not None and labels_test is not None:
        acc_error, acc_best90 = test_accuracy(x_test, labels_test, samples_with_alpha)
        result_str += (
            f"\nTest_accuracy: {acc_error:.4}, top 90% accuracy: {acc_best90:.4}"
        )

    print(result_str)

    return result_str


dataset = scipy.io.loadmat("mcmc_data/benchmarks.mat")
data_name = "banana"
model_logreg, data_split = get_model_and_data(dataset, data_name)
x_train, labels_train, x_test, labels_test = data_split

env: JAX_PLATFORM_NAME=cuda
[cuda(id=0)]
Data shape: (5300, 2)


In [2]:
file_name = f"mcmc_data/{data_name}_ground_truth.npy"

# gt_nuts = MCMC(NUTS(model_logreg, step_size=1.0), num_warmup=2**13, num_samples=2**16)
# gt_nuts.run(jr.PRNGKey(0), x_train, labels_train)
# gt_logreg = vec_dict_to_array(gt_nuts.get_samples())
# # thin the samples to 2**15
# np.save(file_name, gt_logreg)

gt_logreg = np.load(file_name)
gt_logreg = gt_logreg[::4]
size_gt_half = int(gt_logreg.shape[0] // 2)
energy_bias = energy_distance(gt_logreg[:size_gt_half], gt_logreg[size_gt_half:])
print(f"Energy bias: {energy_bias}")
print(f"Ground truth shape: {gt_logreg.shape}")
print(f"test accuracy: {test_accuracy(x_test, labels_test, gt_logreg)}")

Energy bias: 0.251009472403517
Ground truth shape: (16384, 4)
test accuracy: (Array(0.532, dtype=float32), Array(0.543, dtype=float32))


In [3]:
flattened_gt = jnp.reshape(gt_logreg, (-1, 4))
print(flattened_gt.shape)
print(jnp.var(flattened_gt, axis=0))
print(jnp.mean(flattened_gt, axis=0))

(16384, 4)
[3464.127    0.002    0.002    0.002]
[57.174 -0.045 -0.008 -0.008]


In [4]:
num_chains = 2**5
num_samples_per_chain = 2**10
warmup_len = 2**13

In [5]:
out_logreg_lmc, steps_logreg_lmc = run_lmc_numpyro(
    jr.PRNGKey(3),
    model_logreg,
    (x_train, labels_train),
    num_chains,
    num_samples_per_chain,
    chain_sep=0.3,
    tol=0.05,
    warmup_mult=warmup_len,
    warmup_tol_mult=16,
    use_adaptive=False,
)
print(jtu.tree_map(lambda x: x.shape, out_logreg_lmc))

100.00%|██████████| [00:16<00:00,  6.11%/s]
100.00%|██████████| [00:41<00:00,  2.40%/s]


LMC: gradient evaluations per output: 20.04
{'W': (32, 1024, 2), 'alpha': (32, 1024), 'b': (32, 1024, 1)}


In [6]:
_ = eval_logreg(
    out_logreg_lmc, steps_logreg_lmc, gt_logreg, x_test=x_test, labels_test=labels_test
)

means: [-34.706  -0.124  -0.042  -0.043],
vars:  [2280.214    0.771    1.778    1.917]
Effective sample size: 24883.356050420618
Energy dist vs self: 0.003162, energy dist vs ground truth: 0.05075, test_accuracy: 0.5422, top 90% accuracy: 0.5526, Gradient evals per sample: 20.04


In [8]:
nuts = MCMC(
    NUTS(model_logreg),
    num_warmup=warmup_len,
    num_samples=num_samples_per_chain,
    num_chains=num_chains,
    chain_method="vectorized",
)
nuts.warmup(
    jr.PRNGKey(2),
    x_train,
    labels_train,
    extra_fields=("num_steps",),
    collect_warmup=True,
)
warmup_steps = sum(nuts.get_extra_fields()["num_steps"])
nuts.run(jr.PRNGKey(2), x_train, labels_train, extra_fields=("num_steps",))
out_logreg_nuts = nuts.get_samples(group_by_chain=True)
num_steps_nuts = sum(nuts.get_extra_fields()["num_steps"]) + warmup_steps
geps_nuts = num_steps_nuts / (num_chains * num_samples_per_chain)
print(geps_nuts)

warmup: 100%|██████████| 1024/1024 [00:11<00:00, 88.90it/s] 
sample: 100%|██████████| 1024/1024 [00:06<00:00, 159.68it/s]


21.57537841796875


In [9]:
print(jtu.tree_map(lambda x: x.shape, out_logreg_nuts))

{'W': (32, 1024, 2), 'alpha': (32, 1024), 'b': (32, 1024, 1)}


In [12]:
_ = eval_logreg(
    out_logreg_nuts, geps_nuts, gt_logreg, x_test=x_test, labels_test=labels_test
)

means: [58.19  -0.045 -0.008 -0.008],
vars:  [3766.227    0.002    0.002    0.002]
Effective sample size: 5715.352966006341
Energy dist vs self: 4.118e-05, energy dist vs ground truth: 7.018e-06, test_accuracy: 0.5263, top 90% accuracy: 0.537


In [2]:
def run_logreg_dataset(name, results_filename=None):
    model_logreg, data_split = get_model_and_data(dataset, name)
    x_train, labels_train, x_test, labels_test = data_split

    # Compute ground truth
    # ground_truth_filename = f"mcmc_data/{name}_ground_truth.npy"
    # if os.path.isfile(ground_truth_filename):
    #     ground_truth = np.load(ground_truth_filename)
    # else:
    #     gt_nuts = MCMC(
    #         NUTS(model_logreg, step_size=1.0), num_warmup=2**13, num_samples=2**18
    #     )
    #     gt_nuts.run(jr.PRNGKey(0), x_train, labels_train)
    #     ground_truth = vec_dict_to_array(gt_nuts.get_samples())
    #     np.save(f"mcmc_data/{name}_ground_truth.npy", ground_truth)
    # # thin the samples to 2**16
    # factor = int(ground_truth.shape[0] // 2**16)
    # ground_truth = ground_truth[::factor]
    # assert ground_truth.shape == (
    #     2**16,
    #     2 + x_train.shape[1],
    # ), f"ground_truth.shape: {ground_truth.shape}"
    #
    # size_gt_half = int(ground_truth.shape[0] // 2)
    # gt_energy_bias = energy_distance(
    #     ground_truth[:size_gt_half], ground_truth[size_gt_half:]
    # )
    # print(f"Ground truth energy bias: {gt_energy_bias:.4}")
    num_chains = 2**5
    num_samples_per_chain = 2**10
    warmup_len = 2**12

    out_logreg_lmc, geps_lmc = run_lmc_numpyro(
        jr.PRNGKey(3),
        model_logreg,
        (x_train, labels_train),
        num_chains,
        num_samples_per_chain,
        chain_sep=0.3,
        tol=0.05,
        warmup_mult=warmup_len,
        warmup_tol_mult=4,
        use_adaptive=False,
    )

    eval_lmc = eval_logreg(
        out_logreg_lmc,
        geps_lmc,
        ground_truth=None,
        x_test=x_test,
        labels_test=labels_test,
    )

    nuts = MCMC(
        NUTS(model_logreg),
        num_warmup=warmup_len,
        num_samples=num_samples_per_chain,
        num_chains=num_chains,
        chain_method="vectorized",
    )
    nuts.warmup(
        jr.PRNGKey(2),
        x_train,
        labels_train,
        extra_fields=("num_steps",),
        collect_warmup=True,
    )
    warmup_steps = sum(nuts.get_extra_fields()["num_steps"])
    nuts.run(jr.PRNGKey(2), x_train, labels_train, extra_fields=("num_steps",))
    out_logreg_nuts = nuts.get_samples(group_by_chain=True)
    num_steps_nuts = sum(nuts.get_extra_fields()["num_steps"]) + warmup_steps
    geps_nuts = num_steps_nuts / (num_chains * num_samples_per_chain)
    print("NUTS:")
    eval_nuts = eval_logreg(
        out_logreg_nuts,
        geps_nuts,
        ground_truth=None,
        x_test=x_test,
        labels_test=labels_test,
    )

    if results_filename is not None:
        with open(results_filename, "a") as f:
            f.write(f"{name}\n")
            f.write(f"LMC: {eval_lmc}\n")
            f.write(f"NUTS: {eval_nuts}\n\n")

In [3]:
names = [
    "banana",
    "breast_cancer",
    "diabetis",
    "flare_solar",
    "german",
    "heart",
    "image",
    "ringnorm",
    "splice",
    "thyroid",
    "titanic",
    "twonorm",
    "waveform",
]

In [4]:
# make a file for the results, which has date and time in the name
import datetime


results_filename = (
    f"mcmc_data/results_{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.txt"
)

# create the results file
with open(results_filename, "w") as f:
    f.write("Results\n\n")

# with warnings.catch_warnings():
#     warnings.simplefilter("ignore")
for name in names:
    print(f"==================== {name} ====================")
    run_logreg_dataset(name, results_filename)
    print()

Data shape: (5300, 2)


100.00%|██████████| [00:30<00:00,  3.23%/s]
100.00%|██████████| [00:41<00:00,  2.41%/s]


LMC: gradient evaluations per output: 26.04
means: [-38.685  -0.127  -0.053  -0.037],
vars:  [2785.765    1.051    2.313    1.315]
Effective sample size: 25824.6440927419, grad evals per sample: 26.04
Energy dist v self: 0.0001288
Test_accuracy: 0.5437, top 90% accuracy: 0.5537


warmup: 100%|██████████| 4096/4096 [00:24<00:00, 168.78it/s]
sample: 100%|██████████| 1024/1024 [00:07<00:00, 129.09it/s]


NUTS:
means: [58.63  -0.045 -0.009 -0.008],
vars:  [4033.952    0.002    0.002    0.002]
Effective sample size: 5615.085475543447, grad evals per sample: 46.75
Energy dist v self: 1.552e-05
Test_accuracy: 0.5357, top 90% accuracy: 0.5466

Data shape: (263, 9)


100.00%|██████████| [00:31<00:00,  3.14%/s]
100.00%|██████████| [00:41<00:00,  2.40%/s]


LMC: gradient evaluations per output: 26.04
means: [-1.086 -0.925 -0.676  1.228  0.284  2.989 -1.551  4.056 -0.994 -0.028
 -1.09 ],
vars:  [ 1.108  1.886  1.5    1.254  1.618  6.507  6.94  10.686  1.36   0.487
  3.71 ]
Effective sample size: 19911.019149260137, grad evals per sample: 26.04
Energy dist v self: 0.0006025
Test_accuracy: 0.6356, top 90% accuracy: 0.6552


warmup: 100%|██████████| 4096/4096 [00:16<00:00, 251.28it/s]
sample: 100%|██████████| 1024/1024 [00:05<00:00, 173.57it/s]


NUTS:
means: [ 4.592 -0.124 -0.098  0.108  0.053  0.295 -0.211  0.425 -0.112 -0.001
 -0.133],
vars:  [2.504 0.    0.024 0.027 0.02  0.026 0.022 0.031 0.019 0.021 0.017]
Effective sample size: 31398.56769901632, grad evals per sample: 40.68
Energy dist v self: 6.009e-05
Test_accuracy: 0.6588, top 90% accuracy: 0.6635

Data shape: (768, 8)


100.00%|██████████| [00:32<00:00,  3.09%/s]
100.00%|██████████| [00:41<00:00,  2.38%/s]


LMC: gradient evaluations per output: 26.04
means: [-2.144 -1.239  4.427 14.44  -3.232 -1.058 -0.598  8.987  2.509  1.854],
vars:  [ 1.364  4.632  7.931 99.812 19.078  6.379 20.258 49.16   4.966 19.189]
Effective sample size: 20284.231899834205, grad evals per sample: 26.04
Energy dist v self: 0.00203
Test_accuracy: 0.7015, top 90% accuracy: 0.7249


warmup: 100%|██████████| 4096/4096 [00:14<00:00, 275.01it/s]
sample: 100%|██████████| 1024/1024 [00:05<00:00, 203.86it/s]


NUTS:
means: [ 2.102 -0.098  0.404  1.142 -0.25  -0.038 -0.102  0.7    0.171  0.106],
vars:  [0.274 0.    0.017 0.021 0.013 0.017 0.015 0.021 0.013 0.016]
Effective sample size: 43251.54723584915, grad evals per sample: 36.73
Energy dist v self: 7.334e-05
Test_accuracy: 0.7721, top 90% accuracy: 0.7742

Data shape: (144, 9)


100.00%|██████████| [00:32<00:00,  3.12%/s]
100.00%|██████████| [00:42<00:00,  2.36%/s]


LMC: gradient evaluations per output: 26.04
means: [  -4.79    66.186    0.834    0.743   -0.62     3.351 -310.052   18.519
    0.16     1.115  207.591],
vars:  [   0.061   51.825    1.675    1.302    0.998   10.993  940.152  705.102
    0.233    1.662 2462.391]
Effective sample size: 72.64679882644724, grad evals per sample: 26.04
Energy dist v self: 1.297
Test_accuracy: 0.6213, top 90% accuracy: 0.6292


warmup: 100%|██████████| 4096/4096 [00:32<00:00, 126.20it/s]
sample: 100%|██████████| 1024/1024 [00:11<00:00, 89.57it/s]


NUTS:
means: [21.593  0.05   0.018  0.024 -0.     0.046 -0.017  0.041  0.002  0.023
  0.012],
vars:  [191.924   0.002   0.008   0.006   0.007   0.012   0.016   0.015   0.003
   0.007   0.01 ]
Effective sample size: 10168.12940199541, grad evals per sample: 66.59
Energy dist v self: 0.0001643
Test_accuracy: 0.5515, top 90% accuracy: 0.5519

Data shape: (1000, 20)


100.00%|██████████| [00:32<00:00,  3.08%/s]
100.00%|██████████| [00:43<00:00,  2.29%/s]


LMC: gradient evaluations per output: 26.04
means: [ -2.686  -4.129 -41.821  19.505 -14.172 -10.053  14.341 -20.658 -12.894
  10.557 -17.654 -11.339   5.684  26.635  -2.316  -8.647 -12.04    3.658
  -1.503   1.718  -8.435  -4.901],
vars:  [  0.244  50.412 525.827 158.221 119.716  37.103  84.575 120.678  42.671
  24.141  79.601  23.476  22.559 113.338  15.823  21.419  27.051  34.649
  24.454   9.278  25.796  21.462]
Effective sample size: 10806.724448469908, grad evals per sample: 26.04
Energy dist v self: 0.014
Test_accuracy: 0.6896, top 90% accuracy: 0.7162


warmup: 100%|██████████| 4096/4096 [00:19<00:00, 209.48it/s]
sample: 100%|██████████| 1024/1024 [00:06<00:00, 154.77it/s]


NUTS:
means: [ 4.262 -0.052 -0.619  0.284 -0.211 -0.142  0.188 -0.292 -0.17   0.146
 -0.249 -0.164  0.085  0.351 -0.046 -0.112 -0.144  0.069 -0.018  0.023
 -0.117 -0.084],
vars:  [0.687 0.    0.014 0.014 0.013 0.01  0.015 0.012 0.012 0.012 0.011 0.011
 0.011 0.015 0.012 0.009 0.012 0.012 0.012 0.012 0.013 0.014]
Effective sample size: 37076.5409434621, grad evals per sample: 46.65
Energy dist v self: 8.993e-05
Test_accuracy: 0.7456, top 90% accuracy: 0.7485

Data shape: (270, 13)


100.00%|██████████| [00:32<00:00,  3.11%/s]
100.00%|██████████| [00:41<00:00,  2.40%/s]


LMC: gradient evaluations per output: 26.04
means: [-1.376 -0.14  -1.048  5.911  5.865  2.765  2.206 -0.244  2.472 -1.056
  3.589  2.423  2.85   9.186  5.086],
vars:  [ 0.397  0.47   3.846 12.453 14.796  3.256  2.724  0.936  4.031  7.907
  6.579  7.524  3.756 26.5   17.924]
Effective sample size: 11481.772492453989, grad evals per sample: 26.04
Energy dist v self: 0.002634
Test_accuracy: 0.7755, top 90% accuracy: 0.7904


warmup: 100%|██████████| 4096/4096 [00:16<00:00, 252.63it/s]
sample: 100%|██████████| 1024/1024 [00:05<00:00, 195.34it/s]


NUTS:
means: [ 2.073 -0.022 -0.11   0.635  0.719  0.277  0.212 -0.043  0.271 -0.143
  0.414  0.316  0.317  1.017  0.645],
vars:  [0.247 0.    0.047 0.06  0.045 0.038 0.04  0.04  0.039 0.057 0.042 0.054
 0.052 0.071 0.042]
Effective sample size: 36304.656066844735, grad evals per sample: 41.67
Energy dist v self: 0.0001574
Test_accuracy: 0.8018, top 90% accuracy: 0.8098

Data shape: (2086, 18)


100.00%|██████████| [00:32<00:00,  3.11%/s]
100.00%|██████████| [00:42<00:00,  2.35%/s]


LMC: gradient evaluations per output: 26.04
means: [ -3.557   2.543  -9.403  24.006  -3.228  -0.05  -11.55   -6.894  26.311
  29.348  44.994  52.155  33.247  39.87   49.798  -7.988 -35.621 -55.294
  60.947 -19.575],
vars:  [   0.06     2.86     4.479   16.51     0.951    0.673    4.185    7.456
   11.2     29.97   445.427  433.269  422.339  346.188  141.605  254.486
  116.795 1861.795   22.687   42.254]
Effective sample size: 104.13890329575774, grad evals per sample: 26.04
Energy dist v self: 2.916
Test_accuracy: 0.7824, top 90% accuracy: 0.7878


warmup: 100%|██████████| 4096/4096 [01:39<00:00, 41.37it/s]
sample: 100%|██████████| 1024/1024 [00:16<00:00, 62.59it/s]


NUTS:
means: [ 1.333  0.015 -0.246  0.48  -0.088  0.012 -0.319 -0.226  0.687  0.568
  0.629  0.824  0.513  0.57   1.239 -0.263 -0.794  0.215  1.38  -0.772],
vars:  [0.09  0.    0.013 0.027 0.014 0.011 0.035 0.137 0.049 0.283 0.533 0.505
 0.564 0.526 0.19  0.409 0.184 0.566 0.051 0.073]
Effective sample size: 32978.794853877844, grad evals per sample: 268.1
Energy dist v self: 0.0004903
Test_accuracy: 0.821, top 90% accuracy: 0.8234

Data shape: (7400, 20)


100.00%|██████████| [00:33<00:00,  3.00%/s]
100.00%|██████████| [00:43<00:00,  2.31%/s]


LMC: gradient evaluations per output: 26.04
means: [ -3.069  -0.446 -17.122 -15.326 -33.29  -13.957 -23.081 -13.571 -24.035
 -14.64  -28.931 -18.124 -31.797 -10.115 -17.534 -23.706 -29.174 -22.228
 -21.944 -19.437 -25.332 -20.262],
vars:  [0.029 8.697 3.877 3.363 7.437 3.49  3.526 2.651 5.296 1.94  5.698 3.379
 6.175 3.713 3.05  3.396 5.352 3.576 4.604 3.553 6.227 4.428]
Effective sample size: 3039.9328890670954, grad evals per sample: 26.04
Energy dist v self: 0.002125
Test_accuracy: 0.7178, top 90% accuracy: 0.7251


warmup: 100%|██████████| 4096/4096 [00:19<00:00, 205.17it/s]
sample: 100%|██████████| 1024/1024 [00:04<00:00, 240.78it/s]


NUTS:
means: [ 3.308 -0.004 -0.24  -0.226 -0.446 -0.21  -0.31  -0.187 -0.322 -0.182
 -0.395 -0.232 -0.41  -0.163 -0.237 -0.306 -0.39  -0.289 -0.292 -0.271
 -0.371 -0.285],
vars:  [0.383 0.    0.012 0.012 0.014 0.013 0.012 0.012 0.013 0.013 0.012 0.01
 0.013 0.012 0.012 0.013 0.013 0.013 0.011 0.012 0.012 0.013]
Effective sample size: 55065.06793305351, grad evals per sample: 44.53
Energy dist v self: 8.174e-05
Test_accuracy: 0.7528, top 90% accuracy: 0.7543

Data shape: (2991, 60)


100.00%|██████████| [00:33<00:00,  3.03%/s]
100.00%|██████████| [00:44<00:00,  2.25%/s]


LMC: gradient evaluations per output: 26.04
means: [  -3.993  -12.968    8.151   41.462  -37.299  -40.901    9.544   48.638
   24.348  -18.873   38.083   16.989   11.6      7.718   64.416    5.973
   33.586    3.756   16.251   81.785    0.301   63.685   14.852   27.781
   20.084   40.565   24.339   45.294    2.063  117.626 -236.195 -104.312
 -109.094  168.712 -105.085  -78.3    -55.425   66.396  -39.499   12.088
  -62.53    43.023    1.591   -4.417   -2.586  -25.059  -22.006   60.642
  -10.101    0.306  -33.31   -25.042   28.785   -4.73    -5.824   72.103
  -37.277    5.236   -3.83    -6.86   -29.276    3.688],
vars:  [   78.417    71.166    32.906   464.141   376.439   446.145    33.232
   630.187   173.329   120.475   424.659    90.608    57.242    31.004
  1065.8      64.148   359.856    81.556    81.687  1771.219    55.915
  1088.181   127.824   210.367   150.885   471.862   277.556   614.053
    13.078  3592.419 14637.719  2835.889  3143.596  7649.423  2892.39
  1666.049   828.73 

warmup: 100%|██████████| 4096/4096 [02:42<00:00, 25.22it/s]
sample: 100%|██████████| 1024/1024 [00:34<00:00, 29.76it/s]


NUTS:
means: [ 3.155 -0.081  0.004  0.207 -0.147 -0.188  0.028  0.234  0.085 -0.068
  0.223  0.096  0.034  0.027  0.295  0.088  0.229  0.099  0.081  0.41
  0.046  0.343  0.123  0.167  0.114  0.205  0.187  0.257 -0.016  0.56
 -1.259 -0.479 -0.595  0.903 -0.565 -0.427 -0.279  0.341 -0.196  0.057
 -0.224  0.192 -0.03  -0.015 -0.056 -0.067 -0.08   0.259 -0.029  0.004
 -0.133 -0.108  0.094 -0.014 -0.023  0.319 -0.17  -0.004 -0.011 -0.055
 -0.15   0.044],
vars:  [0.189 0.001 0.017 0.017 0.019 0.016 0.017 0.018 0.016 0.018 0.017 0.016
 0.016 0.016 0.02  0.017 0.018 0.017 0.018 0.019 0.016 0.017 0.018 0.018
 0.017 0.016 0.016 0.019 0.016 0.022 0.03  0.028 0.025 0.023 0.02  0.019
 0.019 0.017 0.018 0.016 0.021 0.017 0.016 0.017 0.017 0.018 0.018 0.018
 0.018 0.016 0.018 0.018 0.018 0.017 0.017 0.02  0.019 0.016 0.017 0.017
 0.016 0.018]
Effective sample size: 39403.62681640876, grad evals per sample: 436.0
Energy dist v self: 0.0001822
Test_accuracy: 0.8027, top 90% accuracy: 0.8047

Data shape

100.00%|██████████| [00:31<00:00,  3.14%/s]
100.00%|██████████| [00:42<00:00,  2.37%/s]


LMC: gradient evaluations per output: 26.04
means: [-1.688  0.167 -1.506  4.008  2.163 14.596  3.852],
vars:  [  0.357   0.15    1.411   7.387   1.733 158.914   8.371]
Effective sample size: 324.0979055522625, grad evals per sample: 26.04
Energy dist v self: 0.009433
Test_accuracy: 0.8398, top 90% accuracy: 0.843


warmup: 100%|██████████| 4096/4096 [00:30<00:00, 135.30it/s]
sample: 100%|██████████| 1024/1024 [00:07<00:00, 129.85it/s]


NUTS:
means: [ 0.235  0.041 -1.144  3.162  1.805 10.492  3.003],
vars:  [0.008 0.013 0.226 0.767 0.379 7.981 1.635]
Effective sample size: 9204.68229441554, grad evals per sample: 60.02
Energy dist v self: 0.001321
Test_accuracy: 0.8449, top 90% accuracy: 0.8485

Data shape: (24, 3)


100.00%|██████████| [00:32<00:00,  3.10%/s]
100.00%|██████████| [00:41<00:00,  2.40%/s]


LMC: gradient evaluations per output: 26.04
means: [-12.56   -0.069  -0.556   1.144  -0.451],
vars:  [789.835  11.374  58.521  23.746  37.226]
Effective sample size: 2018.134778299716, grad evals per sample: 26.04
Energy dist v self: 0.00486
Test_accuracy: 0.4695, top 90% accuracy: 0.4894


warmup: 100%|██████████| 4096/4096 [00:34<00:00, 119.68it/s]
sample: 100%|██████████| 1024/1024 [00:15<00:00, 67.66it/s]


NUTS:
means: [85.228  0.005 -0.006  0.019 -0.004],
vars:  [5555.305    0.002    0.004    0.004    0.004]
Effective sample size: 7466.438693134033, grad evals per sample: 53.89
Energy dist v self: 2.853e-05
Test_accuracy: 0.5182, top 90% accuracy: 0.5334

Data shape: (7400, 20)


100.00%|██████████| [00:33<00:00,  2.97%/s]
100.00%|██████████| [00:43<00:00,  2.28%/s]


LMC: gradient evaluations per output: 26.04
means: [-3.03  -0.151 35.745 16.692 27.576 37.237 16.476 17.835  9.554 28.122
 16.402 24.036 23.182 13.6   38.546 20.301 41.443 20.34  20.699 24.923
 18.506 19.982],
vars:  [  0.401   0.196 389.252 131.447 241.379 486.568 133.951 110.961 101.567
 266.482 154.662 203.677 199.302 104.039 488.689 177.997 589.072 139.646
 158.133 232.159 150.672 154.835]
Effective sample size: 162.69401997540487, grad evals per sample: 26.04
Energy dist v self: 0.1892
Test_accuracy: 0.9666, top 90% accuracy: 0.9679


warmup: 100%|██████████| 4096/4096 [00:27<00:00, 151.02it/s]
sample: 100%|██████████| 1024/1024 [00:06<00:00, 161.18it/s]


NUTS:
means: [0.893 0.017 1.351 0.975 1.232 1.18  0.91  0.751 0.892 1.202 1.063 1.079
 1.057 0.824 1.489 1.148 1.758 1.209 0.925 1.331 0.955 1.006],
vars:  [0.046 0.    0.224 0.152 0.244 0.274 0.145 0.139 0.178 0.205 0.151 0.171
 0.165 0.192 0.243 0.191 0.304 0.163 0.157 0.177 0.184 0.184]
Effective sample size: 14309.547377947856, grad evals per sample: 52.4
Energy dist v self: 0.0003405
Test_accuracy: 0.9678, top 90% accuracy: 0.9685

Data shape: (5000, 21)


100.00%|██████████| [00:33<00:00,  3.02%/s]
100.00%|██████████| [00:42<00:00,  2.33%/s]


LMC: gradient evaluations per output: 26.04
means: [ -2.752  -5.051   8.147   0.933  -3.148   1.613  13.652   2.375   2.702
  14.972  23.312  25.357  23.992  13.002  -7.265  -4.148 -26.316 -28.2
 -26.912 -19.039  -7.302 -12.733   7.515],
vars:  [0.028 2.81  1.913 1.185 1.277 1.731 3.053 3.194 5.26  4.347 7.129 5.677
 5.743 4.239 2.449 4.954 7.361 6.085 7.208 5.698 2.682 2.205 1.356]
Effective sample size: 1853.3920673438893, grad evals per sample: 26.04
Energy dist v self: 0.00916
Test_accuracy: 0.8539, top 90% accuracy: 0.8585


warmup: 100%|██████████| 4096/4096 [00:32<00:00, 127.85it/s]
sample: 100%|██████████| 1024/1024 [00:08<00:00, 115.75it/s]


NUTS:
means: [ 2.384 -0.132  0.205  0.015 -0.068  0.051  0.296  0.111  0.136  0.381
  0.626  0.624  0.594  0.323 -0.13  -0.186 -0.585 -0.674 -0.643 -0.529
 -0.207 -0.32   0.189],
vars:  [0.196 0.    0.02  0.021 0.024 0.031 0.045 0.045 0.05  0.047 0.046 0.041
 0.038 0.039 0.038 0.052 0.063 0.055 0.048 0.033 0.027 0.024 0.023]
Effective sample size: 34196.62190399509, grad evals per sample: 90.17
Energy dist v self: 0.0001486
Test_accuracy: 0.866, top 90% accuracy: 0.8674

