In [1]:
%env JAX_PLATFORM_NAME=cuda

from functools import partial

import jax
import jax.numpy as jnp
import jax.random as jr
import jax.tree_util as jtu
import numpy as np
import numpyro
import ot
import scipy
from jax import Array
from numpyro import diagnostics, distributions as dist
from numpyro.infer import MCMC, NUTS

from mcmc import run_lmc_numpyro


jnp.set_printoptions(precision=3, suppress=True)
jax.config.update("jax_enable_x64", True)
print(jax.devices("cuda"))


def get_model_and_data(data, name):
    dset = data[name][0, 0]
    x = dset["x"]
    labels = jnp.squeeze(dset["t"])
    # labels are -1 and 1, convert to 0 and 1
    labels = (labels + 1) / 2
    n, data_dim = x.shape
    print(f"Data shape: {x.shape}")

    # randomly shuffle the data
    perm = jax.random.permutation(jr.PRNGKey(0), n)
    x = x[perm]
    labels = labels[perm]

    n_train = min(int(n * 0.8), 1000)
    x_train = x[:n_train]
    labels_train = labels[:n_train]
    x_test = x[n_train:]
    labels_test = labels[n_train:]

    def model(x, labels):
        alpha = numpyro.sample("alpha", dist.Exponential(0.01))
        W = numpyro.sample("W", dist.Normal(jnp.zeros(data_dim), 1.0 / alpha))
        b = numpyro.sample("b", dist.Normal(jnp.zeros((1,)), 1.0 / alpha))
        logits = jnp.sum(W * x + b, axis=-1)
        return numpyro.sample("obs", dist.Bernoulli(logits=logits), obs=labels)

    return model, (x_train, labels_train, x_test, labels_test)


def compute_w2(x1, x2, num_iters):
    source_samples = np.array(x1)
    target_samples = np.array(x2)
    source_weights = np.ones(source_samples.shape[0]) / source_samples.shape[0]
    target_weights = np.ones(target_samples.shape[0]) / target_samples.shape[0]
    mm = ot.dist(source_samples, target_samples)
    return ot.emd2(source_weights, target_weights, mm, numItermax=num_iters)


@partial(jax.jit, static_argnames=("max_len",))
def energy_distance(x: Array, y: Array, max_len: int = 2**15):
    assert y.ndim == x.ndim
    assert x.shape[1:] == y.shape[1:]
    if x.shape[0] > max_len:
        x = x[:max_len]
    if y.shape[0] > max_len:
        y = y[:max_len]

    @partial(jax.vmap, in_axes=(None, 0))
    def _dist_single(_x, _y_single):
        assert _x.ndim == _y_single.ndim + 1, f"{_x.ndim} != {_y_single.ndim + 1}"
        diff = _x - _y_single
        if x.ndim > 1:
            # take the norm over all axes except the first one
            diff = jnp.sqrt(jnp.sum(diff**2, axis=tuple(range(1, diff.ndim))))
        return jnp.mean(jnp.abs(diff))

    def dist(_x, _y):
        assert _x.ndim == _y.ndim
        return jnp.mean(_dist_single(_x, _y))

    return 2 * dist(x, y) - dist(x, x) - dist(y, y)


def dict_to_array(dct: dict):
    alpha = dct["alpha"]
    alpha = jnp.expand_dims(alpha, alpha.ndim)
    b = dct["b"]
    return jnp.concatenate([alpha, b, dct["W"]], axis=-1)


vec_dict_to_array = jax.jit(jax.vmap(dict_to_array, in_axes=0, out_axes=0))


def predict(x, samples):
    sum = jnp.sum(samples[:, 2:] * x + samples[:, 1:2], axis=-1)
    # apply sigmoid
    return 1.0 / (1.0 + jnp.exp(-sum))


def test_accuracy(x_test, labels_test, samples):
    if isinstance(samples, dict):
        samples = vec_dict_to_array(samples)
    sample_dim = samples.shape[-1]
    samples = jnp.reshape(samples, (-1, sample_dim))
    if samples.shape[0] > 2**10:
        samples = samples[: 2**10]

    func = jax.jit(jax.vmap(lambda x: predict(x, samples), in_axes=0, out_axes=0))
    predictions = func(x_test)
    assert predictions.shape == (
        labels_test.shape[0],
        samples.shape[0],
    ), f"{predictions.shape} != {(labels_test.shape[0], samples.shape[0])}"

    labels_test = jnp.reshape(labels_test, (labels_test.shape[0], 1))
    is_correct = jnp.abs(predictions - labels_test) < 0.5
    accuracy_per_sample = jnp.mean(is_correct, axis=0)

    avg_accuracy = jnp.mean(accuracy_per_sample)

    len90 = int(0.9 * accuracy_per_sample.shape[0])
    best_sorted = jnp.sort(accuracy_per_sample)[:len90]
    accuracy_best90 = jnp.mean(best_sorted)
    return avg_accuracy, accuracy_best90


