In [None]:
"""
Theoretical-Informed Neural Network (TINN) for Cognitive Model Parameter Estimation
from Experimental Reaction Time Data

Q1 Journal Paper: Multi-condition parameter estimation using physics-informed deep learning
Experimental data from FDBNCRW2008 dataset with speed, neutral, and accuracy conditions
"""

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from time import time
from mpl_toolkits.mplot3d import Axes3D
import scipy.optimize
import pandas as pd
from scipy.stats import gaussian_kde
from scipy.optimize import minimize

# =============================================================================
# Configuration
# =============================================================================

# TINN Architecture
NUM_HIDDEN_LAYERS = 4
NUM_NEURONS_PER_LAYER = 30
ACTIVATION = 'tanh'
KERNEL_INITIALIZER = 'glorot_normal'
OUTPUT_DIM = 3  # For three experimental conditions

# Training Parameters
NUM_EPOCHS = 20000
LEARNING_RATE_SCHEDULE = [1000, 15000]
LEARNING_RATES = [0.05, 0.001, 0.0005]

# Domain Parameters
THRESHOLD = 1.0
N_COLLOCATION = 100
N_BOUNDARY = 500
N_INITIAL = 50

# =============================================================================
# TINN Architecture for Multi-Condition Cognitive Modeling
# =============================================================================

class TINN_CognitiveNet(tf.keras.Model):
    """
    Theoretical-Informed Neural Network for cognitive model parameter estimation.
    Simultaneously models three experimental conditions with separate noise parameters.
    """

    def __init__(self, lb, ub, output_dim=3, num_hidden_layers=5,
                 num_neurons_per_layer=50, activation='tanh',
                 kernel_initializer='glorot_normal', **kwargs):
        super().__init__(**kwargs)

        self.num_hidden_layers = num_hidden_layers
        self.output_dim = output_dim
        self.lb = lb
        self.ub = ub

        # Neural network architecture
        self.hidden_layers = [
            tf.keras.layers.Dense(
                num_neurons_per_layer,
                activation=tf.keras.activations.get(activation),
                kernel_initializer=kernel_initializer
            ) for _ in range(self.num_hidden_layers - 1)
        ]
        self.final_hidden = tf.keras.layers.Dense(
            num_neurons_per_layer, activation='softplus'
        )
        self.output_layer = tf.keras.layers.Dense(output_dim)

        # Cognitive model parameters
        self.drift_param = self.add_weight(
            name="drift_param", initializer="ones", trainable=True, dtype=tf.float32
        )
        self.noise_speed = self.add_weight(
            name="noise_speed", initializer="ones", trainable=True, dtype=tf.float32
        )
        self.noise_neutral = self.add_weight(
            name="noise_neutral", initializer="ones", trainable=True, dtype=tf.float32
        )
        self.noise_accuracy = self.add_weight(
            name="noise_accuracy", initializer="ones", trainable=True, dtype=tf.float32
        )
        self.non_decision_param = self.add_weight(
            name="non_decision_param", initializer="ones", trainable=True, dtype=tf.float32
        )

        # Training history
        self.drift_history = []
        self.noise_speed_history = []
        self.noise_neutral_history = []
        self.noise_accuracy_history = []
        self.non_decision_history = []

    def call(self, X):
        """Forward pass through the network."""
        Z = X
        for layer in self.hidden_layers:
            Z = layer(Z)
        Z = self.final_hidden(Z)
        return self.output_layer(Z)

    @property
    def drift_rate(self):
        """Get positive drift rate using softplus."""
        return tf.nn.softplus(self.drift_param)

    @property
    def non_decision_time(self):
        """Get non-decision time with sigmoid scaling."""
        return tf.math.sigmoid(self.non_decision_param)

# =============================================================================
# TINN Cognitive Model Solver
# =============================================================================

