In [1]:
%env JAX_PLATFORM_NAME=cuda

import warnings
from functools import partial

import jax
import jax.numpy as jnp
import jax.random as jr
import jax.tree_util as jtu
import numpy as np
import numpyro
import ot
import scipy
from jax import Array
from numpyro import diagnostics, distributions as dist
from numpyro.infer import MCMC, NUTS

from mcmc import run_lmc_numpyro


warnings.simplefilter("ignore", FutureWarning)

jnp.set_printoptions(precision=3, suppress=True)
jax.config.update("jax_enable_x64", True)
print(jax.devices("cuda"))


def get_model_and_data(data, name):
    dset = data[name][0, 0]
    x = dset["x"]
    labels = jnp.squeeze(dset["t"])
    # labels are -1 and 1, convert to 0 and 1
    labels = (labels + 1) / 2
    n, data_dim = x.shape
    print(f"Data shape: {x.shape}")

    # randomly shuffle the data
    perm = jax.random.permutation(jr.PRNGKey(0), n)
    x = x[perm]
    labels = labels[perm]

    n_train = min(int(n * 0.8), 500)
    x_train = x[:n_train]
    labels_train = labels[:n_train]
    x_test = x[n_train:]
    labels_test = labels[n_train:]

    def model(x, labels):
        alpha = numpyro.sample("alpha", dist.Exponential(0.01))
        W = numpyro.sample("W", dist.Normal(jnp.zeros(data_dim), 1.0 / alpha))
        b = numpyro.sample("b", dist.Normal(jnp.zeros((1,)), 1.0 / alpha))
        logits = jnp.sum(W * x + b, axis=-1)
        return numpyro.sample("obs", dist.Bernoulli(logits=logits), obs=labels)

    return model, (x_train, labels_train, x_test, labels_test)


def compute_w2(x1, x2, num_iters):
    source_samples = np.array(x1)
    target_samples = np.array(x2)
    source_weights = np.ones(source_samples.shape[0]) / source_samples.shape[0]
    target_weights = np.ones(target_samples.shape[0]) / target_samples.shape[0]
    mm = ot.dist(source_samples, target_samples)
    return ot.emd2(source_weights, target_weights, mm, numItermax=num_iters)


@partial(jax.jit, static_argnames=("max_len",))
def energy_distance(x: Array, y: Array, max_len: int = 2**15):
    assert y.ndim == x.ndim
    assert x.shape[1:] == y.shape[1:]
    if x.shape[0] > max_len:
        x = x[:max_len]
    if y.shape[0] > max_len:
        y = y[:max_len]

    @partial(jax.vmap, in_axes=(None, 0))
    def _dist_single(_x, _y_single):
        assert _x.ndim == _y_single.ndim + 1, f"{_x.ndim} != {_y_single.ndim + 1}"
        diff = _x - _y_single
        if x.ndim > 1:
            # take the norm over all axes except the first one
            diff = jnp.sqrt(jnp.sum(diff**2, axis=tuple(range(1, diff.ndim))))
        return jnp.mean(jnp.abs(diff))

    def dist(_x, _y):
        assert _x.ndim == _y.ndim
        return jnp.mean(_dist_single(_x, _y))

    return 2 * dist(x, y) - dist(x, x) - dist(y, y)


def dict_to_array(dct: dict):
    alpha = dct["alpha"]
    alpha = jnp.expand_dims(alpha, alpha.ndim)
    b = dct["b"]
    return jnp.concatenate([alpha, b, dct["W"]], axis=-1)


vec_dict_to_array = jax.jit(jax.vmap(dict_to_array, in_axes=0, out_axes=0))


def predict(x, samples):
    sum = jnp.sum(samples[:, 2:] * x + samples[:, 1:2], axis=-1)
    # apply sigmoid
    return 1.0 / (1.0 + jnp.exp(-sum))


def test_accuracy(x_test, labels_test, samples):
    if isinstance(samples, dict):
        samples = vec_dict_to_array(samples)
    sample_dim = samples.shape[-1]
    samples = jnp.reshape(samples, (-1, sample_dim))
    if samples.shape[0] > 2**10:
        samples = samples[: 2**10]

    func = jax.jit(jax.vmap(lambda x: predict(x, samples), in_axes=0, out_axes=0))
    predictions = func(x_test)
    assert predictions.shape == (
        labels_test.shape[0],
        samples.shape[0],
    ), f"{predictions.shape} != {(labels_test.shape[0], samples.shape[0])}"

    labels_test = jnp.reshape(labels_test, (labels_test.shape[0], 1))
    is_correct = jnp.abs(predictions - labels_test) < 0.5
    accuracy_per_sample = jnp.mean(is_correct, axis=0)

    avg_accuracy = jnp.mean(accuracy_per_sample)

    len10 = int(0.1 * accuracy_per_sample.shape[0])
    best_sorted = jnp.sort(accuracy_per_sample)[len10:]
    accuracy_best90 = jnp.mean(best_sorted)
    return avg_accuracy, accuracy_best90


