In [1]:
from confirm.outlaw.nb_util import setup_nb

setup_nb()

import io

import numpy as np
import pandas as pd

import confirm.imprint as ip

In [2]:
import jax
import jax.numpy as jnp


@jax.jit
def _sim(samples, theta, null_truth):
    p = jax.scipy.special.expit(theta)
    stats = jnp.sum(samples[None, :] < p[:, None], axis=2) / samples.shape[1]
    return jnp.where(
        null_truth[:, None, 0],
        1 - stats,
        jnp.inf,
    )


def unifs(seed, *, shape, dtype):
    return jax.random.uniform(jax.random.PRNGKey(seed), shape=shape, dtype=dtype)


class Binom1D:
    def __init__(self, cache, seed, max_K, *, n):
        self.family = "binomial"
        self.family_params = {"n": n}
        self.dtype = jnp.float32

        # cache_key = f'samples-{seed}-{max_K}-{n}-{self.dtype}'
        # if cache_key in cache:
        #     self.samples = cache[cache_key]
        # else:
        #     key = jax.random.PRNGKey(seed)
        #     self.samples = jax.random.uniform(key, shape=(max_K, n), dtype=self.dtype)
        #     cache.update({cache_key: self.samples})
        #
        self.samples = cache(unifs)(seed, shape=(max_K, n), dtype=self.dtype)

    def sim_batch(self, begin_sim, end_sim, theta, null_truth, detailed=False):
        return _sim(self.samples[begin_sim:end_sim], theta, null_truth)

In [3]:
class Cache:
    def __init__(self):
        self._cache = {}

    def __call__(self, func, safety=2, serialize=False):
        def wrapper(*args, **kwargs):
            key = (func, args, tuple(kwargs.items()))
            if key in self._cache:
                return self._cache[key]
            else:
                result = func(*args, **kwargs)
                self._cache[key] = result
                return result

        return wrapper

In [4]:
import os
import hashlib
import confirm


def hash_confirm_code():
    confirm_path = os.path.dirname(confirm.__file__)
    hashes = []
    hash_md5 = hashlib.md5()
    for path, subdirs, files in os.walk(confirm_path):
        for fn in files:
            if not fn.endswith(".py"):
                continue
            with open(os.path.join(path, fn), "rb") as f:
                for chunk in iter(lambda: f.read(4096), b""):
                    hash_md5.update(chunk)
    hash_md5.hexdigest()

In [8]:
import glob

glob.glob("../../confirm/**/*.py")

