In [1]:
import numpy as np
import pandas as pd
from tqdm import tqdm

from mpmath import besseljzero
from scipy.special import gamma
from scipy.special import jv, iv, ive
from scipy.optimize import minimize
from scipy.optimize import differential_evolution
from scipy.interpolate import interp1d

from scipy.stats import pearsonr
from sklearn.metrics import r2_score

import seaborn as sns
import matplotlib.pyplot as plt

In [2]:
def simulate_HSDM_4D(a, mu, ndt, sigma=1, dt=0.001):
    x = np.zeros(mu.shape)
    
    rt = 0
    
    while np.linalg.norm(x, 2) < a(rt):
        x += mu*dt + sigma*np.sqrt(dt)*np.random.normal(0, 1, mu.shape)
        rt += dt
    
    theta1 = np.arctan2(np.sqrt(x[3]**2 + x[2]**2 + x[1]**2), x[0])
    theta2 = np.arctan2(np.sqrt(x[3]**2 + x[2]**2), x[1])
    theta3 = np.arctan2(x[3], x[2])
    
    
    return rt+ndt, (theta1, theta2, theta3)

In [3]:
def k(a, da, t, q, sigma=2):
    return 0.5 * (q - 0.5*sigma - da(t))

def psi(a, da, t, z, tau, q, sigma=2):
    kk = k(a, da, t, q, sigma)
    
    if 2*np.sqrt(a(t)*z)/(sigma*(t-tau))<=700:
        term1 = 1./(sigma*(t - tau)) * np.exp(- (a(t) + z)/(sigma*(t-tau)))
        term2 = (a(t)/z)**(0.5*(q-sigma)/sigma)
        term3 = da(t) - (a(t)/(t-tau)) + kk
        term4 = iv(q/sigma-1, 2*np.sqrt(a(t)*z)/(sigma*(t-tau)))
        term5 = (np.sqrt(a(t)*z)/(t-tau)) * iv(q/sigma, 2*np.sqrt(a(t)*z)/(sigma*(t-tau)))
    else:
        term1 = 1./(sigma*(t - tau))
        term2 = (a(t)/z)**(0.5*(q-sigma)/sigma)
        term3 = da(t) - (a(t)/(t-tau)) + kk
        term4 = ive(q/sigma-1, (a(t) + z)/(sigma*(t-tau)))
        term5 = (np.sqrt(a(t)*z)/(t-tau)) * ive(q/sigma, (a(t) + z)/(sigma*(t-tau)))
    
    return term1 * term2 * (term3 * term4 + term5)

def ie_bessel_fpt(a, da, q, z, sigma=2, dt=0.1, T_max=2):
    g = [0]
    T = [0]
    g.append(-2*psi(a, da, dt, z, 0, q, sigma))
    T.append(dt)
    
    for n in range(2, int(T_max/dt)+2):
        s = -2 * psi(a, da, n*dt, z, 0, q, sigma)

        for j in range(1, n):
            if a(j*dt) == 0:
                continue
            
            s += 2 * dt * g[j] * psi(a, da, n*dt, a(j*dt), j*dt, q, sigma)

        g.append(s)
        T.append(n*dt)
        
    g = np.asarray(g)
    T = np.asarray(T)
    
    gt = interp1d(T, g)
    return gt

In [4]:
def HSDM_4D_likelihood(prms, RT, Theta):
    a = lambda t: prms[0] - prms[1]*t
    a2 = lambda t: (a(t))**2
    da2 = lambda t: -2*prms[1] * a(t)
    
    ndt = prms[2]
    mu = np.array([prms[3], prms[4], prms[5], prms[6]])
    
    T_max = min(max(RT), prms[0]/prms[1])
    fpt = ie_bessel_fpt(a2, da2, mu.shape[0], 0.000001, 
                        dt=0.02, T_max=T_max)
    
    log_lik = 0
    for i in range(len(RT)):
        rt, theta = RT[i], Theta[i]
        if rt - ndt > 0.0001 and rt - ndt < T_max:
            mu_dot_x0 = mu[0]*np.cos(theta[0])
            mu_dot_x1 = mu[1]*np.sin(theta[0])*np.cos(theta[1]) 
            mu_dot_x2 = mu[2]*np.sin(theta[0])*np.sin(theta[1])*np.cos(theta[2])
            mu_dot_x3 = mu[3]*np.sin(theta[0])*np.sin(theta[1])*np.sin(theta[2]) 
            term1 = a(rt - ndt) * (mu_dot_x0 + mu_dot_x1 + mu_dot_x2 + mu_dot_x3)
            term2 = 0.5 * np.linalg.norm(mu, 2)**2 * (rt-ndt)
            
            density = np.exp(term1 - term2) * fpt(rt - ndt)
            
            if 0.1**14 < density:
                log_lik += -np.log(density)
            else:
                log_lik += -np.log(0.1**14)
        else:
            log_lik += -np.log(0.1**14)
        
    return log_lik