class TINN_CognitiveSolver:
    """
    Solver for cognitive model parameter estimation using TINN.
    Handles multiple experimental conditions with correct/incorrect responses.
    """

    def __init__(self, model, collocation_points, reaction_times_dict):
        self.model = model
        self.t = collocation_points[:, 0:1]
        self.x = collocation_points[:, 1:2]
        self.reaction_times = reaction_times_dict

        # Training history
        self.loss_history = []
        self.iteration = 0
        self.min_reaction_time = 0.0

    def compute_pde_residuals(self):
        """
        Compute PDE residuals for all three conditions using automatic differentiation.
        """
        with tf.GradientTape(persistent=True) as tape:
            tape.watch(self.t)
            tape.watch(self.x)

            # Compute solutions and gradients for all conditions
            s = self.model(tf.stack([self.t[:, 0], self.x[:, 0]], axis=1))
            u_speed = s[:, 0:1]
            u_neutral = s[:, 1:2]
            u_accuracy = s[:, 2:3]

            u_speed_x = tape.gradient(u_speed, self.x)
            u_neutral_x = tape.gradient(u_neutral, self.x)
            u_accuracy_x = tape.gradient(u_accuracy, self.x)

        u_speed_t = tape.gradient(u_speed, self.t)
        u_speed_xx = tape.gradient(u_speed_x, self.x)
        u_neutral_t = tape.gradient(u_neutral, self.t)
        u_neutral_xx = tape.gradient(u_neutral_x, self.x)
        u_accuracy_t = tape.gradient(u_accuracy, self.t)
        u_accuracy_xx = tape.gradient(u_accuracy_x, self.x)

        del tape

        # Return residuals for all conditions
        return (
            self._pde_residual(u_speed, u_speed_t, u_speed_x, u_speed_xx, self.model.noise_speed),
            self._pde_residual(u_neutral, u_neutral_t, u_neutral_x, u_neutral_xx, self.model.noise_neutral),
            self._pde_residual(u_accuracy, u_accuracy_t, u_accuracy_x, u_accuracy_xx, self.model.noise_accuracy)
        )

    def _pde_residual(self, u, u_t, u_x, u_xx, noise_param):
        """Fokker-Planck equation residual."""
        return u_t + self.model.drift_rate * u_x - 0.5 * (noise_param**2) * u_xx

    def loss_function(self, boundary_data, boundary_values):
        """
        Comprehensive loss function combining PDE residuals and data fitting.
        """
        # PDE residual losses
        residual_speed, residual_neutral, residual_accuracy = self.compute_pde_residuals()
        pde_loss = (
            tf.reduce_mean(tf.square(residual_speed)) +
            tf.reduce_mean(tf.square(residual_neutral)) +
            tf.reduce_mean(tf.square(residual_accuracy))
        )

        # Boundary condition losses
        u_pred_initial = self.model(boundary_data[0])
        initial_loss = tf.reduce_mean(tf.square(boundary_values[0] - u_pred_initial))

        u_pred_boundary = self.model(boundary_data[1])
        boundary_loss = tf.reduce_mean(tf.square(boundary_values[1] - u_pred_boundary))

        # Data fitting losses for all conditions
        data_losses = self._compute_data_fitting_losses()

        # Combined loss
        total_loss = 1.0 * (pde_loss + initial_loss + boundary_loss) + data_losses

        return total_loss

    def _compute_data_fitting_losses(self):
        """
        Compute data fitting losses for all experimental conditions.
        """
        # Update minimum reaction time across all conditions
        all_times = []
        for key in self.reaction_times:
            all_times.extend(self.reaction_times[key])
        self.min_reaction_time = min(all_times) if all_times else 0.1

        total_data_loss = 0.0

        # Speed condition (correct and incorrect)
        total_data_loss += self._compute_condition_loss(
            self.reaction_times['speed_correct'], self.model.ub[1], 0, True
        )
        total_data_loss += self._compute_condition_loss(
            self.reaction_times['speed_incorrect'], self.model.lb[1], 0, False
        )

        # Neutral condition (correct and incorrect)
        total_data_loss += self._compute_condition_loss(
            self.reaction_times['neutral_correct'], self.model.ub[1], 1, True
        )
        total_data_loss += self._compute_condition_loss(
            self.reaction_times['neutral_incorrect'], self.model.lb[1], 1, False
        )

        # Accuracy condition (correct and incorrect)
        total_data_loss += self._compute_condition_loss(
            self.reaction_times['accuracy_correct'], self.model.ub[1], 2, True
        )
        total_data_loss += self._compute_condition_loss(
            self.reaction_times['accuracy_incorrect'], self.model.lb[1], 2, False
        )

        return total_data_loss

    def _compute_condition_loss(self, reaction_times, boundary_position, output_index, upper_boundary):
        """
        Compute loss for a single condition.
        """
        if len(reaction_times) == 0:
            return tf.constant(0.0, dtype=tf.float32)

        sorted_times = np.sort(reaction_times)

        # Prepare time points for prediction
        times_tensor = tf.constant(
            sorted_times.reshape((len(sorted_times), 1)), tf.float32
        )
        adjusted_times = times_tensor - self.model.non_decision_time * self.min_reaction_time

        # Compute predicted flux
        predicted_flux = self._compute_boundary_flux(adjusted_times, boundary_position, output_index, upper_boundary)

        # Empirical distribution using KDE
        kde = gaussian_kde(sorted_times)
        empirical_density = len(sorted_times) * kde(sorted_times)[:, np.newaxis] / len(sorted_times)
        empirical_tensor = tf.convert_to_tensor(empirical_density, dtype=tf.float32)

        if not upper_boundary:
            empirical_tensor = -empirical_tensor

        return tf.reduce_mean(tf.square(predicted_flux - empirical_tensor))

    def _compute_boundary_flux(self, times_tensor, boundary_pos, output_index, upper_boundary):
        """
        Compute flux at boundary using finite differences.
        """
        batch_size = tf.shape(times_tensor)[0]

        # Boundary points
        x_boundary = tf.ones((batch_size, 1), dtype=tf.float32) * boundary_pos
        X_boundary = tf.concat([times_tensor, x_boundary], axis=1)
        p_boundary = self.model(X_boundary)[:, output_index:output_index+1]

        # Points for derivative calculation
        dx = 0.02 if upper_boundary else 0.05
        offset = -dx if upper_boundary else dx

        x_offset1 = tf.ones((batch_size, 1), dtype=tf.float32) * (boundary_pos + offset)
        X_offset1 = tf.concat([times_tensor, x_offset1], axis=1)
        p_offset1 = self.model(X_offset1)[:, output_index:output_index+1]

        x_offset2 = tf.ones((batch_size, 1), dtype=tf.float32) * (boundary_pos + 2 * offset)
        X_offset2 = tf.concat([times_tensor, x_offset2], axis=1)
        p_offset2 = self.model(X_offset2)[:, output_index:output_index+1]

        # Finite difference derivative
        if upper_boundary:
            p_x = (3 * p_boundary - 4 * p_offset1 + p_offset2) / (2 * dx)
        else:
            p_x = (-3 * p_boundary + 4 * p_offset1 - p_offset2) / (2 * dx)

        # Get appropriate noise parameter
        if output_index == 0:
            noise_param = self.model.noise_speed
        elif output_index == 1:
            noise_param = self.model.noise_neutral
        else:
            noise_param = self.model.noise_accuracy

        # Boundary flux
        flux = self.model.drift_rate * p_boundary - 0.5 * (noise_param**2) * p_x

        return flux

    def compute_gradients(self, boundary_data, boundary_values):
        """Compute gradients of the loss function."""
        with tf.GradientTape() as tape:
            loss = self.loss_function(boundary_data, boundary_values)

        gradients = tape.gradient(loss, self.model.trainable_variables)
        return loss, gradients

    def train_step(self, optimizer, boundary_data, boundary_values):
        """Single training step."""
        loss, gradients = self.compute_gradients(boundary_data, boundary_values)
        optimizer.apply_gradients(zip(gradients, self.model.trainable_variables))
        return loss

    def fit(self, boundary_data, boundary_values, epochs=1000):
        """Train the TINN model."""
        lr_schedule = tf.keras.optimizers.schedules.PiecewiseConstantDecay(
            LEARNING_RATE_SCHEDULE, LEARNING_RATES
        )
        optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)

        def training_step():
            return self.train_step(optimizer, boundary_data, boundary_values)

        print("Starting TINN cognitive model fitting...")
        start_time = time()

        for epoch in range(epochs):
            loss = training_step()
            self._training_callback(epoch, loss)

        training_time = time() - start_time
        print(f'\nTraining completed in {training_time:.2f} seconds')

    def _training_callback(self, epoch, loss):
        """Training progress monitoring."""
        current_loss = loss.numpy()
        drift = self.model.drift_rate.numpy()
        noise_speed = self.model.noise_speed.numpy()
        noise_neutral = self.model.noise_neutral.numpy()
        noise_accuracy = self.model.noise_accuracy.numpy()
        non_decision = self.model.non_decision_time.numpy() * self.min_reaction_time

        # Store parameter history
        self.model.drift_history.append(drift)
        self.model.noise_speed_history.append(noise_speed)
        self.model.noise_neutral_history.append(noise_neutral)
        self.model.noise_accuracy_history.append(noise_accuracy)
        self.model.non_decision_history.append(non_decision)

        if epoch % 100 == 0:
            print(
                f'Epoch {epoch:05d}: Loss = {current_loss:10.8e}, '
                f'Drift = {drift:10.8e}, '
                f'Noise (S/N/A) = ({noise_speed:10.8e}/{noise_neutral:10.8e}/{noise_accuracy:10.8e}), '
                f'Non-decision = {non_decision:10.8e}'
            )

        self.loss_history.append(current_loss)
        self.iteration += 1

