In [1]:
from confirm.outlaw.nb_util import setup_nb

setup_nb()
import jax
import numpy as np
import jax.numpy as jnp
from confirm.lewislib import lewis

# Configuration used during simulation
name = "4d_full"
params = {
    "n_arms": 4,
    "n_stage_1": 50,
    "n_stage_2": 100,
    "n_stage_1_interims": 2,
    "n_stage_1_add_per_interim": 100,
    "n_stage_2_add_per_interim": 100,
    "stage_1_futility_threshold": 0.15,
    "stage_1_efficacy_threshold": 0.7,
    "stage_2_futility_threshold": 0.2,
    "stage_2_efficacy_threshold": 0.95,
    "inter_stage_futility_threshold": 0.6,
    "posterior_difference_threshold": 0,
    "rejection_threshold": 0.05,
    "key": jax.random.PRNGKey(0),
    "n_table_pts": 20,
    "n_pr_sims": 100,
    "n_sig2_sims": 20,
    "batch_size": int(2**12),
    "cache_tables": f"./{name}/lei_cache.pkl",
}
lei_obj = lewis.Lewis45(**params)
n_arm_samples = int(lei_obj.unifs_shape()[0])

2022-11-09 17:26:45.275028: E external/org_tensorflow/tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:267] failed call to cuInit: CUDA_ERROR_SYSTEM_DRIVER_MISMATCH: system has unsupported display driver / cuda driver combination
2022-11-09 17:26:45.275169: E external/org_tensorflow/tensorflow/compiler/xla/stream_executor/cuda/cuda_diagnostics.cc:313] kernel version 520.61.5 does not match DSO version 515.65.7 -- cannot find working devices in this configuration


In [2]:
import confirm.mini_imprint.lewis_drivers as lts

N = 256
K = 2**18
unifs = jax.random.uniform(jax.random.PRNGKey(0), (K, n_arm_samples, 4))
unifs_order = jnp.arange(n_arm_samples)
theta = np.random.rand(N, 4)
null_truth = np.ones((N, 3), dtype=bool)
sim_sizes = np.full(N, 2**13)
lam = 0.05

## Is copying the uniforms array expensive?

In [3]:
%%time
unifs.sum(axis=1).block_until_ready()

CPU times: user 32.4 ms, sys: 4.15 ms, total: 36.5 ms
Wall time: 96.2 ms