In [5]:
recovery_df = {'b0_true': [],
               'b0_estimate': [],
               'lambda_true': [],
               'lambda_estimate': [],
               'ndt_true': [],
               'ndt_estimate': [],
               'mu1_true': [],
               'mu1_estimate': [],
               'mu2_true': [],
               'mu2_estimate': [],
               'mu3_true': [],
               'mu3_estimate': [],
               'mu4_true': [],
               'mu4_estimate': []}

min_b0 = 2
max_b0 = 5

min_lambda = .1
max_lambda = 2

min_ndt = 0.1
max_ndt = 1

min_mu = -3
max_mu = 3

In [6]:
for n in tqdm(range(100)):
    b0 = np.random.uniform(min_b0, max_b0)
    lamb = np.random.uniform(min_lambda, max_lambda)
    a = lambda t: b0 - lamb*t
    
    ndt = np.random.uniform(min_ndt, max_ndt)
    mu = np.array([np.random.uniform(min_mu, max_mu), 
                   np.random.uniform(min_mu, max_mu),
                   np.random.uniform(min_mu, max_mu),
                   np.random.uniform(min_mu, max_mu)])
    
    RT = []
    Theta = []
    
    for i in range(500):
        rt, theta = simulate_HSDM_4D(a, mu, ndt)
        RT.append(rt)
        Theta.append(theta)
    
    min_ans = differential_evolution(HSDM_4D_likelihood,
                                     args=(RT, Theta), 
                                     bounds=[(min_b0, max_b0), (min_lambda, max_lambda),
                                             (min_ndt, max_ndt), (min_mu, max_mu), 
                                             (min_mu, max_mu), (min_mu, max_mu), (min_mu, max_mu)])
    
    min_ans = minimize(HSDM_4D_likelihood,
                       args=(RT, Theta),
                       method='nelder-mead',
                       x0=min_ans.x,
                       bounds=[(min_b0, max_b0), (min_lambda, max_lambda), (min_ndt, max_ndt),
                                (min_mu, max_mu), (min_mu, max_mu), (min_mu, max_mu), (min_mu, max_mu)])    
    
    if min_ans.success:
        recovery_df['b0_true'].append(b0)
        recovery_df['lambda_true'].append(lamb)
        recovery_df['ndt_true'].append(ndt)
        recovery_df['mu1_true'].append(mu[0])
        recovery_df['mu2_true'].append(mu[1])
        recovery_df['mu3_true'].append(mu[2])
        recovery_df['mu4_true'].append(mu[3])

        recovery_df['b0_estimate'].append(min_ans.x[0])
        recovery_df['lambda_estimate'].append(min_ans.x[1])
        recovery_df['ndt_estimate'].append(min_ans.x[2])
        recovery_df['mu1_estimate'].append(min_ans.x[3])
        recovery_df['mu2_estimate'].append(min_ans.x[4])
        recovery_df['mu3_estimate'].append(min_ans.x[5])
        recovery_df['mu4_estimate'].append(min_ans.x[6])

100%|██████████████████████████████████████| 100/100 [9:03:28<00:00, 326.08s/it]


In [7]:
recovery_df = pd.DataFrame(recovery_df)
recovery_df.corr()

