Tuning inside Adagrid is a scary thing to do. This document is a summary of the various problems I've run into. 

First, some basics. We have three different groups of thresholds. $i$ is a tile index, $j$ is a bootstrap index.
1. The original sample, $\lambda^*_i$ and it's grid-wise minimum $\lambda^{**}$. 
2. $N_B$  global bootstraps $\lambda_{i, B_j}^*$ and their grid-wise minima $\lambda_{B_j}^{**}$. In the code, info regarding these bootstraps is prefixed with `B_`.
3. $N_b$  tile-wise investigation bootstraps $\lambda_{i, b_j}^*$ and their tile-wise minima $\lambda_{i}^{**}$. In the code, info regarding these bootstraps is prefixed with `twb_` standing for "tile-wise bootstrap". 

For each of these tuning problems, we tune at TIE level $\alpha_0 = \alpha - C_{\alpha}$ where $C_{\alpha}$ is the TIE consumed by continuous simulation extension. The C stands for "cost" and in the code this is called `alpha_cost`. 

The different problems I've run into so far:
- impossible tuning. This occurs when $\alpha_0 < 2 / (K+1)$ . In this situation, we can't tune because there are too few test statistics. We need to either run more simulations (increase $K$) or refine (increase $\alpha_0$). 
- it's possible to have a tile where the twb_min_lam is large... like 1 but B_lam is small like 0.015. 
	- these tiles have too much variance, but there's no way to detect them because our tilewise bootstrap didn't turn up any evidence of danger. 
	- it's not possible to completely remove this possibility because there's always some randomness.
	- this partially suggests i'm using a baseline of too few simulations or too large tiles. this is fixable. I bumped up the baseline K to 4096.
	- another option would be to use a new bootstrap in some way to get a new sample?
- part of the problem is tiles for which $\alpha_0$ is super small and so the tuning result is like index 2 of the batch which will of course result in a high variance. the simple thing to do is to make $\alpha_0$ larger. is there a smooth way to do this?

In [1]:
import confirm.outlaw.nb_util as nb_util

nb_util.setup_nb(pretty=True)

import time
import jax
import os
import re
import pickle
import numpy as np
import jax.numpy as jnp
import scipy.spatial
import matplotlib.pyplot as plt
from confirm.mini_imprint import grid
from confirm.lewislib import grid as lewgrid
from confirm.lewislib import lewis, batch
from confirm.mini_imprint import binomial, checkpoint

import confirm.mini_imprint.lewis_drivers as lts

from rich import print as rprint

# Configuration used during simulation
name = "4d_full"
params = {
    "n_arms": 4,
    "n_stage_1": 50,
    "n_stage_2": 100,
    "n_stage_1_interims": 2,
    "n_stage_1_add_per_interim": 100,
    "n_stage_2_add_per_interim": 100,
    "stage_1_futility_threshold": 0.15,
    "stage_1_efficacy_threshold": 0.7,
    "stage_2_futility_threshold": 0.2,
    "stage_2_efficacy_threshold": 0.95,
    "inter_stage_futility_threshold": 0.6,
    "posterior_difference_threshold": 0,
    "rejection_threshold": 0.05,
    "key": jax.random.PRNGKey(0),
    "n_table_pts": 20,
    "n_pr_sims": 100,
    "n_sig2_sims": 20,
    "batch_size": int(2**12),
    "cache_tables": f"./{name}/lei_cache.pkl",
}

# Configuration used during simulation
# name = "3d_smaller2"
# params = {
#     "n_arms": 3,
#     "n_stage_1": 50,
#     "n_stage_2": 100,
#     "n_stage_1_interims": 2,
#     "n_stage_1_add_per_interim": 100,
#     "n_stage_2_add_per_interim": 100,
#     "stage_1_futility_threshold": 0.15,
#     "stage_1_efficacy_threshold": 0.7,
#     "stage_2_futility_threshold": 0.2,
#     "stage_2_efficacy_threshold": 0.95,
#     "inter_stage_futility_threshold": 0.6,
#     "posterior_difference_threshold": 0,
#     "rejection_threshold": 0.05,
#     "key": jax.random.PRNGKey(0),
#     "n_table_pts": 20,
#     "n_pr_sims": 100,
#     "n_sig2_sims": 20,
#     "batch_size": int(2**12),
#     "cache_tables": f"./{name}/lei_cache.pkl",
# }

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
n_arms = params["n_arms"]
ns = np.concatenate(
    [np.ones(n_arms - 1)[:, None], -np.eye(n_arms - 1)],
    axis=-1,
)
null_hypos = [grid.HyperPlane(n, 0) for n in ns]
symmetry = []
for i in range(n_arms - 2):
    n = np.zeros(n_arms)
    n[i + 1] = 1
    n[i + 2] = -1
    symmetry.append(grid.HyperPlane(n, 0))

