# DDM stan model fitting

Imports

In [None]:
import pandas as pd
import numpy as np
import stan
import nest_asyncio
import seaborn as sns
import matplotlib.pyplot as plt
from scipy import stats

# enables multithreading in jupyter notebook
nest_asyncio.apply()

In [None]:
plt.rcParams['savefig.dpi'] = 300

## Stan model code

In [None]:
DDM_code = """
functions {
  /* Wiener diffusion log-PDF for a single response (adapted from brms 1.10.2)
   * Arguments:
   *   Y: acc*rt in seconds (negative and positive RTs for incorrect and correct responses respectively)
   *   boundary: boundary separation parameter > 0
   *   ndt: non-decision time parameter > 0
   *   bias: initial bias parameter in [0, 1]
   *   drift: drift rate parameter
   * Returns:
   *   a scalar to be added to the log posterior
   */
   real diffusion_lpdf(real Y, real boundary,
                              real ndt, real bias, real drift) {

    if (Y >= 0) {
        return wiener_lpdf( abs(Y) | boundary, ndt, bias, drift ); // change to abs()
    } else {
        return wiener_lpdf( abs(Y) | boundary, ndt, 1-bias, -drift ); // change to abs()
    }

   }
}

data {
    int<lower=1> N; // Number of trial-level observations
    int<lower=1> n_conditions; // Number of conditions (congruent and incongruent)
    int<lower=1> n_participants; // Number of participants

    array[N] real y; // acc*rt in seconds (negative and positive RTs for incorrect and correct responses respectively)
    array[N] int<lower=1> participant; // Participant index
    array[N] int condition; // Condition index
}

parameters {
    vector<lower=0, upper=0.3>[n_participants] participants_ter; // Participant-level Non-decision time
    vector<lower=0, upper=3>[n_participants] participants_alpha; // Participant-level Boundary parameter (speed-accuracy tradeoff)
    vector<lower=0, upper=1>[n_participants] participants_beta; // Participant-level Start point bias towards choice A
    matrix[n_participants,n_conditions] participants_condition_delta; // Participant-level and condition-level drift rate to choice A
}

model {
    // ##########
    // Participant-level DDM parameter priors
    // ##########
    for (p in 1:n_participants) {

        // Participant-level non-decision time
        participants_ter[p] ~ normal(.05, .2) T[0, .3];

        // Participant-level boundary parameter (speed-accuracy tradeoff)
        participants_alpha[p] ~ normal(1., 1.) T[0, 3];

        //Participant-level start point bias towards choice A
        participants_beta[p] ~ normal(.5, .1) T[0, 1];

        //Participant-level and condition-level drift rate
        for (c in 1:n_conditions) {
            participants_condition_delta[p,c] ~ normal(0., 1.); // Participant-level and condition-level drift rate
        }
    }

    // Wiener likelihood
    for (i in 1:N) {

        target += diffusion_lpdf( y[i] | participants_alpha[participant[i]], participants_ter[participant[i]], participants_beta[participant[i]], participants_condition_delta[participant[i], condition[i]]);
    }
}
"""