Unnamed: 0,b0_true,b0_estimate,lambda_true,lambda_estimate,ndt_true,ndt_estimate,mu1_true,mu1_estimate,mu2_true,mu2_estimate,mu3_true,mu3_estimate,mu4_true,mu4_estimate
b0_true,1.0,0.884774,0.012622,-0.249036,0.049831,0.134694,-0.015502,-0.017164,0.201504,0.196411,-0.127076,-0.117661,0.062552,0.061739
b0_estimate,0.884774,1.0,-0.188778,-0.187494,0.050102,0.03629,0.047806,0.04268,0.178559,0.173288,-0.142916,-0.133991,0.062642,0.057324
lambda_true,0.012622,-0.188778,1.0,0.818067,0.119382,0.190065,-0.048918,-0.04744,0.041882,0.043982,-0.067154,-0.063081,-0.072726,-0.076806
lambda_estimate,-0.249036,-0.187494,0.818067,1.0,0.130448,0.091806,0.044119,0.042456,0.031055,0.033618,-0.093187,-0.091435,-0.089048,-0.100092
ndt_true,0.049831,0.050102,0.119382,0.130448,1.0,0.975218,0.001463,0.002797,0.135581,0.127444,0.089479,0.096123,-0.082797,-0.082441
ndt_estimate,0.134694,0.03629,0.190065,0.091806,0.975218,1.0,-0.02301,-0.020302,0.159623,0.151929,0.079904,0.086984,-0.073382,-0.071633
mu1_true,-0.015502,0.047806,-0.048918,0.044119,0.001463,-0.02301,1.0,0.998793,-0.037565,-0.045847,-0.051277,-0.045132,-0.053983,-0.056074
mu1_estimate,-0.017164,0.04268,-0.04744,0.042456,0.002797,-0.020302,0.998793,1.0,-0.044134,-0.052146,-0.042877,-0.037208,-0.056453,-0.058644
mu2_true,0.201504,0.178559,0.041882,0.031055,0.135581,0.159623,-0.037565,-0.044134,1.0,0.99914,-0.239961,-0.238092,-0.153328,-0.158095
mu2_estimate,0.196411,0.173288,0.043982,0.033618,0.127444,0.151929,-0.045847,-0.052146,0.99914,1.0,-0.237315,-0.236186,-0.148185,-0.152862


In [8]:
file_name = '_Recovery_data/IE_4d_recovery_linear_500_02.csv'
old_recovery_data = pd.read_csv(file_name, index_col=0)
recovery_df = pd.concat([old_recovery_data, 
                         recovery_df]).reset_index(drop=True)
recovery_df.to_csv(file_name)

In [9]:
recovery_df

Unnamed: 0,b0_true,b0_estimate,lambda_true,lambda_estimate,ndt_true,ndt_estimate,mu1_true,mu1_estimate,mu2_true,mu2_estimate,mu3_true,mu3_estimate,mu4_true,mu4_estimate
0,3.422841,3.769044,0.815469,1.251666,0.287800,0.256546,-0.585328,-0.676373,-2.832156,-2.761936,-0.131006,-0.070862,2.997387,2.882531
1,2.814087,2.486088,1.685831,1.367489,0.313983,0.359983,1.244327,1.212591,-1.198178,-1.385737,-1.514009,-1.588650,-1.580492,-1.602869
2,4.570897,4.355990,1.367959,1.152510,0.156177,0.201790,2.754176,2.824698,1.789119,1.780158,1.285610,1.473571,-1.221456,-1.274268
3,2.280701,2.354089,1.906957,2.000000,0.113419,0.104908,2.910419,2.823183,-1.507368,-1.404064,2.704223,2.705331,0.775741,0.820171
4,2.192476,2.128044,1.772249,1.908600,0.973431,0.991255,0.609228,0.651968,-0.743107,-0.678064,-2.327720,-2.287195,-2.159586,-2.232033
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
295,3.089591,2.938726,0.377076,0.324561,0.939712,1.000000,1.396021,1.479454,-0.135786,-0.080920,-0.380484,-0.482018,0.379330,0.380260
296,2.749646,2.819186,1.877150,1.919034,0.951353,0.953036,-2.188780,-2.179164,-1.887883,-1.872944,-0.146213,-0.179366,-1.580299,-1.527193
297,3.630904,4.097311,0.129542,0.521190,0.613756,0.577421,-2.601748,-2.594564,2.129972,2.127311,-1.497618,-1.416354,-2.405355,-2.347214
298,2.984006,2.842214,1.189734,0.930759,0.627716,0.640421,2.924702,2.999835,1.033432,1.011489,-2.496244,-2.553327,-0.365767,-0.288793
