In [1]:
import numpy as np
import pandas as pd
from tqdm import tqdm

from mpmath import besseljzero
from scipy.special import gamma
from scipy.special import jv, iv
from scipy.special import jn_zeros
from scipy.optimize import minimize
from scipy.interpolate import interp1d

from scipy.stats import pearsonr
from sklearn.metrics import r2_score

import seaborn as sns
import matplotlib.pyplot as plt

In [2]:
def simulate_HSDM_3D(a, mu, ndt, sigma=1, dt=0.001):
    x = np.zeros(mu.shape)
    
    rt = 0
    
    while np.linalg.norm(x, 2) < a(rt):
        x += mu*dt + sigma*np.sqrt(dt)*np.random.normal(0, 1, mu.shape)
        rt += dt
    
    theta1 = np.arctan2(np.sqrt(x[2]**2 + x[1]**2), x[0])
    theta2 = np.arctan2(x[2], x[1])   
    
    return rt+ndt, (theta1, theta2)

In [3]:
def series_bessel_fpt(t, a=1, sigma=1, nu=0, n=100):
    zeros = np.asarray([float(besseljzero(nu, i+1)) for i in range(n)])
    fpt = np.zeros(t.shape)
    
    for i in range(t.shape[0]):
        series = np.sum((zeros**(nu+1)/jv(nu+1, zeros)) * np.exp(-(zeros**2 * sigma**2)/(2*a**2)*t[i]))
        fpt[i] = sigma**2/(2**nu * a**2 * gamma(nu + 1)) * series
        
    return interp1d(t, fpt)

In [4]:
def HSDM_3D_likelihood(prms, RT, Theta, N_series):
    a = prms[0]
    ndt = prms[1]
    mu = np.array([prms[2], prms[3], prms[4]])
    
    tt = np.arange(0.001, max(RT)+0.02, 0.02)
    fpt = series_bessel_fpt(tt, a, sigma=1, nu=(mu.shape[0]-2)/2, n=N_series)
    
    log_lik = 0
    for i in range(len(RT)):
        rt, theta = RT[i], Theta[i]
        if rt - ndt > 0.001:
            mu_dot_x0 = mu[0]*np.cos(theta[0])
            mu_dot_x1 = mu[1]*np.sin(theta[0])*np.cos(theta[1]) 
            mu_dot_x2 = mu[2]*np.sin(theta[0])*np.sin(theta[1])
            term1 = prms[0] * (mu_dot_x0 + mu_dot_x1 + mu_dot_x2)
            term2 = 0.5 * np.linalg.norm(mu, 2)**2 * (rt-ndt)
            
            density = np.exp(term1 - term2) * fpt(rt - ndt)
            
            if 0.1**14 < density:
                log_lik += -np.log(density)
            else:
                log_lik += -np.log(0.1**14)
        else:
            log_lik += -np.log(0.1**14)
        
    return log_lik

In [5]:
recovery_df = {'threshold_true': [],
               'threshold_estimate': [],
               'ndt_true': [],
               'ndt_estimate': [],
               'mu1_true': [],
               'mu1_estimate': [],
               'mu2_true': [],
               'mu2_estimate': [],
               'mu3_true': [],
               'mu3_estimate': []}

min_threshold = 0.5
max_threshold = 5

min_ndt = 0.1
max_ndt = 1

min_mu = -3.5
max_mu = 3.5

N_series = 250

In [6]:
for n in tqdm(range(150)):
    threshold = np.random.uniform(min_threshold, max_threshold)
    a = lambda t: threshold
    ndt = np.random.uniform(min_ndt, max_ndt)
    mu = np.array([np.random.uniform(min_mu, max_mu), 
                   np.random.uniform(min_mu, max_mu),
                   np.random.uniform(min_mu, max_mu)])
    
    
    recovery_df['threshold_true'].append(threshold)
    recovery_df['ndt_true'].append(ndt)
    recovery_df['mu1_true'].append(mu[0])
    recovery_df['mu2_true'].append(mu[1])
    recovery_df['mu3_true'].append(mu[2])
    
    RT = []
    Theta = []

    for i in range(250):
        rt, theta = simulate_HSDM_3D(a, mu, ndt)
        RT.append(rt)
        Theta.append(theta)
    
    min_ans = minimize(HSDM_3D_likelihood,
                       args=(RT, Theta, N_series), 
                       x0=np.array([np.random.uniform(min_threshold, max_threshold),
                                    np.random.uniform(min_ndt, max_ndt), 
                                    np.random.uniform(min_mu, max_mu),
                                    np.random.uniform(min_mu, max_mu),
                                    np.random.uniform(min_mu, max_mu)]),
                       method='Nelder-Mead', 
                       bounds=[(min_threshold, max_threshold), (min_ndt, max_ndt),
                               (min_mu, max_mu), (min_mu, max_mu), (min_mu, max_mu)])
    
    recovery_df['threshold_estimate'].append(min_ans.x[0])
    recovery_df['ndt_estimate'].append(min_ans.x[1])
    recovery_df['mu1_estimate'].append(min_ans.x[2])
    recovery_df['mu2_estimate'].append(min_ans.x[3])
    recovery_df['mu3_estimate'].append(min_ans.x[4])
    
