# Neural Circuit Policies with MLX on Apple Silicon

This notebook demonstrates how to use Neural Circuit Policies optimized for Apple Silicon processors using MLX:

- Neural Engine Optimization
- Hardware-Specific Features
- Performance Monitoring
- Advanced Architectures

In [None]:
import mlx.core as mx
import mlx.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from ncps.mlx import CfC, CfCCell
from ncps.wirings import AutoNCP
from ncps.mlx.advanced_profiling import MLXProfiler
from ncps.tests.configs.device_configs import get_device_config

## 1. Device Configuration

First, let's get the optimal configuration for our Apple Silicon device:

In [None]:
# Get device configuration
config = get_device_config()
print(f"Detected device: {config.device_type}")
print(f"Optimal batch size: {config.get_optimal_batch_size()}")
print(f"Optimal hidden size: {config.get_optimal_hidden_size()}")
print(f"Optimal backbone: {config.get_optimal_backbone()}")

## 2. Model Creation

Create models optimized for the device:

In [None]:
class OptimizedSequenceModel(nn.Module):
    """Sequence model optimized for Apple Silicon."""
    
    def __init__(self, config):
        super().__init__()
        
        # Create wiring with optimal size
        wiring = AutoNCP(
            units=config.get_optimal_hidden_size(),
            output_size=config.get_optimal_hidden_size() // 4
        )
        
        # Create optimized model
        self.cfc = CfC(
            cell=CfCCell(
                wiring=wiring,
                activation="tanh",
                backbone_units=config.get_optimal_backbone(),
                backbone_layers=2
            ),
            return_sequences=True,
            return_state=True
        )
        
        self.output_layer = nn.Linear(
            config.get_optimal_hidden_size(),
            config.get_optimal_hidden_size() // 4
        )
    
    def __call__(self, x, time_delta=None, initial_state=None):
        outputs, states = self.cfc(x, time_delta=time_delta, initial_state=initial_state)
        return self.output_layer(outputs[:, -1]), states

# Create model
model = OptimizedSequenceModel(config)

# Enable compilation for Neural Engine
@mx.compile(static_argnums=(1,))
def forward(x, training=False):
    return model(x, training=training)

## 3. Performance Profiling

Profile model performance on the device:

In [None]:
def profile_performance(model, config):
    """Profile model performance."""
    profiler = MLXProfiler(model)
    
    # Profile with different batch sizes
    results = []
    for batch_size in config.batch_sizes:
        # Create test data
        x = mx.random.normal((batch_size, 16, model.cfc.cell.input_size))
        
        # Profile compute
        compute_stats = profiler.profile_compute(
            batch_size=batch_size,
            seq_length=16,
            num_runs=100
        )
        
        # Profile memory
        memory_stats = profiler.profile_memory(
            batch_size=batch_size
        )
        
        results.append({
            'batch_size': batch_size,
            'tflops': compute_stats['tflops'],
            'memory': memory_stats['peak_usage'],
            'bandwidth': memory_stats['bandwidth']
        })
    
    return results

# Profile performance
results = profile_performance(model, config)

# Plot results
plt.figure(figsize=(15, 5))

# Plot TFLOPS
plt.subplot(131)
plt.plot(
    [r['batch_size'] for r in results],
    [r['tflops'] for r in results],
    marker='o'
)
plt.axhline(y=config.min_tflops, color='r', linestyle='--', label='Minimum Required')
plt.xlabel('Batch Size')
plt.ylabel('TFLOPS')
plt.title('Neural Engine Performance')
plt.legend()
plt.grid(True)

# Plot memory
plt.subplot(132)
plt.plot(
    [r['batch_size'] for r in results],
    [r['memory'] for r in results],
    marker='o'
)
plt.axhline(y=config.memory_budget, color='r', linestyle='--', label='Memory Budget')
plt.xlabel('Batch Size')
plt.ylabel('Memory (MB)')
plt.title('Memory Usage')
plt.legend()
plt.grid(True)

# Plot bandwidth
plt.subplot(133)
plt.plot(
    [r['batch_size'] for r in results],
    [r['bandwidth'] for r in results],
    marker='o'
)
plt.axhline(y=config.min_bandwidth, color='r', linestyle='--', label='Minimum Required')
plt.xlabel('Batch Size')
plt.ylabel('Bandwidth (GB/s)')
plt.title('Memory Bandwidth')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()

## 4. Advanced Features

Demonstrate advanced features like time-aware processing and state management:

In [None]:
def demonstrate_features(model, config):
    """Demonstrate advanced features."""
    batch_size = config.get_optimal_batch_size()
    seq_length = 16
    
    # Create data
    x = mx.random.normal((batch_size, seq_length, model.cfc.cell.input_size))
    
    # 1. Basic forward pass
    outputs, states = model(x)
    
    # 2. Time-aware processing
    time_delta = mx.random.uniform(
        low=0.5,
        high=1.5,
        shape=(batch_size, seq_length)
    )
    outputs_time, states_time = model(x, time_delta=time_delta)
    
    # 3. State management
    initial_state = mx.zeros((batch_size, model.cfc.cell.units))
    outputs_state, final_state = model(x, initial_state=initial_state)
    
    # Visualize states
    plt.figure(figsize=(15, 5))
    
    # Plot regular states
    plt.subplot(131)
    plt.imshow(states[0].T, aspect='auto', cmap='RdBu')
    plt.colorbar()
    plt.title('Regular States')
    plt.xlabel('Time Step')
    plt.ylabel('Neuron Index')
    
    # Plot time-aware states
    plt.subplot(132)
    plt.imshow(states_time[0].T, aspect='auto', cmap='RdBu')
    plt.colorbar()
    plt.title('Time-Aware States')
    plt.xlabel('Time Step')
    plt.ylabel('Neuron Index')
    
    # Plot state differences
    plt.subplot(133)
    plt.imshow((states_time[0] - states[0]).T, aspect='auto', cmap='RdBu')
    plt.colorbar()
    plt.title('State Differences')
    plt.xlabel('Time Step')
    plt.ylabel('Neuron Index')
    
    plt.tight_layout()
    plt.show()

# Demonstrate features
demonstrate_features(model, config)

## Hardware-Specific Insights

Based on our experiments:

1. **Neural Engine Performance**
   - Compilation provides significant speedup
   - Power-of-2 sizes are optimal
   - Batch size affects utilization
   - Device-specific scaling

2. **Memory Management**
   - Unified memory is efficient
   - Bandwidth scales with batch size
   - Memory usage is predictable
   - Device limits are respected

3. **Optimization Tips**
   - Use device-specific configs
   - Enable compilation
   - Monitor performance
   - Balance resources

4. **Device-Specific Settings**
   - M1: 32-64 batch size
   - M1 Pro/Max: 64-128 batch size
   - M1 Ultra: 128-256 batch size
   - Adjust based on model size