In [16]:
import torch
import pyro
import pyro.distributions as dist
import argparse
import matplotlib
import matplotlib.pyplot as plt
from torch.distributions.constraints import positive

import logging
import os

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import pickle
import scipy
from scipy.stats import truncnorm
import cloudpickle as cpickle

import pyro
import ssms
import lanfactory
torch.set_default_dtype(torch.float32)

import lanfactory
from copy import deepcopy

smoke_test = ('CI' in os.environ)
assert pyro.__version__.startswith('1.8.1')

pyro.enable_validation(True)
pyro.set_rng_seed(9)
logging.basicConfig(format='%(message)s', level=logging.INFO)

import math
from numbers import Real
from numbers import Number

import torch
from torch.distributions import constraints
from torch.distributions.exp_family import ExponentialFamily
from torch.distributions.utils import _standard_normal, broadcast_all

import jax
#jax.config.update('jax_platform_name', 'cpu')
from jax import numpy as jnp

In [6]:
import platform
import psutil
p = psutil.subprocess.run(['lscpu'], capture_output=True, text=True)


In [17]:
test = pickle.load(open('/gpfs/data/frankmj/afengler/proj_lan_varinf/LAN_varinf/data/hierarchical/ddm_nsamples_1000_nsubjects_20_nparams_200_stdfracdenom_6/dict_variational_pyro_dataidx_9cpu_nparticles_1.pickle', 'rb'))


In [20]:
test = pickle.load(open('/gpfs/data/frankmj/afengler/proj_lan_varinf/LAN_varinf/data/hierarchical/ddm_nsamples_1000_nsubjects_20_nparams_200_stdfracdenom_6/arviz_variational_pyro_dataidx_9cpu_nparticles_1.pickle', 'rb'))


In [21]:
test

In [43]:
def sim_wrap(theta = torch.zeros(0), 
             model = 'ddm', 
             out_type = 'tensor', # 'tensor', 'numpy'
             ):
    if type(theta) == np.ndarray:
        pass
    else:
        theta = theta.numpy().astype(np.float32)
    out = ssms.basic_simulators.simulator(theta = theta,
                                          model = model,
                                          n_samples = 1,
                                          delta_t = 0.001,
                                          max_t = 20.0,
                                          no_noise = False,
                                          bin_dim = None,
                                          bin_pointwise = False)
    
    out = np.hstack([out['rts'].astype(np.float32),
                     out['choices'].astype(np.float32)])
                    
    if out_type == 'tensor':
        return torch.tensor(out)
    elif out_type == 'numpy':
        return out           
    else:
        return 'out_type = ' + out_type + ' is not supported'
    
def get_parameter_from_trunc_normal(model_config = None,
                                    param_str = None,
                                    std_denominator = 6.,
                                    range_shrinkage = 0.0,
                                    n_parameter_sets = 1,
                                    fixed_parameters = {}):
    
    if type(fixed_parameters) == dict:
        if param_str in fixed_parameters.keys():
            return np.array(fixed_parameters[param_str])
        
    param_idx = model_config["params"].index(param_str)
    min_ = model_config['param_bounds'][0][param_idx]
    max_ = model_config['param_bounds'][1][param_idx]
    range_ = max_ - min_
    mean_ = (max_ + min_) / 2
    min_adj = mean_ - (0.5 * (1. - range_shrinkage)) * range_
    max_adj = mean_ + (0.5 * (1. - range_shrinkage)) * range_

    std_adj = (max_adj - min_adj) / std_denominator

    a, b = (min_adj - mean_) / std_adj, (max_adj - mean_) / std_adj

    out = truncnorm.rvs(a, 
                        b, 
                        loc = mean_, 
                        scale = std_adj, 
                        size = n_parameter_sets)
    return out
    