theta_min = -1.0
theta_max = 1.0
init_grid_size = 8
theta, radii = grid.cartesian_gridpts(
    np.full(n_arms, theta_min),
    np.full(n_arms, theta_max),
    np.full(n_arms, init_grid_size),
)
g_raw = grid.build_grid(theta, radii)
g = grid.build_grid(
    theta, radii, null_hypos=null_hypos, symmetry_planes=symmetry, should_prune=True
)

In [3]:
import adastate
from criterion import Criterion

lei_obj = lewis.Lewis45(**params)
n_arm_samples = int(lei_obj.unifs_shape()[0])

In [4]:
P = adastate.AdaParams(
    init_K=2**11,
    n_K_double=8,
    alpha_target=0.025,
    grid_target=0.002,
    bias_target=0.002,
    nB_global=50,
    nB_tile=50,
    step_size=2**14,
    tuning_min_idx=20
)
D = adastate.init_data(P, lei_obj, 0)
adastate.save(f"./{name}/data_params.pkl", (P, D))

In [7]:
load_iter = 'latest'
S, load_iter, fn = adastate.load(name, load_iter)
if S is None:
    print('initializing')
    S = adastate.init_state(P, g)
S.todo[0] = True

loading checkpoint 4d_full/124.pkl


In [8]:
R = adastate.AdaRunner(P, lei_obj)
iter_max = 10000
cost_per_sim = np.inf
for II in range(load_iter + 1, iter_max):
    if np.sum(S.todo) == 0:
        break

    print(f"starting iteration {II} with {np.sum(S.todo)} tiles to process")
    total_effort = np.sum(S.sim_sizes[S.todo])
    predicted_time = total_effort * cost_per_sim
    print(f"runtime prediction: {predicted_time:.2f}")

    start = time.time()
    R.step(P, S, D)
    cost_per_sim = (time.time() - start) / total_effort
    print(f"step took {time.time() - start:.2f}s")

    start = time.time()
    adastate.save(f"{name}/{II}.pkl", S)
    for old_i in checkpoint.exponential_delete(II, base=1):
        fp = f"{name}/{old_i}.pkl"
        if os.path.exists(fp):
            os.remove(fp)
    print(f"checkpointing took {time.time() - start:.2f}s")

    start = time.time()
    cr = Criterion(lei_obj, P, S, D)
    print(f'criterion took {time.time() - start:.2f}s')
    rprint(cr.report)

    start = time.time()
    if (np.sum(cr.which_refine) > 0 or np.sum(cr.which_deepen) > 0) and II != iter_max - 1:
        S.sim_sizes[cr.which_deepen] = S.sim_sizes[cr.which_deepen] * 2
        S.todo[cr.which_deepen] = True

        S = S.refine(P, cr.which_refine, null_hypos, symmetry)
        print(f"refinement took {time.time() - start:.2f}s")

starting iteration 125 with 1 tiles to process
runtime prediction: inf
tuning for 32768 simulations with 1 tiles and batch size (64, 1024)
0.8734002113342285
0.3879389762878418
step took 3.35s
checkpointing took 7.52s
tuning for 2048 simulations with 1 tiles and batch size (1, 16384)
0.08160567283630371
0.2525291442871094
criterion took 11.08s


refinement took 6.32s
starting iteration 126 with 127848 tiles to process
runtime prediction: 29025.09
tuning for 2048 simulations with 116960 tiles and batch size (64, 1024)
96.89662957191467
19.538331747055054
tuning for 4096 simulations with 10888 tiles and batch size (64, 1024)
17.563320875167847
2.470937490463257
step took 137.82s
checkpointing took 7.13s
tuning for 2048 simulations with 1 tiles and batch size (1, 16384)
0.08059406280517578
0.25332045555114746
criterion took 11.71s


