In [9]:
"""
Theoretical-Informed Neural Network (TINN) for Parameter Estimation
in Time-Dependent Drift Diffusion Model (DDM)
"""

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from time import time
from mpl_toolkits.mplot3d import Axes3D
import scipy.optimize
import pandas as pd
from scipy.stats import gaussian_kde
from scipy.optimize import minimize
import math


# =============================================================================
# Configuration and Constants
# =============================================================================

# True parameters for time-dependent DDM
MU_C = 0.5
A = 20
TAU = 150
ALPHA = 2
BOUNDARY_SEPARATION = 2.0
STARTING_POINT = 0.0
NOISE_STD = 1.0
NON_DECISION_TIME = 0.7

# Simulation parameters
NUM_SIMULATIONS = 2000
TIME_STEPS = 20000
DT = 0.01

# TINN training parameters
NUM_EPOCHS = 50000
NUM_HIDDEN_LAYERS = 4
NUM_NEURONS_PER_LAYER = 30
THRESHOLD = 2.0

# Domain parameters
XMIN = -2.0
XMAX = 2.0
N_COLLOCATION = 100
N_BOUNDARY = 500
N_INITIAL = 100

# Physical parameters
DELTA = 7.8e-2


# =============================================================================
# DDM Simulation with Time-Dependent Drift
# =============================================================================

def time_dependent_drift(time, A, mu0, tau, alpha):
    """Define time-dependent drift coefficient."""
    z = mu0 + A * np.exp(-time/tau) * (time * np.exp(1) / (alpha - 1) / tau)**(alpha - 1) * ((alpha - 1)/time - 1/tau)
    return z


def rk4_step(drift_function, t, dt, A, mu0, tau, alpha):
    """RK4 integration step for drift function."""
    k1 = drift_function(t, A, mu0, tau, alpha)
    k2 = drift_function(t + 0.5 * dt, A, mu0, tau, alpha)
    k3 = drift_function(t + 0.5 * dt, A, mu0, tau, alpha)
    k4 = drift_function(t + dt, A, mu0, tau, alpha)
    drift_rate = (k1 + 2*k2 + 2*k3 + k4) / 6
    return drift_rate


def simulate_time_dependent_ddm(num_simulations, time_steps, boundary_separation, noise_std, dt, A, mu0, tau, alpha):
    """Simulate DDM with time-dependent drift using RK4 integration."""
    decision_times_upper = []
    decision_times_lower = []
    no_decision = []

    for i in range(num_simulations):
        decision_variable = STARTING_POINT
        for t in range(1, time_steps):
            # Use RK4 to get drift rate at time step t
            drift_rate = rk4_step(time_dependent_drift, t * dt, dt, A, mu0, tau, alpha)

            # Update decision variable with drift and noise
            noise = np.random.normal(0, noise_std)
            decision_variable += drift_rate * dt + noise * np.sqrt(dt)

            # Check boundary crossings
            if decision_variable >= boundary_separation:
                decision_times_upper.append(t * dt)
                break
            elif decision_variable <= -boundary_separation:
                decision_times_lower.append(t * dt)
                break
        else:
            # If no decision is made within time_steps
            no_decision.append(time_steps * dt)

    return decision_times_upper, decision_times_lower


# =============================================================================
# Neural Network Architecture
# =============================================================================

class PINN_NeuralNet(tf.keras.Model):
    """Base architecture for Physics-Informed Neural Network."""

    def __init__(self, lb, ub, output_dim=1, num_hidden_layers=5,
                 num_neurons_per_layer=50, activation='tanh',
                 kernel_initializer='glorot_normal', **kwargs):
        super().__init__(**kwargs)

        self.num_hidden_layers = num_hidden_layers
        self.output_dim = output_dim
        self.lb = lb
        self.ub = ub

        # Define NN architecture
        self.hidden = [tf.keras.layers.Dense(num_neurons_per_layer,
                             activation=tf.keras.activations.get(activation),
                             kernel_initializer=kernel_initializer)
                           for _ in range(self.num_hidden_layers-1)]
        self.hidden1 = tf.keras.layers.Dense(num_neurons_per_layer, activation='softplus')
        self.out = tf.keras.layers.Dense(output_dim)

    def call(self, X):
        """Forward-pass through neural network."""
        Z = X
        for i in range(self.num_hidden_layers-1):
            Z = self.hidden[i](Z)
        return self.out(Z)


