In [1]:
import numpy as np
import jax
from jax import numpy as jnp
from jax import scipy as jsp
from adopy.base import GridSpace, Task, Model, Engine



In [2]:
def make_grid_design():
    # p_var & a_var for risky & ambiguous trials
    pval = [.05, .10, .15, .20, .25, .30, .35, .40, .45]
    aval = [.125, .25, .375, .5, .625, .75]

    # risky trials: a_var fixed to 0
    pa_risky = [[p, 0] for p in pval]
    # ambiguous trials: p_var fixed to 0.5
    pa_ambig = [[0.5, a] for a in aval]
    pr_am = np.array(pa_risky + pa_ambig)

    # r_var & r_fix while r_var > r_fix
    rval = [10, 15, 21, 31, 45, 66, 97, 141, 206, 300]
    rewards = []
    for r_var in rval:
        for r_fix in rval:
            if r_var > r_fix:
                rewards.append([r_var, r_fix])
    rewards = np.array(rewards)

    return GridSpace({('p_var', 'a_var'): pr_am, ('r_var', 'r_fix'): rewards})


def make_grid_param():
    alp = np.linspace(0, 3, 11)
    bet = np.linspace(-3, 3, 11)
    gam = np.linspace(0, 5, 11)
    return GridSpace(dict(alpha=alp, beta=bet, gamma=gam))


def make_grid_response():
    return GridSpace({'choice': [0, 1]})


class ModelLinear(Model):
    @staticmethod
    @jax.jit
    def compute(choice, p_var, a_var, r_var, r_fix, alpha, beta, gamma):
        sv_var = jnp.power(r_var, alpha)
        sv_var = (p_var - beta * jnp.divide(a_var, 2)) * sv_var
        sv_fix = 0.5 * jnp.power(r_fix, alpha)
        p_obs = 1. / (1. + jnp.exp(-gamma * (sv_var - sv_fix)))
        return jsp.stats.bernoulli.logpmf(choice, p_obs)

In [3]:
grid_design = make_grid_design()
grid_param = make_grid_param()
grid_response = make_grid_response()

task = Task(
    name='Choice under risk and ambiguity',
    designs=['p_var', 'a_var', 'r_var', 'r_fix'],
    responses=['choice'],
    grid_design=grid_design,
    grid_response=grid_response)

model = ModelLinear(
    name='Linear model',
    task=task,
    params=['alpha', 'beta', 'gamma'],
    grid_param=grid_param)



In [4]:
%time
engine = Engine(task=task, model=model)

CPU times: user 1 µs, sys: 0 ns, total: 1 µs
Wall time: 3.1 µs


In [5]:
%%timeit
d = engine.get_design()
y = np.random.randint(0, 1)
engine.update(d, {'choice': y})

7.62 ms ± 104 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
