In [1]:
import datetime
import os
import sys
import warnings

import jax
import jax.numpy as jnp
from get_model import get_model_and_data  # noqa: F401
from main import run_simple_lmc_numpyro  # noqa: F401
from numpyro.infer import MCMC, NUTS, Predictive  # noqa: F401
from progressive_full_run import run_progressive_logreg


%env JAX_PLATFORM_NAME=cuda
warnings.simplefilter("ignore", FutureWarning)

jnp.set_printoptions(precision=3, suppress=True, threshold=sys.maxsize)
jax.config.update("jax_enable_x64", True)
print(jax.devices("cuda"))

names = [
    "banana",
    "breast_cancer",
    "diabetis",
    "flare_solar",
    "german",
    "heart",
    "image",
    "ringnorm",
    "splice",
    "thyroid",
    "titanic",
    "twonorm",
    "waveform",
]

env: JAX_PLATFORM_NAME=cuda
[CudaDevice(id=0)]


In [2]:
import glob


timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
log_filename = f"progressive_results/log_{timestamp}.txt"
with open(log_filename, "w") as f:
    f.write(f"Results for time {timestamp}\n\n")

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
for name in names:
    filenames = glob.glob(f"progressive_results/result_dict_{name}_*.pkl")
    filenames.sort(key=os.path.getmtime)
    latest_dict = filenames[-1]
    print(f"==================== {name} ====================")
    run_progressive_logreg(
        name,
        log_filename,
        timestamp,
        nuts_dict_filename=latest_dict,
    )
    print()

Data shape: (5300, 2)
GT energy bias: 2.455e-05, test acc: 0.5498, test acc top 90%: 0.556


100.00%|██████████| [00:04<00:00, 24.18%/s]


QUICSORT acc: 0.5509, acc top 90%: 0.5575, energy: 2.590e-05, w2: 3.747e-04


100.00%|██████████| [00:05<00:00, 19.18%/s]


Euler acc: 0.5444, acc top 90%: 0.5541, energy: 7.631e-03, w2: 1.584e-02
NUTS acc: 0.5509, acc top 90%: 0.5568, energy: 7.520e-06, w2: 3.826e-04

Data shape: (263, 9)
GT energy bias: 9.460e-05, test acc: 0.6627, test acc top 90%: 0.6677


100.00%|██████████| [00:01<00:00, 51.61%/s]


QUICSORT acc: 0.666, acc top 90%: 0.6718, energy: 1.156e-04, w2: 1.355e-01


100.00%|██████████| [00:03<00:00, 29.10%/s]


Euler acc: 0.675, acc top 90%: 0.6818, energy: 2.128e-01, w2: 4.192e-01
NUTS acc: 0.6641, acc top 90%: 0.6726, energy: 2.228e-04, w2: 1.490e-01

Data shape: (768, 8)
GT energy bias: 4.407e-05, test acc: 0.7772, test acc top 90%: 0.7798


100.00%|██████████| [00:02<00:00, 42.04%/s]


QUICSORT acc: 0.7809, acc top 90%: 0.7836, energy: 1.135e-04, w2: 3.959e-02


100.00%|██████████| [00:03<00:00, 26.21%/s]


Euler acc: 0.7761, acc top 90%: 0.7791, energy: 4.000e-01, w2: 4.312e-01
NUTS acc: 0.7788, acc top 90%: 0.7814, energy: 1.338e-04, w2: 4.328e-02

Data shape: (144, 9)
GT energy bias: 8.068e-04, test acc: 0.6169, test acc top 90%: 0.6247


100.00%|██████████| [00:09<00:00, 10.16%/s]


QUICSORT acc: 0.6297, acc top 90%: 0.6389, energy: 6.087e-03, w2: 1.790e+00


100.00%|██████████| [00:19<00:00,  5.13%/s]


Euler acc: 0.6276, acc top 90%: 0.6363, energy: 2.851e+00, w2: 3.977e+01
NUTS acc: 0.6191, acc top 90%: 0.6269, energy: 1.461e-02, w2: 3.028e+00

Data shape: (1000, 20)
GT energy bias: 1.110e-04, test acc: 0.7837, test acc top 90%: 0.7872


100.00%|██████████| [00:02<00:00, 38.62%/s]


QUICSORT acc: 0.7855, acc top 90%: 0.7891, energy: 3.034e-04, w2: 1.567e-01


100.00%|██████████| [00:03<00:00, 25.33%/s]


Euler acc: 0.7769, acc top 90%: 0.7815, energy: 2.661e-01, w2: 8.403e-01
NUTS acc: 0.7842, acc top 90%: 0.7878, energy: 2.858e-04, w2: 1.610e-01

Data shape: (270, 13)
GT energy bias: 1.212e-04, test acc: 0.8014, test acc top 90%: 0.8094


100.00%|██████████| [00:02<00:00, 41.81%/s]


