In [1]:
# Load modules
import hddm
import tensorflow as tf
import matplotlib
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
import scipy as scp
import psutil
from time import time
from copy import deepcopy
import os
import pickle
import argparse
import yaml as yml
from multiprocessing import Pool
from functools import partial
from scipy.stats.stats import pearsonr

# HDDM specific imports
from hddm.model_config import model_config

In [2]:
def choice_percentage_check(data = None,
                            perc_cutoff = 0.05,
                            response_options = [0, 1]):

    for resp in response_options:
        if np.sum(data.response == resp) < int(data.shape[0] * perc_cutoff):
            #print(np.sum(data.response == resp) / data.shape[0])
            return 0
        else:
            #print(np.sum(data.response == resp) / data.shape[0])
            pass
    return 1

def param_buffer_check(theta = None,
                       model = 'ddm',
                       buffer_perc = 0.1):
    for param in model_config[model]['params']:
        tmp_idx = model_config[model]['params'].index(param)

        a = model_config[model]['param_bounds'][0][tmp_idx]
        b = model_config[model]['param_bounds'][1][tmp_idx]

        tmp_dist = b - a
        tmp_dist_proposed  = theta[param].values[0] - a

        if (tmp_dist_proposed < (buffer_perc * (tmp_dist))) or (tmp_dist_proposed > ((1-buffer_perc) * (tmp_dist))):
            return 0
    return 1

def added_checks(theta = None,
                 check_dictionary = {'equal': {},
                                     'higher': {'v': -5},
                                     'lower': {'v': 5}}):
    for key_test_type_ in check_dictionary.keys():
        if check_dictionary[key_test_type_] is not None:
            for key_ in check_dictionary[key_test_type_]:
                if key_test_type_ == 'equal':
                    if theta[key_].values[0] != check_dictionary[key_test_type_][key_]:
                        return 0
                    
                elif key_test_type_ == 'higher':
                    if theta[key_].values[0] < check_dictionary[key_test_type_][key_]:
                        return 0

                elif key_test_type_ == 'lower':
                    if theta[key_].values[0] > check_dictionary[key_test_type_][key_]:
                        return 0
    return 1

In [5]:
def make_ground_truth_dataset(models = ['ddm'],
                              n_datasets = 50,
                              p_outlier = 0.0,
                              buffer_perc = 0.1,
                              choice_percentage_cutoff = 0.05,
                              added_check_dictionary = {'lower': {'v': 10}}, # 'lower', 'equal', 'higher'
                              response_coding = 'hddm',
                              save = True,
                              save_folder = 'data/parameter_recovery'
                              ):

    if response_coding == 'hddm':
        response_options = [0, 1]

    assert response_coding == 'hddm', 'Response coding needs to be set to hddm at this point. No alternatives allowed!'
            
    if type(models) is not list:
        models = [models]

    for model in models:
        param_list = hddm.model_config.model_config[model]['params']
        data_dict = {}
        cnt = 0
        simple_cnt = 0
        cnt_blocked = 0

        while cnt < n_datasets:
            param_df = hddm.simulators.make_parameter_vectors_nn(model = model,
                                                                 n_parameter_vectors = 1)
            model_data = hddm.simulators.simulator_single_subject(model = model,
                                                                parameters = param_df.loc[0, :][param_list].values, 
                                                                p_outlier = p_outlier,
                                                                n_samples = 1000)
            
            cp_check_ = choice_percentage_check(data = model_data[0], perc_cutoff = choice_percentage_cutoff, response_options = response_options)
            buffer_check_ = param_buffer_check(theta = param_df,
                                            buffer_perc = buffer_perc)
            added_checks_ = added_checks(theta = param_df,
                                        check_dictionary = added_check_dictionary)
            
            if cp_check_ and buffer_check_ and added_checks_:
                data_dict[cnt] = {}
                data_dict[cnt]['dataset'] = model_data[0]
                data_dict[cnt]['dataset']['subj_idx'] = cnt
                data_dict[cnt]['param_dict'] = model_data[1]
                cnt += 1
                #print(cnt)
            else:
                #print('defective dataset')
                pass
            simple_cnt += 1

            if cnt % 10 == 0:
                print(str(cnt), ' out of ', str(simple_cnt), ' parameters passed the filters!')
                cnt_blocked = cnt

        if save:
            pickle.dump(data_dict, open(save_folder + '/param_recov_dataset_' + model + '.pickle', 'wb'))
    return 1

In [6]:
make_ground_truth_dataset(models = ['ddm', 'ornstein', 'weibull', 'angle', 'levy'],
                          n_datasets = 100,
                          p_outlier = 0.0,
                          buffer_perc = 0.1,
                          choice_percentage_cutoff = 0.05,
                          added_check_dictionary = {'lower': {'v': 10}}, # 'lower', 'equal', 'higher'
                          response_coding = 'hddm',
                          save = True,
                          save_folder = 'data/parameter_recovery')

10  out of  50  parameters passed the filters!
10  out of  51  parameters passed the filters!
10  out of  52  parameters passed the filters!
10  out of  53  parameters passed the filters!
10  out of  54  parameters passed the filters!
10  out of  55  parameters passed the filters!
10  out of  56  parameters passed the filters!
10  out of  57  parameters passed the filters!
10  out of  58  parameters passed the filters!
10  out of  59  parameters passed the filters!
20  out of  96  parameters passed the filters!
20  out of  97  parameters passed the filters!
20  out of  98  parameters passed the filters!
30  out of  148  parameters passed the filters!
30  out of  149  parameters passed the filters!
30  out of  150  parameters passed the filters!
40  out of  176  parameters passed the filters!
40  out of  177  parameters passed the filters!
40  out of  178  parameters passed the filters!
40  out of  179  parameters passed the filters!
40  out of  180  parameters passed the filters!
40  o

1