def eval_logreg(
    samples,
    evals_per_sample,
    ground_truth=None,
    num_iters_w2=0,
    x_test=None,
    labels_test=None,
):
    if isinstance(samples, dict):
        samples = vec_dict_to_array(samples)
    # remove the alpha parameter
    samples_with_alpha = samples
    samples = samples[..., 1:]
    ess = diagnostics.effective_sample_size(samples)
    avg_ess = 1 / jnp.mean(1 / jnp.stack(jtu.tree_leaves(ess)))
    print(f"Effective sample size: {avg_ess}")
    # print(
    #     f"Gradient evals per effective sample:"
    #     f" {evals_per_sample * samples.shape[0]/ess}"
    # )
    sample_dim = samples.shape[-1]
    reshaped = jnp.reshape(samples, (-1, sample_dim))

    half_len = reshaped.shape[0] // 2
    energy_self = energy_distance(reshaped[:half_len], reshaped[half_len:])
    result_str = f"Energy dist vs self: {energy_self:.4}"

    if ground_truth is not None:
        ground_truth = ground_truth[..., 1:]
        energy_gt = energy_distance(reshaped, ground_truth)
        result_str += f", energy dist vs ground truth: {energy_gt:.4}"
    if num_iters_w2 > 0 and ground_truth is not None:
        w2 = compute_w2(reshaped, ground_truth, num_iters_w2)
        result_str += f", Wasserstein-2: {w2:.4}"

    if x_test is not None and labels_test is not None:
        acc_error, acc_best90 = test_accuracy(x_test, labels_test, samples_with_alpha)
        result_str += (
            f", test_accuracy: {acc_error:.4}, top 90% accuracy: {acc_best90:.4}"
        )
    print(result_str)


dataset = scipy.io.loadmat("mcmc_data/benchmarks.mat")
data_name = "banana"
model_logreg, data_split = get_model_and_data(dataset, data_name)
x_train, labels_train, x_test, labels_test = data_split

env: JAX_PLATFORM_NAME=cuda
[cuda(id=0)]
Data shape: (5300, 2)


In [2]:
file_name = f"mcmc_data/{data_name}_ground_truth.npy"

# gt_nuts = MCMC(NUTS(model_logreg, step_size=1.0), num_warmup=2**13, num_samples=2**16)
# gt_nuts.run(jr.PRNGKey(0), x_train, labels_train)
# gt_logreg = vec_dict_to_array(gt_nuts.get_samples())
# # thin the samples to 2**15
# np.save(file_name, gt_logreg)