def synthetic_data_single_subject(model = 'ddm',
                                  model_config = {},
                                  n_parameter_sets = 200,
                                  n_samples = 1000,
                                  std_as_fraction_of_range = 0.1666666,
                                  fixed_parameters = {'z': 0.5},
                                  save_file = None,
                                 ):

    # scaler for the standard deviation applied,
    # when sampling parameters in the allowed range
    std_denominator = 1 / std_as_fraction_of_range 

    # Make parameters
    parameter_samples_dict = {}
    for param in model_config["params"]:
        parameter_samples_dict[param] = get_parameter_from_trunc_normal(model_config = model_config,
                                                                        param_str = param,
                                                                        std_denominator = std_denominator,
                                                                        n_parameter_sets = n_parameter_sets,
                                                                        fixed_parameters = fixed_parameters)

    # Generate rt, choice data
    out_dict = {}
    out_dict['model'] = model 
    out_dict['data'] = {}
    for i in range(n_parameter_sets):
        out_dict['data'][i] = {}
        out_dict['data'][i]['gt_params'] = {key_: parameter_samples_dict[key_][i] for key_ in parameter_samples_dict.keys()}
        data_tmp = sim_wrap(theta = np.tile(np.array([out_dict['data'][i]['gt_params'][key_] for key_ in out_dict['data'][i]['gt_params'].keys()]).astype(np.float32), ((n_samples, 1))),
                            model = model,
                            out_type = 'numpy')
        
        # pyro ready data
        out_dict['data'][i]['pyro'] = torch.tensor(data_tmp)
        
        # numpyro ready data
        out_dict['data'][i]['numpy'] = data_tmp
        
        # numpyro ready data
        out_dict['data'][i]['numpyro'] = jnp.asarray(data_tmp)
        
        # preprocess data to be hddm ready
        data_tmp_pd = pd.DataFrame(data_tmp, columns = ['rt', 'response'])
        data_tmp_pd['subj_idx'] = 0
        
        out_dict['data'][i]['hddm'] = data_tmp_pd
        
        if i % 10 == 0:
            print(i, ' of ', n_parameter_sets, ' datasets finished')
        
    if save_file is not None:
        cpickle.dump(out_dict, open(save_file, 'wb'))
    
    return out_dict       
    
def synthetic_data_hierarchical(model = 'ddm',
                                model_config = {},
                                n_parameter_sets = 200,
                                n_subjects = 20,
                                n_samples_by_subject = 1000,
                                std_as_fraction_of_range = 0.1666666,
                                fixed_parameters = {'z': 0.5},
                                save_file = None,
                                ):
    
    # Make parameter sets
    param_dict = {}
    for i in range(n_parameter_sets):
        param_dict[i] = {}
        for param in model_config['params']:
            proceed = 0
            param_idx = model_config['params'].index(param)
            while not proceed:
                param_mu = get_parameter_from_trunc_normal(model_config = model_config,
                                                           param_str = param,
                                                           std_denominator = 6.,
                                                           range_shrinkage = 0.2,
                                                           n_parameter_sets = 1,
                                                           fixed_parameters = fixed_parameters)
                
                if type(fixed_parameters) == dict:
                    if param in fixed_parameters.keys():
                        param_std = 0
                    else: 
                        param_std = np.random.uniform(low = 0.05, 
                                                      high = (model_config['param_bounds'][1][param_idx] - \
                                                                 model_config['param_bounds'][0][param_idx]) / 10,
                                                      size = 1)
                        # param_std = scipy.stats.halfnorm.rvs(loc = 0, 
                        #                                      scale = 0.15) # potentially fix       
                else:
                    param_std = np.random.uniform(low = 0.05, 
                                                  high = (model_config['param_bounds'][1][param_idx] - \
                                                            model_config['param_bounds'][0][param_idx]) / 10,
                                                  size = 1)  
                    
                    # param_std = scipy.stats.halfnorm.rvs(loc = 0, 
                    #                                      scale = 0.15) # potentially fix

                params_subj = np.random.normal(loc = param_mu,
                                               scale = param_std,
                                               size = n_subjects)

                # check that none of the individual parameters are violating bounds
                if (params_subj < model_config['param_bounds'][0][model_config['params'].index(param)]).sum() == 0 and \
                        (params_subj > model_config['param_bounds'][1][model_config['params'].index(param)]).sum() == 0:
                    
                    param_dict[i][param + '_mu_mu'] = param_mu.astype(np.float32)
                    param_dict[i][param + '_mu_std'] =  param_std.astype(np.float32)
                    param_dict[i][param + '_subj'] = params_subj.astype(np.float32)
                    proceed = 1
                else: 
                    print('rejected ', param, '\n', 
                          'with ', param + '_mu_mu', '\n',
                          param_mu,
                          'with ', param + '_mu_std', '\n',
                          param_std,
                         'with ', param + '_subj', '\n',
                         params_subj)
    
    # Make data
    out_dict = {}
    out_dict['model'] = model
    out_dict['data'] = {}
    for i in range(n_parameter_sets):
        out_dict['data'][i] = {}
        out_dict['data'][i]['gt_params'] = deepcopy(param_dict[i])
        
        theta_tmp = np.repeat(np.stack([param_dict[i][param + '_subj'] for param in model_config['params']]).T, n_samples_by_subject, axis = 0)
        
        data_tmp = sim_wrap(theta = theta_tmp.astype(np.float32),
                            model = model,
                            out_type = 'numpy')
        
        # pyro data
        data_pyro = torch.tensor(np.resize(data_tmp, (n_subjects, n_samples_by_subject, 2)).swapaxes(0, 1).astype(np.float32))
        out_dict['data'][i]['pyro'] = data_pyro
        
        # numpyro data
        data_numpyro = data_pyro.numpy().astype(np.float32)
        out_dict['data'][i]['numpy'] = data_numpyro
        
        # jax
        data_jax = jnp.asarray(data_numpyro)
        out_dict['data'][i]['numpyro'] = data_jax
        
        
        # hddm data
        data_hddm = pd.DataFrame(data_tmp, columns = ['rt', 'response'])
        #print('data_hddm shape: ', data_hddm.shape)
        data_hddm['subj_idx'] = np.array([i for i in range(n_subjects)]).astype(np.float32).repeat(n_samples_by_subject)
        data_hddm['subj_idx'] = data_hddm['subj_idx'].astype(int)
        out_dict['data'][i]['hddm'] = data_hddm
        
        if i % 10 == 0:
            print(i, ' of ', n_parameter_sets, ' finished')
            
    if save_file is not None:
        cpickle.dump(out_dict, open(save_file, 'wb'))
    
    return out_dict

