In [8]:
"""
Theoretical-Informed Neural Network (TINN) for Parameter Estimation
in Drift Diffusion Model with Time-Dependent Drift
"""

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from time import time
from mpl_toolkits.mplot3d import Axes3D
import scipy.optimize
import pandas as pd
from scipy.stats import gaussian_kde
from scipy.optimize import minimize

# =============================================================================
# Configuration and Constants
# =============================================================================

# Simulation parameters
NUM_SIMULATIONS = 5000
TIME_STEPS = 20000
DT = 0.01
NON_DECISION_TIME = 0.7

# True parameters for DDM with time-dependent drift
V0_TRUE = 0.3
V1_TRUE = 0.5
TAU_TRUE = 0.8
BOUNDARY_SEPARATION = 2.0
STARTING_POINT = 0.0
NOISE_STD = 1.0

# TINN training parameters
NUM_EPOCHS = 30000
NUM_HIDDEN_LAYERS = 4
NUM_NEURONS_PER_LAYER = 30
THRESHOLD = 2.0

# =============================================================================
# DDM Simulation with Time-Dependent Drift
# =============================================================================

def drift_urgency(t):
    """Define time-dependent drift coefficient."""
    return V0_TRUE + V1_TRUE * t / (t + TAU_TRUE)

def simulation_ddm_rk4(num_simulations, time_steps, boundary_separation, starting_point, noise_std, dt):
    """Simulate DDM with time-dependent drift using RK4 method."""
    decision_times1 = []
    decision_times2 = []

    # Generate noise once for efficiency
    noise = np.random.normal(0, noise_std, (num_simulations, time_steps))

    for i in range(num_simulations):
        decision_variable = starting_point
        for t in range(time_steps):
            current_time = t * dt
            drift_rate = drift_urgency(current_time)

            # RK4 method
            k1 = drift_rate * dt + noise[i, t] * np.sqrt(dt)
            k2 = drift_urgency(current_time + dt / 2) * dt + noise[i, t] * np.sqrt(dt)
            k3 = drift_urgency(current_time + dt / 2) * dt + noise[i, t] * np.sqrt(dt)
            k4 = drift_urgency(current_time + dt) * dt + noise[i, t] * np.sqrt(dt)

            decision_variable += (k1 + 2 * k2 + 2 * k3 + k4) / 6

            # Check boundary crossings
            if decision_variable >= boundary_separation:
                decision_times1.append(t * dt)
                break
            elif decision_variable <= -boundary_separation:
                decision_times2.append(t * dt)
                break

    return np.array(decision_times1), np.array(decision_times2)

# =============================================================================
# TINN Architecture
# =============================================================================

class PINN_NeuralNet(tf.keras.Model):
    """Set basic architecture of the PINN model."""

    def __init__(self, lb, ub, output_dim=1, num_hidden_layers=5,
                 num_neurons_per_layer=50, activation='tanh',
                 kernel_initializer='glorot_normal', **kwargs):
        super().__init__(**kwargs)

        self.num_hidden_layers = num_hidden_layers
        self.output_dim = output_dim
        self.lb = lb
        self.ub = ub

        # Define NN architecture
        self.hidden = [tf.keras.layers.Dense(num_neurons_per_layer,
                             activation=tf.keras.activations.get(activation),
                             kernel_initializer=kernel_initializer)
                           for _ in range(self.num_hidden_layers-1)]
        self.hidden1 = tf.keras.layers.Dense(num_neurons_per_layer, activation='softplus')
        self.out = tf.keras.layers.Dense(output_dim)

    def call(self, X):
        """Forward-pass through neural network."""
        Z = X
        for i in range(self.num_hidden_layers-1):
            Z = self.hidden[i](Z)
        return self.out(Z)

