## Parameter recovery

We simulated data for a set of design and various ground truth parameters. Now we will try to estimate those parameters from the simulated data

In [1]:
# Built-in/Generic Imports
import os,sys
import glob

# Libs
import numpy as np
import pandas as pd
import pymc as pm
import arviz as az

import logging
logger = logging.getLogger("pymc")
logger.setLevel(logging.ERROR)

In [2]:

def estimate_bhm(index,design_df,choices):

    # save_dir = os.path.join('simul','bhm')
    
    delay_amt = design_df['cdd_delay_amt'].values
    delay_wait = design_df['cdd_delay_wait'].values
    immed_amt = design_df['cdd_immed_amt'].values
    immed_wait = design_df['cdd_immed_wait'].values
    subj_id = [0]*len(choices)
    # We will fit a model for each subject
    with pm.Model() as model_simple:

        # Hyperparameters for kappa
        mu_kappa_hyper = pm.Beta('mu_kappa_hyper',mu=0.02,sigma=0.01)
        # use the same hyper SD for both parameters
        sd_hyper = pm.LogNormal('sd_hyper',sigma=1)

        kappa = pm.LogNormal('kappa',mu=mu_kappa_hyper,sigma=sd_hyper,shape=np.size(np.unique(subj_id)))
        gamma = pm.HalfNormal('gamma',sigma=sd_hyper,shape=np.size(np.unique(subj_id)))
        
        prob = pm.Deterministic('prob', 1 / (1 + pm.math.exp(-gamma[subj_id] * ( delay_amt/(1+(kappa[subj_id]*delay_wait)) 
                                                                                - immed_amt/(1+(kappa[subj_id]*immed_wait)) ))))

        y_1 = pm.Bernoulli('y_1',p=prob,observed=choices)

        trace_prior = pm.sample(1000, tune=100, cores=2,target_accept=0.98,progressbar=False)


    # This is how you get a nice array. Note that this returns a pandas DataFrame, not a numpy array. Indexing is totally different.
    summary= az.summary(trace_prior,round_to=10)
    kappa_hat = summary['mean'].loc['kappa[{}]'.format(0)]
    gamma_hat = summary['mean'].loc['gamma[{}]'.format(0)]
    return kappa_hat,gamma_hat




In [3]:

fn = os.path.join('simul','ground_truth.csv')
params_df = pd.read_csv(fn,index_col=0)

fn = os.path.join('simul','design_set.csv')
design_df = pd.read_csv(fn,index_col=0)

simulated_data = sorted(glob.glob(os.path.join('simul','response','*.csv')))
kappa_hat,gamma_hat = [],[]

for index,fn in enumerate(simulated_data):
    print(fn)
    df = pd.read_csv(fn,index_col=0)
    choices = df['response']
    kh,gh = estimate_bhm(index,design_df,choices)
    kappa_hat += [kh]
    gamma_hat += [gh]

params_df['kappa_bhm'] = kappa_hat
params_df['gamma_bhm'] = gamma_hat
params_df

simul/response/p0000.csv


Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [mu_kappa_hyper, sd_hyper, kappa, gamma]
Sampling 2 chains for 100 tune and 1_000 draw iterations (200 + 2_000 draws total) took 2 seconds.
We recommend running at least 4 chains for robust computation of convergence diagnostics


simul/response/p0001.csv


Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [mu_kappa_hyper, sd_hyper, kappa, gamma]
Sampling 2 chains for 100 tune and 1_000 draw iterations (200 + 2_000 draws total) took 2 seconds.
We recommend running at least 4 chains for robust computation of convergence diagnostics


simul/response/p0002.csv


Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [mu_kappa_hyper, sd_hyper, kappa, gamma]
Sampling 2 chains for 100 tune and 1_000 draw iterations (200 + 2_000 draws total) took 5 seconds.
We recommend running at least 4 chains for robust computation of convergence diagnostics


simul/response/p0003.csv


Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [mu_kappa_hyper, sd_hyper, kappa, gamma]
Sampling 2 chains for 100 tune and 1_000 draw iterations (200 + 2_000 draws total) took 2 seconds.
We recommend running at least 4 chains for robust computation of convergence diagnostics


simul/response/p0004.csv


Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [mu_kappa_hyper, sd_hyper, kappa, gamma]
Sampling 2 chains for 100 tune and 1_000 draw iterations (200 + 2_000 draws total) took 2 seconds.
We recommend running at least 4 chains for robust computation of convergence diagnostics