In [None]:
DDM_delta_decomposed_code = """
functions {
  /* Wiener diffusion log-PDF for a single response (adapted from brms 1.10.2)
   * Arguments:
   *   Y: acc*rt in seconds (negative and positive RTs for incorrect and correct responses respectively)
   *   boundary: boundary separation parameter > 0
   *   ndt: non-decision time parameter > 0
   *   bias: initial bias parameter in [0, 1]
   *   drift: drift rate parameter
   * Returns:
   *   a scalar to be added to the log posterior
   */
   real diffusion_lpdf(real Y, real boundary,
                              real ndt, real bias, real drift) {

    if (Y >= 0) {
        return wiener_lpdf( abs(Y) | boundary, ndt, bias, drift ); // change to abs()
    } else {
        return wiener_lpdf( abs(Y) | boundary, ndt, 1-bias, -drift ); // change to abs()
    }

   }
}

data {
    int<lower=1> N; // Number of trial-level observations
    int<lower=1> n_conditions; // Number of conditions (congruent and incongruent)
    int<lower=1> n_participants; // Number of participants

    array[N] real y; // acc*rt in seconds (negative and positive RTs for incorrect and correct responses respectively)
    array[N] int condition; // Contrast coded condition: -1 for erroneous and 1 for correct response respectively
    array[N] int<lower=1> participant; // Participant index
}

parameters {
    vector<lower=0, upper=0.3>[n_participants] participants_ter; // Participant-level Non-decision time
    vector<lower=0, upper=3>[n_participants] participants_alpha; // Participant-level Boundary parameter (speed-accuracy tradeoff)
    vector<lower=0, upper=1>[n_participants] participants_beta; // Participant-level Start point bias towards choice A
    vector[n_participants] participants_delta; // Participant-level drift-rate
    vector[n_participants] participants_delta_theta; // Per-participant condition-level drift-rate adjustment 
}

model {
    // ##########
    // Participant-level DDM parameter priors
    // ##########
    for (p in 1:n_participants) {

        // Participant-level non-decision time
        participants_ter[p] ~ normal(.1, .2) T[0, .3];

        // Participant-level boundary parameter (speed-accuracy tradeoff)
        participants_alpha[p] ~ normal(1., 1.) T[0, 3];

        //Participant-level start point bias towards choice A
        participants_beta[p] ~ normal(.5, .2) T[0, 1];
        
        //Participant-level drift rate
        participants_delta[p] ~ normal(0, 1);
        
        //Participant-level condition_adjustment
        participants_delta_theta[p] ~ normal(0, 1);        

    }

    // Wiener likelihood
    for (i in 1:N) {

        target += diffusion_lpdf( y[i] | participants_alpha[participant[i]], participants_ter[participant[i]], participants_beta[participant[i]], participants_delta[participant[i]] + participants_delta_theta[participant[i]]*condition[i]);
    }
}
"""

## Read and prepare data

In [None]:
df = pd.read_csv('two_participants_test_set.csv').drop(columns='Unnamed: 0')

# check dataframe
display(df.isnull().any())
display(df.head())

Remove trials with RT < 100ms for model to converge (problem with non-decision time)

In [None]:
df_rts_truncated = df[df['rt'] > 0.1]

Prepare data for Stan

In [None]:
y = df_rts_truncated['y'].to_numpy()
participant = df_rts_truncated['participant_index'].to_numpy()
# condition_index = df_trun['condition_index'].to_numpy()
condition = df_rts_truncated['condition'].to_numpy()
n_participants = len(np.unique(participant))
n_conditions = len(np.unique(condition))

print(f"Number of participants: {n_participants}\nNumber of conditions: {n_conditions}")

In [None]:
data = {
    "N": len(y),
    "n_conditions": n_conditions,
    "n_participants": n_participants,
    "y": y,
    "participant": participant,
    "condition": condition,
}

## Build and fit the model

In [None]:
num_chains = 4
warmup = 1000
num_samples = 10000

# todo - chains' init
posterior = stan.build(DDM_delta_decomposed_code, data=data, random_seed=42)
posterior
# fit = posterior.sample(num_chains=num_chains, num_samples=num_samples, num_warmup = warmup)

In [None]:
fit = posterior.sample(num_chains=num_chains, num_samples=num_samples, num_warmup = warmup, save_warmup=True)
fit

Extract samples and chains

In [None]:
fit_df = fit.to_frame()

# adds chain number to dataframe with draws_. See: link_to_pull_request
chains = np.ones((num_samples, 1), int) * np.arange(num_chains)
fit_df.insert(0, "chain__", chains.ravel())

fit_df.head()

## Check model

### Summary of the results

In [None]:
variables_to_track = list(posterior.constrained_param_names)

In [None]:
# overall summary
fit_df[variables_to_track].describe().T

In [None]:
# summary by chain
fit_df.groupby(['chain__'])[variables_to_track].describe().T

Posterior and chains plots

In [None]:
plt.figure(figsize=(5,10))

melted_df = pd.melt(fit_df, id_vars=list(filter(lambda x: x not in set(variables_to_track),fit_df.columns.to_list())), var_name='parameter_name', value_name='draws')

g = sns.FacetGrid(
    melted_df,
    col="parameter_name",
    col_wrap=2,
    sharex=False,
    sharey=False,
    aspect=2,
    hue='chain__',
)

g.map_dataframe(
    sns.histplot,
    x="draws",
    kde=True,
)

g.add_legend()
# plt.savefig('parameters_posteriors.png', bbox_inches='tight')


g = sns.FacetGrid(
    melted_df,
    col="parameter_name",
    col_wrap=2,
    sharex=False,
    sharey=False,
    aspect=2,
    hue='chain__',
)

g.map_dataframe(
    sns.lineplot,
    x=np.arange(0,num_samples),
    y="draws",
)

g.add_legend()
# plt.savefig('chains.png', bbox_inches='tight')

plt.show()