In [44]:
def make_fixparamstr(fixed_param_dict = {}):
    tmp = [key_ for key_ in fixed_param_dict.keys()]
    mystr = ''
    for str_ in tmp:
        mystr += str_
    return mystr

In [31]:
np.random.uniform(size = 1)

array([0.01037415])

## Single Subject

In [45]:
# Metaparameters
model = 'ddm'
model_config = ssms.config.model_config[model]
n_parameter_sets = 200
n_samples = 1000
stdfracdenom = 6
std_as_fraction_of_range = 1/stdfracdenom
fixed_param_dict = {}

fixparamstr = make_fixparamstr(fixed_param_dict)

if fixparamstr == '':
    # Derived file name
    save_file = 'data/single_subject/' + model + '_nsamples_' + str(n_samples) + '_nparams_' + \
                    str(n_parameter_sets) + '_stdfracdenom_' + str(stdfracdenom) + '.pickle'
else:
    save_file = 'data/single_subject/' + model + '_nsamples_' + str(n_samples) + '_nparams_' + \
                    str(n_parameter_sets) + '_stdfracdenom_' + str(stdfracdenom) + '_fixed_' + fixparamstr + '.pickle'

In [46]:
out_dict = synthetic_data_single_subject(model = model,
                                         model_config = model_config,
                                         n_parameter_sets = n_parameter_sets,
                                         n_samples = n_samples,
                                         std_as_fraction_of_range = std_as_fraction_of_range,
                                         fixed_parameters = fixed_param_dict,
                                         save_file = save_file)

0  of  200  datasets finished
10  of  200  datasets finished
20  of  200  datasets finished
30  of  200  datasets finished
40  of  200  datasets finished
50  of  200  datasets finished
60  of  200  datasets finished
70  of  200  datasets finished
80  of  200  datasets finished
90  of  200  datasets finished
100  of  200  datasets finished
110  of  200  datasets finished
120  of  200  datasets finished
130  of  200  datasets finished
140  of  200  datasets finished
150  of  200  datasets finished
160  of  200  datasets finished
170  of  200  datasets finished
180  of  200  datasets finished
190  of  200  datasets finished


## Hierarchical

In [47]:
# Metaparameters
model = 'ddm'
model_config = ssms.config.model_config[model]
n_parameter_sets = 200
n_subjects = 20
n_samples_by_subject = 500
stdfracdenom = 6
std_as_fraction_of_range = 1 / stdfracdenom
fixed_param_dict = {}
fixparamstr = make_fixparamstr(fixed_param_dict)


if fixparamstr == '':
    save_file = 'data/hierarchical/' + model + '_nsamples_' + str(n_samples) + '_nsubjects_' + str(n_subjects) + \
                    '_nparams_' + str(n_parameter_sets) + '_stdfracdenom_' + str(stdfracdenom) + '.pickle'
else:
    save_file = 'data/hierarchical/' + model + '_nsamples_' + str(n_samples) + '_nsubjects_' + str(n_subjects) + \
                    '_nparams_' + str(n_parameter_sets) + '_stdfracdenom_' + str(stdfracdenom) + '_fixed_' + fixparamstr + '.pickle'

In [48]:
synthetic_data_hierarchical(model = model,
                            model_config = model_config,
                            n_parameter_sets = n_parameter_sets,
                            n_subjects = n_subjects,
                            n_samples_by_subject = n_samples_by_subject,
                            std_as_fraction_of_range = std_as_fraction_of_range,
                            fixed_parameters = fixed_param_dict,
                            save_file = save_file,
                            )

