In [1]:
import numpy as np
import pandas as pd

from adopy.base import JaxEngineV1

In [2]:
import jax
from jax.scipy.special import expit as inv_logit
from jax.scipy.stats import bernoulli

In [3]:
@jax.jit
def func_logistic_log_lik(choice, stimulus, guess_rate, lapse_rate, threshold, slope):
    f = inv_logit(slope * (stimulus - threshold))
    p = guess_rate + (1 - guess_rate - lapse_rate) * f
    return bernoulli.logpmf(choice, p)

In [4]:
grid_design = {
    'stimulus': np.linspace(20 * np.log10(.05), 20 * np.log10(400), 100)
}

grid_param = {
    'threshold': np.linspace(20 * np.log10(.1), 20 * np.log10(200), 100),
    'slope': np.linspace(0, 10, 101)[0:],
    'guess_rate': [0.5],
    'lapse_rate': [0.05],
}

grid_response = {'choice': [0, 1]}

In [5]:
%%time
engine = JaxEngineV1(
    designs=['stimulus'],
    parameters=['guess_rate', 'lapse_rate', 'threshold', 'slope'],
    responses=['choice'],
    model_func=func_logistic_log_lik,
    grid_design=grid_design,
    grid_param=grid_param,
    grid_response=grid_response,
)



CPU times: user 1.23 s, sys: 60.3 ms, total: 1.29 s
Wall time: 1.56 s


In [6]:
%time d = engine.get_design()

CPU times: user 68.9 ms, sys: 4.71 ms, total: 73.6 ms
Wall time: 103 ms


In [7]:
%timeit d = engine.get_design()

1.09 ms ± 291 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [8]:
%time engine.update(d, 1)

CPU times: user 273 ms, sys: 16.4 ms, total: 289 ms
Wall time: 319 ms


In [9]:
%timeit engine.update(d, 1)

12 ms ± 266 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


## Once compiled

In [10]:
%time d = engine.get_design()

CPU times: user 1.93 ms, sys: 838 µs, total: 2.77 ms
Wall time: 2.5 ms


In [11]:
%time engine.update(d, 1)

CPU times: user 21.7 ms, sys: 3.41 ms, total: 25.1 ms
Wall time: 26.4 ms
