In [1]:
import numpy as np
import pandas as pd
from tqdm import tqdm

from mpmath import besseljzero
from scipy.special import gamma
from scipy.special import jv, iv
from scipy.special import jn_zeros
from scipy.optimize import minimize
from scipy.interpolate import interp1d

from scipy.stats import pearsonr
from sklearn.metrics import r2_score

import seaborn as sns
import matplotlib.pyplot as plt

In [2]:
def simulate_HSDM_3D(a, mu, ndt, sigma=1, dt=0.001):
    x = np.zeros(mu.shape)
    
    rt = 0
    
    while np.linalg.norm(x, 2) < a(rt):
        x += mu*dt + sigma*np.sqrt(dt)*np.random.normal(0, 1, mu.shape)
        rt += dt
    
    theta1 = np.arctan2(np.sqrt(x[2]**2 + x[1]**2), x[0])
    theta2 = np.arctan2(x[2], x[1])   
    
    return rt+ndt, (theta1, theta2)

In [3]:
def series_bessel_fpt(t, a=1, sigma=1, nu=0, n=100):
    zeros = np.asarray([float(besseljzero(nu, i+1)) for i in range(n)])
    fpt = np.zeros(t.shape)
    
    for i in range(t.shape[0]):
        series = np.sum((zeros**(nu+1)/jv(nu+1, zeros)) * np.exp(-(zeros**2 * sigma**2)/(2*a**2)*t[i]))
        fpt[i] = sigma**2/(2**nu * a**2 * gamma(nu + 1)) * series
        
    return interp1d(t, fpt)

In [4]:
def HSDM_3D_likelihood(prms, RT, Theta, N_series):
    a = prms[0]
    ndt = prms[1]
    mu = np.array([prms[2], prms[3], prms[4]])
    
    tt = np.arange(0.001, max(RT)+0.02, 0.02)
    fpt = series_bessel_fpt(tt, a, sigma=1, nu=(mu.shape[0]-2)/2, n=N_series)
    
    log_lik = 0
    for i in range(len(RT)):
        rt, theta = RT[i], Theta[i]
        if rt - ndt > 0.001:
            mu_dot_x0 = mu[0]*np.cos(theta[0])
            mu_dot_x1 = mu[1]*np.sin(theta[0])*np.cos(theta[1]) 
            mu_dot_x2 = mu[2]*np.sin(theta[0])*np.sin(theta[1])
            term1 = prms[0] * (mu_dot_x0 + mu_dot_x1 + mu_dot_x2)
            term2 = 0.5 * np.linalg.norm(mu, 2)**2 * (rt-ndt)
            
            density = np.exp(term1 - term2) * fpt(rt - ndt)
            
            if 0.1**14 < density:
                log_lik += -np.log(density)
            else:
                log_lik += -np.log(0.1**14)
        else:
            log_lik += -np.log(0.1**14)
        
    return log_lik

In [5]:
recovery_df = {'threshold_true': [],
               'threshold_estimate': [],
               'ndt_true': [],
               'ndt_estimate': [],
               'mu1_true': [],
               'mu1_estimate': [],
               'mu2_true': [],
               'mu2_estimate': [],
               'mu3_true': [],
               'mu3_estimate': []}

min_threshold = 0.5
max_threshold = 5

min_ndt = 0.1
max_ndt = 1

min_mu = -3.5
max_mu = 3.5

N_series = 150

In [6]:
for n in tqdm(range(150)):
    threshold = np.random.uniform(min_threshold, max_threshold)
    a = lambda t: threshold
    ndt = np.random.uniform(min_ndt, max_ndt)
    mu = np.array([np.random.uniform(min_mu, max_mu), 
                   np.random.uniform(min_mu, max_mu),
                   np.random.uniform(min_mu, max_mu)])
    
    
    recovery_df['threshold_true'].append(threshold)
    recovery_df['ndt_true'].append(ndt)
    recovery_df['mu1_true'].append(mu[0])
    recovery_df['mu2_true'].append(mu[1])
    recovery_df['mu3_true'].append(mu[2])
    
    RT = []
    Theta = []

    for i in range(250):
        rt, theta = simulate_HSDM_3D(a, mu, ndt)
        RT.append(rt)
        Theta.append(theta)
    
    min_ans = minimize(HSDM_3D_likelihood,
                       args=(RT, Theta, N_series), 
                       x0=np.array([np.random.uniform(min_threshold, max_threshold),
                                    np.random.uniform(min_ndt, max_ndt), 
                                    np.random.uniform(min_mu, max_mu),
                                    np.random.uniform(min_mu, max_mu),
                                    np.random.uniform(min_mu, max_mu)]),
                       method='Nelder-Mead', 
                       bounds=[(min_threshold, max_threshold), (min_ndt, max_ndt),
                               (min_mu, max_mu), (min_mu, max_mu), (min_mu, max_mu)])
    
    recovery_df['threshold_estimate'].append(min_ans.x[0])
    recovery_df['ndt_estimate'].append(min_ans.x[1])
    recovery_df['mu1_estimate'].append(min_ans.x[2])
    recovery_df['mu2_estimate'].append(min_ans.x[3])
    recovery_df['mu3_estimate'].append(min_ans.x[4])
    
