In [17]:
# -*- coding: utf-8 -*-
"""
This script provides a complete implementation of the Recurrent Marked Temporal
Point Process (RMTPP) model using Flax's nnx module. It includes data
simulation, model definition, training, and prediction.
"""

import functools
from typing import Dict, Tuple, List

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import optax
import seaborn as sns
from flax import nnx
from jax import random
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

# Constants for data simulation and model configuration
N_SEQS = 200  # Increased for a larger test set
N_MARKS = 5
MAX_SEQ_LEN = 50
HIDDEN_SIZE = 32
LEARNING_RATE = 1e-3
EPOCHS = 200
BATCH_SIZE = 16
TEST_SPLIT_RATIO = 0.2


def simulate_data(
    n_sequences: int, max_len: int, n_marks: int
) -> List[Dict[str, jnp.ndarray]]:
    """
    Simulates marked temporal point process data.

    A simple Hawkes-like process is simulated where certain event marks can
    excite or inhibit other marks, influencing the time to the next event.

    Args:
        n_sequences: The number of sequences to generate.
        max_len: The maximum length of any sequence.
        n_marks: The number of unique event types (marks).

    Returns:
        A list of sequences. Each sequence is a dictionary containing the
        event times, marks, and a mask for padding.
    """
    key = random.PRNGKey(42)
    sequences = []

    # Create an influence matrix: mark_i -> mark_j
    key, subkey = random.split(key)
    influence_matrix = random.uniform(
        subkey, (n_marks, n_marks), minval=-0.5, maxval=1.0
    )

    for _ in range(n_sequences):
        key, subkey_len, subkey_time, subkey_mark = random.split(key, 4)
        seq_len = random.randint(subkey_len, (), minval=10, maxval=max_len)

        times = np.zeros(max_len, dtype=np.float32)
        marks = np.zeros(max_len, dtype=np.int32)

        # Initialize first event
        marks[0] = random.randint(subkey_mark, (), 0, n_marks)
        times[0] = random.exponential(subkey_time, ()) * 0.1

        # Generate subsequent events
        for i in range(1, seq_len):
            key, subkey_time, subkey_mark = random.split(key, 3)

            # Intensity is influenced by the previous mark
            last_mark = marks[i - 1]
            base_rate = 1.0
            # Next event's rate is influenced by all possible next marks
            rates = base_rate + influence_matrix[last_mark, :]
            rates = jnp.maximum(rates, 0.1)  # Ensure positive rates

            # Sample time from exponential dist with the sum of rates
            total_rate = jnp.sum(rates)
            inter_event_time = random.exponential(subkey_time, ()) / total_rate
            times[i] = times[i - 1] + inter_event_time

            # Sample next mark based on individual rates
            probabilities = rates / total_rate
            marks[i] = random.choice(subkey_mark, n_marks, p=probabilities)

        mask = np.arange(max_len) < seq_len
        sequences.append(
            {
                "times": jnp.asarray(times),
                "marks": jnp.asarray(marks),
                "mask": jnp.asarray(mask),
            }
        )
    return sequences


