# HDDM models informed with EEG and pre-trial accuracy

Imports

In [None]:
import cmdstanpy
cmdstanpy.install_cmdstan()

In [None]:
from cmdstanpy import CmdStanModel
import os
import numpy as np
import pandas as pd
from datetime import datetime
import pickle
import json
import time
from contextlib import redirect_stdout
import seaborn as sns
# set path to CMD Stan
# set_cmdstan_path('/stan/math_HOW-TO-USE/cmdstan-ddm-7pm')
# cmdstan_path()

# set Stan globals
# os.environ['STAN_NUM_THREADS'] = "12"

## Define the model

In [None]:
name = 'drift_boundary_pre2_ncond_tbb' 
model_name = f'wiener_{name}_model.stan'

print(f'Processing model: {model_name}')

## Compile the model

In [None]:
def compile_model(stan_file, max_retries=5, retry_delay=5):
    model = None
    compiled = False
    retries=0

    while retries < max_retries:
        try:
            model = CmdStanModel(
                stan_file=stan_file, 
                cpp_options={'STAN_THREADS': True}, 
                force_compile=True
            )
            compiled = True
            break
        except Exception as e:
            print(f"Error compiling model: {e}")
            retries+=1
            if retries >= max_retries:
                print("Max retries reached. Exiting.")
                return None, compiled
            print(f"Retrying in {retry_delay} seconds...")
            time.sleep(retry_delay)
    if not compiled:
        return None, compiled
    else:
        return model, compiled

In [None]:
stan_file = os.path.join('../models/TBB_models', model_name)
hddm_model, compiled = compile_model(stan_file)
compiled

In [None]:
hddm_model.exe_info()

## Define data file

In [None]:
data_file = os.path.join('../data/', 'stahl_acc_data_standarized.json')

### Read the data

In [None]:
with open(data_file, 'r') as file:
    data = json.load(file)

In [None]:
data_df = pd.DataFrame(
    {
        'participant_index': data['participant'],
         'rt': abs(np.array(data['y']))
    }
)

In [None]:
data_df

In [None]:
with pd.option_context('display.max_rows', None,):
    display(data_df.groupby('participant_index').describe())

## Fit the model

Fit parameters

In [None]:
num_chains = 1
warmup = 10
num_samples = 10
thin=5
adapt_delta=0.99
random_state = 42

Define initial values

In [None]:
n_participants = data['n_participants']

min_rt = np.zeros(n_participants)
for idx, participant_idx in enumerate(np.unique(data['participant'])):
    participant_rts = data_df[data_df['participant_index'] == participant_idx]['rt'].to_numpy()
    min_rt[idx] = np.min(abs(participant_rts))