class PINNIdentificationNet(PINN_NeuralNet):
    """PINN with trainable parameters for system identification."""

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        # Initialize trainable parameters
        self.mu = self.add_weight(name="mu", initializer="ones", trainable=True, dtype=tf.float32)
        self.a1 = self.add_weight(name="a1", initializer=tf.constant_initializer(15), trainable=True, dtype=tf.float32)
        self.tau = self.add_weight(name="tau", initializer=tf.constant_initializer(100), trainable=True, dtype=tf.float32)
        self.sig = self.add_weight(name="sig", initializer="ones", trainable=True, dtype=tf.float32)
        self.tt0 = self.add_weight(name="tt0", initializer="ones", trainable=True, dtype=tf.float32)

        # History tracking
        self.mu_list = []
        self.a1_list = []
        self.tau_list = []
        self.sig_list = []
        self.tt0_list = []


# =============================================================================
# PINN Solver
# =============================================================================

class PINNSolver:
    """Solver for Physics-Informed Neural Networks."""

    def __init__(self, model, X_r):
        self.model = model
        self.t = X_r[:, 0:1]
        self.x = X_r[:, 1:2]
        self.hist = []
        self.iter = 0

    def get_r(self):
        """Compute PDE residual using automatic differentiation."""
        with tf.GradientTape(persistent=False) as tape:
            tape.watch(self.t)
            tape.watch(self.x)

            u = self.model(tf.stack([self.t[:, 0], self.x[:, 0]], axis=1))
            u_x = tape.gradient(u, self.x)

        u_t = tape.gradient(u, self.t)
        u_xx = tape.gradient(u_x, self.x)

        del tape
        return self.fun_r(self.t, self.x, u, u_t, u_x, u_xx)

    def loss_fn(self, X, xmax, rt1, rt2, u):
        """Compute total loss function."""
        # PDE residual loss
        r = self.get_r()
        phi_r = tf.reduce_mean(tf.square(r))

        # Initial condition loss
        u_pred_0 = self.model(X[0])
        loss_0 = tf.reduce_mean(tf.square(u[0] - u_pred_0))

        # Boundary condition loss
        u_pred_1 = self.model(X[1])
        loss_b = tf.reduce_mean(tf.square(u[1] - u_pred_1))

        # First passage time loss for upper boundary
        tspace1 = np.sort(rt1)
        tspace_tf1 = tf.constant(tspace1.reshape((len(rt1), 1)), 'float32')
        minRT1 = min(rt1)
        self.minRT = minRT1
        t1 = tspace_tf1 - tf.math.sigmoid(self.model.tt0) * self.minRT

        # Compute boundary flux using finite differences
        xspace = np.ones((len(tspace1), 1)) * xmax
        X_bound = tf.concat([t1, tf.constant(xspace, 'float32')], 1)
        p_i = self.model(X_bound)

        xspace1 = np.ones((len(tspace1), 1)) * (xmax - 0.02)
        X1 = tf.concat([t1, tf.constant(xspace1, 'float32')], 1)
        p_ii = self.model(X1)

        xspace2 = np.ones((len(tspace1), 1)) * (xmax - 0.04)
        X2 = tf.concat([t1, tf.constant(xspace2, 'float32')], 1)
        p_iii = self.model(X2)

        p_x = (3 * p_i - 4 * p_ii + p_iii) / (2 * 0.02)

        # Compute flux using Fokker-Planck equation
        muc1 = tf.nn.softplus(self.model.mu)
        mut1 = self.model.a1 * tf.math.exp(-t1 / self.model.tau) * (t1 * math.exp(1) / self.model.tau) * (1 / t1 - 1 / self.model.tau)
        J1 = (muc1 + mut1) * p_i - 0.5 * (self.model.sig)**2 * p_x

        # Compare with empirical distribution
        kde1 = gaussian_kde(tspace1)
        p_kde1 = len(tspace1) * kde1(tspace1)[:, np.newaxis] / (len(rt2) + len(rt1))
        KDE_loss1 = tf.reduce_mean(tf.square(J1 - tf.convert_to_tensor(p_kde1, dtype=tf.float32)))

        # First passage time loss for lower boundary
        tspace_tf2 = tf.constant(np.sort(rt2).reshape((len(rt2), 1)), 'float32')
        t2 = tspace_tf2 - tf.math.sigmoid(self.model.tt0) * self.minRT

        # Compute boundary flux for lower boundary
        xspace20 = np.ones((len(rt2), 1)) * (-xmax)
        X20 = tf.concat([t2, tf.constant(xspace20, 'float32')], 1)
        p_i2 = self.model(X20)

        xspace21 = np.ones((len(rt2), 1)) * (-xmax + 0.02)
        X21 = tf.concat([t2, tf.constant(xspace21, 'float32')], 1)
        p_ii2 = self.model(X21)

        xspace22 = np.ones((len(rt2), 1)) * (-xmax + 0.04)
        X22 = tf.concat([t2, tf.constant(xspace22, 'float32')], 1)
        p_iii2 = self.model(X22)

        p_x2 = (-3 * p_i2 + 4 * p_ii2 - p_iii2) / (2 * 0.02)

        mut2 = self.model.a1 * tf.math.exp(-t2 / self.model.tau) * (t2 * math.exp(1) / self.model.tau) * (1 / t2 - 1 / self.model.tau)
        J2 = (muc1 + mut2) * p_i2 - 0.5 * (self.model.sig)**2 * p_x2

        kde2 = gaussian_kde(rt2)
        p_kde2 = -len(rt2) * kde2(rt2)[:, np.newaxis] / (len(rt2) + len(rt1))
        KDE_loss2 = tf.reduce_mean(tf.square(J2 - tf.convert_to_tensor(p_kde2, dtype=tf.float32)))

        return 100 * (phi_r + loss_0 + loss_b) + KDE_loss1 + KDE_loss2

    def get_grad(self, X, xmax, rt1, rt2, u):
        """Compute gradients of loss function."""
        with tf.GradientTape(persistent=True) as tape:
            loss = self.loss_fn(X, xmax, rt1, rt2, u)

        g = tape.gradient(loss, self.model.trainable_variables)
        del tape
        return loss, g

    def solve_with_TFoptimizer(self, optimizer, X, xmax, rt1, rt2, u, N=1001):
        """Solve using TensorFlow optimizer."""
        @tf.function
        def train_step():
            loss, grad_theta = self.get_grad(X, xmax, rt1, rt2, u)
            optimizer.apply_gradients(zip(grad_theta, self.model.trainable_variables))
            return loss

        for i in range(N):
            loss = train_step()
            self.current_loss = loss.numpy()
            self.callback()

    def callback(self, xr=None):
        """Callback function for training progress."""
        if self.iter % 50 == 0:
            print(f'It {self.iter:05d}: loss = {self.current_loss:10.8e}')
        self.hist.append(self.current_loss)
        self.iter += 1


