In [1]:
import numpy as np
import pandas as pd
from tqdm import tqdm

from mpmath import besseljzero
from scipy.special import gamma
from scipy.special import jv, iv, ive
from scipy.optimize import differential_evolution
from scipy.interpolate import interp1d

from scipy.stats import pearsonr
from sklearn.metrics import r2_score

import seaborn as sns
import matplotlib.pyplot as plt

In [2]:
def simulate_HSDM_4D(a, mu, eta, ndt, sigma=1, dt=0.001):
    x = np.zeros(mu.shape)
    
    rt = 0
    
    mu_t = np.random.normal(mu, eta)
    while np.linalg.norm(x, 2) < a(rt):
        x += mu_t*dt + sigma*np.sqrt(dt)*np.random.normal(0, 1, mu.shape)
        rt += dt
    
    theta1 = np.arctan2(np.sqrt(x[3]**2 + x[2]**2 + x[1]**2), x[0])
    theta2 = np.arctan2(np.sqrt(x[3]**2 + x[2]**2), x[1])
    theta3 = np.arctan2(x[3], x[2])
    
    
    return rt+ndt, (theta1, theta2, theta3)

In [3]:
def k(a, da, t, q, sigma=2):
    return 0.5 * (q - 0.5*sigma - da(t))

def psi(a, da, t, z, tau, q, sigma=2):
    kk = k(a, da, t, q, sigma)
    
    if 2*np.sqrt(a(t)*z)/(sigma*(t-tau))<=700:
        term1 = 1./(sigma*(t - tau)) * np.exp(- (a(t) + z)/(sigma*(t-tau)))
        term2 = (a(t)/z)**(0.5*(q-sigma)/sigma)
        term3 = da(t) - (a(t)/(t-tau)) + kk
        term4 = iv(q/sigma-1, 2*np.sqrt(a(t)*z)/(sigma*(t-tau)))
        term5 = (np.sqrt(a(t)*z)/(t-tau)) * iv(q/sigma, 2*np.sqrt(a(t)*z)/(sigma*(t-tau)))
    else:
        term1 = 1./(sigma*(t - tau))
        term2 = (a(t)/z)**(0.5*(q-sigma)/sigma)
        term3 = da(t) - (a(t)/(t-tau)) + kk
        term4 = ive(q/sigma-1, (a(t) + z)/(sigma*(t-tau)))
        term5 = (np.sqrt(a(t)*z)/(t-tau)) * ive(q/sigma, (a(t) + z)/(sigma*(t-tau)))
    
    return term1 * term2 * (term3 * term4 + term5)

def ie_bessel_fpt(a, da, q, z, sigma=2, dt=0.1, T_max=2):
    g = [0]
    T = [0]
    g.append(-2*psi(a, da, dt, z, 0, q, sigma))
    T.append(dt)
    
    for n in range(2, int(T_max/dt)+2):
        s = -2 * psi(a, da, n*dt, z, 0, q, sigma)

        for j in range(1, n):
            s += 2 * dt * g[j] * psi(a, da, n*dt, a(j*dt), j*dt, q, sigma)

        g.append(s)
        T.append(n*dt)
        
    g = np.asarray(g)
    T = np.asarray(T)
    
    gt = interp1d(T, g)
    return gt

In [4]:
def HSDM_4D_likelihood(prms, RT, Theta):
    a = lambda t: prms[0]**2
    da = lambda t: 0
    ndt = prms[1]
    
    eta = prms[2]
    eta2 = eta**2
    
    mu = np.array([prms[3], prms[4], prms[5], prms[6]])
    
    if max(RT) <= 4.5:
        fpt = ie_bessel_fpt(a, da, mu.shape[0], 0.000001, 
                            dt=0.02, T_max=max(RT))
    else:
        fpt = ie_bessel_fpt(a, da, mu.shape[0], 0.000001, 
                            dt=0.02, T_max=4.5)
        
        fpt_l = ie_bessel_fpt(a, da, mu.shape[0], 0.000001, 
                              dt=0.1, T_max=max(RT))
    
    log_lik = 0
    for i in range(len(RT)):
        rt, theta = RT[i], Theta[i]
        if rt - ndt > 0.001:
            x3 =  prms[0]*np.cos(theta[0])
            x2 =  prms[0]*np.sin(theta[0])*np.cos(theta[1])
            x1 =  prms[0]*np.sin(theta[0])*np.sin(theta[1])*np.cos(theta[2])
            x0 =  prms[0]*np.sin(theta[0])*np.sin(theta[1])*np.sin(theta[2])
            fixed = 1/(np.sqrt(eta2 * (rt - ndt) + 1))
            exponent0 = -0.5*mu[0]**2/eta2 + 0.5*(x0 * eta2 + mu[0])**2 / (eta2 * (eta2 * (rt - ndt) + 1))
            exponent1 = -0.5*mu[1]**2/eta2 + 0.5*(x1 * eta2 + mu[1])**2 / (eta2 * (eta2 * (rt - ndt) + 1))
            exponent2 = -0.5*mu[2]**2/eta2 + 0.5*(x2 * eta2 + mu[2])**2 / (eta2 * (eta2 * (rt - ndt) + 1))
            exponent3 = -0.5*mu[3]**2/eta2 + 0.5*(x3 * eta2 + mu[3])**2 / (eta2 * (eta2 * (rt - ndt) + 1))
            term1 = fixed * np.exp(exponent0)
            term2 = fixed * np.exp(exponent1)
            term3 = fixed * np.exp(exponent2)
            term4 = fixed * np.exp(exponent3)
            
            if rt - ndt <= 4.5:
                density = term1 * term2 * term3 * term4 * fpt(rt - ndt)
            else:
                density = term1 * term2 * term3 * term4 * fpt_l(rt - ndt)
            
            if 0.1**14 < density:
                log_lik += -np.log(density)
            else:
                log_lik += -np.log(0.1**14)
        else:
            log_lik += -np.log(0.1**14)
        
    return log_lik

