In [1]:
%env JAX_PLATFORM_NAME=cuda

import warnings
from functools import partial

import jax
import jax.numpy as jnp
import jax.random as jr
import jax.tree_util as jtu
import numpy as np
import numpyro
import ot
import scipy
from jax import Array
from numpyro import diagnostics, distributions as dist
from numpyro.infer import MCMC, NUTS

from mcmc import run_lmc_numpyro


warnings.simplefilter("ignore", FutureWarning)

jnp.set_printoptions(precision=3, suppress=True)
jax.config.update("jax_enable_x64", True)
print(jax.devices("cuda"))


def get_model_and_data(data, name):
    dset = data[name][0, 0]
    x = dset["x"]
    labels = jnp.squeeze(dset["t"])
    # labels are -1 and 1, convert to 0 and 1
    labels = (labels + 1) / 2
    n, data_dim = x.shape
    print(f"Data shape: {x.shape}")

    # randomly shuffle the data
    perm = jax.random.permutation(jr.PRNGKey(0), n)
    x = x[perm]
    labels = labels[perm]

    n_train = min(int(n * 0.8), 500)
    x_train = x[:n_train]
    labels_train = labels[:n_train]
    x_test = x[n_train:]
    labels_test = labels[n_train:]

    def model(x, labels):
        alpha = numpyro.sample("alpha", dist.Exponential(0.01))
        W = numpyro.sample("W", dist.Normal(jnp.zeros(data_dim), 1.0 / alpha))
        b = numpyro.sample("b", dist.Normal(jnp.zeros((1,)), 1.0 / alpha))
        logits = jnp.sum(W * x + b, axis=-1)
        return numpyro.sample("obs", dist.Bernoulli(logits=logits), obs=labels)

    return model, (x_train, labels_train, x_test, labels_test)


def compute_w2(x1, x2, num_iters):
    source_samples = np.array(x1)
    target_samples = np.array(x2)
    source_weights = np.ones(source_samples.shape[0]) / source_samples.shape[0]
    target_weights = np.ones(target_samples.shape[0]) / target_samples.shape[0]
    mm = ot.dist(source_samples, target_samples)
    return ot.emd2(source_weights, target_weights, mm, numItermax=num_iters)


@partial(jax.jit, static_argnames=("max_len",))
def energy_distance(x: Array, y: Array, max_len: int = 2**15):
    assert y.ndim == x.ndim
    assert x.shape[1:] == y.shape[1:]
    if x.shape[0] > max_len:
        x = x[:max_len]
    if y.shape[0] > max_len:
        y = y[:max_len]

    @partial(jax.vmap, in_axes=(None, 0))
    def _dist_single(_x, _y_single):
        assert _x.ndim == _y_single.ndim + 1, f"{_x.ndim} != {_y_single.ndim + 1}"
        diff = _x - _y_single
        if x.ndim > 1:
            # take the norm over all axes except the first one
            diff = jnp.sqrt(jnp.sum(diff**2, axis=tuple(range(1, diff.ndim))))
        return jnp.mean(jnp.abs(diff))

    def dist(_x, _y):
        assert _x.ndim == _y.ndim
        return jnp.mean(_dist_single(_x, _y))

    return 2 * dist(x, y) - dist(x, x) - dist(y, y)


def dict_to_array(dct: dict):
    alpha = dct["alpha"]
    alpha = jnp.expand_dims(alpha, alpha.ndim)
    b = dct["b"]
    return jnp.concatenate([alpha, b, dct["W"]], axis=-1)


vec_dict_to_array = jax.jit(jax.vmap(dict_to_array, in_axes=0, out_axes=0))


def predict(x, samples):
    sum = jnp.sum(samples[:, 2:] * x + samples[:, 1:2], axis=-1)
    # apply sigmoid
    return 1.0 / (1.0 + jnp.exp(-sum))


def test_accuracy(x_test, labels_test, samples):
    if isinstance(samples, dict):
        samples = vec_dict_to_array(samples)
    sample_dim = samples.shape[-1]
    samples = jnp.reshape(samples, (-1, sample_dim))
    if samples.shape[0] > 2**10:
        samples = samples[: 2**10]

    func = jax.jit(jax.vmap(lambda x: predict(x, samples), in_axes=0, out_axes=0))
    predictions = func(x_test)
    assert predictions.shape == (
        labels_test.shape[0],
        samples.shape[0],
    ), f"{predictions.shape} != {(labels_test.shape[0], samples.shape[0])}"

    labels_test = jnp.reshape(labels_test, (labels_test.shape[0], 1))
    is_correct = jnp.abs(predictions - labels_test) < 0.5
    accuracy_per_sample = jnp.mean(is_correct, axis=0)

    avg_accuracy = jnp.mean(accuracy_per_sample)

    len10 = int(0.1 * accuracy_per_sample.shape[0])
    best_sorted = jnp.sort(accuracy_per_sample)[len10:]
    accuracy_best90 = jnp.mean(best_sorted)
    return avg_accuracy, accuracy_best90


