# HDDM stan model fitting

Imports

In [None]:
import pandas as pd
import numpy as np
import stan
import nest_asyncio
import seaborn as sns
import matplotlib.pyplot as plt
from scipy import stats

# enables multithreading in jupyter notebook
nest_asyncio.apply()

In [None]:
plt.rcParams['savefig.dpi'] = 300

## Stan model code

In [None]:
HDDM_delta_decomposed_ndt_trick_boundary_informed_simple = """
functions {
    real participant_level_diffusion_lpdf(
        vector y, 
        real boundary,
        real boundary_ne,
        real boundary_pre_acc,
        real boundary_ne_pre_acc,
        real ndt, 
        real bias, 
        real drift, 
        real drift_cond, 
        vector condition, 
        vector pre_ne,
        vector pre_acc,
        int n_trials
    ) {
        vector[n_trials] participant_level_likelihood;
        
        for (t in 1:n_trials) {
            if (abs(y[t]) - ndt > 0) {
                participant_level_likelihood[t] = diffusion_lpdf(y[t] | boundary  + boundary_ne*pre_ne[t] + boundary_pre_acc*pre_acc[t] + boundary_ne_pre_acc*pre_ne[t]*pre_acc[t], ndt, bias, drift + drift_cond*condition[t]);
            } else {
                participant_level_likelihood[t] = diffusion_lpdf(ndt | boundary  + boundary_ne*pre_ne[t] + boundary_pre_acc*pre_acc[t] + boundary_ne_pre_acc*pre_ne[t]*pre_acc[t], ndt, bias, drift + drift_cond*condition[t]);
            }
        }
        return(sum(participant_level_likelihood));
    }
      
    /* Wiener diffusion log-PDF for a single response (adapted from brms 1.10.2)
    * Arguments:
    *   Y: acc*rt in seconds (negative and positive RTs for incorrect and correct responses respectively)
    *   boundary: boundary separation parameter > 0
    *   ndt: non-decision time parameter > 0
    *   bias: initial bias parameter in [0, 1]
    *   drift: drift rate parameter
    * Returns:
    *   a scalar to be added to the log posterior
    */
    real diffusion_lpdf(real Y, real boundary, real ndt, real bias, real drift) {
        if (Y >= 0) {
            return wiener_lpdf( abs(Y) | boundary, ndt, bias, drift );
        } else {
            return wiener_lpdf( abs(Y) | boundary, ndt, 1-bias, -drift );
        }
    }
}

data {
    int<lower=1> N; // Number of trial-level observations
    int<lower=1> n_conditions; // Number of conditions (congruent and incongruent)
    int<lower=1> n_participants; // Number of participants

    array[n_participants, 2] int participants_trials_slices; // slices TODO
    vector[N] y; // acc*rt in seconds (negative and positive RTs for incorrect and correct responses respectively)
    vector[N] condition; // Contrast coded condition: -1 for erroneous and 1 for correct response respectively
    vector[N] pre_acc; // Contrast coded accuracy on previous trial
    vector[N] pre_ne; // centered correct/error negativity on previous trial
    array[N] int<lower=1> participant; // Participant index
}

parameters {
    vector<lower=0, upper=0.3>[n_participants] participants_ter; // Participant-level Non-decision time
    vector<lower=0, upper=3> [n_participants] participants_alpha; // Participant-level Boundary parameter (speed-accuracy tradeoff) // remove bound
    // vector<lower=0, upper=1>[n_participants] participants_beta; // Participant-level Start point bias towards choice A
    vector[n_participants] participants_delta; // Participant-level drift-rate
    vector[n_participants] participants_delta_cond; // Per-participant condition-level drift-rate adjustment 
        
    real<lower=0> ter; // Hierarchical non-decision time
    real<lower=0, upper=3> alpha; // Hierarchical boundary parameter (speed-accuracy tradeoff)
    // real beta; // Hierarchical start point bias towards choice A
    real delta; // Hierarchical drift-rate
    real delta_cond; // Hierarchical drift-rate adjustment 
    
    real<lower=0> ter_sd; // Between-participants variability in non-decision time
    real<lower=0> alpha_sd; // Between-participants variability in boundary parameter (speed-accuracy tradeoff)
    // real<lower=0> beta_sd; // Between-participants variability in start point bias towards choice A
    real<lower=0> delta_sd; // Between-participants variability in drift-rate
    real<lower=0> delta_cond_sd; // Between-participants variability in effect of condition
    
    
    // Non Hierarchical
    real alpha_ne;
    real alpha_pre_acc; 
    real alpha_ne_pre_acc; 
}

model {

    // ##########
    // Between-participant variability priors
    // ##########
    ter_sd ~ gamma(.3,1);
    alpha_sd ~ gamma(1,1);
    delta_sd ~ gamma(1,1);
    delta_cond_sd ~ gamma(1,1);

    // ##########
    // Hierarchical parameters priors
    // ##########
    ter ~ normal(.1, .2);
    alpha ~ normal(1, 1) T[0, 3];
    delta ~ normal(0, 2);
    delta_cond ~ normal(0, 2);

    // ##########
    // Non Hierarchical boundary parameters
    // ##########
    alpha_ne ~ normal(0, 0.1); // try (0, 0.1) or truncate
    alpha_pre_acc ~ normal(0,0.2);  // try (0, 0.2) or truncate
    alpha_ne_pre_acc ~ normal(0, 0.1); // try (0, 0.1) or truncate


    // ##########
    // Participant-level DDM parameter priors
    // ##########
    for (p in 1:n_participants) {

        // Participant-level non-decision time
        participants_ter[p] ~ normal(ter, ter_sd) T[0, .3];

        // Participant-level boundary parameter (speed-accuracy tradeoff)
        participants_alpha[p] ~ normal(alpha, alpha_sd) T[0, 3];

        //Participant-level drift rate
        participants_delta[p] ~ normal(delta, delta_sd);
        
        //Participant-level condition_adjustment
        participants_delta_cond[p] ~ normal(delta_cond, delta_cond_sd);  
                
        
        target += participant_level_diffusion_lpdf( y[participants_trials_slices[p][1]:participants_trials_slices[p][2]] | participants_alpha[p], alpha_ne, alpha_pre_acc, alpha_ne_pre_acc, participants_ter[p], 0.5, participants_delta[p], participants_delta_cond[p], condition[participants_trials_slices[p][1]:participants_trials_slices[p][2]], pre_ne[participants_trials_slices[p][1]:participants_trials_slices[p][2]],pre_acc[participants_trials_slices[p][1]:participants_trials_slices[p][2]], (participants_trials_slices[p][2] - participants_trials_slices[p][1] + 1));         
    }
}
"""

