# Neural Circuit Policies for Robotics

This notebook demonstrates how to use wiring patterns for robotics applications:
- State estimation
- Control policies
- Sensor fusion
- Motion planning

In [None]:
import mlx.core as mx
import mlx.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from ncps.mlx import CfC, LTC
from ncps.mlx.wirings import Wiring, NCP, AutoNCP

## 1. Custom Robotics Wiring

Create a wiring pattern specifically for robotics applications:

In [None]:
class RoboticsWiring(Wiring):
    """Custom wiring for robotics applications.
    
    Architecture:
    - Sensor processing layer (inter neurons)
    - State estimation layer (command neurons)
    - Control layer (motor neurons)
    
    Features:
    - Direct sensor-to-motor connections for reflexes
    - Recurrent connections for state estimation
    - Multiple timescales for different control loops
    """
    
    def __init__(
        self,
        sensor_neurons: int,
        state_neurons: int,
        control_neurons: int,
        sensor_fanout: int = 4,
        state_recurrent: int = 3,
        control_fanin: int = 4
    ):
        total_units = sensor_neurons + state_neurons + control_neurons
        super().__init__(total_units)
        
        # Store configuration
        self.sensor_neurons = sensor_neurons
        self.state_neurons = state_neurons
        self.control_neurons = control_neurons
        self.sensor_fanout = sensor_fanout
        self.state_recurrent = state_recurrent
        self.control_fanin = control_fanin
        
        # Set output dimension (control neurons)
        self.set_output_dim(control_neurons)
        
        # Define neuron ranges
        self.control_range = range(control_neurons)
        self.state_range = range(
            control_neurons,
            control_neurons + state_neurons
        )
        self.sensor_range = range(
            control_neurons + state_neurons,
            total_units
        )
        
        # Connect layers
        self._build_sensor_connections()
        self._build_state_connections()
        self._build_control_connections()
        
    def _build_sensor_connections(self):
        """Build connections from sensor layer."""
        # Connect sensors to state estimation
        for src in self.sensor_range:
            targets = np.random.choice(
                list(self.state_range),
                size=self.sensor_fanout,
                replace=False
            )
            for dest in targets:
                self.add_synapse(src, dest, 1)
                
        # Direct sensor-to-motor connections (reflexes)
        for src in self.sensor_range:
            if np.random.random() < 0.2:  # 20% chance of reflex connection
                dest = np.random.choice(list(self.control_range))
                self.add_synapse(src, dest, 1)
    
    def _build_state_connections(self):
        """Build connections in state estimation layer."""
        # Recurrent connections for state memory
        for _ in range(self.state_recurrent):
            src = np.random.choice(list(self.state_range))
            dest = np.random.choice(list(self.state_range))
            self.add_synapse(src, dest, 1)
    
    def _build_control_connections(self):
        """Build connections to control layer."""
        # Connect state estimation to control
        for dest in self.control_range:
            sources = np.random.choice(
                list(self.state_range),
                size=self.control_fanin,
                replace=False
            )
            for src in sources:
                self.add_synapse(src, dest, 1)

# Create robotics wiring
wiring = RoboticsWiring(
    sensor_neurons=16,   # Process sensor inputs
    state_neurons=32,    # Estimate robot state
    control_neurons=4    # Control outputs
)

# Create model
model = CfC(
    wiring=wiring,
    activation="tanh",
    backbone_units=[64],
    backbone_layers=1
)

## 2. State Estimation

Use the model for state estimation from noisy sensor data:

In [None]:
def generate_robot_data(n_samples=1000, seq_length=50):
    """Generate simulated robot data.
    
    Returns:
    - Noisy sensor readings (position, velocity, acceleration)
    - True robot state
    """
    # True trajectory
    t = np.linspace(0, 4*np.pi, seq_length)
    position = np.sin(t)
    velocity = np.cos(t)
    acceleration = -np.sin(t)
    
    # Generate samples
    X = np.zeros((n_samples, seq_length, 9))  # 3 sensors x 3 measurements
    y = np.zeros((n_samples, seq_length, 3))  # True state
    
    for i in range(n_samples):
        # Add noise to each sensor
        X[i, :, 0:3] = position + np.random.normal(0, 0.1, (seq_length, 3))  # Position sensors
        X[i, :, 3:6] = velocity + np.random.normal(0, 0.2, (seq_length, 3))  # Velocity sensors
        X[i, :, 6:9] = acceleration + np.random.normal(0, 0.3, (seq_length, 3))  # Acceleration sensors
        
        # True state
        y[i, :, 0] = position
        y[i, :, 1] = velocity
        y[i, :, 2] = acceleration
    
    return mx.array(X), mx.array(y)