def eval_logreg(
    samples,
    evals_per_sample=None,
    ground_truth=None,
    num_iters_w2=0,
    x_test=None,
    labels_test=None,
):
    if isinstance(samples, dict):
        samples = vec_dict_to_array(samples)

    sample_dim = samples.shape[-1]
    reshaped_with_alpha = jnp.reshape(samples, (-1, sample_dim))
    vars = jnp.var(reshaped_with_alpha, axis=0)
    means = jnp.mean(reshaped_with_alpha, axis=0)
    result_str = f"means: {means},\nvars:  {vars}"

    samples_with_alpha = samples
    samples = samples[..., 1:]
    ess = diagnostics.effective_sample_size(samples)
    avg_ess = 1 / jnp.mean(1 / jnp.stack(jtu.tree_leaves(ess)))
    result_str += f"\nEffective sample size: {avg_ess}"
    if evals_per_sample is not None:
        result_str += f", grad evals per sample: {evals_per_sample:.4}"

    reshaped = jnp.reshape(samples, (-1, sample_dim - 1))

    half_len = reshaped.shape[0] // 2
    energy_self = energy_distance(reshaped[:half_len], reshaped[half_len:])
    result_str += f"\nEnergy dist v self: {energy_self:.4}"

    if ground_truth is not None:
        ground_truth = ground_truth[..., 1:]
        energy_gt = energy_distance(reshaped, ground_truth)
        result_str += f", energy dist vs ground truth: {energy_gt:.4}"
    if num_iters_w2 > 0 and ground_truth is not None:
        w2 = compute_w2(reshaped, ground_truth, num_iters_w2)
        result_str += f", Wasserstein-2: {w2:.4}"

    if x_test is not None and labels_test is not None:
        acc_error, acc_best90 = test_accuracy(x_test, labels_test, samples_with_alpha)
        result_str += (
            f"\nTest_accuracy: {acc_error:.4}, top 90% accuracy: {acc_best90:.4}"
        )

    print(result_str)

    return result_str


dataset = scipy.io.loadmat("mcmc_data/benchmarks.mat")
data_name = "banana"
model_logreg, data_split = get_model_and_data(dataset, data_name)
x_train, labels_train, x_test, labels_test = data_split

env: JAX_PLATFORM_NAME=cuda
[cuda(id=0)]
Data shape: (5300, 2)


In [2]:
file_name = f"mcmc_data/{data_name}_ground_truth.npy"

# gt_nuts = MCMC(NUTS(model_logreg, step_size=1.0), num_warmup=2**13, num_samples=2**16)
# gt_nuts.run(jr.PRNGKey(0), x_train, labels_train)
# gt_logreg = vec_dict_to_array(gt_nuts.get_samples())
# # thin the samples to 2**15
# np.save(file_name, gt_logreg)