In [5]:
min_threshold = 1
max_threshold = 3

min_ndt = 0.1
max_ndt = 1

min_eta = 0.1
max_eta = 1

min_mu = -2.5
max_mu = 2.5

file_name = '_Recovery_data/IE_4d_recovery_dvar.csv'

In [6]:
for n in tqdm(range(1)):
    threshold = np.random.uniform(min_threshold, max_threshold)
    a = lambda t: threshold
    ndt = np.random.uniform(min_ndt, max_ndt)
    eta = np.random.uniform(min_eta, max_eta)
    mu = np.array([np.random.uniform(min_mu, max_mu), 
                   np.random.uniform(min_mu, max_mu),
                   np.random.uniform(min_mu, max_mu),
                   np.random.uniform(min_mu, max_mu)])
    
    RT = []
    Theta = []
    
    for i in range(500):
        rt, theta = simulate_HSDM_4D(a, mu, eta, ndt)
        RT.append(rt)
        Theta.append(theta)
    
    min_ans = differential_evolution(HSDM_4D_likelihood,
                                     args=(RT, Theta),
                                     bounds=[(min_threshold, max_threshold), 
                                             (min_ndt, max_ndt), (min_eta, max_eta),
                                             (min_mu, max_mu), (min_mu, max_mu), 
                                             (min_mu, max_mu), (min_mu, max_mu)])
    
    if min_ans.success:
        recovery_df = pd.DataFrame({'threshold_true': [threshold],
                                    'threshold_estimate': [min_ans.x[0]],
                                    'ndt_true': [ndt],
                                    'ndt_estimate': [min_ans.x[1]],
                                    'eta_true': [eta],
                                    'eta_estimate': [min_ans.x[2]],
                                    'mu1_true': [mu[3]],
                                    'mu1_estimate': [min_ans.x[3]],
                                    'mu2_true': [mu[2]],
                                    'mu2_estimate': [min_ans.x[4]],
                                    'mu3_true': [mu[1]],
                                    'mu3_estimate': [min_ans.x[5]],
                                    'mu4_true': [mu[0]],
                                    'mu4_estimate': [min_ans.x[6]]})
        
        print(eta, min_ans.x[2])
        
#         old_recovery_data = pd.read_csv(file_name, index_col=0)
#         recovery_df = pd.concat([old_recovery_data, recovery_df]).reset_index(drop=True)
        recovery_df.to_csv(file_name)

  2%|▋                               | 1/50 [01:20<1:05:26, 80.12s/it]

0.8767882691946569 1.0


  4%|█▏                             | 2/50 [09:22<4:13:13, 316.53s/it]

0.3894857874182703 0.42151543473429076


  6%|█▊                             | 3/50 [10:27<2:38:07, 201.86s/it]

0.38246501373795744 0.3754221912453365


  8%|██▍                            | 4/50 [12:02<2:02:19, 159.56s/it]

0.3747188893506799 0.3203287958374855


 10%|███                            | 5/50 [20:35<3:35:26, 287.27s/it]

0.681570231712616 0.7002307254577741


 12%|███▋                           | 6/50 [21:59<2:39:46, 217.88s/it]

0.34071750655801397 0.1


 14%|████▎                          | 7/50 [25:31<2:34:58, 216.24s/it]

