In [1]:
import datetime
import warnings

import jax
import jax.numpy as jnp
from logreg_full_run import run_logreg_dataset


%env JAX_PLATFORM_NAME=cuda
warnings.simplefilter("ignore", FutureWarning)

jnp.set_printoptions(precision=4, suppress=True)
jax.config.update("jax_enable_x64", True)
print(jax.devices("cuda"))

names = [
    "banana",
    "breast_cancer",
    "diabetis",
    "flare_solar",
    "german",
    "heart",
    "image",
    "ringnorm",
    "splice",
    "thyroid",
    "titanic",
    "twonorm",
    "waveform",
]

env: JAX_PLATFORM_NAME=cuda
[CudaDevice(id=0)]


In [2]:
timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
results_filename = f"mcmc_data/results_{timestamp}.txt"

# create the results file
with open(results_filename, "w") as f:
    f.write("Results\n\n")

for name in names:
    print(f"==================== {name} ====================")
    run_logreg_dataset(name, results_filename, timestamp)
    print()

Data shape: (5300, 2)
GT energy bias: 1.374e-05, test acc: 0.5468, test acc top 90%: 0.5547
NUTS:


warmup: 100%|██████████| 128/128 [00:03<00:00, 36.61it/s] 
sample: 100%|██████████| 256/256 [00:03<00:00, 68.50it/s] 



ESS per sample: 0.9391, grad evals per sample: 7.627, GEPS/ESS: 8.122
Energy dist vs ground truth: 9.843e-06, Wasserstein-2 error: 0.0004232
Test_accuracy: 0.5506, top 90% accuracy: 0.5558
LMC:
Target time-interval between samples for LMC: 0.09642
QUICSORT:


100.00%|██████████| [00:00<00:00, 159.53%/s]
100.00%|██████████| [00:03<00:00, 29.43%/s]



ESS per sample: 0.9263, grad evals per sample: 7.32, GEPS/ESS: 7.903
Energy dist vs ground truth: 2.213e-05, Wasserstein-2 error: 0.0004446
Test_accuracy: 0.5499, top 90% accuracy: 0.5561
Euler:


100.00%|██████████| [00:01<00:00, 80.42%/s]
100.00%|██████████| [00:05<00:00, 17.92%/s]



ESS per sample: 0.3913, grad evals per sample: 6.309, GEPS/ESS: 16.12
Energy dist vs ground truth: 7.375, Wasserstein-2 error: 114.9
Test_accuracy: 0.507, top 90% accuracy: 0.5184

Data shape: (263, 9)
GT energy bias: 1.405e-04, test acc: 0.6599, test acc top 90%: 0.6655
NUTS:


warmup: 100%|██████████| 128/128 [00:02<00:00, 43.03it/s] 
sample: 100%|██████████| 256/256 [00:03<00:00, 72.83it/s] 



ESS per sample: 1.163, grad evals per sample: 12.66, GEPS/ESS: 10.89
Energy dist vs ground truth: 0.0001956, Wasserstein-2 error: 0.1366
Test_accuracy: 0.663, top 90% accuracy: 0.6682
LMC:
Target time-interval between samples for LMC: 0.16
QUICSORT:


100.00%|██████████| [00:01<00:00, 85.55%/s]
100.00%|██████████| [00:05<00:00, 17.56%/s]



ESS per sample: 0.4182, grad evals per sample: 12.12, GEPS/ESS: 28.99
Energy dist vs ground truth: 0.0001647, Wasserstein-2 error: 0.1379
Test_accuracy: 0.6607, top 90% accuracy: 0.6663
Euler:


100.00%|██████████| [00:02<00:00, 43.35%/s]
100.00%|██████████| [00:10<00:00,  9.76%/s]



ESS per sample: 0.4494, grad evals per sample: 11.11, GEPS/ESS: 24.72
Energy dist vs ground truth: 0.3474, Wasserstein-2 error: 1.765
Test_accuracy: 0.6502, top 90% accuracy: 0.6608

