In [3]:
import os
import time
import datetime
import pickle
import numpy as np

import pymc as pm
import pytensor
import pytensor.tensor as pt
import arviz as az

from pytensor.graph.op import Op
from efficient_fpt.multi_stage_cy import compute_loss_parallel, print_num_threads


# =====================================================================
# Config
# =====================================================================
DATA_PATH      = "addm_data_20251015-163921.pkl"
START_INDEX    = 0
END_INDEX      = 1000   # adjust if you want fewer trials
NUM_THREADS    = 32      # threads for compute_loss_parallel
N_TIMES        = 20      # repetitions for timing loglik/logp
N_DRAWS_BENCH  = 100     # draws for pm.sample benchmark

# Optional: keep BLAS single-threaded so OpenMP (your Cython) dominates
os.environ.setdefault("MKL_NUM_THREADS", "1")
os.environ.setdefault("OPENBLAS_NUM_THREADS", "1")


# =====================================================================
# Helper: timing function
# =====================================================================
def time_func(func, *args, n=20):
    t0 = time.perf_counter()
    for _ in range(n):
        func(*args)
    t1 = time.perf_counter()
    return (t1 - t0) / n


# =====================================================================
# Start full-cell timer
# =====================================================================
cell_t0 = time.perf_counter()
print("Cell started at:", datetime.datetime.now().strftime("%H:%M:%S"))


# =====================================================================
# Load data
# =====================================================================
data = pickle.load(open(DATA_PATH, "rb"))

DATA_TYPE = np.float64

# True parameters
a_true     = float(data["a"])
b_true     = float(data["b"])
x0_true    = float(data["x0"])
eta_true   = float(data["eta"])
kappa_true = float(data["kappa"])

r1_full    = data["r1_data"]
r2_full    = data["r2_data"]
flag_full  = data["flag_data"].astype(np.int32)

sigma = float(data["sigma"])
T     = float(data["T"])

mu_full     = data["mu_array_padded_data"].astype(DATA_TYPE)
sacc_full   = data["sacc_array_padded_data"].astype(DATA_TYPE)
length_full = data["d_data"].astype(np.int32)
rt_full     = data["decision_data"][:, 0].astype(DATA_TYPE)
choice_full = data["decision_data"][:, 1].astype(np.int32)

num_data_full, max_d = mu_full.shape

# Subset trials
start_index = START_INDEX
end_index   = min(END_INDEX, num_data_full)
idx         = slice(start_index, end_index)

r1_data     = r1_full[idx]
r2_data     = r2_full[idx]
flag_data   = flag_full[idx]
sacc_data   = sacc_full[idx]
length_data = length_full[idx]
rt_data     = rt_full[idx]
choice_data = choice_full[idx]

num_data = len(rt_data)
M = float(np.max(rt_data))

print(f"Using {num_data} trials, max_d={max_d}, M={M:.4f}")
print_num_threads()


# =====================================================================
# Direct log-likelihood (baseline)
# =====================================================================
def loglik_direct(eta, kappa, a, b, x0, num_threads=NUM_THREADS):
    """Direct call to compute_loss_parallel, as in your custom code."""
    mu1 = kappa * (r1_data - eta * r2_data)
    mu2 = kappa * (eta * r1_data - r2_data)
    nll = compute_loss_parallel(
        mu1,
        mu2,
        rt_data,
        choice_data,
        flag_data,
        sacc_data,
        length_data,
        max_d,
        sigma,
        a,
        b,
        x0,
        num_threads=num_threads,
    )
    return -num_data * nll


theta_true_vec = np.array(
    [eta_true, kappa_true, a_true, b_true, x0_true], dtype=float
)
eta0, kappa0, a0, b0, x00 = theta_true_vec

print("\n--- Warm-up direct loglik ---")
_ = loglik_direct(eta0, kappa0, a0, b0, x00)