['../../confirm/lewislib/grid.py',
 '../../confirm/lewislib/__init__.py',
 '../../confirm/lewislib/lewis.py',
 '../../confirm/lewislib/jax_wrappers.py',
 '../../confirm/lewislib/table.py',
 '../../confirm/berrylib/util.py',
 '../../confirm/berrylib/quadrature.py',
 '../../confirm/berrylib/constants.py',
 '../../confirm/berrylib/__init__.py',
 '../../confirm/berrylib/fast_inla.py',
 '../../confirm/berrylib/dirty_bayes.py',
 '../../confirm/berrylib/imprint.py',
 '../../confirm/berrylib/mcmc.py',
 '../../confirm/berrylib/batch_run.py',
 '../../confirm/berrylib/fast_math.py',
 '../../confirm/berrylib/binomial.py',
 '../../confirm/models/ztest.py',
 '../../confirm/models/__init__.py',
 '../../confirm/models/fisher_exact.py',
 '../../confirm/models/binom1d.py',
 '../../confirm/outlaw/quad.py',
 '../../confirm/outlaw/numpyro_interface.py',
 '../../confirm/outlaw/inla.py',
 '../../confirm/outlaw/nb_util.py',
 '../../confirm/outlaw/__init__.py',
 '../../confirm/outlaw/berry.py',
 '../../confirm

In [4]:
cache = Cache()

In [5]:
%%time
unifs(0, shape=(10, 10), dtype=jnp.float32)

CPU times: user 97.6 ms, sys: 6.93 ms, total: 105 ms
Wall time: 103 ms


DeviceArray([[0.02379167, 0.8527204 , 0.8132185 , 0.5140263 , 0.17172801, 0.8026866 , 0.5124631 ,
              0.34838438, 0.50526905, 0.3370521 ],
             [0.10868239, 0.10520637, 0.83827364, 0.78986526, 0.34059846, 0.8349273 , 0.24575627,
              0.21387374, 0.02423227, 0.5617423 ],
             [0.28066766, 0.94366455, 0.61214995, 0.7383388 , 0.52419806, 0.65466726, 0.41012764,
              0.24028647, 0.74443066, 0.03544927],
             [0.851014  , 0.02434528, 0.47239733, 0.72706807, 0.35055435, 0.6274171 , 0.61077535,
              0.06525731, 0.8091929 , 0.21307838],
             [0.6465323 , 0.3245015 , 0.5538883 , 0.8849807 , 0.9591211 , 0.83856845, 0.48919427,
              0.11810577, 0.16933143, 0.83657074],
             [0.587505  , 0.6867087 , 0.95522237, 0.5797727 , 0.28024232, 0.34749162, 0.5199702 ,
              0.9811766 , 0.5645981 , 0.2446456 ],
             [0.68722725, 0.9616587 , 0.480047  , 0.88953114, 0.7083205 , 0.948612  , 0.67764974,
        

In [32]:
t.results().write_results(coverdir=".")


functions called:
filename: /Users/tbent/.mambaforge/envs/confirm/lib/python3.10/abc.py, modulename: abc, funcname: ABCMeta.__instancecheck__
filename: /Users/tbent/.mambaforge/envs/confirm/lib/python3.10/collections/__init__.py, modulename: __init__, funcname: _make
filename: /Users/tbent/.mambaforge/envs/confirm/lib/python3.10/collections/__init__.py, modulename: __init__, funcname: _replace
filename: /Users/tbent/.mambaforge/envs/confirm/lib/python3.10/contextlib.py, modulename: contextlib, funcname: _GeneratorContextManager.__enter__
filename: /Users/tbent/.mambaforge/envs/confirm/lib/python3.10/contextlib.py, modulename: contextlib, funcname: _GeneratorContextManager.__exit__
filename: /Users/tbent/.mambaforge/envs/confirm/lib/python3.10/contextlib.py, modulename: contextlib, funcname: _GeneratorContextManagerBase.__init__
filename: /Users/tbent/.mambaforge/envs/confirm/lib/python3.10/contextlib.py, modulename: contextlib, funcname: helper
filename: /Users/tbent/.mambaforge/envs/

In [27]:
%%timeit
model = Binom1D(cache, 0, 1000000, n=100)

197 ms ± 4.66 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [7]:
cache._cache

{(<function __main__.unifs(seed, *, shape, dtype)>,
  (0,),
  (('shape', (100, 1000)),
   ('dtype',
    jax.numpy.float32))): DeviceArray([[0.28537035, 0.32794476, 0.7018368 , 0.99795973, 0.80373716, 0.8437431 , 0.18751788,
               0.5537597 , 0.35716057, 0.633845  , ..., 0.9791453 , 0.98088837, 0.15622127,
               0.6792512 , 0.9664891 , 0.65000224, 0.5286546 , 0.06554615, 0.6344644 , 0.16242659],
              [0.60463357, 0.6501305 , 0.34219182, 0.62337744, 0.5855551 , 0.823779  , 0.17784536,
               0.37527883, 0.46417534, 0.9261869 , ..., 0.53743875, 0.526199  , 0.73024786,
               0.64630795, 0.3048563 , 0.8993064 , 0.27745914, 0.583395  , 0.1201272 , 0.26367867],
              [0.5007554 , 0.5009476 , 0.45099878, 0.18468988, 0.15988815, 0.35168993, 0.7108166 ,
               0.85837865, 0.76702344, 0.06741774, ..., 0.6344186 , 0.5593307 , 0.5593804 ,
               0.76436687, 0.85874534, 0.10786879, 0.9393954 , 0.5567074 , 0.04266787, 0.17117405],
  

## old

In [3]:
g = ip.cartesian_grid(theta_min=[-1], theta_max=[1], null_hypos=[ip.hypo("x0 < 0")])
rej_df = ip.validate(Binom1D, g, 0.5, K=2**10, model_kwargs={"n": 100})
rej_df

TypeError: Binom1D.__init__() missing 1 required positional argument: 'max_K'

In [26]:
model = Binom1D(0, 2**18, n=100)
db = ip.db.DuckDB.connect()

In [27]:
%%time
samples = pd.DataFrame(model.samples)

CPU times: user 106 µs, sys: 11 µs, total: 117 µs
Wall time: 116 µs


In [35]:
%%timeit
memfile = io.BytesIO()
np.save(memfile, samples.values)

14.3 ms ± 169 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [39]:
%%timeit
memfile.seek(0)
s2 = np.load(memfile)

12.5 ms ± 143 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [37]:
np.all(s2 == samples.values)

True

In [30]:
%%timeit
db.con.execute("drop table samples")
db.con.execute("create table samples as select * from samples")

125 ms ± 1.68 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