# =============================================================================
# Data Preparation and Visualization
# =============================================================================

def load_experimental_data(file_path):
    """Load experimental data from RData file."""
    try:
        import pyreadr
        result = pyreadr.read_r(file_path)
        return result["data"]
    except ImportError:
        print("pyreadr not available. Using synthetic data for demonstration.")
        return create_synthetic_data()

def create_synthetic_data():
    """Create synthetic data for testing when real data is not available."""
    np.random.seed(42)
    n_trials = 100

    synthetic_data = {
        'speed_correct': np.random.gamma(2, 0.3, n_trials) + 0.2,
        'speed_incorrect': np.random.gamma(2, 0.4, n_trials) + 0.2,
        'neutral_correct': np.random.gamma(2, 0.35, n_trials) + 0.2,
        'neutral_incorrect': np.random.gamma(2, 0.45, n_trials) + 0.2,
        'accuracy_correct': np.random.gamma(2, 0.4, n_trials) + 0.2,
        'accuracy_incorrect': np.random.gamma(2, 0.5, n_trials) + 0.2
    }

    return synthetic_data

def prepare_training_data(reaction_times_dict, threshold=THRESHOLD):
    """Prepare domain and training data for TINN."""
    # Define domain boundaries
    xmax = threshold
    xmin = -xmax

    # Find maximum reaction time
    max_time = 0
    for key in reaction_times_dict:
        if len(reaction_times_dict[key]) > 0:
            max_time = max(max_time, max(reaction_times_dict[key]))

    lb = tf.constant([0.0, xmin], dtype=tf.float32)
    ub = tf.constant([max_time, xmax], dtype=tf.float32)

    # Collocation points
    tspace = np.linspace(lb[0], ub[0], N_COLLOCATION + 1)
    xspace = np.linspace(lb[1], ub[1], N_COLLOCATION + 1)
    T, X = np.meshgrid(tspace, xspace)
    X_grid = np.vstack([T.flatten(), X.flatten()]).T
    collocation_points = tf.constant(X_grid, dtype=tf.float32)

    # Boundary and initial conditions
    boundary_data, boundary_values = _prepare_boundary_conditions(lb, ub)

    return collocation_points, boundary_data, boundary_values, lb, ub

