# Benchmarking to inform the Model API design.

I had several question that are important to answer to not hamstring our
performance when we start to commit to a Model API:

1. Does slicing and copying a large array of uniforms (or other pregenerated random variates) cause performance issues?
    - No! It does not. This is irrelevant even though I had previously
      suspected this was a significant cause of performance problems. Both the
      microbenchmark and later full benchmark demonstrate this. This
      is true across CPU and GPU.
2. What batch size should we use?
    - Larger when the simulation function is slower. (duh)
    - 1024 sims x 64 pts is reasonable for something like Lei.
    - On CPU: Large-ish but it's actually faster to use some batches rather than
      running the whole thing at once. This is unsurprising and due to
      cache-friendliness. Ideal was 32768 sims and 128 grid points in a single batch.
    - On GPU:
3. Are the concatenations in our current GPU code problematic.
    - On CPU: Yes, concatenation is bad for performance, especially when we're
      double batching over both grid points and simulations since we incur a
      concatenation for each outer batch.
    - On GPU: CHECK THIS! This will be especially bad because we force blocking
      and copy data from GPU to CPU.
4. Does it help to include the summation of rejections inside the jitted function call? Or can we factor that out into the calling code?
    - On CPU, we get better performance when we include the summation inside the jit call.
    - On GPU: CHECK THIS!

In [1]:
from confirm.outlaw.nb_util import setup_nb

setup_nb()

import time
from functools import partial

import numpy as np
import jax
import jax.numpy as jnp

from confirm.lewislib import batch

## Is copying the uniforms array expensive?



In [2]:
unifs = jax.random.uniform(jax.random.PRNGKey(0), (256000, 350, 4))
unifs10 = jax.random.uniform(jax.random.PRNGKey(0), (10000, 350, 4))
unifs2d = jax.random.uniform(jax.random.PRNGKey(0), (256000, 350, 1))

In [3]:
for k in range(4):
    start = time.time()
    out1 = unifs.sum(axis=1).block_until_ready()
    if k >= 1:
        print("sum", time.time() - start)

    start = time.time()
    copy = unifs.copy().block_until_ready()
    if k >= 1:
        print("copy", time.time() - start)

    start = time.time()
    out2 = unifs[:-1].sum(axis=1).block_until_ready()
    if k >= 1:
        print("slicesum", time.time() - start)

    start = time.time()
    outs = []
    for i in range(0, unifs.shape[0], 10000):
        begin_idx = i
        end_idx = min(i + 10000, unifs.shape[0])
        outs.append(unifs[begin_idx:end_idx].sum(axis=1))
    out2 = jnp.concatenate(outs).block_until_ready()
    if k >= 1:
        print("batched sum", time.time() - start)
    np.testing.assert_allclose(out1, out2)

    start = time.time()
    unifs10.sum(axis=1).block_until_ready()
    if k >= 1:
        print("sum10k no slice", time.time() - start)

    start = time.time()
    unifs[:10000].sum(axis=1).block_until_ready()
    if k >= 1:
        print("sum10k with slice", time.time() - start)

    start = time.time()
    unifs2d.sum(axis=1).block_until_ready()
    if k >= 1:
        print("sum2d no slice", time.time() - start)

    start = time.time()
    unifs[:, :, :1].sum(axis=1).block_until_ready()
    if k >= 1:
        print("sum2d with slice", time.time() - start)

sum 0.04343724250793457
copy 0.31673121452331543
slicesum 0.22755980491638184
batched sum 0.21455097198486328
sum10k no slice 0.0024559497833251953
sum10k with slice 0.009029865264892578
sum2d no slice 0.008322954177856445
sum2d with slice 0.07137775421142578
sum 0.04260683059692383
copy 0.30614209175109863
slicesum 0.2239699363708496
batched sum 0.20693397521972656
sum10k no slice 0.0021729469299316406
sum10k with slice 0.007977962493896484
sum2d no slice 0.008790016174316406
sum2d with slice 0.06688880920410156
sum 0.03954720497131348
copy 0.29372501373291016
slicesum 0.2198009490966797
batched sum 0.2037220001220703
sum10k no slice 0.0016851425170898438
sum10k with slice 0.008710145950317383
sum2d no slice 0.009112834930419922
sum2d with slice 0.07049393653869629


## What batch sizes??

In [114]:
def simulator(p, unifs):
    return jnp.sum(unifs[:, :] < p[None, :]) / unifs.size


