In [1]:
import numpy as np
import pandas as pd
from tqdm import tqdm

from scipy.stats import pearsonr
from scipy.optimize import minimize
from scipy.optimize import differential_evolution
from scipy.interpolate import interp1d

from sklearn.metrics import r2_score

import seaborn as sns
import matplotlib.pyplot as plt

from numba import jit, float64, int64

In [2]:
@jit(nopython=True)
def f(x, t, z, tau, delta, sigma=1):
    term1 = 1/np.sqrt(2 * np.pi * sigma**2 * (t-tau))
    term2 = -(x - z - delta * (t-tau))**2 / (2 * sigma**2 * (t-tau))
    return term1 * np.exp(term2)

@jit(nopython=True)
def psi(threshold, lamda, t, z, tau, delta, sigma=1):
    db = -lamda * threshold * np.exp(-lamda*t)
    term1 = 0.5*f(threshold * np.exp(-lamda*t), t, z, tau, delta, sigma)
    term2 = db - delta - (threshold * np.exp(-lamda*t) - z - delta * (t-tau))/(t-tau)
    return term1 * term2

@jit(nopython=True)
def fpt(threshold, lamda, delta, z=0, sigma=1, dt=0.02, T_max=5):
    gu = np.zeros((int(T_max/dt)+2,))
    gl = np.zeros((int(T_max/dt)+2,))
    T = np.zeros((int(T_max/dt)+2,))
    
    gu[1] = -2*psi(threshold, lamda, dt, z, 0, delta, sigma)
    gl[1] =  2*psi(-threshold, lamda, dt, z, 0, delta, sigma)
    T[1] = dt
    
    for n in range(2, int(T_max/dt)+2):
        su = -2 * psi( threshold, lamda, n*dt, z, 0, delta, sigma)
        sl =  2 * psi(-threshold, lamda, n*dt, z, 0, delta, sigma)
        
        for j in range(1, n):
            if threshold * np.exp(-lamda*j*dt) == 0:
                continue
            
            psi_n_j_pp = psi( threshold, lamda, n*dt,  threshold * np.exp(-lamda*j*dt), j*dt, delta, sigma)
            psi_n_j_pn = psi( threshold, lamda, n*dt, -threshold * np.exp(-lamda*j*dt), j*dt, delta, sigma)
            psi_n_j_np = psi(-threshold, lamda, n*dt,  threshold * np.exp(-lamda*j*dt), j*dt, delta, sigma)
            psi_n_j_nn = psi(-threshold, lamda, n*dt, -threshold * np.exp(-lamda*j*dt), j*dt, delta, sigma)
            
            su +=  2 * dt * (gu[j] * psi_n_j_pp + gl[j] * psi_n_j_pn)
            sl += -2 * dt * (gu[j] * psi_n_j_np + gl[j] * psi_n_j_nn)
            
        gu[n] = su
        gl[n] = sl
        T[n] = (n*dt)
    return gu, gl, T

In [3]:
def CDDM_likelihood(prms, RT, Contrast, Z):
    ub = lambda t: prms[0] * np.exp(-prms[1]*t)
    lb = lambda t: -1*ub(t)
    dub = lambda t: -prms[1] * prms[0] * np.exp(-prms[1]*t)
    dlb = lambda t: -1*dub(t)

    delta0 = prms[2]
    delta1 = prms[3]
    t0 = prms[4]
    sig = prms[5]
    
    T_max = np.max(np.abs(RT))
    gu, gl, TT = fpt(prms[0], prms[1], 0, z=0, dt=0.02, T_max=T_max)
    
    gtup = interp1d(TT, gu)
    gtlp = interp1d(TT, gl)
    
    ll = 0
    for i in range(len(RT)):
        if np.abs(RT[i])-t0 > 0:
            
#             delta = delta0 - delta1*np.log(Contrast[i])
            delta = delta0 - delta1*np.log((Contrast[i] - 0.025)/(Contrast[i] + 0.025))
            
            ll += 0.5*(np.log(Z[i]) - np.log(t0) + 0.5*sig**2)**2/sig**2 + 0.5*np.log(2*np.pi*sig**2*Z[i]**2)
            
            if RT[i]>=0:
                exp_term = np.exp(delta*ub(np.abs(RT[i])-t0) - 0.5*delta**2*(np.abs(RT[i])-t0))
                density = exp_term*gtup(np.abs(RT[i])-t0)
            else:                
                exp_term = np.exp(delta*lb(np.abs(RT[i])-t0) - 0.5*delta**2*(np.abs(RT[i])-t0))
                density = exp_term*gtlp(np.abs(RT[i])-t0)
                
            if density>1e-14:
                ll += -np.log(density)
            else:
                ll += -np.log(1e-14) 
        else:
            ll += -np.log(1e-14)
    
    return ll

In [4]:
data = pd.read_csv('../../_Data/Study1.csv', 
                   index_col=0).reset_index(drop=True)

data = data.sort_values(by=['participant', 'trials', 'event'])

data = data[data.condition == 'speed']
# data = data[data.condition == 'accuracy']
data = data[data.event == 3] # This event corresponds to decision time and the rest correspond to non-decision time
data = data[data['Duration']<data['rt']].reset_index(drop=True)

data['rt']/=1000
data['Duration']/=1000
data['contrast']/=100

In [5]:
data['participant'].nunique()

26

In [6]:
data