In [None]:
HDDM_delta_decomposed_ndt_trick_boundary_informed_condition_simple = """
functions {
    real participant_level_diffusion_lpdf(
        vector y, 
        real boundary,
        real boundary_cond,
        real boundary_ne,
        real boundary_pre_acc,
        real boundary_ne_pre_acc,
        real boundary_ne_cond,
        real boundary_pre_acc_cond,
        real boundary_ne_pre_acc_cond,
        real ndt, 
        real bias, 
        real drift, 
        real drift_cond, 
        vector condition, 
        vector pre_ne,
        vector pre_acc,
        int n_trials
    ) {
        vector[n_trials] participant_level_likelihood;
        
        for (t in 1:n_trials) {
            if (abs(y[t]) - ndt > 0) {
                participant_level_likelihood[t] = diffusion_lpdf(y[t] | boundary  + boundary_cond*condition[t] + boundary_ne*pre_ne[t] + boundary_pre_acc*pre_acc[t] + boundary_ne_pre_acc*pre_ne[t]*pre_acc[t] + boundary_ne_cond*pre_ne[t]*condition[t] + boundary_pre_acc_cond*pre_acc[t]*condition[t] + boundary_ne_pre_acc_cond*pre_ne[t]*pre_acc[t]*condition[t], ndt, bias, drift + drift_cond*condition[t]);
            } else {
                participant_level_likelihood[t] = diffusion_lpdf(ndt | boundary  + boundary_cond*condition[t] + boundary_ne*pre_ne[t] + boundary_pre_acc*pre_acc[t] + boundary_ne_pre_acc*pre_ne[t]*pre_acc[t] + boundary_ne_cond*pre_ne[t]*condition[t] + boundary_pre_acc_cond*pre_acc[t]*condition[t] + boundary_ne_pre_acc_cond*pre_ne[t]*pre_acc[t]*condition[t], ndt, bias, drift + drift_cond*condition[t]);
            }
        }
        return(sum(participant_level_likelihood));
    }
      
    /* Wiener diffusion log-PDF for a single response (adapted from brms 1.10.2)
    * Arguments:
    *   Y: acc*rt in seconds (negative and positive RTs for incorrect and correct responses respectively)
    *   boundary: boundary separation parameter > 0
    *   ndt: non-decision time parameter > 0
    *   bias: initial bias parameter in [0, 1]
    *   drift: drift rate parameter
    * Returns:
    *   a scalar to be added to the log posterior
    */
    real diffusion_lpdf(real Y, real boundary, real ndt, real bias, real drift) {
        if (Y >= 0) {
            return wiener_lpdf( abs(Y) | boundary, ndt, bias, drift );
        } else {
            return wiener_lpdf( abs(Y) | boundary, ndt, 1-bias, -drift );
        }
    }
}

data {
    int<lower=1> N; // Number of trial-level observations
    int<lower=1> n_conditions; // Number of conditions (congruent and incongruent)
    int<lower=1> n_participants; // Number of participants

    array[n_participants, 2] int participants_trials_slices; // slices TODO
    vector[N] y; // acc*rt in seconds (negative and positive RTs for incorrect and correct responses respectively)
    vector[N] condition; // Contrast coded condition: -1 for erroneous and 1 for correct response respectively
    vector[N] pre_acc; // Contrast coded accuracy on previous trial
    vector[N] pre_ne; // centered correct/error negativity on previous trial
    array[N] int<lower=1> participant; // Participant index
}

parameters {
    vector<lower=0, upper=0.3>[n_participants] participants_ter; // Participant-level Non-decision time
    vector<lower=0, upper=3>[n_participants] participants_alpha; // Participant-level Boundary parameter (speed-accuracy tradeoff)
    vector[n_participants] participants_alpha_cond; // Per-participant condition-level boundary adjustment 
    vector[n_participants] participants_delta; // Participant-level drift-rate
    vector[n_participants] participants_delta_cond; // Per-participant condition-level drift-rate adjustment 
        
    real<lower=0> ter; // Hierarchical non-decision time
    real<lower=0, upper=3> alpha; // Hierarchical boundary parameter (speed-accuracy tradeoff)
    real alpha_cond;
    real delta; // Hierarchical drift-rate
    real delta_cond; // Hierarchical drift-rate adjustment 
    
    real<lower=0> ter_sd; // Between-participants variability in non-decision time
    real<lower=0> alpha_sd; // Between-participants variability in boundary parameter (speed-accuracy tradeoff)
    real<lower=0> alpha_cond_sd;
    real<lower=0> delta_sd; // Between-participants variability in drift-rate
    real<lower=0> delta_cond_sd; // Between-participants variability in effect of condition
    
    
    // Non Hierarchical
    real alpha_ne;
    real alpha_pre_acc; 
    real alpha_ne_pre_acc; 
    
    real alpha_ne_cond;
    real alpha_pre_acc_cond; 
    real alpha_ne_pre_acc_cond; 
}

model {

    // ##########
    // Between-participant variability priors
    // ##########
    ter_sd ~ gamma(.3,1);
    alpha_sd ~ gamma(1,1);
    alpha_cond_sd ~ gamma(1,1); // 0.3
    delta_sd ~ gamma(1,1);
    delta_cond_sd ~ gamma(1,1);

    // ##########
    // Hierarchical parameters priors
    // ##########
    ter ~ normal(.1, .2);
    alpha ~ normal(1, 1) T[0, 3];
    alpha_cond ~ normal(0, 1);  // 0.2
    delta ~ normal(0, 2);
    delta_cond ~ normal(0, 2);

    // ##########
    // Non Hierarchical boundary parameters
    // ##########
    alpha_ne ~ normal(0, 0.1); 
    alpha_pre_acc ~ normal(0,0.2);  
    alpha_ne_pre_acc ~ normal(0, 0.1);
    
    alpha_ne_cond ~ normal(0, 0.1);
    alpha_pre_acc_cond ~ normal(0,0.2);  
    alpha_ne_pre_acc_cond ~ normal(0, 0.1);


    // ##########
    // Participant-level DDM parameter priors
    // ##########
    for (p in 1:n_participants) {

        // Participant-level non-decision time
        participants_ter[p] ~ normal(ter, ter_sd) T[0, .3];

        // Participant-level boundary parameter (speed-accuracy tradeoff)
        participants_alpha[p] ~ normal(alpha, alpha_sd) T[0, 3];
        
        participants_alpha_cond[p] ~ normal(alpha_cond, alpha_cond_sd);

        //Participant-level drift rate
        participants_delta[p] ~ normal(delta, delta_sd);
        
        //Participant-level condition_adjustment
        participants_delta_cond[p] ~ normal(delta_cond, delta_cond_sd);  
                
        
        target += participant_level_diffusion_lpdf( y[participants_trials_slices[p][1]:participants_trials_slices[p][2]] | participants_alpha[p], participants_alpha_cond[p], alpha_ne, alpha_pre_acc, alpha_ne_pre_acc, alpha_ne_cond, alpha_pre_acc_cond, alpha_ne_pre_acc_cond, participants_ter[p], 0.5, participants_delta[p], participants_delta_cond[p], condition[participants_trials_slices[p][1]:participants_trials_slices[p][2]], pre_ne[participants_trials_slices[p][1]:participants_trials_slices[p][2]],pre_acc[participants_trials_slices[p][1]:participants_trials_slices[p][2]], (participants_trials_slices[p][2] - participants_trials_slices[p][1] + 1));         
    }
}
"""