initials = []
for c in range(0, num_chains):
    chain_init = {               
        'ter_sd': np.random.uniform(.01, .2),
        'alpha_sd': np.random.uniform(.01, 1.),
        'alpha_cond_sd': np.random.uniform(.01, 1.), # <- was 0.5
        'delta_sd': np.random.uniform(.1, 3.),
        'delta_cond_sd': np.random.uniform(.1, 3.),
        
        'alpha_ne_sd': np.random.uniform(.01, .2), # <- works quite nice with .01, .2, works with .01, 1 
        'delta_ne_sd': np.random.uniform(.001, .2), # 0.2 ###########

        'alpha_ne_pre_acc_sd': np.random.uniform(.01, .2), # <- works quite nice with .01, .2, works with .01, 1 
        'delta_ne_pre_acc_sd': np.random.uniform(.001, .2), # 0.2 ###########

        'ter': np.random.uniform(0.05, .4),
        'alpha': np.random.uniform(1, 2), #0.2 ## <- does not work with < 1
        'alpha_cond': np.random.uniform(-.5, .5), # <- was -.1, .1 and works a little bit better
        'delta': np.random.uniform(-4., 4.),
        'delta_cond': np.random.uniform(-4., 4.),

        'alpha_ne': np.random.uniform(-.05, .05), # <- does not work with -0.1, 0.1
        'alpha_pre_acc': np.random.uniform(-0.1, .1), 
        'alpha_ne_pre_acc': np.random.uniform(-.05, .05), # does not work with -0.1, 0.1
        'alpha_ne_cond': np.random.uniform(-.05, .05), # <- does not work with -0.1, 0.1
        'alpha_pre_acc_cond': np.random.uniform(-0.1, .1), 
        'alpha_ne_pre_acc_cond': np.random.uniform(-.05, .05), # does not work with -0.1, 0.1

        'delta_ne': np.random.uniform(-.1, .1),
        'delta_pre_acc': np.random.uniform(-.5, .5),
        'delta_ne_pre_acc': np.random.uniform(-.1, .1),
        'delta_ne_cond': np.random.uniform(-.1, .1),
        'delta_pre_acc_cond': np.random.uniform(-.5, .5),
        'delta_ne_pre_acc_cond': np.random.uniform(-.1, .1),
        
        'participants_ter': np.random.uniform(0.05, .4, size=n_participants),
        'participants_alpha': np.random.uniform(1, 2., size=n_participants), ## <- does not work with <1
        'participants_alpha_cond': np.random.uniform(-0.5, .5, size=n_participants), # <- was -.1, .1 and works a little bit better 
        'participants_delta': np.random.uniform(-4., 4., size=n_participants),
        'participants_delta_cond': np.random.uniform(-4., 4., size=n_participants),
        
        'participants_alpha_ne': np.random.uniform(-.05, .05, size=n_participants),
        'participants_delta_ne': np.random.uniform(-.1, .1, size=n_participants), #########

        'participants_alpha_ne_pre_acc': np.random.uniform(-.05, .05, size=n_participants),
        'participants_delta_ne_pre_acc': np.random.uniform(-.05, .05, size=n_participants), #########
    }
    for p in range(0, n_participants):
        chain_init['participants_ter'][p] = np.random.uniform(0., min_rt[p]/2)

    initials.append(chain_init)

print(min_rt)
# n_participants = data['n_participants']

# min_rt = np.zeros(n_participants)
# for idx, participant_idx in enumerate(np.unique(data['participant'])):
#     participant_rts = data_df[data_df['participant_index'] == participant_idx]['rt'].to_numpy()
#     min_rt[idx] = np.min(abs(participant_rts))

# initials = []
# for c in range(0, num_chains):
#     chain_init = {               
#         'ter_sd': np.random.uniform(.01, .2),
#         'varsigma_sd': np.random.uniform(.01, 1.),
#         'varsigma_cond_sd': np.random.uniform(.01, 1.), # <- was 0.5
#         'delta_sd': np.random.uniform(.1, 3.),
#         'delta_cond_sd': np.random.uniform(.1, 3.),
        
#         'varsigma_ne_sd': np.random.uniform(.01, .2), # <- works quite nice with .01, .2, works with .01, 1 
#         'delta_ne_sd': np.random.uniform(.001, .2), # 0.2 ###########

#         'varsigma_ne_pre_acc_sd': np.random.uniform(.01, .2), # <- works quite nice with .01, .2, works with .01, 1 
#         'delta_ne_pre_acc_sd': np.random.uniform(.001, .2), # 0.2 ###########

#         'ter': np.random.uniform(0.05, .4),
#         'varsigma': np.random.uniform(1, 2), #0.2 ## <- does not work with < 1
#         'varsigma_cond': np.random.uniform(-.5, .5), # <- was -.1, .1 and works a little bit better
#         'delta': np.random.uniform(-4., 4.),
#         'delta_cond': np.random.uniform(-4., 4.),

#         'varsigma_ne': np.random.uniform(-.05, .05), # <- does not work with -0.1, 0.1
#         'varsigma_pre_acc': np.random.uniform(-0.1, .1), 
#         'varsigma_ne_pre_acc': np.random.uniform(-.05, .05), # does not work with -0.1, 0.1
#         'varsigma_ne_cond': np.random.uniform(-.05, .05), # <- does not work with -0.1, 0.1
#         'varsigma_pre_acc_cond': np.random.uniform(-0.1, .1), 
#         'varsigma_ne_pre_acc_cond': np.random.uniform(-.05, .05), # does not work with -0.1, 0.1

#         'delta_ne': np.random.uniform(-.1, .1),
#         'delta_pre_acc': np.random.uniform(-.5, .5),
#         'delta_ne_pre_acc': np.random.uniform(-.1, .1),
#         'delta_ne_cond': np.random.uniform(-.1, .1),
#         'delta_pre_acc_cond': np.random.uniform(-.5, .5),
#         'delta_ne_pre_acc_cond': np.random.uniform(-.1, .1),
        
