In [1]:
import datetime
import warnings

import jax
import jax.numpy as jnp
from logreg_full_run import run_logreg_dataset


%env JAX_PLATFORM_NAME=cuda
warnings.simplefilter("ignore", FutureWarning)

jnp.set_printoptions(precision=4, suppress=True)
jax.config.update("jax_enable_x64", True)
print(jax.devices("cuda"))

names = [
    "banana",
    "breast_cancer",
    "diabetis",
    "flare_solar",
    "german",
    "heart",
    "image",
    "ringnorm",
    "splice",
    "thyroid",
    "titanic",
    "twonorm",
    "waveform",
]

env: JAX_PLATFORM_NAME=cuda
[CudaDevice(id=0)]


In [2]:
time = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
results_filename = f"mcmc_data/results_{time}.txt"
results_dict_filename = f"mcmc_data/results_dict_{time}.pkl"

# create the results file
with open(results_filename, "w") as f:
    f.write("Results\n\n")

for name in names:
    print(f"==================== {name} ====================")
    run_logreg_dataset(name, results_filename, results_dict_filename)
    print()

Data shape: (5300, 2)
GT energy bias: 1.374e-05, test acc: 0.5468, test acc top 90%: 0.5547


warmup: 100%|██████████| 128/128 [00:03<00:00, 36.58it/s] 
sample: 100%|██████████| 256/256 [00:03<00:00, 69.52it/s] 


NUTS:

Effective sample size: 6.155e+04, ess per sample: 0.9391, grad evals per sample: 7.627
Energy dist v self: 1.596e-05, energy dist vs ground truth: 9.843e-06, Wasserstein-2: 0.0004232
Test_accuracy: 0.5506, top 90% accuracy: 0.5558
LMC:
Target chain separation: 0.09642


100.00%|██████████| [00:00<00:00, 158.77%/s]
100.00%|██████████| [00:03<00:00, 29.57%/s]



Effective sample size: 6.071e+04, ess per sample: 0.9263, grad evals per sample: 7.32
Energy dist v self: 7.002e-06, energy dist vs ground truth: 2.213e-05, Wasserstein-2: 0.0004446
Test_accuracy: 0.5499, top 90% accuracy: 0.5561
Energy distance between LMC and NUTS: 1.4581e-05

Data shape: (263, 9)
GT energy bias: 2.373e-04, test acc: 0.6599, test acc top 90%: 0.6655


warmup: 100%|██████████| 128/128 [00:02<00:00, 45.94it/s] 
sample: 100%|██████████| 256/256 [00:03<00:00, 85.05it/s] 


NUTS:

Effective sample size: 7.62e+04, ess per sample: 1.163, grad evals per sample: 12.66
Energy dist v self: 0.0003218, energy dist vs ground truth: 0.0003513, Wasserstein-2: 0.1546
Test_accuracy: 0.663, top 90% accuracy: 0.6682
LMC:
Target chain separation: 0.16


100.00%|██████████| [00:01<00:00, 85.29%/s]
100.00%|██████████| [00:05<00:00, 17.50%/s]



Effective sample size: 2.741e+04, ess per sample: 0.4182, grad evals per sample: 12.12
Energy dist v self: 0.0003059, energy dist vs ground truth: 0.0003111, Wasserstein-2: 0.1608
Test_accuracy: 0.6607, top 90% accuracy: 0.6663
Energy distance between LMC and NUTS: 0.00037744

Data shape: (768, 8)
GT energy bias: 1.654e-04, test acc: 0.7777, test acc top 90%: 0.7802


warmup: 100%|██████████| 128/128 [00:03<00:00, 37.96it/s] 
sample: 100%|██████████| 256/256 [00:04<00:00, 62.13it/s] 


NUTS:

Effective sample size: 7.442e+04, ess per sample: 1.136, grad evals per sample: 11.94
Energy dist v self: 0.0001961, energy dist vs ground truth: 0.0001615, Wasserstein-2: 0.04626
Test_accuracy: 0.7773, top 90% accuracy: 0.7799
LMC:
Target chain separation: 0.151


100.00%|██████████| [00:01<00:00, 97.39%/s]
100.00%|██████████| [00:04<00:00, 22.11%/s]



