In [1]:
import numpy as np
import pandas as pd
from tqdm import tqdm

from scipy.stats import pearsonr
from scipy.optimize import minimize
from scipy.optimize import differential_evolution
from scipy.interpolate import interp1d

from sklearn.metrics import r2_score

import seaborn as sns
import matplotlib.pyplot as plt

from numba import jit, float64, int64

In [2]:
@jit(nopython=True)
def f(x, t, z, tau, delta, sigma=1):
    term1 = 1/np.sqrt(2 * np.pi * sigma**2 * (t-tau))
    term2 = -(x - z - delta * (t-tau))**2 / (2 * sigma**2 * (t-tau))
    return term1 * np.exp(term2)

@jit(nopython=True)
def psi(threshold, lamda, t, z, tau, delta, sigma=1):
    db = (-lamda*threshold)/(1 + lamda*t)**2
    term1 = 0.5*f(threshold /(1 + lamda*t), t, z, tau, delta, sigma)
    term2 = db - delta - ((threshold /(1 + lamda*t)) - z - delta * (t-tau))/(t-tau)
    return term1 * term2

@jit(nopython=True)
def fpt(threshold, lamda, delta, z=0, sigma=1, dt=0.02, T_max=5):
    gu = np.zeros((int(T_max/dt)+2,))
    gl = np.zeros((int(T_max/dt)+2,))
    T = np.zeros((int(T_max/dt)+2,))
    
    gu[1] = -2*psi(threshold, lamda, dt, z, 0, delta, sigma)
    gl[1] =  2*psi(-threshold, lamda, dt, z, 0, delta, sigma)
    T[1] = dt
    
    for n in range(2, int(T_max/dt)+2):
        su = -2 * psi( threshold, lamda, n*dt, z, 0, delta, sigma)
        sl =  2 * psi(-threshold, lamda, n*dt, z, 0, delta, sigma)
        
        for j in range(1, n):
            if (threshold /(1 + lamda*j*dt)) == 0:
                continue
            
            psi_n_j_pp = psi( threshold, lamda, n*dt,  threshold /(1 + lamda*j*dt), j*dt, delta, sigma)
            psi_n_j_pn = psi( threshold, lamda, n*dt, -threshold /(1 + lamda*j*dt), j*dt, delta, sigma)
            psi_n_j_np = psi(-threshold, lamda, n*dt,  threshold /(1 + lamda*j*dt), j*dt, delta, sigma)
            psi_n_j_nn = psi(-threshold, lamda, n*dt, -threshold /(1 + lamda*j*dt), j*dt, delta, sigma)
            
            su +=  2 * dt * (gu[j] * psi_n_j_pp + gl[j] * psi_n_j_pn)
            sl += -2 * dt * (gu[j] * psi_n_j_np + gl[j] * psi_n_j_nn)
            
        gu[n] = su
        gl[n] = sl
        T[n] = (n*dt)
    return gu, gl, T

In [3]:
def CDDM_likelihood(prms, RT, Contrast):
    ub = lambda t: prms[0] /(1 + prms[1]*t)
    lb = lambda t: -1*ub(t)
    dub = lambda t: (-prms[1]*prms[0])/(1 + prms[1]*t)**2
    dlb = lambda t: -1*dub(t)

    delta0 = prms[2]
    delta1 = prms[3]
    t0 = prms[4]
    
    T_max = np.max(np.abs(RT))
    gu, gl, TT = fpt(prms[0], prms[1], 0, z=0, dt=0.02, T_max=T_max)
    
    gtup = interp1d(TT, gu)
    gtlp = interp1d(TT, gl)
    
    ll = 0
    for i in range(len(RT)):
        if np.abs(RT[i])-t0 > 0:
            
            delta = delta0 - delta1*np.log(Contrast[i])
            
            if RT[i]>=0:
                exp_term = np.exp(delta*ub(np.abs(RT[i])-t0) - 0.5*delta**2*(np.abs(RT[i])-t0))
                density = exp_term*gtup(np.abs(RT[i])-t0)
            else:                
                exp_term = np.exp(delta*lb(np.abs(RT[i])-t0) - 0.5*delta**2*(np.abs(RT[i])-t0))
                density = exp_term*gtlp(np.abs(RT[i])-t0)
                
            if density>1e-14:
                ll += -np.log(density)
            else:
                ll += -np.log(1e-14) 
        else:
            ll += -np.log(1e-14)
    
    return ll