# =====================================================================
# PyTensor Op for the log-likelihood
# =====================================================================
class LogLikeFPT(Op):
    """
    PyTensor Op that wraps compute_loss_parallel.
    Input: theta = [eta, kappa, a, b, x0] (1D np array)
    Output: scalar log-likelihood.
    """
    itypes = [pt.dvector]
    otypes = [pt.dscalar]

    def __init__(
        self,
        rt_data,
        choice_data,
        r1_data,
        r2_data,
        flag_data,
        sacc_data,
        length_data,
        max_d,
        sigma,
        num_threads=NUM_THREADS,
    ):
        self.rt_data     = np.asarray(rt_data, dtype=np.float64)
        self.choice_data = np.asarray(choice_data, dtype=np.int32)
        self.r1_data     = np.asarray(r1_data, dtype=np.float64)
        self.r2_data     = np.asarray(r2_data, dtype=np.float64)
        self.flag_data   = np.asarray(flag_data, dtype=np.int32)
        self.sacc_data   = np.asarray(sacc_data, dtype=np.float64)
        self.length_data = np.asarray(length_data, dtype=np.int32)

        self.max_d = int(max_d)
        self.sigma = float(sigma)
        self.num_data = len(self.rt_data)
        self.num_threads = int(num_threads)

    def perform(self, node, inputs, outputs):
        (theta,) = inputs
        eta, kappa, a, b, x0 = theta

        mu1_data = kappa * (self.r1_data - eta * self.r2_data)
        mu2_data = kappa * (eta * self.r1_data - self.r2_data)

        nll = compute_loss_parallel(
            mu1_data,
            mu2_data,
            self.rt_data,
            self.choice_data,
            self.flag_data,
            self.sacc_data,
            self.length_data,
            self.max_d,
            self.sigma,
            a,
            b,
            x0,
            num_threads=self.num_threads,
        )
        loglik = -self.num_data * nll
        outputs[0][0] = np.array(loglik, dtype="float64")


loglike_op = LogLikeFPT(
    rt_data=rt_data,
    choice_data=choice_data,
    r1_data=r1_data,
    r2_data=r2_data,
    flag_data=flag_data,
    sacc_data=sacc_data,
    length_data=length_data,
    max_d=max_d,
    sigma=sigma,
    num_threads=NUM_THREADS,
)

theta_sym = pt.dvector("theta_sym")
ll_sym = loglike_op(theta_sym)
f_ll = pytensor.function([theta_sym], ll_sym)

print("\n--- Warm-up PyTensor Op ---")
_ = f_ll(theta_true_vec)

ll_direct = loglik_direct(eta0, kappa0, a0, b0, x00)
ll_op     = float(f_ll(theta_true_vec))
print(f"Direct loglik(true)      = {ll_direct:.6f}")
print(f"LogLikeFPT loglik(true)  = {ll_op:.6f}")


# =====================================================================
# Build PyMC model and compile full logp
# =====================================================================
with pm.Model() as model:
    eta     = pm.Beta("eta", alpha=2.0, beta=2.0)
    kappa   = pm.Gamma("kappa", alpha=2.0, beta=4.0)  # scale = 1/4
    a_param = pm.Gamma("a", alpha=2.0, beta=1.0)      # scale = 1

    b_raw   = pm.Beta("b_raw", alpha=2.0, beta=2.0)
    b_param = pm.Deterministic("b", b_raw * a_param / M)

    x0_raw   = pm.Beta("x0_raw", alpha=2.0, beta=2.0)
    x0_param = pm.Deterministic("x0", -a_param + 2.0 * a_param * x0_raw)

    theta = pt.stack([eta, kappa, a_param, b_param, x0_param])
    pm.Potential("loglik", loglike_op(theta))

    # compile full logp (priors + potential)
    logp_fn = model.compile_fn(model.logp(sum=True))
    # get a valid starting point in transformed space
    ip = model.initial_point()

print("\n--- Warm-up full PyMC logp ---")
logp_val = float(logp_fn(ip))
print(f"Full PyMC logp(initial_point) = {logp_val:.6f}")


# =====================================================================
# Timing: direct vs Op vs full logp
# =====================================================================
print("\n=== Timing log-likelihood / logp evaluations ===")
t_direct = time_func(loglik_direct, eta0, kappa0, a0, b0, x00, NUM_THREADS, n=N_TIMES)
print(f"Direct loglik:       {t_direct*1000:.3f} ms per call")