Effective sample size: 4.041e+04, ess per sample: 0.6165, grad evals per sample: 9.984
Energy dist v self: 0.0001412, energy dist vs ground truth: 0.0001502, Wasserstein-2: 0.04492
Test_accuracy: 0.7774, top 90% accuracy: 0.7799
Energy distance between LMC and NUTS: 0.00014471

Data shape: (144, 9)
GT energy bias: 1.909e-03, test acc: 0.6165, test acc top 90%: 0.6244


warmup: 100%|██████████| 128/128 [00:10<00:00, 12.17it/s]
sample: 100%|██████████| 256/256 [00:05<00:00, 45.87it/s]


NUTS:

Effective sample size: 5.386e+04, ess per sample: 0.8218, grad evals per sample: 98.48
Energy dist v self: 0.003037, energy dist vs ground truth: 0.002298, Wasserstein-2: 2.075
Test_accuracy: 0.616, top 90% accuracy: 0.6238
LMC:
Target chain separation: 1.245


100.00%|██████████| [00:08<00:00, 11.55%/s]
100.00%|██████████| [00:38<00:00,  2.59%/s]



Effective sample size: 1.045e+04, ess per sample: 0.1594, grad evals per sample: 80.31
Energy dist v self: 0.03459, energy dist vs ground truth: 0.01002, Wasserstein-2: 2.318
Test_accuracy: 0.6176, top 90% accuracy: 0.6251
Energy distance between LMC and NUTS: 0.0080092

Data shape: (1000, 20)


sample: 100%|██████████| 9216/9216 [00:17<00:00, 515.47it/s]


GT energy bias: 3.364e-04, test acc: 0.7838, test acc top 90%: 0.7875


warmup: 100%|██████████| 128/128 [00:03<00:00, 33.98it/s] 
sample: 100%|██████████| 256/256 [00:04<00:00, 54.78it/s] 


NUTS:

Effective sample size: 8.259e+04, ess per sample: 1.26, grad evals per sample: 15.81
Energy dist v self: 0.0003344, energy dist vs ground truth: 0.0002966, Wasserstein-2: 0.1706
Test_accuracy: 0.7832, top 90% accuracy: 0.7867
LMC:
Target chain separation: 0.1999


100.00%|██████████| [00:01<00:00, 70.43%/s]
100.00%|██████████| [00:06<00:00, 16.53%/s]



Effective sample size: 1.136e+05, ess per sample: 1.733, grad evals per sample: 12.62
Energy dist v self: 0.0002058, energy dist vs ground truth: 0.0002981, Wasserstein-2: 0.1594
Test_accuracy: 0.784, top 90% accuracy: 0.7877
Energy distance between LMC and NUTS: 0.00026025

Data shape: (270, 13)


sample: 100%|██████████| 9216/9216 [00:11<00:00, 781.22it/s] 


GT energy bias: 4.896e-04, test acc: 0.8003, test acc top 90%: 0.8086


warmup: 100%|██████████| 128/128 [00:03<00:00, 37.24it/s]
sample: 100%|██████████| 256/256 [00:02<00:00, 88.30it/s] 


NUTS:

Effective sample size: 8.362e+04, ess per sample: 1.276, grad evals per sample: 12.35
Energy dist v self: 0.0008113, energy dist vs ground truth: 0.0005652, Wasserstein-2: 0.3856
Test_accuracy: 0.8012, top 90% accuracy: 0.809
LMC:
Target chain separation: 0.1561


100.00%|██████████| [00:01<00:00, 96.52%/s]
100.00%|██████████| [00:04<00:00, 22.55%/s]



Effective sample size: 2.47e+04, ess per sample: 0.3769, grad evals per sample: 10.05
Energy dist v self: 0.0005494, energy dist vs ground truth: 0.0005259, Wasserstein-2: 0.4069
Test_accuracy: 0.7999, top 90% accuracy: 0.8081
Energy distance between LMC and NUTS: 0.00055192

Data shape: (2086, 18)


sample: 100%|██████████| 9216/9216 [00:59<00:00, 153.95it/s]


GT energy bias: 1.013e-03, test acc: 0.8213, test acc top 90%: 0.8227