def _prepare_boundary_conditions(lb, ub, n_boundary=N_BOUNDARY, n_initial=N_INITIAL):
    """Prepare boundary and initial condition data."""
    # Initial condition (Gaussian distribution)
    def initial_condition(x, delta=7.8e-2, x0=0.0):
        return 1 / (2 * np.sqrt(np.pi * delta)) * tf.math.exp(-((x - x0)**2) / (4 * delta))

    # Initial condition data
    t_initial = tf.ones((n_initial, 1), dtype=tf.float32) * lb[0]
    x_initial = np.linspace(lb[1], ub[1], n_initial - 1, dtype=np.float32)
    x_initial = np.sort(np.concatenate([x_initial, [0.0]]))
    x_initial = tf.constant(x_initial.reshape((n_initial, 1)), dtype=tf.float32)
    u_initial = initial_condition(x_initial)
    X_initial = tf.concat([t_initial, x_initial], axis=1)

    # Boundary condition data
    t_boundary = tf.random.uniform((n_boundary, 1), lb[0], ub[0], dtype=tf.float32)
    x_boundary = lb[1] + (ub[1] - lb[1]) * tf.keras.backend.random_bernoulli(
        (n_boundary, 1), 0.5, dtype=tf.float32
    )
    X_boundary = tf.concat([t_boundary, x_boundary], axis=1)
    u_boundary = tf.zeros_like(x_boundary, dtype=tf.float32)

    return [X_initial, X_boundary], [u_initial, u_boundary]