DeviceArray([[179.80161698, 171.71986841, 182.57861407, 175.14109228],
             [182.33056452, 169.87411049, 178.04835406, 175.63145328],
             [175.99040328, 185.59461593, 177.51426566, 179.69973176],
             [177.13297839, 172.99227953, 169.86119223, 168.91658085],
             [171.89859426, 181.44848379, 179.28002138, 177.31247947],
             [179.43708104, 172.75405059, 188.08694166, 175.59197108],
             [171.98213949, 176.62782851, 181.80152753, 184.41477242],
             [179.26737369, 171.73704153, 184.15515405, 171.13365942],
             [176.69139802, 171.91147518, 184.71368942, 182.71569347],
             [181.70612486, 176.24510559, 176.96877983, 175.30222542],
             ...,
             [165.7079974 , 177.41280957, 190.44120925, 173.16103124],
             [180.24945033, 176.29029716, 171.97620631, 174.27696988],
             [179.24719995, 175.40236968, 179.11769934, 179.31873293],
             [175.08198462, 179.53311256, 186.79704052, 178

In [4]:
%%time
outs = []
for i in range(0, unifs.shape[0], 10000):
    begin_idx = i
    end_idx = min(i + 10000, unifs.shape[0])
    outs.append(unifs[begin_idx:end_idx].sum(axis=1))
outs = jnp.concatenate(outs).block_until_ready()

CPU times: user 196 ms, sys: 13.8 ms, total: 210 ms
Wall time: 371 ms


In [5]:
%%time
unifs[:100000].sum(axis=1).block_until_ready()

CPU times: user 52.9 ms, sys: 2.19 ms, total: 55.1 ms
Wall time: 90.4 ms


DeviceArray([[179.80161698, 171.71986841, 182.57861407, 175.14109228],
             [182.33056452, 169.87411049, 178.04835406, 175.63145328],
             [175.99040328, 185.59461593, 177.51426566, 179.69973176],
             [177.13297839, 172.99227953, 169.86119223, 168.91658085],
             [171.89859426, 181.44848379, 179.28002138, 177.31247947],
             [179.43708104, 172.75405059, 188.08694166, 175.59197108],
             [171.98213949, 176.62782851, 181.80152753, 184.41477242],
             [179.26737369, 171.73704153, 184.15515405, 171.13365942],
             [176.69139802, 171.91147518, 184.71368942, 182.71569347],
             [181.70612486, 176.24510559, 176.96877983, 175.30222542],
             ...,
             [176.10235683, 168.37626266, 170.48436907, 171.50836705],
             [180.55719205, 175.48620749, 173.35058797, 177.0754814 ],
             [174.80902265, 177.25224545, 171.39816258, 178.78827998],
             [173.96520158, 175.34039108, 173.14592125, 175

In [6]:
%%time
outs = []
for i in range(0, 100000, 10000):
    begin_idx = i
    end_idx = min(i + 10000, unifs.shape[0])
    outs.append(unifs[begin_idx:end_idx].sum(axis=1))
outs = jnp.concatenate(outs).block_until_ready()

CPU times: user 43.2 ms, sys: 3.54 ms, total: 46.7 ms
Wall time: 80.2 ms


In [7]:
%%time
unifs[:30000].sum(axis=1).block_until_ready()

CPU times: user 53.5 ms, sys: 2.78 ms, total: 56.3 ms
Wall time: 90.4 ms


DeviceArray([[179.80161698, 171.71986841, 182.57861407, 175.14109228],
             [182.33056452, 169.87411049, 178.04835406, 175.63145328],
             [175.99040328, 185.59461593, 177.51426566, 179.69973176],
             [177.13297839, 172.99227953, 169.86119223, 168.91658085],
             [171.89859426, 181.44848379, 179.28002138, 177.31247947],
             [179.43708104, 172.75405059, 188.08694166, 175.59197108],
             [171.98213949, 176.62782851, 181.80152753, 184.41477242],
             [179.26737369, 171.73704153, 184.15515405, 171.13365942],
             [176.69139802, 171.91147518, 184.71368942, 182.71569347],
             [181.70612486, 176.24510559, 176.96877983, 175.30222542],
             ...,
             [179.14507797, 175.59830491, 168.03510018, 169.15966343],
             [179.1704259 , 175.13724082, 168.31696202, 182.99072149],
             [177.66093114, 174.56513581, 165.35283681, 175.80492763],
             [168.65633582, 170.51986776, 180.63847469, 167

In [8]:
%%time
outs = []
for i in range(0, 30000, 10000):
    begin_idx = i
    end_idx = min(i + 10000, unifs.shape[0])
    outs.append(unifs[begin_idx:end_idx].sum(axis=1))
outs = jnp.concatenate(outs).block_until_ready()

CPU times: user 25.1 ms, sys: 561 µs, total: 25.7 ms
Wall time: 58.9 ms


In [9]:
import gc

gc.collect()

0

In [10]:
lts.memory_status("hi")

hi memory usage 3.026335682
hi buffer sizes [(30000, 4), (30000, 4), (100000, 4), (262144, 4), (350,), (262144, 350, 4), (10, 160000, 3), (10, 4, 20), (10,), (10,), (4,), (10, 160000, 3), (10, 160000, 3), (10, 4, 20), (10,), (10,), (4,), (10, 160000, 3), (10, 4, 20), (10,), (10,), (4,), (4,), (3, 4), (4,), (1,), (20,), (20,), (20,), (19,), (20,), (4,), (1,), (15,), (15,), (15,), (2,)]


## What batch sizes??

In [4]:
%%time
rejs = lts.rej_runner(
    lei_obj,
    sim_sizes,
    lam,
    theta,
    null_truth,
    unifs,
    unifs_order,
    sim_batch_size=8192,
    grid_batch_size=64,
)

simulating with K=8192 and n_tiles=512 and batch_size=(64, 8192)
simulation runtime 12.518173217773438
CPU times: user 14.3 s, sys: 653 ms, total: 15 s
Wall time: 12.5 s


In [23]:
%%time
rejs = lts.rej_runner(
    lei_obj,
    sim_sizes,
    lam,
    theta,
    null_truth,
    unifs,
    unifs_order,
    sim_batch_size=8192,
    grid_batch_size=512,
)

simulating with K=8192 and n_tiles=512 and batch_size=(512, 8192)
simulation runtime 2.4775424003601074
CPU times: user 2.44 s, sys: 47.8 ms, total: 2.49 s
Wall time: 2.48 s


In [43]:
%%time
rejs = lts.rej_runner(
    lei_obj,
    sim_sizes,
    lam,
    theta,
    null_truth,
    unifs,
    unifs_order,
    sim_batch_size=1024,
    grid_batch_size=64,
)

simulating with K=8192 and n_tiles=512 and batch_size=(64, 1024)
simulation runtime 2.4350481033325195
CPU times: user 2.42 s, sys: 72.4 ms, total: 2.49 s
Wall time: 2.44 s


In [64]:
%%time
rejs = lts.rej_runner(
    lei_obj,
    sim_sizes,
    lam,
    theta,
    null_truth,
    unifs,
    unifs_order,
    sim_batch_size=2048,
    grid_batch_size=128,
)

simulating with K=8192 and n_tiles=256 and batch_size=(128, 2048)
simulation runtime 1.3524508476257324
CPU times: user 1.35 s, sys: 19.6 ms, total: 1.37 s
Wall time: 1.36 s


In [12]:
%%time
rejs = lts.rej_runner(
    lei_obj,
    sim_sizes,
    lam,
    theta,
    null_truth,
    unifs,
    unifs_order,
    sim_batch_size=8192,
    grid_batch_size=128,
)

simulating with K=8192 and n_tiles=512 and batch_size=(128, 8192)
simulation runtime 2.251185178756714
CPU times: user 2.1 s, sys: 157 ms, total: 2.26 s
Wall time: 2.26 s


In [49]:
%%time
rejs = lts.rej_runner(
    lei_obj,
    sim_sizes,
    lam,
    theta,
    null_truth,
    unifs,
    unifs_order,
    sim_batch_size=512,
    grid_batch_size=128,
)

simulating with K=8192 and n_tiles=512 and batch_size=(128, 512)
simulation runtime 2.4038445949554443
CPU times: user 2.4 s, sys: 49.8 ms, total: 2.45 s
Wall time: 2.41 s


In [16]:
%%time
rejs = lts.rej_runner(
    lei_obj,
    sim_sizes,
    lam,
    theta,
    null_truth,
    unifs,
    unifs_order,
    sim_batch_size=64,
    grid_batch_size=64,
)

simulating with K=8192 and n_tiles=512 and batch_size=(64, 64)
simulation runtime 4.788881063461304
CPU times: user 4.86 s, sys: 523 ms, total: 5.38 s
Wall time: 4.8 s


In [17]:
%%time
rejs = lts.rej_runner(
    lei_obj,
    sim_sizes[:64],
    lam,
    theta[:64],
    null_truth[:64],
    unifs,
    unifs_order,
    sim_batch_size=1024,
    grid_batch_size=64,
)

simulating with K=8192 and n_tiles=64 and batch_size=(64, 1024)
simulation runtime 0.4780433177947998
CPU times: user 449 ms, sys: 36 ms, total: 485 ms
Wall time: 479 ms


In [118]:
import time
import gc
import jax.numpy as jnp
from confirm.lewislib import batch
from confirm.mini_imprint.lewis_drivers import get_sim_size_groups


def simulator(p, unifs, unifs_order):
    return jnp.sum(unifs[:, :] < p[None, :]) / unifs.size, 1, 0


simulatev = jax.vmap(simulator, in_axes=(None, 0, None))


def stat(lei_obj, theta, null_truth, unifs, unifs_order):
    p = jax.scipy.special.expit(theta)
    test_stats, best_arms, _ = simulatev(p, unifs, unifs_order)
    false_test_stats = jnp.where(null_truth[best_arms - 1], test_stats, 100.0)
    return false_test_stats


statv = jax.jit(jax.vmap(stat, in_axes=(None, 0, 0, None, None)), static_argnums=(0,))


@jax.jit
def sumstats(stats, lam):
    return jnp.sum(stats < lam, axis=-1)


def rej_runner(
    lei_obj,
    sim_sizes,
    lam,
    theta,
    null_truth,
    unifs,
    unifs_order,
    sim_batch_size=1024,
    grid_batch_size=64,
):
    outs = []
    for (_, idx, stats) in _stats_backend(
        lei_obj,
        sim_sizes,
        theta,
        null_truth,
        unifs,
        unifs_order,
        sim_batch_size,
        grid_batch_size,
    ):
        outs.append(sumstats(stats, lam))
    return jnp.concatenate(outs)


def _stats_backend(
    lei_obj,
    sim_sizes,
    theta,
    null_truth,
    unifs,
    unifs_order,
    sim_batch_size=1024,
    grid_batch_size=64,
):
    batched_statv = batch.batch(
        batch.batch(
            statv, sim_batch_size, in_axes=(None, None, None, 0, None), out_axes=(1,)
        ),
        grid_batch_size,
        in_axes=(None, 0, 0, None, None),
    )

    for size, idx in get_sim_size_groups(sim_sizes):
        print(
            f"simulating with K={size} and n_tiles={idx.sum()}"
            f" and batch_size=({grid_batch_size}, {sim_batch_size})"
        )
        start = time.time()
        unifs_chunk = unifs[:size]
        stats = batched_statv(
            lei_obj, theta[idx], null_truth[idx], unifs_chunk, unifs_order
        )
        print("simulation runtime", time.time() - start)

        yield (size, idx, stats)

In [119]:
from functools import partial


@partial(jax.jit, static_argnums=(0,))
def stat_sum(lei_obj, lam, theta, null_truth, unifs, unifs_order):
    stats = jax.vmap(stat, in_axes=(None, 0, 0, None, None))(
        lei_obj, theta, null_truth, unifs, unifs_order
    )
    return jnp.sum(stats < lam, axis=-1)

In [151]:
unifs = jax.random.uniform(jax.random.PRNGKey(0), (K, 350, 4))

In [196]:
%%time
unifs_chunk = unifs[:1024]
res = stat_sum(
    lei_obj, 0.6, theta, null_truth, unifs_chunk, unifs_order
).block_until_ready()

CPU times: user 4 ms, sys: 195 µs, total: 4.19 ms
Wall time: 6.16 ms


In [154]:
batched_stat_sum = batch.batch(
    batch.batch(
        stat_sum, 8192, in_axes=(None, None, None, None, 0, None), out_axes=(1,)
    ),
    256,
    in_axes=(None, None, 0, 0, None, None),
)

In [158]:
%%time
unifs_chunk = unifs[:8192]
stats = batched_stat_sum(
    lei_obj, 0.05, theta, null_truth, unifs[:8192], unifs_order
).block_until_ready()

CPU times: user 71.7 ms, sys: 20.1 ms, total: 91.8 ms
Wall time: 126 ms


In [161]:
%%time
unifs_chunk = unifs[:8192]
stats = statv(lei_obj, theta, null_truth, unifs_chunk, unifs_order)
rej = jnp.sum(stats < lam, axis=-1).block_until_ready()

CPU times: user 1.93 ms, sys: 4.03 ms, total: 5.96 ms
Wall time: 33.7 ms


In [162]:
batched_statv = batch.batch(
    batch.batch(statv, 8192, in_axes=(None, None, None, 0, None), out_axes=(1,)),
    256,
    in_axes=(None, 0, 0, None, None),
)

In [164]:
%%time
unifs_chunk = unifs[:8192]
stats = batched_statv(lei_obj, theta, null_truth, unifs_chunk, unifs_order)
rej = jnp.sum(stats < lam, axis=-1).block_until_ready()

CPU times: user 32.5 ms, sys: 64.3 ms, total: 96.7 ms
Wall time: 131 ms


In [132]:
batched_statv2 = batch.batch(
    batch.batch(statv, 2048, in_axes=(None, None, None, 0, None), out_axes=(1,)),
    128,
    in_axes=(None, 0, 0, None, None),
)

In [165]:
%%time
unifs_chunk = unifs[:8192]
stats = batched_statv2(lei_obj, theta, null_truth, unifs_chunk, unifs_order)
rej = jnp.sum(stats < lam, axis=-1).block_until_ready()

CPU times: user 142 ms, sys: 32.2 ms, total: 174 ms
Wall time: 165 ms


In [187]:
%%time
unifs_chunk = unifs[:8192]
stats = statv(lei_obj, theta, null_truth, unifs_chunk, unifs_order)
rej = np.sum(stats < lam, axis=-1).block_until_ready()

CPU times: user 4.94 ms, sys: 36 µs, total: 4.98 ms
Wall time: 35.9 ms


In [201]:
%%time
unifs_chunk = unifs[:1024]
rejs = []
j_step = 128
i_step = 1024
for i in range(0, unifs_chunk.shape[0], i_step):
    i_begin = i
    i_end = i + i_step
    for j in range(0, theta.shape[0], j_step):
        j_begin = j
        j_end = min(theta.shape[0], j + j_step)
        subunifs_chunk = unifs_chunk[i_begin:i_end]
        # rejs.append(stat_sum(lei_obj, lam, theta[j_begin:j_end], null_truth[j_begin:j_end], subunifs_chunk, unifs_order))
        stats = statv(
            lei_obj,
            theta[j_begin:j_end],
            null_truth[j_begin:j_end],
            subunifs_chunk,
            unifs_order,
        )
        rejs.append(jnp.sum(stats < lam, axis=-1))
rej = jnp.concatenate(rejs).block_until_ready()

CPU times: user 1.85 ms, sys: 4.13 ms, total: 5.98 ms
Wall time: 6.53 ms


Bad pipe message: %s [b'@\x03iTd-r\xedm\xdaw\xac\x12Bm,\x96\x9f R&\x90\xb7V\x1f\x0f\x9e/\xbbQ \x9d\xb5F\x95\xf1\xd1\x9a\x15\xfcA\x17\xe8\xc4]\x12\xb7|\x89oG\x00\x08\x13\x02\x13\x03\x13\x01\x00\xff\x01\x00\x00\x8f\x00\x00\x00\x0e\x00\x0c\x00\x00\t127.0.0.1\x00\x0b\x00\x04\x03\x00\x01\x02\x00\n\x00\x0c\x00\n\x00\x1d\x00\x17\x00\x1e\x00\x19\x00\x18\x00#\x00\x00\x00\x16\x00\x00\x00\x17\x00\x00\x00\r\x00\x1e\x00\x1c\x04\x03\x05\x03\x06\x03\x08\x07\x08\x08\x08\t\x08\n\x08\x0b\x08\x04\x08\x05\x08\x06\x04\x01\x05\x01\x06\x01\x00+\x00\x03\x02\x03\x04\x00-\x00\x02\x01\x01\x003\x00&\x00$\x00\x1d\x00 5\\h\xd7&\x9cC\\>C\x99\x14\x0b\xa6\x08\xc3:\x96Q.Q\xdf\xc9']
Bad pipe message: %s [b'\x8e\x1d\x9bX\xc6\xcf\x96I]1\x1f\x93\x8b\x82\xd79+\x84\x00\x00|\xc0,\xc00\x00\xa3\x00\x9f\xcc\xa9\xcc\xa8\xcc\xaa\xc0\xaf\xc0\xad\xc0\xa3\xc0\x9f\xc0]\xc0a\xc0W\xc0S\xc0+\xc0/\x00\xa2\x00\x9e\xc0\xae\xc0\xac\xc0\xa2\xc0\x9e\xc0\\\xc0`\xc0V']
Bad pipe message: %s [b'\xc2\x13;\xf2\x14\xb2\x05\xb5\x95\xcb\x89\xd7\x9d\x1f

In [40]:
%%time
rejs = lts.rej_runner(
    lei_obj,
    sim_sizes,
    lam,
    theta,
    null_truth,
    unifs,
    unifs_order,
    sim_batch_size=2048,
    grid_batch_size=128,
)

simulating with K=8192 and n_tiles=256 and batch_size=(128, 2048)
simulation runtime 1.266707420349121
CPU times: user 1.24 s, sys: 44.4 ms, total: 1.28 s
Wall time: 1.27 s


In [3]:
unifs = jax.random.uniform(jax.random.PRNGKey(10), (K, n_arm_samples, 4))

In [19]:
%load_ext line_profiler

The line_profiler extension is already loaded. To reload it, use:
  %reload_ext line_profiler


In [21]:
%lprun -T output.log -f rej_runner -f _stats_backend rej_runner(lei_obj, sim_sizes, lam, theta, null_truth, unifs, unifs_order, sim_batch_size=1024, grid_batch_size=64)
open("output.log").read()

simulating with K=8192 and n_tiles=512 and batch_size=(64, 1024)
simulation runtime 2.30135178565979

*** Profile printout saved to text file 'output.log'. 




Timer unit: 1e-06 s

Total time: 2.30714 s
File: /tmp/ipykernel_18094/1071322944.py
Function: rej_runner at line 21

Line #      Hits         Time  Per Hit   % Time  Line Contents
    21                                           def rej_runner(
    22                                               lei_obj,
    23                                               sim_sizes,
    24                                               lam,
    25                                               theta,
    26                                               null_truth,
    27                                               unifs,
    28                                               unifs_order,
    29                                               sim_batch_size=1024,
    30                                               grid_batch_size=64,
    31                                           ):
    32         1          1.0      1.0      0.0      outs = []
    33         3    2303281.0 767760.3     99.8      for (