refinement took 7.13s
starting iteration 127 with 22502 tiles to process
runtime prediction: 38.36
tuning for 2048 simulations with 6384 tiles and batch size (64, 1024)
5.336625814437866
1.3505454063415527
tuning for 4096 simulations with 16118 tiles and batch size (64, 1024)
25.87602686882019
3.5235230922698975
step took 36.90s
checkpointing took 7.26s
tuning for 2048 simulations with 1 tiles and batch size (1, 16384)
0.09521079063415527
0.26537060737609863
criterion took 11.09s


refinement took 6.91s
starting iteration 128 with 36978 tiles to process
runtime prediction: 50.04
tuning for 2048 simulations with 21584 tiles and batch size (64, 1024)
17.937230348587036
3.7743823528289795
tuning for 4096 simulations with 15394 tiles and batch size (64, 1024)
24.701953411102295
3.3779497146606445
step took 50.71s
checkpointing took 6.81s
tuning for 2048 simulations with 1 tiles and batch size (1, 16384)
0.09720706939697266
0.2738652229309082
criterion took 10.92s


refinement took 7.81s
starting iteration 129 with 243902 tiles to process
runtime prediction: 241.82
tuning for 2048 simulations with 238068 tiles and batch size (64, 1024)
197.3535394668579
39.63521456718445
tuning for 4096 simulations with 5834 tiles and batch size (64, 1024)
9.452295064926147
1.4555296897888184
step took 249.82s
checkpointing took 7.43s
tuning for 2048 simulations with 1 tiles and batch size (1, 16384)
0.09488844871520996
0.27478981018066406
criterion took 11.17s


refinement took 12.35s
starting iteration 130 with 75863 tiles to process
runtime prediction: 89.57
tuning for 2048 simulations with 62184 tiles and batch size (64, 1024)
51.58075714111328
10.490442276000977
tuning for 4096 simulations with 13679 tiles and batch size (64, 1024)
21.96653151512146
3.0932881832122803
step took 88.29s
checkpointing took 7.01s
tuning for 2048 simulations with 1 tiles and batch size (1, 16384)
0.09368181228637695
0.25615906715393066
criterion took 11.25s


refinement took 7.45s
starting iteration 131 with 109125 tiles to process
runtime prediction: 119.99
tuning for 2048 simulations with 96560 tiles and batch size (64, 1024)
79.8993809223175
14.597189664840698
tuning for 4096 simulations with 12565 tiles and batch size (64, 1024)


