In [1]:
import numpy as np
import pandas as pd
from tqdm import tqdm

from mpmath import besseljzero
from scipy.special import gamma
from scipy.special import jv, iv
from scipy.special import jn_zeros
from scipy.optimize import minimize
from scipy.interpolate import interp1d

from scipy.stats import pearsonr
from sklearn.metrics import r2_score

import seaborn as sns
import matplotlib.pyplot as plt

In [2]:
def simulate_HSDM_4D(a, mu, ndt, sigma=1, dt=0.001):
    x = np.zeros(mu.shape)
    
    rt = 0
    
    while np.linalg.norm(x, 2) < a(rt):
        x += mu*dt + sigma*np.sqrt(dt)*np.random.normal(0, 1, mu.shape)
        rt += dt
    
    theta1 = np.arctan2(np.sqrt(x[3]**2 + x[2]**2 + x[1]**2), x[0])
    theta2 = np.arctan2(np.sqrt(x[3]**2 + x[2]**2), x[1])
    theta3 = np.arctan2(x[3], x[2])
    
    
    return rt+ndt, (theta1, theta2, theta3)

In [3]:
def series_bessel_fpt(t, a=1, sigma=1, nu=0, n=100):
    zeros = np.asarray([float(besseljzero(nu, i+1)) for i in range(n)])
    fpt = np.zeros(t.shape)
    
    for i in range(t.shape[0]):
        series = np.sum((zeros**(nu+1)/jv(nu+1, zeros)) * np.exp(-(zeros**2 * sigma**2)/(2*a**2)*t[i]))
        fpt[i] = sigma**2/(2**nu * a**2 * gamma(nu + 1)) * series
        
    return interp1d(t, fpt)

In [4]:
def HSDM_4D_likelihood(prms, RT, Theta, N_series):
    a = prms[0]
    ndt = prms[1]
    mu = np.array([prms[2], prms[3], prms[4], prms[5]])
    
    tt = np.arange(0.001, max(RT)+0.02, 0.02)
    fpt = series_bessel_fpt(tt, a, sigma=1, nu=(mu.shape[0]-2)/2, n=N_series)
    
    log_lik = 0
    for i in range(len(RT)):
        rt, theta = RT[i], Theta[i]
        if rt - ndt > 0.001:
            mu_dot_x0 = mu[0]*np.cos(theta[0])
            mu_dot_x1 = mu[1]*np.sin(theta[0])*np.cos(theta[1]) 
            mu_dot_x2 = mu[2]*np.sin(theta[0])*np.sin(theta[1])*np.cos(theta[2])
            mu_dot_x3 = mu[3]*np.sin(theta[0])*np.sin(theta[1])*np.sin(theta[2]) 
            term1 = prms[0] * (mu_dot_x0 + mu_dot_x1 + mu_dot_x2 + mu_dot_x3)
            term2 = 0.5 * np.linalg.norm(mu, 2)**2 * (rt-ndt)
            
            density = np.exp(term1 - term2) * fpt(rt - ndt)
            
            if 0.1**14 < density:
                log_lik += -np.log(density)
            else:
                log_lik += -np.log(0.1**14)
        else:
            log_lik += -np.log(0.1**14)
        
    return log_lik

In [5]:
recovery_df = {'threshold_true': [],
               'threshold_estimate': [],
               'ndt_true': [],
               'ndt_estimate': [],
               'mu1_true': [],
               'mu1_estimate': [],
               'mu2_true': [],
               'mu2_estimate': [],
               'mu3_true': [],
               'mu3_estimate': [],
               'mu4_true': [],
               'mu4_estimate': []}

min_threshold = 0.5
max_threshold = 5

min_ndt = 0.1
max_ndt = 1

min_mu = -3.5
max_mu = 3.5

N_series = 250

In [6]:
for n in tqdm(range(150)):
    threshold = np.random.uniform(min_threshold, max_threshold)
    a = lambda t: threshold
    ndt = np.random.uniform(min_ndt, max_ndt)
    mu = np.array([np.random.uniform(min_mu, max_mu), 
                   np.random.uniform(min_mu, max_mu),
                   np.random.uniform(min_mu, max_mu),
                   np.random.uniform(min_mu, max_mu)])
    
    
    recovery_df['threshold_true'].append(threshold)
    recovery_df['ndt_true'].append(ndt)
    recovery_df['mu1_true'].append(mu[0])
    recovery_df['mu2_true'].append(mu[1])
    recovery_df['mu3_true'].append(mu[2])
    recovery_df['mu4_true'].append(mu[3])
    
    RT = []
    Theta = []
    
    for i in range(250):
        rt, theta = simulate_HSDM_4D(a, mu, ndt)
        RT.append(rt)
        Theta.append(theta)
    
    min_ans = minimize(HSDM_4D_likelihood,
                       args=(RT, Theta, N_series), 
                       x0=np.array([np.random.uniform(min_threshold, max_threshold),
                                    np.random.uniform(min_ndt, max_ndt), 
                                    np.random.uniform(min_mu, max_mu),
                                    np.random.uniform(min_mu, max_mu),
                                    np.random.uniform(min_mu, max_mu),
                                    np.random.uniform(min_mu, max_mu)]),
                       method='Nelder-Mead', 
                       bounds=[(min_threshold, max_threshold), (min_ndt, max_ndt),
                               (min_mu, max_mu), (min_mu, max_mu), 
                               (min_mu, max_mu), (min_mu, max_mu)])
    
    recovery_df['threshold_estimate'].append(min_ans.x[0])
    recovery_df['ndt_estimate'].append(min_ans.x[1])
    recovery_df['mu1_estimate'].append(min_ans.x[2])
    recovery_df['mu2_estimate'].append(min_ans.x[3])
    recovery_df['mu3_estimate'].append(min_ans.x[4])
    recovery_df['mu4_estimate'].append(min_ans.x[5])
    