In [None]:
HDDM_delta_decomposed_ndt_trick_boundary_informed_condition_hdd = """
functions {
    real participant_level_diffusion_lpdf(
        vector y, 
        real boundary,
        real boundary_cond,
        real boundary_ne,
        real boundary_pre_acc,
        real boundary_ne_pre_acc,
        real boundary_ne_cond,
        real boundary_pre_acc_cond,
        real boundary_ne_pre_acc_cond,
        real ndt, 
        real bias, 
        real drift, 
        real drift_cond, 
        vector condition, 
        vector pre_ne,
        vector pre_acc,
        int n_trials
    ) {
        vector[n_trials] participant_level_likelihood;
        
        for (t in 1:n_trials) {
            if (abs(y[t]) - ndt > 0) {
                participant_level_likelihood[t] = diffusion_lpdf(y[t] | boundary  + boundary_cond*condition[t] + boundary_ne*pre_ne[t] + boundary_pre_acc*pre_acc[t] + boundary_ne_pre_acc*pre_ne[t]*pre_acc[t] + boundary_ne_cond*pre_ne[t]*condition[t] + boundary_pre_acc_cond*pre_acc[t]*condition[t] + boundary_ne_pre_acc_cond*pre_ne[t]*pre_acc[t]*condition[t], ndt, bias, drift + drift_cond*condition[t]);
            } else {
                participant_level_likelihood[t] = diffusion_lpdf(ndt | boundary  + boundary_cond*condition[t] + boundary_ne*pre_ne[t] + boundary_pre_acc*pre_acc[t] + boundary_ne_pre_acc*pre_ne[t]*pre_acc[t] + boundary_ne_cond*pre_ne[t]*condition[t] + boundary_pre_acc_cond*pre_acc[t]*condition[t] + boundary_ne_pre_acc_cond*pre_ne[t]*pre_acc[t]*condition[t], ndt, bias, drift + drift_cond*condition[t]);
            }
        }
        return(sum(participant_level_likelihood));
    }
      
    /* Wiener diffusion log-PDF for a single response (adapted from brms 1.10.2)
    * Arguments:
    *   Y: acc*rt in seconds (negative and positive RTs for incorrect and correct responses respectively)
    *   boundary: boundary separation parameter > 0
    *   ndt: non-decision time parameter > 0
    *   bias: initial bias parameter in [0, 1]
    *   drift: drift rate parameter
    * Returns:
    *   a scalar to be added to the log posterior
    */
    real diffusion_lpdf(real Y, real boundary, real ndt, real bias, real drift) {
        if (Y >= 0) {
            return wiener_lpdf( abs(Y) | boundary, ndt, bias, drift );
        } else {
            return wiener_lpdf( abs(Y) | boundary, ndt, 1-bias, -drift );
        }
    }
}

data {
    int<lower=1> N; // Number of trial-level observations
    int<lower=1> n_conditions; // Number of conditions (congruent and incongruent)
    int<lower=1> n_participants; // Number of participants

    array[n_participants, 2] int participants_trials_slices; // slices TODO
    vector[N] y; // acc*rt in seconds (negative and positive RTs for incorrect and correct responses respectively)
    vector[N] condition; // Contrast coded condition: -1 for erroneous and 1 for correct response respectively
    vector[N] pre_acc; // Contrast coded accuracy on previous trial
    vector[N] pre_ne; // centered correct/error negativity on previous trial
    array[N] int<lower=1> participant; // Participant index
}

parameters {
    vector<lower=0, upper=0.3>[n_participants] participants_ter; // Participant-level Non-decision time
    vector<lower=0, upper=3>[n_participants] participants_alpha; // Participant-level Boundary parameter (speed-accuracy tradeoff)
    vector[n_participants] participants_alpha_cond; // Per-participant condition-level boundary adjustment
    vector[n_participants] participants_alpha_ne;  
    vector[n_participants] participants_alpha_pre_acc; 
    vector[n_participants] participants_alpha_ne_pre_acc; 
    vector[n_participants] participants_alpha_ne_cond; 
    vector[n_participants] participants_alpha_pre_acc_cond; 
    vector[n_participants] participants_alpha_ne_pre_acc_cond; 
 
    vector[n_participants] participants_delta; // Participant-level drift-rate
    vector[n_participants] participants_delta_cond; // Per-participant condition-level drift-rate adjustment 
        
    real<lower=0> ter; // Hierarchical non-decision time
    real<lower=0, upper=3> alpha; // Hierarchical boundary parameter (speed-accuracy tradeoff)
    real alpha_cond;
    real delta; // Hierarchical drift-rate
    real delta_cond; // Hierarchical drift-rate adjustment 
    
    real alpha_ne;
    real alpha_pre_acc; 
    real alpha_ne_pre_acc; 
    real alpha_ne_cond;
    real alpha_pre_acc_cond; 
    real alpha_ne_pre_acc_cond; 
    
    real<lower=0> ter_sd; // Between-participants variability in non-decision time
    real<lower=0> alpha_sd; // Between-participants variability in boundary parameter (speed-accuracy tradeoff)
    real<lower=0> alpha_cond_sd;
    real<lower=0> alpha_ne_sd;
    real<lower=0> alpha_pre_acc_sd;
    real<lower=0> alpha_ne_pre_acc_sd;
    real<lower=0> alpha_ne_cond_sd;
    real<lower=0> alpha_pre_acc_cond_sd;
    real<lower=0> alpha_ne_pre_acc_cond_sd;

    real<lower=0> delta_sd; // Between-participants variability in drift-rate
    real<lower=0> delta_cond_sd; // Between-participants variability in effect of condition
    
}

model {

    // ##########
    // Between-participant variability priors
    // ##########
    ter_sd ~ gamma(.3,1);
    alpha_sd ~ gamma(1,1);
    alpha_cond_sd ~ gamma(1,1); // 0.3
    
    alpha_ne_sd ~ gamma(1,1); // works quite nice with 0.3, 1 and really nice with (1,1)
    alpha_pre_acc_sd ~ gamma(1,1); // works with 1, 1 and really nice with (1,1)
    alpha_ne_pre_acc_sd ~ gamma(1,1); // works with 0.3, 1 and really nice with (1,1)
    alpha_ne_cond_sd ~ gamma(1,1); // works with 0.3, 1 and really nice with (1,1)
    alpha_pre_acc_cond_sd ~ gamma(1,1); // works with 1, 1 and really nice with (1,1)
    alpha_ne_pre_acc_cond_sd ~ gamma(1,1); // works with 0.3, 1 and really nice with (1,1)

    delta_sd ~ gamma(1,1);
    delta_cond_sd ~ gamma(1,1);

    // ##########
    // Hierarchical parameters priors
    // ##########
    ter ~ normal(.1, .2);
    alpha ~ normal(1, 1) T[0, 3];
    alpha_cond ~ normal(0, 1);  
    delta ~ normal(0, 2);
    delta_cond ~ normal(0, 2);

    alpha_ne ~ normal(0, 0.5); // was 0.5 and works super, with 1 worser
    alpha_pre_acc ~ normal(0, 0.5);  // was 0.5 and works super, with 1 worser
    alpha_ne_pre_acc ~ normal(0, 0.5); // was 0.5 and works super, with 1 worser
    
    alpha_ne_cond ~ normal(0, 0.5); // was 0.5 and works super, with 1 worser
    alpha_pre_acc_cond ~ normal(0, 0.5); // was 0.5 and works super, with 1 worser 
    alpha_ne_pre_acc_cond ~ normal(0, 0.5); // was 0.5 and works super, with 1 worser


    // ##########
    // Participant-level DDM parameter priors
    // ##########
    for (p in 1:n_participants) {

        // Participant-level non-decision time
        participants_ter[p] ~ normal(ter, ter_sd) T[0, .3];

        // Participant-level boundary parameter (speed-accuracy tradeoff)
        participants_alpha[p] ~ normal(alpha, alpha_sd) T[0, 3];
        
        participants_alpha_cond[p] ~ normal(alpha_cond, alpha_cond_sd);

        //Participant-level drift rate
        participants_delta[p] ~ normal(delta, delta_sd);
        
        //Participant-level condition_adjustment
        participants_delta_cond[p] ~ normal(delta_cond, delta_cond_sd);  
        
        participants_alpha_ne[p] ~ normal(alpha_ne, alpha_ne_sd);
        participants_alpha_pre_acc[p] ~ normal(alpha_pre_acc, alpha_pre_acc_sd);
        participants_alpha_ne_pre_acc[p] ~ normal(alpha_ne_pre_acc, alpha_ne_pre_acc_sd);
        participants_alpha_ne_cond[p] ~ normal(alpha_ne_cond, alpha_ne_cond_sd);
        participants_alpha_pre_acc_cond[p] ~ normal(alpha_pre_acc_cond, alpha_pre_acc_cond_sd);
        participants_alpha_ne_pre_acc_cond[p] ~ normal(alpha_ne_pre_acc_cond, alpha_ne_pre_acc_cond_sd);

                
        
        target += participant_level_diffusion_lpdf( y[participants_trials_slices[p][1]:participants_trials_slices[p][2]] | participants_alpha[p], participants_alpha_cond[p], participants_alpha_ne[p], participants_alpha_pre_acc[p], participants_alpha_ne_pre_acc[p], participants_alpha_ne_cond[p], participants_alpha_pre_acc_cond[p], participants_alpha_ne_pre_acc_cond[p], participants_ter[p], 0.5, participants_delta[p], participants_delta_cond[p], condition[participants_trials_slices[p][1]:participants_trials_slices[p][2]], pre_ne[participants_trials_slices[p][1]:participants_trials_slices[p][2]],pre_acc[participants_trials_slices[p][1]:participants_trials_slices[p][2]], (participants_trials_slices[p][2] - participants_trials_slices[p][1] + 1));         
    }
}
"""

