In [1]:
import numpy as np
import pandas as pd
from tqdm import tqdm

from mpmath import besseljzero
from scipy.special import gamma
from scipy.special import jv, iv, ive
from scipy.optimize import differential_evolution
from scipy.interpolate import interp1d

from scipy.stats import pearsonr
from sklearn.metrics import r2_score

import seaborn as sns
import matplotlib.pyplot as plt

In [2]:
def simulate_HSDM_4D(a, mu, eta, ndt, sigma=1, dt=0.001):
    x = np.zeros(mu.shape)
    
    rt = 0
    
    mu_t = np.random.normal(mu, eta) 
    while np.linalg.norm(x, 2) < a(rt):
        x += mu_t*dt + sigma*np.sqrt(dt)*np.random.normal(0, 1, mu.shape)
        rt += dt
    
    theta1 = np.arctan2(np.sqrt(x[3]**2 + x[2]**2 + x[1]**2), x[0])
    theta2 = np.arctan2(np.sqrt(x[3]**2 + x[2]**2), x[1])
    theta3 = np.arctan2(x[3], x[2])
    
    
    return rt+ndt, (theta1, theta2, theta3)

In [3]:
def k(a, da, t, q, sigma=2):
    return 0.5 * (q - 0.5*sigma - da(t))

def psi(a, da, t, z, tau, q, sigma=2):
    kk = k(a, da, t, q, sigma)
    
    if 2*np.sqrt(a(t)*z)/(sigma*(t-tau))<=700:
        term1 = 1./(sigma*(t - tau)) * np.exp(- (a(t) + z)/(sigma*(t-tau)))
        term2 = (a(t)/z)**(0.5*(q-sigma)/sigma)
        term3 = da(t) - (a(t)/(t-tau)) + kk
        term4 = iv(q/sigma-1, 2*np.sqrt(a(t)*z)/(sigma*(t-tau)))
        term5 = (np.sqrt(a(t)*z)/(t-tau)) * iv(q/sigma, 2*np.sqrt(a(t)*z)/(sigma*(t-tau)))
    else:
        term1 = 1./(sigma*(t - tau))
        term2 = (a(t)/z)**(0.5*(q-sigma)/sigma)
        term3 = da(t) - (a(t)/(t-tau)) + kk
        term4 = ive(q/sigma-1, (a(t) + z)/(sigma*(t-tau)))
        term5 = (np.sqrt(a(t)*z)/(t-tau)) * ive(q/sigma, (a(t) + z)/(sigma*(t-tau)))
    
    return term1 * term2 * (term3 * term4 + term5)

def ie_bessel_fpt(a, da, q, z, sigma=2, dt=0.1, T_max=2):
    g = [0]
    T = [0]
    g.append(-2*psi(a, da, dt, z, 0, q, sigma))
    T.append(dt)
    
    for n in range(2, int(T_max/dt)+2):
        s = -2 * psi(a, da, n*dt, z, 0, q, sigma)

        for j in range(1, n):
            s += 2 * dt * g[j] * psi(a, da, n*dt, a(j*dt), j*dt, q, sigma)

        g.append(s)
        T.append(n*dt)
        
    g = np.asarray(g)
    T = np.asarray(T)
    
    gt = interp1d(T, g)
    return gt