simul/response/p0005.csv


Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [mu_kappa_hyper, sd_hyper, kappa, gamma]
Sampling 2 chains for 100 tune and 1_000 draw iterations (200 + 2_000 draws total) took 4 seconds.
We recommend running at least 4 chains for robust computation of convergence diagnostics
  (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
  (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
  (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
  (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
  (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
  (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
  (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
  (between_chain_variance / within_chain_variance +

simul/response/p0006.csv


Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [mu_kappa_hyper, sd_hyper, kappa, gamma]
Sampling 2 chains for 100 tune and 1_000 draw iterations (200 + 2_000 draws total) took 5 seconds.
We recommend running at least 4 chains for robust computation of convergence diagnostics


simul/response/p0007.csv


Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [mu_kappa_hyper, sd_hyper, kappa, gamma]
Sampling 2 chains for 100 tune and 1_000 draw iterations (200 + 2_000 draws total) took 3 seconds.
We recommend running at least 4 chains for robust computation of convergence diagnostics
  (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
  (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
  (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
  (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
  (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
  (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
  (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
  (between_chain_variance / within_chain_variance +

simul/response/p0008.csv


Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [mu_kappa_hyper, sd_hyper, kappa, gamma]
Sampling 2 chains for 100 tune and 1_000 draw iterations (200 + 2_000 draws total) took 2 seconds.
We recommend running at least 4 chains for robust computation of convergence diagnostics
  (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
  (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
  (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
  (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
  (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
  (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
  (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
  (between_chain_variance / within_chain_variance +

simul/response/p0009.csv


Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [mu_kappa_hyper, sd_hyper, kappa, gamma]
Sampling 2 chains for 100 tune and 1_000 draw iterations (200 + 2_000 draws total) took 3 seconds.
We recommend running at least 4 chains for robust computation of convergence diagnostics
  (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
  (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
  (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
  (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
  (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
  (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
  (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
  (between_chain_variance / within_chain_variance +

simul/response/p0010.csv


Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [mu_kappa_hyper, sd_hyper, kappa, gamma]
Sampling 2 chains for 100 tune and 1_000 draw iterations (200 + 2_000 draws total) took 3 seconds.
We recommend running at least 4 chains for robust computation of convergence diagnostics


simul/response/p0011.csv


Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [mu_kappa_hyper, sd_hyper, kappa, gamma]
Sampling 2 chains for 100 tune and 1_000 draw iterations (200 + 2_000 draws total) took 2 seconds.
We recommend running at least 4 chains for robust computation of convergence diagnostics


simul/response/p0012.csv


Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [mu_kappa_hyper, sd_hyper, kappa, gamma]
Sampling 2 chains for 100 tune and 1_000 draw iterations (200 + 2_000 draws total) took 2 seconds.
We recommend running at least 4 chains for robust computation of convergence diagnostics


simul/response/p0013.csv


Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [mu_kappa_hyper, sd_hyper, kappa, gamma]
Sampling 2 chains for 100 tune and 1_000 draw iterations (200 + 2_000 draws total) took 4 seconds.
We recommend running at least 4 chains for robust computation of convergence diagnostics


simul/response/p0014.csv


Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [mu_kappa_hyper, sd_hyper, kappa, gamma]
Sampling 2 chains for 100 tune and 1_000 draw iterations (200 + 2_000 draws total) took 3 seconds.
We recommend running at least 4 chains for robust computation of convergence diagnostics


simul/response/p0015.csv


Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [mu_kappa_hyper, sd_hyper, kappa, gamma]
Sampling 2 chains for 100 tune and 1_000 draw iterations (200 + 2_000 draws total) took 3 seconds.
We recommend running at least 4 chains for robust computation of convergence diagnostics


simul/response/p0016.csv


Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [mu_kappa_hyper, sd_hyper, kappa, gamma]
Sampling 2 chains for 100 tune and 1_000 draw iterations (200 + 2_000 draws total) took 4 seconds.
We recommend running at least 4 chains for robust computation of convergence diagnostics


simul/response/p0017.csv


Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [mu_kappa_hyper, sd_hyper, kappa, gamma]
Sampling 2 chains for 100 tune and 1_000 draw iterations (200 + 2_000 draws total) took 3 seconds.
We recommend running at least 4 chains for robust computation of convergence diagnostics
  (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
  (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
  (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
  (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
  (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
  (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
  (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
  (between_chain_variance / within_chain_variance +

simul/response/p0018.csv


Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [mu_kappa_hyper, sd_hyper, kappa, gamma]


In [None]:
params_df