### Diagnostics

In [None]:
# adapted from https://github.com/mdnunez/pyhddmjags/tree/master
def diagnostic(insamples):
    """
    Returns two versions of Rhat (measure of convergence, less is better with an approximate
    1.10 cutoff) and Neff, number of effective samples). Note that 'rhat' is more diagnostic than 'oldrhat' according to 
    Gelman et al. (2014).

    Reference for preferred Rhat calculation (split chains) and number of effective sample calculation: 
        Gelman, A., Carlin, J. B., Stern, H. S., Dunson, D. B., Vehtari, A. & Rubin, D. B. (2014). 
        Bayesian data analysis (Third Edition). CRC Press:
        Boca Raton, FL

    Reference for original Rhat calculation:
        Gelman, A., Carlin, J., Stern, H., & Rubin D., (2004).
        Bayesian Data Analysis (Second Edition). Chapman & Hall/CRC:
        Boca Raton, FL.


    Parameters
    ----------
    insamples: dic
        Sampled values of monitored variables as a dictionary where keys
        are variable names and values are numpy arrays with shape:
        (dim_1, dim_n, iterations, chains). dim_1, ..., dim_n describe the
        shape of variable in JAGS model.

    Returns
    -------
    dict:
        rhat, oldrhat, neff, posterior mean, and posterior std for each variable. Prints maximum Rhat and minimum Neff across all variables
    """

    result = {}  # Initialize dictionary
    maxrhatsold = np.zeros((len(insamples.keys())), dtype=float)
    maxrhatsnew = np.zeros((len(insamples.keys())), dtype=float)
    minneff = np.ones((len(insamples.keys())), dtype=float)*np.inf
    allkeys ={} # Initialize dictionary
    keyindx = 0
    for key in insamples.keys():
        if key[0] != '_':
            result[key] = {}

            possamps = insamples[key]

            # Number of chains
            nchains = possamps.shape[-1]

            # Number of samples per chain
            nsamps = possamps.shape[-2]

            # Number of variables per key
            nvars = np.prod(possamps.shape[0:-2])

            # Reshape data
            allsamps = np.reshape(possamps, possamps.shape[:-2] + (nchains * nsamps,))

            # Reshape data to preduce R_hatnew
            possampsnew = np.empty(possamps.shape[:-2] + (int(nsamps/2), nchains * 2,))
            newc=0
            for c in range(nchains):
                possampsnew[...,newc] = np.take(np.take(possamps,np.arange(0,int(nsamps/2)),axis=-2),c,axis=-1)
                possampsnew[...,newc+1] = np.take(np.take(possamps,np.arange(int(nsamps/2),nsamps),axis=-2),c,axis=-1)
                newc += 2

            # Index of variables
            varindx = np.arange(nvars).reshape(possamps.shape[0:-2])

            # Reshape data
            alldata = np.reshape(possamps, (nvars, nsamps, nchains))

            # Mean of each chain for rhat
            chainmeans = np.mean(possamps, axis=-2)
            # Mean of each chain for rhatnew
            chainmeansnew = np.mean(possampsnew, axis=-2)
            # Global mean of each parameter for rhat
            globalmean = np.mean(chainmeans, axis=-1)
            globalmeannew = np.mean(chainmeansnew, axis=-1)
            result[key]['mean'] = globalmean
            result[key]['std'] = np.std(allsamps, axis=-1)
            globalmeanext = np.expand_dims(
                globalmean, axis=-1)  # Expand the last dimension
            globalmeanext = np.repeat(
                globalmeanext, nchains, axis=-1)  # For differencing
            globalmeanextnew = np.expand_dims(
                globalmeannew, axis=-1)  # Expand the last dimension
            globalmeanextnew = np.repeat(
                globalmeanextnew, nchains*2, axis=-1)  # For differencing
            # Between-chain variance for rhat
            between = np.sum(np.square(chainmeans - globalmeanext),
                             axis=-1) * nsamps / (nchains - 1.)
            # Mean of the variances of each chain for rhat
            within = np.mean(np.var(possamps, axis=-2), axis=-1)
            # Total estimated variance for rhat
            totalestvar = (1. - (1. / nsamps)) * \
                          within + (1. / nsamps) * between
            # Rhat (original Gelman-Rubin statistic)
            temprhat = np.sqrt(totalestvar / within)
            maxrhatsold[keyindx] = np.nanmax(temprhat) # Ignore NANs
            allkeys[keyindx] = key
            result[key]['oldrhat'] = temprhat
            # Between-chain variance for rhatnew
            betweennew = np.sum(np.square(chainmeansnew - globalmeanextnew),
                                axis=-1) * (nsamps/2) / ((nchains*2) - 1.)
            # Mean of the variances of each chain for rhatnew
            withinnew = np.mean(np.var(possampsnew, axis=-2), axis=-1)
            # Total estimated variance
            totalestvarnew = (1. - (1. / (nsamps/2))) * \
                             withinnew + (1. / (nsamps/2)) * betweennew
            # Rhatnew (Gelman-Rubin statistic from Gelman et al., 2013)
            temprhatnew = np.sqrt(totalestvarnew / withinnew)
            maxrhatsnew[keyindx] = np.nanmax(temprhatnew) # Ignore NANs
            result[key]['rhat'] = temprhatnew
            # Number of effective samples from Gelman et al. (2013) 286-288
            neff = np.empty(possamps.shape[0:-2])
            for v in range(0, nvars):
                whereis = np.where(varindx == v)
                rho_hat = []
                rho_hat_even = 0
                rho_hat_odd = 0
                t = 2
                while (t < nsamps - 2) & (float(rho_hat_even) + float(rho_hat_odd) >= 0):
                    # above equation (11.7) in Gelman et al., 2013
                    variogram_odd = np.mean(np.mean(np.power(alldata[v,(t-1):nsamps,:] - alldata[v,0:(nsamps-t+1),:],2),axis=0))
                    
                    # Equation (11.7) in Gelman et al., 2013
                    rho_hat_odd = 1 - np.divide(variogram_odd, 2*totalestvar[whereis]).item()
                    rho_hat.append(rho_hat_odd)
                    
                    # above equation (11.7) in Gelman et al., 2013
                    variogram_even = np.mean(np.mean(np.power(alldata[v,t:nsamps,:] - alldata[v,0:(nsamps-t),:],2),axis=0)) 
                    
                    # Equation (11.7) in Gelman et al., 2013
                    rho_hat_even = 1 - np.divide(variogram_even, 2*totalestvar[whereis]).item() 
                    rho_hat.append(rho_hat_even)
                    
                    t += 2
                rho_hat = np.asarray(rho_hat)
                # Equation (11.8) in Gelman et al., 2013
                neff[whereis] = np.divide(nchains*nsamps, 1 + 2*np.sum(rho_hat)) 
            result[key]['neff'] = np.round(neff)
            minneff[keyindx] = np.nanmin(np.round(neff))
            keyindx += 1

            # Geweke statistic?
    # print("Maximum old Rhat was %3.2f for variable %s" % (np.max(maxrhatsold),allkeys[np.argmax(maxrhatsold)]))
    maxrhatkey = allkeys[np.argmax(maxrhatsnew)]
    maxrhatindx = np.unravel_index( np.argmax(result[maxrhatkey]['rhat']) , result[maxrhatkey]['rhat'].shape)
    print("Maximum Rhat was %3.2f for variable %s at index %s" % (np.max(maxrhatsnew), maxrhatkey, maxrhatindx))
    minneffkey = allkeys[np.argmin(minneff)]
    minneffindx = np.unravel_index( np.argmin(result[minneffkey]['neff']) , result[minneffkey]['neff'].shape)
    print("Minimum number of effective samples was %d for variable %s at index %s" % (np.min(minneff), minneffkey, minneffindx))
    return result

