In [None]:
"""
Theoretical-Informed Neural Network (TINN) for Parameter Estimation
in Ornstein-Uhlenbeck Drift Diffusion Model (OU-DDM)
"""

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from time import time
from scipy.stats import gaussian_kde


# =============================================================================
# Configuration and Constants
# =============================================================================

# True parameters for OU-DDM
DRIFT_0 = 0.6
BETA = 0.3
BOUNDARY_SEPARATION = 2.0
STARTING_POINT = 0.0
NOISE_STD = 1.0
NON_DECISION_TIME = 0.7

# Simulation parameters
NUM_SIMULATIONS = 5000
TIME_STEPS = 20000
DT = 0.01

# TINN training parameters
NUM_EPOCHS = 30000
NUM_HIDDEN_LAYERS = 4
NUM_NEURONS_PER_LAYER = 50

# Domain parameters
XMIN = -2.0
XMAX = 2.0
N_COLLOCATION = 200
N_BOUNDARY = 100
N_INITIAL = 50

# Physical parameters
DELTA = 7.8e-2

# Loss weights
LAMBDA_PDE = 1.0
LAMBDA_IC = 10.0
LAMBDA_BC = 10.0
LAMBDA_DATA = 100.0


# =============================================================================
# Ornstein-Uhlenbeck DDM Simulation
# =============================================================================

def drift_ou(drift0, beta, xx):
    """Drift function for Ornstein-Uhlenbeck process."""
    return drift0 - beta * xx


def simulate_ou_ddm_rk4(num_simulations, time_steps, boundary_separation,
                       starting_point, noise_std, dt, drift0, beta):
    """Simulate OU-DDM using RK4 integration for improved accuracy."""
    decision_times_upper = []
    decision_times_lower = []
    trajectories_upper = []
    trajectories_lower = []

    for i in range(num_simulations):
        decision_variable = starting_point
        trajectory = [decision_variable]

        for t in range(1, time_steps):
            current_position = decision_variable

            # RK4 method for stochastic integration
            k1 = drift_ou(drift0, beta, current_position) * dt + np.random.normal(0, noise_std) * np.sqrt(dt)
            k2 = drift_ou(drift0, beta, current_position + k1/2) * dt + np.random.normal(0, noise_std) * np.sqrt(dt)
            k3 = drift_ou(drift0, beta, current_position + k2/2) * dt + np.random.normal(0, noise_std) * np.sqrt(dt)
            k4 = drift_ou(drift0, beta, current_position + k3) * dt + np.random.normal(0, noise_std) * np.sqrt(dt)

            decision_variable += (k1 + 2*k2 + 2*k3 + k4) / 6
            trajectory.append(decision_variable)

            if decision_variable >= boundary_separation:
                decision_times_upper.append(t * dt)
                trajectories_upper.append(trajectory)
                break
            elif decision_variable <= -boundary_separation:
                decision_times_lower.append(t * dt)
                trajectories_lower.append(trajectory)
                break
        else:
            # If no decision is made
            pass

    return (np.array(decision_times_upper), np.array(decision_times_lower),
            trajectories_upper, trajectories_lower)


# =============================================================================
# Neural Network Architecture
# =============================================================================

class PINN_NeuralNet(tf.keras.Model):
    """Base architecture for Physics-Informed Neural Network."""

    def __init__(self, lb, ub, output_dim=1, num_hidden_layers=5,
                 num_neurons_per_layer=50, activation='tanh',
                 kernel_initializer='glorot_normal', **kwargs):
        super().__init__(**kwargs)

        self.num_hidden_layers = num_hidden_layers
        self.output_dim = output_dim
        self.lb = lb
        self.ub = ub

        # Define NN architecture
        self.hidden_layers = []
        for _ in range(num_hidden_layers):
            self.hidden_layers.append(
                tf.keras.layers.Dense(num_neurons_per_layer,
                                    activation=activation,
                                    kernel_initializer=kernel_initializer)
            )
        self.out = tf.keras.layers.Dense(output_dim)

    def call(self, X):
        """Forward-pass through neural network."""
        Z = X
        for layer in self.hidden_layers:
            Z = layer(Z)
        return self.out(Z)


class PINNIdentificationNet(PINN_NeuralNet):
    """PINN with trainable parameters for OU-DDM system identification."""

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        # Initialize trainable parameters with better initial values
        self.lambd0 = tf.Variable(0.5, name="lambd0", trainable=True, dtype=tf.float32)
        self.beta = tf.Variable(0.2, name="beta", trainable=True, dtype=tf.float32)
        self.sig = tf.Variable(0.8, name="sig", trainable=True, dtype=tf.float32)
        self.tt0 = tf.Variable(0.5, name="tt0", trainable=True, dtype=tf.float32)

        # History tracking
        self.lambd0_list = []
        self.beta_list = []
        self.sig_list = []
        self.tt0_list = []


