In [1]:
import warnings

import jax
import jax.numpy as jnp
import jax.random as jr
import jax.tree_util as jtu
import numpy as np
import scipy
from evaluation import (
    energy_distance,
    eval_logreg,
    flatten_samples,
    test_accuracy,
)
from get_model import get_model_and_data
from numpyro.infer import MCMC, NUTS

from mcmc import run_lmc_numpyro


%env JAX_PLATFORM_NAME=cuda
warnings.simplefilter("ignore", FutureWarning)

jnp.set_printoptions(precision=4, suppress=True)
jax.config.update("jax_enable_x64", True)
print(jax.devices("cuda"))

dataset = scipy.io.loadmat("mcmc_data/benchmarks.mat")
data_name = "flare_solar"
model_logreg, data_split = get_model_and_data(dataset, data_name)
x_train, labels_train, x_test, labels_test = data_split

env: JAX_PLATFORM_NAME=cuda
[CudaDevice(id=0)]
Data shape: (144, 9)


In [2]:
file_name = f"mcmc_data/{data_name}_ground_truth.npy"

# gt_nuts = MCMC(NUTS(model_logreg, step_size=1.0), num_warmup=2**14, num_samples=2**16)
# gt_nuts.run(jr.PRNGKey(0), x_train, labels_train)
# gt_logreg = vec_dict_to_array(gt_nuts.get_samples())
# np.save(file_name, gt_logreg)

gt_logreg = np.load(file_name)
size_gt_half = int(gt_logreg.shape[0] // 2)
energy_bias = energy_distance(gt_logreg[:size_gt_half], gt_logreg[size_gt_half:])
print(f"Energy bias: {energy_bias}")
print(f"Ground truth shape: {gt_logreg.shape}")
print(f"test accuracy: {test_accuracy(x_test, labels_test, gt_logreg)}")
flattened_gt = jnp.reshape(gt_logreg, (-1, 4))
print(flattened_gt.shape)
print(jnp.var(flattened_gt, axis=0))
print(jnp.mean(flattened_gt, axis=0))

FileNotFoundError: [Errno 2] No such file or directory: 'mcmc_data/flare_solar_ground_truth.npy'

In [3]:
num_chains = 2**4
num_samples_per_chain = 2**8
warmup_len = 2**9

In [3]:
out_logreg_lmc, steps_logreg_lmc = run_lmc_numpyro(
    jr.PRNGKey(2),
    model_logreg,
    (x_train, labels_train),
    num_chains,
    num_samples_per_chain,
    chain_sep=1.0,
    tol=0.2,
    warmup_mult=warmup_len,
    warmup_tol_mult=2,
    use_adaptive=False,
)
out_logreg_lmc["alpha"] = jnp.exp(out_logreg_lmc["alpha"])
print(jtu.tree_map(lambda x: x.shape, out_logreg_lmc))

100.00%|██████████| [02:27<00:00,  1.47s/%]
100.00%|██████████| [21:25<00:00, 12.86s/%]


{'W': (16, 4096, 9), 'alpha': (16, 4096), 'b': (16, 4096, 1)}


In [15]:
flat_lmc = flatten_samples(out_logreg_lmc)
outlier_positions = jnp.any(jnp.abs(flat_lmc) > 400, axis=1)
outliers = flat_lmc[outlier_positions]
print(outliers.shape)

(10, 10)


In [7]:
_ = eval_logreg(
    out_logreg_lmc, steps_logreg_lmc, x_test=x_test, labels_test=labels_test
)

means: [   1.287   38.399    0.242    0.207   -0.184    0.978 -175.385    3.333
    0.019    0.313  125.581],
vars:  [   54.842   494.392     0.408     0.21      0.191     0.59  10095.7
    69.699     0.029     0.275  5909.377]
Effective sample size: 43.88, ess per sample: 0.0006696, grad evals per sample: 112.6
Energy dist v self: 22.28
Test_accuracy: 0.6182, top 90% accuracy: 0.6266


In [10]:
lmc_last_sample = jtu.tree_map(lambda x: x[:num_chains, -1], out_logreg_lmc)
lmc_last_sample["alpha"] = jnp.log(lmc_last_sample["alpha"])
print(jtu.tree_map(lambda x: x.shape, lmc_last_sample))

{'W': (16, 9), 'alpha': (16,), 'b': (16, 1)}


In [5]:
nuts = MCMC(
    NUTS(model_logreg),
    num_warmup=warmup_len,
    num_samples=num_samples_per_chain,
    num_chains=num_chains,
    chain_method="vectorized",
)
nuts.warmup(
    jr.PRNGKey(2),
    x_train,
    labels_train,
    # init_params=lmc_last_sample,
    extra_fields=("num_steps",),
    collect_warmup=True,
)
nuts_steps_raw = nuts.get_extra_fields()["num_steps"]
warmup_steps = sum(nuts.get_extra_fields()["num_steps"])
nuts.run(jr.PRNGKey(2), x_train, labels_train, extra_fields=("num_steps",))
out_logreg_nuts = nuts.get_samples(group_by_chain=True)
num_steps_nuts = sum(nuts.get_extra_fields()["num_steps"]) + warmup_steps
geps_nuts = num_steps_nuts / (num_chains * num_samples_per_chain)
print(geps_nuts)
print(jtu.tree_map(lambda x: x.shape, out_logreg_nuts))

warmup: 100%|██████████| 512/512 [01:04<00:00,  7.99it/s]
sample: 100%|██████████| 256/256 [00:23<00:00, 10.78it/s]


926.51513671875
{'W': (16, 256, 9), 'b': (16, 256, 1)}


In [13]:
_ = eval_logreg(out_logreg_nuts, geps_nuts, x_test=x_test, labels_test=labels_test)

means: [22.47   0.05   0.018  0.025  0.     0.044 -0.015  0.038  0.002  0.023
  0.012],
vars:  [248.525   0.002   0.007   0.006   0.006   0.011   0.014   0.013   0.003
   0.007   0.009]
Effective sample size: 2.145e+04, ess per sample: 0.3273, grad evals per sample: 50.91
Energy dist v self: 5.708e-05
Test_accuracy: 0.5514, top 90% accuracy: 0.552


In [16]:
flat_nuts = flatten_samples(out_logreg_nuts)
enenrgy_dist = energy_distance(flat_nuts, flat_lmc)
print(enenrgy_dist)

241.1279099007155


In [4]:
names = [
    "banana",
    "breast_cancer",
    "diabetis",
    "flare_solar",
    "german",
    "heart",
    "image",
    "ringnorm",
    "splice",
    "thyroid",
    "titanic",
    "twonorm",
    "waveform",
]