In [2]:
import datetime
import sys
import warnings

import jax
import jax.numpy as jnp
from get_model import get_model_and_data  # noqa: F401
from main import run_simple_lmc_numpyro  # noqa: F401
from numpyro.infer import MCMC, NUTS, Predictive  # noqa: F401
from progressive_full_run import run_progressive_logreg


%env JAX_PLATFORM_NAME=cuda
warnings.simplefilter("ignore", FutureWarning)

jnp.set_printoptions(precision=3, suppress=True, threshold=sys.maxsize)
jax.config.update("jax_enable_x64", True)
print(jax.devices("cuda"))

names = [
    "banana",
    "breast_cancer",
    "diabetis",
    "flare_solar",
    "german",
    "heart",
    "image",
    "ringnorm",
    "splice",
    "thyroid",
    "titanic",
    "twonorm",
    "waveform",
]

env: JAX_PLATFORM_NAME=cuda
[CudaDevice(id=0)]


In [3]:
timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
log_filename = f"progressive_results/log_{timestamp}.txt"
with open(log_filename, "w") as f:
    f.write(f"Results for time {timestamp}\n\n")

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
for name in names:
    print(f"==================== {name} ====================")
    run_progressive_logreg(name, log_filename, timestamp)
    print()

Data shape: (5300, 2)
GT energy bias: 1.374e-05, test acc: 0.5468, test acc top 90%: 0.5547


100.00%|██████████| [00:18<00:00,  5.47%/s]


QUICSORT acc: 0.5517, acc top 90%: 0.557, energy: 5.419e-05, w2: 3.961e-04


100.00%|██████████| [00:31<00:00,  3.19%/s]


Euler acc: 0.5093, acc top 90%: 0.5209, energy: 5.655e-01, w2: 1.368e+00


warmup: 100%|██████████| 20/20 [00:39<00:00,  1.96s/it]
sample: 100%|██████████| 44/44 [00:56<00:00,  1.28s/it]


NUTS acc: 0.5509, acc top 90%: 0.5568, energy: 7.520e-06, w2: 3.826e-04

Data shape: (263, 9)
GT energy bias: 2.373e-04, test acc: 0.6599, test acc top 90%: 0.6655


100.00%|██████████| [00:17<00:00,  5.82%/s]


QUICSORT acc: 0.6651, acc top 90%: 0.6713, energy: 2.021e-04, w2: 1.483e-01


100.00%|██████████| [00:39<00:00,  2.53%/s]


Euler acc: 0.6566, acc top 90%: 0.665, energy: 2.755e-01, w2: 1.169e+00


warmup: 100%|██████████| 20/20 [00:14<00:00,  1.34it/s]
sample: 100%|██████████| 44/44 [00:19<00:00,  2.25it/s]


NUTS acc: 0.6641, acc top 90%: 0.6726, energy: 2.228e-04, w2: 1.490e-01

Data shape: (768, 8)
GT energy bias: 1.654e-04, test acc: 0.7777, test acc top 90%: 0.7802


100.00%|██████████| [00:22<00:00,  4.35%/s]


QUICSORT acc: 0.7786, acc top 90%: 0.7812, energy: 1.394e-04, w2: 4.334e-02


100.00%|██████████| [00:45<00:00,  2.18%/s]


Euler acc: 0.7074, acc top 90%: 0.7198, energy: 9.727e-01, w2: 4.079e+00


warmup: 100%|██████████| 20/20 [00:34<00:00,  1.75s/it]
sample: 100%|██████████| 44/44 [00:54<00:00,  1.24s/it]


NUTS acc: 0.7788, acc top 90%: 0.7814, energy: 1.338e-04, w2: 4.328e-02

Data shape: (144, 9)
GT energy bias: 1.909e-03, test acc: 0.6165, test acc top 90%: 0.6244


100.00%|██████████| [00:14<00:00,  6.88%/s]