gt_logreg = np.load(file_name)
gt_logreg = gt_logreg[::4]
size_gt_half = int(gt_logreg.shape[0] // 2)
energy_bias = energy_distance(gt_logreg[:size_gt_half], gt_logreg[size_gt_half:])
print(f"Energy bias: {energy_bias}")
print(f"Ground truth shape: {gt_logreg.shape}")
print(f"test accuracy: {test_accuracy(x_test, labels_test, gt_logreg)}")

Energy bias: 0.251009472403517
Ground truth shape: (16384, 4)
test accuracy: (Array(0.532, dtype=float32), Array(0.543, dtype=float32))


In [3]:
flattened_gt = jnp.reshape(gt_logreg, (-1, 4))
print(flattened_gt.shape)
print(jnp.var(flattened_gt, axis=0))
print(jnp.mean(flattened_gt, axis=0))

(16384, 4)
[3464.127    0.002    0.002    0.002]
[57.174 -0.045 -0.008 -0.008]


In [4]:
num_chains = 2**5
num_samples_per_chain = 2**10
warmup_len = 2**13

In [5]:
out_logreg_lmc, steps_logreg_lmc = run_lmc_numpyro(
    jr.PRNGKey(3),
    model_logreg,
    (x_train, labels_train),
    num_chains,
    num_samples_per_chain,
    chain_sep=0.3,
    tol=0.05,
    warmup_mult=warmup_len,
    warmup_tol_mult=16,
    use_adaptive=False,
)
print(jtu.tree_map(lambda x: x.shape, out_logreg_lmc))

100.00%|██████████| [00:16<00:00,  6.11%/s]
100.00%|██████████| [00:41<00:00,  2.40%/s]


LMC: gradient evaluations per output: 20.04
{'W': (32, 1024, 2), 'alpha': (32, 1024), 'b': (32, 1024, 1)}


In [6]:
_ = eval_logreg(
    out_logreg_lmc, steps_logreg_lmc, gt_logreg, x_test=x_test, labels_test=labels_test
)

means: [-34.706  -0.124  -0.042  -0.043],
vars:  [2280.214    0.771    1.778    1.917]
Effective sample size: 24883.356050420618
Energy dist vs self: 0.003162, energy dist vs ground truth: 0.05075, test_accuracy: 0.5422, top 90% accuracy: 0.5526, Gradient evals per sample: 20.04


In [8]:
nuts = MCMC(
    NUTS(model_logreg),
    num_warmup=warmup_len,
    num_samples=num_samples_per_chain,
    num_chains=num_chains,
    chain_method="vectorized",
)
nuts.warmup(
    jr.PRNGKey(2),
    x_train,
    labels_train,
    extra_fields=("num_steps",),
    collect_warmup=True,
)
warmup_steps = sum(nuts.get_extra_fields()["num_steps"])
nuts.run(jr.PRNGKey(2), x_train, labels_train, extra_fields=("num_steps",))
out_logreg_nuts = nuts.get_samples(group_by_chain=True)
num_steps_nuts = sum(nuts.get_extra_fields()["num_steps"]) + warmup_steps
geps_nuts = num_steps_nuts / (num_chains * num_samples_per_chain)
print(geps_nuts)

warmup: 100%|██████████| 1024/1024 [00:11<00:00, 88.90it/s] 
sample: 100%|██████████| 1024/1024 [00:06<00:00, 159.68it/s]


21.57537841796875


In [9]:
print(jtu.tree_map(lambda x: x.shape, out_logreg_nuts))

{'W': (32, 1024, 2), 'alpha': (32, 1024), 'b': (32, 1024, 1)}


In [12]:
_ = eval_logreg(
    out_logreg_nuts, geps_nuts, gt_logreg, x_test=x_test, labels_test=labels_test
)

means: [58.19  -0.045 -0.008 -0.008],
vars:  [3766.227    0.002    0.002    0.002]
Effective sample size: 5715.352966006341
Energy dist vs self: 4.118e-05, energy dist vs ground truth: 7.018e-06, test_accuracy: 0.5263, top 90% accuracy: 0.537


In [5]:
def run_logreg_dataset(name, results_filename=None):
    model_logreg, data_split = get_model_and_data(dataset, name)
    x_train, labels_train, x_test, labels_test = data_split

    # Compute ground truth
    # ground_truth_filename = f"mcmc_data/{name}_ground_truth.npy"
    # if os.path.isfile(ground_truth_filename):
    #     ground_truth = np.load(ground_truth_filename)
    # else:
    #     gt_nuts = MCMC(
    #         NUTS(model_logreg, step_size=1.0), num_warmup=2**13, num_samples=2**18
    #     )
    #     gt_nuts.run(jr.PRNGKey(0), x_train, labels_train)
    #     ground_truth = vec_dict_to_array(gt_nuts.get_samples())
    #     np.save(f"mcmc_data/{name}_ground_truth.npy", ground_truth)
    # # thin the samples to 2**16
    # factor = int(ground_truth.shape[0] // 2**16)
    # ground_truth = ground_truth[::factor]
    # assert ground_truth.shape == (
    #     2**16,
    #     2 + x_train.shape[1],
    # ), f"ground_truth.shape: {ground_truth.shape}"
    #
    # size_gt_half = int(ground_truth.shape[0] // 2)
    # gt_energy_bias = energy_distance(
    #     ground_truth[:size_gt_half], ground_truth[size_gt_half:]
    # )
    # print(f"Ground truth energy bias: {gt_energy_bias:.4}")
    num_chains = 2**5
    num_samples_per_chain = 2**10
    warmup_len = 2**12

    out_logreg_lmc, geps_lmc = run_lmc_numpyro(
        jr.PRNGKey(3),
        model_logreg,
        (x_train, labels_train),
        num_chains,
        num_samples_per_chain,
        chain_sep=0.3,
        tol=0.05,
        warmup_mult=warmup_len,
        warmup_tol_mult=4,
        use_adaptive=False,
    )

    eval_lmc = eval_logreg(
        out_logreg_lmc,
        geps_lmc,
        ground_truth=None,
        x_test=x_test,
        labels_test=labels_test,
    )

    nuts = MCMC(
        NUTS(model_logreg),
        num_warmup=warmup_len,
        num_samples=num_samples_per_chain,
        num_chains=num_chains,
        chain_method="vectorized",
    )
    nuts.warmup(
        jr.PRNGKey(2),
        x_train,
        labels_train,
        extra_fields=("num_steps",),
        collect_warmup=True,
    )
    warmup_steps = sum(nuts.get_extra_fields()["num_steps"])
    nuts.run(jr.PRNGKey(2), x_train, labels_train, extra_fields=("num_steps",))
    out_logreg_nuts = nuts.get_samples(group_by_chain=True)
    num_steps_nuts = sum(nuts.get_extra_fields()["num_steps"]) + warmup_steps
    geps_nuts = num_steps_nuts / (num_chains * num_samples_per_chain)
    print("NUTS:")
    eval_nuts = eval_logreg(
        out_logreg_nuts,
        geps_nuts,
        ground_truth=None,
        x_test=x_test,
        labels_test=labels_test,
    )

    if results_filename is not None:
        with open(results_filename, "a") as f:
            f.write(f"{name}\n")
            f.write(f"LMC: {eval_lmc}\n")
            f.write(f"NUTS: {eval_nuts}\n\n")

In [6]:
names = [
    "banana",
    "breast_cancer",
    "diabetis",
    "flare_solar",
    "german",
    "heart",
    "image",
    "ringnorm",
    "splice",
    "thyroid",
    "titanic",
    "twonorm",
    "waveform",
]

In [7]:
# make a file for the results, which has date and time in the name
import datetime


results_filename = (
    f"mcmc_data/results_{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.txt"
)

# create the results file
with open(results_filename, "w") as f:
    f.write("Results\n\n")

# with warnings.catch_warnings():
#     warnings.simplefilter("ignore")
for name in names:
    print(f"==================== {name} ====================")
    run_logreg_dataset(name, results_filename)
    print()

Data shape: (5300, 2)


100.00%|██████████| [00:30<00:00,  3.23%/s]
70.96%|███████   | [00:28<00:11,  2.47%/s]E0726 16:19:02.423946   35629 pjrt_stream_executor_client.cc:2985] Execution of replica 0 failed: INTERNAL: CustomCall failed: CpuCallback error: Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/home/andy/PycharmProjects/diffrax_STLA/venv/lib/python3.11/site-packages/ipykernel_launcher.py", line 18, in <module>
  File "/home/andy/PycharmProjects/diffrax_STLA/venv/lib/python3.11/site-packages/traitlets/config/application.py", line 1075, in launch_instance
  File "/home/andy/PycharmProjects/diffrax_STLA/venv/lib/python3.11/site-packages/ipykernel/kernelapp.py", line 739, in start
  File "/home/andy/PycharmProjects/diffrax_STLA/venv/lib/python3.11/site-packages/tornado/platform/asyncio.py", line 205, in start
  File "/usr/lib/python3.11/asyncio/base_events.py", line 607, in run_forever
  File "/usr/lib/pyt

