# Apple Silicon Optimization Guide

This notebook demonstrates how to optimize Neural Circuit Policies for Apple Silicon processors:

- Neural Engine Optimization
- Memory Management
- Performance Profiling
- Hardware-Specific Tuning

In [None]:
import mlx.core as mx
import mlx.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from time import time
from ncps.mlx import CfC, CfCCell, LTC, LTCCell
from ncps.wirings import Random, NCP, AutoNCP
from ncps.mlx.advanced_profiling import MLXProfiler, quick_profile

## 1. Neural Engine Optimization

MLX automatically leverages the Neural Engine for supported operations. Here's how to optimize for it:

In [None]:
def optimize_for_neural_engine(sizes=[32, 64, 128]):
    """Compare configurations optimized for Neural Engine."""
    results = []
    
    # Test different model configurations
    for hidden_size in sizes:
        # Create wiring with power-of-2 sizes for efficiency
        wiring = AutoNCP(
            units=hidden_size,
            output_size=hidden_size // 4,
            sparsity_level=0.5
        )
        
        # Create model with Neural Engine-friendly configuration
        model = CfC(
            cell=CfCCell(
                wiring=wiring,
                activation="tanh",
                backbone_units=[hidden_size, hidden_size],  # Power of 2 sizes
                backbone_layers=2
            ),
            return_sequences=True
        )
        
        # Test different batch sizes
        batch_sizes = [16, 32, 64, 128]
        for batch_size in batch_sizes:
            profiler = MLXProfiler(model)
            
            # Profile with and without compilation
            for compiled in [False, True]:
                if compiled:
                    # Compile for static shapes
                    @mx.compile(static_argnums=(1,))
                    def forward(x, training=True):
                        return model(x, training=training)
                else:
                    forward = lambda x, training: model(x, training=training)
                
                stats = profiler.profile_compute(
                    batch_size=batch_size,
                    seq_length=16,  # Power of 2
                    num_runs=50,
                    forward_fn=forward
                )
                
                memory_stats = profiler.profile_memory(
                    batch_size=batch_size
                )
                
                results.append({
                    'size': hidden_size,
                    'batch_size': batch_size,
                    'compiled': compiled,
                    'tflops': stats['tflops'],
                    'memory': memory_stats['peak_usage'],
                    'time': stats['time_mean']
                })
    
    return results

# Run Neural Engine optimization
ne_results = optimize_for_neural_engine()

# Plot results
plt.figure(figsize=(15, 5))

# Plot TFLOPS
plt.subplot(131)
for compiled in [False, True]:
    data = [r for r in ne_results if r['size'] == 64 and r['compiled'] == compiled]
    plt.plot(
        [d['batch_size'] for d in data],
        [d['tflops'] for d in data],
        marker='o',
        label=f'{\'Compiled\' if compiled else \'Uncompiled\'}'
    )
plt.xlabel('Batch Size')
plt.ylabel('TFLOPS')
plt.title('Neural Engine Performance')
plt.legend()
plt.grid(True)

# Plot Memory Usage
plt.subplot(132)
sizes = [32, 64, 128]
for size in sizes:
    data = [r for r in ne_results if r['size'] == size and not r['compiled']]
    plt.plot(
        [d['batch_size'] for d in data],
        [d['memory'] for d in data],
        marker='o',
        label=f'Size {size}'
    )
plt.xlabel('Batch Size')
plt.ylabel('Memory (MB)')
plt.title('Memory Usage')
plt.legend()
plt.grid(True)

# Plot Execution Time
plt.subplot(133)
for compiled in [False, True]:
    data = [r for r in ne_results if r['size'] == 64 and r['compiled'] == compiled]
    plt.plot(
        [d['batch_size'] for d in data],
        [d['time']*1000 for d in data],
        marker='o',
        label=f'{\'Compiled\' if compiled else \'Uncompiled\'}'
    )
plt.xlabel('Batch Size')
plt.ylabel('Time (ms)')
plt.title('Execution Time')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()

## 2. Memory Management

