# DDM Simulation with TensorFlow Profiling

This notebook demonstrates how to profile a Drift-Diffusion Model simulation in TensorFlow and visualize it with TensorBoard.

**Import libraries**

In [None]:
import tensorflow as tf
from datetime import datetime
import os
import numpy as np
from pathlib import Path

# Ensure we can see GPU devices if available
print("TensorFlow version:", tf.__version__)
print("GPU available:", len(tf.config.list_physical_devices('GPU')) > 0)

## Tensorflow DDM simulation function

In [None]:
def ddm_tensorflow(n_trials=1000, max_t=10.0, dt=0.001, drift=0.5, noise=0.1, bound=1.0):
    """
    Clean implementation of the Drift Diffusion Model in TensorFlow.

    Parameters:
    -----------
    n_trials : int
        Number of simulated trials
    max_t : float
        Maximum decision time in seconds
    dt : float
        Time step size in seconds
    drift : float
        Drift rate (evidence accumulation rate)
    noise : float
        Noise standard deviation
    bound : float
        Decision boundary

    Returns:
    --------
    A function that when called returns (responses, reaction_times)
    """
    n_steps = int(max_t / dt)  # Number of time steps

    @tf.function
    def simulate():
        # Generate random noise for all trials and time steps
        # Noise follows: σ * √(dt) * N(0,1) where σ is noise parameter
        noise_terms = noise * tf.sqrt(dt) * tf.random.normal([n_trials, n_steps])

        # Create drift terms: v * dt for each time step
        # This represents the systematic evidence accumulation
        drift_terms = tf.ones([n_trials, n_steps]) * drift * dt

        # Accumulate evidence over time: X(t) = Σ(v*dt + σ*dW)
        # cumsum performs the integration over time
        evidence = tf.cumsum(drift_terms + noise_terms, axis=1)

        # Add initial condition: evidence starts at 0 at t=0
        # This ensures we don't miss very fast responses
        evidence = tf.pad(evidence, [[0, 0], [1, 0]], constant_values=0.0)

        # Check which trials crossed each boundary
        crossed_up = evidence >= bound    # Upper boundary crossings
        crossed_down = evidence <= -bound  # Lower boundary crossings
        any_crossing = crossed_up | crossed_down  # Any boundary crossing

        # Calculate response times:
        # For trials that crossed: time = crossing_step * dt
        # For trials that didn't cross: time = max_t (timeout)
        rts = tf.where(
            tf.reduce_any(any_crossing, axis=1),  # If any crossing occurred
            dt * tf.cast(tf.argmax(tf.cast(any_crossing, tf.int32), axis=1), tf.float32),
            tf.constant(max_t, dtype=tf.float32)  # Timeout value
        )

        # Determine responses:
        # 1 = upper boundary, 0 = lower boundary, -1 = no decision
        responses = tf.where(
            tf.reduce_any(crossed_up, axis=1),     # If upper boundary was crossed
            tf.constant(1, dtype=tf.int32),       # Response = 1
            tf.where(                              # Else check lower boundary
                tf.reduce_any(crossed_down, axis=1),
                tf.constant(0, dtype=tf.int32),    # Response = 0
                tf.constant(-1, dtype=tf.int32)    # No decision
            )
        )

        return responses, rts

    return simulate


# Tensorflow DDM function with scoping (to structure the graph better)