class PINNIdentificationNet(PINN_NeuralNet):
    def __init__(self, *args, **kwargs):
        # Call init of base class
        super().__init__(*args, **kwargs)

        # Initialize variable for lambda
        self.lambd0 = self.add_weight(name="lambd0", initializer="ones", trainable=True, dtype=tf.float32)
        self.lambd1 = self.add_weight(name="lambd1", initializer="ones", trainable=True, dtype=tf.float32)
        self.tau = self.add_weight(name="tau", initializer="ones", trainable=True, dtype=tf.float32)
        self.sig = self.add_weight(name="sig", initializer="ones", trainable=True, dtype=tf.float32)
        self.tt0 = self.add_weight(name="tt0", initializer="ones", trainable=True, dtype=tf.float32)

        self.lambd0_list = []
        self.lambd1_list = []
        self.tau_list = []
        self.sig_list = []
        self.tt0_list = []

# =============================================================================
# TINN Solver
# =============================================================================

class PINNSolver():
    def __init__(self, model, X_r):
        self.model = model
        # Store collocation points
        self.t = X_r[:, 0:1]
        self.x = X_r[:, 1:2]
        # Initialize history of losses and global iteration counter
        self.hist = []
        self.iter = 0
        self.current_loss = 0.0

    def get_r(self):
        with tf.GradientTape(persistent=False) as tape:
            # Watch variables representing t and x during this GradientTape
            tape.watch(self.t)
            tape.watch(self.x)
            # Compute current values u(t,x)
            u = self.model(tf.stack([self.t[:, 0], self.x[:, 0]], axis=1))
            u_x = tape.gradient(u, self.x)

        u_t = tape.gradient(u, self.t)
        u_xx = tape.gradient(u_x, self.x)
        del tape
        return self.fun_r(self.t, self.x, u, u_t, u_x, u_xx)

    def loss_fn(self, X, xmax, rt1, rt2, u):
        # Compute phi_r
        r = self.get_r()
        phi_r = tf.reduce_mean(tf.square(r))
        loss_r = phi_r

        # Add phi_0 and phi_b to the loss
        u_pred = self.model(X[0])
        loss_0 = tf.reduce_mean(tf.square(u[0] - u_pred))

        u_pred = self.model(X[1])
        loss_b = tf.reduce_mean(tf.square(u[1] - u_pred))

        # First passage time calculation for upper boundary
        tspace1 = np.sort(rt1)
        tspace2 = np.sort(rt2)  # Define tspace2 here

        tspace_tf1 = tf.constant(tspace1.reshape((len(rt1), 1)), 'float32')
        minRT1 = min(rt1)
        minRT2 = min(rt2)
        self.minRT = min(minRT1, minRT2)
        t1 = tspace_tf1 - tf.math.sigmoid(self.model.tt0) * self.minRT

        xspace = np.ones((len(tspace1), 1)) * xmax
        xtf = tf.constant(xspace, 'float32')
        X_bound = tf.concat([t1, xtf], 1)
        p_i = self.model(X_bound)

        xspace1 = np.ones((len(tspace1), 1)) * (xmax - 0.02)
        xtf1 = tf.constant(xspace1, 'float32')
        X1 = tf.concat([t1, xtf1], 1)
        p_ii = self.model(X1)

        xspace2 = np.ones((len(tspace1), 1)) * (xmax - 0.04)
        xtf2 = tf.constant(xspace2, 'float32')
        X2 = tf.concat([t1, xtf2], 1)
        p_iii = self.model(X2)
        p_x = (3 * p_i - 4 * p_ii + p_iii) / (2 * 0.02)

        J1 = (tf.nn.softplus(self.model.lambd0) + tf.nn.softplus(self.model.lambd1) * t1 / (t1 + self.model.tau)) * p_i - 0.5 * (self.model.sig)**2 * p_x

        kde1 = gaussian_kde(tspace1)
        p_kde1 = len(tspace1) * kde1(tspace1)[:, np.newaxis] / (len(tspace2) + len(tspace1))
        p_kde_tensor1 = tf.convert_to_tensor(p_kde1, dtype=tf.float32)
        KDE_loss1 = tf.reduce_mean(tf.square(J1 - p_kde_tensor1))

        # First passage time calculation for lower boundary
        tspace_tf2 = tf.constant(tspace2.reshape((len(rt2), 1)), 'float32')
        t2 = tspace_tf2 - tf.math.sigmoid(self.model.tt0) * self.minRT

        xspace20 = np.ones((len(tspace2), 1)) * (-xmax)
        xtf20 = tf.constant(xspace20, 'float32')
        X20 = tf.concat([t2, xtf20], 1)
        p_i2 = self.model(X20)

        xspace21 = np.ones((len(tspace2), 1)) * (-xmax + 0.02)
        xtf21 = tf.constant(xspace21, 'float32')
        X21 = tf.concat([t2, xtf21], 1)
        p_ii2 = self.model(X21)

        xspace22 = np.ones((len(tspace2), 1)) * (-xmax + 0.04)
        xtf22 = tf.constant(xspace22, 'float32')
        X22 = tf.concat([t2, xtf22], 1)
        p_iii2 = self.model(X22)
        p_x2 = (-3 * p_i2 + 4 * p_ii2 - p_iii2) / (2 * 0.02)

        J2 = (tf.nn.softplus(self.model.lambd0) + tf.nn.softplus(self.model.lambd1) * t2 / (t2 + self.model.tau)) * p_i2 - 0.5 * (self.model.sig)**2 * p_x2

        kde2 = gaussian_kde(tspace2)
        p_kde2 = -len(tspace2) * kde2(tspace2)[:, np.newaxis] / (len(tspace2) + len(tspace1))
        p_kde_tensor2 = tf.convert_to_tensor(p_kde2, dtype=tf.float32)
        KDE_loss2 = tf.reduce_mean(tf.square(J2 - p_kde_tensor2))

        return (loss_r + loss_0 + loss_b) + KDE_loss1 + KDE_loss2

    def get_grad(self, X, xmax, rt1, rt2, u):
        with tf.GradientTape(persistent=True) as tape:
            loss = self.loss_fn(X, xmax, rt1, rt2, u)

        g = tape.gradient(loss, self.model.trainable_variables)
        del tape
        return loss, g

    def solve_with_TFoptimizer(self, optimizer, X, xmax, rt1, rt2, u, N=1001):
        """This method performs a gradient descent type optimization."""
        @tf.function
        def train_step():
            loss, grad_theta = self.get_grad(X, xmax, rt1, rt2, u)
            optimizer.apply_gradients(zip(grad_theta, self.model.trainable_variables))
            return loss

        for i in range(N):
            loss = train_step()
            self.current_loss = loss.numpy()
            self.callback()

    def callback(self, xr=None):
        if self.iter % 50 == 0:
            print('It {:05d}: loss = {:10.8e}'.format(self.iter, self.current_loss))
        self.hist.append(self.current_loss)
        self.iter += 1

