In [1]:
import numpy as np
import pandas as pd
from tqdm import tqdm

from mpmath import besseljzero
from scipy.special import gamma
from scipy.special import jv, iv, ive
from scipy.optimize import differential_evolution
from scipy.interpolate import interp1d

from scipy.stats import pearsonr
from sklearn.metrics import r2_score

import seaborn as sns
import matplotlib.pyplot as plt

In [2]:
def simulate_HSDM_4D(a, mu, ndt, sigma=1, dt=0.001):
    x = np.zeros(mu.shape)
    
    rt = 0
    
    while np.linalg.norm(x, 2) < a(rt):
        x += mu*dt + sigma*np.sqrt(dt)*np.random.normal(0, 1, mu.shape)
        rt += dt
    
    theta1 = np.arctan2(np.sqrt(x[3]**2 + x[2]**2 + x[1]**2), x[0])
    theta2 = np.arctan2(np.sqrt(x[3]**2 + x[2]**2), x[1])
    theta3 = np.arctan2(x[3], x[2])
    
    
    return rt+ndt, (theta1, theta2, theta3)

In [3]:
def k(a, da, t, q, sigma=2):
    return 0.5 * (q - 0.5*sigma - da(t))

def psi(a, da, t, z, tau, q, sigma=2):
    kk = k(a, da, t, q, sigma)
    
    if 2*np.sqrt(a(t)*z)/(sigma*(t-tau))<=700:
        term1 = 1./(sigma*(t - tau)) * np.exp(- (a(t) + z)/(sigma*(t-tau)))
        term2 = (a(t)/z)**(0.5*(q-sigma)/sigma)
        term3 = da(t) - (a(t)/(t-tau)) + kk
        term4 = iv(q/sigma-1, 2*np.sqrt(a(t)*z)/(sigma*(t-tau)))
        term5 = (np.sqrt(a(t)*z)/(t-tau)) * iv(q/sigma, 2*np.sqrt(a(t)*z)/(sigma*(t-tau)))
    else:
        term1 = 1./(sigma*(t - tau))
        term2 = (a(t)/z)**(0.5*(q-sigma)/sigma)
        term3 = da(t) - (a(t)/(t-tau)) + kk
        term4 = ive(q/sigma-1, (a(t) + z)/(sigma*(t-tau)))
        term5 = (np.sqrt(a(t)*z)/(t-tau)) * ive(q/sigma, (a(t) + z)/(sigma*(t-tau)))
    
    return term1 * term2 * (term3 * term4 + term5)

def ie_bessel_fpt(a, da, q, z, sigma=2, dt=0.1, T_max=2):
    g = [0]
    T = [0]
    g.append(-2*psi(a, da, dt, z, 0, q, sigma))
    T.append(dt)
    
    for n in range(2, int(T_max/dt)+2):
        s = -2 * psi(a, da, n*dt, z, 0, q, sigma)

        for j in range(1, n):
            s += 2 * dt * g[j] * psi(a, da, n*dt, a(j*dt), j*dt, q, sigma)

        g.append(s)
        T.append(n*dt)
        
    g = np.asarray(g)
    T = np.asarray(T)
    
    gt = interp1d(T, g)
    return gt

In [4]:
def HSDM_4D_likelihood(prms, RT, Theta):
    a = lambda t: prms[0]**2
    da = lambda t: 0
    ndt = prms[1]
    mu = np.array([prms[2], prms[3], prms[4], prms[5]])
    
    fpt = ie_bessel_fpt(a, da, mu.shape[0], 0.0001, 
                        dt=0.05, T_max=max(RT))
    log_lik = 0
    for i in range(len(RT)):
        rt, theta = RT[i], Theta[i]
        if rt - ndt > 0.001:
            mu_dot_x0 = mu[0]*np.cos(theta[0])
            mu_dot_x1 = mu[1]*np.sin(theta[0])*np.cos(theta[1]) 
            mu_dot_x2 = mu[2]*np.sin(theta[0])*np.sin(theta[1])*np.cos(theta[2])
            mu_dot_x3 = mu[3]*np.sin(theta[0])*np.sin(theta[1])*np.sin(theta[2]) 
            term1 = prms[0] * (mu_dot_x0 + mu_dot_x1 + mu_dot_x2 + mu_dot_x3)
            term2 = 0.5 * np.linalg.norm(mu, 2)**2 * (rt-ndt)
            
            density = np.exp(term1 - term2) * fpt(rt - ndt)
            
            if 0.1**14 < density:
                log_lik += -np.log(density)
            else:
                log_lik += -np.log(0.1**14)
        else:
            log_lik += -np.log(0.1**14)
        
    return log_lik

In [5]:
recovery_df = {'threshold_true': [],
               'threshold_estimate': [],
               'ndt_true': [],
               'ndt_estimate': [],
               'mu1_true': [],
               'mu1_estimate': [],
               'mu2_true': [],
               'mu2_estimate': [],
               'mu3_true': [],
               'mu3_estimate': [],
               'mu4_true': [],
               'mu4_estimate': []}

min_threshold = 0.5
max_threshold = 6

min_ndt = 0.1
max_ndt = 1

min_mu = -6
max_mu = 6

