In [1]:
import numpy as np
import pandas as pd
from tqdm import tqdm

from mpmath import besseljzero
from scipy.special import gamma
from scipy.special import jv, iv, ive
from scipy.optimize import differential_evolution
from scipy.interpolate import interp1d

from scipy.stats import pearsonr
from sklearn.metrics import r2_score

import seaborn as sns
import matplotlib.pyplot as plt

In [2]:
def k(a, da, t, q, sigma=2):
    return 0.5 * (q - 0.5*sigma - da(t))

def psi(a, da, t, z, tau, q, sigma=2):
    kk = k(a, da, t, q, sigma)
    
    if 2*np.sqrt(a(t)*z)/(sigma*(t-tau))<=700:
        term1 = 1./(sigma*(t - tau)) * np.exp(- (a(t) + z)/(sigma*(t-tau)))
        term2 = (a(t)/z)**(0.5*(q-sigma)/sigma)
        term3 = da(t) - (a(t)/(t-tau)) + kk
        term4 = iv(q/sigma-1, 2*np.sqrt(a(t)*z)/(sigma*(t-tau)))
        term5 = (np.sqrt(a(t)*z)/(t-tau)) * iv(q/sigma, 2*np.sqrt(a(t)*z)/(sigma*(t-tau)))
    else:
        term1 = 1./(sigma*(t - tau))
        term2 = (a(t)/z)**(0.5*(q-sigma)/sigma)
        term3 = da(t) - (a(t)/(t-tau)) + kk
        term4 = ive(q/sigma-1, (a(t) + z)/(sigma*(t-tau)))
        term5 = (np.sqrt(a(t)*z)/(t-tau)) * ive(q/sigma, (a(t) + z)/(sigma*(t-tau)))
    
    return term1 * term2 * (term3 * term4 + term5)

def ie_bessel_fpt(a, da, q, z, sigma=2, dt=0.1, T_max=2):
    g = [0]
    T = [0]
    g.append(-2*psi(a, da, dt, z, 0, q, sigma))
    T.append(dt)
    
    for n in range(2, int(T_max/dt)+2):
        s = -2 * psi(a, da, n*dt, z, 0, q, sigma)

        for j in range(1, n):
            s += 2 * dt * g[j] * psi(a, da, n*dt, a(j*dt), j*dt, q, sigma)

        g.append(s)
        T.append(n*dt)
        
    g = np.asarray(g)
    T = np.asarray(T)
    
    gt = interp1d(T, g)
    return gt

In [3]:
def HSDM_2D_likelihood(prms, RT, Theta):
    a = lambda t: prms[0]**2
    da = lambda t: 0
    ndt = prms[1]
    mu = np.array([prms[2], prms[3]])
    
    fpt = ie_bessel_fpt(a, da, mu.shape[0], 0.0001, 
                        dt=0.05, T_max=max(RT))
    log_lik = 0
    for i in range(len(RT)):
        rt, theta = RT[i], Theta[i]
        if rt - ndt > 0.001:
            mu_dot_x0 = mu[0]*np.cos(theta)
            mu_dot_x1 = mu[1]*np.sin(theta)
            term1 = prms[0] * (mu_dot_x0 + mu_dot_x1)
            term2 = 0.5 * np.linalg.norm(mu, 2)**2 * (rt - ndt)
            
            density = np.exp(term1 - term2) * fpt(rt - ndt)
            
            if 0.1**14 < density:
                log_lik += -np.log(density)
            else:
                log_lik += -np.log(0.1**14)
        else:
            log_lik += -np.log(0.1**14)
        
    return log_lik

In [4]:
data = pd.read_csv('_data/Kvam_2019_orientation_judgments_data.csv')
data = data[data['isCued'] == 0].reset_index(drop=True)

In [5]:
print('Data from {} participants in an orientation judgment task'.format(data.Participant.nunique()))
print('Number of trials for each participant is as follows:')
for par in data.Participant.unique():
    print(data[data['Participant'] == par].shape[0], end=',\t')

Data from 12 participants in an orientation judgment task
Number of trials for each participant is as follows:
599,	477,	476,	465,	480,	480,	480,	476,	475,	456,	466,	472,	