# Generate data
X_train, y_train = generate_robot_data()
X_test, y_test = generate_robot_data(n_samples=100)

# Train model
optimizer = nn.Adam(learning_rate=0.001)

def train_step(model, x, y):
    """Single training step."""
    def loss_fn(model, x, y):
        pred = model(x)
        return mx.mean((pred - y) ** 2)
    
    loss, grads = nn.value_and_grad(model, loss_fn)(model, x, y)
    optimizer.update(model, grads)
    return loss

# Training loop
losses = []
for epoch in range(100):
    loss = train_step(model, X_train, y_train)
    losses.append(float(loss))
    
    if epoch % 10 == 0:
        print(f"Epoch {epoch}, Loss: {float(loss):.4f}")

# Plot results
plt.figure(figsize=(15, 5))

# Plot training loss
plt.subplot(121)
plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.grid(True)

# Plot predictions
plt.subplot(122)
predictions = model(X_test)
plt.plot(y_test[0, :, 0], label='True')
plt.plot(predictions[0, :, 0], '--', label='Predicted')
plt.xlabel('Time Step')
plt.ylabel('Position')
plt.title('State Estimation')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()

## 3. Control Policy

Use the model for control policy learning:

In [None]:
def generate_control_data(n_samples=1000, seq_length=50):
    """Generate control policy training data.
    
    Returns:
    - Robot state (position, velocity)
    - Optimal control actions
    """
    # Generate trajectories
    t = np.linspace(0, 2*np.pi, seq_length)
    X = np.zeros((n_samples, seq_length, 4))  # State: [pos_x, pos_y, vel_x, vel_y]
    y = np.zeros((n_samples, seq_length, 2))  # Control: [force_x, force_y]
    
    for i in range(n_samples):
        # Generate circular trajectory
        radius = 1.0 + 0.1 * np.random.randn()
        phase = 2 * np.pi * np.random.rand()
        
        # Position
        X[i, :, 0] = radius * np.cos(t + phase)
        X[i, :, 1] = radius * np.sin(t + phase)
        
        # Velocity
        X[i, :, 2] = -radius * np.sin(t + phase)
        X[i, :, 3] = radius * np.cos(t + phase)
        
        # Optimal control (acceleration)
        y[i, :, 0] = -radius * np.cos(t + phase)
        y[i, :, 1] = -radius * np.sin(t + phase)
    
    return mx.array(X), mx.array(y)

# Create control policy model
control_wiring = RoboticsWiring(
    sensor_neurons=8,    # Process state input
    state_neurons=16,    # Internal representation
    control_neurons=2    # Control actions
)

control_model = CfC(
    wiring=control_wiring,
    activation="tanh"
)

# Generate data
X_train, y_train = generate_control_data()
X_test, y_test = generate_control_data(n_samples=100)

# Train model
optimizer = nn.Adam(learning_rate=0.001)
losses = []

for epoch in range(100):
    loss = train_step(control_model, X_train, y_train)
    losses.append(float(loss))
    
    if epoch % 10 == 0:
        print(f"Epoch {epoch}, Loss: {float(loss):.4f}")

# Plot results
plt.figure(figsize=(15, 5))

# Plot training loss
plt.subplot(121)
plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.grid(True)

# Plot trajectory
plt.subplot(122)
predictions = control_model(X_test)
plt.plot(X_test[0, :, 0], X_test[0, :, 1], label='Desired')
plt.quiver(
    X_test[0, ::5, 0],
    X_test[0, ::5, 1],
    predictions[0, ::5, 0],
    predictions[0, ::5, 1],
    color='r',
    label='Control'
)
plt.xlabel('X Position')
plt.ylabel('Y Position')
plt.title('Control Policy')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()

## Analysis

The robotics wiring pattern provides several advantages:

1. **State Estimation**
   - Effectively fuses multiple sensor inputs
   - Handles noisy measurements
   - Maintains temporal consistency

2. **Control Policy**
   - Learns smooth control actions
   - Adapts to different trajectories
   - Handles multi-dimensional control

3. **Architecture Benefits**
   - Direct sensor-to-motor connections for fast reflexes
   - State estimation layer for memory and filtering
   - Structured connectivity for better learning

Key considerations for robotics applications:
- Balance between reactivity and planning
- Handle multiple time scales
- Robust to sensor noise and delays
- Efficient computation for real-time control