XlaRuntimeError: INTERNAL: CustomCall failed: CpuCallback error: Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/home/andy/PycharmProjects/diffrax_STLA/venv/lib/python3.11/site-packages/ipykernel_launcher.py", line 18, in <module>
  File "/home/andy/PycharmProjects/diffrax_STLA/venv/lib/python3.11/site-packages/traitlets/config/application.py", line 1075, in launch_instance
  File "/home/andy/PycharmProjects/diffrax_STLA/venv/lib/python3.11/site-packages/ipykernel/kernelapp.py", line 739, in start
  File "/home/andy/PycharmProjects/diffrax_STLA/venv/lib/python3.11/site-packages/tornado/platform/asyncio.py", line 205, in start
  File "/usr/lib/python3.11/asyncio/base_events.py", line 607, in run_forever
  File "/usr/lib/python3.11/asyncio/base_events.py", line 1922, in _run_once
  File "/usr/lib/python3.11/asyncio/events.py", line 80, in _run
  File "/home/andy/PycharmProjects/diffrax_STLA/venv/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 545, in dispatch_queue
  File "/home/andy/PycharmProjects/diffrax_STLA/venv/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 534, in process_one
  File "/home/andy/PycharmProjects/diffrax_STLA/venv/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 437, in dispatch_shell
  File "/home/andy/PycharmProjects/diffrax_STLA/venv/lib/python3.11/site-packages/ipykernel/ipkernel.py", line 362, in execute_request
  File "/home/andy/PycharmProjects/diffrax_STLA/venv/lib/python3.11/site-packages/ipykernel/kernelbase.py", line 778, in execute_request
  File "/home/andy/PycharmProjects/diffrax_STLA/venv/lib/python3.11/site-packages/ipykernel/ipkernel.py", line 449, in do_execute
  File "/home/andy/PycharmProjects/diffrax_STLA/venv/lib/python3.11/site-packages/ipykernel/zmqshell.py", line 549, in run_cell
  File "/home/andy/PycharmProjects/diffrax_STLA/venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3075, in run_cell
  File "/home/andy/PycharmProjects/diffrax_STLA/venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3130, in _run_cell
  File "/home/andy/PycharmProjects/diffrax_STLA/venv/lib/python3.11/site-packages/IPython/core/async_helpers.py", line 129, in _pseudo_sync_runner
  File "/home/andy/PycharmProjects/diffrax_STLA/venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3334, in run_cell_async
  File "/home/andy/PycharmProjects/diffrax_STLA/venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3517, in run_ast_nodes
  File "/home/andy/PycharmProjects/diffrax_STLA/venv/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3577, in run_code
  File "/tmp/ipykernel_35629/3213873017.py", line 14, in <module>
  File "/tmp/ipykernel_35629/260045252.py", line 36, in run_logreg_dataset
  File "/home/andy/PycharmProjects/diffrax_STLA/mcmc/main.py", line 41, in run_lmc_numpyro
  File "/home/andy/PycharmProjects/diffrax_STLA/mcmc/main.py", line 145, in run_lmc
  File "/home/andy/PycharmProjects/diffrax_STLA/venv/lib/python3.11/site-packages/equinox/_jit.py", line 206, in __call__
  File "/home/andy/PycharmProjects/diffrax_STLA/venv/lib/python3.11/site-packages/equinox/_module.py", line 1053, in __call__
  File "/home/andy/PycharmProjects/diffrax_STLA/venv/lib/python3.11/site-packages/equinox/_jit.py", line 200, in _call
  File "/home/andy/PycharmProjects/diffrax_STLA/venv/lib/python3.11/site-packages/jax/_src/traceback_util.py", line 179, in reraise_with_filtered_traceback
  File "/home/andy/PycharmProjects/diffrax_STLA/venv/lib/python3.11/site-packages/jax/_src/pjit.py", line 327, in cache_miss
  File "/home/andy/PycharmProjects/diffrax_STLA/venv/lib/python3.11/site-packages/jax/_src/pjit.py", line 185, in _python_pjit_helper
  File "/home/andy/PycharmProjects/diffrax_STLA/venv/lib/python3.11/site-packages/jax/_src/core.py", line 2834, in bind
  File "/home/andy/PycharmProjects/diffrax_STLA/venv/lib/python3.11/site-packages/jax/_src/core.py", line 420, in bind_with_trace
  File "/home/andy/PycharmProjects/diffrax_STLA/venv/lib/python3.11/site-packages/jax/_src/core.py", line 921, in process_primitive
  File "/home/andy/PycharmProjects/diffrax_STLA/venv/lib/python3.11/site-packages/jax/_src/pjit.py", line 1635, in _pjit_call_impl
  File "/home/andy/PycharmProjects/diffrax_STLA/venv/lib/python3.11/site-packages/jax/_src/pjit.py", line 1614, in call_impl_cache_miss
  File "/home/andy/PycharmProjects/diffrax_STLA/venv/lib/python3.11/site-packages/jax/_src/pjit.py", line 1568, in _pjit_call_impl_python
  File "/home/andy/PycharmProjects/diffrax_STLA/venv/lib/python3.11/site-packages/jax/_src/profiler.py", line 335, in wrapper
  File "/home/andy/PycharmProjects/diffrax_STLA/venv/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 1244, in __call__
  File "/home/andy/PycharmProjects/diffrax_STLA/venv/lib/python3.11/site-packages/jax/_src/interpreters/mlir.py", line 2476, in _wrapped_callback
KeyboardInterrupt: 