In [4]:
def HSDM_4D_likelihood(prms, RT, Theta):
    a = lambda t: prms[0]**2
    da = lambda t: 0
    ndt = prms[1]
    
    eta = prms[2]
    eta2 = eta**2
    
    mu = np.array([prms[3], prms[4], prms[5], prms[6]])
    
    fpt = ie_bessel_fpt(a, da, mu.shape[0], 0.000001, 
                        dt=0.01, T_max=max(RT))
    log_lik = 0
    for i in range(len(RT)):
        rt, theta = RT[i], Theta[i]
        if rt - ndt > 0.001:
            
            x0 = prms[0]*np.cos(theta[0])
            x1 = prms[0]*np.sin(theta[0])*np.cos(theta[1]) 
            x2 = prms[0]*np.sin(theta[0])*np.sin(theta[1])*np.cos(theta[2])
            x3 = prms[0]*np.sin(theta[0])*np.sin(theta[1])*np.sin(theta[2])
            
            fixed = 1/(np.sqrt(eta2 * (rt - ndt) + 1))
            
            exponent0 = -0.5*mu[0]**2/eta2 + 0.5*(x0 * eta2 + mu[0])**2 / (eta2 * (eta2 * (rt - ndt) + 1))
            exponent1 = -0.5*mu[1]**2/eta2 + 0.5*(x1 * eta2 + mu[1])**2 / (eta2 * (eta2 * (rt - ndt) + 1))
            exponent2 = -0.5*mu[2]**2/eta2 + 0.5*(x2 * eta2 + mu[2])**2 / (eta2 * (eta2 * (rt - ndt) + 1))
            exponent3 = -0.5*mu[3]**2/eta2 + 0.5*(x3 * eta2 + mu[3])**2 / (eta2 * (eta2 * (rt - ndt) + 1))
            
            term1 = fixed * np.exp(exponent0)
            term2 = fixed * np.exp(exponent1)
            term3 = fixed * np.exp(exponent2)
            term4 = fixed * np.exp(exponent3)
            
            density = term1 * term2 * term3 * term4 * fpt(rt - ndt)
            
            if 0.1**14 < density:
                log_lik += -np.log(density)
            else:
                log_lik += -np.log(0.1**14)
        else:
            log_lik += -np.log(0.1**14)
        
    return log_lik

In [5]:
recovery_df = {'threshold_true': [],
               'threshold_estimate': [],
               'ndt_true': [],
               'ndt_estimate': [],
               'eta_true': [],
               'eta_estimate':[],
               'mu1_true': [],
               'mu1_estimate': [],
               'mu2_true': [],
               'mu2_estimate': [],
               'mu3_true': [],
               'mu3_estimate': [],
               'mu4_true': [],
               'mu4_estimate': []}

min_threshold = 0.5
max_threshold = 3

min_ndt = 0.1
max_ndt = 1

min_eta = 0.1
max_eta = 1

min_mu = -3
max_mu = 3

In [6]:
for n in tqdm(range(5)):
    threshold = np.random.uniform(min_threshold, max_threshold)
    a = lambda t: threshold
    ndt = np.random.uniform(min_ndt, max_ndt)
    eta = np.random.uniform(min_eta, max_eta)
    mu = np.array([np.random.uniform(min_mu, max_mu), 
                   np.random.uniform(min_mu, max_mu),
                   np.random.uniform(min_mu, max_mu),
                   np.random.uniform(min_mu, max_mu)])
    
    
    recovery_df['threshold_true'].append(threshold)
    recovery_df['ndt_true'].append(ndt)
    recovery_df['eta_true'].append(eta)
    recovery_df['mu1_true'].append(mu[0])
    recovery_df['mu2_true'].append(mu[1])
    recovery_df['mu3_true'].append(mu[2])
    recovery_df['mu4_true'].append(mu[3])
    
    RT = []
    Theta = []
    
    for i in range(500):
        rt, theta = simulate_HSDM_4D(a, mu, eta, ndt)
        RT.append(rt)
        Theta.append(theta)
    
    min_ans = differential_evolution(HSDM_4D_likelihood,
                                     args=(RT, Theta),
                                     bounds=[(min_threshold, max_threshold), 
                                             (min_ndt, max_ndt), (min_eta, max_eta),
                                             (min_mu, max_mu), (min_mu, max_mu), 
                                             (min_mu, max_mu), (min_mu, max_mu)])
    
    recovery_df['threshold_estimate'].append(min_ans.x[0])
    recovery_df['ndt_estimate'].append(min_ans.x[1])
    recovery_df['eta_estimate'].append(min_ans.x[2])
    recovery_df['mu1_estimate'].append(min_ans.x[3])
    recovery_df['mu2_estimate'].append(min_ans.x[4])
    recovery_df['mu3_estimate'].append(min_ans.x[5])
    recovery_df['mu4_estimate'].append(min_ans.x[6])

    