class RMTPP(nnx.Module):
    """
        Recurrent Marked Temporal Point Process (RMTPP) model.

        This class implements the architecture described in the paper, using an RNN
        to learn a representation of event history and predict the time and mark
    of
        the next event.
    """

    def __init__(
        self,
        hidden_size: int,
        n_marks: int,
        *,
        rngs: nnx.Rngs,
    ):
        """
        Initializes the RMTPP model layers.

        Args:
            hidden_size: The dimensionality of the hidden state vector.
            n_marks: The number of unique event marks.
            rngs: JAX random number generators.
        """
        self.embedding = nnx.Embed(
            num_embeddings=n_marks, features=hidden_size, rngs=rngs
        )
        self.recurrent_update_w = nnx.Linear(
            in_features=hidden_size, out_features=hidden_size, use_bias=False, rngs=rngs
        )
        self.recurrent_update_h = nnx.Linear(
            in_features=hidden_size, out_features=hidden_size, use_bias=True, rngs=rngs
        )
        self.recurrent_update_t = nnx.Linear(
            in_features=1, out_features=hidden_size, use_bias=False, rngs=rngs
        )

        # Output layers for intensity function parameters and mark prediction
        self.v_proj = nnx.Linear(hidden_size, 1, use_bias=False, rngs=rngs)
        self.w_proj = nnx.Linear(hidden_size, 1, use_bias=False, rngs=rngs)
        self.b_proj = nnx.Linear(hidden_size, 1, use_bias=True, rngs=rngs)
        self.mark_proj = nnx.Linear(hidden_size, n_marks, use_bias=True, rngs=rngs)

    def __call__(
        self, times: jnp.ndarray, marks: jnp.ndarray
    ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
        """
        Performs the forward pass of the model.

        Args:
            times: A (batch, seq_len) array of event times.
            marks: A (batch, seq_len) array of event marks.

        Returns:
            A tuple containing:
            - Predicted logits for the next mark.
            - The 'v' parameter of the intensity function.
            - The 'w' parameter of the intensity function.
            - The 'b' parameter of the intensity function.
        """
        # Calculate inter-event times (d_j = t_j - t_{j-1})
        inter_event_times = jnp.diff(times, prepend=0)
        inter_event_times = jnp.expand_dims(inter_event_times, -1)

        # Scan over the sequence to update the hidden state
        def rnn_step(h_prev, inputs):
            mark_j, time_j = inputs
            w_emb = self.embedding(mark_j)

            # Equation (9) from the paper
            h_j = jax.nn.relu(
                self.recurrent_update_w(w_emb)
                + self.recurrent_update_h(h_prev)
                + self.recurrent_update_t(time_j)
            )
            return h_j, h_j

        # Initialize hidden state
        batch_size = times.shape[0]
        h0 = jnp.zeros((batch_size, self.recurrent_update_h.in_features))

        # Prepare inputs for scan: shape must be (seq_len, batch_size, ...)
        marks_scannable = marks.T
        inter_event_times_scannable = jnp.transpose(inter_event_times, (1, 0, 2))

        # Run the RNN
        scan_fn = nnx.scan(rnn_step)
        _, h_sequence = scan_fn(h0, (marks_scannable, inter_event_times_scannable))
        h_sequence = h_sequence.transpose((1, 0, 2))

        # Project hidden states to get predictions
        mark_logits = self.mark_proj(h_sequence)

        # Stabilize the outputs for v and b to prevent overflow in the exp function
        v_unconstrained = self.v_proj(h_sequence)
        b_unconstrained = self.b_proj(h_sequence)
        v = jax.nn.tanh(v_unconstrained) * 5.0
        b = jax.nn.tanh(b_unconstrained) * 5.0

        # Softplus ensures w > 0, and epsilon prevents division by zero.
        w = jax.nn.softplus(self.w_proj(h_sequence)) + 1e-6

        return mark_logits, v, w, b


def compute_loss(model: RMTPP, batch: Dict[str, jnp.ndarray]) -> jnp.ndarray:
    """
    Computes the negative log-likelihood loss for a batch of sequences.

    Args:
        model: The RMTPP model instance.
        batch: A dictionary containing 'times', 'marks', and 'mask'.

    Returns:
        The total loss for the batch.
    """
    times, marks, mask = batch["times"], batch["marks"], batch["mask"]

    # The last valid event in each sequence has no future event to predict.
    # We create a new mask to exclude the last event of each sequence from loss calculation.
    is_last_event = jnp.cumsum(mask, axis=1) == jnp.sum(mask, axis=1, keepdims=True)
    prediction_mask = mask & ~is_last_event

    # Shift targets for prediction: we predict event j+1 from event j
    targets_marks = jnp.roll(marks, -1, axis=1)
    targets_times = jnp.roll(times, -1, axis=1)
    inter_event_times_target = targets_times - times

    mark_logits, v, w, b = model(times, marks)
    v, w, b = v.squeeze(-1), w.squeeze(-1), b.squeeze(-1)

    # 1. Mark Loss (Cross-Entropy)
    mark_loss = optax.softmax_cross_entropy_with_integer_labels(
        mark_logits, targets_marks
    )
    # Apply mask to ignore loss from padded elements and the last real event
    mark_loss = jnp.sum(mark_loss * prediction_mask) / jnp.sum(prediction_mask)

    # 2. Time Loss (Negative Log-Likelihood)
    # Using the more stable formulation: -log(lambda) + integral(lambda)
    log_lambda = v + w * inter_event_times_target + b

    integral_lambda = (1 / w) * (
        jnp.exp(v + w * inter_event_times_target + b) - jnp.exp(v + b)
    )

    time_loss = -log_lambda + integral_lambda
    time_loss = jnp.sum(time_loss * prediction_mask) / jnp.sum(prediction_mask)

    return mark_loss + time_loss


@functools.partial(nnx.jit, static_argnums=(0, 1))
def train_step(
    graphdef: nnx.GraphDef,
    tx: optax.GradientTransformation,
    optimizer_state: optax.OptState,
    params: nnx.State,
    batch: Dict[str, jnp.ndarray],
) -> Tuple[float, optax.OptState, nnx.State]:
    """
    Performs a single training step within a JIT-compiled function.
    This version works with a split model (GraphDef and State) for compatibility
    with JAX transformations.

    Args:
        graphdef: The static graph definition of the model.
        tx: The Optax optimizer.
        optimizer_state: The current state of the optimizer.
        params: The learnable parameters of the model.
        batch: A batch of training data.

    Returns:
        A tuple of the loss, updated optimizer state, and updated parameters.
    """

    def loss_fn(current_params: nnx.State) -> float:
        """Calculates loss for the given parameters."""
        # Reconstruct the model inside the loss function
        model = nnx.merge(graphdef, current_params)
        return compute_loss(model, batch)

    # Use jax.value_and_grad directly on the function of params
    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(params)

    updates, optimizer_state = tx.update(grads, optimizer_state, params)
    params = optax.apply_updates(params, updates)
    return loss, optimizer_state, params


@functools.partial(nnx.jit, static_argnums=(0,))
def predict_next_event(
    graphdef: nnx.GraphDef,
    params: nnx.State,
    history_times: jnp.ndarray,
    history_marks: jnp.ndarray,
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
    """
    Predicts the time and mark of the next event given a history.

    Args:
        graphdef: The static graph definition of the model.
        params: The trained parameters of the model.
        history_times: A (seq_len,) array of event times in the history.
        history_marks: A (seq_len,) array of event marks in the history.

    Returns:
        A tuple containing the predicted absolute time, the predicted mark,
        and the intensity parameters (v, w, b) for the last step.
    """
    # Reconstruct the model inside the JIT'd function
    model = nnx.merge(graphdef, params)

    # Add a batch dimension for the model
    times_batch = jnp.expand_dims(history_times, 0)
    marks_batch = jnp.expand_dims(history_marks, 0)

    # Get model outputs for the history
    mark_logits, v, w, b = model(times_batch, marks_batch)

    # Get the outputs corresponding to the *last* event in the sequence
    last_mark_logits = mark_logits[0, -1, :]
    last_v = v[0, -1, 0]
    last_w = w[0, -1, 0]
    last_b = b[0, -1, 0]

    # 1. Predict next mark (the one with the highest probability)
    predicted_mark = jnp.argmax(last_mark_logits)

    # 2. Predict next time
    # Instead of numerical integration for the mean, we compute the median
    # which has a simple analytical form.
    # delta_t = (1/w) * log(1 + (w * log(2)) / exp(v+b))
    A = last_v + last_b
    predicted_inter_event_time = (1 / last_w) * jnp.log(
        1 + (last_w * jnp.log(2.0)) / jnp.exp(A)
    )

    last_time = history_times[-1]
    predicted_time = last_time + predicted_inter_event_time

    return predicted_time, predicted_mark, last_v, last_w, last_b


def train_model(
    graphdef: nnx.GraphDef,
    tx: optax.GradientTransformation,
    params: nnx.State,
    sequences: List[Dict[str, jnp.ndarray]],
) -> Tuple[nnx.State, List[float]]:
    """
    Handles the main training loop for the RMTPP model.

    Args:
        graphdef: The static graph definition of the model.
        tx: The Optax optimizer.
        params: The initial learnable parameters of the model.
        sequences: The list of training sequences.

    Returns:
        A tuple containing the trained parameters and the list of losses per
        epoch.
    """
    print("2. Starting model training...")
    optimizer_state = tx.init(params)
    losses = []
    n_batches = len(sequences) // BATCH_SIZE

    for epoch in range(EPOCHS):
        epoch_loss = 0.0
        # Simple batching
        np.random.shuffle(sequences)
        for i in range(n_batches):
            batch_list = sequences[i * BATCH_SIZE : (i + 1) * BATCH_SIZE]

            # Pad sequences to be the same length for batching
            batch = {
                "times": jnp.stack([s["times"] for s in batch_list]),
                "marks": jnp.stack([s["marks"] for s in batch_list]),
                "mask": jnp.stack([s["mask"] for s in batch_list]),
            }

            loss_val, optimizer_state, params = train_step(
                graphdef, tx, optimizer_state, params, batch
            )
            epoch_loss += loss_val

        avg_loss = epoch_loss / n_batches
        losses.append(avg_loss)
        if (epoch + 1) % 20 == 0:
            print(f"Epoch {epoch+1}/{EPOCHS}, Avg. Loss: {avg_loss:.4f}")

    print("3. Training complete.")
    return params, losses


def evaluate_and_plot(
    graphdef: nnx.GraphDef,
    params: nnx.State,
    test_sequences: List[Dict[str, jnp.ndarray]],
    losses: List[float],
):
    """
    Runs a comprehensive evaluation on the test set and generates result plots.

    Args:
        graphdef: The static graph definition of the model.
        params: The trained parameters of the model.
        test_sequences: A list of sequences for evaluation.
        losses: A list of losses from training for plotting.
    """
    # --- 1. Plot Training Loss ---
    plt.figure(figsize=(10, 5))
    plt.plot(losses)
    plt.title("Training Loss Over Epochs")
    plt.xlabel("Epoch")
    plt.ylabel("Negative Log-Likelihood Loss")
    plt.grid(True)
    plt.savefig("training_loss.png")
    print("   - Saved training loss plot to 'training_loss.png'")
    plt.close()

    # --- 2. Gather Predictions from Test Set ---
    print("\n4. Gathering predictions from the test set...")
    all_pred_marks, all_actual_marks = [], []
    all_pred_times, all_actual_times = [], []
    intensity_params_for_plot = None

    for i, seq in enumerate(test_sequences):
        actual_seq_len = int(jnp.sum(seq["mask"]))
        if actual_seq_len < 10:
            continue

        history_len = actual_seq_len // 2
        actual_event_index = history_len

        history_times = seq["times"][:history_len]
        history_marks = seq["marks"][:history_len]

        pred_time, pred_mark, v, w, b = predict_next_event(
            graphdef, params, history_times, history_marks
        )

        actual_time = seq["times"][actual_event_index]
        actual_mark = seq["marks"][actual_event_index]

        all_pred_marks.append(pred_mark)
        all_actual_marks.append(actual_mark)
        all_pred_times.append(pred_time)
        all_actual_times.append(actual_time)

        # Save params from the first sequence for the intensity plot
        if i == 0:
            intensity_params_for_plot = {
                "v": v,
                "w": w,
                "b": b,
                "actual_dt": actual_time - history_times[-1],
            }

    # --- 3. Generate and Save Plots ---
    print("5. Generating evaluation plots...")

    # Plot single prediction example
    plot_prediction_example(
        history_times,
        history_marks,
        all_actual_times[0],
        all_actual_marks[0],
        all_pred_times[0],
        all_pred_marks[0],
    )

    # Plot Time Prediction Error Histogram
    time_errors = np.array(all_pred_times) - np.array(all_actual_times)
    plt.figure(figsize=(10, 6))
    sns.histplot(time_errors, kde=True)
    plt.title("Distribution of Time Prediction Errors (Predicted - Actual)")
    plt.xlabel("Time Error")
    plt.ylabel("Frequency")
    plt.grid(True)
    plt.savefig("time_error_distribution.png")
    print("   - Saved time error distribution plot.")
    plt.close()

    # Plot Confusion Matrix
    cm = confusion_matrix(all_actual_marks, all_pred_marks, labels=range(N_MARKS))
    disp = ConfusionMatrixDisplay(
        confusion_matrix=cm, display_labels=[f"Mark {i}" for i in range(N_MARKS)]
    )
    disp.plot(cmap=plt.cm.Blues)
    plt.title("Mark Prediction Confusion Matrix")
    plt.savefig("confusion_matrix.png")
    print("   - Saved confusion matrix plot.")
    plt.close()

    # Plot Intensity Function
    plot_intensity_function(intensity_params_for_plot)


def plot_prediction_example(h_times, h_marks, a_time, a_mark, p_time, p_mark):
    """Plots a single prediction example."""
    plt.figure(figsize=(12, 6))
    plt.scatter(h_times, h_marks, c="blue", marker="o", label="History Events")
    plt.scatter(a_time, a_mark, c="green", marker="*", s=200, label="Actual Next Event")
    plt.scatter(
        p_time, p_mark, c="red", marker="x", s=200, label="Predicted Next Event"
    )
    plt.yticks(range(N_MARKS), [f"Mark {i}" for i in range(N_MARKS)])
    plt.xlabel("Time")
    plt.ylabel("Event Mark")
    plt.title("RMTPP Prediction vs. Actual Event")
    plt.legend()
    plt.grid(True, linestyle="--", alpha=0.6)
    plt.savefig("prediction_vs_actual.png")
    print("   - Saved prediction example plot.")
    plt.close()


def plot_intensity_function(params):
    """Plots the learned conditional intensity function for one example."""
    v, w, b = params["v"], params["w"], params["b"]
    actual_dt = params["actual_dt"]

    lambda_func = lambda t: jnp.exp(v + w * t + b)

    max_t = float(actual_dt) * 2.0  # Plot up to twice the actual delta_t
    t_range = np.linspace(0, max_t, 200)
    intensity_values = [lambda_func(t) for t in t_range]

    plt.figure(figsize=(10, 6))
    plt.plot(t_range, intensity_values, label="Learned Intensity $\lambda^*(t)$")
    plt.axvline(
        x=actual_dt,
        color="green",
        linestyle="--",
        label=f"Actual Event Time ($\Delta t={actual_dt:.2f}$)",
    )
    plt.xlabel("Time since last event ($\Delta t$)")
    plt.ylabel("Conditional Intensity $\lambda^*(t)$")
    plt.title("Learned Conditional Intensity Function")
    plt.legend()
    plt.grid(True)
    plt.savefig("intensity_function.png")
    print("   - Saved intensity function plot.")
    plt.close()


def main():
    """Main function to run data simulation, model training, and plotting."""
    print("1. Simulating data...")
    sequences = simulate_data(N_SEQS, MAX_SEQ_LEN, N_MARKS)
    np.random.shuffle(sequences)

    # Split data into training and testing sets
    test_size = int(len(sequences) * TEST_SPLIT_RATIO)
    train_sequences = sequences[test_size:]
    test_sequences = sequences[:test_size]

    # Initialize model and optimizer
    model = RMTPP(hidden_size=HIDDEN_SIZE, n_marks=N_MARKS, rngs=nnx.Rngs(0))

    # Chain the optimizer with gradient clipping to prevent exploding gradients
    tx = optax.chain(optax.clip(1.0), optax.adam(LEARNING_RATE))

    # Split the model into its static definition (graphdef) and dynamic state (params).
    graphdef, params = nnx.split(model, nnx.Param)

    # Train the model
    trained_params, losses = train_model(graphdef, tx, params, train_sequences)

    # Evaluate the model and plot results
    evaluate_and_plot(graphdef, trained_params, test_sequences, losses)


if __name__ == "__main__":
    main()

1. Simulating data...
2. Starting model training...
Epoch 20/200, Avg. Loss: 0.8651
Epoch 40/200, Avg. Loss: 0.8349
Epoch 60/200, Avg. Loss: 0.8102
Epoch 80/200, Avg. Loss: 0.7948
Epoch 100/200, Avg. Loss: 0.7818
Epoch 120/200, Avg. Loss: 0.7674
Epoch 140/200, Avg. Loss: 0.7533
Epoch 160/200, Avg. Loss: 0.7397
Epoch 180/200, Avg. Loss: 0.7251
Epoch 200/200, Avg. Loss: 0.7145
3. Training complete.
   - Saved training loss plot to 'training_loss.png'

4. Gathering predictions from the test set...
5. Generating evaluation plots...
   - Saved prediction example plot.
   - Saved time error distribution plot.
   - Saved confusion matrix plot.
   - Saved intensity function plot.


In [18]:
# -*- coding: utf-8 -*-
"""
This script implements an advanced version of the RMTPP model adapted for
modeling multichannel neural waveform data. It uses a CNN encoder-decoder
architecture to handle the high-dimensional waveform "marks".
"""

import functools
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Tuple

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import optax
import seaborn as sns
from flax import nnx
from jax import random
from sklearn.metrics import ConfusionMatrixDisplay, confusion_matrix


# ==============================================================================
# 1. Configuration
# ==============================================================================
@dataclass
class Config:
    """Configuration object for the RMTPP Waveform model and training."""

    # Data Simulation
    n_sequences: int = 300
    n_marks: int = 5  # Number of distinct waveform templates
    max_seq_len: int = 50
    n_channels: int = 8  # Number of channels on the linear probe
    n_samples: int = 40  # Number of time samples per waveform clip

    # Model Architecture
    hidden_size: int = 64  # Increased hidden size for more complex data
    embedding_size: int = 32  # Size of the compressed waveform vector

    # Training
    learning_rate: float = 1e-3
    epochs: int = 200
    batch_size: int = 16
    test_split_ratio: float = 0.2
    seed: int = 42
    output_dir: Path = Path("./results_waveform")


# ==============================================================================
# 2. Data Simulation (Generating Waveforms)
# ==============================================================================
def simulate_waveform_data(
    config: Config,
) -> List[Dict[str, jnp.ndarray]]:
    """
    Simulates sequences of multichannel waveform clips.

    Each "mark" is now a (n_channels, n_samples) array representing a spike
    waveform as recorded on a linear probe.

    Args:
        config: The configuration object.

    Returns:
        A list of sequences, each a dictionary containing event times,
        waveform clips, and a padding mask.
    """
    key = random.PRNGKey(config.seed)
    sequences = []

    # Generate N waveform templates
    key, subkey = random.split(key)
    templates = random.normal(
        subkey, (config.n_marks, config.n_channels, config.n_samples)
    )

    for i in range(config.n_marks):
        # Create some structure in the templates
        peak_time = config.n_samples // 3 + i
        peak_channel = i % config.n_channels
        amp = 2.0 + (i / config.n_marks)
        for c in range(config.n_channels):
            # Make amplitude decay away from the peak channel
            dist = jnp.abs(c - peak_channel)
            channel_amp = amp * jnp.exp(-dist / 2.0)
            t = jnp.linspace(-5, 5, config.n_samples)
            # A simple Gabor-like wavelet
            templates = templates.at[i, c].set(
                channel_amp
                * jnp.exp(-((t - (peak_time - t.mean())) ** 2))
                * jnp.cos(1.5 * (t - (peak_time - t.mean())))
            )

    for _ in range(config.n_sequences):
        key, subkey_len, subkey_time, subkey_mark, subkey_noise = random.split(key, 5)
        seq_len = random.randint(subkey_len, (), minval=20, maxval=config.max_seq_len)

        times = np.zeros(config.max_seq_len, dtype=np.float32)
        # The marks are now high-dimensional arrays
        marks = np.zeros(
            (
                config.max_seq_len,
                config.n_channels,
                config.n_samples,
            ),
            dtype=np.float32,
        )

        last_mark_idx = 0
        times[0] = random.exponential(subkey_time, ()) * 0.1
        marks[0] = (
            templates[last_mark_idx]
            + random.normal(subkey_noise, templates[0].shape) * 0.1
        )

        for i in range(1, seq_len):
            key, subkey_time, subkey_mark, subkey_noise = random.split(key, 4)

            # Transition probability for next waveform type
            probs = jnp.ones(config.n_marks) * 0.2
            probs = probs.at[last_mark_idx].set(0.8)
            next_mark_idx = random.choice(subkey_mark, config.n_marks, p=probs)

            inter_event_time = random.exponential(subkey_time, ()) * (
                1.0 + 0.5 * (next_mark_idx / config.n_marks)
            )
            times[i] = times[i - 1] + inter_event_time
            marks[i] = (
                templates[next_mark_idx]
                + random.normal(subkey_noise, templates[0].shape) * 0.1
            )
            last_mark_idx = next_mark_idx

        mask = np.arange(config.max_seq_len) < seq_len
        sequences.append(
            {
                "times": jnp.asarray(times),
                "marks": jnp.asarray(marks),
                "mask": jnp.asarray(mask),
            }
        )
    return sequences


# ==============================================================================
# 3. Waveform Encoder and Decoder Modules
# ==============================================================================
class WaveformEncoder(nnx.Module):
    """A CNN to encode a waveform clip into a feature vector."""

    def __init__(self, embedding_size: int, n_channels: int, *, rngs: nnx.Rngs):
        self.conv1 = nnx.Conv(
            in_features=n_channels,
            out_features=16,
            kernel_size=(3, 3),
            padding="SAME",
            rngs=rngs,
        )
        self.conv2 = nnx.Conv(
            in_features=16,
            out_features=32,
            kernel_size=(3, 3),
            padding="SAME",
            rngs=rngs,
        )
        self.dense = nnx.Linear(
            in_features=32 * n_channels, out_features=embedding_size, rngs=rngs
        )

    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        """
        Args:
            x: Waveform clip with shape (batch, channels, samples).
        """
        # Flax Conv expects channels-last format
        x = x.transpose((0, 2, 1))
        x = nnx.relu(self.conv1(x))
        x = nnx.avg_pool(x, window_shape=(2,), strides=(2,))
        x = nnx.relu(self.conv2(x))
        x = nnx.avg_pool(x, window_shape=(2,), strides=(2,))
        # Global average pooling over the time dimension
        x = jnp.mean(x, axis=1)
        x = self.dense(x)
        return x


class WaveformDecoder(nnx.Module):
    """A Transposed CNN to decode a feature vector into a waveform clip."""

    def __init__(
        self,
        hidden_size: int,
        n_channels: int,
        n_samples: int,
        *,
        rngs: nnx.Rngs,
    ):
        self.n_channels = n_channels
        self.n_samples = n_samples
        self.dense = nnx.Linear(
            in_features=hidden_size,
            out_features=32 * (n_samples // 4) * n_channels,
            rngs=rngs,
        )
        self.deconv1 = nnx.ConvTranspose(
            in_features=32,
            out_features=16,
            kernel_size=(4, 4),
            strides=(2, 2),
            padding="SAME",
            rngs=rngs,
        )
        self.deconv2 = nnx.ConvTranspose(
            in_features=16,
            out_features=n_channels,
            kernel_size=(4, 4),
            strides=(2, 2),
            padding="SAME",
            rngs=rngs,
        )

    def __call__(self, h: jnp.ndarray) -> jnp.ndarray:
        """
        Args:
            h: Hidden state vector of shape (batch, hidden_size).
        """
        x = self.dense(h)
        # Reshape to an image-like format for deconvolution
        x = x.reshape(
            (
                h.shape[0],
                self.n_samples // 4,
                self.n_channels,
                32,
            )
        )
        x = nnx.relu(self.deconv1(x))
        x = self.deconv2(x)
        # Transpose back to (batch, channels, samples)
        x = x.transpose((0, 2, 1))
        return x


# ==============================================================================
# 4. Main RMTPP Model for Waveforms
# ==============================================================================
class RMTPP_Waveform(nnx.Module):
    """RMTPP model adapted for waveform data."""

    def __init__(self, config: Config, *, rngs: nnx.Rngs):
        self.encoder = WaveformEncoder(
            embedding_size=config.embedding_size,
            n_channels=config.n_channels,
            rngs=rngs,
        )
        self.time_proj = nnx.Linear(
            in_features=1, out_features=config.hidden_size, rngs=rngs
        )
        self.gru_cell = nnx.GRUCell(
            in_features=config.embedding_size,
            hidden_features=config.hidden_size,
            rngs=rngs,
        )
        self.decoder = WaveformDecoder(
            hidden_size=config.hidden_size,
            n_channels=config.n_channels,
            n_samples=config.n_samples,
            rngs=rngs,
        )
        # Output layers for time prediction
        self.v_proj = nnx.Linear(config.hidden_size, 1, use_bias=False, rngs=rngs)
        self.w_proj = nnx.Linear(config.hidden_size, 1, use_bias=False, rngs=rngs)
        self.b_proj = nnx.Linear(config.hidden_size, 1, use_bias=True, rngs=rngs)

    def __call__(
        self, times: jnp.ndarray, marks: jnp.ndarray
    ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
        """
        Performs the forward pass of the model.
        Now returns the predicted waveform instead of logits.
        """
        inter_event_times = jnp.diff(times, prepend=0)
        inter_event_times = jnp.expand_dims(inter_event_times, -1)

        def rnn_step(h_prev, inputs):
            mark_j, time_j = inputs
            # Encode the waveform clip to get a feature vector
            mark_emb = self.encoder(mark_j)
            time_emb = self.time_proj(time_j)
            # Input to the GRU is the mark embedding
            # Note: A more complex model could combine these differently
            h_j = self.gru_cell(h_prev, mark_emb + time_emb)
            return h_j, h_j

        batch_size, seq_len, _, _ = marks.shape
        h0 = self.gru_cell.initialize_carry(
            (batch_size, self.gru_cell.in_features), rngs=self.rngs
        )

        # We need to scan over marks and times together
        marks_scannable = jnp.transpose(marks, (1, 0, 2, 3))
        inter_event_times_scannable = jnp.transpose(inter_event_times, (1, 0, 2))

        scan_fn = nnx.scan(rnn_step)
        _, h_sequence = scan_fn(h0, (marks_scannable, inter_event_times_scannable))
        h_sequence = h_sequence.transpose((1, 0, 2))

        # Decode the hidden states to predict the next waveform
        predicted_waveforms = self.decoder(h_sequence)

        v_unconstrained = self.v_proj(h_sequence)
        b_unconstrained = self.b_proj(h_sequence)
        v = jax.nn.tanh(v_unconstrained) * 5.0
        b = jax.nn.tanh(b_unconstrained) * 5.0
        w = jax.nn.softplus(self.w_proj(h_sequence)) + 1e-6

        return predicted_waveforms, v, w, b


# ==============================================================================
# 5. Training and Evaluation Logic
# ==============================================================================
def compute_loss(model: RMTPP_Waveform, batch: Dict[str, jnp.ndarray]) -> jnp.ndarray:
    """Computes the combined loss for a batch of sequences."""
    times, marks, mask = batch["times"], batch["marks"], batch["mask"]
    is_last_event = jnp.cumsum(mask, axis=1) == jnp.sum(mask, axis=1, keepdims=True)
    prediction_mask = mask & ~is_last_event

    # Target for time is the next inter-event time
    targets_times = jnp.roll(times, -1, axis=1)
    inter_event_times_target = targets_times - times

    # Target for mark is the next waveform in the sequence
    targets_marks = jnp.roll(marks, -1, axis=1)

    predicted_waveforms, v, w, b = model(times, marks)
    v, w, b = v.squeeze(-1), w.squeeze(-1), b.squeeze(-1)

    # 1. Mark Loss (Mean Squared Error for waveforms)
    mark_loss = optax.squared_error(predicted_waveforms, targets_marks)
    # Average over channel and sample dimensions
    mark_loss = jnp.mean(mark_loss, axis=(-2, -1))
    mark_loss = jnp.sum(mark_loss * prediction_mask) / jnp.sum(prediction_mask)

    # 2. Time Loss (Negative Log-Likelihood)
    log_lambda = v + w * inter_event_times_target + b
    integral_lambda = (1 / w) * (
        jnp.exp(v + w * inter_event_times_target + b) - jnp.exp(v + b)
    )
    time_loss = -log_lambda + integral_lambda
    time_loss = jnp.sum(time_loss * prediction_mask) / jnp.sum(prediction_mask)

    return mark_loss + time_loss


# --- Training and Prediction functions are largely unchanged ---
@functools.partial(nnx.jit, static_argnums=(0, 1))
def train_step(
    graphdef: nnx.GraphDef,
    tx: optax.GradientTransformation,
    optimizer_state: optax.OptState,
    params: nnx.State,
    batch: Dict[str, jnp.ndarray],
) -> Tuple[float, optax.OptState, nnx.State]:
    """Performs a single training step."""

    def loss_fn(current_params: nnx.State) -> float:
        model = nnx.merge(graphdef, current_params)
        return compute_loss(model, batch)

    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(params)
    updates, optimizer_state = tx.update(grads, optimizer_state, params)
    params = optax.apply_updates(params, updates)
    return loss, optimizer_state, params


@functools.partial(nnx.jit, static_argnums=(0,))
def predict_next_event(
    graphdef: nnx.GraphDef,
    params: nnx.State,
    history_times: jnp.ndarray,
    history_marks: jnp.ndarray,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """Predicts the time and waveform of the next event."""
    model = nnx.merge(graphdef, params)
    times_batch = jnp.expand_dims(history_times, 0)
    marks_batch = jnp.expand_dims(history_marks, 0)
    predicted_waveforms, v, w, b = model(times_batch, marks_batch)
    last_v, last_w, last_b = v[0, -1, 0], w[0, -1, 0], b[0, -1, 0]

    # Prediction for the waveform is the decoder output at the last step
    predicted_waveform = predicted_waveforms[0, -1]

    # Time prediction logic is unchanged
    A = last_v + last_b
    predicted_inter_event_time = (1 / last_w) * jnp.log(
        1 + (last_w * jnp.log(2.0)) / jnp.exp(A)
    )
    predicted_time = history_times[-1] + predicted_inter_event_time
    return predicted_time, predicted_waveform


# --- Main loop and evaluation plotting functions need updating ---
def train_model(
    graphdef: nnx.GraphDef,
    tx: optax.GradientTransformation,
    params: nnx.State,
    sequences: List[Dict[str, jnp.ndarray]],
    config: Config,
) -> Tuple[nnx.State, List[float]]:
    """Handles the main training loop."""
    print("2. Starting model training...")
    optimizer_state = tx.init(params)
    losses = []
    n_batches = len(sequences) // config.batch_size
    for epoch in range(config.epochs):
        epoch_loss = 0.0
        np.random.shuffle(sequences)
        for i in range(n_batches):
            batch_list = sequences[i * config.batch_size : (i + 1) * config.batch_size]
            batch = {
                "times": jnp.stack([s["times"] for s in batch_list]),
                "marks": jnp.stack([s["marks"] for s in batch_list]),
                "mask": jnp.stack([s["mask"] for s in batch_list]),
            }
            loss_val, optimizer_state, params = train_step(
                graphdef, tx, optimizer_state, params, batch
            )
            epoch_loss += loss_val
        avg_loss = epoch_loss / n_batches
        losses.append(avg_loss)
        if (epoch + 1) % 20 == 0:
            print(f"Epoch {epoch+1}/{config.epochs}, Avg. Loss: {avg_loss:.4f}")
    print("3. Training complete.")
    return params, losses


def evaluate_and_plot(
    graphdef: nnx.GraphDef,
    params: nnx.State,
    test_sequences: List[Dict[str, jnp.ndarray]],
    losses: List[float],
    config: Config,
):
    """Runs evaluation and generates result plots for waveform data."""
    # Plot Training Loss
    plt.figure(figsize=(10, 5))
    plt.plot(losses)
    plt.title("Training Loss Over Epochs")
    plt.xlabel("Epoch")
    plt.ylabel("Combined Loss (Time NLL + Waveform MSE)")
    plt.grid(True)
    plt.savefig(config.output_dir / "training_loss.png")
    print(f"\n   - Saved training loss plot to '{config.output_dir}/'")
    plt.close()

    # Gather Predictions
    print("\n4. Gathering predictions from the test set...")
    all_pred_times, all_actual_times = [], []
    first_pred_waveform, first_actual_waveform = None, None

    for i, seq in enumerate(test_sequences):
        actual_seq_len = int(jnp.sum(seq["mask"]))
        if actual_seq_len < 10:
            continue
        history_len = actual_seq_len // 2
        actual_event_index = history_len
        history_times = seq["times"][:history_len]
        history_marks = seq["marks"][:history_len]

        pred_time, pred_waveform = predict_next_event(
            graphdef, params, history_times, history_marks
        )

        actual_time = seq["times"][actual_event_index]
        actual_waveform = seq["marks"][actual_event_index]

        all_pred_times.append(pred_time)
        all_actual_times.append(actual_time)
        if i == 0:
            first_pred_waveform = pred_waveform
            first_actual_waveform = actual_waveform

    # --- Calculate and Report Metrics ---
    print("\n5. Calculating aggregate test metrics...")
    time_errors = np.array(all_pred_times) - np.array(all_actual_times)
    mae = np.mean(np.abs(time_errors))
    rmse = np.sqrt(np.mean(np.square(time_errors)))
    print(f"   - Time Prediction MAE:    {mae:.4f}")
    print(f"   - Time Prediction RMSE:   {rmse:.4f}")

    # --- Generate and Save Plots ---
    print("\n6. Generating evaluation plots...")
    plot_time_error_distribution(time_errors, config)
    plot_waveform_comparison(first_actual_waveform, first_pred_waveform, config)


def plot_time_error_distribution(time_errors: np.ndarray, config: Config):
    """Plots the distribution of time prediction errors."""
    plt.figure(figsize=(10, 6))
    sns.histplot(time_errors, kde=True)
    plt.title("Distribution of Time Prediction Errors (Predicted - Actual)")
    plt.xlabel("Time Error")
    plt.ylabel("Frequency")
    plt.grid(True)
    plt.savefig(config.output_dir / "time_error_distribution.png")
    print("   - Saved time error distribution plot.")
    plt.close()


def plot_waveform_comparison(
    actual_waveform: jnp.ndarray,
    pred_waveform: jnp.ndarray,
    config: Config,
):
    """Plots the actual vs. predicted waveform for one example."""
    fig, axes = plt.subplots(
        config.n_channels,
        1,
        figsize=(10, 2 * config.n_channels),
        sharex=True,
    )
    fig.suptitle("Waveform Prediction vs. Actual", fontsize=16)
    time_axis = np.arange(config.n_samples)

    for i in range(config.n_channels):
        ax = axes[i]
        ax.plot(
            time_axis,
            actual_waveform[i, :],
            "b-",
            label="Actual Waveform",
        )
        ax.plot(
            time_axis,
            pred_waveform[i, :],
            "r--",
            label="Predicted Waveform",
        )
        ax.set_ylabel(f"Channel {i}")
        ax.grid(True, linestyle="--", alpha=0.6)
    axes[-1].set_xlabel("Time Samples")
    handles, labels = axes[-1].get_legend_handles_labels()
    fig.legend(handles, labels, loc="upper right")
    plt.tight_layout(rect=[0, 0, 1, 0.96])
    plt.savefig(config.output_dir / "waveform_comparison.png")
    print("   - Saved waveform comparison plot.")
    plt.close()


def main():
    """Main function to run simulation, training, and plotting."""
    config = Config()
    config.output_dir.mkdir(exist_ok=True, parents=True)
    print("1. Simulating waveform data...")
    sequences = simulate_waveform_data(config)
    np.random.shuffle(sequences)
    test_size = int(len(sequences) * config.test_split_ratio)
    train_sequences, test_sequences = sequences[test_size:], sequences[:test_size]
    model = RMTPP_Waveform(config=config, rngs=nnx.Rngs(config.seed))
    tx = optax.chain(optax.clip(1.0), optax.adam(config.learning_rate))
    graphdef, params = nnx.split(model, nnx.Param)
    trained_params, losses = train_model(graphdef, tx, params, train_sequences, config)
    evaluate_and_plot(graphdef, trained_params, test_sequences, losses, config)


if __name__ == "__main__":
    main()

1. Simulating waveform data...


ValueError: Non-exhaustive filters, got a non-empty remainder: [38;2;79;201;177mFlatState[0m[38;2;255;213;3m([[0m[38;2;105;105;105m[0m
  [38;2;156;220;254m[0m[38;2;212;212;212m[0m[38;2;255;213;3m([0m[38;2;255;213;3m([0m[38;2;207;144;120m'gru_cell'[0m, [38;2;207;144;120m'rngs'[0m, [38;2;207;144;120m'default'[0m, [38;2;207;144;120m'count'[0m[38;2;255;213;3m)[0m, [38;2;79;201;177mVariableState[0m[38;2;255;213;3m([0m[38;2;105;105;105m # 1 (4 B)[0m
    [38;2;156;220;254mtype[0m[38;2;212;212;212m=[0m[38;2;79;201;177mRngCount[0m,
    [38;2;156;220;254mvalue[0m[38;2;212;212;212m=[0mArray(21, dtype=uint32),
    [38;2;156;220;254mtag[0m[38;2;212;212;212m=[0m[38;2;207;144;120m'default'[0m
  [38;2;255;213;3m)[0m[38;2;255;213;3m)[0m,
  [38;2;156;220;254m[0m[38;2;212;212;212m[0m[38;2;255;213;3m([0m[38;2;255;213;3m([0m[38;2;207;144;120m'gru_cell'[0m, [38;2;207;144;120m'rngs'[0m, [38;2;207;144;120m'default'[0m, [38;2;207;144;120m'key'[0m[38;2;255;213;3m)[0m, [38;2;79;201;177mVariableState[0m[38;2;255;213;3m([0m[38;2;105;105;105m # 1 (8 B)[0m
    [38;2;156;220;254mtype[0m[38;2;212;212;212m=[0m[38;2;79;201;177mRngKey[0m,
    [38;2;156;220;254mvalue[0m[38;2;212;212;212m=[0mArray((), dtype=key<fry>) overlaying:
    [ 0 42],
    [38;2;156;220;254mtag[0m[38;2;212;212;212m=[0m[38;2;207;144;120m'default'[0m
  [38;2;255;213;3m)[0m[38;2;255;213;3m)[0m
[38;2;255;213;3m])[0m.
Use `...` to match all remaining elements.