Data shape: (768, 8)
GT energy bias: 9.340e-05, test acc: 0.7777, test acc top 90%: 0.7802
NUTS:


warmup: 100%|██████████| 128/128 [00:03<00:00, 35.91it/s] 
sample: 100%|██████████| 256/256 [00:04<00:00, 62.33it/s] 



ESS per sample: 1.136, grad evals per sample: 11.94, GEPS/ESS: 10.52
Energy dist vs ground truth: 0.0001099, Wasserstein-2 error: 0.04171
Test_accuracy: 0.7773, top 90% accuracy: 0.7799
LMC:
Target time-interval between samples for LMC: 0.151
QUICSORT:


100.00%|██████████| [00:00<00:00, 100.26%/s]
100.00%|██████████| [00:04<00:00, 22.24%/s]



ESS per sample: 0.6165, grad evals per sample: 9.984, GEPS/ESS: 16.19
Energy dist vs ground truth: 9.262e-05, Wasserstein-2 error: 0.04077
Test_accuracy: 0.7774, top 90% accuracy: 0.7799
Euler:


100.00%|██████████| [00:01<00:00, 51.00%/s]
100.00%|██████████| [00:08<00:00, 11.31%/s]



ESS per sample: 0.4793, grad evals per sample: 9.984, GEPS/ESS: 20.83
Energy dist vs ground truth: 5.71, Wasserstein-2 error: 43.73
Test_accuracy: 0.6768, top 90% accuracy: 0.6858

Data shape: (144, 9)
GT energy bias: 3.103e-03, test acc: 0.6165, test acc top 90%: 0.6244
NUTS:


warmup: 100%|██████████| 128/128 [00:10<00:00, 12.17it/s]
sample: 100%|██████████| 256/256 [00:05<00:00, 45.97it/s]



ESS per sample: 0.8218, grad evals per sample: 98.48, GEPS/ESS: 119.8
Energy dist vs ground truth: 0.003401, Wasserstein-2 error: 1.78
Test_accuracy: 0.616, top 90% accuracy: 0.6238
LMC:
Target time-interval between samples for LMC: 1.245
QUICSORT:


100.00%|██████████| [00:08<00:00, 11.57%/s]
100.00%|██████████| [00:38<00:00,  2.60%/s]



ESS per sample: 0.1594, grad evals per sample: 80.31, GEPS/ESS: 503.9
Energy dist vs ground truth: 0.002446, Wasserstein-2 error: 1.862
Test_accuracy: 0.6176, top 90% accuracy: 0.6251
Euler:


100.00%|██████████| [00:17<00:00,  5.77%/s]
100.00%|██████████| [01:16<00:00,  1.31%/s]



ESS per sample: 0.02594, grad evals per sample: 79.3, GEPS/ESS: 3.057e+03
Energy dist vs ground truth: 2.666e+05, Wasserstein-2 error: 3.608e+15
Test_accuracy: 0.5664, top 90% accuracy: 0.5794

Data shape: (1000, 20)
GT energy bias: 3.190e-04, test acc: 0.7838, test acc top 90%: 0.7875
NUTS:


warmup: 100%|██████████| 128/128 [00:04<00:00, 28.36it/s]
sample: 100%|██████████| 256/256 [00:04<00:00, 53.59it/s] 



ESS per sample: 1.26, grad evals per sample: 15.81, GEPS/ESS: 12.55
Energy dist vs ground truth: 0.0002894, Wasserstein-2 error: 0.1641
Test_accuracy: 0.7832, top 90% accuracy: 0.7867
LMC:
Target time-interval between samples for LMC: 0.1999
QUICSORT:


100.00%|██████████| [00:01<00:00, 69.69%/s]
100.00%|██████████| [00:06<00:00, 16.57%/s]



ESS per sample: 1.733, grad evals per sample: 12.62, GEPS/ESS: 7.28
Energy dist vs ground truth: 0.0002908, Wasserstein-2 error: 0.1552
Test_accuracy: 0.784, top 90% accuracy: 0.7877
Euler:


100.00%|██████████| [00:02<00:00, 36.51%/s]
100.00%|██████████| [00:11<00:00,  8.53%/s]



ESS per sample: 0.6149, grad evals per sample: 12.62, GEPS/ESS: 20.52
Energy dist vs ground truth: 11.76, Wasserstein-2 error: 153.0
Test_accuracy: 0.6844, top 90% accuracy: 0.6944

Data shape: (270, 13)
GT energy bias: 2.740e-04, test acc: 0.8003, test acc top 90%: 0.8086
NUTS:


warmup: 100%|██████████| 128/128 [00:02<00:00, 46.10it/s] 
sample: 100%|██████████| 256/256 [00:03<00:00, 81.02it/s] 



ESS per sample: 1.276, grad evals per sample: 12.35, GEPS/ESS: 9.68
Energy dist vs ground truth: 0.0002391, Wasserstein-2 error: 0.3472
Test_accuracy: 0.8012, top 90% accuracy: 0.809
LMC:
Target time-interval between samples for LMC: 0.1561
QUICSORT:


100.00%|██████████| [00:01<00:00, 97.61%/s]
100.00%|██████████| [00:04<00:00, 22.72%/s]



ESS per sample: 0.3769, grad evals per sample: 10.05, GEPS/ESS: 26.66
Energy dist vs ground truth: 0.0002913, Wasserstein-2 error: 0.3586
Test_accuracy: 0.7999, top 90% accuracy: 0.8081
Euler:


100.00%|██████████| [00:01<00:00, 50.07%/s]
100.00%|██████████| [00:08<00:00, 11.39%/s]



ESS per sample: 0.406, grad evals per sample: 10.05, GEPS/ESS: 24.74
Energy dist vs ground truth: 0.1634, Wasserstein-2 error: 1.479
Test_accuracy: 0.7872, top 90% accuracy: 0.7991

Data shape: (2086, 18)
GT energy bias: 9.415e-04, test acc: 0.8213, test acc top 90%: 0.8227
NUTS:


warmup: 100%|██████████| 128/128 [00:12<00:00,  9.97it/s]
sample: 100%|██████████| 256/256 [00:14<00:00, 17.22it/s]



ESS per sample: 1.087, grad evals per sample: 83.42, GEPS/ESS: 76.77
Energy dist vs ground truth: 0.0007904, Wasserstein-2 error: 0.7503
Test_accuracy: 0.8215, top 90% accuracy: 0.8227
LMC:
Target time-interval between samples for LMC: 1.055
QUICSORT:


100.00%|██████████| [00:07<00:00, 12.76%/s]
100.00%|██████████| [00:34<00:00,  2.88%/s]



ESS per sample: 0.8699, grad evals per sample: 67.82, GEPS/ESS: 77.96
Energy dist vs ground truth: 0.001078, Wasserstein-2 error: 0.748
Test_accuracy: 0.8213, top 90% accuracy: 0.8225
Euler:


100.00%|██████████| [00:15<00:00,  6.55%/s]
100.00%|██████████| [01:06<00:00,  1.51%/s]



ESS per sample: 1.209, grad evals per sample: 66.8, GEPS/ESS: 55.28
Energy dist vs ground truth: 33.92, Wasserstein-2 error: 1.012e+03
Test_accuracy: 0.6346, top 90% accuracy: 0.654

Data shape: (7400, 20)
GT energy bias: 2.306e-04, test acc: 0.7576, test acc top 90%: 0.7587
NUTS:


warmup: 100%|██████████| 128/128 [00:04<00:00, 30.17it/s]
sample: 100%|██████████| 256/256 [00:05<00:00, 49.37it/s]



ESS per sample: 1.766, grad evals per sample: 11.64, GEPS/ESS: 6.591
Energy dist vs ground truth: 0.0002409, Wasserstein-2 error: 0.1193
Test_accuracy: 0.7575, top 90% accuracy: 0.7585
LMC:
Target time-interval between samples for LMC: 0.1471
QUICSORT:


100.00%|██████████| [00:01<00:00, 96.58%/s]
100.00%|██████████| [00:04<00:00, 21.08%/s]



