In [1]:
import datetime
import warnings

import jax
import jax.numpy as jnp
from logreg_full_run import run_logreg_dataset


%env JAX_PLATFORM_NAME=cuda
warnings.simplefilter("ignore", FutureWarning)

jnp.set_printoptions(precision=4, suppress=True)
jax.config.update("jax_enable_x64", True)
print(jax.devices("cuda"))

names = [
    "banana",
    "breast_cancer",
    "diabetis",
    "flare_solar",
    "german",
    "heart",
    "image",
    "ringnorm",
    "splice",
    "thyroid",
    "titanic",
    "twonorm",
    "waveform",
]

env: JAX_PLATFORM_NAME=cuda
[CudaDevice(id=0)]


In [2]:
timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
results_filename = f"mcmc_data/results_{timestamp}.txt"

# create the results file
with open(results_filename, "w") as f:
    f.write("Results\n\n")

for name in names:
    print(f"==================== {name} ====================")
    run_logreg_dataset(name, results_filename, timestamp)
    print()

Data shape: (5300, 2)
GT energy bias: 2.455e-05, test acc: 0.5498, test acc top 90%: 0.556
NUTS:


warmup: 100%|██████████| 128/128 [00:02<00:00, 44.45it/s] 
sample: 100%|██████████| 256/256 [00:02<00:00, 87.35it/s] 



ESS per sample: 0.9635, grad evals per sample: 7.626, GEPS/ESS: 7.915
Energy dist vs ground truth: 7.912e-06, Wasserstein-2 error: 0.0004198
Test_accuracy: 0.5492, top 90% accuracy: 0.5558
LMC:
Target time-interval between samples for LMC: 0.09641
QUICSORT:


100.00%|██████████| [00:00<00:00, 170.92%/s]
100.00%|██████████| [00:03<00:00, 32.38%/s]



ESS per sample: 0.9302, grad evals per sample: 7.32, GEPS/ESS: 7.87
Energy dist vs ground truth: 3.462e-05, Wasserstein-2 error: 0.0004376
Test_accuracy: 0.5506, top 90% accuracy: 0.5563
Euler:


100.00%|██████████| [00:02<00:00, 34.15%/s]
100.00%|██████████| [00:13<00:00,  7.52%/s]



ESS per sample: 0.5159, grad evals per sample: 16.28, GEPS/ESS: 31.55
Energy dist vs ground truth: 2.396, Wasserstein-2 error: 12.85
Test_accuracy: 0.5072, top 90% accuracy: 0.5183

Data shape: (263, 9)
GT energy bias: 9.460e-05, test acc: 0.6627, test acc top 90%: 0.6677
NUTS:


warmup: 100%|██████████| 128/128 [00:02<00:00, 47.68it/s] 
sample: 100%|██████████| 256/256 [00:03<00:00, 84.90it/s] 



ESS per sample: 1.192, grad evals per sample: 12.11, GEPS/ESS: 10.17
Energy dist vs ground truth: 9.42e-05, Wasserstein-2 error: 0.1419
Test_accuracy: 0.6586, top 90% accuracy: 0.6643
LMC:
Target time-interval between samples for LMC: 0.1531
QUICSORT:


100.00%|██████████| [00:00<00:00, 106.15%/s]
100.00%|██████████| [00:03<00:00, 25.22%/s]



ESS per sample: 0.4016, grad evals per sample: 10.02, GEPS/ESS: 24.94
Energy dist vs ground truth: 6.859e-05, Wasserstein-2 error: 0.1393
Test_accuracy: 0.6603, top 90% accuracy: 0.6656
Euler:


100.00%|██████████| [00:04<00:00, 21.32%/s]
100.00%|██████████| [00:20<00:00,  4.78%/s]



ESS per sample: 0.4237, grad evals per sample: 25.02, GEPS/ESS: 59.06
Energy dist vs ground truth: 0.007589, Wasserstein-2 error: 0.1935
Test_accuracy: 0.66, top 90% accuracy: 0.6658