warmup: 100%|██████████| 128/128 [00:12<00:00,  9.93it/s]
sample: 100%|██████████| 256/256 [00:14<00:00, 17.24it/s]


NUTS:

Effective sample size: 7.121e+04, ess per sample: 1.087, grad evals per sample: 83.42
Energy dist v self: 0.0009327, energy dist vs ground truth: 0.0008437, Wasserstein-2: 0.7705
Test_accuracy: 0.8215, top 90% accuracy: 0.8227
LMC:
Target chain separation: 1.055


100.00%|██████████| [00:07<00:00, 12.75%/s]
100.00%|██████████| [00:34<00:00,  2.89%/s]



Effective sample size: 5.701e+04, ess per sample: 0.8699, grad evals per sample: 67.82
Energy dist v self: 0.0007647, energy dist vs ground truth: 0.001134, Wasserstein-2: 0.7543
Test_accuracy: 0.8213, top 90% accuracy: 0.8225
Energy distance between LMC and NUTS: 0.00084101

Data shape: (7400, 20)


sample: 100%|██████████| 9216/9216 [00:11<00:00, 788.73it/s] 


GT energy bias: 2.372e-04, test acc: 0.7576, test acc top 90%: 0.7587


warmup: 100%|██████████| 128/128 [00:04<00:00, 30.20it/s]
sample: 100%|██████████| 256/256 [00:05<00:00, 49.24it/s]


NUTS:

Effective sample size: 1.157e+05, ess per sample: 1.766, grad evals per sample: 11.64
Energy dist v self: 0.0002976, energy dist vs ground truth: 0.0002888, Wasserstein-2: 0.123
Test_accuracy: 0.7575, top 90% accuracy: 0.7585
LMC:
Target chain separation: 0.1471


100.00%|██████████| [00:01<00:00, 94.80%/s]
100.00%|██████████| [00:04<00:00, 20.97%/s]



Effective sample size: 9.495e+04, ess per sample: 1.449, grad evals per sample: 9.938
Energy dist v self: 0.0002222, energy dist vs ground truth: 0.0002463, Wasserstein-2: 0.1185
Test_accuracy: 0.7577, top 90% accuracy: 0.7588
Energy distance between LMC and NUTS: 0.00027642

Data shape: (2991, 60)


sample: 100%|██████████| 9216/9216 [02:04<00:00, 73.86it/s]


GT energy bias: 1.667e-03, test acc: 0.8263, test acc top 90%: 0.8275


warmup: 100%|██████████| 128/128 [00:51<00:00,  2.51it/s]
sample: 100%|██████████| 256/256 [00:32<00:00,  7.84it/s]


NUTS:

Effective sample size: 9.167e+04, ess per sample: 1.399, grad evals per sample: 150.0
Energy dist v self: 0.001355, energy dist vs ground truth: 0.001503, Wasserstein-2: 0.9041
Test_accuracy: 0.8264, top 90% accuracy: 0.8275
LMC:
Target chain separation: 1.896


100.00%|██████████| [00:19<00:00,  5.05%/s]
100.00%|██████████| [01:26<00:00,  1.16%/s]



Effective sample size: 7.291e+04, ess per sample: 1.112, grad evals per sample: 120.8
Energy dist v self: 0.03112, energy dist vs ground truth: 21.33, Wasserstein-2: 699.5
Test_accuracy: 0.7229, top 90% accuracy: 0.7471
Energy distance between LMC and NUTS: 21.596

Data shape: (215, 5)


sample: 100%|██████████| 9216/9216 [00:14<00:00, 623.97it/s]


GT energy bias: 2.711e-04, test acc: 0.8006, test acc top 90%: 0.8077


warmup: 100%|██████████| 128/128 [00:03<00:00, 41.48it/s] 
sample: 100%|██████████| 256/256 [00:02<00:00, 88.58it/s] 


NUTS:

Effective sample size: 6.549e+04, ess per sample: 0.9994, grad evals per sample: 12.4
Energy dist v self: 0.0006131, energy dist vs ground truth: 0.0002351, Wasserstein-2: 0.092
Test_accuracy: 0.7996, top 90% accuracy: 0.8065
LMC:
Target chain separation: 0.1568