def stat(theta, null_truth, unifs):
    p = jax.scipy.special.expit(theta)
    simulatev = jax.vmap(simulator, in_axes=(None, 0))
    test_stats = simulatev(p, unifs)
    false_test_stats = jnp.where(null_truth[0], test_stats, 100.0)
    return false_test_stats


statv = jax.jit(jax.vmap(stat, in_axes=(0, 0, None)))


@partial(jax.jit, static_argnums=(0,))
def stat_sum(lam, theta, null_truth, unifs):
    stats = jax.vmap(stat, in_axes=(0, 0, None))(theta, null_truth, unifs)
    return jnp.sum(stats < lam, axis=-1)

In [119]:
N = 1024
theta = np.random.rand(N, 4)
null_truth = np.ones((N, 3), dtype=bool)
sim_sizes = np.full(N, 2**13)
lam = 0.05

In the cell below, I'm comparing several things:
1. How does the batch size affect the output? 
    - caution: I think this is quite different on GPU versus CPU. JAX on GPU
      pipelines GPU calls as long as we don't block and wait for the results.
      So, smaller batches are acceptable on the GPU.    

In [120]:
def simple(unifs_chunk, _1, _2):
    return stat_sum(
        lam,
        theta,
        null_truth,
        unifs_chunk,
    ).block_until_ready()


def batched(unifs_chunk, sim_batch_size, grid_batch_size):
    batched_stat_sum = batch.batch(
        batch.batch(stat_sum, sim_batch_size, in_axes=(None, None, None, 0)),
        grid_batch_size,
        in_axes=(None, 0, 0, None),
    )
    return batched_stat_sum(lam, theta, null_truth, unifs_chunk)


def late_concat_batch(unifs_chunk, sim_batch_size, grid_batch_size):
    rejs = []
    for i in range(0, unifs_chunk.shape[0], sim_batch_size):
        i_begin = i
        i_end = min(unifs_chunk.shape[0], i + sim_batch_size)
        for j in range(0, theta.shape[0], grid_batch_size):
            j_begin = j
            j_end = min(theta.shape[0], j + grid_batch_size)
            subunifs_chunk = unifs_chunk[i_begin:i_end]
            rejs.append(
                stat_sum(
                    lam,
                    theta[j_begin:j_end],
                    null_truth[j_begin:j_end],
                    subunifs_chunk,
                )
            )
    return jnp.concatenate(rejs).block_until_ready()


def late_concat_stat_then_sum(unifs_chunk, sim_batch_size, grid_batch_size):
    rejs = []
    for i in range(0, unifs_chunk.shape[0], sim_batch_size):
        i_begin = i
        i_end = min(unifs_chunk.shape[0], i + sim_batch_size)
        for j in range(0, theta.shape[0], grid_batch_size):
            j_begin = j
            j_end = min(theta.shape[0], j + grid_batch_size)
            subunifs_chunk = unifs_chunk[i_begin:i_end]
            stats = statv(
                theta[j_begin:j_end],
                null_truth[j_begin:j_end],
                subunifs_chunk,
            )
            rejs.append(jnp.sum(stats < lam, axis=-1))
    return jnp.concatenate(rejs).block_until_ready()

In [121]:
fncs = dict(
    simple=simple,
    batched=batched,
    late_concat_batch=late_concat_batch,
    late_concat_stat_then_sum=late_concat_stat_then_sum,
)

In [122]:
unifs = jax.random.uniform(jax.random.PRNGKey(0), (32768, 1, 4))

run_keys = list(fncs.keys())
for sim_batch_size in [1024, 2048, 4096, 8192, 16384, 32768]:
    print(" ")
    for grid_batch_size in [32, 64, 128, 256, 512, 1024]:
        print(" ")
        for k in range(2):
            for run_key in run_keys:
                start = time.time()
                result = fncs[run_key](unifs, sim_batch_size, grid_batch_size)
                if k >= 1:
                    print(
                        f"{run_key} ({sim_batch_size}, {grid_batch_size}) = {time.time() - start}"
                    )

 
 
simple (1024, 32) = 0.018233060836791992
batched (1024, 32) = 2.3820278644561768
late_concat_batch (1024, 32) = 0.8279647827148438
late_concat_stat_then_sum (1024, 32) = 0.8557298183441162
 