def eval_logreg(
    samples,
    evals_per_sample,
    ground_truth=None,
    num_iters_w2=0,
    x_test=None,
    labels_test=None,
):
    if isinstance(samples, dict):
        samples = vec_dict_to_array(samples)
    ess = diagnostics.effective_sample_size(samples)
    print(f"Effective sample size: {ess}")
    # print(
    #     f"Gradient evals per effective sample:"
    #     f" {evals_per_sample * samples.shape[0]/ess}"
    # )
    sample_dim = samples.shape[-1]
    reshaped = jnp.reshape(samples, (-1, sample_dim))

    half_len = reshaped.shape[0] // 2
    energy_self = energy_distance(reshaped[:half_len], reshaped[half_len:])
    result_str = f"Energy dist vs self: {energy_self:.4}"

    if ground_truth is not None:
        energy_gt = energy_distance(reshaped, ground_truth)
        result_str += f", energy dist vs ground truth: {energy_gt:.4}"
    if num_iters_w2 > 0 and ground_truth is not None:
        w2 = compute_w2(reshaped, ground_truth, num_iters_w2)
        result_str += f", Wasserstein-2: {w2:.4}"

    if x_test is not None and labels_test is not None:
        acc_error, acc_best90 = test_accuracy(x_test, labels_test, samples)
        result_str += (
            f", test_accuracy: {acc_error:.4}, top 90% accuracy: {acc_best90:.4}"
        )
    print(result_str)


dataset = scipy.io.loadmat("mcmc_data/benchmarks.mat")
data_name = "banana"
model_logreg, data_split = get_model_and_data(dataset, data_name)
x_train, labels_train, x_test, labels_test = data_split

env: JAX_PLATFORM_NAME=cuda
[cuda(id=0)]
Data shape: (5300, 2)


In [3]:
file_name = f"mcmc_data/{data_name}_ground_truth.npy"

# gt_nuts = MCMC(NUTS(model_logreg, step_size=1.0), num_warmup=2**13, num_samples=2**16)
# gt_nuts.run(jr.PRNGKey(0), x_train, labels_train)
# gt_logreg = vec_dict_to_array(gt_nuts.get_samples())
# # thin the samples to 2**15
# np.save(file_name, gt_logreg)