## Read and prepare data

In [None]:
df = pd.read_csv('twentythree_participants_post_eeg_many_test_set.csv').drop(columns='Unnamed: 0')

# check dataframe
display(df.isnull().any())
display(df)

Remove trials with NaNs

In [None]:
df_no_nans = df.dropna()

# check dataframe
display(df_no_nans.isnull().any())
display(df_no_nans)

Remove trials with RT < 100ms for model to converge (problem with non-decision time)

In [None]:
threshold = 0.1
df_rts_truncated = df_no_nans[df_no_nans['rt'] > threshold]

df_rts_truncated

Filter CCXP trials sequences

In [None]:
# df_rts_truncated = df_rts_truncated[df_rts_truncated['is_in_sequence'] == True]
# df_rts_truncated

Prepare 1D data for Stan with information on per participant number of trials

In [None]:
y = df_rts_truncated['y'].to_numpy()
condition = df_rts_truncated['condition'].to_numpy()
pre_acc = df_rts_truncated['pre_acc'].to_numpy()
# pre_ne = df_rts_truncated['pre_ne_FCz_centered'].to_numpy()
participant_index = df_rts_truncated['participant_index'].to_numpy()

n_participants = len(np.unique(participant_index))
n_conditions = len(np.unique(condition))

participants_trials_slices = []
pre_ne = []
for index in np.unique(participant_index):
    indices = np.where(participant_index == index)[0]
    start_index = indices[0] + 1
    end_index = indices[-1] + 1
    participants_trials_slices.append([start_index, end_index])
    
    participants_ne = df_rts_truncated.iloc[indices]['pre_ne_FCz'].to_numpy().flatten()
    participants_ne_stand = (participants_ne - np.mean(participants_ne)) / np.std(participants_ne)
    
    pre_ne.extend(participants_ne_stand)
    