gt_logreg = np.load(file_name)
gt_logreg = gt_logreg[::4]
size_gt_half = int(gt_logreg.shape[0] // 2)
energy_bias = energy_distance(gt_logreg[:size_gt_half], gt_logreg[size_gt_half:])
print(f"Energy bias: {energy_bias}")
print(f"Ground truth shape: {gt_logreg.shape}")
print(f"test accuracy: {test_accuracy(x_test, labels_test, gt_logreg)}")

Energy bias: 0.251009472403517
Ground truth shape: (16384, 4)
test accuracy: (Array(0.532, dtype=float32), Array(0.543, dtype=float32))


In [3]:
flattened_gt = jnp.reshape(gt_logreg, (-1, 4))
print(flattened_gt.shape)
print(jnp.var(flattened_gt, axis=0))
print(jnp.mean(flattened_gt, axis=0))

(16384, 4)
[3464.127    0.002    0.002    0.002]
[57.174 -0.045 -0.008 -0.008]


In [4]:
num_chains = 2**5
num_samples_per_chain = 2**10
total_samples = num_chains * num_samples_per_chain
warmup_len = 2**12

In [5]:
out_logreg_lmc, steps_logreg_lmc = run_lmc_numpyro(
    jr.PRNGKey(3),
    model_logreg,
    (x_train, labels_train),
    num_chains,
    num_samples_per_chain,
    chain_sep=0.2,
    tol=0.5,
    warmup_mult=warmup_len,
    warmup_tol_mult=20,
    use_adaptive=True,
)
print(jtu.tree_map(lambda x: x.shape, out_logreg_lmc))

100.00%|██████████| [02:18<00:00,  1.39s/%]
100.00%|██████████| [04:18<00:00,  2.59s/%]


LMC: gradient evaluations per output: 149.7
{'W': (32, 1024, 2), 'alpha': (32, 1024), 'b': (32, 1024, 1)}


In [6]:
eval_logreg(
    out_logreg_lmc, steps_logreg_lmc, gt_logreg, x_test=x_test, labels_test=labels_test
)
out_arr_lmc = vec_dict_to_array(out_logreg_lmc)
flattened_lmc = jnp.reshape(out_arr_lmc, (total_samples, 4))
vars_lmc = jnp.var(flattened_lmc, axis=0)
means_lmc = jnp.mean(flattened_lmc, axis=0)
print(f"vars: {vars_lmc}, means: {means_lmc}")

Effective sample size: 9518.90316636304
Energy dist vs self: 0.0001116, energy dist vs ground truth: 0.0108, test_accuracy: 0.5386, top 90% accuracy: 0.5495
vars: [420.794   0.005   0.008   0.008], means: [-9.681 -0.073 -0.021 -0.02 ]


In [7]:
nuts = MCMC(
    NUTS(model_logreg, step_size=1.0),
    num_warmup=warmup_len,
    num_samples=num_samples_per_chain,
    num_chains=num_chains,
    chain_method="vectorized",
)
# nuts.warmup(jr.PRNGKey(2), x_train, labels_train, extra_fields=("num_steps",))
# warmup_steps = nuts.get_extra_fields()["num_steps"]
nuts.run(jr.PRNGKey(2), x_train, labels_train, extra_fields=("num_steps",))
out_logreg_nuts = nuts.get_samples(group_by_chain=True)
num_steps_nuts = nuts.get_extra_fields()["num_steps"]
geps_nuts = sum(num_steps_nuts) / (num_chains * num_samples_per_chain)
print(geps_nuts)

sample: 100%|██████████| 5120/5120 [00:29<00:00, 173.69it/s]


10.07049560546875


In [8]:
eval_logreg(
    out_logreg_nuts, geps_nuts, gt_logreg, x_test=x_test, labels_test=labels_test
)
out_arr_nuts = vec_dict_to_array(out_logreg_nuts)
flattened_nuts = jnp.reshape(out_arr_nuts, (total_samples, 4))
vars_nuts = jnp.var(flattened_nuts, axis=0)
means_nuts = jnp.mean(flattened_nuts, axis=0)
print(f"vars: {vars_nuts}, means: {means_nuts}")

Effective sample size: 4931.994023257395
Energy dist vs self: 0.000405, energy dist vs ground truth: 3.029e-05, test_accuracy: 0.5341, top 90% accuracy: 0.5448
vars: [4498.434    0.002    0.002    0.002], means: [61.918 -0.043 -0.008 -0.008]


In [10]:
import os


def run_logreg_dataset(name):
    model_logreg, data_split = get_model_and_data(dataset, name)
    x_train, labels_train, x_test, labels_test = data_split

    # Compute ground truth
    ground_truth_filename = f"mcmc_data/{name}_ground_truth.npy"
    if os.path.isfile(ground_truth_filename):
        ground_truth = np.load(ground_truth_filename)
    else:
        gt_nuts = MCMC(
            NUTS(model_logreg, step_size=1.0), num_warmup=2**13, num_samples=2**18
        )
        gt_nuts.run(jr.PRNGKey(0), x_train, labels_train)
        ground_truth = vec_dict_to_array(gt_nuts.get_samples())
        np.save(f"mcmc_data/{name}_ground_truth.npy", ground_truth)
    # thin the samples to 2**16
    factor = int(ground_truth.shape[0] // 2**16)
    ground_truth = ground_truth[::factor]
    assert ground_truth.shape == (
        2**16,
        2 + x_train.shape[1],
    ), f"ground_truth.shape: {ground_truth.shape}"

    size_gt_half = int(ground_truth.shape[0] // 2)
    gt_energy_bias = energy_distance(
        ground_truth[:size_gt_half], ground_truth[size_gt_half:]
    )
    print(f"Ground truth energy bias: {gt_energy_bias:.4}")
    num_chains = 2**6
    num_samples_per_chain = 2**9
    warmup_len = 2**9

    out_logreg_lmc, steps_logreg_lmc = run_lmc_numpyro(
        jr.PRNGKey(0),
        model_logreg,
        (x_train, labels_train),
        num_chains,
        num_samples_per_chain,
        chain_sep=0.5,
        tol=1.0,
        warmup_mult=warmup_len,
        warmup_tol_mult=50,
    )

    eval_logreg(
        out_logreg_lmc,
        steps_logreg_lmc,
        ground_truth,
        x_test=x_test,
        labels_test=labels_test,
    )

    nuts = MCMC(
        NUTS(model_logreg),
        num_warmup=warmup_len,
        num_samples=num_samples_per_chain,
        num_chains=num_chains,
        chain_method="vectorized",
    )
    nuts.warmup(jr.PRNGKey(2), x_train, labels_train, extra_fields=("num_steps",))
    warmup_steps = nuts.get_extra_fields()["num_steps"]
    nuts.run(jr.PRNGKey(2), x_train, labels_train, extra_fields=("num_steps",))
    out_logreg_nuts = nuts.get_samples(group_by_chain=True)
    num_steps_nuts = nuts.get_extra_fields()["num_steps"] + warmup_steps
    geps_nuts = sum(num_steps_nuts) / (num_chains * num_samples_per_chain)
    print(f"NUTS: Gradient evals per output: {geps_nuts:.4}")
    eval_logreg(
        out_logreg_nuts, geps_nuts, ground_truth, x_test=x_test, labels_test=labels_test
    )

In [11]:
names = [
    # "banana",
    "breast_cancer",
    "diabetis",
    "flare_solar",
    "german",
    "heart",
    "image",
    "ringnorm",
    "splice",
    "thyroid",
    "titanic",
    "twonorm",
    "waveform",
]

In [12]:
# with warnings.catch_warnings():
#     warnings.simplefilter("ignore")
for name in names:
    print(f"==================== {name} ====================")
    run_logreg_dataset(name)
    print()

Data shape: (263, 9)
Ground truth energy bias: 5.265e-05


100.00%|██████████| [00:14<00:00,  6.89%/s]
100.00%|██████████| [01:40<00:00,  1.00s/%]


LMC: gradient evaluations per output: 52.1
Effective sample size: 11045.111485292075
Energy dist vs self: 8.343e-05, energy dist vs ground truth: 3.783, test_accuracy: 0.6571, top 90% accuracy: 0.6521


warmup: 100%|██████████| 512/512 [00:06<00:00, 73.54it/s] 
sample: 100%|██████████| 512/512 [00:01<00:00, 270.96it/s]


NUTS: Gradient evals per output: 8.305
Effective sample size: 29451.629207132286
Energy dist vs self: 0.0001881, energy dist vs ground truth: 0.0001118, test_accuracy: 0.6583, top 90% accuracy: 0.6531

Data shape: (768, 8)


sample: 100%|██████████| 270336/270336 [10:37<00:00, 424.05it/s, 7 steps of size 6.38e-01. acc. prob=0.84] 


Ground truth energy bias: 3.868e-05


100.00%|██████████| [00:30<00:00,  3.23%/s]
100.00%|██████████| [02:23<00:00,  1.43s/%]


LMC: gradient evaluations per output: 77.08
Effective sample size: 11888.729979305777
Energy dist vs self: 9.536e-05, energy dist vs ground truth: 1.599, test_accuracy: 0.7776, top 90% accuracy: 0.7748


warmup: 100%|██████████| 512/512 [00:08<00:00, 61.25it/s] 
sample: 100%|██████████| 512/512 [00:01<00:00, 306.22it/s]


NUTS: Gradient evals per output: 7.072
Effective sample size: 42109.23992533085
Energy dist vs self: 0.0001622, energy dist vs ground truth: 3.794e-05, test_accuracy: 0.7779, top 90% accuracy: 0.7754

Data shape: (144, 9)


sample:  44%|████▍     | 120204/270336 [50:38<1:03:15, 39.55it/s, 7 steps of size 2.51e-01. acc. prob=0.81]E0607 19:27:11.763629   14861 pjrt_stream_executor_client.cc:2809] Execution of replica 0 failed: INTERNAL: Failed to complete all kernels launched on stream 0x4a82090: Could not synchronize CUDA stream: CUDA_ERROR_LAUNCH_FAILED: unspecified launch failure



ValueError: INTERNAL: Failed to complete all kernels launched on stream 0x4a82090: Could not synchronize CUDA stream: CUDA_ERROR_LAUNCH_FAILED: unspecified launch failure