# Neural Circuit Policy Visualization Guide

This notebook demonstrates how to use visualization tools to analyze and optimize neural circuit policies:
- Wiring Pattern Analysis
- Performance Visualization
- Profiling Results
- Comparative Analysis

In [None]:
import mlx.core as mx
import mlx.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from ncps.mlx import CfC, LTC
from ncps.mlx.wirings import Random, NCP, AutoNCP
from ncps.mlx.advanced_profiling import MLXProfiler
from ncps.mlx.visualization import WiringVisualizer, PerformanceVisualizer, ProfileVisualizer, plot_comparison

## 1. Wiring Pattern Analysis

Analyze different wiring patterns:

In [None]:
def analyze_wiring_pattern(wiring, name):
    """Analyze a wiring pattern."""
    print(f"\nAnalyzing {name}:")
    visualizer = WiringVisualizer(wiring)
    
    # Plot wiring pattern
    print("\nWiring Pattern:")
    visualizer.plot_wiring(figsize=(8, 8))
    
    # Plot connectivity matrix
    print("\nConnectivity Matrix:")
    visualizer.plot_connectivity_matrix()
    
    # Plot degree distributions
    print("\nDegree Distributions:")
    visualizer.plot_degree_distribution()
    
    # Plot path lengths
    print("\nPath Length Distribution:")
    visualizer.plot_path_lengths()

# Compare different wiring patterns
wirings = [
    ('Random Dense', Random(units=50, sparsity_level=0.2)),
    ('Random Sparse', Random(units=50, sparsity_level=0.8)),
    ('NCP', NCP(
        inter_neurons=25,
        command_neurons=15,
        motor_neurons=10,
        sensory_fanout=3,
        inter_fanout=3,
        recurrent_command_synapses=5,
        motor_fanin=3
    )),
    ('AutoNCP', AutoNCP(units=50, output_size=10, sparsity_level=0.5))
]

for name, wiring in wirings:
    analyze_wiring_pattern(wiring, name)

## 2. Performance Visualization

Track and visualize performance metrics during training:

In [None]:
def train_and_visualize(model, name, X, y, num_epochs=50):
    """Train model and visualize performance."""
    optimizer = nn.Adam(learning_rate=0.001)
    visualizer = PerformanceVisualizer()
    profiler = MLXProfiler(model)
    
    for epoch in range(num_epochs):
        # Training step
        def loss_fn(model, x, y):
            pred = model(x)
            return mx.mean((pred - y) ** 2)
        
        loss, grads = mx.value_and_grad(model, loss_fn)(model, X, y)
        optimizer.update(model, grads)
        
        # Profile performance
        stats = profiler.profile_compute(batch_size=32)
        memory_stats = profiler.profile_memory()
        
        # Record metrics
        visualizer.add_metrics(
            loss=float(loss),
            memory=memory_stats['peak_usage'],
            time=stats['time_mean'],
            tflops=stats['tflops']
        )
        
        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch+1}: Loss = {float(loss):.4f}")
    
    # Plot metrics
    print(f"\n{name} Performance Metrics:")
    visualizer.plot_metrics(rolling_window=5)
    
    # Plot correlations
    print(f"\n{name} Metric Correlations:")
    visualizer.plot_correlation('loss', 'tflops')
    visualizer.plot_correlation('memory', 'time')
    
    return visualizer.history

# Generate sample data
X = mx.random.normal((1000, 10, 8))
y = mx.random.normal((1000, 10, 1))

# Train and visualize different models
results = {}
for name, wiring in wirings:
    model = CfC(wiring=wiring)
    results[name] = train_and_visualize(model, name, X, y)

# Compare results
print("\nComparison of Models:")
plot_comparison(results, metrics=['loss', 'tflops', 'memory'])

## 3. Profiling Visualization

Visualize detailed profiling results:

In [None]:
def profile_and_visualize(model, name):
    """Profile and visualize model performance."""
    profiler = MLXProfiler(model)
    visualizer = ProfileVisualizer(profiler)
    
    # Profile with different batch sizes
    batch_sizes = [1, 16, 32, 64]
    for batch_size in batch_sizes:
        profiler.profile_compute(
            batch_size=batch_size,
            seq_length=10,
            num_runs=50
        )
        profiler.profile_memory(
            batch_size=batch_size
        )
        profiler.profile_stream(
            batch_size=batch_size
        )
    
    print(f"\n{name} Profiling Results:")
    
    # Plot compute profile
    print("\nCompute Profile:")
    visualizer.plot_compute_profile()
    
    # Plot memory profile
    print("\nMemory Profile:")
    visualizer.plot_memory_profile()
    
    # Plot stream profile
    print("\nStream Profile:")
    visualizer.plot_stream_profile()

# Profile different models
for name, wiring in wirings:
    model = CfC(wiring=wiring)
    profile_and_visualize(model, name)

## Visualization Insights

Based on our analysis:

1. **Wiring Patterns**
   - Dense patterns show high connectivity
   - Sparse patterns reduce complexity
   - NCP provides structured connectivity
   - AutoNCP balances structure and sparsity

2. **Performance Metrics**
   - Training convergence varies
   - Memory usage correlates with density
   - Compute efficiency depends on batch size
   - Trade-offs between metrics evident

3. **Profiling Results**
   - Batch size impacts efficiency
   - Memory patterns differ by architecture
   - Stream operations show optimization opportunities
   - Performance characteristics guide tuning

Recommendations:
- Use visualization tools regularly
- Monitor multiple metrics
- Consider trade-offs in design
- Optimize based on profiling