Unnamed: 0,participant,trials,event,component,Duration,event_name,rt,condition,side,contrast,response,correct,rec_sat,rec_cont
0,S10_epo,0,3,0,0.358,stimulus/74,0.494141,speed,left,0.74,right,False,0.0,0.74
1,S10_epo,1,3,0,0.372,stimulus/79,0.476562,speed,left,0.79,right,False,0.0,0.79
2,S10_epo,2,3,0,0.279,stimulus/93,0.404297,speed,left,0.93,left,True,0.0,0.93
3,S10_epo,3,3,0,0.216,stimulus/50,0.337891,speed,left,0.50,left,True,0.0,0.50
4,S10_epo,4,3,0,0.313,stimulus/42,0.462891,speed,left,0.42,right,False,0.0,0.42
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
14299,S9_epo,1111,3,0,0.197,stimulus/15,0.399414,speed,right,0.15,right,True,0.0,0.15
14300,S9_epo,1112,3,0,0.206,stimulus/81,0.366211,speed,left,0.81,left,True,0.0,0.81
14301,S9_epo,1113,3,0,0.389,stimulus/86,0.518555,speed,left,0.86,left,True,0.0,0.86
14302,S9_epo,1114,3,0,0.208,stimulus/13,0.417969,speed,right,0.13,right,True,0.0,0.13


In [7]:
prms_dc = {'sbj':[],
           'b0':[],
           'lambda':[],
           'delta0':[],
           'delta1':[],
           't0':[],
           'sigma':[],
           'mean_z':[],
           'std_z':[],
           'G2':[],
           'BIC':[]}

min_b0 = 0.5
max_b0 = 4

min_lambda = .01
max_lambda = 4

min_ndt = 0.05
max_ndt = 1

min_delta0 = -3
max_delta0 = 5

min_delta1 = -3
max_delta1 = 5

min_sig = 0.01
max_sig = 2

In [8]:
for sbj in tqdm(data.participant.unique()):
    sbj_data = data[data['participant']==sbj]
    choice = 2*sbj_data.correct.values.astype(np.int64)-1
    RT = choice*sbj_data.rt.values
    Contranst = sbj_data.contrast.values
    Z = sbj_data.rt.values-sbj_data.Duration.values

    min_ans = differential_evolution(CDDM_likelihood,
                                     args=(RT, Contranst, Z),
                                     bounds=[(min_b0, max_b0), (min_lambda, max_lambda), 
                                             (min_delta0, max_delta0), (min_delta1, max_delta1), 
                                             (min_ndt, max_ndt), (min_sig, max_sig)])

    min_ans = minimize(CDDM_likelihood,
                       args=(RT, Contranst, Z),
                       method='nelder-mead',
                       x0=min_ans.x,
                       bounds=[(min_b0, max_b0), (min_lambda, max_lambda), 
                               (min_delta0, max_delta0), (min_delta1, max_delta1), 
                               (min_ndt, max_ndt), (min_sig, max_sig)])
    prms_dc['sbj'].append(sbj)
    prms_dc['b0'].append(min_ans.x[0])
    prms_dc['lambda'].append(min_ans.x[1])
    prms_dc['delta0'].append(min_ans.x[2])
    prms_dc['delta1'].append(min_ans.x[3])
    prms_dc['t0'].append(min_ans.x[4])
    prms_dc['sigma'].append(min_ans.x[5])
    prms_dc['mean_z'].append(np.mean(Z))
    prms_dc['std_z'].append(np.std(Z))
    prms_dc['G2'].append(2*min_ans.fun)
    prms_dc['BIC'].append(2*min_ans.fun + 6 * np.log(RT.shape[0]))

100%|███████████████████████████████████████████| 26/26 [08:25<00:00, 19.46s/it]


In [9]:
prms_df = pd.DataFrame(prms_dc)

prms_df.to_csv('_prms/exp_speed_log.csv', index=False)

In [10]:
prms_df

Unnamed: 0,sbj,b0,lambda,delta0,delta1,t0,sigma,mean_z,std_z,G2,BIC
0,S10_epo,1.066733,1.951547,0.820207,1.128973,0.165513,0.26203,0.164701,0.049439,-2099.971785,-2062.068797
1,S11_epo,0.767083,0.619108,0.828023,0.798623,0.285259,0.300661,0.291594,0.076257,-935.695238,-897.749084
2,S12_epo,0.941237,0.823696,1.008109,0.017793,0.228601,0.439721,0.253327,0.099739,-591.564331,-553.650522
3,S13_epo,1.941767,1.174892,1.072178,1.713828,0.217192,0.267797,0.217273,0.069727,-1437.12121,-1399.175056
4,S14_epo,2.299114,3.841938,1.23998,0.93321,0.191949,0.27725,0.191765,0.056666,-2204.139013,-2166.192859
5,S15_epo,0.995031,2.033631,0.668686,1.267327,0.211133,0.323176,0.208444,0.063075,-1637.858987,-1599.977699
6,S16_epo,1.421882,1.98784,0.77161,1.354711,0.178651,0.34867,0.177056,0.069428,-1659.479999,-1621.555389
7,S17_epo,1.247473,1.187486,0.542924,1.183964,0.209905,0.357574,0.207032,0.077328,-1017.407741,-979.461588
8,S18_epo,0.730296,0.627546,0.409032,1.020894,0.209451,0.401285,0.212032,0.078819,-915.436551,-877.490397
9,S19_epo,0.557463,0.260501,0.370344,0.663424,0.198789,0.556761,0.191164,0.083277,-847.340959,-809.427151


In [12]:
data['contrast'].max()

0.95