gt_logreg = np.load(file_name)
gt_logreg = gt_logreg[::4]
size_gt_half = int(gt_logreg.shape[0] // 2)
energy_bias = energy_distance(gt_logreg[:size_gt_half], gt_logreg[size_gt_half:])
print(f"Energy bias: {energy_bias}")
print(f"Ground truth shape: {gt_logreg.shape}")
print(f"test accuracy: {test_accuracy(x_test, labels_test, gt_logreg)}")

Energy bias: 0.0018354477882169817
Ground truth shape: (65536, 4)
test accuracy: (Array(0.551, dtype=float32), Array(0.547, dtype=float32))


In [4]:
num_chains = 2**6
num_samples_per_chain = 2**9
warmup_len = 2**9

In [None]:
out_logreg_lmc, steps_logreg_lmc = run_lmc_numpyro(
    jr.PRNGKey(0),
    model_logreg,
    (x_train, labels_train),
    num_chains,
    num_samples_per_chain,
    chain_sep=0.5,
    tol=0.5,
    warmup_mult=warmup_len,
    warmup_tol_mult=50,
)
print(jtu.tree_map(lambda x: x.shape, out_logreg_lmc))

0.00%|          | [00:07<?, ?%/s]

In [11]:
eval_logreg(
    out_logreg_lmc, steps_logreg_lmc, gt_logreg, x_test=x_test, labels_test=labels_test
)

Effective sample size: [ 598.865  691.028 4758.201 2369.001]
Energy dist vs self: 0.000138, energy dist vs ground truth: 14.22, test_accuracy: 0.5477, top 90% accuracy: 0.5445


In [7]:
nuts = MCMC(
    NUTS(model_logreg),
    num_warmup=warmup_len,
    num_samples=num_samples_per_chain,
    num_chains=num_chains,
    chain_method="vectorized",
)
nuts.warmup(jr.PRNGKey(2), x_train, labels_train, extra_fields=("num_steps",))
warmup_steps = nuts.get_extra_fields()["num_steps"]
nuts.run(jr.PRNGKey(2), x_train, labels_train, extra_fields=("num_steps",))
out_logreg_nuts = nuts.get_samples(group_by_chain=True)
num_steps_nuts = nuts.get_extra_fields()["num_steps"] + warmup_steps
geps_nuts = sum(num_steps_nuts) / (num_chains * num_samples_per_chain)
print(geps_nuts)

warmup: 100%|██████████| 512/512 [00:07<00:00, 69.32it/s] 
sample: 100%|██████████| 512/512 [00:01<00:00, 284.49it/s]


6.866943359375


In [12]:
eval_logreg(
    out_logreg_nuts, geps_nuts, gt_logreg, x_test=x_test, labels_test=labels_test
)

Effective sample size: [  757.298  5807.821 23396.194 19655.407]
Energy dist vs self: 0.01987, energy dist vs ground truth: 0.04302, test_accuracy: 0.5514, top 90% accuracy: 0.5478


In [2]:
import os


def run_logreg_dataset(name):
    model_logreg, data_split = get_model_and_data(dataset, name)
    x_train, labels_train, x_test, labels_test = data_split

    # Compute ground truth
    ground_truth_filename = f"mcmc_data/{name}_ground_truth.npy"
    if os.path.isfile(ground_truth_filename):
        ground_truth = np.load(ground_truth_filename)
    else:
        gt_nuts = MCMC(
            NUTS(model_logreg, step_size=1.0), num_warmup=2**13, num_samples=2**18
        )
        gt_nuts.run(jr.PRNGKey(0), x_train, labels_train)
        ground_truth = vec_dict_to_array(gt_nuts.get_samples())
        np.save(f"mcmc_data/{name}_ground_truth.npy", ground_truth)
    # thin the samples to 2**16
    factor = int(ground_truth.shape[0] // 2**16)
    ground_truth = ground_truth[::factor]
    assert ground_truth.shape == (
        2**16,
        2 + x_train.shape[1],
    ), f"ground_truth.shape: {ground_truth.shape}"

    size_gt_half = int(ground_truth.shape[0] // 2)
    gt_energy_bias = energy_distance(
        ground_truth[:size_gt_half], ground_truth[size_gt_half:]
    )
    print(f"Ground truth energy bias: {gt_energy_bias:.4}")
    num_chains = 2**6
    num_samples_per_chain = 2**9
    warmup_len = 2**9

    out_logreg_lmc, steps_logreg_lmc = run_lmc_numpyro(
        jr.PRNGKey(0),
        model_logreg,
        (x_train, labels_train),
        num_chains,
        num_samples_per_chain,
        chain_sep=0.5,
        tol=0.5,
        warmup_mult=warmup_len,
        warmup_tol_mult=50,
    )

    eval_logreg(
        out_logreg_lmc,
        steps_logreg_lmc,
        ground_truth,
        x_test=x_test,
        labels_test=labels_test,
    )

    nuts = MCMC(
        NUTS(model_logreg),
        num_warmup=warmup_len,
        num_samples=num_samples_per_chain,
        num_chains=num_chains,
        chain_method="vectorized",
    )
    nuts.warmup(jr.PRNGKey(2), x_train, labels_train, extra_fields=("num_steps",))
    warmup_steps = nuts.get_extra_fields()["num_steps"]
    nuts.run(jr.PRNGKey(2), x_train, labels_train, extra_fields=("num_steps",))
    out_logreg_nuts = nuts.get_samples(group_by_chain=True)
    num_steps_nuts = nuts.get_extra_fields()["num_steps"] + warmup_steps
    geps_nuts = sum(num_steps_nuts) / (num_chains * num_samples_per_chain)
    print(f"NUTS: Gradient evals per output: {geps_nuts:.4}")
    eval_logreg(
        out_logreg_nuts, geps_nuts, ground_truth, x_test=x_test, labels_test=labels_test
    )

In [3]:
names = [
    "banana",
    "breast_cancer",
    "diabetis",
    "flare_solar",
    "german",
    "heart",
    "image",
    "ringnorm",
    "splice",
    "thyroid",
    "titanic",
    "twonorm",
    "waveform",
]

In [None]:
# with warnings.catch_warnings():
#     warnings.simplefilter("ignore")
for name in names:
    print(f"==================== {name} ====================")
    run_logreg_dataset(name)
    print()

Data shape: (5300, 2)
Ground truth energy bias: 0.001835


100.00%|██████████| [00:31<00:00,  3.19%/s]
100.00%|██████████| [07:41<00:00,  4.61s/%]


LMC: gradient evaluations per output: 142.9
Effective sample size: [ 926.834  684.512 4671.124 2378.13 ]
Energy dist vs self: 0.001877, energy dist vs ground truth: 14.04, test_accuracy: 0.5479, top 90% accuracy: 0.5447


warmup: 100%|██████████| 512/512 [00:07<00:00, 70.26it/s] 
sample: 100%|██████████| 512/512 [00:01<00:00, 285.76it/s]


NUTS: Gradient evals per output: 6.867
Effective sample size: [  757.298  5807.821 23396.194 19655.407]
Energy dist vs self: 0.01987, energy dist vs ground truth: 0.007543, test_accuracy: 0.5514, top 90% accuracy: 0.5478

Data shape: (263, 9)


sample: 100%|██████████| 270336/270336 [13:00<00:00, 346.30it/s, 15 steps of size 4.14e-01. acc. prob=0.92]


Ground truth energy bias: 5.265e-05


0.00%|          | [04:51<?, ?%/s]