In [1]:
import datetime
import sys
import warnings

import diffrax
import jax
import jax.numpy as jnp
import jax.random as jr
import scipy
from get_model import get_model_and_data  # noqa: F401
from lmc import run_simple_lmc_numpyro  # noqa: F401
from numpyro.infer import MCMC, NUTS, Predictive  # noqa: F401

from mcmc.evaluation.eval_gt_logreg import eval_gt_logreg
from mcmc.experiment_main import run_experiment
from mcmc.progressive import (
    ProgressiveEvaluator,
    ProgressiveLMC,
    ProgressiveLogger,
    ProgressiveNUTS,
)


%env JAX_PLATFORM_NAME=cuda
warnings.simplefilter("ignore", FutureWarning)

jnp.set_printoptions(precision=3, suppress=True, threshold=sys.maxsize)
jax.config.update("jax_enable_x64", True)
print(jax.devices("cuda"))

dataset = scipy.io.loadmat("mcmc_data/benchmarks.mat")
names = [
    "banana",
    "breast_cancer",
    "diabetis",
    "flare_solar",
    "german",
    "heart",
    "image",
    "ringnorm",
    "splice",
    "thyroid",
    "titanic",
    "twonorm",
    "waveform",
]

env: JAX_PLATFORM_NAME=cuda
[CudaDevice(id=0)]


In [2]:
timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")


def get_result_filename(model_name):
    return f"progressive_results/{model_name}_{timestamp}.pkl"


ground_truth_dirname = "ground_truth"

evaluator = ProgressiveEvaluator()
logger = ProgressiveLogger(log_filename=f"progressive_results/{timestamp}.log")
nust_method = ProgressiveNUTS(num_warmup=20, chain_len=2**6)
quic_method = ProgressiveLMC(2**5, 0.5, 0.07, diffrax.QUICSORT(0.1))
euler_method = ProgressiveLMC(2**5, 0.5, 0.03, diffrax.Euler())
methods = [nust_method, quic_method, euler_method]

tols_and_seps = {
    "banana": (0.04, 0.3),
    "splice": (0.01, 1.0),
    "flare_solar": (0.06, 2.0),
    "image": (0.07, 1.0),
    "waveform": (0.07, 1.0),
}

for name in names:
    model, (x_train, labels_train, x_test, labels_test) = get_model_and_data(
        dataset, name
    )
    model_args = (x_train, labels_train)
    config = {
        "num_particles": 2**14,
        "x_test": x_test,
        "labels_test": labels_test,
    }
    if name in tols_and_seps:
        quic_tol, chain_sep = tols_and_seps[name]
    else:
        quic_tol, chain_sep = 0.08, 0.5
    quic_method.tol, quic_method.chain_sep = quic_tol, chain_sep
    euler_method.tol, euler_method.chain_sep = quic_tol / 40, chain_sep / 20

    run_experiment(
        jr.key(0),
        model,
        model_args,
        name,
        methods,
        config,
        evaluator,
        logger,
        ground_truth_dirname,
        eval_gt_logreg,
        get_result_filename,
    )

Data shape: (5300, 2)
GT energy bias: 2.455e-05, test acc: 0.5498, test acc top 90%: 0.556
GT energy bias: 2.455e-05, test acc: 0.5498, test acc top 90%: 0.556



warmup: 100%|██████████| 20/20 [00:21<00:00,  1.08s/it]
  0%|          | 0/44 [00:03<?, ?it/s]


KeyboardInterrupt: 