# Neural Circuit Policies: Performance Benchmarks

This notebook compares the performance of our MLX implementation against other frameworks (PyTorch, TensorFlow) across different scenarios:

- Training speed
- Inference latency
- Memory usage
- Scaling with sequence length

In [1]:
# Install required packages if not present
try:
    import memory_profiler
except ImportError:
    %pip install memory_profiler

import mlx.core as mx
import mlx.nn as nn
import torch
import tensorflow as tf
import numpy as np
import time
import matplotlib.pyplot as plt

try:
    from memory_profiler import profile
except ImportError:
    print("Warning: memory_profiler not available. Memory profiling will be disabled.")
    def profile(func):
        return func

from ncps.mlx import CfC as MLXCfC
from ncps.torch import CfC as TorchCfC
from ncps.tf import CfC as TFCfC

AttributeError: module 'keras._tf_keras.keras.layers' has no attribute 'AbstractRNNCell'

## Computation Efficiency

In [None]:
def benchmark_computation(model_sizes=[32, 64, 128, 256]):
    """Benchmark computation efficiency across model sizes."""
    try:
        from ncps.wirings import AutoNCP
        from ncps.mlx import CfC
        from ncps.mlx.profiling import profile_wiring
        
        results = []
        batch_size = 64
        seq_len = 32
        
        for size in model_sizes:
            try:
                wiring = AutoNCP(units=size, output_size=size//4)
                model = CfC(wiring=wiring)
                
                # Profile model
                perf_stats, conn_stats = profile_wiring(
                    wiring=wiring,
                    model=model,
                    batch_size=batch_size,
                    seq_length=seq_len,
                    num_runs=100
                )
                
                # Calculate FLOPS (approximation based on forward pass)
                # Each neuron performs multiply-add operations with its inputs
                synapses = conn_stats['avg_in_degree'] * wiring.units
                flops_per_step = synapses * 2  # multiply + add
                total_flops = flops_per_step * seq_len * batch_size
                
                # Calculate throughput
                forward_time = perf_stats['forward_time']
                throughput = total_flops / forward_time if forward_time > 0 else 0
                
                results.append({
                    'size': size,
                    'flops': total_flops,
                    'throughput': throughput
                })
                
            except Exception as e:
                print(f"Error benchmarking size {size}: {str(e)}")
                continue
                
        return results
        
    except ImportError as e:
        print(f"Missing required imports: {str(e)}")
        return []

# Run computation benchmark
comp_results = benchmark_computation()

# Plot results
plt.figure(figsize=(10, 5))

plt.subplot(121)
plt.plot([r['size'] for r in comp_results],
         [r['flops'] for r in comp_results],
         marker='o')
plt.xlabel('Model Size')
plt.ylabel('FLOPS')
plt.title('Computational Complexity')
plt.grid(True)

plt.subplot(122)
plt.plot([r['size'] for r in comp_results],
         [r['throughput'] for r in comp_results],
         marker='o')
plt.xlabel('Model Size')
plt.ylabel('Throughput (FLOPS/s)')
plt.title('Computational Efficiency')
plt.grid(True)

plt.tight_layout()
plt.show()