participants_trials_slices = np.array(participants_trials_slices)
pre_ne = np.array(pre_ne)
df_rts_truncated['pre_ne_FCz_centered'] = pre_ne

Check distributions of EEG

In [None]:
sns.histplot(pre_ne)

In [None]:
g = sns.FacetGrid(
    df_rts_truncated.sort_values(['ID']),
    col="ID",
    col_wrap=2,
    sharex=False,
    sharey=False,
    aspect=2,
)

g.map_dataframe(
    sns.histplot,
    x="pre_ne_FCz_centered",
    hue='pre_acc',
    kde=True,
)

In [None]:
print(f"Shape of y data: {y.shape}")
print(f"Shape of condition data: {condition.shape}")
print(f"Number of participants: {n_participants}\nNumber of conditions: {n_conditions}")
print(f"Participants trial slices shape: {participants_trials_slices.shape}")

In [None]:
data_2d = {
    "N": len(y),
    "participants_trials_slices": participants_trials_slices,
    "n_conditions": n_conditions,
    "n_participants": n_participants,
    "y": y,
    "condition": condition,
    'pre_ne': pre_ne,
    'pre_acc': pre_acc,
    "participant": participant_index,
}

## Build and fit the model

In [None]:
posterior = stan.build(HDDM_delta_decomposed_ndt_trick_boundary_informed_condition_hdd, data=data_2d, random_seed=42)

In [None]:
num_chains = 4
warmup = 2000
num_samples = 1000
thin=1

min_rt = np.zeros(n_participants)
for idx, participant_idx in enumerate(np.unique(participant_index)):
    participant_data = df_rts_truncated[df_rts_truncated['participant_index'] == participant_idx]['y'].to_numpy()
    min_rt[idx] = np.min(abs(participant_data))