100.00%|██████████| [00:01<00:00, 96.97%/s] 
100.00%|██████████| [00:04<00:00, 22.88%/s]



Effective sample size: 1.721e+04, ess per sample: 0.2626, grad evals per sample: 10.05
Energy dist v self: 0.001257, energy dist vs ground truth: 0.0007926, Wasserstein-2: 0.1089
Test_accuracy: 0.7997, top 90% accuracy: 0.8062
Energy distance between LMC and NUTS: 0.00052573

Data shape: (24, 3)


sample: 100%|██████████| 9216/9216 [00:10<00:00, 913.48it/s] 


GT energy bias: 6.719e-04, test acc: 0.5094, test acc top 90%: 0.528


warmup: 100%|██████████| 128/128 [00:04<00:00, 30.07it/s]
sample: 100%|██████████| 256/256 [00:02<00:00, 90.66it/s] 


NUTS:

Effective sample size: 6.662e+04, ess per sample: 1.017, grad evals per sample: 11.4
Energy dist v self: 0.0001992, energy dist vs ground truth: 0.0002036, Wasserstein-2: 0.02694
Test_accuracy: 0.516, top 90% accuracy: 0.539
LMC:
Target chain separation: 0.1441


100.00%|██████████| [00:00<00:00, 108.37%/s]
100.00%|██████████| [00:04<00:00, 24.00%/s]



Effective sample size: 1.524e+04, ess per sample: 0.2326, grad evals per sample: 9.898
Energy dist v self: 0.0004072, energy dist vs ground truth: 0.0003779, Wasserstein-2: 0.03114
Test_accuracy: 0.5111, top 90% accuracy: 0.5332
Energy distance between LMC and NUTS: 6.4283e-05

Data shape: (7400, 20)


sample: 100%|██████████| 9216/9216 [00:12<00:00, 736.21it/s] 


GT energy bias: 5.726e-04, test acc: 0.9721, test acc top 90%: 0.9725


warmup: 100%|██████████| 128/128 [00:03<00:00, 32.22it/s]
sample: 100%|██████████| 256/256 [00:06<00:00, 40.61it/s] 


NUTS:

Effective sample size: 9.523e+04, ess per sample: 1.453, grad evals per sample: 12.09
Energy dist v self: 0.0006802, energy dist vs ground truth: 0.0007649, Wasserstein-2: 0.7076
Test_accuracy: 0.9719, top 90% accuracy: 0.9724
LMC:
Target chain separation: 0.1528


100.00%|██████████| [00:01<00:00, 92.25%/s]
100.00%|██████████| [00:04<00:00, 21.17%/s]



Effective sample size: 2.649e+04, ess per sample: 0.4041, grad evals per sample: 10.01
Energy dist v self: 0.000838, energy dist vs ground truth: 0.0008356, Wasserstein-2: 0.7298
Test_accuracy: 0.9721, top 90% accuracy: 0.9725
Energy distance between LMC and NUTS: 0.00085003

Data shape: (5000, 21)


sample: 100%|██████████| 9216/9216 [00:29<00:00, 308.93it/s]


GT energy bias: 5.707e-04, test acc: 0.8748, test acc top 90%: 0.8756


warmup: 100%|██████████| 128/128 [00:06<00:00, 21.06it/s]
sample: 100%|██████████| 256/256 [00:06<00:00, 37.34it/s]


NUTS:

Effective sample size: 8.877e+04, ess per sample: 1.355, grad evals per sample: 32.64
Energy dist v self: 0.0007087, energy dist vs ground truth: 0.0007921, Wasserstein-2: 0.4284
Test_accuracy: 0.8747, top 90% accuracy: 0.8755
LMC:
Target chain separation: 0.4127


100.00%|██████████| [00:03<00:00, 32.91%/s]
100.00%|██████████| [00:13<00:00,  7.25%/s]



Effective sample size: 1.388e+05, ess per sample: 2.119, grad evals per sample: 27.42
Energy dist v self: 0.0003701, energy dist vs ground truth: 0.0005828, Wasserstein-2: 0.4259
Test_accuracy: 0.8751, top 90% accuracy: 0.876
Energy distance between LMC and NUTS: 0.0003962