recovery_df = pd.DataFrame(recovery_df)

100%|████████████████████████████████████████████| 5/5 [23:45<00:00, 285.10s/it]


In [7]:
recovery_df

Unnamed: 0,threshold_true,threshold_estimate,ndt_true,ndt_estimate,eta_true,eta_estimate,mu1_true,mu1_estimate,mu2_true,mu2_estimate,mu3_true,mu3_estimate,mu4_true,mu4_estimate
0,1.293249,1.327734,0.92854,0.926192,0.208775,0.1,-1.097418,-1.091312,1.358304,1.374346,2.286987,2.268142,2.967237,2.880509
1,2.145888,1.973983,0.859941,0.899944,0.59896,0.215803,0.030441,0.078386,-2.099103,-2.078556,-1.820041,-1.806755,1.915581,1.856958
2,1.722134,1.711204,0.614894,0.609394,0.29841,0.1,1.550843,1.405455,2.435683,2.32185,-2.688826,-2.620245,-0.100464,-0.120283
3,1.35153,1.321854,0.103183,0.116307,0.219748,0.1,1.654393,1.600805,1.872562,1.984876,-2.323047,-2.3481,-0.072341,-0.100474
4,2.939803,2.541756,0.953311,0.984084,0.901557,0.429686,2.460695,2.193653,-2.676407,-2.338371,0.037241,-0.000238,-1.152377,-0.974255


In [8]:
file_name = 'IE_4d_recovery_dvar.csv'
# old_recovery_data = pd.read_csv(file_name, index_col=0)
# recovery_df = pd.concat([old_recovery_data, 
#                          recovery_df]).reset_index(drop=True)
recovery_df.to_csv(file_name)

In [9]:
recovery_df.corr()

Unnamed: 0,threshold_true,threshold_estimate,ndt_true,ndt_estimate,eta_true,eta_estimate,mu1_true,mu1_estimate,mu2_true,mu2_estimate,mu3_true,mu3_estimate,mu4_true,mu4_estimate
threshold_true,1.0,0.996739,0.524773,0.555515,0.989047,0.963834,0.539312,0.520425,-0.862315,-0.853104,-0.015463,-0.019732,-0.534376,-0.505978
threshold_estimate,0.996739,1.0,0.552321,0.580819,0.978144,0.943092,0.52682,0.50678,-0.841362,-0.835367,-0.031079,-0.033998,-0.523671,-0.495964
ndt_true,0.524773,0.552321,1.0,0.998554,0.54679,0.517642,-0.34703,-0.38317,-0.596922,-0.61337,0.619192,0.625336,0.323713,0.355443
ndt_estimate,0.555515,0.580819,0.998554,1.0,0.581541,0.54856,-0.330827,-0.365551,-0.638037,-0.654289,0.601862,0.607386,0.310508,0.342771
eta_true,0.989047,0.978144,0.54679,0.581541,1.0,0.975884,0.457624,0.441282,-0.925228,-0.915967,0.046805,0.041214,-0.447812,-0.417617
eta_estimate,0.963834,0.943092,0.517642,0.54856,0.975884,1.0,0.493533,0.471607,-0.881003,-0.861019,0.171085,0.163774,-0.492603,-0.459824
mu1_true,0.539312,0.52682,-0.34703,-0.330827,0.457624,0.493533,1.0,0.998577,-0.153069,-0.123334,-0.536258,-0.542912,-0.997673,-0.997668
mu1_estimate,0.520425,0.50678,-0.38317,-0.365551,0.441282,0.471607,0.998577,1.0,-0.14582,-0.11685,-0.572202,-0.57913,-0.993032,-0.994456
mu2_true,-0.862315,-0.841362,-0.596922,-0.638037,-0.925228,-0.881003,-0.153069,-0.14582,1.0,0.998471,-0.165892,-0.159209,0.129935,0.099184
mu2_estimate,-0.853104,-0.835367,-0.61337,-0.654289,-0.915967,-0.861019,-0.123334,-0.11685,0.998471,1.0,-0.152838,-0.146856,0.09913,0.069004