Data shape: (768, 8)
GT energy bias: 4.407e-05, test acc: 0.7772, test acc top 90%: 0.7798
NUTS:


warmup: 100%|██████████| 128/128 [00:02<00:00, 44.90it/s] 
sample: 100%|██████████| 256/256 [00:03<00:00, 76.83it/s] 



ESS per sample: 1.155, grad evals per sample: 11.93, GEPS/ESS: 10.32
Energy dist vs ground truth: 3.515e-05, Wasserstein-2 error: 0.04258
Test_accuracy: 0.7778, top 90% accuracy: 0.7803
LMC:
Target time-interval between samples for LMC: 0.1508
QUICSORT:


100.00%|██████████| [00:00<00:00, 107.12%/s]
100.00%|██████████| [00:04<00:00, 24.87%/s]



ESS per sample: 0.618, grad evals per sample: 9.984, GEPS/ESS: 16.15
Energy dist vs ground truth: 4.192e-05, Wasserstein-2 error: 0.04122
Test_accuracy: 0.7775, top 90% accuracy: 0.7801
Euler:


100.00%|██████████| [00:04<00:00, 21.36%/s]
100.00%|██████████| [00:20<00:00,  4.93%/s]



ESS per sample: 0.7565, grad evals per sample: 23.94, GEPS/ESS: 31.64
Energy dist vs ground truth: 1.119, Wasserstein-2 error: 3.511
Test_accuracy: 0.718, top 90% accuracy: 0.7252

Data shape: (144, 9)
GT energy bias: 8.068e-04, test acc: 0.6169, test acc top 90%: 0.6247
NUTS:


warmup: 100%|██████████| 128/128 [00:09<00:00, 13.79it/s]
sample: 100%|██████████| 256/256 [00:04<00:00, 51.27it/s] 



ESS per sample: 0.8171, grad evals per sample: 97.47, GEPS/ESS: 119.3
Energy dist vs ground truth: 0.0008303, Wasserstein-2 error: 2.106
Test_accuracy: 0.6151, top 90% accuracy: 0.6226
LMC:
Target time-interval between samples for LMC: 1.232
QUICSORT:


100.00%|██████████| [00:07<00:00, 12.94%/s]
100.00%|██████████| [00:33<00:00,  3.01%/s]



ESS per sample: 0.1538, grad evals per sample: 78.13, GEPS/ESS: 508.1
Energy dist vs ground truth: 0.0009825, Wasserstein-2 error: 5.009
Test_accuracy: 0.6169, top 90% accuracy: 0.6246
Euler:


100.00%|██████████| [00:40<00:00,  2.46%/s]
100.00%|██████████| [02:56<00:00,  1.76s/%]



ESS per sample: 0.139, grad evals per sample: 195.3, GEPS/ESS: 1.405e+03
Energy dist vs ground truth: 8.793, Wasserstein-2 error: 82.27
Test_accuracy: 0.6052, top 90% accuracy: 0.6119

Data shape: (1000, 20)
GT energy bias: 1.110e-04, test acc: 0.7837, test acc top 90%: 0.7872
NUTS:


warmup: 100%|██████████| 128/128 [00:03<00:00, 36.22it/s] 
sample: 100%|██████████| 256/256 [00:03<00:00, 75.02it/s] 



ESS per sample: 1.255, grad evals per sample: 15.79, GEPS/ESS: 12.58
Energy dist vs ground truth: 9.884e-05, Wasserstein-2 error: 0.1668
Test_accuracy: 0.7834, top 90% accuracy: 0.787
LMC:
Target time-interval between samples for LMC: 0.1996
QUICSORT:


100.00%|██████████| [00:01<00:00, 75.27%/s]
100.00%|██████████| [00:05<00:00, 18.14%/s]



ESS per sample: 1.732, grad evals per sample: 12.62, GEPS/ESS: 7.283
Energy dist vs ground truth: 0.0001003, Wasserstein-2 error: 0.1626
Test_accuracy: 0.7841, top 90% accuracy: 0.7875
Euler:


100.00%|██████████| [00:06<00:00, 16.29%/s]
100.00%|██████████| [00:26<00:00,  3.76%/s]



ESS per sample: 1.019, grad evals per sample: 31.53, GEPS/ESS: 30.94
Energy dist vs ground truth: 3.048, Wasserstein-2 error: 15.65
Test_accuracy: 0.7045, top 90% accuracy: 0.7134

Data shape: (270, 13)
GT energy bias: 1.212e-04, test acc: 0.8014, test acc top 90%: 0.8094
NUTS:


warmup: 100%|██████████| 128/128 [00:02<00:00, 50.70it/s] 
sample: 100%|██████████| 256/256 [00:02<00:00, 92.91it/s] 



ESS per sample: 1.241, grad evals per sample: 12.61, GEPS/ESS: 10.16
Energy dist vs ground truth: 7.32e-05, Wasserstein-2 error: 0.354
Test_accuracy: 0.8015, top 90% accuracy: 0.8096
LMC:
Target time-interval between samples for LMC: 0.1595
QUICSORT:


100.00%|██████████| [00:00<00:00, 101.95%/s]
100.00%|██████████| [00:04<00:00, 23.26%/s]



ESS per sample: 0.3837, grad evals per sample: 10.09, GEPS/ESS: 26.3
Energy dist vs ground truth: 9.588e-05, Wasserstein-2 error: 0.3542
Test_accuracy: 0.7993, top 90% accuracy: 0.8073
Euler:


100.00%|██████████| [00:04<00:00, 20.59%/s]
100.00%|██████████| [00:21<00:00,  4.74%/s]



ESS per sample: 0.3962, grad evals per sample: 25.22, GEPS/ESS: 63.65
Energy dist vs ground truth: 0.005055, Wasserstein-2 error: 0.4197
Test_accuracy: 0.7973, top 90% accuracy: 0.8055

Data shape: (2086, 18)
GT energy bias: 3.996e-04, test acc: 0.821, test acc top 90%: 0.8223
NUTS:


warmup: 100%|██████████| 128/128 [00:07<00:00, 17.18it/s]
sample: 100%|██████████| 256/256 [00:08<00:00, 29.87it/s]



ESS per sample: 1.063, grad evals per sample: 84.1, GEPS/ESS: 79.11
Energy dist vs ground truth: 0.0004884, Wasserstein-2 error: 0.7849
Test_accuracy: 0.8207, top 90% accuracy: 0.822
LMC:
Target time-interval between samples for LMC: 1.063
QUICSORT:


100.00%|██████████| [00:06<00:00, 14.41%/s]
100.00%|██████████| [00:30<00:00,  3.31%/s]



ESS per sample: 0.8605, grad evals per sample: 67.93, GEPS/ESS: 78.94
Energy dist vs ground truth: 0.0003285, Wasserstein-2 error: 0.753
Test_accuracy: 0.8212, top 90% accuracy: 0.8225
Euler:


100.00%|██████████| [00:35<00:00,  2.79%/s]
100.00%|██████████| [02:34<00:00,  1.54s/%]



ESS per sample: 0.2963, grad evals per sample: 167.8, GEPS/ESS: 566.3
Energy dist vs ground truth: 13.09, Wasserstein-2 error: 143.8
Test_accuracy: 0.6906, top 90% accuracy: 0.7003

Data shape: (7400, 20)
GT energy bias: 1.145e-04, test acc: 0.7575, test acc top 90%: 0.7586
NUTS:


warmup: 100%|██████████| 128/128 [00:03<00:00, 40.75it/s] 
sample: 100%|██████████| 256/256 [00:03<00:00, 68.90it/s] 



ESS per sample: 1.744, grad evals per sample: 11.77, GEPS/ESS: 6.748
Energy dist vs ground truth: 7.695e-05, Wasserstein-2 error: 0.1219
Test_accuracy: 0.7572, top 90% accuracy: 0.7584
LMC:
Target time-interval between samples for LMC: 0.1488
QUICSORT:


100.00%|██████████| [00:00<00:00, 102.98%/s]
100.00%|██████████| [00:04<00:00, 23.13%/s]



