# Neural Circuit Policy Benchmarks

This notebook provides comprehensive benchmarks for different wiring patterns across common tasks:
- Sequence Prediction
- Time Series Classification
- Control Tasks
- Real-time Processing

In [None]:
import mlx.core as mx
import mlx.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from time import time
from ncps.mlx import CfC, LTC
from ncps.mlx.wirings import Random, NCP, AutoNCP
from ncps.mlx.advanced_profiling import MLXProfiler, quick_profile

## 1. Sequence Prediction

Benchmark sequence prediction performance:

In [None]:
def generate_sequence_data(n_samples=1000, seq_length=50):
    """Generate sequence prediction data."""
    X = np.zeros((n_samples, seq_length, 1))
    y = np.zeros((n_samples, seq_length, 1))
    
    for i in range(n_samples):
        # Generate sinusoidal sequence
        t = np.linspace(0, 4*np.pi, seq_length)
        freq = 1.0 + 0.1 * np.random.randn()
        phase = 2 * np.pi * np.random.rand()
        X[i, :, 0] = np.sin(freq * t + phase)
        y[i, :, 0] = np.cos(freq * t + phase)  # Predict derivative
    
    return mx.array(X), mx.array(y)

def benchmark_sequence_prediction():
    """Benchmark sequence prediction task."""
    # Generate data
    X_train, y_train = generate_sequence_data()
    X_test, y_test = generate_sequence_data(n_samples=100)
    
    # Define models to test
    models = {
        'Random Dense': CfC(Random(units=100, sparsity_level=0.2)),
        'Random Sparse': CfC(Random(units=100, sparsity_level=0.8)),
        'NCP': CfC(NCP(
            inter_neurons=50,
            command_neurons=30,
            motor_neurons=20,
            sensory_fanout=5,
            inter_fanout=5,
            recurrent_command_synapses=10,
            motor_fanin=5
        )),
        'AutoNCP': CfC(AutoNCP(units=100, output_size=1))
    }
    
    results = {}
    for name, model in models.items():
        # Train model
        optimizer = nn.Adam(learning_rate=0.001)
        train_losses = []
        train_time = time()
        
        for epoch in range(100):
            def loss_fn(model, x, y):
                pred = model(x)
                return mx.mean((pred - y) ** 2)
            
            loss, grads = mx.value_and_grad(model, loss_fn)(model, X_train, y_train)
            optimizer.update(model, grads)
            train_losses.append(float(loss))
        
        train_time = time() - train_time
        
        # Evaluate
        pred = model(X_test)
        test_loss = float(mx.mean((pred - y_test) ** 2))
        
        # Profile
        stats = quick_profile(model)
        
        results[name] = {
            'train_time': train_time,
            'train_loss': train_losses,
            'test_loss': test_loss,
            'tflops': stats['compute']['tflops'],
            'memory': stats['memory']['peak_usage']
        }
    
    return results

# Run benchmark
sequence_results = benchmark_sequence_prediction()

# Plot results
plt.figure(figsize=(15, 5))

plt.subplot(131)
for name, result in sequence_results.items():
    plt.plot(result['train_loss'], label=name)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.legend()
plt.grid(True)

plt.subplot(132)
plt.bar(sequence_results.keys(),
        [r['train_time'] for r in sequence_results.values()])
plt.xticks(rotation=45)
plt.ylabel('Time (s)')
plt.title('Training Time')

plt.subplot(133)
plt.bar(sequence_results.keys(),
        [r['tflops'] for r in sequence_results.values()])
plt.xticks(rotation=45)
plt.ylabel('TFLOPS')
plt.title('Compute Efficiency')

plt.tight_layout()
plt.show()

## 2. Time Series Classification

Benchmark classification performance:

In [None]:
def generate_classification_data(n_samples=1000, seq_length=50, n_classes=5):
    """Generate time series classification data."""
    X = np.zeros((n_samples, seq_length, 1))
    y = np.zeros((n_samples, n_classes))
    
    for i in range(n_samples):
        # Generate pattern based on class
        class_id = np.random.randint(n_classes)
        y[i, class_id] = 1
        
        t = np.linspace(0, 4*np.pi, seq_length)
        if class_id == 0:
            # Sine wave
            X[i, :, 0] = np.sin(t)
        elif class_id == 1:
            # Square wave
            X[i, :, 0] = np.sign(np.sin(t))
        elif class_id == 2:
            # Sawtooth
            X[i, :, 0] = t % (2*np.pi) - np.pi
        else:
            # Random patterns
            X[i, :, 0] = np.cumsum(np.random.randn(seq_length)) / np.sqrt(seq_length)
    
    return mx.array(X), mx.array(y)

def benchmark_classification():
    """Benchmark classification task."""
    # Generate data
    X_train, y_train = generate_classification_data()
    X_test, y_test = generate_classification_data(n_samples=100)
    
    # Define models
    models = {
        'Random Dense': CfC(Random(units=100, sparsity_level=0.2)),
        'Random Sparse': CfC(Random(units=100, sparsity_level=0.8)),
        'NCP': CfC(NCP(
            inter_neurons=50,
            command_neurons=30,
            motor_neurons=5,
            sensory_fanout=5,
            inter_fanout=5,
            recurrent_command_synapses=10,
            motor_fanin=5
        )),
        'AutoNCP': CfC(AutoNCP(units=100, output_size=5))
    }
    
    results = {}
    for name, model in models.items():
        # Train model
        optimizer = nn.Adam(learning_rate=0.001)
        train_losses = []
        train_time = time()
        
        for epoch in range(100):
            def loss_fn(model, x, y):
                logits = model(x)[:, -1]  # Use final output
                return mx.mean((logits - y) ** 2)
            
            loss, grads = mx.value_and_grad(model, loss_fn)(model, X_train, y_train)
            optimizer.update(model, grads)
            train_losses.append(float(loss))
        
        train_time = time() - train_time
        
        # Evaluate
        pred = model(X_test)[:, -1]
        accuracy = float(mx.mean(mx.argmax(pred, axis=1) == mx.argmax(y_test, axis=1)))
        
        # Profile
        stats = quick_profile(model)
        
        results[name] = {
            'train_time': train_time,
            'train_loss': train_losses,
            'accuracy': accuracy,
            'tflops': stats['compute']['tflops'],
            'memory': stats['memory']['peak_usage']
        }
    
    return results