def plot_training_convergence(solver):
    """Plot training convergence and parameter evolution."""
    fig, axes = plt.subplots(2, 2, figsize=(12, 8))

    # Loss history
    axes[0, 0].semilogy(solver.loss_history, 'k-', linewidth=2)
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].set_title('Training Loss Convergence')
    axes[0, 0].grid(True, alpha=0.3)

    # Drift rate convergence
    axes[0, 1].plot(solver.model.drift_history, 'b-', linewidth=2)
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Drift Rate')
    axes[0, 1].set_title('Drift Rate Estimation')
    axes[0, 1].grid(True, alpha=0.3)

    # Noise parameters convergence
    axes[1, 0].plot(solver.model.noise_speed_history, 'r-', linewidth=2, label='Speed')
    axes[1, 0].plot(solver.model.noise_neutral_history, 'g-', linewidth=2, label='Neutral')
    axes[1, 0].plot(solver.model.noise_accuracy_history, 'b-', linewidth=2, label='Accuracy')
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('Noise Parameters')
    axes[1, 0].set_title('Noise Parameters Estimation')
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)

    # Non-decision time convergence
    axes[1, 1].plot(solver.model.non_decision_history, 'm-', linewidth=2)
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('Non-decision Time')
    axes[1, 1].set_title('Non-decision Time Estimation')
    axes[1, 1].grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig("tinn_cognitive_convergence.pdf", format="pdf", bbox_inches="tight")
    plt.show()

# =============================================================================
# Main Execution
# =============================================================================

def main():
    """Main function for TINN cognitive model fitting."""
    print("TINN Cognitive Model Parameter Estimation")
    print("=" * 60)

    # Load experimental data
    data_file_path = '.../FDBNCRW2008.RData'
    experimental_data = load_experimental_data(data_file_path)

    # Prepare reaction time data structure
    reaction_times = {
        'speed_correct': [],
        'speed_incorrect': [],
        'neutral_correct': [],
        'neutral_incorrect': [],
        'accuracy_correct': [],
        'accuracy_incorrect': []
    }

    # Process experimental data (example for first subject)
    if hasattr(experimental_data, 'subj'):
        subjects = np.unique(experimental_data.subj)
        if len(subjects) > 0:
            first_subject = subjects[0]
            for condition in ['speed', 'neutral', 'accuracy']:
                sub_data = experimental_data.loc[
                    (experimental_data.subj == first_subject) &
                    (experimental_data.instruction == condition)
                ]
                reaction_times[f'{condition}_correct'] = list(
                    sub_data.loc[sub_data.correct == True].RT
                )
                reaction_times[f'{condition}_incorrect'] = list(
                    sub_data.loc[sub_data.correct == False].RT
                )
    else:
        # Use synthetic data if experimental data not available
        reaction_times = create_synthetic_data()

    print("Data summary:")
    for key, values in reaction_times.items():
        if len(values) > 0:
            print(f"  {key}: {len(values)} trials, RT range: {min(values):.3f}-{max(values):.3f}")

    # Prepare training data
    collocation_points, boundary_data, boundary_values, lb, ub = prepare_training_data(reaction_times)

    # Initialize TINN model
    model = TINN_CognitiveNet(
        lb, ub,
        num_hidden_layers=NUM_HIDDEN_LAYERS,
        num_neurons_per_layer=NUM_NEURONS_PER_LAYER,
        activation=ACTIVATION,
        kernel_initializer=KERNEL_INITIALIZER
    )

    # Initialize solver
    solver = TINN_CognitiveSolver(model, collocation_points, reaction_times)

    # Train model
    solver.fit(boundary_data, boundary_values, epochs=NUM_EPOCHS)

    # Display results
    print("\n" + "=" * 60)
    print("COGNITIVE MODEL PARAMETER ESTIMATION RESULTS")
    print("=" * 60)
    print(f"Estimated drift rate: {model.drift_rate.numpy():.4f}")
    print(f"Estimated noise (Speed): {model.noise_speed.numpy():.4f}")
    print(f"Estimated noise (Neutral): {model.noise_neutral.numpy():.4f}")
    print(f"Estimated noise (Accuracy): {model.noise_accuracy.numpy():.4f}")
    print(f"Estimated non-decision time: {solver.model.non_decision_time.numpy() * solver.min_reaction_time:.4f}")

    # Generate plots
    plot_training_convergence(solver)

    print("\nCognitive model fitting completed successfully!")

if __name__ == "__main__":
    main()