ESS per sample: 1.53, grad evals per sample: 9.961, GEPS/ESS: 6.512
Energy dist vs ground truth: 0.0001049, Wasserstein-2 error: 0.121
Test_accuracy: 0.7574, top 90% accuracy: 0.7585
Euler:


100.00%|██████████| [00:04<00:00, 21.55%/s]
100.00%|██████████| [00:20<00:00,  4.92%/s]



ESS per sample: 0.8361, grad evals per sample: 23.88, GEPS/ESS: 28.55
Energy dist vs ground truth: 1.436, Wasserstein-2 error: 6.195
Test_accuracy: 0.683, top 90% accuracy: 0.7006

Data shape: (2991, 60)
GT energy bias: 3.211e-04, test acc: 0.8264, test acc top 90%: 0.8275
NUTS:


warmup: 100%|██████████| 128/128 [00:26<00:00,  4.86it/s]
sample: 100%|██████████| 256/256 [00:16<00:00, 15.10it/s]



ESS per sample: 1.406, grad evals per sample: 153.8, GEPS/ESS: 109.4
Energy dist vs ground truth: 0.0002113, Wasserstein-2 error: 0.8448
Test_accuracy: 0.8264, top 90% accuracy: 0.8276
LMC:
Target time-interval between samples for LMC: 0.4862
QUICSORT:


100.00%|██████████| [00:16<00:00,  6.21%/s]
100.00%|██████████| [01:09<00:00,  1.44%/s]



ESS per sample: 0.9728, grad evals per sample: 123.5, GEPS/ESS: 126.9
Energy dist vs ground truth: 0.0004557, Wasserstein-2 error: 0.8592
Test_accuracy: 0.8265, top 90% accuracy: 0.8278
Euler:


100.00%|██████████| [01:18<00:00,  1.28%/s]
100.00%|██████████| [05:38<00:00,  3.38s/%]



ESS per sample: 0.9689, grad evals per sample: 307.6, GEPS/ESS: 317.5
Energy dist vs ground truth: 132.7, Wasserstein-2 error: 6.147e+03
Test_accuracy: 0.5446, top 90% accuracy: 0.5554

Data shape: (215, 5)
GT energy bias: 6.158e-05, test acc: 0.7996, test acc top 90%: 0.8062
NUTS:


warmup: 100%|██████████| 128/128 [00:03<00:00, 32.62it/s] 
sample: 100%|██████████| 256/256 [00:02<00:00, 95.90it/s] 



ESS per sample: 0.9857, grad evals per sample: 12.54, GEPS/ESS: 12.72
Energy dist vs ground truth: 0.0001008, Wasserstein-2 error: 0.08194
Test_accuracy: 0.8001, top 90% accuracy: 0.8067
LMC:
Target time-interval between samples for LMC: 0.1585
QUICSORT:


100.00%|██████████| [00:00<00:00, 102.60%/s]
100.00%|██████████| [00:04<00:00, 23.41%/s]



ESS per sample: 0.2659, grad evals per sample: 10.08, GEPS/ESS: 37.9
Energy dist vs ground truth: 0.0001657, Wasserstein-2 error: 0.09241
Test_accuracy: 0.7989, top 90% accuracy: 0.8057
Euler:


100.00%|██████████| [00:04<00:00, 20.83%/s]
100.00%|██████████| [00:20<00:00,  4.77%/s]



ESS per sample: 0.2721, grad evals per sample: 25.19, GEPS/ESS: 92.58
Energy dist vs ground truth: 0.002401, Wasserstein-2 error: 0.1175
Test_accuracy: 0.7981, top 90% accuracy: 0.8048

Data shape: (24, 3)
GT energy bias: 8.428e-05, test acc: 0.5117, test acc top 90%: 0.5328
NUTS:


warmup: 100%|██████████| 128/128 [00:02<00:00, 47.70it/s] 
sample: 100%|██████████| 256/256 [00:02<00:00, 96.33it/s] 