# Run benchmark
classification_results = benchmark_classification()

# Plot results
plt.figure(figsize=(15, 5))

plt.subplot(131)
for name, result in classification_results.items():
    plt.plot(result['train_loss'], label=name)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.legend()
plt.grid(True)

plt.subplot(132)
plt.bar(classification_results.keys(),
        [r['accuracy'] for r in classification_results.values()])
plt.xticks(rotation=45)
plt.ylabel('Accuracy')
plt.title('Classification Accuracy')

plt.subplot(133)
plt.bar(classification_results.keys(),
        [r['memory'] for r in classification_results.values()])
plt.xticks(rotation=45)
plt.ylabel('Memory (MB)')
plt.title('Memory Usage')

plt.tight_layout()
plt.show()

## 3. Control Tasks

Benchmark control performance:

In [None]:
def generate_control_data(n_samples=1000, seq_length=50):
    """Generate control task data."""
    X = np.zeros((n_samples, seq_length, 4))  # State: [pos_x, pos_y, vel_x, vel_y]
    y = np.zeros((n_samples, seq_length, 2))  # Control: [force_x, force_y]
    
    for i in range(n_samples):
        # Generate circular trajectory
        t = np.linspace(0, 2*np.pi, seq_length)
        radius = 1.0 + 0.1 * np.random.randn()
        phase = 2 * np.pi * np.random.rand()
        
        # Position
        X[i, :, 0] = radius * np.cos(t + phase)
        X[i, :, 1] = radius * np.sin(t + phase)
        
        # Velocity
        X[i, :, 2] = -radius * np.sin(t + phase)
        X[i, :, 3] = radius * np.cos(t + phase)
        
        # Optimal control (acceleration)
        y[i, :, 0] = -radius * np.cos(t + phase)
        y[i, :, 1] = -radius * np.sin(t + phase)
    
    return mx.array(X), mx.array(y)

def benchmark_control():
    """Benchmark control task."""
    # Generate data
    X_train, y_train = generate_control_data()
    X_test, y_test = generate_control_data(n_samples=100)
    
    # Define models
    models = {
        'Random Dense': CfC(Random(units=100, sparsity_level=0.2)),
        'Random Sparse': CfC(Random(units=100, sparsity_level=0.8)),
        'NCP': CfC(NCP(
            inter_neurons=50,
            command_neurons=30,
            motor_neurons=2,
            sensory_fanout=5,
            inter_fanout=5,
            recurrent_command_synapses=10,
            motor_fanin=5
        )),
        'AutoNCP': CfC(AutoNCP(units=100, output_size=2))
    }
    
    results = {}
    for name, model in models.items():
        # Train model
        optimizer = nn.Adam(learning_rate=0.001)
        train_losses = []
        train_time = time()
        
        for epoch in range(100):
            def loss_fn(model, x, y):
                pred = model(x)
                return mx.mean((pred - y) ** 2)
            
            loss, grads = mx.value_and_grad(model, loss_fn)(model, X_train, y_train)
            optimizer.update(model, grads)
            train_losses.append(float(loss))
        
        train_time = time() - train_time
        
        # Evaluate
        pred = model(X_test)
        test_loss = float(mx.mean((pred - y_test) ** 2))
        
        # Profile real-time performance
        profiler = MLXProfiler(model)
        latency_stats = profiler.profile_compute(
            batch_size=1,  # Real-time control
            seq_length=1,  # Single step
            num_runs=1000
        )
        
        results[name] = {
            'train_time': train_time,
            'train_loss': train_losses,
            'test_loss': test_loss,
            'latency': latency_stats['time_mean'] * 1000,  # ms
            'latency_std': latency_stats['time_std'] * 1000  # ms
        }
    
    return results

# Run benchmark
control_results = benchmark_control()

# Plot results
plt.figure(figsize=(15, 5))

plt.subplot(131)
for name, result in control_results.items():
    plt.plot(result['train_loss'], label=name)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.legend()
plt.grid(True)

plt.subplot(132)
plt.bar(control_results.keys(),
        [r['test_loss'] for r in control_results.values()])
plt.xticks(rotation=45)
plt.ylabel('Test Loss')
plt.title('Control Performance')

plt.subplot(133)
plt.bar(control_results.keys(),
        [r['latency'] for r in control_results.values()],
        yerr=[r['latency_std'] for r in control_results.values()])
plt.xticks(rotation=45)
plt.ylabel('Latency (ms)')
plt.title('Real-time Performance')

plt.tight_layout()
plt.show()

## Benchmark Summary

Based on our benchmarks:

1. **Sequence Prediction**
   - NCP performs best for long-term dependencies
   - Dense patterns good for short sequences
   - AutoNCP balances performance and efficiency

2. **Classification**
   - Sparse patterns work well
   - Memory usage varies significantly
   - Training time differences notable

3. **Control Tasks**
   - Real-time performance critical
   - Latency varies by pattern
   - Trade-off between accuracy and speed

Recommendations:
- Use NCP for complex temporal tasks
- Consider AutoNCP for balanced performance
- Choose sparsity based on task requirements
- Monitor real-time performance carefully