In [None]:
import ssms
def sim_wrap(theta = torch.zeros(0), model = 'ddm', n_samples = 1, output_format = 'torch'):
    theta = theta.squeeze()
    
    if theta.dim() == 3:
        out_list = []
        out_processed_list = []
        
        for i in range(theta.shape[0]):
            theta_tmp = theta[i, ...].numpy().astype(np.float32)
    
            out_tmp = ssms.basic_simulators.simulator(theta = theta_tmp,
                                                      model = model,
                                                      n_samples = n_samples,
                                                      delta_t = 0.001,
                                                      max_t = 20.0,
                                                      no_noise = False,
                                                      bin_dim = None,
                                                      bin_pointwise = False)
        
            out_processed_list.append(np.concatenate([out_tmp['rts'].astype(np.float32), 
                                                      out_tmp['choices'].astype(np.float32)], axis = -1))
            
        return torch.tensor(np.stack(out_processed_list))
            
    elif (theta.dim() == 2) or (theta.dim() == 1):
        theta_tmp = theta.numpy().astype(np.float32)
        out_tmp = ssms.basic_simulators.simulator(theta = theta_tmp,
                                                  model = model,
                                                  n_samples = n_samples,
                                                  delta_t = 0.001,
                                                  max_t = 20.0,
                                                  no_noise = False,
                                                  bin_dim = None,
                                                  bin_pointwise = False)
        return torch.tensor(np.concatenate([out_tmp['rts'].astype(np.float32), 
                                                out_tmp['choices'].astype(np.float32)], axis = -1))
    else:
        raise NotImplementedError("theta should be of dimensionality 2 or 3 after squeezing")
        
            
def model_maker(model = 'ddm'):
    model_config = ssms.config.model_config[model]
    def ssm_model(num_trials, data, network):
        param_list = []
        for param in model_config['params']:
            idx = model_config['params'].index(param)
            param_list.append(pyro.sample(param, dist.Uniform(model_config['param_bounds'][0][idx],
                                                              model_config['param_bounds'][1][idx])))

        with pyro.plate("data", num_trials) as data_plate:
            return pyro.sample("obs", 
                               SSMDist(torch.stack(param_list, dim = -1), 
                                       num_trials, 
                                       network, 
                                       model), 
                               obs = data)
    return ssm_model

In [None]:
from scipy.stats import truncnorm
# Parameter recovery
n_samples = 1000 # total number of samples --> n_samples_go + n_samples_nogo = 2 * n_samples_go
n_parameter_samples = 200 # number of parameter vectors to samples and run the recovery over
std_denominator = 6 # a scaler for the standard deviation applied when sampling parameters in the allowed range

parameter_samples_dict = {}

for param in model_config["params"]:
    myclip_a = model_config["param_bounds"][0][model_config["params"].index(param)]
    myclip_b = model_config["param_bounds"][1][model_config["params"].index(param)]
    
    my_mean = myclip_a + (1/2) * (myclip_b - myclip_a)
    my_std = (myclip_b - myclip_a) / std_denominator
    
    a, b = (myclip_a - my_mean) / my_std, (myclip_b - my_mean) / my_std

    parameter_samples_dict[param] = truncnorm.rvs(a, 
                                                  b, 
                                                  loc = my_mean, 
                                                  scale = my_std, size = n_parameter_samples)

    plt.hist(parameter_samples_dict[param], alpha = 0.5, histtype = 'step', label = param)

plt.legend()

In [None]:
out = sim_wrap(theta = torch.tensor(np.array([parameter_samples_dict[key_][0] for key_ in parameter_samples_dict.keys()], 
                                             dtype = np.float32)),
               model = model,
               n_samples = n_samples)

In [None]:
def hierarchical_noncentered_ddm_model(num_subjects, num_trials, data):
    v_mu_mu = npy.sample("v_mu_mu", dist.Uniform(-3, 3))
    v_mu_std = npy.sample("v_mu_std", dist.HalfNormal(1.))
    
    a_mu_mu = npy.sample("a_mu_mu", dist.Uniform(0.3, 2.5))
    a_mu_std = npy.sample("a_mu_std", dist.HalfNormal(1.))
    
    z_mu_mu = npy.sample("z_mu_mu", dist.Uniform(0.1, 0.9))
    z_mu_std = npy.sample("z_mu_std", dist.HalfNormal(1.))
    
    t_mu_mu = npy.sample("t_mu_mu", dist.Uniform(0.0, 2.0))
    t_mu_std = npy.sample("t_mu_std", dist.HalfNormal(1.))
    
    with npy.plate("subjects", num_subjects) as subjects_plate:
        v_subj_z = npy.sample("v_subj_z", 
                               dist.Normal(0.0, 1.0))
        a_subj_z = npy.sample("a_subj_z", 
                               dist.Normal(0.0, 1.0))
        z_subj_z = npy.sample("z_subj_z", 
                               dist.Normal(0.0, 1.0))
        t_subj_z = npy.sample("t_subj_z", 
                               dist.Normal(0.0, 1.0))
        
        v_subj = npy.deterministic("v_subj", v_mu_mu + (v_subj_z * v_mu_std))
        a_subj = npy.deterministic("a_subj", a_mu_mu + (a_subj_z * a_mu_std))
        z_subj = npy.deterministic("z_subj", z_mu_mu + (z_subj_z * z_mu_std))
        t_subj = npy.deterministic("t_subj", t_mu_mu + (t_subj_z * t_mu_std))

        with npy.plate("data", num_trials) as data_plate:
            return npy.sample("obs", 
                               MyDDMh(v_subj, a_subj, z_subj, t_subj), 
                               obs = data)