# Advanced Neural Circuit Policies with MLX

This notebook demonstrates advanced features of the CfC (Closed-form Continuous-time) implementation in MLX, focusing on the CfCRNN for complex sequence processing.

In [None]:
import mlx.core as mx
import mlx.nn as nn
from ncps.mlx import CfCRNN
import numpy as np
import matplotlib.pyplot as plt

## Create a Bidirectional Sequence Model

We'll create a model that processes sequences in both directions using CfCRNN.

In [None]:
class BidirectionalModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.rnn = CfCRNN(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=2,
            bidirectional=True,
            backbone_units=64,
            backbone_layers=2
        )
        # Since the RNN is bidirectional, output dimension is 2*hidden_size
        self.output_layer = nn.Linear(hidden_size * 2, output_size)
    
    def __call__(self, x, time_delta=None):
        # Get sequence output and final states
        sequence_output, states = self.rnn(x, time_delta=time_delta)
        # Use the final sequence output for prediction
        final_output = self.output_layer(sequence_output[:, -1])
        return final_output

## Generate Complex Sequence Data

We'll create a more complex sequence prediction task that involves multiple sinusoidal components.

In [None]:
def generate_complex_data(batch_size, seq_length):
    t = np.linspace(0, 8*np.pi, seq_length)
    
    # Create sequences with multiple frequency components
    x1 = np.sin(t)
    x2 = 0.5 * np.sin(2*t)
    x3 = 0.25 * np.sin(4*t)
    x = x1 + x2 + x3
    
    # Create batches with 2D input (two features)
    X = np.zeros((batch_size, seq_length-1, 2))
    y = np.zeros((batch_size, 2))
    
    for i in range(batch_size):
        start_idx = np.random.randint(0, len(x)-seq_length)
        # First feature is the signal
        X[i, :, 0] = x[start_idx:start_idx+seq_length-1]
        # Second feature is the derivative
        X[i, :, 1] = np.gradient(x[start_idx:start_idx+seq_length-1])
        # Target is next value and its derivative
        y[i, 0] = x[start_idx+seq_length-1]
        y[i, 1] = np.gradient(x)[start_idx+seq_length-1]
    
    return mx.array(X), mx.array(y)

# Generate example data
X, y = generate_complex_data(batch_size=1, seq_length=100)

# Plot example sequence
plt.figure(figsize=(12, 4))
plt.plot(X[0, :, 0], label='Signal')
plt.plot(X[0, :, 1], label='Derivative')
plt.legend()
plt.title('Example Training Sequence')
plt.show()

## Train the Model with Time-Aware Processing

In [None]:
# Create model and optimizer
model = BidirectionalModel(input_size=2, hidden_size=32, output_size=2)

def loss_fn(model, X, y, time_delta=None):
    pred = model(X, time_delta=time_delta)
    return mx.mean((pred - y) ** 2)

# Get gradients function
loss_and_grad_fn = nn.value_and_grad(model, loss_fn)

# Training loop
batch_size = 32
seq_length = 50
n_epochs = 100

optimizer = nn.Adam(learning_rate=0.001)

losses = []
for epoch in range(n_epochs):
    X, y = generate_complex_data(batch_size, seq_length)
    
    # Create variable time deltas to demonstrate time-aware processing
    time_delta = 1.0 + 0.1 * mx.random.uniform((batch_size, seq_length-1))
    
    # Compute loss and gradients
    loss, grads = loss_and_grad_fn(model, X, y, time_delta)
    
    # Update parameters
    optimizer.update(model, grads)
    
    losses.append(float(loss))
    if epoch % 10 == 0:
        print(f"Epoch {epoch}, Loss: {loss}")

# Plot training progress
plt.figure(figsize=(10, 4))
plt.plot(losses)
plt.title('Training Loss')
plt.xlabel('Epoch')
plt.ylabel('MSE Loss')
plt.yscale('log')
plt.grid(True)
plt.show()

## Evaluate Model Predictions

In [None]:
# Generate test sequence
X_test, y_test = generate_complex_data(batch_size=1, seq_length=200)
time_delta_test = mx.ones((1, 199))  # Use constant time steps for testing

# Get predictions
pred = model(X_test, time_delta=time_delta_test)

# Plot results
plt.figure(figsize=(12, 8))

plt.subplot(211)
plt.plot(X_test[0, :, 0], label='Input Signal')
plt.plot(len(X_test[0]), float(pred[0, 0]), 'ro', label='Predicted Signal')
plt.plot(len(X_test[0]), float(y_test[0, 0]), 'go', label='True Signal')
plt.legend()
plt.title('Signal Prediction')

plt.subplot(212)
plt.plot(X_test[0, :, 1], label='Input Derivative')
plt.plot(len(X_test[0]), float(pred[0, 1]), 'ro', label='Predicted Derivative')
plt.plot(len(X_test[0]), float(y_test[0, 1]), 'go', label='True Derivative')
plt.legend()
plt.title('Derivative Prediction')

plt.tight_layout()
plt.show()