In [6]:
data

Unnamed: 0,Participant,isSpeed,isCued,jitter,jitterLevel,cueDeflections,cueOrientation,targetOrientation,response,deviation,absoluteDeviation,RT,points
0,100,0,0,45,3,819,-999.0,2.191300,1.92170,-0.269550,0.269550,1.25950,117
1,100,0,0,30,2,819,-999.0,1.619600,1.78450,0.164840,0.164840,0.86939,135
2,100,0,0,15,1,819,-999.0,0.419540,0.26625,-0.153290,0.153290,0.80090,138
3,100,0,0,30,2,819,-999.0,2.077700,2.15700,0.079275,0.079275,1.00190,155
4,100,0,0,15,1,819,-999.0,1.724100,1.72770,0.003619,0.003619,2.36400,190
...,...,...,...,...,...,...,...,...,...,...,...,...,...
5797,210,1,0,15,1,819,-999.0,0.041314,0.36629,0.324980,0.324980,0.67718,179
5798,210,1,0,45,3,819,-999.0,2.772700,0.82885,1.197800,1.197800,0.64035,124
5799,210,1,0,45,3,819,-999.0,2.446400,2.07510,-0.371310,0.371310,0.68574,176
5800,210,1,0,15,1,819,-999.0,2.370100,2.34730,-0.022781,0.022781,0.64622,199


In [7]:
estimated_prms = {'sbj': [],
                  'isSpeed':[],
                  'jitter':[],
                  'threshold': [],
                  'mux': [],
                  'muy': [],
                  'ndt': [],
                  'G2': []}

In [8]:
for par in tqdm(data.Participant.unique()):
    data_sbj = data[data['Participant'] == par].reset_index(drop=True)
    
    for sp in range(2):
        for jit in [15, 30, 45]:
            cond_data = data_sbj[(data_sbj['isSpeed']==sp) & (data_sbj['jitter']==jit)].reset_index(drop=True)
            Theta = cond_data.deviation.to_numpy()
            RT = cond_data.RT.to_numpy()

            min_ans = differential_evolution(HSDM_2D_likelihood,
                                             args=(RT, Theta),
                                             bounds=[(0.5, 5), (0.1, 1),
                                                     (-6, 7.5), (-3, 3)])
            
            estimated_prms['sbj'].append(par)
            estimated_prms['isSpeed'].append(sp)
            estimated_prms['jitter'].append(jit)
            estimated_prms['threshold'].append(min_ans.x[0])
            estimated_prms['ndt'].append(min_ans.x[1])
            estimated_prms['mux'].append(min_ans.x[2])
            estimated_prms['muy'].append(min_ans.x[3])
            estimated_prms['G2'].append(2*min_ans.fun)


100%|███████████████████████████████████████████| 12/12 [07:42<00:00, 38.54s/it]


In [9]:
estimation_df = pd.DataFrame(estimated_prms)
best_fitting = pd.read_csv('Kvam_2019_best_estimation.csv', index_col=0)
best_fitting.loc[estimation_df['G2']<best_fitting['G2']] = estimation_df.loc[estimation_df['G2']<best_fitting['G2']]

In [10]:
best_fitting.to_csv('Kvam_2019_best_estimation.csv')
best_fitting

Unnamed: 0,sbj,isSpeed,jitter,threshold,mux,muy,ndt,G2
0,100,0,15,3.615505,4.654520,0.042434,0.100000,-421.855021
1,100,0,30,3.431101,3.866256,0.067547,0.100000,-349.992532
2,100,0,45,1.820967,2.394901,-0.202487,0.401474,-195.723563
3,100,1,15,3.852865,7.500000,0.024946,0.140736,-598.742881
4,100,1,30,1.732421,5.764806,0.067050,0.375497,-500.802376
...,...,...,...,...,...,...,...,...
67,210,0,30,1.477225,3.497915,0.014716,0.356336,-263.607248
68,210,0,45,1.179448,2.517375,-0.108472,0.412790,-173.870384
69,210,1,15,2.300362,7.058793,-0.146039,0.285259,-444.046459
70,210,1,30,1.299603,4.251208,-0.296886,0.347738,-324.471656