# =============================================================================
# Fokker-Planck TINN Solver
# =============================================================================

class F_P_PINNSolver(PINNSolver):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def get_r(self):
        with tf.GradientTape(persistent=True) as tape:
            # Watch variables representing t and x during this GradientTape
            tape.watch(self.t)
            tape.watch(self.x)
            # Compute current values u(t,x)
            u = self.model(tf.stack([self.t[:, 0], self.x[:, 0]], axis=1))
            u_x = tape.gradient(u, self.x)

        u_t = tape.gradient(u, self.t)
        u_xx = tape.gradient(u_x, self.x)
        del tape
        return self.fun_r(self.t, self.x, u, u_t, u_x, u_xx)

class F_P_PINNIdentification(F_P_PINNSolver):
    def fun_r(self, t, x, u, u_t, u_x, u_xx):
        """Residual of the PDE"""
        # Define residual of the PDE
        return u_t + (tf.nn.softplus(self.model.lambd0) + tf.nn.softplus(self.model.lambd1) * t / (t + self.model.tau)) * u_x - 0.5 * (self.model.sig)**2 * u_xx

    def callback(self, xr=None):
        v_0 = tf.nn.softplus(self.model.lambd0).numpy()
        v_1 = tf.nn.softplus(self.model.lambd1).numpy()
        tau_0 = self.model.tau.numpy()
        sigma = self.model.sig.numpy()
        tt0 = (tf.math.sigmoid(self.model.tt0) * self.minRT).numpy()

        self.model.lambd0_list.append(v_0)
        self.model.lambd1_list.append(v_1)
        self.model.tau_list.append(tau_0)
        self.model.sig_list.append(sigma)
        self.model.tt0_list.append(tt0)

        if self.iter % 50 == 0:
            print('It {:05d}: loss = {:10.8e} drift0 = {:10.8e} drift1 = {:10.8e} tau = {:10.8e} sigma = {:10.8e} t0 = {:10.8e}'.format(
                self.iter, self.current_loss, v_0, v_1, tau_0, sigma, tt0))

        self.hist.append(self.current_loss)
        self.iter += 1

    def lambda0f(self):
        return tf.nn.softplus(self.model.lambd0).numpy()

    def lambda1f(self):
        return tf.nn.softplus(self.model.lambd1).numpy()

    def tauf(self):
        return self.model.tau.numpy()

    def sigmaf(self):
        sigma = self.model.sig
        return sigma.numpy()

    def tt0f(self):
        return (tf.math.sigmoid(self.model.tt0) * self.minRT).numpy()