In [None]:
def models_diagnostics_dict_to_df(models_diagnostics):
    results_df = pd.DataFrame()
    for key in models_diagnostics.keys():
        main_data = models_diagnostics[key]

        if main_data['mean'].ndim == 1:
            this_df = pd.DataFrame(
                {
                    f"{key}.{i + 1}": 
                        [main_data[inner_key][i] for inner_key in main_data.keys()] for i in range(main_data['mean'].shape[0]) 
                }, index=main_data.keys()
            )

        elif main_data['mean'].ndim == 2:
            this_df = pd.DataFrame(
                {
                    f"{key}.{i + 1}.{j + 1}": 
                     [main_data[inner_key][i, j] for inner_key in main_data.keys()] for i in range(main_data['mean'].shape[0]) for j in range(main_data['mean'].shape[1])
                }, index=main_data.keys()
            )
        else:
            this_df = pd.DataFrame()
            print('3-dim parameters are not implemented')
    
        results_df = pd.concat([results_df, this_df], axis=1)
        
    return results_df

In [None]:
def flip_stan_out(fit, parameters=None):
    results = {}
    
    if parameters is None:
        pass
    else:
        for parameter in parameters:
            print(f"Processing: {parameter} ")
            samples = fit[parameter]

            # reshape from (n_params, n_samples*n_chains) to (n_params, n_samples, n_chains)
            samples_reshaped = samples.reshape(
                samples.shape[:-1] + (num_samples, num_chains), 
                order='C'
            )
            results[parameter] = samples_reshaped
    
    return results