# =============================================================================
# PINN Solver
# =============================================================================

class F_P_PINNIdentification:
    """Fokker-Planck PINN for OU-DDM parameter identification."""

    def __init__(self, model, X_r):
        self.model = model
        self.t = X_r[:, 0:1]
        self.x = X_r[:, 1:2]
        self.hist = []
        self.iter = 0
        self.minRT = None

    def get_r(self):
        """Compute PDE residual using automatic differentiation."""
        with tf.GradientTape(persistent=True) as tape:
            tape.watch(self.t)
            tape.watch(self.x)

            # Compute solution
            u = self.model(tf.stack([self.t[:, 0], self.x[:, 0]], axis=1))

            # Compute gradients
            u_x = tape.gradient(u, self.x)
            u_xx = tape.gradient(u_x, self.x)
            u_t = tape.gradient(u, self.t)

        # Fokker-Planck equation for OU process
        drift = self.model.lambd0 - self.model.beta * self.x
        residual = u_t + drift * u_x - 0.5 * (self.model.sig)**2 * u_xx - self.model.beta * u

        del tape
        return residual

    def loss_fn(self, X, xmax, rt1, rt2, u_data):
        """Compute total loss function."""
        # Set minRT for the first time
        if self.minRT is None:
            self.minRT = min(np.min(rt1), np.min(rt2))

        # 1. PDE residual loss
        r = self.get_r()
        loss_pde = tf.reduce_mean(tf.square(r))

        # 2. Initial condition loss
        u_pred_0 = self.model(X[0])
        loss_ic = tf.reduce_mean(tf.square(u_data[0] - u_pred_0))

        # 3. Boundary condition loss
        u_pred_b = self.model(X[1])
        loss_bc = tf.reduce_mean(tf.square(u_data[1] - u_pred_b))

        # 4. First passage time data loss
        loss_data = self._compute_data_loss(rt1, rt2, xmax)

        # Weighted total loss
        total_loss = (LAMBDA_PDE * loss_pde +
                     LAMBDA_IC * loss_ic +
                     LAMBDA_BC * loss_bc +
                     LAMBDA_DATA * loss_data)

        return total_loss

    def _compute_data_loss(self, rt1, rt2, xmax):
        """Compute loss from first passage time data."""
        # Upper boundary
        tspace1 = np.sort(rt1)
        if len(tspace1) > 0:
            t1 = tf.constant(tspace1.reshape(-1, 1), 'float32') - tf.math.sigmoid(self.model.tt0) * self.minRT
            loss_upper = self._compute_boundary_flux_loss(t1, xmax, tspace1, rt1, rt2, upper=True)
        else:
            loss_upper = 0.0

        # Lower boundary
        tspace2 = np.sort(rt2)
        if len(tspace2) > 0:
            t2 = tf.constant(tspace2.reshape(-1, 1), 'float32') - tf.math.sigmoid(self.model.tt0) * self.minRT
            loss_lower = self._compute_boundary_flux_loss(t2, -xmax, tspace2, rt1, rt2, upper=False)
        else:
            loss_lower = 0.0

        return loss_upper + loss_lower

    def _compute_boundary_flux_loss(self, t, x_boundary, tspace, rt1, rt2, upper=True):
        """Compute loss for boundary flux matching."""
        # Create boundary points
        x_points = tf.ones_like(t) * x_boundary
        X_boundary = tf.concat([t, x_points], axis=1)

        # Compute solution at boundary and nearby points for gradient
        with tf.GradientTape() as tape:
            tape.watch(x_points)
            u_boundary = self.model(X_boundary)

        # Compute spatial gradient using automatic differentiation
        u_x = tape.gradient(u_boundary, x_points)

        # Compute flux using Fokker-Planck boundary condition
        drift = self.model.lambd0 - self.model.beta * x_boundary
        J_pred = drift * u_boundary - 0.5 * (self.model.sig)**2 * u_x

        # Empirical distribution from data
        kde = gaussian_kde(tspace)
        if upper:
            weight = len(tspace) / (len(rt1) + len(rt2))
            J_empirical = weight * kde(tspace.flatten())
        else:
            weight = -len(tspace) / (len(rt1) + len(rt2))
            J_empirical = weight * kde(tspace.flatten())

        J_empirical = tf.constant(J_empirical.reshape(-1, 1), dtype=tf.float32)

        return tf.reduce_mean(tf.square(J_pred - J_empirical))

    def get_grad(self, X, xmax, rt1, rt2, u_data):
        """Compute gradients of loss function."""
        with tf.GradientTape() as tape:
            loss = self.loss_fn(X, xmax, rt1, rt2, u_data)

        grads = tape.gradient(loss, self.model.trainable_variables)
        return loss, grads

    def solve_with_TFoptimizer(self, optimizer, X, xmax, rt1, rt2, u_data, N=1001):
        """Solve using TensorFlow optimizer."""
        @tf.function
        def train_step():
            loss, grads = self.get_grad(X, xmax, rt1, rt2, u_data)
            optimizer.apply_gradients(zip(grads, self.model.trainable_variables))
            return loss

        # Training loop
        for i in range(N):
            loss = train_step()
            self.current_loss = loss.numpy()
            self._callback()

            # Early stopping if loss becomes NaN
            if np.isnan(self.current_loss):
                print("Training stopped due to NaN loss")
                break

    def _callback(self):
        """Callback function for training progress."""
        # Track parameters
        self.model.lambd0_list.append(self.model.lambd0.numpy())
        self.model.beta_list.append(self.model.beta.numpy())
        self.model.sig_list.append(self.model.sig.numpy())
        self.model.tt0_list.append((tf.math.sigmoid(self.model.tt0) * self.minRT).numpy())

        if self.iter % 100 == 0:
            print(f'Iter {self.iter:05d}: Loss = {self.current_loss:10.8e}')
            print(f'         drift0 = {self.model.lambd0.numpy():.4f}, '
                  f'beta = {self.model.beta.numpy():.4f}, '
                  f'sigma = {self.model.sig.numpy():.4f}, '
                  f't0 = {(tf.math.sigmoid(self.model.tt0) * self.minRT).numpy():.4f}')

        self.hist.append(self.current_loss)
        self.iter += 1

    # Parameter getter methods
    def get_parameters(self):
        """Get current estimated parameters."""
        return {
            'drift0': float(self.model.lambd0.numpy()),
            'beta': float(self.model.beta.numpy()),
            'sigma': float(self.model.sig.numpy()),
            't0': float((tf.math.sigmoid(self.model.tt0) * self.minRT).numpy())
        }