QUICSORT acc: 0.6295, acc top 90%: 0.6385, energy: 6.967e-02, w2: 2.117e+00


100.00%|██████████| [00:36<00:00,  2.76%/s]


Euler acc: 0.5914, acc top 90%: 0.5968, energy: 4.254e+00, w2: 9.623e+01


warmup: 100%|██████████| 20/20 [01:10<00:00,  3.51s/it]
sample: 100%|██████████| 44/44 [02:55<00:00,  3.98s/it]


NUTS acc: 0.6191, acc top 90%: 0.6269, energy: 1.461e-02, w2: 3.028e+00

Data shape: (1000, 20)
GT energy bias: 3.364e-04, test acc: 0.7838, test acc top 90%: 0.7875


100.00%|██████████| [00:47<00:00,  2.13%/s]


QUICSORT acc: 0.7848, acc top 90%: 0.7884, energy: 3.173e-04, w2: 1.614e-01


100.00%|██████████| [01:31<00:00,  1.09%/s]


Euler acc: 0.7162, acc top 90%: 0.724, energy: 1.331e+00, w2: 7.621e+00


warmup: 100%|██████████| 20/20 [01:06<00:00,  3.31s/it]
sample: 100%|██████████| 44/44 [02:20<00:00,  3.18s/it]


NUTS acc: 0.7842, acc top 90%: 0.7878, energy: 2.858e-04, w2: 1.610e-01

Data shape: (270, 13)
GT energy bias: 4.896e-04, test acc: 0.8003, test acc top 90%: 0.8086


100.00%|██████████| [00:22<00:00,  4.46%/s]


QUICSORT acc: 0.8039, acc top 90%: 0.8118, energy: 3.685e-04, w2: 3.799e-01


100.00%|██████████| [00:51<00:00,  1.92%/s]


Euler acc: 0.7983, acc top 90%: 0.8073, energy: 9.073e-02, w2: 9.021e-01


warmup: 100%|██████████| 20/20 [00:18<00:00,  1.10it/s]
sample: 100%|██████████| 44/44 [00:22<00:00,  1.94it/s]


NUTS acc: 0.8038, acc top 90%: 0.812, energy: 3.556e-04, w2: 3.772e-01

Data shape: (2086, 18)
GT energy bias: 1.013e-03, test acc: 0.8213, test acc top 90%: 0.8227


100.00%|██████████| [00:47<00:00,  2.12%/s]


QUICSORT acc: 0.8224, acc top 90%: 0.824, energy: 6.639e-04, w2: 7.524e-01


100.00%|██████████| [01:28<00:00,  1.13%/s]


Euler acc: 0.6833, acc top 90%: 0.6998, energy: 2.523e+00, w2: 1.213e+01


2024-09-29 12:15:30.778025: W external/xla/xla/service/hlo_rematerialization.cc:3005] Can't reduce memory use below 4.18GiB (4485233875 bytes) by rematerialization; only reduced to 4.39GiB (4718592000 bytes), down from 4.39GiB (4718592000 bytes) originally
2024-09-29 12:15:31.137893: W external/xla/xla/service/hlo_rematerialization.cc:3005] Can't reduce memory use below 3.95GiB (4240816537 bytes) by rematerialization; only reduced to 4.39GiB (4718592000 bytes), down from 4.39GiB (4718592000 bytes) originally
warmup: 100%|██████████| 20/20 [06:18<00:00, 18.90s/it]
sample: 100%|██████████| 44/44 [07:03<00:00,  9.63s/it]
2024-09-29 12:29:03.559442: W external/xla/xla/tsl/framework/bfc_allocator.cc:482] Allocator (GPU_0_bfc) ran out of memory trying to allocate 6.75GiB (rounded to 7249603328)requested by op 
2024-09-29 12:29:03.559877: W external/xla/xla/tsl/framework/bfc_allocator.cc:494] ************************____________________________________________________________________________


ValueError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 7249603088 bytes.