ESS per sample: 0.9901, grad evals per sample: 11.41, GEPS/ESS: 11.53
Energy dist vs ground truth: 0.0001159, Wasserstein-2 error: 0.02753
Test_accuracy: 0.5162, top 90% accuracy: 0.5349
LMC:
Target time-interval between samples for LMC: 0.1443
QUICSORT:


100.00%|██████████| [00:00<00:00, 111.91%/s]
100.00%|██████████| [00:03<00:00, 25.61%/s]



ESS per sample: 0.2331, grad evals per sample: 9.898, GEPS/ESS: 42.46
Energy dist vs ground truth: 8.259e-05, Wasserstein-2 error: 0.02848
Test_accuracy: 0.5203, top 90% accuracy: 0.5386
Euler:


100.00%|██████████| [00:04<00:00, 21.73%/s]
100.00%|██████████| [00:20<00:00,  4.79%/s]



ESS per sample: 0.2146, grad evals per sample: 23.73, GEPS/ESS: 110.6
Energy dist vs ground truth: 0.2855, Wasserstein-2 error: 10.04
Test_accuracy: 0.5246, top 90% accuracy: 0.5445

Data shape: (7400, 20)
GT energy bias: 2.240e-04, test acc: 0.9719, test acc top 90%: 0.9724
NUTS:


warmup: 100%|██████████| 128/128 [00:03<00:00, 41.91it/s] 
sample: 100%|██████████| 256/256 [00:03<00:00, 74.10it/s] 



ESS per sample: 1.441, grad evals per sample: 12.23, GEPS/ESS: 8.491
Energy dist vs ground truth: 0.0002312, Wasserstein-2 error: 0.717
Test_accuracy: 0.9719, top 90% accuracy: 0.9724
LMC:
Target time-interval between samples for LMC: 0.1546
QUICSORT:


100.00%|██████████| [00:01<00:00, 97.86%/s]
100.00%|██████████| [00:04<00:00, 23.06%/s]



ESS per sample: 0.4066, grad evals per sample: 10.03, GEPS/ESS: 24.67
Energy dist vs ground truth: 0.0002498, Wasserstein-2 error: 0.7042
Test_accuracy: 0.9719, top 90% accuracy: 0.9724
Euler:


100.00%|██████████| [00:04<00:00, 20.86%/s]
100.00%|██████████| [00:21<00:00,  4.69%/s]



ESS per sample: 0.409, grad evals per sample: 25.07, GEPS/ESS: 61.28
Energy dist vs ground truth: 0.00729, Wasserstein-2 error: 0.816
Test_accuracy: 0.9711, top 90% accuracy: 0.9716

Data shape: (5000, 21)
GT energy bias: 2.431e-04, test acc: 0.8749, test acc top 90%: 0.8758
NUTS:


warmup: 100%|██████████| 128/128 [00:04<00:00, 31.33it/s]
sample: 100%|██████████| 256/256 [00:06<00:00, 42.60it/s] 



ESS per sample: 1.361, grad evals per sample: 32.79, GEPS/ESS: 24.09
Energy dist vs ground truth: 0.0001958, Wasserstein-2 error: 0.4343
Test_accuracy: 0.8748, top 90% accuracy: 0.8757
LMC:
Target time-interval between samples for LMC: 0.4146
QUICSORT:


100.00%|██████████| [00:02<00:00, 36.22%/s]
100.00%|██████████| [00:12<00:00,  7.93%/s]



ESS per sample: 2.113, grad evals per sample: 27.45, GEPS/ESS: 12.99
Energy dist vs ground truth: 0.0001771, Wasserstein-2 error: 0.4152
Test_accuracy: 0.8751, top 90% accuracy: 0.876
Euler:


100.00%|██████████| [00:13<00:00,  7.60%/s]
100.00%|██████████| [00:56<00:00,  1.76%/s]



ESS per sample: 1.72, grad evals per sample: 65.57, GEPS/ESS: 38.11
Energy dist vs ground truth: 15.87, Wasserstein-2 error: 132.6
Test_accuracy: 0.8044, top 90% accuracy: 0.8175

