In [1]:
import sys
import warnings

import jax
import jax.numpy as jnp
import jax.random as jr
import diffrax
import datetime

from mcmc.logreg_utils import get_model_and_data  # noqa: F401
from lmc import run_simple_lmc_numpyro  # noqa: F401
from numpyro.infer import MCMC, NUTS, Predictive  # noqa: F401

from mcmc.bayes_nn.bnn_evaluator import ProgBNNEvaluator, BNNLogger
from mcmc.bayes_nn.bnn_utils import get_model_and_data, get_gt_bnn, eval_gt_bnn
from mcmc.progressive import ProgressiveNUTS, ProgressiveLMC
from mcmc.metrics import adjust_max_len
from mcmc.experiment_main import run_experiment


%env JAX_PLATFORM_NAME=cuda
warnings.simplefilter("ignore", FutureWarning)

jnp.set_printoptions(precision=3, suppress=True, threshold=sys.maxsize)
jax.config.update("jax_enable_x64", True)
print(jax.devices("cuda"))

env: JAX_PLATFORM_NAME=cuda
[CudaDevice(id=0)]


In [2]:
model, model_args, test_args = get_model_and_data()
timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
get_result_filename = lambda name: f"progressive_results/{name}_pid_{timestamp}.pkl"
get_prev_result_filename = lambda name: f"progressive_results/{name}_*.pkl"
data_dim = model_args[0].shape[1] + 1
num_particles = adjust_max_len(2**6, data_dim)
config = {
    "num_particles": num_particles,
    "test_args": test_args,
}

evaluator = ProgBNNEvaluator()
logger = BNNLogger(log_filename=f"bnn_results/log_{timestamp}.txt")
logger.start_log(timestamp)

nust = ProgressiveNUTS(20, 2**5)

USE_PID = False


def make_pid(atol, dt0):
    if not USE_PID:
        return None
    return diffrax.PIDController(
        atol=atol,
        rtol=0.0,
        dtmax=0.5,
        dtmin=dt0 / 10,
        pcoeff=0.1,
        icoeff=0.4,
    )


quic_kwargs = {
    "chain_len": 2**5,
    "chain_sep": 1.0,
    "dt0": 0.07,
    "solver": diffrax.QUICSORT(0.1),
    "pid": make_pid(0.1, 0.07),
}
quic = ProgressiveLMC(quic_kwargs)

run_experiment(
    jr.key(0),
    model,
    model_args,
    "bnn",
    [nust, quic],
    config,
    evaluator,
    logger,
    get_gt_bnn,
    eval_gt_bnn,
    get_result_filename,
)


sample energy bias: 5.108e-01, mean_err: 1.192, pred_energy_err: 36.9



warmup: 100%|██████████| 20/20 [00:05<00:00,  3.33it/s]
sample: 100%|██████████| 12/12 [00:05<00:00,  2.25it/s]


vec_predict. Sample shape: {'prec_obs': (32,), 'w1': (32, 3, 4), 'w2': (32, 4, 4), 'w3': (32, 4, 1)}, X shape: (500, 3)


AssertionError: 