# =============================================================================
# Fokker-Planck PINN Solver
# =============================================================================

class F_P_PINNSolver(PINNSolver):
    """PINN solver for Fokker-Planck equation."""

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def get_r(self):
        """Compute PDE residual with persistent gradient tape."""
        with tf.GradientTape(persistent=True) as tape:
            tape.watch(self.t)
            tape.watch(self.x)

            u = self.model(tf.stack([self.t[:, 0], self.x[:, 0]], axis=1))
            u_x = tape.gradient(u, self.x)

        u_t = tape.gradient(u, self.t)
        u_xx = tape.gradient(u_x, self.x)

        del tape
        return self.fun_r(self.t, self.x, u, u_t, u_x, u_xx)


class F_P_PINNIdentification(F_P_PINNSolver):
    """Fokker-Planck PINN for parameter identification."""

    def fun_r(self, t, x, u, u_t, u_x, u_xx):
        """Residual of the time-dependent Fokker-Planck equation."""
        muc1 = tf.nn.softplus(self.model.mu)
        mut = self.model.a1 * tf.math.exp(-t / self.model.tau) * (t * math.exp(1) / self.model.tau) * (1 / (t + 1e-4) - 1 / self.model.tau)

        return u_t + (muc1 + mut) * u_x - 0.5 * (self.model.sig)**2 * u_xx

    def callback(self, xr=None):
        """Enhanced callback with parameter tracking."""
        mu = tf.nn.softplus(self.model.mu).numpy()
        a_1 = self.model.a1.numpy()
        tau_0 = self.model.tau.numpy()
        sigma = self.model.sig.numpy()
        tt0 = (tf.math.sigmoid(self.model.tt0) * self.minRT).numpy()

        # Track parameter history
        self.model.mu_list.append(mu)
        self.model.a1_list.append(a_1)
        self.model.tau_list.append(tau_0)
        self.model.sig_list.append(sigma)
        self.model.tt0_list.append(tt0)

        if self.iter % 50 == 0:
            print(f'It {self.iter:05d}: loss = {self.current_loss:10.8e} '
                  f'drift0 = {mu:10.8e} A = {a_1:10.8e} tau = {tau_0:10.8e} '
                  f'sigma = {sigma:10.8e} t0 = {tt0:10.8e}')

        self.hist.append(self.current_loss)
        self.iter += 1

    # Parameter getter methods
    def mu_f(self):
        return tf.nn.softplus(self.model.mu).numpy()

    def a1_f(self):
        return self.model.a1.numpy()

    def tau_f(self):
        return self.model.tau.numpy()

    def sigma_f(self):
        return self.model.sig.numpy()

    def tt0_f(self):
        return (tf.math.sigmoid(self.model.tt0) * self.minRT).numpy()