rejected  v 
 with  v_mu_mu 
 [2.39378215] with  v_mu_std 
 [0.55529239] with  v_subj 
 [3.00426657 2.66011134 2.37248533 2.80465549 2.35061792 2.06115466
 2.60625114 3.16858681 2.68036926 3.35334331 3.15674258 2.70155377
 3.05028057 2.38277656 2.6988561  2.3973199  2.22110427 2.67913183
 2.59100784 2.40359524]
rejected  a 
 with  a_mu_mu 
 [1.99729552] with  a_mu_std 
 [0.20097923] with  a_subj 
 [2.51946437 2.12115495 1.83894233 1.9697629  2.33059866 1.95367456
 2.23650459 1.92062849 1.78202235 1.86915941 1.50385253 2.14937269
 1.6525935  2.34536697 1.9222938  1.71163882 2.12437775 1.8377852
 2.13598625 2.09477827]
rejected  z 
 with  z_mu_mu 
 [0.20776407] with  z_mu_std 
 [0.074943] with  z_subj 
 [0.28029229 0.24884145 0.33776642 0.21238623 0.22563195 0.25525529
 0.1170057  0.21162344 0.15365761 0.26808488 0.2576184  0.21753802
 0.11175108 0.28808617 0.31633705 0.1567862  0.16831948 0.07240302
 0.15985459 0.15169105]
rejected  v 
 with  v_mu_mu 
 [2.20179258] with  v_mu_std 
 [0.4

{'model': 'ddm',
 'data': {0: {'gt_params': {'v_mu_mu': array([1.0052878], dtype=float32),
    'v_mu_std': array([0.24195193], dtype=float32),
    'v_subj': array([0.9452495 , 1.3748401 , 1.256722  , 1.0833569 , 0.7857757 ,
           0.9122426 , 1.4068842 , 0.8900736 , 0.92867965, 0.71578366,
           1.2655684 , 0.9830523 , 1.0457228 , 0.5867159 , 0.8760517 ,
           0.7052873 , 1.0443975 , 1.4080981 , 0.71332514, 0.5924883 ],
          dtype=float32),
    'a_mu_mu': array([1.8441465], dtype=float32),
    'a_mu_std': array([0.14398074], dtype=float32),
    'a_subj': array([1.7637   , 1.7824256, 1.8768406, 1.8244165, 1.8914915, 1.6932824,
           1.7024006, 1.8390669, 1.9703456, 1.9213043, 1.808007 , 1.8960593,
           1.732847 , 1.8154541, 1.616302 , 1.6372718, 1.4253577, 2.0566232,
           1.9298105, 1.9444742], dtype=float32),
    'z_mu_mu': array([0.55305517], dtype=float32),
    'z_mu_std': array([0.07713509], dtype=float32),
    'z_subj': array([0.56016976, 0.51359

In [61]:
np.random.uniform(low = 0.01, high = 0.00001)

0.009487975495173146

In [58]:
help(np.random.uniform)

Help on built-in function uniform:

uniform(...) method of numpy.random.mtrand.RandomState instance
    uniform(low=0.0, high=1.0, size=None)
    
    Draw samples from a uniform distribution.
    
    Samples are uniformly distributed over the half-open interval
    ``[low, high)`` (includes low, but excludes high).  In other words,
    any value within the given interval is equally likely to be drawn
    by `uniform`.
    
    .. note::
        New code should use the ``uniform`` method of a ``default_rng()``
        instance instead; please see the :ref:`random-quick-start`.
    
    Parameters
    ----------
    low : float or array_like of floats, optional
        Lower boundary of the output interval.  All values generated will be
        greater than or equal to low.  The default value is 0.
    high : float or array_like of floats
        Upper boundary of the output interval.  All values generated will be
        less than or equal to high.  The default value is 1.0.
    size 

In [68]:
ssms.config.model_config['weibull']

{'name': 'weibull',
 'params': ['v', 'a', 'z', 't', 'alpha', 'beta'],
 'param_bounds': [[-2.5, 0.3, 0.2, 0.001, 0.31, 0.31],
  [2.5, 2.5, 0.8, 2.0, 4.99, 6.99]],
 'boundary': <function ssms.basic_simulators.boundary_functions.weibull_cdf(t=1, alpha=1, beta=1)>,
 'n_params': 6,
 'default_params': [0.0, 1.0, 0.5, 0.001, 3.0, 3.0],
 'hddm_include': ['z', 'alpha', 'beta'],
 'nchoices': 2}