ESS per sample: 1.449, grad evals per sample: 9.938, GEPS/ESS: 6.859
Energy dist vs ground truth: 0.0002417, Wasserstein-2 error: 0.1151
Test_accuracy: 0.7577, top 90% accuracy: 0.7588
Euler:


100.00%|██████████| [00:01<00:00, 50.78%/s]
100.00%|██████████| [00:09<00:00, 10.95%/s]



ESS per sample: 0.4768, grad evals per sample: 9.934, GEPS/ESS: 20.83
Energy dist vs ground truth: 7.19, Wasserstein-2 error: 76.07
Test_accuracy: 0.6579, top 90% accuracy: 0.6783

Data shape: (2991, 60)
GT energy bias: 8.853e-04, test acc: 0.8263, test acc top 90%: 0.8275
NUTS:


warmup: 100%|██████████| 128/128 [00:51<00:00,  2.50it/s]
sample: 100%|██████████| 256/256 [00:32<00:00,  7.77it/s]



ESS per sample: 1.399, grad evals per sample: 150.0, GEPS/ESS: 107.2
Energy dist vs ground truth: 0.000778, Wasserstein-2 error: 0.8642
Test_accuracy: 0.8264, top 90% accuracy: 0.8275
LMC:
Target time-interval between samples for LMC: 0.474
QUICSORT:


100.00%|██████████| [00:19<00:00,  5.01%/s]
100.00%|██████████| [01:27<00:00,  1.15%/s]



ESS per sample: 1.092, grad evals per sample: 120.8, GEPS/ESS: 110.6
Energy dist vs ground truth: 0.0008142, Wasserstein-2 error: 0.8711
Test_accuracy: 0.8264, top 90% accuracy: 0.8276
Euler:


100.00%|██████████| [00:36<00:00,  2.74%/s]
100.00%|██████████| [02:37<00:00,  1.58s/%]



ESS per sample: 0.01488, grad evals per sample: 119.8, GEPS/ESS: 8.051e+03
Energy dist vs ground truth: 142.4, Wasserstein-2 error: 1.049e+04
Test_accuracy: 0.5175, top 90% accuracy: 0.5252

Data shape: (215, 5)
GT energy bias: 2.076e-04, test acc: 0.8006, test acc top 90%: 0.8077
NUTS:


warmup: 100%|██████████| 128/128 [00:04<00:00, 31.40it/s] 
sample: 100%|██████████| 256/256 [00:03<00:00, 83.97it/s] 



ESS per sample: 0.9994, grad evals per sample: 12.4, GEPS/ESS: 12.41
Energy dist vs ground truth: 0.000176, Wasserstein-2 error: 0.08486
Test_accuracy: 0.7996, top 90% accuracy: 0.8065
LMC:
Target time-interval between samples for LMC: 0.1568
QUICSORT:


100.00%|██████████| [00:01<00:00, 97.77%/s] 
100.00%|██████████| [00:04<00:00, 22.86%/s]



ESS per sample: 0.2626, grad evals per sample: 10.05, GEPS/ESS: 38.29
Energy dist vs ground truth: 0.0005494, Wasserstein-2 error: 0.1002
Test_accuracy: 0.7997, top 90% accuracy: 0.8062
Euler:


100.00%|██████████| [00:02<00:00, 49.80%/s]
100.00%|██████████| [00:08<00:00, 11.41%/s]



ESS per sample: 0.2862, grad evals per sample: 10.05, GEPS/ESS: 35.13
Energy dist vs ground truth: 0.449, Wasserstein-2 error: 1.815
Test_accuracy: 0.7817, top 90% accuracy: 0.7898

Data shape: (24, 3)
GT energy bias: 6.719e-04, test acc: 0.5094, test acc top 90%: 0.528
NUTS:


warmup: 100%|██████████| 128/128 [00:03<00:00, 40.98it/s] 
sample: 100%|██████████| 256/256 [00:02<00:00, 91.29it/s] 