# =============================================================================
# Main Execution
# =============================================================================

def main():
    print("TINN Parameter Estimation for DDM with Time-Dependent Drift")
    print("=" * 60)

    # Generate synthetic data
    print("Generating synthetic DDM data...")
    decision_times1, decision_times2 = simulation_ddm_rk4(
        NUM_SIMULATIONS, TIME_STEPS, BOUNDARY_SEPARATION,
        STARTING_POINT, NOISE_STD, DT
    )
    rt1 = np.sort(decision_times1) + NON_DECISION_TIME
    rt2 = np.sort(decision_times2) + NON_DECISION_TIME
    print(f"Generated {len(rt1)} upper and {len(rt2)} lower boundary crossings")

    # Prepare domain and training data
    xmax = THRESHOLD
    xmin = -xmax
    maxrt1 = max(rt1)
    maxrt2 = max(rt2)

    lb = tf.constant([0, xmin], dtype='float32')
    ub = tf.constant([max(maxrt1, maxrt2), xmax], dtype='float32')

    # Collocation points
    N = 100
    tspace = np.linspace(lb[0], ub[0], N + 1)
    xspace = np.linspace(lb[1], ub[1], N + 1)
    T, X = np.meshgrid(tspace, xspace)
    Xgrid = np.vstack([T.flatten(), X.flatten()]).T
    Xgrid = tf.constant(Xgrid, 'float32')

    # Boundary conditions
    def ddf(x, delta, x0):
        return 1 / (2 * np.sqrt(np.pi * delta)) * tf.math.exp(-((x - x0)**2) / (4 * delta))

    def fun_u_0(x):
        return ddf(x, 7.8e-2, 0)

    # Boundary data
    N_b = 500
    t_b = tf.random.uniform((N_b, 1), lb[0], ub[0], dtype='float32')
    x_b = lb[1] + (ub[1] - lb[1]) * tf.keras.backend.random_bernoulli((N_b, 1), 0.5, dtype='float32')
    X_b = tf.concat([t_b, x_b], axis=1)
    u_b = tf.zeros(tf.shape(x_b), 'float32')

    # Initial condition data
    N_0 = 50
    t_0 = tf.ones((N_0, 1), dtype='float32') * lb[0]
    x_0 = np.linspace(lb[1], ub[1], N_0 - 1, dtype='float32')
    x_0 = np.asarray(list(x_0) + [0.0])
    x_0 = np.sort(x_0)
    x_0 = tf.convert_to_tensor(x_0, dtype='float32')
    x_0 = tf.reshape(x_0, [N_0, 1])
    u_0 = fun_u_0(x_0)
    X_0 = tf.concat([t_0, x_0], axis=1)

    X_param = [X_0, X_b]
    u_param = [u_0, u_b]

    # Initialize and train TINN
    model = PINNIdentificationNet(
        lb, ub, num_hidden_layers=NUM_HIDDEN_LAYERS,
        num_neurons_per_layer=NUM_NEURONS_PER_LAYER,
        activation='tanh', kernel_initializer='glorot_normal'
    )
    model.build(input_shape=(None, 2))

    f_p_Identification = F_P_PINNIdentification(model, Xgrid)

    # Setup optimizer
    lr_schedule = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
        [1000, 15000], [0.01, 0.001, 0.0005]
    )
    optim = tf.keras.optimizers.Adam(learning_rate=lr_schedule)

    print("Starting TINN training...")
    t0 = time()
    f_p_Identification.solve_with_TFoptimizer(
        optim, X_param, xmax, rt1, rt2, u_param, N=NUM_EPOCHS
    )
    print('\nComputation time: {} seconds'.format(time() - t0))

    # Display results
    print("\nTrue parameters:")
    print(f"  v0 = {V0_TRUE:.4f}, v1 = {V1_TRUE:.4f}, tau = {TAU_TRUE:.4f}")
    print(f"  sigma = {NOISE_STD:.4f}, t0 = {NON_DECISION_TIME:.4f}")

    print("\nEstimated parameters:")
    print(f"  v0 = {f_p_Identification.lambda0f():.4f}, v1 = {f_p_Identification.lambda1f():.4f}")
    print(f"  tau = {f_p_Identification.tauf():.4f}, sigma = {f_p_Identification.sigmaf():.4f}")
    print(f"  t0 = {f_p_Identification.tt0f():.4f}")

    # Visualization code remains the same as original...
    # [Rest of your visualization code here]