# =============================================================================
# Utility Functions
# =============================================================================

def dirac_delta_function(x, delta, x0):
    """Approximate Dirac delta function as Gaussian distribution."""
    return 1 / (2 * np.sqrt(np.pi * delta)) * tf.math.exp(-((x - x0)**2) / (4 * delta))


def initial_condition(x):
    """Define initial condition for Fokker-Planck equation."""
    return dirac_delta_function(x, DELTA, 0.0)


# =============================================================================
# Main Execution
# =============================================================================

def main():
    """Main function to run parameter estimation."""
    print("TINN Parameter Estimation for Time-Dependent DDM")
    print("=" * 50)

    # Generate synthetic data
    print("Generating synthetic DDM data with time-dependent drift...")
    decision_times1, decision_times2 = simulate_time_dependent_ddm(
        NUM_SIMULATIONS, TIME_STEPS, BOUNDARY_SEPARATION, NOISE_STD, DT, A, MU_C, TAU, ALPHA)

    # Add non-decision time
    rt1 = np.sort(decision_times1) + NON_DECISION_TIME
    rt2 = np.sort(decision_times2) + NON_DECISION_TIME

    print(f"Generated {len(rt1)} upper and {len(rt2)} lower boundary crossings")

    # Prepare domain
    maxrt1 = max(rt1)
    lb = tf.constant([0, XMIN], dtype='float32')
    ub = tf.constant([maxrt1, XMAX], dtype='float32')

    # Create collocation points
    tspace = np.linspace(lb[0], ub[0], N_COLLOCATION + 1)
    xspace = np.linspace(lb[1], ub[1], N_COLLOCATION + 1)
    T, X = np.meshgrid(tspace, xspace)
    Xgrid = np.vstack([T.flatten(), X.flatten()]).T
    Xgrid = tf.constant(Xgrid, 'float32')

    # Prepare boundary conditions
    t_b = tf.random.uniform((N_BOUNDARY, 1), lb[0], ub[0], dtype='float32')
    x_b = lb[1] + (ub[1] - lb[1]) * tf.keras.backend.random_bernoulli((N_BOUNDARY, 1), 0.5, dtype='float32')
    X_b = tf.concat([t_b, x_b], axis=1)
    u_b = tf.zeros(tf.shape(x_b), 'float32')

    # Prepare initial condition
    t_0 = tf.ones((N_INITIAL, 1), dtype='float32') * lb[0]
    x_0 = np.linspace(lb[1], ub[1], N_INITIAL, dtype='float32')
    x_0 = tf.convert_to_tensor(x_0, dtype='float32')
    x_0 = tf.reshape(x_0, [N_INITIAL, 1])
    u_0 = initial_condition(x_0)
    X_0 = tf.concat([t_0, x_0], axis=1)

    # Collect training data
    X_param = [X_0, X_b]
    u_param = [u_0, u_b]

    # Initialize model and solver
    model = PINNIdentificationNet(lb, ub,
                                 num_hidden_layers=NUM_HIDDEN_LAYERS,
                                 num_neurons_per_layer=NUM_NEURONS_PER_LAYER,
                                 activation='tanh',
                                 kernel_initializer='glorot_normal')
    model.build(input_shape=(None, 2))

    solver = F_P_PINNIdentification(model, Xgrid)

    # Set up optimizer
    lr_schedule = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
        [1000, 25000], [0.01, 0.001, 0.0005])
    optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)

    # Train the model
    print("Starting TINN training...")
    t0 = time()
    solver.solve_with_TFoptimizer(optimizer, X_param, XMAX, rt1, rt2, u_param, N=NUM_EPOCHS)
    training_time = time() - t0
    print(f'\nComputation time: {training_time:.2f} seconds')

    # Display results
    print("\nTrue parameters:")
    print(f"  mu_c = {MU_C:.4f}, A = {A:.4f}, tau = {TAU:.4f}, sigma = {NOISE_STD:.4f}, t0 = {NON_DECISION_TIME:.4f}")

    print("\nEstimated parameters:")
    print(f"  mu_c = {solver.mu_f():.4f}, A = {solver.a1_f():.4f}, tau = {solver.tau_f():.4f}, "
          f"sigma = {solver.sigma_f():.4f}, t0 = {solver.tt0_f():.4f}")

    print("\nParameter estimation completed successfully!")