recovery_df = pd.DataFrame(recovery_df)

100%|█████████████████████████████████████████| 150/150 [3:26:52<00:00, 82.75s/it]


In [7]:
recovery_df

Unnamed: 0,threshold_true,threshold_estimate,ndt_true,ndt_estimate,mu1_true,mu1_estimate,mu2_true,mu2_estimate,mu3_true,mu3_estimate
0,3.192941,3.148793,0.610135,0.606508,2.851088,2.820996,2.645554,2.649220,-2.207012,-2.154169
1,2.207279,2.429406,0.144319,0.100000,1.997185,2.077014,0.804097,0.845160,-3.333568,-3.500000
2,4.071013,4.259570,0.182345,0.100000,-2.742439,-2.556222,-3.231009,-3.059486,0.832804,0.855717
3,3.285207,3.139824,0.876976,0.928341,0.666838,0.694480,1.567617,1.516116,-1.037872,-1.096138
4,1.680086,1.681089,0.343995,0.355204,0.169603,0.146222,3.124326,3.195164,-2.026703,-2.137408
...,...,...,...,...,...,...,...,...,...,...
145,4.952261,4.939724,0.396030,0.477438,0.685845,0.761195,2.704597,2.894478,1.793769,1.864061
146,1.848311,1.983089,0.324690,0.349254,2.919781,3.251222,-2.464588,-3.500000,-1.110184,-1.356782
147,3.065443,2.870472,0.278074,0.352210,0.904332,0.910534,1.931205,1.850032,-1.710729,-1.752619
148,1.382504,1.390763,0.548792,0.556503,2.514708,2.746689,-2.670994,-2.831821,1.971786,1.876971


In [8]:
recovery_df.corr()

Unnamed: 0,threshold_true,threshold_estimate,ndt_true,ndt_estimate,mu1_true,mu1_estimate,mu2_true,mu2_estimate,mu3_true,mu3_estimate
threshold_true,1.0,0.87577,0.010757,0.031627,0.026963,-0.010477,0.085842,0.039738,0.066812,0.041663
threshold_estimate,0.87577,1.0,0.028199,-0.101339,0.008972,-0.041841,0.047933,0.08893,0.083406,0.099207
ndt_true,0.010757,0.028199,1.0,0.846135,-0.018532,-0.030408,0.033677,0.02267,0.007699,0.008874
ndt_estimate,0.031627,-0.101339,0.846135,1.0,-0.029698,-0.009982,0.002654,0.02349,-0.070384,-0.041366
mu1_true,0.026963,0.008972,-0.018532,-0.029698,1.0,0.978219,-0.019997,-0.051765,0.085487,0.083304
mu1_estimate,-0.010477,-0.041841,-0.030408,-0.009982,0.978219,1.0,-0.048422,-0.067644,0.110211,0.105409
mu2_true,0.085842,0.047933,0.033677,0.002654,-0.019997,-0.048422,1.0,0.929614,0.068919,0.057068
mu2_estimate,0.039738,0.08893,0.02267,0.02349,-0.051765,-0.067644,0.929614,1.0,0.04996,0.077556
mu3_true,0.066812,0.083406,0.007699,-0.070384,0.085487,0.110211,0.068919,0.04996,1.0,0.976697
mu3_estimate,0.041663,0.099207,0.008874,-0.041366,0.083304,0.105409,0.057068,0.077556,0.976697,1.0


In [9]:
recovery_df.to_csv('Series_3d_recovery_{}.csv'.format(N_series))