if __name__ == "__main__":
    main()

TINN Parameter Estimation for DDM with Time-Dependent Drift
Generating synthetic DDM data...
Generated 4623 upper and 377 lower boundary crossings
Starting TINN training...




It 00000: loss = 4.25366116e+00 drift0 = 1.30596101e+00 drift1 = 1.30596101e+00 tau = 1.00999987e+00 sigma = 9.90000188e-01 t0 = 6.81708694e-01
It 00050: loss = 9.32520777e-02 drift0 = 1.20074296e+00 drift1 = 1.20227349e+00 tau = 1.15451670e+00 sigma = 8.67243886e-01 t0 = 7.05361903e-01
It 00100: loss = 5.09773493e-02 drift0 = 1.19118965e+00 drift1 = 1.19620705e+00 tau = 1.16652966e+00 sigma = 7.94577837e-01 t0 = 7.05262125e-01
It 00150: loss = 3.58340181e-02 drift0 = 1.17908680e+00 drift1 = 1.19151306e+00 tau = 1.17845404e+00 sigma = 6.72353506e-01 t0 = 7.04825819e-01
It 00200: loss = 2.70782597e-02 drift0 = 1.16323912e+00 drift1 = 1.18616676e+00 tau = 1.19222176e+00 sigma = 5.94949543e-01 t0 = 7.03658342e-01
It 00250: loss = 2.06525680e-02 drift0 = 1.14547157e+00 drift1 = 1.18090630e+00 tau = 1.20638883e+00 sigma = 5.64004958e-01 t0 = 7.02029765e-01
It 00300: loss = 1.60370227e-02 drift0 = 1.12700534e+00 drift1 = 1.17616475e+00 tau = 1.22015929e+00 sigma = 5.57872653e-01 t0 = 7.00490