In [2]:
import datetime
import sys
import warnings

import diffrax
import jax
import jax.numpy as jnp
import jax.random as jr
import scipy
from get_model import get_model_and_data  # noqa: F401
from lmc import run_simple_lmc_numpyro  # noqa: F401
from numpyro.infer import MCMC, NUTS, Predictive  # noqa: F401

from mcmc.evaluation.eval_gt_logreg import eval_gt_logreg
from mcmc.evaluation.metrics import adjust_max_len
from mcmc.experiment_main import run_experiment
from mcmc.progressive import (
    ProgressiveEvaluator,
    ProgressiveLMC,
    ProgressiveLogger,
    ProgressiveNUTS,
)
from mcmc.utils import compute_gt_logreg


%env JAX_PLATFORM_NAME=cuda
warnings.simplefilter("ignore", FutureWarning)

jnp.set_printoptions(precision=3, suppress=True, threshold=sys.maxsize)
jax.config.update("jax_enable_x64", True)
print(jax.devices("cuda"))

dataset = scipy.io.loadmat("mcmc_data/benchmarks.mat")
names = [
    "banana",
    "breast_cancer",
    "diabetis",
    "flare_solar",
    "german",
    "heart",
    "image",
    "ringnorm",
    "splice",
    "thyroid",
    "titanic",
    "twonorm",
    "waveform",
]

env: JAX_PLATFORM_NAME=cuda
[CudaDevice(id=0)]


In [3]:
timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
get_result_filename = lambda name: f"progressive_results/{name}_pid_{timestamp}.pkl"
get_prev_result_filename = lambda name: f"progressive_results/{name}_*.pkl"
ground_truth_dirname = "ground_truth"

evaluator = ProgressiveEvaluator()
logger = ProgressiveLogger(log_filename=f"progressive_results/log_{timestamp}.txt")
logger.start_log(timestamp)

nust = ProgressiveNUTS(20, 2**6, get_prev_result_filename)

USE_PID = False


def make_pid(atol, dt0):
    if not USE_PID:
        return None
    return diffrax.PIDController(
        atol=atol,
        rtol=0.0,
        dtmax=0.5,
        dtmin=dt0 / 10,
        pcoeff=0.1,
        icoeff=0.4,
    )


quic_kwargs = {
    "chain_len": 2**5,
    "chain_sep": 1.0,
    "dt0": 0.07,
    "solver": diffrax.QUICSORT(0.1),
    "pid": make_pid(0.1, 0.07),
}
quic = ProgressiveLMC(quic_kwargs)
euler_kwargs = {
    "chain_len": 2**5,
    "chain_sep": 0.5,
    "dt0": 0.03,
    "solver": diffrax.Euler(),
    "pid": make_pid(0.1, 0.03),
}
euler = ProgressiveLMC(euler_kwargs)
methods = [nust, quic]

dt0s = {
    "banana": 0.04,
    "splice": 0.01,
    "flare_solar": 0.08,
}
seps = {
    "banana": 0.3,
    "splice": 1.0,
    "flare_solar": 2.0,
    "image": 1.0,
    "waveform": 1.0,
}
atols = {}


for name in names:
    model, model_args, test_args = get_model_and_data(dataset, name)
    data_dim = model_args[0].shape[1] + 1
    num_particles = adjust_max_len(2**14, data_dim)
    config = {
        "num_particles": num_particles,
        "test_args": test_args,
    }
    quic_dt0 = dt0s.get(name, 0.07)
    chain_sep = seps.get(name, 0.5)
    atol = atols.get(name, 0.8)
    quic.lmc_kwargs["dt0"], quic.lmc_kwargs["chain_sep"] = quic_dt0, chain_sep
    quic.lmc_kwargs["pid"] = make_pid(atol, quic_dt0)
    euler.lmc_kwargs["dt0"], euler.lmc_kwargs["chain_sep"] = (
        quic_dt0 / 40,
        chain_sep / 20,
    )
    euler.lmc_kwargs["pid"] = make_pid(atol, quic_dt0 / 20)

    run_experiment(
        jr.key(0),
        model,
        model_args,
        name,
        methods,
        config,
        evaluator,
        logger,
        compute_gt_logreg,
        ground_truth_dirname,
        eval_gt_logreg,
        get_result_filename,
    )

Data shape: (5300, 2)

GT energy bias: 2.455e-05, test acc: 0.5498, test acc top 90%: 0.556

NUTS: acc: 0.5506, acc top 90%: 0.557, energy: 1.658e-05, w2: 3.582e-04


100.00%|██████████| [00:04<00:00, 24.13%/s]


avg accepted: 248.00, avg rejected: 0.00
QUICSORT: acc: 0.5509, acc top 90%: 0.5575, energy: 2.590e-05, w2: 3.747e-04
Data shape: (263, 9)

GT energy bias: 9.460e-05, test acc: 0.6627, test acc top 90%: 0.6677

NUTS: acc: 0.6635, acc top 90%: 0.6696, energy: 9.450e-05, w2: 1.353e-01


100.00%|██████████| [00:02<00:00, 45.24%/s]


avg accepted: 248.00, avg rejected: 0.00
QUICSORT: acc: 0.6659, acc top 90%: 0.6717, energy: 1.175e-04, w2: 1.356e-01
Data shape: (768, 8)

GT energy bias: 4.407e-05, test acc: 0.7772, test acc top 90%: 0.7798