initials = []
for c in range(0, num_chains):
    chain_init = {               
        'ter_sd': np.random.uniform(.01, .2),
        'alpha_sd': np.random.uniform(.01, 1.),
        'alpha_cond_sd': np.random.uniform(.01, 1.),
        'delta_sd': np.random.uniform(.1, 3.),
        'delta_cond_sd': np.random.uniform(.1, 3.),
        
        'alpha_ne_sd': np.random.uniform(.01, 1), 
        'alpha_pre_acc_sd': np.random.uniform(.01, 1), 
        'alpha_ne_pre_acc_sd': np.random.uniform(.01, 1), 
        'alpha_ne_cond_sd': np.random.uniform(.01, 1), 
        'alpha_pre_acc_cond_sd': np.random.uniform(.01, 1), 
        'alpha_ne_pre_acc_cond_sd': np.random.uniform(.01, 1),


        'ter': np.random.uniform(0.05, .3),
        'alpha': np.random.uniform(1, 2), 
        'alpha_cond': np.random.uniform(-.5, .5), 
        'delta': np.random.uniform(-4., 4.),
        'delta_cond': np.random.uniform(-4., 4.),

        'alpha_ne': np.random.uniform(-.05, .05), 
        'alpha_pre_acc': np.random.uniform(-0.1, .1), 
        'alpha_ne_pre_acc': np.random.uniform(-.05, .05), 
        'alpha_ne_cond': np.random.uniform(-.05, .05), 
        'alpha_pre_acc_cond': np.random.uniform(-0.1, .1), 
        'alpha_ne_pre_acc_cond': np.random.uniform(-.05, .05),
        
        'participants_ter': np.random.uniform(0.05, .3, size=n_participants),
        'participants_alpha': np.random.uniform(1, 2., size=n_participants), 
        'participants_alpha_cond': np.random.uniform(-0.5, .5, size=n_participants),
        'participants_delta': np.random.uniform(-4., 4., size=n_participants),
        'participants_delta_cond': np.random.uniform(-4., 4., size=n_participants),
        
        'participants_alpha_ne': np.random.uniform(-.05, .05, size=n_participants), 
        'participants_alpha_pre_acc': np.random.uniform(-0.1, .1, size=n_participants), 
        'participants_alpha_ne_pre_acc': np.random.uniform(-.05, .05, size=n_participants), 
        'participants_alpha_ne_cond': np.random.uniform(-.05, .05, size=n_participants),  
        'participants_alpha_pre_acc_cond': np.random.uniform(-0.1, .1, size=n_participants),   
        'participants_alpha_ne_pre_acc_cond': np.random.uniform(-.05, .05, size=n_participants), 
    }

    for p in range(0, n_participants):
        chain_init['participants_ter'][p] = np.random.uniform(0., min_rt[p]/2)

    initials.append(chain_init)

print(min_rt)

In [None]:
fit = posterior.sample(
    num_chains=num_chains, 
    num_samples=num_samples, 
    num_warmup = warmup, 
    save_warmup=False, 
    init=initials, 
    num_thin=thin
)

Extract samples and chains

In [None]:
fit_df = fit.to_frame()

# adds chain number to dataframe with draws_. See: https://github.com/stan-dev/pystan/pull/333
samples_saved = num_samples // thin
chains = np.ones((samples_saved, 1), int) * np.arange(num_chains)
fit_df.insert(0, "chain__", chains.ravel())

fit_df.head()

## Check model

### Summary of the results

In [None]:
variables_to_track = list(posterior.constrained_param_names)

In [None]:
# overall summary
fit_df[variables_to_track].describe().T

In [None]:
# summary by chain
fit_df.groupby(['chain__'])[variables_to_track].describe().T

Posterior and chains plots

In [None]:
# plt.figure(figsize=(50,100))

melted_df = pd.melt(fit_df, id_vars=list(filter(lambda x: x not in set(variables_to_track),fit_df.columns.to_list())), var_name='parameter_name', value_name='draws')

g = sns.FacetGrid(
    melted_df,
    col="parameter_name",
    col_wrap=3,
    sharex=False,
    sharey=False,
    aspect=1.5,
    hue='chain__',
)

g.map_dataframe(
    sns.histplot,
    x="draws",
    kde=True,
)

g.add_legend()
# plt.savefig('hddm_parameters_posteriors_trick_cutoff.png', bbox_inches='tight')


g = sns.FacetGrid(
    melted_df,
    col="parameter_name",
    col_wrap=3,
    sharex=False,
    sharey=False,
    aspect=1.5,
    hue='chain__',
)

g.map_dataframe(
    sns.lineplot,
    x=np.arange(0,samples_saved),
    y="draws",
)

g.add_legend()
# plt.savefig('hddm_chains_trick_cutoff.png', bbox_inches='tight')

plt.show()

### Diagnostics