#         'participants_ter': np.random.uniform(0.05, .4, size=n_participants),
#         'participants_varsigma': np.random.uniform(1, 2., size=n_participants), ## <- does not work with <1
#         'participants_varsigma_cond': np.random.uniform(-0.5, .5, size=n_participants), # <- was -.1, .1 and works a little bit better 
#         'participants_delta': np.random.uniform(-4., 4., size=n_participants),
#         'participants_delta_cond': np.random.uniform(-4., 4., size=n_participants),
        
#         'participants_varsigma_ne': np.random.uniform(-.05, .05, size=n_participants),
#         'participants_delta_ne': np.random.uniform(-.1, .1, size=n_participants), #########

#         'participants_varsigma_ne_pre_acc': np.random.uniform(-.05, .05, size=n_participants),
#         'participants_delta_ne_pre_acc': np.random.uniform(-.05, .05, size=n_participants), #########
#     }
#     for p in range(0, n_participants):
#         chain_init['participants_ter'][p] = np.random.uniform(0., min_rt[p]/2)

#     initials.append(chain_init)

# print(min_rt)

Perform fit

In [None]:
def fit_model(model, data_file, name, max_retries=5, retry_delay=5):
    fit = None
    retries=0

    while retries < max_retries:
        try:
            with open('jupyter_logs.txt', 'a') as f:
                fit = model.sample(
                    data=data_file,
                    chains=num_chains, 
                    seed=random_state,
                    thin=thin,
                    adapt_delta=adapt_delta,
                    inits=initials, 
                    iter_warmup=warmup, 
                    iter_sampling=num_samples,
                    parallel_chains=num_chains,
                    threads_per_chain= 12,
                    max_treedepth=12,
                    show_progress=True,
                    show_console=True,
                    output_dir=f'../../plgrid_results/ncond_models/stahl/acc/stahl_acc_ncond_{name}_1/'
                )
            break
        except Exception as e:
            print(f"Error sampling model: {e}")
            retries+=1
            if retries >= max_retries:
                print("Max retries reached. Exiting.")
                return None
            print(f"Retrying in {retry_delay} seconds...")
            time.sleep(retry_delay)
    return fit

In [None]:
if compiled:
    fit = fit_model(
        hddm_model,
        data_file,
        name,
    )
    

# with open('jupyter_logs.txt', 'a') as f:
#     with redirect_stdout(f):
#         start = time.time()
#         fit = hddm_model.sample(
#             data=data_file,
#             chains=num_chains, 
#             seed=random_state,
#             thin=thin,
#             adapt_delta=adapt_delta,
#             inits=initials, 
#             iter_warmup=warmup, 
#             iter_sampling=num_samples,
#             parallel_chains=num_chains,
#             threads_per_chain= 12,
#             max_treedepth=12,
#             show_progress=True,
#             show_console=True,
#             output_dir=f'../../plgrid_results/ncond_models/stahl/acc/stahl_acc_ncond_{name}_1/'
#         )
#         end = time.time()

# print(f'Fitting took: {end - start}')

In [None]:
print(fit.diagnose())

In [None]:
summary_df = fit.summary()

In [None]:
summary_df.to_csv(f'test_priors2_summary.csv')

In [None]:
# 7 - changed initials (as boundary(and main effects - the same as boundary, no sd priors || 1.23
# 6 - changed initials and main effects, and sd priors ||1.6 for delta_ne_sd
# 5 - changed initials and main effects, no sd priors || 1.10 for delta_ne_sd
# 4 - changed initials and main effects, and sd priors
# 3 - ?
# 2 - main effects, and sd priors
# 1 - changed main effects

In [None]:
fit_df = fit.draws_pd()

sns.lineplot(
    data=fit_df,
    x = 'iter__',
    y = 'delta_ne_pre_acc_sd',
    hue='chain__'
)

Save the MCMC fit object

In [None]:
fit.save_csvfiles(dir=f'../plgrid_results/ncond_models_stahl/acc/stahl_acc_{name}_warmup-{warmup}_samples-{num_samples}_thin-{thin}-6/')