# Advanced Profiling Guide for Neural Circuit Policies

This notebook demonstrates how to use advanced profiling tools with MLX integration:
- Compute profiling
- Memory profiling
- Stream profiling
- Performance optimization

In [None]:
import mlx.core as mx
import mlx.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from ncps.mlx import CfC, LTC
from ncps.mlx.wirings import Random, NCP, AutoNCP
from ncps.mlx.advanced_profiling import MLXProfiler, quick_profile

## 1. Quick Profiling

Let's start with a quick overview of model performance:

In [None]:
# Create a model
wiring = Random(units=100, sparsity_level=0.5)
model = CfC(wiring=wiring)

# Quick profile
stats = quick_profile(
    model,
    batch_size=32,
    seq_length=10,
    num_runs=100
)

print("Compute Performance:")
for key, value in stats['compute'].items():
    if 'time' in key:
        print(f"{key}: {value*1000:.2f} ms")
    else:
        print(f"{key}: {value:.2f}")

print("\nMemory Usage:")
for key, value in stats['memory'].items():
    print(f"{key}: {value:.2f} MB")

print("\nStream Operations:")
for key, value in stats['stream'].items():
    if 'time' in key:
        print(f"{key}: {value*1000:.2f} ms")
    else:
        print(f"{key}: {value}")

## 2. Detailed Compute Analysis

Let's analyze computational performance in detail:

In [None]:
# Create profiler
profiler = MLXProfiler(model)

# Profile different batch sizes
batch_sizes = [1, 16, 32, 64, 128]
compute_results = []

for batch_size in batch_sizes:
    stats = profiler.profile_compute(
        batch_size=batch_size,
        seq_length=10,
        num_runs=50
    )
    compute_results.append({
        'batch_size': batch_size,
        'time': stats['time_mean'],
        'tflops': stats['tflops']
    })

# Plot results
plt.figure(figsize=(15, 5))

plt.subplot(121)
plt.plot([r['batch_size'] for r in compute_results],
         [r['time']*1000 for r in compute_results],
         marker='o')
plt.xlabel('Batch Size')
plt.ylabel('Time (ms)')
plt.title('Compute Time vs Batch Size')
plt.grid(True)

plt.subplot(122)
plt.plot([r['batch_size'] for r in compute_results],
         [r['tflops'] for r in compute_results],
         marker='o')
plt.xlabel('Batch Size')
plt.ylabel('TFLOPS')
plt.title('Compute Efficiency vs Batch Size')
plt.grid(True)

plt.tight_layout()
plt.show()

## 3. Memory Analysis

Let's examine memory usage patterns:

In [None]:
def analyze_memory_scaling(sizes=[50, 100, 200, 400]):
    """Analyze memory scaling with network size."""
    results = []
    
    for size in sizes:
        # Create model
        wiring = Random(units=size, sparsity_level=0.5)
        model = CfC(wiring=wiring)
        profiler = MLXProfiler(model)
        
        # Profile memory
        stats = profiler.profile_memory(
            batch_size=32,
            seq_length=10
        )
        
        results.append({
            'size': size,
            'peak': stats['peak_usage'],
            'allocated': stats['total_allocated']
        })
    
    return results

# Analyze memory scaling
memory_results = analyze_memory_scaling()

# Plot results
plt.figure(figsize=(15, 5))

plt.subplot(121)
plt.plot([r['size'] for r in memory_results],
         [r['peak'] for r in memory_results],
         marker='o')
plt.xlabel('Network Size')
plt.ylabel('Peak Memory (MB)')
plt.title('Peak Memory Usage')
plt.grid(True)

plt.subplot(122)
plt.plot([r['size'] for r in memory_results],
         [r['allocated'] for r in memory_results],
         marker='o')
plt.xlabel('Network Size')
plt.ylabel('Total Allocated (MB)')
plt.title('Total Memory Allocated')
plt.grid(True)

plt.tight_layout()
plt.show()

## 4. Stream Analysis

Let's analyze stream operations and data transfers:

In [None]:
def analyze_stream_operations(seq_lengths=[10, 20, 50, 100]):
    """Analyze stream operations with different sequence lengths."""
    results = []
    
    for seq_length in seq_lengths:
        stats = profiler.profile_stream(
            batch_size=32,
            seq_length=seq_length
        )
        
        results.append({
            'seq_length': seq_length,
            'kernel_time': stats['kernel_time'],
            'memory_time': stats['memory_time'],
            'num_kernels': stats['num_kernels']
        })
    
    return results

# Analyze stream operations
stream_results = analyze_stream_operations()

# Plot results
plt.figure(figsize=(15, 5))

plt.subplot(121)
plt.plot([r['seq_length'] for r in stream_results],
         [r['kernel_time']*1000 for r in stream_results],
         marker='o',
         label='Kernel Time')
plt.plot([r['seq_length'] for r in stream_results],
         [r['memory_time']*1000 for r in stream_results],
         marker='o',
         label='Memory Time')
plt.xlabel('Sequence Length')
plt.ylabel('Time (ms)')
plt.title('Operation Times')
plt.legend()
plt.grid(True)

plt.subplot(122)
plt.plot([r['seq_length'] for r in stream_results],
         [r['num_kernels'] for r in stream_results],
         marker='o')
plt.xlabel('Sequence Length')
plt.ylabel('Number of Kernels')
plt.title('Kernel Count')
plt.grid(True)

plt.tight_layout()
plt.show()

## 5. Performance Optimization

Let's compare different wiring patterns and configurations:

In [None]:
def compare_wirings():
    """Compare different wiring patterns."""
    configs = [
        ('Random Dense', Random(units=100, sparsity_level=0.2)),
        ('Random Sparse', Random(units=100, sparsity_level=0.8)),
        ('NCP', NCP(
            inter_neurons=50,
            command_neurons=30,
            motor_neurons=20,
            sensory_fanout=5,
            inter_fanout=5,
            recurrent_command_synapses=10,
            motor_fanin=5
        )),
        ('AutoNCP', AutoNCP(units=100, output_size=20, sparsity_level=0.5))
    ]
    
    results = []
    for name, wiring in configs:
        model = CfC(wiring=wiring)
        stats = quick_profile(model)
        
        results.append({
            'name': name,
            'compute_time': stats['compute']['time_mean'],
            'memory_usage': stats['memory']['peak_usage'],
            'tflops': stats['compute']['tflops']
        })
    
    return results

# Compare wiring patterns
comparison_results = compare_wirings()

# Plot comparison
plt.figure(figsize=(15, 5))

plt.subplot(131)
plt.bar([r['name'] for r in comparison_results],
        [r['compute_time']*1000 for r in comparison_results])
plt.xticks(rotation=45)
plt.ylabel('Compute Time (ms)')
plt.title('Computation Time')

plt.subplot(132)
plt.bar([r['name'] for r in comparison_results],
        [r['memory_usage'] for r in comparison_results])
plt.xticks(rotation=45)
plt.ylabel('Memory Usage (MB)')
plt.title('Memory Usage')

plt.subplot(133)
plt.bar([r['name'] for r in comparison_results],
        [r['tflops'] for r in comparison_results])
plt.xticks(rotation=45)
plt.ylabel('TFLOPS')
plt.title('Compute Efficiency')

plt.tight_layout()
plt.show()

## Performance Insights

Based on our analysis:

1. **Compute Performance**
   - Larger batch sizes improve TFLOPS
   - Dense patterns are faster for small networks
   - Sparse patterns scale better

2. **Memory Usage**
   - Memory scales quadratically with size
   - Sparsity significantly reduces memory
   - AutoNCP provides good balance

3. **Stream Operations**
   - Kernel time dominates for large sequences
   - Memory transfers increase with size
   - Batch processing helps efficiency

4. **Optimization Tips**
   - Choose sparsity based on size
   - Optimize batch size for hardware
   - Consider sequence length impact
   - Monitor memory allocation