In [None]:
# adapted from https://github.com/mdnunez/pyhddmjags/tree/master
def diagnostic(insamples):
    """
    Returns two versions of Rhat (measure of convergence, less is better with an approximate
    1.10 cutoff) and Neff, number of effective samples). Note that 'rhat' is more diagnostic than 'oldrhat' according to 
    Gelman et al. (2014).

    Reference for preferred Rhat calculation (split chains) and number of effective sample calculation: 
        Gelman, A., Carlin, J. B., Stern, H. S., Dunson, D. B., Vehtari, A. & Rubin, D. B. (2014). 
        Bayesian data analysis (Third Edition). CRC Press:
        Boca Raton, FL

    Reference for original Rhat calculation:
        Gelman, A., Carlin, J., Stern, H., & Rubin D., (2004).
        Bayesian Data Analysis (Second Edition). Chapman & Hall/CRC:
        Boca Raton, FL.


    Parameters
    ----------
    insamples: dic
        Sampled values of monitored variables as a dictionary where keys
        are variable names and values are numpy arrays with shape:
        (dim_1, dim_n, iterations, chains). dim_1, ..., dim_n describe the
        shape of variable in JAGS model.

    Returns
    -------
    dict:
        rhat, oldrhat, neff, posterior mean, and posterior std for each variable. Prints maximum Rhat and minimum Neff across all variables
    """

    result = {}  # Initialize dictionary
    maxrhatsold = np.zeros((len(insamples.keys())), dtype=float)
    maxrhatsnew = np.zeros((len(insamples.keys())), dtype=float)
    minneff = np.ones((len(insamples.keys())), dtype=float)*np.inf
    allkeys ={} # Initialize dictionary
    keyindx = 0
    for key in insamples.keys():
        if key[0] != '_':
            result[key] = {}

            possamps = insamples[key]

            # Number of chains
            nchains = possamps.shape[-1]

            # Number of samples per chain
            nsamps = possamps.shape[-2]

            # Number of variables per key
            nvars = np.prod(possamps.shape[0:-2])

            # Reshape data
            allsamps = np.reshape(possamps, possamps.shape[:-2] + (nchains * nsamps,))

            # Reshape data to preduce R_hatnew
            possampsnew = np.empty(possamps.shape[:-2] + (int(nsamps/2), nchains * 2,))
            newc=0
            for c in range(nchains):
                possampsnew[...,newc] = np.take(np.take(possamps,np.arange(0,int(nsamps/2)),axis=-2),c,axis=-1)
                possampsnew[...,newc+1] = np.take(np.take(possamps,np.arange(int(nsamps/2),nsamps),axis=-2),c,axis=-1)
                newc += 2

            # Index of variables
            varindx = np.arange(nvars).reshape(possamps.shape[0:-2])

            # Reshape data
            alldata = np.reshape(possamps, (nvars, nsamps, nchains))

            # Mean of each chain for rhat
            chainmeans = np.mean(possamps, axis=-2)
            # Mean of each chain for rhatnew
            chainmeansnew = np.mean(possampsnew, axis=-2)
            # Global mean of each parameter for rhat
            globalmean = np.mean(chainmeans, axis=-1)
            globalmeannew = np.mean(chainmeansnew, axis=-1)
            result[key]['mean'] = globalmean
            result[key]['std'] = np.std(allsamps, axis=-1)
            globalmeanext = np.expand_dims(
                globalmean, axis=-1)  # Expand the last dimension
            globalmeanext = np.repeat(
                globalmeanext, nchains, axis=-1)  # For differencing
            globalmeanextnew = np.expand_dims(
                globalmeannew, axis=-1)  # Expand the last dimension
            globalmeanextnew = np.repeat(
                globalmeanextnew, nchains*2, axis=-1)  # For differencing
            # Between-chain variance for rhat
            between = np.sum(np.square(chainmeans - globalmeanext),
                             axis=-1) * nsamps / (nchains - 1.)
            # Mean of the variances of each chain for rhat
            within = np.mean(np.var(possamps, axis=-2), axis=-1)
            # Total estimated variance for rhat
            totalestvar = (1. - (1. / nsamps)) * \
                          within + (1. / nsamps) * between
            # Rhat (original Gelman-Rubin statistic)
            temprhat = np.sqrt(totalestvar / within)
            maxrhatsold[keyindx] = np.nanmax(temprhat) # Ignore NANs
            allkeys[keyindx] = key
            result[key]['oldrhat'] = temprhat
            # Between-chain variance for rhatnew
            betweennew = np.sum(np.square(chainmeansnew - globalmeanextnew),
                                axis=-1) * (nsamps/2) / ((nchains*2) - 1.)
            # Mean of the variances of each chain for rhatnew
            withinnew = np.mean(np.var(possampsnew, axis=-2), axis=-1)
            # Total estimated variance
            totalestvarnew = (1. - (1. / (nsamps/2))) * \
                             withinnew + (1. / (nsamps/2)) * betweennew
            # Rhatnew (Gelman-Rubin statistic from Gelman et al., 2013)
            temprhatnew = np.sqrt(totalestvarnew / withinnew)
            maxrhatsnew[keyindx] = np.nanmax(temprhatnew) # Ignore NANs
            result[key]['rhat'] = temprhatnew
            # Number of effective samples from Gelman et al. (2013) 286-288
            neff = np.empty(possamps.shape[0:-2])
            for v in range(0, nvars):
                whereis = np.where(varindx == v)
                rho_hat = []
                rho_hat_even = 0
                rho_hat_odd = 0
                t = 2
                while (t < nsamps - 2) & (float(rho_hat_even) + float(rho_hat_odd) >= 0):
                    # above equation (11.7) in Gelman et al., 2013
                    variogram_odd = np.mean(np.mean(np.power(alldata[v,(t-1):nsamps,:] - alldata[v,0:(nsamps-t+1),:],2),axis=0))
                    
                    # Equation (11.7) in Gelman et al., 2013
                    rho_hat_odd = 1 - np.divide(variogram_odd, 2*totalestvar[whereis]).item()
                    rho_hat.append(rho_hat_odd)
                    
                    # above equation (11.7) in Gelman et al., 2013
                    variogram_even = np.mean(np.mean(np.power(alldata[v,t:nsamps,:] - alldata[v,0:(nsamps-t),:],2),axis=0)) 
                    
                    # Equation (11.7) in Gelman et al., 2013
                    rho_hat_even = 1 - np.divide(variogram_even, 2*totalestvar[whereis]).item() 
                    rho_hat.append(rho_hat_even)
                    
                    t += 2
                rho_hat = np.asarray(rho_hat)
                # Equation (11.8) in Gelman et al., 2013
                neff[whereis] = np.divide(nchains*nsamps, 1 + 2*np.sum(rho_hat)) 
            result[key]['neff'] = np.round(neff)
            minneff[keyindx] = np.nanmin(np.round(neff))
            keyindx += 1

            # Geweke statistic?
    # print("Maximum old Rhat was %3.2f for variable %s" % (np.max(maxrhatsold),allkeys[np.argmax(maxrhatsold)]))
    maxrhatkey = allkeys[np.argmax(maxrhatsnew)]
    maxrhatindx = np.unravel_index( np.argmax(result[maxrhatkey]['rhat']) , result[maxrhatkey]['rhat'].shape)
    print("Maximum Rhat was %3.2f for variable %s at index %s" % (np.max(maxrhatsnew), maxrhatkey, maxrhatindx))
    minneffkey = allkeys[np.argmin(minneff)]
    minneffindx = np.unravel_index( np.argmin(result[minneffkey]['neff']) , result[minneffkey]['neff'].shape)
    print("Minimum number of effective samples was %d for variable %s at index %s" % (np.min(minneff), minneffkey, minneffindx))
    return result