t_op = time_func(f_ll, theta_true_vec, n=N_TIMES)
print(f"PyTensor Op:         {t_op*1000:.3f} ms per call "
      f"({t_op/t_direct:.2f}× vs direct)")

t_logp = time_func(logp_fn, ip, n=N_TIMES)
print(f"Full PyMC logp:      {t_logp*1000:.3f} ms per call "
      f"({t_logp/t_direct:.2f}× vs direct, {t_logp/t_op:.2f}× vs Op)")


# =====================================================================
# Benchmark pm.sample: with *full* timing around it
# =====================================================================
print("\n=== Benchmark pm.sample (with full timing) ===")
print("Calling pm.sample at:", datetime.datetime.now().strftime("%H:%M:%S"))

with model:
    sample_t0 = time.perf_counter()
    trace_bench = pm.sample(
        draws=N_DRAWS_BENCH,
        tune=0,
        chains=1,
        cores=1,
        step=pm.Metropolis(),
        progressbar=True,
        return_inferencedata=True,
    )
    sample_t1 = time.perf_counter()

print("pm.sample returned at:", datetime.datetime.now().strftime("%H:%M:%S"))
print(f"pm.sample wall time (sample_t1 - sample_t0): {sample_t1 - sample_t0:.1f} s")
print(f"pm.sample: {N_DRAWS_BENCH} draws => { (sample_t1 - sample_t0)/N_DRAWS_BENCH*1000:.1f} ms per draw")

acc = trace_bench.sample_stats["accepted"].values  # (1, draws, 5)
print("Bench chain overall accept rate:", acc.mean(axis=(0, 2)))


# =====================================================================
# End full-cell timing
# =====================================================================
cell_t1 = time.perf_counter()
print("Cell finished at:", datetime.datetime.now().strftime("%H:%M:%S"))
print(f"Full cell wall time (cell_t1 - cell_t0): {cell_t1 - cell_t0:.1f} s")


Cell started at: 03:50:54
Using 1000 trials, max_d=12, M=4.8855
Number of available threads: 64

--- Warm-up direct loglik ---

--- Warm-up PyTensor Op ---
Direct loglik(true)      = -1493.814546
LogLikeFPT loglik(true)  = -1493.814546

--- Warm-up full PyMC logp ---
Full PyMC logp(initial_point) = -1563.840746

=== Timing log-likelihood / logp evaluations ===
Direct loglik:       5.584 ms per call


PyTensor Op:         5.585 ms per call (1.00× vs direct)
Full PyMC logp:      5.493 ms per call (0.98× vs direct, 0.98× vs Op)

=== Benchmark pm.sample (with full timing) ===
Calling pm.sample at: 03:50:54


Sequential sampling (1 chains in 1 job)
CompoundStep
>Metropolis: [eta]
>Metropolis: [kappa]
>Metropolis: [a]
>Metropolis: [b_raw]
>Metropolis: [x0_raw]


Output()

Sampling 1 chain for 0 tune and 100 draw iterations (0 + 100 draws total) took 6 seconds.
Only one chain was sampled, this makes it impossible to run some convergence checks


pm.sample returned at: 03:51:01
pm.sample wall time (sample_t1 - sample_t0): 7.1 s
pm.sample: 100 draws => 70.6 ms per draw
Bench chain overall accept rate: [0.2 0.4 0.4 0.2 0.2 0.4 0.2 0.2 0.2 0.2 0.  0.  0.  0.  0.  0.2 0.  0.
 0.  0.  0.  0.  0.  0.2 0.  0.  0.2 0.2 0.2 0.  0.  0.  0.2 0.  0.  0.
 0.  0.  0.  0.  0.  0.  0.4 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
 0.  0.2 0.  0.4 0.  0.4 0.  0.  0.2 0.  0.4 0.  0.  0.  0.  0.  0.  0.
 0.2 0.2 0.  0.  0.2 0.  0.  0.4 0.  0.  0.  0.  0.2 0.2 0.  0.  0.  0.
 0.  0.  0.2 0.  0.  0.  0.  0.2 0.  0. ]
Cell finished at: 03:51:01
Full cell wall time (cell_t1 - cell_t0): 7.6 s