recovery_df = pd.DataFrame(recovery_df)

100%|███████████████████████████████████| 150/150 [3:27:48<00:00, 83.13s/it]


In [7]:
recovery_df

Unnamed: 0,threshold_true,threshold_estimate,ndt_true,ndt_estimate,mu1_true,mu1_estimate,mu2_true,mu2_estimate,mu3_true,mu3_estimate,mu4_true,mu4_estimate
0,2.489391,2.374237,0.228935,0.255423,3.011704,3.109683,-1.467541,-1.554401,-2.218748,-2.118139,-3.433076,-3.499994
1,4.812167,4.868996,0.113142,0.100005,-1.667708,-1.624517,-0.068537,0.043324,-1.690678,-1.664535,2.727342,2.863536
2,2.942051,3.033480,0.159553,0.154711,0.372851,0.386705,3.490944,3.446594,-3.131266,-3.139138,1.434539,1.384794
3,2.755465,1.957853,0.198421,0.495260,0.182085,0.077212,2.863043,1.811294,-0.760503,-0.188503,2.936075,2.451529
4,3.604415,3.638560,0.143397,0.100000,1.209562,1.155201,0.309609,0.272802,2.353546,2.293834,-0.989765,-0.935354
...,...,...,...,...,...,...,...,...,...,...,...,...
145,1.365854,1.382072,0.356208,0.358574,3.294580,3.442777,-0.171985,-0.395924,-1.417428,-1.233740,1.575345,1.390446
146,3.716339,2.190823,0.750126,1.000000,3.130151,3.500000,-2.732246,-2.954996,-3.128652,-3.223654,-3.420109,-3.490237
147,1.489365,1.517284,0.802264,0.804656,2.630464,2.641330,2.406872,2.139105,3.022772,3.093834,2.476160,2.283366
148,0.559872,5.000000,0.275573,0.282573,2.638833,3.500000,1.186290,3.500000,-2.946036,-3.500000,-3.461176,-3.500000


In [8]:
recovery_df.corr()

Unnamed: 0,threshold_true,threshold_estimate,ndt_true,ndt_estimate,mu1_true,mu1_estimate,mu2_true,mu2_estimate,mu3_true,mu3_estimate,mu4_true,mu4_estimate
threshold_true,1.0,0.777201,-0.079008,-0.188581,-0.057034,-0.05802,-0.109582,-0.161684,-0.022384,-0.041718,-0.116715,-0.040996
threshold_estimate,0.777201,1.0,-0.229532,-0.291022,-0.031181,-0.045551,-0.074389,-0.090357,-0.020569,-0.087235,-0.111934,-0.100015
ndt_true,-0.079008,-0.229532,1.0,0.729141,0.075177,0.136768,-0.011215,-0.042254,-0.010662,-0.051553,-0.031268,0.059968
ndt_estimate,-0.188581,-0.291022,0.729141,1.0,0.113982,0.070616,-0.139144,-0.088197,-0.033928,-0.069345,0.038025,-0.022243
mu1_true,-0.057034,-0.031181,0.075177,0.113982,1.0,0.902068,0.013561,0.044056,0.007185,0.016021,-0.016121,-0.037835
mu1_estimate,-0.05802,-0.045551,0.136768,0.070616,0.902068,1.0,0.079046,0.104062,-0.021307,0.026734,-0.055876,-0.026871
mu2_true,-0.109582,-0.074389,-0.011215,-0.139144,0.013561,0.079046,1.0,0.921625,-0.041947,0.029458,0.154997,0.166617
mu2_estimate,-0.161684,-0.090357,-0.042254,-0.088197,0.044056,0.104062,0.921625,1.0,0.004309,0.018227,0.122983,0.135411
mu3_true,-0.022384,-0.020569,-0.010662,-0.033928,0.007185,-0.021307,-0.041947,0.004309,1.0,0.859091,0.084363,0.093992
mu3_estimate,-0.041718,-0.087235,-0.051553,-0.069345,0.016021,0.026734,0.029458,0.018227,0.859091,1.0,0.106198,0.094602


In [9]:
recovery_df.to_csv('Series_4d_recovery_{}.csv'.format(N_series))