In [1]:
import datetime
import sys
import warnings

import diffrax
import jax
import jax.numpy as jnp
import jax.random as jr
import scipy
from get_model import get_model_and_data  # noqa: F401
from lmc import run_simple_lmc_numpyro  # noqa: F401
from numpyro.infer import MCMC, NUTS, Predictive  # noqa: F401

from mcmc.evaluation.eval_gt_logreg import eval_gt_logreg
from mcmc.evaluation.metrics import adjust_max_len
from mcmc.experiment_main import run_experiment
from mcmc.progressive import (
    ProgressiveEvaluator,
    ProgressiveLMC,
    ProgressiveLogger,
    ProgressiveNUTS,
)


%env JAX_PLATFORM_NAME=cuda
warnings.simplefilter("ignore", FutureWarning)

jnp.set_printoptions(precision=3, suppress=True, threshold=sys.maxsize)
jax.config.update("jax_enable_x64", True)
print(jax.devices("cuda"))

dataset = scipy.io.loadmat("mcmc_data/benchmarks.mat")
names = [
    "banana",
    "breast_cancer",
    "diabetis",
    "flare_solar",
    "german",
    "heart",
    "image",
    "ringnorm",
    "splice",
    "thyroid",
    "titanic",
    "twonorm",
    "waveform",
]

env: JAX_PLATFORM_NAME=cuda
[CudaDevice(id=0)]


In [2]:
timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
get_result_filename = lambda name: f"progressive_results/{name}_{timestamp}.pkl"
get_prev_result_filename = lambda name: f"progressive_results/{name}_*.pkl"
ground_truth_dirname = "ground_truth"

evaluator = ProgressiveEvaluator()
logger = ProgressiveLogger(log_filename=f"progressive_results/log_pid_{timestamp}.txt")
logger.start_log(timestamp)

nust = ProgressiveNUTS(20, 2**6, get_prev_result_filename)


def make_pid(atol, dt0):
    return diffrax.PIDController(
        atol=atol,
        rtol=0.0,
        dtmax=0.5,
        dtmin=dt0 / 5,
        pcoeff=0.2,
        icoeff=0.6,
    )


quic_kwargs = {
    "chain_len": 2**5,
    "chain_sep": 1.0,
    "dt0": 0.07,
    "solver": diffrax.QUICSORT(0.1),
    "pid": make_pid(0.1, 0.07),
}
quic = ProgressiveLMC(quic_kwargs)
euler_kwargs = {
    "chain_len": 2**5,
    "chain_sep": 0.5,
    "dt0": 0.03,
    "solver": diffrax.Euler(),
    "pid": make_pid(0.1, 0.03),
}
euler = ProgressiveLMC(euler_kwargs)
methods = [nust, quic, euler]

dt0s_seps_atols = {
    "banana": (0.04, 0.5, 0.1),
    "splice": (0.01, 1.0, 0.1),
    "flare_solar": (0.06, 2.0, 0.1),
    "image": (0.07, 1.0, 0.1),
    "waveform": (0.07, 1.0, 0.1),
}

for name in names:
    model, model_args, test_args = get_model_and_data(dataset, name)
    data_dim = model_args[0].shape[1] + 1
    num_particles = adjust_max_len(2**14, data_dim)
    config = {
        "num_particles": num_particles,
        "test_args": test_args,
    }
    if name in dt0s_seps_atols:
        quic_dt0, chain_sep, atol = dt0s_seps_atols[name]
    else:
        quic_dt0, chain_sep, atol = 0.08, 1.0, 0.1
    quic.lmc_kwargs["dt0"], quic.lmc_kwargs["chain_sep"] = quic_dt0, chain_sep
    quic.lmc_kwargs["pid"] = make_pid(atol, quic_dt0)
    euler.lmc_kwargs["dt0"], euler.lmc_kwargs["chain_sep"] = (
        quic_dt0 / 20,
        chain_sep / 10,
    )
    euler.lmc_kwargs["pid"] = make_pid(atol, quic_dt0 / 20)

    run_experiment(
        jr.key(0),
        model,
        model_args,
        name,
        methods,
        config,
        evaluator,
        logger,
        ground_truth_dirname,
        eval_gt_logreg,
        get_result_filename,
    )

Data shape: (5300, 2)

GT energy bias: 2.455e-05, test acc: 0.5498, test acc top 90%: 0.556

Loading progressive_results/banana_2024-10-02_12-56-44.pkl
NUTS: acc: 0.5506, acc top 90%: 0.557, energy: 1.658e-05, w2: 3.582e-04
Loading progressive_results/banana_2024-10-02_12-56-44.pkl
QUICSORT: acc: 0.5509, acc top 90%: 0.5575, energy: 2.590e-05, w2: 3.747e-04
Loading progressive_results/banana_2024-10-02_12-56-44.pkl
Euler: acc: 0.5444, acc top 90%: 0.5541, energy: 7.631e-03, w2: 1.584e-02
