In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from numba import njit
from tqdm.notebook import tqdm
import pystan
import stan_utility

In [None]:
import multiprocessing
multiprocessing.set_start_method("fork")

## Simulators

### Static

In [None]:
# def dynamic_prior(batch_size):
#     """
#     Generates a random draw from the diffusion model prior.
#     """
#     params = np.random.gamma(5.0, 0.5, (batch_size, 6))
#     hyper_params = np.random.uniform(0.01, 0.1, (batch_size, 6))

#     return np.c_[params, hyper_params]

def dynamic_prior(batch_size):
    """
    Generates a random draw from the diffusion model prior.
    """
    v = np.random.gamma(2.5, 1/1.5, (batch_size, 4))
    a = np.random.gamma(4.0, 1/3.0, batch_size)
    ndt = np.random.gamma(1.5, 1/5.0, batch_size)
    hyper_params = np.random.uniform(0.01, 0.1, (batch_size, 6))

    return np.c_[v, a, ndt, hyper_params]


@njit
def context_gen(batch_size, n_obs):
    obs_per_condition = int(n_obs / 4)
    context = np.zeros((batch_size, n_obs), dtype=np.int32)
    x = np.repeat([1, 2, 3, 4], obs_per_condition)
    for i in range(batch_size):
        np.random.shuffle(x)
        context[i] = x
    return context


@njit
def diffusion_trial(v, a, ndt, zr=0.5, dt=0.001, s=1.0, max_iter=1e4):
    """
    Simulates a single reaction time from a simple drift-diffusion process.
    """

    n_iter = 0
    x = a * zr
    c = np.sqrt(dt * s)
    
    while x > 0 and x < a:
        
        # DDM equation
        x += v*dt + c * np.random.randn()
        
        n_iter += 1
        
    rt = n_iter * dt
    return rt+ndt if x > 0 else -(rt+ndt)


@njit
def static_diffusion_process(prior_samples, context, n_obs):
    """
    Performs one run of a static diffusion model process.
    """
    
    params_t, params_stds = np.split(prior_samples, 2, axis=-1)
    
    rt = np.zeros(n_obs)
    
    # Iterate over number of trials
    for t in range(n_obs):
        
        # Run diffusion process
        rt[t] = diffusion_trial(params_t[context[t] - 1], params_t[4], params_t[5])
        
    return np.vstack((rt, context)).T, params_t
    

@njit
def static_batch_simulator(prior_samples, n_obs):
    
    batch_size = prior_samples.shape[0]
    context = context_gen(batch_size, n_obs)
    sim_data = np.zeros((batch_size, n_obs, 2))
    theta = np.zeros((batch_size, n_obs, 6))

    for i in range(batch_size):
        sim_data[i], theta[i] = static_diffusion_process(prior_samples[i], 
                                                    context[i],
                                                    n_obs)
    return sim_data, theta

In [None]:
N_OBS = 800
batch_size = 1

In [None]:
prior_draws = dynamic_prior(batch_size)
sim_data, params_t = static_batch_simulator(prior_draws, N_OBS)
params_t.shape

## Stan modeling

In [None]:
stan_model = """
data {
  int<lower=0> N;                 
  real<lower=0> rt[N];    
  int<lower=0,upper=1> correct[N];
  int<lower=1,upper=4> context[N];
}

parameters {
  real<lower=0> v[4];
  real<lower=0> a; 
  real<lower=0> ndt;
}

model {
  // Priors
  v ~ gamma(2.5, 1.5);
  a ~ gamma(4.0, 3.0);
  ndt ~ gamma(1.5, 5.0);
  
  for (n in 1:N) {
     if (correct[n] == 1) {
        rt[n] ~ wiener(a, ndt, 0.5, v[context[n]]);
     } 
     else {
        rt[n] ~ wiener(a, ndt, 1 - 0.5, -v[context[n]]);
     }
  }
}
"""

In [None]:
# compile stan model
sm = pystan.StanModel(model_code=stan_model)

In [None]:
def to_stan(sim_data):
    """
    Convert data from simulator to stan-friendly format.
    """
    rt = sim_data[:, 0]
    context = sim_data[:, 1].astype(np.int32)
    correct = (rt >= 0).astype(np.int32)
    rt = np.abs(rt).astype(np.float32)
    return {'rt': rt, 'correct': correct, 'context': context, 'N': rt.shape[0]}

def loop_stan(data, verbose=True):
    """
    Loop through data and obtain posteriors.
    """
    
    stan_post_samples = []
    for idx in range(data.shape[0]):
        data_i = to_stan(data[idx])
        ndt_init = min(data_i['x2'].min(), data_i['x1'].min()) * .75
        init = {'ndt': ndt_init}
        fit = sm.sampling(data=data_i, 
                          iter=2000, chains=4, n_jobs=4, init=[init, init, init, init],
                          control=dict(adapt_delta=0.99, max_treedepth=15))
        samples = fit.extract(permuted=True)
        stan_post_samples.append(samples)
        if verbose:
            print(f'Finished estimating data set {idx+1}...')
    return stan_post_samples

In [None]:
prior_draws = dynamic_prior(1)
sim_data, params_t = static_batch_simulator(prior_draws, N_OBS)

stan_data = to_stan(sim_data[0])
ndt_init = stan_data['rt'].min() * 0.75
init = {'ndt': ndt_init}

fit = sm.sampling(data=stan_data, 
                    iter=2000, chains=4, n_jobs=4, init=[init, init, init, init],
                    control=dict(adapt_delta=0.99, max_treedepth=15))
                    
# samples = fit.extract(permuted=True)

In [None]:
samples = fit.extract(permuted=True)
np.mean(samples['v'], axis=0)
# np.mean(samples['a'], axis=0)

In [None]:
params_t[0, 0, 0:4]
# params_t[0, 0, 4]

In [None]:
print(fit)

In [None]:
stan_utility.check_all_diagnostics(fit)