ESS per sample: 1.017, grad evals per sample: 11.4, GEPS/ESS: 11.21
Energy dist vs ground truth: 0.0002036, Wasserstein-2 error: 0.02694
Test_accuracy: 0.516, top 90% accuracy: 0.539
LMC:
Target time-interval between samples for LMC: 0.1441
QUICSORT:


100.00%|██████████| [00:00<00:00, 108.78%/s]
100.00%|██████████| [00:04<00:00, 23.92%/s]



ESS per sample: 0.2326, grad evals per sample: 9.898, GEPS/ESS: 42.56
Energy dist vs ground truth: 0.0003779, Wasserstein-2 error: 0.03114
Test_accuracy: 0.5111, top 90% accuracy: 0.5332
Euler:


100.00%|██████████| [00:01<00:00, 56.89%/s]
100.00%|██████████| [00:08<00:00, 12.06%/s]



ESS per sample: 0.1305, grad evals per sample: 9.898, GEPS/ESS: 75.83
Energy dist vs ground truth: 6.228e+10, Wasserstein-2 error: 6.405e+23
Test_accuracy: 0.4998, top 90% accuracy: 0.5108

Data shape: (7400, 20)
GT energy bias: 4.978e-04, test acc: 0.9721, test acc top 90%: 0.9725
NUTS:


warmup: 100%|██████████| 128/128 [00:03<00:00, 32.59it/s]
sample: 100%|██████████| 256/256 [00:04<00:00, 52.88it/s] 



ESS per sample: 1.453, grad evals per sample: 12.09, GEPS/ESS: 8.317
Energy dist vs ground truth: 0.0006075, Wasserstein-2 error: 0.6886
Test_accuracy: 0.9719, top 90% accuracy: 0.9724
LMC:
Target time-interval between samples for LMC: 0.1528
QUICSORT:


100.00%|██████████| [00:01<00:00, 93.23%/s]
100.00%|██████████| [00:04<00:00, 21.08%/s]



ESS per sample: 0.4041, grad evals per sample: 10.01, GEPS/ESS: 24.76
Energy dist vs ground truth: 0.0006843, Wasserstein-2 error: 0.7096
Test_accuracy: 0.9721, top 90% accuracy: 0.9725
Euler:


100.00%|██████████| [00:02<00:00, 49.24%/s]
100.00%|██████████| [00:09<00:00, 10.96%/s]



ESS per sample: 0.3889, grad evals per sample: 10.0, GEPS/ESS: 25.72
Energy dist vs ground truth: 0.1902, Wasserstein-2 error: 2.116
Test_accuracy: 0.9674, top 90% accuracy: 0.9686

Data shape: (5000, 21)
GT energy bias: 5.078e-04, test acc: 0.8748, test acc top 90%: 0.8756
NUTS:


warmup: 100%|██████████| 128/128 [00:06<00:00, 20.59it/s]
sample: 100%|██████████| 256/256 [00:08<00:00, 31.14it/s]



ESS per sample: 1.355, grad evals per sample: 32.64, GEPS/ESS: 24.1
Energy dist vs ground truth: 0.0006622, Wasserstein-2 error: 0.4236
Test_accuracy: 0.8747, top 90% accuracy: 0.8755
LMC:
Target time-interval between samples for LMC: 0.4127
QUICSORT:


100.00%|██████████| [00:03<00:00, 32.86%/s]
100.00%|██████████| [00:13<00:00,  7.23%/s]



ESS per sample: 2.119, grad evals per sample: 27.42, GEPS/ESS: 12.94
Energy dist vs ground truth: 0.0004867, Wasserstein-2 error: 0.4145
Test_accuracy: 0.8751, top 90% accuracy: 0.876
Euler:


100.00%|██████████| [00:05<00:00, 17.23%/s]
100.00%|██████████| [00:25<00:00,  3.90%/s]



ESS per sample: 1.138, grad evals per sample: 26.41, GEPS/ESS: 23.21
Energy dist vs ground truth: 39.65, Wasserstein-2 error: 836.1
Test_accuracy: 0.777, top 90% accuracy: 0.7921