NUTS: acc: 0.7785, acc top 90%: 0.7811, energy: 5.960e-05, w2: 4.012e-02


100.00%|██████████| [00:02<00:00, 37.01%/s]


avg accepted: 248.00, avg rejected: 0.00
QUICSORT: acc: 0.7809, acc top 90%: 0.7836, energy: 8.609e-05, w2: 3.985e-02
Data shape: (144, 9)

GT energy bias: 8.068e-04, test acc: 0.6169, test acc top 90%: 0.6247

NUTS: acc: 0.6176, acc top 90%: 0.6254, energy: 1.782e-02, w2: 1.985e+00


100.00%|██████████| [00:06<00:00, 14.46%/s]


avg accepted: 775.00, avg rejected: 0.00
QUICSORT: acc: 0.6269, acc top 90%: 0.6357, energy: 6.101e-03, w2: 1.788e+00
Data shape: (1000, 20)

GT energy bias: 1.110e-04, test acc: 0.7837, test acc top 90%: 0.7872

NUTS: acc: 0.7844, acc top 90%: 0.7879, energy: 1.479e-04, w2: 1.621e-01


100.00%|██████████| [00:02<00:00, 33.64%/s]


avg accepted: 248.00, avg rejected: 0.00
QUICSORT: acc: 0.7852, acc top 90%: 0.7888, energy: 2.317e-04, w2: 1.584e-01
Data shape: (270, 13)

GT energy bias: 1.212e-04, test acc: 0.8014, test acc top 90%: 0.8094

NUTS: acc: 0.8034, acc top 90%: 0.8114, energy: 1.403e-04, w2: 3.397e-01


100.00%|██████████| [00:02<00:00, 36.36%/s]


avg accepted: 248.00, avg rejected: 0.00
QUICSORT: acc: 0.806, acc top 90%: 0.8141, energy: 1.557e-04, w2: 3.383e-01
Data shape: (2086, 18)

GT energy bias: 3.996e-04, test acc: 0.821, test acc top 90%: 0.8223

NUTS: acc: 0.8214, acc top 90%: 0.8227, energy: 5.527e-04, w2: 7.458e-01


100.00%|██████████| [00:05<00:00, 17.89%/s]


avg accepted: 465.00, avg rejected: 0.00
QUICSORT: acc: 0.8237, acc top 90%: 0.8252, energy: 4.670e-04, w2: 7.302e-01
Data shape: (7400, 20)

GT energy bias: 1.145e-04, test acc: 0.7575, test acc top 90%: 0.7586

NUTS: acc: 0.7577, acc top 90%: 0.7588, energy: 1.401e-04, w2: 1.171e-01


100.00%|██████████| [00:03<00:00, 31.82%/s]


avg accepted: 248.00, avg rejected: 0.00
QUICSORT: acc: 0.7584, acc top 90%: 0.7596, energy: 3.407e-04, w2: 1.139e-01
Data shape: (2991, 60)

GT energy bias: 3.211e-04, test acc: 0.8264, test acc top 90%: 0.8275

NUTS: acc: 0.8262, acc top 90%: 0.8274, energy: 2.715e-02, w2: 9.307e-01


100.00%|██████████| [02:02<00:00,  1.22s/%]


avg accepted: 3100.00, avg rejected: 0.00
QUICSORT: acc: 0.8269, acc top 90%: 0.8281, energy: 3.790e-04, w2: 8.340e-01
Data shape: (215, 5)

GT energy bias: 6.158e-05, test acc: 0.7996, test acc top 90%: 0.8062

NUTS: acc: 0.801, acc top 90%: 0.8077, energy: 1.071e-04, w2: 8.024e-02


100.00%|██████████| [00:01<00:00, 56.72%/s]


avg accepted: 248.00, avg rejected: 0.00
QUICSORT: acc: 0.8256, acc top 90%: 0.8315, energy: 1.124e-04, w2: 8.255e-02
Data shape: (24, 3)

GT energy bias: 8.428e-05, test acc: 0.5117, test acc top 90%: 0.5328

NUTS: acc: 0.5205, acc top 90%: 0.5419, energy: 6.772e-05, w2: 2.508e-02


100.00%|██████████| [00:01<00:00, 72.63%/s]


avg accepted: 248.00, avg rejected: 0.00
QUICSORT: acc: 0.5266, acc top 90%: 0.5447, energy: 6.559e-05, w2: 2.529e-02
Data shape: (7400, 20)

GT energy bias: 2.240e-04, test acc: 0.9719, test acc top 90%: 0.9724

NUTS: acc: 0.9722, acc top 90%: 0.9726, energy: 3.368e-04, w2: 6.814e-01


100.00%|██████████| [00:03<00:00, 31.84%/s]


avg accepted: 248.00, avg rejected: 0.00
QUICSORT: acc: 0.9721, acc top 90%: 0.9726, energy: 3.493e-04, w2: 6.945e-01
Data shape: (5000, 21)

GT energy bias: 2.431e-04, test acc: 0.8749, test acc top 90%: 0.8758

NUTS: acc: 0.8753, acc top 90%: 0.8761, energy: 2.966e-04, w2: 4.117e-01


100.00%|██████████| [00:06<00:00, 16.22%/s]


avg accepted: 465.00, avg rejected: 0.00
QUICSORT: acc: 0.8756, acc top 90%: 0.8765, energy: 3.020e-04, w2: 4.065e-01