QUICSORT acc: 0.8061, acc top 90%: 0.8141, energy: 1.553e-04, w2: 3.381e-01


100.00%|██████████| [00:04<00:00, 23.38%/s]


Euler acc: 0.8179, acc top 90%: 0.824, energy: 3.350e-01, w2: 9.836e-01
NUTS acc: 0.8038, acc top 90%: 0.812, energy: 3.556e-04, w2: 3.772e-01

Data shape: (2086, 18)
GT energy bias: 3.996e-04, test acc: 0.821, test acc top 90%: 0.8223


100.00%|██████████| [00:05<00:00, 17.81%/s]


QUICSORT acc: 0.8237, acc top 90%: 0.8252, energy: 4.670e-04, w2: 7.302e-01


100.00%|██████████| [00:08<00:00, 11.50%/s]


Euler acc: 0.7975, acc top 90%: 0.8036, energy: 7.774e-01, w2: 4.557e+00
NUTS acc: 0.8215, acc top 90%: 0.8228, energy: 5.669e-04, w2: 7.305e-01

Data shape: (7400, 20)
GT energy bias: 1.145e-04, test acc: 0.7575, test acc top 90%: 0.7586


100.00%|██████████| [00:02<00:00, 35.99%/s]


QUICSORT acc: 0.7587, acc top 90%: 0.7598, energy: 5.104e-04, w2: 1.122e-01


100.00%|██████████| [00:04<00:00, 24.26%/s]


Euler acc: 0.7646, acc top 90%: 0.7663, energy: 2.865e-01, w2: 4.334e-01
NUTS acc: 0.7578, acc top 90%: 0.7589, energy: 2.022e-04, w2: 1.160e-01

Data shape: (2991, 60)
GT energy bias: 3.211e-04, test acc: 0.8264, test acc top 90%: 0.8275


100.00%|██████████| [02:01<00:00,  1.22s/%]


QUICSORT acc: 0.8269, acc top 90%: 0.8281, energy: 3.790e-04, w2: 8.340e-01


100.00%|██████████| [03:04<00:00,  1.85s/%]


Euler acc: 0.7956, acc top 90%: 0.8014, energy: 2.480e+00, w2: 1.231e+01
NUTS acc: 0.8263, acc top 90%: 0.8275, energy: 2.341e-02, w2: 9.465e-01

Data shape: (215, 5)
GT energy bias: 6.158e-05, test acc: 0.7996, test acc top 90%: 0.8062


100.00%|██████████| [00:01<00:00, 64.44%/s]


QUICSORT acc: 0.8256, acc top 90%: 0.8315, energy: 1.128e-04, w2: 8.251e-02


100.00%|██████████| [00:02<00:00, 35.89%/s]


Euler acc: 0.7541, acc top 90%: 0.7609, energy: 1.275e+00, w2: 2.079e+00
NUTS acc: 0.8024, acc top 90%: 0.809, energy: 1.317e-04, w2: 8.280e-02

Data shape: (24, 3)
GT energy bias: 8.428e-05, test acc: 0.5117, test acc top 90%: 0.5328


100.00%|██████████| [00:01<00:00, 63.53%/s]


QUICSORT acc: 0.5217, acc top 90%: 0.5403, energy: 6.309e-05, w2: 2.562e-02


100.00%|██████████| [00:02<00:00, 35.73%/s]


Euler acc: 0.5293, acc top 90%: 0.5503, energy: 1.218e-02, w2: 5.192e-02
NUTS acc: 0.5195, acc top 90%: 0.539, energy: 1.545e-04, w2: 2.459e-02

Data shape: (7400, 20)
GT energy bias: 2.240e-04, test acc: 0.9719, test acc top 90%: 0.9724


100.00%|██████████| [00:02<00:00, 36.26%/s]


QUICSORT acc: 0.9721, acc top 90%: 0.9726, energy: 3.468e-04, w2: 6.940e-01


100.00%|██████████| [00:04<00:00, 24.40%/s]


Euler acc: 0.9769, acc top 90%: 0.9771, energy: 2.889e-01, w2: 1.381e+00
NUTS acc: 0.9721, acc top 90%: 0.9725, energy: 4.373e-04, w2: 6.820e-01

Data shape: (5000, 21)
GT energy bias: 2.431e-04, test acc: 0.8749, test acc top 90%: 0.8758


100.00%|██████████| [00:06<00:00, 16.26%/s]


QUICSORT acc: 0.8756, acc top 90%: 0.8765, energy: 3.020e-04, w2: 4.065e-01


100.00%|██████████| [00:09<00:00, 10.49%/s]


Euler acc: 0.8601, acc top 90%: 0.8654, energy: 6.933e-01, w2: 3.703e+00
NUTS acc: 0.8751, acc top 90%: 0.8759, energy: 4.402e-04, w2: 4.112e-01