recovery_df = pd.DataFrame(recovery_df)

100%|████████████████████████████| 150/150 [6:08:50<00:00, 147.53s/it]


In [7]:
recovery_df

Unnamed: 0,threshold_true,threshold_estimate,ndt_true,ndt_estimate,mu1_true,mu1_estimate,mu2_true,mu2_estimate,mu3_true,mu3_estimate
0,2.445474,2.532671,0.164011,0.155420,-3.386274,-3.499999,-1.492547,-1.449987,0.433016,0.361861
1,3.117165,3.140063,0.892896,0.869209,3.037887,2.987335,-1.696465,-1.776658,-2.986426,-2.967020
2,2.315522,2.180256,0.108553,0.127248,-2.682811,-2.546266,-1.198112,-1.136595,-3.059908,-2.941244
3,4.488260,4.567201,0.506245,0.571821,1.456814,1.623724,0.439113,0.354301,-1.691600,-1.850352
4,2.543233,2.503997,0.504439,0.522007,-2.570532,-2.483641,-0.409616,-0.375058,0.047079,0.026178
...,...,...,...,...,...,...,...,...,...,...
145,4.701030,4.977795,0.281930,0.199771,-0.876935,-0.865101,1.629700,1.666668,1.501565,1.582305
146,0.500648,2.543024,0.177193,0.789774,3.499408,-2.292772,-1.821980,-3.353737,-3.299560,-2.372305
147,2.244178,0.500000,0.277160,0.565481,-2.391563,-1.081067,2.879844,1.286292,-0.244812,-0.057112
148,0.699176,0.718306,0.936166,0.940462,2.762461,3.500000,1.671904,1.434693,-3.429415,-3.500000


In [8]:
recovery_df.corr()

Unnamed: 0,threshold_true,threshold_estimate,ndt_true,ndt_estimate,mu1_true,mu1_estimate,mu2_true,mu2_estimate,mu3_true,mu3_estimate
threshold_true,1.0,0.825599,0.077231,0.051817,-0.076031,-0.005053,-0.030338,-0.038908,0.071038,0.086135
threshold_estimate,0.825599,1.0,0.022727,-0.068654,-0.046264,0.052097,-0.089155,-0.113263,0.086701,0.096374
ndt_true,0.077231,0.022727,1.0,0.787879,0.109547,0.200786,-0.088911,-0.107601,-0.013698,-0.000628
ndt_estimate,0.051817,-0.068654,0.787879,1.0,0.04499,0.026945,0.042927,-0.001683,-0.050211,-0.037315
mu1_true,-0.076031,-0.046264,0.109547,0.04499,1.0,0.831298,-0.018903,0.013154,0.057104,0.05799
mu1_estimate,-0.005053,0.052097,0.200786,0.026945,0.831298,1.0,-0.07901,-0.059263,0.097119,0.086984
mu2_true,-0.030338,-0.089155,-0.088911,0.042927,-0.018903,-0.07901,1.0,0.921284,0.008301,0.040919
mu2_estimate,-0.038908,-0.113263,-0.107601,-0.001683,0.013154,-0.059263,0.921284,1.0,-0.00913,-0.002275
mu3_true,0.071038,0.086701,-0.013698,-0.050211,0.057104,0.097119,0.008301,-0.00913,1.0,0.942363
mu3_estimate,0.086135,0.096374,-0.000628,-0.037315,0.05799,0.086984,0.040919,-0.002275,0.942363,1.0


In [9]:
recovery_df.to_csv('Series_3d_recovery_{}.csv'.format(N_series))