In [6]:
for n in tqdm(range(50)):
    threshold = np.random.uniform(min_threshold, max_threshold)
    a = lambda t: threshold
    ndt = np.random.uniform(min_ndt, max_ndt)
    mu = np.array([np.random.uniform(min_mu, max_mu), 
                   np.random.uniform(min_mu, max_mu),
                   np.random.uniform(min_mu, max_mu),
                   np.random.uniform(min_mu, max_mu)])
    
    RT = []
    Theta = []
    
    for i in range(50):
        rt, theta = simulate_HSDM_4D(a, mu, ndt)
        RT.append(rt)
        Theta.append(theta)
    
    min_ans = differential_evolution(HSDM_4D_likelihood,
                                     args=(RT, Theta),
                                     bounds=[(min_threshold, max_threshold), (min_ndt, max_ndt),
                                             (min_mu, max_mu), (min_mu, max_mu), 
                                             (min_mu, max_mu), (min_mu, max_mu)])
    
    
    
    recovery_df['threshold_true'].append(threshold)
    recovery_df['ndt_true'].append(ndt)
    recovery_df['mu1_true'].append(mu[0])
    recovery_df['mu2_true'].append(mu[1])
    recovery_df['mu3_true'].append(mu[2])
    recovery_df['mu4_true'].append(mu[3])
    recovery_df['threshold_estimate'].append(min_ans.x[0])
    recovery_df['ndt_estimate'].append(min_ans.x[1])
    recovery_df['mu1_estimate'].append(min_ans.x[2])
    recovery_df['mu2_estimate'].append(min_ans.x[3])
    recovery_df['mu3_estimate'].append(min_ans.x[4])
    recovery_df['mu4_estimate'].append(min_ans.x[5])

    
recovery_df = pd.DataFrame(recovery_df)

100%|███████████████████████████████████| 50/50 [15:58<00:00, 19.17s/it]


In [7]:
recovery_df = pd.DataFrame(recovery_df)
recovery_df.corr()

Unnamed: 0,threshold_true,threshold_estimate,ndt_true,ndt_estimate,mu1_true,mu1_estimate,mu2_true,mu2_estimate,mu3_true,mu3_estimate,mu4_true,mu4_estimate
threshold_true,1.0,0.97777,0.109322,0.127765,0.127889,0.141045,0.016227,0.023341,0.055538,0.02438,0.276834,0.279753
threshold_estimate,0.97777,1.0,0.158712,0.145628,0.143732,0.154863,0.032842,0.04408,0.069996,0.039692,0.307801,0.313443
ndt_true,0.109322,0.158712,1.0,0.975564,-0.170074,-0.171245,0.188747,0.19999,0.08787,0.10204,-0.107335,-0.10311
ndt_estimate,0.127765,0.145628,0.975564,1.0,-0.177372,-0.176159,0.169533,0.172933,0.109551,0.121735,-0.134512,-0.129781
mu1_true,0.127889,0.143732,-0.170074,-0.177372,1.0,0.991888,-0.221103,-0.217659,0.011558,-0.021947,0.109342,0.10395
mu1_estimate,0.141045,0.154863,-0.171245,-0.176159,0.991888,1.0,-0.223446,-0.218515,-0.008342,-0.040407,0.129737,0.127654
mu2_true,0.016227,0.032842,0.188747,0.169533,-0.221103,-0.223446,1.0,0.991607,-0.074752,-0.060866,-0.240382,-0.246911
mu2_estimate,0.023341,0.04408,0.19999,0.172933,-0.217659,-0.218515,0.991607,1.0,-0.083931,-0.061613,-0.222746,-0.229906
mu3_true,0.055538,0.069996,0.08787,0.109551,0.011558,-0.008342,-0.074752,-0.083931,1.0,0.987977,0.012661,0.009505
mu3_estimate,0.02438,0.039692,0.10204,0.121735,-0.021947,-0.040407,-0.060866,-0.061613,0.987977,1.0,-0.021155,-0.024391


In [8]:
file_name = '_Recovery_data/IE_4d_recovery_50_05.csv'
old_recovery_data = pd.read_csv(file_name, index_col=0)
recovery_df = pd.concat([old_recovery_data, 
                         recovery_df]).reset_index(drop=True)
recovery_df.to_csv(file_name)

In [9]:
recovery_df

Unnamed: 0,threshold_true,threshold_estimate,ndt_true,ndt_estimate,mu1_true,mu1_estimate,mu2_true,mu2_estimate,mu3_true,mu3_estimate,mu4_true,mu4_estimate
0,0.955647,0.768212,0.479945,0.535619,5.439806,6.000000,1.269493,1.557971,-3.132506,-4.594671,-0.277591,-1.265591
1,2.233960,2.260804,0.545410,0.549004,-2.401942,-2.788096,2.840203,2.531524,-0.208332,-0.210184,-3.168785,-3.181227
2,2.300508,2.114477,0.163589,0.224445,1.126567,0.916940,5.702452,6.000000,-3.151843,-3.704167,1.795141,2.479317
3,5.977190,5.947861,0.254729,0.389449,-2.473553,-2.921790,4.753852,5.537138,0.505275,0.072380,-1.716602,-2.062735
4,3.160574,2.846658,0.293936,0.412698,3.100788,3.485988,5.062195,6.000000,-1.145540,-1.740464,-0.446529,-0.523940
...,...,...,...,...,...,...,...,...,...,...,...,...
295,0.537408,0.644490,0.153031,0.165690,-1.971534,-2.518332,-0.184871,0.035771,-5.385164,-6.000000,-1.205725,-1.116729
296,3.934655,4.507958,0.528554,0.516906,-3.680675,-3.954843,-5.661021,-6.000000,5.489436,6.000000,5.486504,6.000000
297,4.865870,4.814764,0.578389,0.710119,3.375253,4.065767,-4.170853,-4.728915,2.332859,2.867401,0.290630,0.485440
298,3.156013,3.608352,0.987429,0.952088,1.982967,1.915769,0.841563,0.728991,-5.666322,-5.625449,-5.150596,-5.063268