if __name__ == "__main__":
    main()

TINN Parameter Estimation for Time-Dependent DDM
Generating synthetic DDM data with time-dependent drift...
Generated 1945 upper and 55 lower boundary crossings
Starting TINN training...




It 00000: loss = 1.00310333e+02 drift0 = 1.30596101e+00 A = 1.49899998e+01 tau = 1.00010002e+02 sigma = 9.90000129e-01 t0 = 7.33020127e-01
It 00050: loss = 8.28577518e+00 drift0 = 1.18844807e+00 A = 1.48238039e+01 tau = 1.00174683e+02 sigma = 1.07543075e+00 t0 = 7.51646817e-01
It 00100: loss = 2.32839966e+00 drift0 = 9.01532173e-01 A = 1.43684969e+01 tau = 1.00630516e+02 sigma = 5.98695278e-01 t0 = 7.44464993e-01
It 00150: loss = 9.45198834e-01 drift0 = 6.73040509e-01 A = 1.39097481e+01 tau = 1.01080048e+02 sigma = 5.10393500e-01 t0 = 7.22146988e-01
It 00200: loss = 5.69334388e-01 drift0 = 5.70665956e-01 A = 1.36189528e+01 tau = 1.01358330e+02 sigma = 4.59266335e-01 t0 = 7.03143120e-01
It 00250: loss = 4.54322994e-01 drift0 = 5.18520415e-01 A = 1.33967581e+01 tau = 1.01566628e+02 sigma = 4.52559859e-01 t0 = 6.81832612e-01
It 00300: loss = 4.02258396e-01 drift0 = 4.89319086e-01 A = 1.32095413e+01 tau = 1.01738991e+02 sigma = 4.67434376e-01 t0 = 6.56294107e-01
It 00350: loss = 3.68340850