In this notebook we are going to fit 4 different hierarchical DDMs to Alie's data that allow CPE to linearly modulate 1/4 of the main DDM parameters: [drift rate, bias, non-decision time, boundary separation]. 

We will perform model comparison using ELPD-LOO (posterior predictive checks using leave-one-out cross validation), to determine which model fits best. 

Then we will plot the posteriors to see how, exactly, CPE affects the parameter in the winning model. 

In [None]:
import hssm 
hssm.set_floatX("float32")

# Basics
import os
import sys
import time
from matplotlib import pyplot as plt
import arviz as az  # Visualization
import pytensor  # Graph-based tensor library
import hssm
import pandas as pd
import numpy as np
import scipy as sp
from tqdm import tqdm
import seaborn as sns
from matplotlib.backends.backend_pdf import PdfPages

# import ssms.basic_simulators # Model simulators
import hddm_wfpt
import bambi as bmb

# Setting float precision in pytensor
pytensor.config.floatX = "float32"

# from jax.config import config

import numpyro

numpyro.set_host_device_count(4)


import warnings
warnings.filterwarnings('ignore')

In [None]:
base_dir = '/sc/arion'
# change to your directory
save_dir_plots = f'{base_dir}/projects/guLab/Salman/MemoryBanditManuscript/Figures/Exp1'

Load the data 

In [None]:
alie_df = pd.read_csv('/sc/arion/projects/guLab/Salman/Prolific/rt_choice_df_02272024.csv')

alie_df = alie_df[['subj_id', 'choice_t1', 'RT_t1', 'cpe_t']]

alie_df.rename(columns={'choice_t1':'response', 
                       'RT_t1':'rt',
                       'subj_id':'subj_idx'}, inplace=True)

Construct the models

In [None]:
hier_v_cpe_randint = hssm.HSSM(
    model="ddm",
    p_outlier = None,
    lapse = None, 
    data=alie_df[['rt', 'response', 'subj_idx', 'cpe_t']] ,
    include=[
        {
            "name": "v",
                "formula": "v ~ 1 + (1|subj_idx) + cpe_t",
                "link": "identity",
            },
    ],
)

hier_t_cpe_randint = hssm.HSSM(
    model="ddm",
    p_outlier = None,
    lapse = None, 
    data=alie_df[['rt', 'response', 'subj_idx', 'cpe_t']] ,
    include=[
        {
            "name": "t",
            "formula": "t ~ 1 + (1|subj_idx) + cpe_t",
            "link": "identity",
        },
    ],
)

hier_z_cpe_randint = hssm.HSSM(
    model="ddm",
    p_outlier = None,
    lapse = None, 
    data=alie_df[['rt', 'response', 'subj_idx', 'cpe_t']] ,
    include=[
        {
            "name": "z",
            "formula": "z ~ 1 + (1|subj_idx) + cpe_t",
            "link": "identity",
        },
    ],
)

hier_a_cpe_randint = hssm.HSSM(
    model="ddm",
    p_outlier = None,
    lapse = None, 
    data=alie_df[['rt', 'response', 'subj_idx', 'cpe_t']] ,
    include=[
        {
            "name": "a",
            "formula": "a ~ 1 + (1|subj_idx) + cpe_t",
            "link": "identity",
        },
    ],
)

model_types = {'hier_v_cpe_randint':hier_v_cpe_randint,
               'hier_t_cpe_randint':hier_t_cpe_randint, 
               'hier_z_cpe_randint':hier_z_cpe_randint,  
               'hier_a_cpe_randint':hier_a_cpe_randint}

Plot the model structure 

In [None]:
hier_v_cpe_randint.model.graph()

Run the models (will take time and memory)

In [None]:
if __name__ == "__main__":
    model_res = {f'{x}':[] for x in model_types.keys()}

    # change to your directory
    output_dir = '/sc/arion/projects/guLab/Salman/Prolific'

    for model_key, model in model_types.items(): 
        idata = model.sample(
            sampler='nuts_numpyro',
            chains = 4,
            cores = 4,
            draws = 5000,
            tune = 5000,
            idata_kwargs=dict(log_likelihood=True))

        # Save them here: 
        az.to_netcdf(idata, f"{output_dir}/{model_key}_model")

        model_res[model_key] = idata

    df_comp_loo = az.compare(model_res, ic='loo')

Compare all models

In [None]:
az.plot_compare(df_comp_loo)


Compare just the two best models

In [None]:
dict_you_want = {key: model_res[key] for key in ['hier_v_cpe_randint', 'hier_z_cpe_randint']}
df_comp_loo = az.compare(dict_you_want, ic='loo')
az.plot_compare(df_comp_loo)


Plot the posteriors for the winning model 

In [None]:
az.plot_trace(model_res['hier_v_cpe_randint'])