# Neural Circuit Policy Wiring Patterns

This notebook demonstrates how to use different wiring patterns with liquid neural networks in MLX. We'll cover:
- Fully connected networks
- Random sparse networks
- Neural Circuit Policy (NCP) architectures
- Custom wiring patterns

In [None]:
import mlx.core as mx
import mlx.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from ncps.mlx import CfC, LTC
from ncps.mlx.wirings import Wiring, FullyConnected, Random, NCP, AutoNCP

## 1. Fully Connected Networks

Let's start with a fully connected network where every neuron is connected to every other neuron:

In [None]:
# Create fully connected wiring
wiring = FullyConnected(
    units=32,
    output_dim=10,
    self_connections=True
)

# Create CfC model with this wiring
model = CfC(
    wiring=wiring,
    activation="tanh",
    backbone_units=[64],
    backbone_layers=1
)

# Generate sample data
batch_size = 16
seq_length = 20
input_dim = 8
x = mx.random.normal((batch_size, seq_length, input_dim))

# Process data
output = model(x)
print(f"Output shape: {output.shape}")

## 2. Random Sparse Networks

Random sparse networks have fewer connections, which can improve efficiency and generalization:

In [None]:
# Create random sparse wiring
wiring = Random(
    units=32,
    output_dim=10,
    sparsity_level=0.5  # 50% of possible connections
)

# Create LTC model with this wiring
model = LTC(
    wiring=wiring,
    activation="tanh",
    backbone_units=[64],
    backbone_layers=1
)

# Process data
output = model(x)
print(f"Output shape: {output.shape}")
print(f"Number of synapses: {wiring.synapse_count}")

## 3. Neural Circuit Policy (NCP)

NCPs use a structured architecture with distinct neuron types:

In [None]:
# Create NCP wiring
wiring = NCP(
    inter_neurons=16,
    command_neurons=8,
    motor_neurons=4,
    sensory_fanout=4,
    inter_fanout=4,
    recurrent_command_synapses=3,
    motor_fanin=4
)

# Create CfC model with NCP wiring
model = CfC(
    wiring=wiring,
    activation="tanh",
    backbone_units=[64],
    backbone_layers=1
)

# Process data
output = model(x)
print(f"Output shape: {output.shape}")

# Print neuron types
for i in range(wiring.units):
    print(f"Neuron {i}: {wiring.get_type_of_neuron(i)}")

## 4. Automatic NCP

AutoNCP simplifies NCP creation with automatic architecture selection:

In [None]:
# Create AutoNCP wiring
wiring = AutoNCP(
    units=32,
    output_size=4,
    sparsity_level=0.5
)

# Create model with AutoNCP wiring
model = CfC(
    wiring=wiring,
    activation="tanh",
    backbone_units=[64],
    backbone_layers=1
)

# Process data
output = model(x)
print(f"Output shape: {output.shape}")

## 5. Custom Wiring

You can create custom wiring patterns by subclassing the Wiring class:

In [None]:
class LayeredWiring(Wiring):
    """Custom wiring with layered connectivity."""
    
    def __init__(self, layer_sizes, connections_per_layer=2):
        total_units = sum(layer_sizes)
        super().__init__(total_units)
        
        self.layer_sizes = layer_sizes
        self.connections_per_layer = connections_per_layer
        self.set_output_dim(layer_sizes[-1])
        
        # Connect layers
        start_idx = 0
        for i in range(len(layer_sizes) - 1):
            current_size = layer_sizes[i]
            next_size = layer_sizes[i + 1]
            
            # Connect each neuron to n neurons in next layer
            for j in range(current_size):
                current_idx = start_idx + j
                next_indices = np.random.choice(
                    range(start_idx + current_size, start_idx + current_size + next_size),
                    size=connections_per_layer,
                    replace=False
                )
                
                for next_idx in next_indices:
                    self.add_synapse(current_idx, next_idx, 1)
            
            start_idx += current_size

# Create custom layered wiring
wiring = LayeredWiring(
    layer_sizes=[16, 8, 4],
    connections_per_layer=2
)

# Create model with custom wiring
model = CfC(
    wiring=wiring,
    activation="tanh",
    backbone_units=[64],
    backbone_layers=1
)

# Process data
output = model(x)
print(f"Output shape: {output.shape}")
print(f"Number of synapses: {wiring.synapse_count}")

## 6. Training with Different Wirings

Let's compare how different wiring patterns perform on a simple task:

In [None]:
def create_sine_data(samples=1000, seq_length=50):
    """Create sine wave prediction data."""
    t = np.linspace(0, 4*np.pi, seq_length)
    X = np.zeros((samples, seq_length, 1))
    y = np.zeros((samples, seq_length, 1))
    
    for i in range(samples):
        phase = np.random.rand() * 2 * np.pi
        freq = 1.0 + 0.1 * np.random.randn()
        X[i, :, 0] = np.sin(freq * t + phase)
        y[i, :, 0] = np.cos(freq * t + phase)  # Predict derivative
    
    return mx.array(X), mx.array(y)

# Generate data
X_train, y_train = create_sine_data()
X_test, y_test = create_sine_data(samples=100)

# Create models with different wirings
models = {
    'Fully Connected': CfC(FullyConnected(32, output_dim=1)),
    'Random Sparse': CfC(Random(32, output_dim=1, sparsity_level=0.5)),
    'NCP': CfC(AutoNCP(32, output_size=1, sparsity_level=0.5))
}

# Training function
def train_model(model, X, y, epochs=100):
    optimizer = nn.Adam(learning_rate=0.001)
    losses = []
    
    def loss_fn(model, x, y):
        pred = model(x)
        return mx.mean((pred - y) ** 2)
    
    loss_and_grad = nn.value_and_grad(model, loss_fn)
    
    for epoch in range(epochs):
        loss, grads = loss_and_grad(model, X, y)
        optimizer.update(model, grads)
        losses.append(float(loss))
        
        if epoch % 10 == 0:
            print(f"Epoch {epoch}, Loss: {float(loss):.4f}")
    
    return losses

# Train and compare models
results = {}
for name, model in models.items():
    print(f"\nTraining {name} model...")
    losses = train_model(model, X_train, y_train)
    results[name] = losses

# Plot training curves
plt.figure(figsize=(10, 6))
for name, losses in results.items():
    plt.plot(losses, label=name)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss by Wiring Pattern')
plt.legend()
plt.grid(True)
plt.show()

## Analysis

Different wiring patterns have different strengths:

1. **Fully Connected**
   - Maximum expressivity
   - Higher memory usage
   - May overfit on small datasets

2. **Random Sparse**
   - Better generalization
   - More efficient
   - May miss important connections

3. **NCP**
   - Structured connectivity
   - Good balance of efficiency and performance
   - Inspired by biological neural circuits

Choose the wiring pattern based on your specific needs:
- Use fully connected for small networks where expressivity is key
- Use random sparse for large networks where efficiency matters
- Use NCP for structured problems with clear hierarchies