In [1]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import pickle
import time
import seaborn as sns
from scipy.optimize import fmin, minimize, LinearConstraint, Bounds

from efficient_fpt.multi_stage_cy import compute_loss_parallel, print_num_threads


data = pickle.load(open("addm_data_20250223-044706.pkl", "rb"))

In [2]:
CYTHON_TYPE = np.float64

a = data["a"]
b = data["b"]
x0 = data["x0"]
mu1_true = data["mu1"]
mu2_true = data["mu2"]
sigma = data["sigma"]
T = data["T"]

mu_true_data = data["mu_data_padded"].astype(CYTHON_TYPE)
sacc_data = data["sacc_data_padded"].astype(CYTHON_TYPE)
length_data = data["d_data"].astype(np.int32)
rt_data = data["decision_data"][:, 0].astype(CYTHON_TYPE)
choice_data = data["decision_data"][:, 1].astype(np.int32)

num_data, max_d = mu_true_data.shape
# set data range if necessary
start_index, end_index = 0, 50000
num_data = end_index - start_index
mu_true_data, sacc_data, length_data, rt_data, choice_data = mu_true_data[start_index:end_index], sacc_data[start_index:end_index], length_data[start_index:end_index], rt_data[start_index:end_index], choice_data[start_index:end_index]
flag_data = np.isclose(mu_true_data[:, 0], mu2_true).astype(np.int32)


In [3]:
mu1_true_data = np.full(num_data, mu1_true, dtype=CYTHON_TYPE)
mu2_true_data = np.full(num_data, mu2_true, dtype=CYTHON_TYPE)

In [4]:
num_iter = 10
start_time = time.time()
for _ in range(num_iter):
    loss = compute_loss_parallel(mu1_true_data, mu2_true_data, rt_data, choice_data, flag_data, sacc_data, length_data, max_d, sigma, a, b, x0)
end_time = time.time()
print(f"Likelihood evaluation time: {(end_time - start_time) / num_iter:.3f} s")

Likelihood evaluation time: 0.272 s


In [5]:
print("\n")
print_num_threads()
print("# data =", num_data)

# Constraint optimization for searching all parameters
print("\nNumerical optimization for mu1, mu2, a, b, x0:")
method = "trust-constr"
print("Using " + method)
func = lambda paras: compute_loss_parallel(np.full(num_data, paras[0], dtype=CYTHON_TYPE), 
                                           np.full(num_data, paras[1], dtype=CYTHON_TYPE), 
                                           rt_data, choice_data, flag_data, sacc_data, length_data, max_d, sigma, paras[2], paras[3], paras[4])
bounds = Bounds([0, -np.inf, 0, 0, -np.inf], [np.inf, 0, np.inf, np.inf, np.inf])
con = LinearConstraint([[0, 0, 1, -np.max(rt_data), 0], [0, 0, 1, 0, 1], [0, 0, 1, 0, -1]], lb=[0, 0, 0], ub=[np.inf, np.inf, np.inf])
initial_guess = [0, 0, 1, 0.1, 0]
print("Initial guess:", initial_guess)
print()
start_time = time.time()
paras_opt_result = minimize(func, x0=initial_guess, bounds=bounds, constraints=con, method=method, options={"verbose": 1})
print(f"Total time: {time.time() - start_time:.3f} seconds")
print(paras_opt_result)



Number of available threads: 64
# data = 50000

Numerical optimization for mu1, mu2, a, b, x0:
Using trust-constr
Initial guess: [0, 0, 1, 0.1, 0]

`xtol` termination condition is satisfied.
Number of iterations: 76, function evaluations: 438, CG iterations: 212, optimality: 4.59e-08, constraint violation: 0.00e+00, execution time: 1.2e+02 s.
Total time: 118.973 seconds
           message: `xtol` termination condition is satisfied.
           success: True
            status: 2
               fun: 1.8697804425160607
                 x: [ 9.995e-01 -8.053e-01  2.105e+00  3.025e-01 -1.970e-01]
               nit: 76
              nfev: 438
              njev: 73
              nhev: 0
          cg_niter: 212
      cg_stop_cond: 4
              grad: [ 4.470e-08 -1.192e-07  7.785e-08 -1.937e-07 -1.490e-08]
   lagrangian_grad: [ 2.132e-08 -4.590e-08  2.116e-08  2.982e-09 -1.155e-08]
            constr: [array([ 3.277e-01,  1.908e+00,  2.302e+00]), array([ 9.995e-01, -8.053e-01,  2.105e+00

In [6]:
theta_hat = paras_opt_result['x']
print("True and estimated value of parameters:")
print(f"mu1: {mu1_true:.3f}, {theta_hat[0]:.3f}")
print(f"mu2: {mu2_true:.3f}, {theta_hat[1]:.3f}")
print(f"a: {a:.3f}, {theta_hat[2]:.3f}")
print(f"b: {b:.3f}, {theta_hat[3]:.3f}")
print(f"x0: {x0:.3f}, {theta_hat[4]:.3f}")


True and estimated value of parameters:
mu1: 1.000, 0.999
mu2: -0.800, -0.805
a: 2.100, 2.105
b: 0.300, 0.303
x0: -0.200, -0.197


#### Bootstrap confidence intervals

In [None]:
n_boot = 100
alpha = 0.05  # 95% confidence interval

start_time = time.time()
bootstrap_estimates = np.zeros((n_boot, 5))
# Bootstrap resampling
for i in range(n_boot):
    if i % 100 == 0:
        print(f"Bootstrap iteration {i} takes {time.time() - start_time:.3f} seconds")
    indices = np.random.choice(len(rt_data), size=len(rt_data), replace=True)
    # Resample data
    rt_boot = rt_data[indices]
    choice_boot = choice_data[indices]
    flag_boot = flag_data[indices]
    sacc_boot = sacc_data[indices, :]
    length_boot = length_data[indices]
    # Define the function for resampled data
    func_boot = lambda paras: compute_loss_parallel(
        paras[0], paras[1], rt_boot, choice_boot, flag_boot, sacc_boot, length_boot, max_d, sigma, paras[2], paras[3], paras[4]
    )
    # Optimize for bootstrap sample
    result = minimize(func_boot, x0=theta_hat, bounds=bounds, constraints=con, method=method)
    if result.success:
        bootstrap_estimates[i, :] = result.x
    else:
        print(f"Warning: Optimization failed for bootstrap iteration {i}")

print(f"Total time: {time.time() - start_time:.3f} seconds")
# np.save("bootstrap_estimates.npy", bootstrap_estimates)

Bootstrap iteration 0 takes 0.000 seconds
Total time: 14198.076 seconds


In [7]:
# Compute percentiles for pivotal confidence intervals
bootstrap_means = np.mean(bootstrap_estimates, axis=0)
lower_pivot = 2 * bootstrap_means - np.percentile(bootstrap_estimates, 100 * (1 - alpha / 2), axis=0)
upper_pivot = 2 * bootstrap_means - np.percentile(bootstrap_estimates, 100 * (alpha / 2), axis=0)

# Store results
bootstrap_ci = np.vstack((lower_pivot, upper_pivot)).T

In [8]:
bootstrap_ci

array([[ 0.98549772,  1.01573224],
       [-0.81820473, -0.79173476],
       [ 2.09468029,  2.11860388],
       [ 0.2982302 ,  0.30719499],
       [-0.20827249, -0.18829406]])