# =============================================================================
# Utility Functions
# =============================================================================

def dirac_delta_function(x, delta, x0):
    """Approximate Dirac delta function as Gaussian distribution."""
    return 1 / (2 * np.sqrt(np.pi * delta)) * tf.math.exp(-((x - x0)**2) / (4 * delta))


def initial_condition(x):
    """Define initial condition for Fokker-Planck equation."""
    return dirac_delta_function(x, DELTA, 0.0)


def prepare_training_data(lb, ub, n_collocation=100, n_boundary=100, n_initial=50):
    """Prepare collocation, boundary, and initial condition data."""
    # Collocation points
    tspace = np.linspace(lb[0], ub[0], n_collocation + 1)
    xspace = np.linspace(lb[1], ub[1], n_collocation + 1)
    T, X = np.meshgrid(tspace, xspace)
    Xgrid = np.vstack([T.flatten(), X.flatten()]).T
    Xgrid = tf.constant(Xgrid, 'float32')

    # Boundary conditions (both boundaries)
    t_b1 = tf.random.uniform((n_boundary//2, 1), lb[0], ub[0], dtype='float32')
    x_b1 = tf.ones((n_boundary//2, 1), dtype='float32') * ub[1]  # Upper boundary
    X_b1 = tf.concat([t_b1, x_b1], axis=1)
    u_b1 = tf.zeros((n_boundary//2, 1), 'float32')

    t_b2 = tf.random.uniform((n_boundary//2, 1), lb[0], ub[0], dtype='float32')
    x_b2 = tf.ones((n_boundary//2, 1), dtype='float32') * lb[1]  # Lower boundary
    X_b2 = tf.concat([t_b2, x_b2], axis=1)
    u_b2 = tf.zeros((n_boundary//2, 1), 'float32')

    X_b = tf.concat([X_b1, X_b2], axis=0)
    u_b = tf.concat([u_b1, u_b2], axis=0)

    # Initial condition
    t_0 = tf.ones((n_initial, 1), dtype='float32') * lb[0]
    x_0 = tf.linspace(lb[1], ub[1], n_initial)
    x_0 = tf.reshape(x_0, (n_initial, 1))
    u_0 = initial_condition(x_0)
    X_0 = tf.concat([t_0, x_0], axis=1)

    return Xgrid, [X_0, X_b], [u_0, u_b]


# =============================================================================
# Analysis and Visualization
# =============================================================================

def analyze_results(solver, true_params, rt1, rt2):
    """Analyze and display estimation results."""
    estimated = solver.get_parameters()

    print("\n" + "="*60)
    print("PARAMETER ESTIMATION RESULTS")
    print("="*60)
    print(f"{'Parameter':<12} {'True':<10} {'Estimated':<12} {'Error':<10} {'Relative Error':<15}")
    print("-"*60)

    for param in ['drift0', 'beta', 'sigma', 't0']:
        true_val = true_params[param]
        est_val = estimated[param]
        error = abs(true_val - est_val)
        rel_error = error / abs(true_val) * 100

        print(f"{param:<12} {true_val:<10.4f} {est_val:<12.4f} {error:<10.4f} {rel_error:<15.2f}%")

    print("="*60)

    # Plot training history
    plt.figure(figsize=(12, 4))

    plt.subplot(1, 2, 1)
    plt.semilogy(solver.hist)
    plt.xlabel('Iteration')
    plt.ylabel('Loss')
    plt.title('Training Loss History')
    plt.grid(True, alpha=0.3)

    plt.subplot(1, 2, 2)
    plt.plot(solver.model.lambd0_list, label='drift0')
    plt.plot(solver.model.beta_list, label='beta')
    plt.plot(solver.model.sig_list, label='sigma')
    plt.axhline(y=true_params['drift0'], color='r', linestyle='--', alpha=0.7)
    plt.axhline(y=true_params['beta'], color='g', linestyle='--', alpha=0.7)
    plt.axhline(y=true_params['sigma'], color='b', linestyle='--', alpha=0.7)
    plt.xlabel('Iteration')
    plt.ylabel('Parameter Value')
    plt.title('Parameter Convergence')
    plt.legend()
    plt.grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()


# =============================================================================
# Main Execution
# =============================================================================

def main():
    """Main function to run OU-DDM parameter estimation."""
    print("TINN Parameter Estimation for Ornstein-Uhlenbeck DDM")
    print("="*60)

    # True parameters
    true_params = {
        'drift0': DRIFT_0,
        'beta': BETA,
        'sigma': NOISE_STD,
        't0': NON_DECISION_TIME
    }

    # Generate synthetic data
    print("Generating synthetic OU-DDM data...")
    decision_times1, decision_times2, _, _ = simulate_ou_ddm_rk4(
        NUM_SIMULATIONS, TIME_STEPS, BOUNDARY_SEPARATION, STARTING_POINT,
        NOISE_STD, DT, DRIFT_0, BETA)

    # Add non-decision time
    rt1 = decision_times1 + NON_DECISION_TIME
    rt2 = decision_times2 + NON_DECISION_TIME

    print(f"Generated {len(rt1)} upper and {len(rt2)} lower boundary crossings")
    print(f"Mean RT upper: {np.mean(rt1):.3f}s, lower: {np.mean(rt2):.3f}s")

    # Prepare domain
    max_time = max(np.max(rt1) if len(rt1) > 0 else 0,
                   np.max(rt2) if len(rt2) > 0 else 0)
    lb = tf.constant([0.0, XMIN], dtype='float32')
    ub = tf.constant([max_time * 1.1, XMAX], dtype='float32')  # 10% buffer

    # Prepare training data
    Xgrid, X_data, u_data = prepare_training_data(lb, ub, N_COLLOCATION, N_BOUNDARY, N_INITIAL)

    # Initialize model
    model = PINNIdentificationNet(lb, ub,
                                 num_hidden_layers=NUM_HIDDEN_LAYERS,
                                 num_neurons_per_layer=NUM_NEURONS_PER_LAYER,
                                 activation='tanh',
                                 kernel_initializer='glorot_normal')

    # Initialize solver
    solver = F_P_PINNIdentification(model, Xgrid)

    # Set up optimizer with learning rate schedule
    lr_schedule = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
        [1000, 5000, 15000], [1e-2, 5e-3, 1e-3, 5e-4])
    optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule, beta_1=0.9, beta_2=0.999)

    # Train the model
    print("\nStarting TINN training...")
    print("Initial parameters:")
    print(f"  drift0 = {model.lambd0.numpy():.4f}, beta = {model.beta.numpy():.4f}, "
          f"sigma = {model.sig.numpy():.4f}")

    t0 = time()
    solver.solve_with_TFoptimizer(optimizer, X_data, XMAX, rt1, rt2, u_data, N=NUM_EPOCHS)
    training_time = time() - t0

    print(f'\nTraining completed in {training_time:.2f} seconds')

    # Analyze results
    analyze_results(solver, true_params, rt1, rt2)


if __name__ == "__main__":
    main()