MLX's unified memory architecture requires careful management:

In [None]:
class MemoryOptimizedTrainer:
    def __init__(self, model, learning_rate=0.001):
        self.model = model
        self.optimizer = nn.Adam(learning_rate=learning_rate)
        
    @mx.compile(static_argnums=(1,))
    def train_step(self, training=True):
        def loss_fn(model, x, y):
            pred = model(x, training=training)
            return mx.mean((pred - y) ** 2)
        return mx.value_and_grad(self.model, loss_fn)
    
    def train(self, X, y, batch_size=32, epochs=10):
        history = {'loss': [], 'memory': [], 'time': []}
        n_batches = len(X) // batch_size
        
        for epoch in range(epochs):
            epoch_start = time()
            epoch_loss = 0
            
            # Shuffle data
            indices = mx.random.permutation(len(X))
            X = X[indices]
            y = y[indices]
            
            for i in range(n_batches):
                start_idx = i * batch_size
                end_idx = start_idx + batch_size
                
                batch_x = X[start_idx:end_idx]
                batch_y = y[start_idx:end_idx]
                
                # Compute loss and gradients
                loss, grads = self.train_step()(self.model, batch_x, batch_y)
                
                # Update weights
                self.optimizer.update(self.model, grads)
                
                epoch_loss += float(loss)
            
            # Record metrics
            history['loss'].append(epoch_loss / n_batches)
            history['time'].append(time() - epoch_start)
            
            # Profile memory
            profiler = MLXProfiler(self.model)
            memory_stats = profiler.profile_memory(batch_size=batch_size)
            history['memory'].append(memory_stats['peak_usage'])
            
            print(f"Epoch {epoch+1}/{epochs}, Loss: {history['loss'][-1]:.4f}")
        
        return history

# Generate sample data
X = mx.random.normal((1000, 16, 8))
y = mx.random.normal((1000, 16, 1))

# Create models with different configurations
configs = [
    ('Small', 32, 32),
    ('Medium', 64, 64),
    ('Large', 128, 128)
]

results = {}
for name, hidden_size, batch_size in configs:
    print(f"\nTraining {name} model...")
    
    # Create model
    wiring = AutoNCP(units=hidden_size, output_size=1)
    model = CfC(
        cell=CfCCell(
            wiring=wiring,
            backbone_units=[hidden_size],
            backbone_layers=1
        )
    )
    
    # Train model
    trainer = MemoryOptimizedTrainer(model)
    results[name] = trainer.train(X, y, batch_size=batch_size, epochs=5)

# Plot results
plt.figure(figsize=(15, 5))

# Plot Training Loss
plt.subplot(131)
for name, result in results.items():
    plt.plot(result['loss'], label=name)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.legend()
plt.grid(True)

# Plot Memory Usage
plt.subplot(132)
for name, result in results.items():
    plt.plot(result['memory'], marker='o', label=name)
plt.xlabel('Epoch')
plt.ylabel('Memory (MB)')
plt.title('Peak Memory Usage')
plt.legend()
plt.grid(True)

# Plot Training Time
plt.subplot(133)
for name, result in results.items():
    plt.plot(result['time'], marker='o', label=name)
plt.xlabel('Epoch')
plt.ylabel('Time (s)')
plt.title('Training Time per Epoch')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()

## Hardware-Specific Recommendations

Based on our experiments:

1. **Neural Engine Optimization**
   - Use power-of-2 sizes for tensors
   - Enable MLX compilation
   - Batch sizes: 32-128 work best
   - Use static shapes when possible

2. **Memory Management**
   - Leverage unified memory
   - Clear unused variables
   - Use appropriate batch sizes
   - Monitor memory usage

3. **Performance Tips**
   - Profile your specific device
   - Use MLX's lazy evaluation
   - Enable operator fusion
   - Monitor hardware utilization

4. **Device-Specific Settings**
   - M1: 32-64 batch size
   - M1 Pro/Max: 64-128 batch size
   - M1 Ultra: 128-256 batch size
   - Adjust based on your model size