In [None]:
def models_diagnostics_dict_to_df(models_diagnostics):
    results_df = pd.DataFrame()
    for key in models_diagnostics.keys():
        main_data = models_diagnostics[key]

        if main_data['mean'].ndim == 1:
            this_df = pd.DataFrame(
                {
                    f"{key}.{i + 1}": 
                        [main_data[inner_key][i] for inner_key in main_data.keys()] for i in range(main_data['mean'].shape[0]) 
                }, index=main_data.keys()
            )

        elif main_data['mean'].ndim == 2:
            this_df = pd.DataFrame(
                {
                    f"{key}.{i + 1}.{j + 1}": 
                     [main_data[inner_key][i, j] for inner_key in main_data.keys()] for i in range(main_data['mean'].shape[0]) for j in range(main_data['mean'].shape[1])
                }, index=main_data.keys()
            )
        else:
            this_df = pd.DataFrame()
            print('3-dim parameters are not implemented')
    
        results_df = pd.concat([results_df, this_df], axis=1)
        
    return results_df

In [None]:
def flip_stan_out(fit, parameters=None):
    results = {}
    
    if parameters is None:
        pass
    else:
        for parameter in parameters:
            print(f"Processing: {parameter} ")
            samples = fit[parameter]

            # reshape from (n_params, n_samples*n_chains) to (n_params, n_samples, n_chains)
            samples_reshaped = samples.reshape(
                samples.shape[:-1] + (num_samples, num_chains), 
                order='C'
            )
            results[parameter] = samples_reshaped
    
    return results

In [None]:
# creates a dict [parameter_name] : array of shape (*n_params, n_samples, n_chains)
parameters = fit.param_names
extracted_samples_dict = flip_stan_out(fit, parameters)

Show model diagnostics

In [None]:
models_diagnostics = diagnostic(extracted_samples_dict)
models_diagnostics_df = models_diagnostics_dict_to_df(models_diagnostics)
models_diagnostics_df.T

# save results
# models_diagnostics_df.T.to_csv('hddm_model_trick_cutoff_diagnostics.csv')

### Posterior distribution plots

In [None]:
# adapted from https://github.com/mdnunez/pyhddmjags/tree/master
def jellyfish(possamps):  # jellyfish plots
    """Plots posterior distributions of given posterior samples in a jellyfish
    plot. Jellyfish plots are posterior distributions (mirrored over their
    horizontal axes) with 99% and 95% credible intervals (currently plotted
    from the .5% and 99.5% & 2.5% and 97.5% percentiles respectively.
    Also plotted are the median and mean of the posterior distributions"

    Parameters
    ----------
    possamps : ndarray of posterior chains where the last dimension is
    the number of chains, the second to last dimension is the number of samples
    in each chain, all other dimensions describe the shape of the parameter
    """

    # Number of chains
    nchains = possamps.shape[-1]

    # Number of samples per chain
    nsamps = possamps.shape[-2]

    # Number of dimensions
    ndims = possamps.ndim - 2

    # Number of variables to plot
    nvars = np.prod(possamps.shape[0:-2])

    # Index of variables
    varindx = np.arange(nvars).reshape(possamps.shape[0:-2])

    # Reshape data
    alldata = np.reshape(possamps, (nvars, nchains, nsamps))
    alldata = np.reshape(alldata, (nvars, nchains * nsamps))

    # Plot properties
    LineWidths = np.array([2, 5])
    teal = np.array([0, .7, .7])
    blue = np.array([0, 0, 1])
    orange = np.array([1, .3, 0])
    Colors = [teal, blue]

    # Initialize ylabels list
    ylabels = ['']

    for v in range(0, nvars):
        # Create ylabel
        whereis = np.where(varindx == v)
        newlabel = ''
        for l in range(0, ndims):
            newlabel = newlabel + ('_%i' % whereis[l][0])

        ylabels.append(newlabel)

        # Compute posterior density curves
        kde = stats.gaussian_kde(alldata[v, :])
        bounds = stats.scoreatpercentile(alldata[v, :], (.5, 2.5, 97.5, 99.5))
        for b in range(0, 2):
            # Bound by .5th percentile and 99.5th percentile
            x = np.linspace(bounds[b], bounds[-1 - b], 100)
            p = kde(x)

            # Scale distributions down
            maxp = np.max(p)

            # Plot jellyfish
            upper = .25 * p / maxp + v + 1
            lower = -.25 * p / maxp + v + 1
            lines = plt.plot(x, upper, x, lower)
            plt.setp(lines, color=Colors[b], linewidth=LineWidths[b])
            if b == 1:
                # Mark mode
                wheremaxp = np.argmax(p)
                mmode = plt.plot(np.array([1., 1.]) * x[wheremaxp],
                                 np.array([lower[wheremaxp], upper[wheremaxp]]))
                plt.setp(mmode, linewidth=3, color=orange)
                # Mark median
                mmedian = plt.plot(np.median(alldata[v, :]), v + 1, 'ko')
                plt.setp(mmedian, markersize=10, color=[0., 0., 0.])
                # Mark mean
                mmean = plt.plot(np.mean(alldata[v, :]), v + 1, '*')
                plt.setp(mmean, markersize=10, color=teal)

    # Display plot
    plt.setp(plt.gca(), yticklabels=ylabels, yticks=np.arange(0, nvars + 1))

In [None]:
#Posterior distributions
for parameter in fit.param_names:
    plt.figure()
    jellyfish(extracted_samples_dict[parameter])
    plt.title(f'Posterior distributions of the {parameter}')
    # plt.savefig(f'hddm_distributions_trick_cutoff{parameter}.png', bbox_inches='tight')
    plt.show()