In [None]:
# creates a dict [parameter_name] : array of shape (*n_params, n_samples, n_chains)
parameters = fit.param_names
extracted_samples_dict = flip_stan_out(fit, parameters)

Show model diagnostics

In [None]:
models_diagnostics = diagnostic(extracted_samples_dict)
models_diagnostics_df = models_diagnostics_dict_to_df(models_diagnostics)
display(models_diagnostics_df.T)

# save results
# models_diagnostics_df.T.to_csv('models_diagnostics.csv')

### Posterior distribution plots

In [None]:
# adapted from https://github.com/mdnunez/pyhddmjags/tree/master
def jellyfish(possamps):  # jellyfish plots
    """Plots posterior distributions of given posterior samples in a jellyfish
    plot. Jellyfish plots are posterior distributions (mirrored over their
    horizontal axes) with 99% and 95% credible intervals (currently plotted
    from the .5% and 99.5% & 2.5% and 97.5% percentiles respectively.
    Also plotted are the median and mean of the posterior distributions"

    Parameters
    ----------
    possamps : ndarray of posterior chains where the last dimension is
    the number of chains, the second to last dimension is the number of samples
    in each chain, all other dimensions describe the shape of the parameter
    """

    # Number of chains
    nchains = possamps.shape[-1]

    # Number of samples per chain
    nsamps = possamps.shape[-2]

    # Number of dimensions
    ndims = possamps.ndim - 2

    # Number of variables to plot
    nvars = np.prod(possamps.shape[0:-2])

    # Index of variables
    varindx = np.arange(nvars).reshape(possamps.shape[0:-2])

    # Reshape data
    alldata = np.reshape(possamps, (nvars, nchains, nsamps))
    alldata = np.reshape(alldata, (nvars, nchains * nsamps))

    # Plot properties
    LineWidths = np.array([2, 5])
    teal = np.array([0, .7, .7])
    blue = np.array([0, 0, 1])
    orange = np.array([1, .3, 0])
    Colors = [teal, blue]

    # Initialize ylabels list
    ylabels = ['']

    for v in range(0, nvars):
        # Create ylabel
        whereis = np.where(varindx == v)
        newlabel = ''
        for l in range(0, ndims):
            newlabel = newlabel + ('_%i' % whereis[l][0])

        ylabels.append(newlabel)

        # Compute posterior density curves
        kde = stats.gaussian_kde(alldata[v, :])
        bounds = stats.scoreatpercentile(alldata[v, :], (.5, 2.5, 97.5, 99.5))
        for b in range(0, 2):
            # Bound by .5th percentile and 99.5th percentile
            x = np.linspace(bounds[b], bounds[-1 - b], 100)
            p = kde(x)

            # Scale distributions down
            maxp = np.max(p)

            # Plot jellyfish
            upper = .25 * p / maxp + v + 1
            lower = -.25 * p / maxp + v + 1
            lines = plt.plot(x, upper, x, lower)
            plt.setp(lines, color=Colors[b], linewidth=LineWidths[b])
            if b == 1:
                # Mark mode
                wheremaxp = np.argmax(p)
                mmode = plt.plot(np.array([1., 1.]) * x[wheremaxp],
                                 np.array([lower[wheremaxp], upper[wheremaxp]]))
                plt.setp(mmode, linewidth=3, color=orange)
                # Mark median
                mmedian = plt.plot(np.median(alldata[v, :]), v + 1, 'ko')
                plt.setp(mmedian, markersize=10, color=[0., 0., 0.])
                # Mark mean
                mmean = plt.plot(np.mean(alldata[v, :]), v + 1, '*')
                plt.setp(mmean, markersize=10, color=teal)

    # Display plot
    plt.setp(plt.gca(), yticklabels=ylabels, yticks=np.arange(0, nvars + 1))

In [None]:
#Posterior distributions
for parameter in fit.param_names:
    plt.figure()
    jellyfish(extracted_samples_dict[parameter])
    plt.title(f'Posterior distributions of the {parameter}')
    plt.savefig(f'distributions_{parameter}.png', bbox_inches='tight')
    plt.show()