simple (1024, 64) = 0.017095088958740234
batched (1024, 64) = 1.1565570831298828
late_concat_batch (1024, 64) = 0.4140758514404297
late_concat_stat_then_sum (1024, 64) = 0.4697558879852295
 
simple (1024, 128) = 0.018488168716430664
batched (1024, 128) = 0.6148321628570557
late_concat_batch (1024, 128) = 0.22803211212158203
late_concat_stat_then_sum (1024, 128) = 0.26714372634887695
 
simple (1024, 256) = 0.018473148345947266
batched (1024, 256) = 0.3247189521789551
late_concat_batch (1024, 256) = 0.12849926948547363
late_concat_stat_then_sum (1024, 256) = 0.1658928394317627
 
simple (1024, 512) = 0.01674818992614746
batched (1024, 512) = 0.17656707763671875
late_concat_batch (1024, 512) = 0.08060789108276367
late_concat_stat_then_sum (1024, 512) = 0.10544490814208984
 
simple (1024, 1024) = 0.

In [125]:
unifs = jax.random.uniform(jax.random.PRNGKey(0), (16384, 5, 4))

run_keys = list(fncs.keys())
for sim_batch_size in [512, 1024, 2048, 4096, 8192, 16384]:
    print(" ")
    for grid_batch_size in [32, 64, 128, 256, 512, 1024]:
        print(" ")
        for k in range(2):
            for run_key in run_keys:
                start = time.time()
                result = fncs[run_key](unifs, sim_batch_size, grid_batch_size)
                if k >= 1:
                    print(
                        f"{run_key} ({sim_batch_size}, {grid_batch_size}) = {time.time() - start}"
                    )

 
 
simple (512, 32) = 0.01633596420288086
batched (512, 32) = 2.351898193359375
late_concat_batch (512, 32) = 0.8599920272827148
late_concat_stat_then_sum (512, 32) = 0.9273097515106201
 
simple (512, 64) = 0.016728878021240234
batched (512, 64) = 1.2322278022766113
late_concat_batch (512, 64) = 0.4427659511566162
late_concat_stat_then_sum (512, 64) = 0.49536800384521484
 
simple (512, 128) = 0.016174793243408203
batched (512, 128) = 0.6040279865264893
late_concat_batch (512, 128) = 0.219649076461792
late_concat_stat_then_sum (512, 128) = 0.2473299503326416
 
simple (512, 256) = 0.01509404182434082
batched (512, 256) = 0.3037099838256836
late_concat_batch (512, 256) = 0.12256622314453125
late_concat_stat_then_sum (512, 256) = 0.1425340175628662
 
simple (512, 512) = 0.015145063400268555
batched (512, 512) = 0.16445422172546387
late_concat_batch (512, 512) = 0.07249689102172852
late_concat_stat_then_sum (512, 512) = 0.09273791313171387
 
simple (512, 1024) = 0.016646146774291992
batche

In [126]:
unifs = jax.random.uniform(jax.random.PRNGKey(0), (16384, 31, 4))

run_keys = list(fncs.keys())
for sim_batch_size in [512, 1024, 2048, 4096, 8192, 16384]:
    print(" ")
    for grid_batch_size in [32, 64, 128, 256, 512, 1024]:
        print(" ")
        for k in range(2):
            for run_key in run_keys:
                start = time.time()
                result = fncs[run_key](unifs, sim_batch_size, grid_batch_size)
                if k >= 1:
                    print(
                        f"{run_key} ({sim_batch_size}, {grid_batch_size}) = {time.time() - start}"
                    )

 
 
simple (512, 32) = 0.1613318920135498
batched (512, 32) = 2.598283052444458
late_concat_batch (512, 32) = 1.108854055404663
late_concat_stat_then_sum (512, 32) = 1.1509339809417725
 
simple (512, 64) = 0.17508196830749512
batched (512, 64) = 1.4232990741729736
late_concat_batch (512, 64) = 0.6891357898712158
late_concat_stat_then_sum (512, 64) = 0.7098269462585449
 
simple (512, 128) = 0.1658649444580078
batched (512, 128) = 0.8247759342193604
late_concat_batch (512, 128) = 0.4614226818084717
late_concat_stat_then_sum (512, 128) = 0.48766207695007324
 
simple (512, 256) = 0.17583203315734863
batched (512, 256) = 0.5414307117462158
late_concat_batch (512, 256) = 0.3533968925476074
late_concat_stat_then_sum (512, 256) = 0.36873912811279297
 