In [4]:
data = pd.read_csv('../../_Data/Study1.csv', 
                   index_col=0).reset_index(drop=True)

data = data.sort_values(by=['participant', 'trials', 'event'])

# data = data[data.condition == 'speed']
data = data[data.condition == 'accuracy']
data = data[data.event == 3] # This event corresponds to decision time and the rest correspond to non-decision time
data = data[data['Duration']<data['rt']].reset_index(drop=True)

data['rt']/=1000
data['Duration']/=1000

In [5]:
data['participant'].nunique()

26

In [6]:
data

Unnamed: 0,participant,trials,event,component,Duration,event_name,rt,condition,side,contrast,response,correct,rec_sat,rec_cont
0,S10_epo,279,3,0,0.277,stimulus/36,0.426758,accuracy,right,36,right,True,1.0,0.36
1,S10_epo,280,3,0,0.312,stimulus/9,0.451172,accuracy,right,9,right,True,1.0,0.09
2,S10_epo,281,3,0,0.305,stimulus/32,0.476562,accuracy,left,32,left,True,1.0,0.32
3,S10_epo,282,3,0,0.448,stimulus/62,0.611328,accuracy,left,62,left,True,1.0,0.62
4,S10_epo,283,3,0,0.288,stimulus/24,0.485352,accuracy,left,24,left,True,1.0,0.24
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
13763,S9_epo,832,3,0,0.262,stimulus/88,0.460938,accuracy,left,88,right,False,1.0,0.88
13764,S9_epo,833,3,0,0.339,stimulus/81,0.486328,accuracy,left,81,left,True,1.0,0.81
13765,S9_epo,834,3,0,0.320,stimulus/92,0.484375,accuracy,left,92,left,True,1.0,0.92
13766,S9_epo,835,3,0,0.452,stimulus/72,0.619141,accuracy,right,72,right,True,1.0,0.72


In [7]:
prms_dc = {'sbj':[],
           'b0':[],
           'lambda':[],
           'delta0':[],
           'delta1':[],
           't0':[],
           'G2':[],
           'BIC':[]}

min_b0 = 0.5
max_b0 = 4

min_lambda = .01
max_lambda = 4

min_ndt = 0.05
max_ndt = 1

min_delta0 = -3
max_delta0 = 5

min_delta1 = -3
max_delta1 = 5

In [8]:
for sbj in tqdm(data.participant.unique()):
    sbj_data = data[data['participant']==sbj]
    choice = 2*sbj_data.correct.values.astype(np.int64)-1
    RT = choice*sbj_data.rt.values
    Contranst = sbj_data.contrast.values

    min_ans = differential_evolution(CDDM_likelihood,
                                         args=(RT, Contranst),
                                         bounds=[(min_b0, max_b0), (min_lambda, max_lambda), 
                                                 (min_delta0, max_delta0), (min_delta1, max_delta1), 
                                                 (min_ndt, max_ndt)])

    min_ans = minimize(CDDM_likelihood,
                       args=(RT, Contranst),
                       method='nelder-mead',
                       x0=min_ans.x,
                       bounds=[(min_b0, max_b0), (min_lambda, max_lambda), 
                               (min_delta0, max_delta0), (min_delta1, max_delta1), 
                               (min_ndt, max_ndt)])
    prms_dc['sbj'].append(sbj)
    prms_dc['b0'].append(min_ans.x[0])
    prms_dc['lambda'].append(min_ans.x[1])
    prms_dc['delta0'].append(min_ans.x[2])
    prms_dc['delta1'].append(min_ans.x[3])
    prms_dc['t0'].append(min_ans.x[4])
    prms_dc['G2'].append(2*min_ans.fun)
    prms_dc['BIC'].append(2*min_ans.fun + 5 * np.log(RT.shape[0]))

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 26/26 [09:54<00:00, 22.85s/it]


In [9]:
prms_df = pd.DataFrame(prms_dc)

prms_df.to_csv('_prms/bhyp_{}.csv'.format(data.condition.unique()[0]), index=False)