0.5627800385293538 0.784889370415139


 16%|████▉                          | 8/50 [28:09<2:18:21, 197.65s/it]

0.8901843466064353 0.908647782781985


 18%|█████▌                         | 9/50 [33:20<2:39:07, 232.87s/it]

0.7561513830102243 0.8415114626513712


 20%|██████                        | 10/50 [37:56<2:44:12, 246.32s/it]

0.6573785090160121 0.3014780732258877


 22%|██████▌                       | 11/50 [40:56<2:26:54, 226.01s/it]

0.4962952715240411 0.16971397214221434


 24%|███████▏                      | 12/50 [45:16<2:29:37, 236.24s/it]

0.9876677978239199 0.9893767702577952


 26%|███████▊                      | 13/50 [48:00<2:12:13, 214.42s/it]

0.5918680827350249 0.7835932188860455


 28%|████████▍                     | 14/50 [54:17<2:38:09, 263.58s/it]

0.5678691436292976 0.7109746446324723


 30%|█████████                     | 15/50 [57:17<2:18:59, 238.28s/it]

0.7445779912121714 0.9174844349301325


 32%|█████████▌                    | 16/50 [59:12<1:54:07, 201.39s/it]

0.6556574921755556 0.5599674489402291


 34%|█████████▌                  | 17/50 [1:00:23<1:29:09, 162.11s/it]

0.7165914090792083 1.0


 36%|██████████                  | 18/50 [1:01:25<1:10:19, 131.87s/it]

0.28911477424076065 0.44109988012809975


 38%|██████████▋                 | 19/50 [1:03:18<1:05:14, 126.29s/it]

0.22887895763799748 0.1


 40%|███████████▏                | 20/50 [1:05:39<1:05:23, 130.79s/it]

0.1055837383474254 0.1


 42%|███████████▊                | 21/50 [1:07:56<1:04:03, 132.52s/it]

0.6020993820331643 0.2532024186228405


 44%|█████████████▏                | 22/50 [1:09:49<59:04, 126.59s/it]

0.7805365736578652 1.0


 46%|████████████▉               | 23/50 [1:12:27<1:01:19, 136.28s/it]

0.23845557904106537 0.49244066045067025


 48%|█████████████▍              | 24/50 [1:14:55<1:00:27, 139.51s/it]

0.5925718650745722 0.601998933558765


 50%|██████████████              | 25/50 [1:17:31<1:00:14, 144.58s/it]

0.32356229158946986 0.7042065272344928


 52%|██████████████▌             | 26/50 [1:20:30<1:01:57, 154.90s/it]

0.13561749048700775 0.1


 54%|███████████████             | 27/50 [1:23:29<1:02:11, 162.25s/it]

0.20415391279541203 0.7030434375180111


 56%|████████████████▊             | 28/50 [1:24:55<51:03, 139.26s/it]

0.7180945237306747 0.7975180464796847


 58%|█████████████████▍            | 29/50 [1:26:20<43:04, 123.05s/it]

0.858914652681155 0.1


 60%|████████████████▊           | 30/50 [1:34:57<1:20:22, 241.10s/it]

0.9916323051178945 0.9634603575628179


 62%|█████████████████▎          | 31/50 [1:38:55<1:16:02, 240.15s/it]

0.4754091025940805 0.628336207627848


 64%|███████████████████▏          | 32/50 [1:39:46<55:05, 183.63s/it]

0.5692294465280942 0.1


 66%|███████████████████▊          | 33/50 [1:41:17<44:07, 155.71s/it]

0.9623699295340004 0.9912602080636013


 68%|████████████████████▍         | 34/50 [1:42:24<34:27, 129.24s/it]

0.8412324224578319 0.7450175083059268


 70%|█████████████████████         | 35/50 [1:45:32<36:40, 146.68s/it]

0.33984639793582466 0.2742984539598381


 72%|█████████████████████▌        | 36/50 [1:49:42<41:29, 177.82s/it]

0.6516256073150575 0.1


 74%|██████████████████████▏       | 37/50 [1:52:30<37:51, 174.75s/it]

0.373084829306693 0.35954963374535176


 76%|██████████████████████▊       | 38/50 [1:54:49<32:48, 164.01s/it]

0.9182796767413506 0.9973025244962223


 78%|███████████████████████▍      | 39/50 [1:55:55<24:41, 134.68s/it]

0.7218155514910929 0.1


 80%|████████████████████████      | 40/50 [1:58:25<23:12, 139.30s/it]

0.35332202986064376 0.8318245072790181


 80%|████████████████████████      | 40/50 [2:02:27<30:36, 183.69s/it]


KeyboardInterrupt: 