In [1]:
import datetime
import sys
import warnings

import diffrax
import jax
import jax.numpy as jnp
import jax.random as jr
import scipy
from get_model import get_model_and_data  # noqa: F401
from lmc import run_simple_lmc_numpyro  # noqa: F401
from numpyro.infer import MCMC, NUTS, Predictive  # noqa: F401

from mcmc.evaluation.eval_gt_logreg import eval_gt_logreg
from mcmc.evaluation.metrics import adjust_max_len
from mcmc.experiment_main import run_experiment
from mcmc.progressive import (
    ProgressiveEvaluator,
    ProgressiveLMC,
    ProgressiveLogger,
    ProgressiveNUTS,
)


%env JAX_PLATFORM_NAME=cuda
warnings.simplefilter("ignore", FutureWarning)

jnp.set_printoptions(precision=3, suppress=True, threshold=sys.maxsize)
jax.config.update("jax_enable_x64", True)
print(jax.devices("cuda"))

dataset = scipy.io.loadmat("mcmc_data/benchmarks.mat")
names = [
    "banana",
    # "breast_cancer",
    # "diabetis",
    # "flare_solar",
    # "german",
    # "heart",
    # "image",
    # "ringnorm",
    # "splice",
    # "thyroid",
    # "titanic",
    # "twonorm",
    # "waveform",
]

env: JAX_PLATFORM_NAME=cuda
[CudaDevice(id=0)]


In [2]:
timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
get_result_filename = (
    lambda model_name: f"progressive_results/{model_name}_{timestamp}.pkl"
)
ground_truth_dirname = "ground_truth"

evaluator = ProgressiveEvaluator()
logger = ProgressiveLogger(log_filename=f"progressive_results/log_{timestamp}.txt")
logger.start_log(timestamp)
nust_method = ProgressiveNUTS(num_warmup=20, chain_len=2**6)
quic_method = ProgressiveLMC(2**5, 0.5, 0.07, diffrax.QUICSORT(0.1))
euler_method = ProgressiveLMC(2**5, 0.5, 0.03, diffrax.Euler())
methods = [nust_method, quic_method, euler_method]

tols_and_seps = {
    "banana": (0.04, 0.3),
    "splice": (0.01, 1.0),
    "flare_solar": (0.06, 2.0),
    "image": (0.07, 1.0),
    "waveform": (0.07, 1.0),
}

for name in names:
    model, model_args, test_args = get_model_and_data(dataset, name)
    data_dim = model_args[0].shape[1] + 1
    num_particles = adjust_max_len(2**14, data_dim)
    config = {
        "num_particles": num_particles,
        "test_args": test_args,
    }
    if name in tols_and_seps:
        quic_tol, chain_sep = tols_and_seps[name]
    else:
        quic_tol, chain_sep = 0.08, 0.5
    quic_method.tol, quic_method.chain_sep = quic_tol, chain_sep
    euler_method.tol, euler_method.chain_sep = quic_tol / 40, chain_sep / 20

    run_experiment(
        jr.key(0),
        model,
        model_args,
        name,
        methods,
        config,
        evaluator,
        logger,
        ground_truth_dirname,
        eval_gt_logreg,
        get_result_filename,
    )

Data shape: (5300, 2)

GT energy bias: 2.455e-05, test acc: 0.5498, test acc top 90%: 0.556



warmup: 100%|██████████| 20/20 [00:21<00:00,  1.07s/it]
sample: 100%|██████████| 44/44 [00:28<00:00,  1.56it/s]


NUTS: acc: 0.5506, acc top 90%: 0.557, energy: 1.658e-05, w2: 3.582e-04


100.00%|██████████| [00:04<00:00, 24.05%/s]


QUICSORT: acc: 0.5509, acc top 90%: 0.5575, energy: 2.590e-05, w2: 3.747e-04


100.00%|██████████| [00:05<00:00, 18.98%/s]


Euler: acc: 0.5444, acc top 90%: 0.5541, energy: 7.631e-03, w2: 1.584e-02