In [None]:
def ddm_tensorflow(n_trials=1000, max_t=10.0, dt=0.001, drift=0.5, noise=0.1, bound=1.0):
    """
    TensorFlow DDM implementation with scoped operations for better visualization.
    """
    n_steps = int(max_t / dt)

    @tf.function
    def simulate():
        with tf.name_scope('DDM_Simulation'):
            # Create constants for parameters with names (using different variable names)
            with tf.name_scope('Parameters'):
                n_trials_const = tf.constant(n_trials, dtype=tf.int32, name='n_trials')
                n_steps_const = tf.constant(n_steps, dtype=tf.int32, name='n_steps')
                dt_const = tf.constant(dt, dtype=tf.float32, name='dt')
                drift_const = tf.constant(drift, dtype=tf.float32, name='drift')
                noise_const = tf.constant(noise, dtype=tf.float32, name='noise')
                bound_const = tf.constant(bound, dtype=tf.float32, name='bound')
                max_t_const = tf.constant(max_t, dtype=tf.float32, name='max_t')

            # Generate random noise for all trials and time steps
            with tf.name_scope('Noise_Generation'):
                noise_terms = noise_const * tf.sqrt(dt_const) * tf.random.normal([n_trials, n_steps])

            # Create drift terms: v * dt for each time step
            with tf.name_scope('Drift_Terms'):
                drift_terms = tf.ones([n_trials, n_steps]) * drift_const * dt_const

            # Accumulate evidence over time: X(t) = Σ(v*dt + σ*dW)
            with tf.name_scope('Evidence_Accumulation'):
                evidence = tf.cumsum(drift_terms + noise_terms, axis=1)

            # Add initial condition: evidence starts at 0 at t=0
            with tf.name_scope('Initial_Condition'):
                evidence = tf.pad(evidence, [[0, 0], [1, 0]], constant_values=0.0)

            # Check which trials crossed each boundary
            with tf.name_scope('Boundary_Checks'):
                crossed_up = evidence >= bound_const    # Upper boundary crossings
                crossed_down = evidence <= -bound_const  # Lower boundary crossings
                any_crossing = crossed_up | crossed_down  # Any boundary crossing

            # Calculate response times
            with tf.name_scope('Response_Time_Calculation'):
                rts = tf.where(
                    tf.reduce_any(any_crossing, axis=1),  # If any crossing occurred
                    dt_const * tf.cast(tf.argmax(tf.cast(any_crossing, tf.int32), axis=1), tf.float32),
                    tf.constant(max_t, dtype=tf.float32)  # Timeout value
                )

            # Determine responses
            with tf.name_scope('Response_Determination'):
                responses = tf.where(
                    tf.reduce_any(crossed_up, axis=1),     # If upper boundary was crossed
                    tf.constant(1, dtype=tf.int32),       # Response = 1
                    tf.where(                              # Else check lower boundary
                        tf.reduce_any(crossed_down, axis=1),
                        tf.constant(0, dtype=tf.int32),    # Response = 0
                        tf.constant(-1, dtype=tf.int32)    # No decision
                    )
                )
        return responses, rts
    return simulate


# Run simulation with profiling
Create the function and get log directory

In [None]:
# Cell 3: Run and Visualize
from datetime import datetime
import os
import numpy as np

# Clear any previous runs
tf.keras.backend.clear_session()

# Create a log directory with timestamp
log_dir = os.path.join("tf_logs", f"ddm_{datetime.now().strftime('%Y%m%d_%H%M%S')}")
os.makedirs(log_dir, exist_ok=True)

# Get the simulation function
simulate = ddm_tensorflow(n_trials=10000)

# Run with graph tracing
tf.summary.trace_on(graph=True)
responses, rts = simulate()

# Save the graph
with tf.summary.create_file_writer(log_dir).as_default():
    tf.summary.trace_export(name="DDM_Simulation", step=0)
tf.summary.trace_off()

# Convert results to numpy arrays
responses_np = responses.numpy()
rts_np = rts.numpy()

# Print results
print("DDM Simulation Complete!")
print(f"- Total trials: {len(responses_np)}")
print(f"- Mean response time: {np.mean(rts_np[rts_np < max_t]):.3f} seconds")
print(f"- Upper bound choices: {np.mean(responses_np == 1)*100:.1f}%")
print(f"- Lower bound choices: {np.mean(responses_np == 0)*100:.1f}%")
print(f"- No decision: {np.mean(responses_np == -1)*100:.1f}%")
print(f"\nView the computation graph with:")
print(f"tensorboard --logdir {log_dir}")
print("In TensorBoard, look for the 'DDM_Simulation' graph")