Bad pipe message: %s [b"\x00 \xde\xf8\x8b\xbc\x00b\xd8n\x9e\xefH#\x8d\xec\xa4\xee \xfcR\xf5\xf3\xe0\x95D,B\xc6\x07\x92\xb9\x01\xa7\xd3\xf0\x0epo'Xl\x98\xcc\xe7\xd2\x16\x00\xb2k\xea\x00\x08\x13\x02\x13\x03\x13\x01\x00\xff\x01\x00\x00\x8f\x00\x00\x00\x0e\x00\x0c\x00\x00\t127.0.0.1\x00\x0b\x00\x04\x03\x00\x01\x02\x00\n\x00\x0c\x00\n\x00\x1d\x00\x17\x00\x1e\x00\x19\x00\x18\x00#\x00\x00\x00\x16\x00\x00\x00\x17\x00"]
Bad pipe message: %s [b'\xc2(;\xf0>\xf7\xa3o\x15.\xfa\xc5a\xf7\x8e\xc9\xb5\x80\x00\x00\xa6\xc0,\xc00\x00\xa3\x00\x9f\xcc\xa9\xcc\xa8\xcc\xaa\xc0\xaf\xc0\xad\xc0\xa3\xc0\x9f\xc0]\xc0a\xc0W\xc0S\xc0+\xc0/']
Bad pipe message: %s [b'\xd4\xe0\xb5\x12\xa5\x13\xc3\x85K\xed\xee$l|F\xf3\xba\x91\x00\x00\xa2\xc0\x14\xc0\n\x009\x008\x007\x006\x00\x88\x00\x87\x00\x86\x00\x85\xc0\x19\x00:\x00\x89\xc0\x0f\xc0\x05\x005\x00\x84\xc0\x13\xc0\t\x003\x002\x001\x00']
Bad pipe message: %s [b'\x9a\x00\x99\x00\x98\x00\x97\x00E\x00D\x00C\x00B\xc0\x18\x004\x00\x9b\x00F\xc0\x0e\xc0\x04\x00/\x00\x96\x00A\x0

20.159343957901
2.759104013442993
step took 118.75s
checkpointing took 7.55s
tuning for 2048 simulations with 1 tiles and batch size (1, 16384)
0.0948643684387207
0.26118898391723633
criterion took 11.64s


refinement took 7.91s
starting iteration 132 with 115826 tiles to process
runtime prediction: 124.38
tuning for 2048 simulations with 104196 tiles and batch size (64, 1024)
85.90844917297363
15.653781175613403
tuning for 4096 simulations with 11630 tiles and batch size (64, 1024)
18.610960960388184
2.6488096714019775
step took 124.18s
checkpointing took 8.31s
tuning for 2048 simulations with 1 tiles and batch size (1, 16384)
0.09390139579772949
0.2635467052459717
criterion took 14.06s


refinement took 9.24s
starting iteration 133 with 77571 tiles to process
runtime prediction: 88.20
tuning for 2048 simulations with 64608 tiles and batch size (64, 1024)
54.77169442176819
10.467864513397217
tuning for 4096 simulations with 12963 tiles and batch size (64, 1024)
20.813146829605103
2.8315486907958984
step took 90.04s
checkpointing took 7.31s
tuning for 2048 simulations with 1 tiles and batch size (1, 16384)
0.09694790840148926
0.2661004066467285
criterion took 11.47s


refinement took 7.36s
starting iteration 134 with 56439 tiles to process
runtime prediction: 69.99
tuning for 2048 simulations with 42504 tiles and batch size (64, 1024)
35.072375535964966
6.5060834884643555
tuning for 4096 simulations with 13935 tiles and batch size (64, 1024)
22.303184032440186
3.0821940898895264
step took 67.97s
checkpointing took 7.51s
tuning for 2048 simulations with 1 tiles and batch size (1, 16384)
0.0812070369720459
0.25762462615966797
criterion took 12.13s


refinement took 15.77s
starting iteration 135 with 68318 tiles to process
runtime prediction: 79.28
tuning for 2048 simulations with 54548 tiles and batch size (64, 1024)
45.056532859802246
8.331151485443115
tuning for 4096 simulations with 13770 tiles and batch size (64, 1024)
22.127875328063965
3.0028235912323
step took 79.62s
checkpointing took 7.49s
tuning for 2048 simulations with 1 tiles and batch size (1, 16384)
0.1012732982635498
0.2705380916595459
criterion took 12.00s


refinement took 8.13s
starting iteration 136 with 113612 tiles to process
runtime prediction: 121.78
tuning for 2048 simulations with 101668 tiles and batch size (64, 1024)
83.92931365966797
15.192634105682373
tuning for 4096 simulations with 11944 tiles and batch size (64, 1024)
19.135734796524048
2.622493267059326
step took 122.23s
checkpointing took 7.59s
tuning for 2048 simulations with 1 tiles and batch size (1, 16384)
0.05370068550109863
0.2593722343444824
criterion took 11.51s


refinement took 7.77s
starting iteration 137 with 186132 tiles to process
runtime prediction: 189.31
tuning for 2048 simulations with 177796 tiles and batch size (64, 1024)
148.43761324882507
27.255930423736572
tuning for 4096 simulations with 8336 tiles and batch size (64, 1024)
13.465358257293701
1.9849934577941895
step took 192.90s
checkpointing took 7.61s
tuning for 2048 simulations with 1 tiles and batch size (1, 16384)
0.0850679874420166
0.26450538635253906
criterion took 12.00s


refinement took 8.08s
starting iteration 138 with 192903 tiles to process
runtime prediction: 199.49
tuning for 2048 simulations with 184696 tiles and batch size (64, 1024)
153.5243170261383
27.75261902809143
tuning for 4096 simulations with 8207 tiles and batch size (64, 1024)
13.251951932907104
1.896531105041504
step took 198.36s
checkpointing took 7.44s
tuning for 2048 simulations with 1 tiles and batch size (1, 16384)
0.08052206039428711
0.2706177234649658
criterion took 13.08s


refinement took 8.92s
starting iteration 139 with 71015 tiles to process
runtime prediction: 83.25
tuning for 2048 simulations with 57624 tiles and batch size (64, 1024)
47.61862897872925
8.893636226654053
tuning for 4096 simulations with 13391 tiles and batch size (64, 1024)
21.528321266174316
2.9316935539245605
step took 82.08s
checkpointing took 7.94s
tuning for 2048 simulations with 1 tiles and batch size (1, 16384)
0.0798349380493164
0.25830745697021484
criterion took 12.39s


refinement took 21.75s
starting iteration 140 with 40941 tiles to process
runtime prediction: 54.55
tuning for 2048 simulations with 25784 tiles and batch size (64, 1024)
21.312519073486328
4.092013835906982
tuning for 4096 simulations with 15157 tiles and batch size (64, 1024)
24.257670402526855
3.2786309719085693
step took 53.94s
checkpointing took 8.18s
tuning for 2048 simulations with 1 tiles and batch size (1, 16384)
0.11131882667541504
0.26521944999694824
criterion took 12.96s


refinement took 9.17s
starting iteration 141 with 112751 tiles to process
runtime prediction: 119.90
tuning for 2048 simulations with 100804 tiles and batch size (64, 1024)
83.21742296218872
15.157179832458496
tuning for 4096 simulations with 11947 tiles and batch size (64, 1024)
19.138533115386963
2.6225011348724365
step took 121.53s
checkpointing took 7.57s
tuning for 2048 simulations with 1 tiles and batch size (1, 16384)
0.08553051948547363
0.25448083877563477
criterion took 12.87s


refinement took 9.52s
starting iteration 142 with 125243 tiles to process
runtime prediction: 133.28
tuning for 2048 simulations with 113740 tiles and batch size (64, 1024)
93.9915406703949
17.042383432388306
tuning for 4096 simulations with 11503 tiles and batch size (64, 1024)
18.4898419380188
2.5517001152038574
step took 133.57s
checkpointing took 7.99s
tuning for 2048 simulations with 1 tiles and batch size (1, 16384)
0.07973361015319824
0.2555570602416992
criterion took 12.52s


refinement took 9.27s
starting iteration 143 with 177604 tiles to process
runtime prediction: 182.63
tuning for 2048 simulations with 168232 tiles and batch size (64, 1024)
140.37809801101685
25.5211820602417
tuning for 4096 simulations with 9372 tiles and batch size (64, 1024)
15.082818984985352
2.125535011291504
step took 184.88s
checkpointing took 7.97s
tuning for 2048 simulations with 1 tiles and batch size (1, 16384)
0.08860182762145996
0.28585290908813477
criterion took 15.72s


refinement took 11.24s
starting iteration 144 with 226077 tiles to process
runtime prediction: 229.98
tuning for 2048 simulations with 219572 tiles and batch size (64, 1024)
181.37529397010803
32.739277362823486
tuning for 4096 simulations with 6505 tiles and batch size (64, 1024)
10.480717658996582
1.5572118759155273
step took 228.22s
checkpointing took 8.10s
tuning for 2048 simulations with 1 tiles and batch size (1, 16384)
0.07948589324951172
0.27148866653442383
criterion took 13.55s


refinement took 10.13s
starting iteration 145 with 174264 tiles to process
runtime prediction: 180.15
tuning for 2048 simulations with 164940 tiles and batch size (64, 1024)
136.2464520931244
24.773197412490845
tuning for 4096 simulations with 9324 tiles and batch size (64, 1024)
14.926278591156006
2.083827257156372
step took 179.81s
checkpointing took 8.20s
tuning for 2048 simulations with 1 tiles and batch size (1, 16384)
0.09359169006347656
0.25945472717285156
criterion took 14.42s


refinement took 10.48s
starting iteration 146 with 127316 tiles to process
runtime prediction: 135.56
tuning for 2048 simulations with 116228 tiles and batch size (64, 1024)
95.99532318115234
17.518115520477295
tuning for 4096 simulations with 11088 tiles and batch size (64, 1024)
17.864882230758667
2.4578871726989746
step took 135.33s
checkpointing took 8.21s
tuning for 2048 simulations with 1 tiles and batch size (1, 16384)
0.09534597396850586
0.2783076763153076
criterion took 15.24s


refinement took 10.20s
starting iteration 147 with 87996 tiles to process
runtime prediction: 98.81
tuning for 2048 simulations with 74936 tiles and batch size (64, 1024)
61.8918182849884
11.407104969024658
tuning for 4096 simulations with 13060 tiles and batch size (64, 1024)
20.985332250595093
2.883843421936035
step took 98.48s
checkpointing took 8.48s
tuning for 2048 simulations with 1 tiles and batch size (1, 16384)
0.0950937271118164
0.2793247699737549
criterion took 14.05s


refinement took 10.00s
starting iteration 148 with 32928 tiles to process
runtime prediction: 47.03
tuning for 2048 simulations with 17600 tiles and batch size (64, 1024)
14.557708501815796
2.837231159210205
tuning for 4096 simulations with 15328 tiles and batch size (64, 1024)
24.573076725006104
3.2899112701416016
step took 46.26s
checkpointing took 8.13s
tuning for 2048 simulations with 1 tiles and batch size (1, 16384)
0.09479641914367676
0.2855570316314697
criterion took 14.07s


refinement took 9.48s
starting iteration 149 with 52304 tiles to process
runtime prediction: 64.23
tuning for 2048 simulations with 37604 tiles and batch size (64, 1024)
31.101231336593628
5.832622051239014
tuning for 4096 simulations with 14700 tiles and batch size (64, 1024)
23.548077821731567
3.1711807250976562
step took 64.76s
checkpointing took 8.41s
tuning for 2048 simulations with 1 tiles and batch size (1, 16384)
0.10033369064331055
0.27898335456848145
criterion took 13.21s


refinement took 9.51s
starting iteration 150 with 159588 tiles to process
runtime prediction: 162.84
tuning for 2048 simulations with 150696 tiles and batch size (64, 1024)


In [64]:
typeI_sum = batched_rej(
    sim_sizes,
    (np.full(sim_sizes.shape[0], overall_cv),
    g.theta_tiles,
    g.null_truth,),
    unifs,
    unifs_order,
)

savedata = [
    g,
    sim_sizes,
    bootstrap_cvs,
    typeI_sum,
    hob_upper,
    pointwise_target_alpha
]
with open(f"{name}/final.pkl", "wb") as f:
    pickle.dump(savedata, f)

# Calculate actual type I errors?
typeI_est, typeI_CI = binomial.zero_order_bound(
    typeI_sum, sim_sizes, delta_validate, 1.0
)
typeI_bound = typeI_est + typeI_CI

hob_upper = binomial.holder_odi_bound(
    typeI_bound, g.theta_tiles, g.vertices, n_arm_samples, holderq
)
sim_cost = typeI_CI
hob_empirical_cost = hob_upper - typeI_bound
worst_idx = np.argmax(typeI_est)
worst_tile = g.theta_tiles[worst_idx]
typeI_est[worst_idx], worst_tile
worst_cv_idx = np.argmin(sim_cvs)
typeI_est[worst_cv_idx], sim_cvs[worst_cv_idx], g.theta_tiles[worst_cv_idx], pointwise_target_alpha[worst_cv_idx]
plt.hist(typeI_est, bins=np.linspace(0.02,0.025, 100))
plt.show()

theta_0 = np.array([-1.0, -1.0, -1.0])      # sim point
v = 0.1 * np.ones(theta_0.shape[0])     # displacement
f0 = 0.01                               # Type I Error at theta_0
fwd_solver = ehbound.ForwardQCPSolver(n=n_arm_samples)
q_opt = fwd_solver.solve(theta_0=theta_0, v=v, a=f0) # optimal q
ehbound.q_holder_bound_fwd(q_opt, n_arm_samples, theta_0, v, f0)

running for size 1000 with 4721515 tiles took 1573.7245726585388
