# Neural Circuit Policy Profiling Guide

This notebook demonstrates how to use the profiling tools to analyze and optimize neural circuit policies:
- Memory usage analysis
- Performance profiling
- Connectivity analysis
- Optimization techniques

In [None]:
import mlx.core as mx
import mlx.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from ncps.mlx import CfC, LTC
from ncps.mlx.wirings import Random, NCP, AutoNCP
from ncps.mlx.profiling import WiringProfiler, profile_wiring

## 1. Basic Profiling

Let's start by profiling a simple random wiring pattern:

In [None]:
# Create a random wiring
wiring = Random(units=100, sparsity_level=0.5)
model = CfC(wiring=wiring)

# Quick profile
perf_stats, conn_stats = profile_wiring(
    wiring,
    model=model,
    batch_size=32,
    seq_length=10,
    num_runs=100
)

print("Performance Statistics:")
for key, value in perf_stats.items():
    if 'time' in key:
        print(f"{key}: {value*1000:.2f} ms")
    else:
        print(f"{key}: {value:.2f}")

print("\nConnectivity Statistics:")
for key, value in conn_stats.items():
    print(f"{key}: {value}")

## 2. Detailed Analysis

For more detailed analysis, we can use the WiringProfiler class:

In [None]:
# Create profiler
profiler = WiringProfiler(wiring)

# Profile multiple runs
for _ in range(5):
    # Generate random data
    x = mx.random.normal((32, 10, 8))
    y = mx.random.normal((32, 10, wiring.output_dim))
    
    # Profile forward and backward passes
    fwd_stats = profiler.profile_forward(model, batch_size=32)
    bwd_stats = profiler.profile_backward(model, batch_size=32)

# Plot history
profiler.plot_history()

## 3. Connectivity Analysis

Let's analyze different wiring patterns:

In [None]:
def analyze_wiring(name, wiring):
    """Analyze a wiring pattern."""
    profiler = WiringProfiler(wiring)
    print(f"\n{name} Analysis:")
    print(profiler.summary())

# Compare different patterns
wirings = {
    'Random (Dense)': Random(units=100, sparsity_level=0.2),
    'Random (Sparse)': Random(units=100, sparsity_level=0.8),
    'NCP': NCP(
        inter_neurons=50,
        command_neurons=30,
        motor_neurons=20,
        sensory_fanout=5,
        inter_fanout=5,
        recurrent_command_synapses=10,
        motor_fanin=5
    ),
    'AutoNCP': AutoNCP(units=100, output_size=20, sparsity_level=0.5)
}

for name, wiring in wirings.items():
    analyze_wiring(name, wiring)

## 4. Performance Optimization

Let's explore how different parameters affect performance:

In [None]:
def benchmark_wiring(wiring, batch_sizes=[1, 16, 32, 64]):
    """Benchmark wiring with different batch sizes."""
    model = CfC(wiring=wiring)
    profiler = WiringProfiler(wiring)
    
    results = []
    for batch_size in batch_sizes:
        fwd_stats = profiler.profile_forward(
            model,
            batch_size=batch_size,
            num_runs=50
        )
        results.append({
            'batch_size': batch_size,
            'time': fwd_stats['mean']
        })
    
    return results

# Compare different sparsity levels
sparsities = [0.2, 0.5, 0.8]
results = {}

for sparsity in sparsities:
    wiring = Random(units=100, sparsity_level=sparsity)
    results[sparsity] = benchmark_wiring(wiring)

# Plot results
plt.figure(figsize=(10, 5))
for sparsity, data in results.items():
    batch_sizes = [d['batch_size'] for d in data]
    times = [d['time']*1000 for d in data]  # Convert to ms
    plt.plot(batch_sizes, times, marker='o', label=f'Sparsity {sparsity}')

plt.xlabel('Batch Size')
plt.ylabel('Forward Time (ms)')
plt.title('Performance vs Batch Size for Different Sparsity Levels')
plt.legend()
plt.grid(True)
plt.show()

## 5. Memory Analysis

Let's analyze memory usage patterns:

In [None]:
def analyze_memory_scaling(sizes=[50, 100, 200, 400, 800]):
    """Analyze memory scaling with network size."""
    memories = []
    
    for size in sizes:
        wiring = Random(units=size, sparsity_level=0.5)
        profiler = WiringProfiler(wiring)
        memories.append(profiler._measure_memory())
    
    plt.figure(figsize=(10, 5))
    plt.plot(sizes, memories, marker='o')
    plt.xlabel('Number of Units')
    plt.ylabel('Memory Usage (MB)')
    plt.title('Memory Scaling with Network Size')
    plt.grid(True)
    plt.show()
    
    # Fit quadratic curve
    coeffs = np.polyfit(sizes, memories, 2)
    print(f"Memory scaling approximately O(n^2) with coefficient: {coeffs[0]:.2e}")

analyze_memory_scaling()

## 6. Optimization Guidelines

Based on our analysis, here are key guidelines for optimizing neural circuit policies:

1. **Memory Optimization**
   - Use appropriate sparsity levels
   - Balance network size with performance
   - Consider memory-compute tradeoffs

2. **Performance Optimization**
   - Choose batch sizes based on hardware
   - Optimize network topology
   - Use appropriate sparsity patterns

3. **Connectivity Optimization**
   - Design task-specific wiring patterns
   - Balance local and global connections
   - Consider information flow paths

4. **Training Optimization**
   - Use appropriate batch sizes
   - Monitor memory usage
   - Profile critical operations