## Parameter recovery

We simulated data for a set of design and various ground truth parameters. Now we will try to estimate those parameters from the simulated data

In [1]:
# Built-in/Generic Imports
import os,sys
import glob
import time

# Libs
import numpy as np
import pandas as pd
import pymc as pm
import arviz as az

import logging
logger = logging.getLogger("pymc")
logger.setLevel(logging.ERROR)

In [2]:
def estimate_bhm(subj_id=[],design_df=[],choices=[],type='single'):

    delay_amt = design_df['cdd_delay_amt'].values
    delay_wait = design_df['cdd_delay_wait'].values
    immed_amt = design_df['cdd_immed_amt'].values
    immed_wait = design_df['cdd_immed_wait'].values
    
    # We will fit a model for each subject
    with pm.Model() as model_simple:

        # Hyperparameters for kappa and gamma
        # estimated from MLE approximations : np.exp(-3.60) = 0.0273, np.sqrt(1.71)=1.308
        mu_kappa_hyper = pm.Beta('mu_kappa_hyper',mu=np.exp(-3.60),sigma=0.01)
        sd_kappa_hyper = pm.Normal('sd_kappa_hyper',mu=np.sqrt(1.71),sigma=0.1)
        # estimated from MLE approximations : np.sqrt(2.30) = 1.517
        sd_gamma_hyper = pm.Normal('sd_hyper',mu=np.sqrt(2.30),sigma=0.1)

        kappa = pm.LogNormal('kappa',mu=mu_kappa_hyper,sigma=sd_kappa_hyper,shape=np.size(np.unique(subj_id)))
        gamma = pm.HalfNormal('gamma',sigma=sd_gamma_hyper,shape=np.size(np.unique(subj_id)))
        
        prob = pm.Deterministic('prob', 1 / (1 + pm.math.exp(-gamma[subj_id] * ( delay_amt/(1+(kappa[subj_id]*delay_wait)) 
                                                                                - immed_amt/(1+(kappa[subj_id]*immed_wait)) ))))

        y_1 = pm.Bernoulli('y_1',p=prob,observed=choices)

        trace_prior = pm.sample(10000, tune=1000, cores=5,target_accept=0.99,progressbar=False)

    # This is how you get a nice array. Note that this returns a pandas DataFrame, not a numpy array. Indexing is totally different.
    summary= az.summary(trace_prior,round_to=10)
    if type=='single':
        kappa_hat = summary['mean'].loc['kappa[{}]'.format(0)]
        gamma_hat = summary['mean'].loc['gamma[{}]'.format(0)]
    elif type=='aggregate':
        kappa_hat = [summary['mean'].loc['kappa[{}]'.format(x)] for x in set(subj_id)]
        gamma_hat = [summary['mean'].loc['gamma[{}]'.format(x)] for x in set(subj_id)]
    return kappa_hat,gamma_hat


In [3]:
fn = os.path.join('simul','ground_truth.csv')
params_df = pd.read_csv(fn,index_col=0)

fn = os.path.join('simul','design_set.csv')
design_df_single = pd.read_csv(fn,index_col=0)

simulated_data = sorted(glob.glob(os.path.join('simul','split','*/cdd/*.csv')))

choice_col = 'cdd_choice'

tStep0 = time.time()

# single
kappa_hat,gamma_hat = [],[]
# aggregate
subj_id,choices,design_list = [],[],[]
for index,fn in enumerate(simulated_data):
    print(fn)
    df = pd.read_csv(fn,index_col=0)
    # single
    kh,gh = estimate_bhm(subj_id=[0]*len(df[choice_col]),
                         design_df=design_df_single,
                         choices=df[choice_col],type='single')
    kappa_hat += [kh]
    gamma_hat += [gh]

    # aggregate
    choices += df[choice_col].values.tolist()
    subj_id += [index]*len(df[choice_col])
    design_list += [design_df_single]

print('Time to complete {} single BHM : {} minutes'.format(len(simulated_data),(time.time() - tStep0)/60.0))

tStep1 = time.time()
params_df['kappa_bhm_sing'] = kappa_hat
params_df['gamma_bhm_sing'] = gamma_hat

design_df_agg = pd.concat(design_list,axis=0)
# kappa_hat,gamma_hat = estimate_bhm(subj_id,design_df_agg,choices,type='aggregate')
params_df['kappa_bhm_agg'],params_df['gamma_bhm_agg'] = estimate_bhm(
    subj_id=subj_id,design_df=design_df_agg,choices=choices,type='aggregate')

print('Time to complete {} aggregate BHM : {} minutes'.format(len(simulated_data),(time.time() - tStep1)/60.0))


fn = os.path.join('simul','parameter_estimate_bhm.csv')
print('Saving estimates to : {}'.format(fn))
params_df.to_csv(fn)


simul/split/p0000/cdd/p0000_cdd.csv
simul/split/p0001/cdd/p0001_cdd.csv
simul/split/p0002/cdd/p0002_cdd.csv
simul/split/p0003/cdd/p0003_cdd.csv
simul/split/p0004/cdd/p0004_cdd.csv
simul/split/p0005/cdd/p0005_cdd.csv
simul/split/p0006/cdd/p0006_cdd.csv
simul/split/p0007/cdd/p0007_cdd.csv
simul/split/p0008/cdd/p0008_cdd.csv
simul/split/p0009/cdd/p0009_cdd.csv
simul/split/p0010/cdd/p0010_cdd.csv
simul/split/p0011/cdd/p0011_cdd.csv
simul/split/p0012/cdd/p0012_cdd.csv
simul/split/p0013/cdd/p0013_cdd.csv
simul/split/p0014/cdd/p0014_cdd.csv
simul/split/p0015/cdd/p0015_cdd.csv
simul/split/p0016/cdd/p0016_cdd.csv
simul/split/p0017/cdd/p0017_cdd.csv
simul/split/p0018/cdd/p0018_cdd.csv
simul/split/p0019/cdd/p0019_cdd.csv
simul/split/p0020/cdd/p0020_cdd.csv
simul/split/p0021/cdd/p0021_cdd.csv
simul/split/p0022/cdd/p0022_cdd.csv
simul/split/p0023/cdd/p0023_cdd.csv
simul/split/p0024/cdd/p0024_cdd.csv
simul/split/p0025/cdd/p0025_cdd.csv
simul/split/p0026/cdd/p0026_cdd.csv
simul/split/p0027/cdd/p0027_