simple (512, 512) = 0.2324049472808838
batched (512, 512) = 0.39293789863586426
late_concat_batch (512, 512) = 0.29862284660339355
late_concat_stat_then_sum (512, 512) = 0.3158531188964844
 
simple (512, 1024) = 0.16391801834106445
batched (512, 

In [128]:
unifs = jax.random.uniform(jax.random.PRNGKey(0), (16384, 150, 4))

fncs["simple"](unifs, 1, 1)

start = time.time()
fncs["simple"](unifs, 1, 1)
print(time.time() - start)

run_keys = ["batched", "late_concat_batch", "late_concat_stat_then_sum"]
for sim_batch_size in [2048, 4096, 8192, 16384]:
    print(" ")
    for grid_batch_size in [128, 256, 512, 1024]:
        print(" ")
        for k in range(2):
            for run_key in run_keys:
                start = time.time()
                result = fncs[run_key](unifs, sim_batch_size, grid_batch_size)
                if k >= 1:
                    print(
                        f"{run_key} ({sim_batch_size}, {grid_batch_size}) = {time.time() - start}"
                    )

22.354657888412476
 
 
batched (2048, 128) = 5.971358776092529
late_concat_batch (2048, 128) = 5.846956014633179
late_concat_stat_then_sum (2048, 128) = 5.888359069824219
 
batched (2048, 256) = 5.758255958557129
late_concat_batch (2048, 256) = 5.703609943389893
late_concat_stat_then_sum (2048, 256) = 5.749382972717285
 
batched (2048, 512) = 5.766067028045654
late_concat_batch (2048, 512) = 5.766237020492554
late_concat_stat_then_sum (2048, 512) = 5.792371034622192
 
batched (2048, 1024) = 6.105571746826172
late_concat_batch (2048, 1024) = 5.968732833862305
late_concat_stat_then_sum (2048, 1024) = 6.047160625457764
 
 
batched (4096, 128) = 5.980096817016602
late_concat_batch (4096, 128) = 5.837982177734375
late_concat_stat_then_sum (4096, 128) = 5.808861970901489
 
batched (4096, 256) = 5.766739368438721
late_concat_batch (4096, 256) = 5.745783090591431
late_concat_stat_then_sum (4096, 256) = 5.784740924835205
 
batched (4096, 512) = 6.136857986450195
late_concat_batch (4096, 512) = 

KeyboardInterrupt: 

In [129]:
unifs = jax.random.uniform(jax.random.PRNGKey(0), (16384, 150, 4))

run_keys = ["batched", "late_concat_batch", "late_concat_stat_then_sum"]
for sim_batch_size in [256, 512, 1024]:
    print(" ")
    for grid_batch_size in [32, 64, 128, 256]:
        print(" ")
        for k in range(2):
            for run_key in run_keys:
                start = time.time()
                result = fncs[run_key](unifs, sim_batch_size, grid_batch_size)
                if k >= 1:
                    print(
                        f"{run_key} ({sim_batch_size}, {grid_batch_size}) = {time.time() - start}"
                    )

 
 
batched (256, 32) = 14.0449538230896
late_concat_batch (256, 32) = 11.150763988494873
late_concat_stat_then_sum (256, 32) = 11.255561113357544
 
batched (256, 64) = 8.600068092346191
late_concat_batch (256, 64) = 7.188723087310791
late_concat_stat_then_sum (256, 64) = 7.206992864608765
 
batched (256, 128) = 6.206650972366333
late_concat_batch (256, 128) = 5.513365983963013
late_concat_stat_then_sum (256, 128) = 5.526437044143677
 
batched (256, 256) = 5.517500162124634
late_concat_batch (256, 256) = 5.2202980518341064
late_concat_stat_then_sum (256, 256) = 5.21939492225647
 
 
batched (512, 32) = 8.675849914550781
late_concat_batch (512, 32) = 7.147514820098877
late_concat_stat_then_sum (512, 32) = 7.215574026107788
 
batched (512, 64) = 6.239650011062622
late_concat_batch (512, 64) = 5.512767314910889
late_concat_stat_then_sum (512, 64) = 5.5762341022491455
 
batched (512, 128) = 5.552695035934448
late_concat_batch (512, 128) = 5.230609893798828
late_concat_stat_then_sum (512, 12