# Signal Processing with Neural Circuit Policies

This notebook demonstrates how to use wiring patterns for signal processing tasks:
- Filtering and denoising
- Frequency analysis
- Signal prediction
- Multi-scale decomposition

In [None]:
import mlx.core as mx
import mlx.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from scipy import signal
from ncps.mlx import CfC, LTC
from ncps.mlx.wirings import Wiring

## 1. Multi-Scale Signal Processing

Create a wiring pattern for processing signals at multiple scales:

In [None]:
class SignalWiring(Wiring):
    """Wiring pattern for signal processing.
    
    Architecture:
    - Multiple frequency bands
    - Band-specific processing
    - Cross-band interactions
    """
    
    def __init__(
        self,
        input_size: int,
        num_bands: int = 4,
        neurons_per_band: int = 16,
        output_size: int = 1
    ):
        total_units = num_bands * neurons_per_band + output_size
        super().__init__(total_units)
        
        # Store configuration
        self.num_bands = num_bands
        self.neurons_per_band = neurons_per_band
        self.output_size = output_size
        
        # Set output dimension
        self.set_output_dim(output_size)
        
        # Define band ranges
        self.band_ranges = [
            range(
                output_size + i * neurons_per_band,
                output_size + (i + 1) * neurons_per_band
            )
            for i in range(num_bands)
        ]
        
        # Build connectivity
        self._build_band_connections()
        self._build_cross_band_connections()
        self._build_output_connections()
    
    def _build_band_connections(self):
        """Build connections within each frequency band."""
        for band_range in self.band_ranges:
            # Dense connectivity within band
            for src in band_range:
                for dest in band_range:
                    if src != dest:  # No self-connections
                        self.add_synapse(src, dest, 1)
    
    def _build_cross_band_connections(self):
        """Build connections between adjacent frequency bands."""
        for i in range(self.num_bands - 1):
            current_band = self.band_ranges[i]
            next_band = self.band_ranges[i + 1]
            
            # Sparse connections between bands
            for src in current_band:
                for dest in np.random.choice(list(next_band), size=2, replace=False):
                    self.add_synapse(src, dest, 1)
    
    def _build_output_connections(self):
        """Build connections to output neurons."""
        output_range = range(self.output_size)
        
        # Connect each band to output
        for band_range in self.band_ranges:
            for src in band_range:
                for dest in output_range:
                    self.add_synapse(src, dest, 1)

# Create signal processing model
wiring = SignalWiring(
    input_size=1,
    num_bands=4,
    neurons_per_band=16,
    output_size=1
)

model = CfC(
    wiring=wiring,
    activation="tanh"
)

## 2. Signal Filtering

Train the model to filter out noise from signals:

In [None]:
def generate_noisy_signals(n_samples=1000, seq_length=100):
    """Generate clean and noisy signals.
    
    Returns:
    - Noisy signals
    - Clean signals
    """
    # Time points
    t = np.linspace(0, 10, seq_length)
    
    # Generate data
    X = np.zeros((n_samples, seq_length, 1))
    y = np.zeros((n_samples, seq_length, 1))
    
    for i in range(n_samples):
        # Generate clean signal
        freq = 1.0 + 0.1 * np.random.randn()
        phase = 2 * np.pi * np.random.rand()
        amplitude = 0.5 + 0.1 * np.random.randn()
        
        clean = amplitude * np.sin(2 * np.pi * freq * t + phase)
        
        # Add noise
        noise = np.random.normal(0, 0.2, seq_length)
        noisy = clean + noise
        
        X[i, :, 0] = noisy
        y[i, :, 0] = clean
    
    return mx.array(X), mx.array(y)

# Generate data
X_train, y_train = generate_noisy_signals()
X_test, y_test = generate_noisy_signals(n_samples=100)

# Train model
optimizer = nn.Adam(learning_rate=0.001)

def train_step(model, x, y):
    """Single training step."""
    def loss_fn(model, x, y):
        pred = model(x)
        return mx.mean((pred - y) ** 2)
    
    loss, grads = nn.value_and_grad(model, loss_fn)(model, x, y)
    optimizer.update(model, grads)
    return loss

# Training loop
losses = []
for epoch in range(100):
    loss = train_step(model, X_train, y_train)
    losses.append(float(loss))
    
    if epoch % 10 == 0:
        print(f"Epoch {epoch}, Loss: {float(loss):.4f}")

# Plot results
plt.figure(figsize=(15, 5))

# Plot training loss
plt.subplot(121)
plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.grid(True)

# Plot filtering results
plt.subplot(122)
predictions = model(X_test)
plt.plot(X_test[0, :, 0], 'gray', alpha=0.5, label='Noisy')
plt.plot(y_test[0, :, 0], 'b', label='Clean')
plt.plot(predictions[0, :, 0], 'r--', label='Filtered')
plt.xlabel('Time Step')
plt.ylabel('Amplitude')
plt.title('Signal Filtering')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()

## 3. Frequency Analysis

Train the model to decompose signals into frequency components:

In [None]:
class FrequencyWiring(Wiring):
    """Wiring pattern for frequency analysis.
    
    Features:
    - Frequency-specific neurons
    - Harmonic connections
    - Phase relationships
    """
    
    def __init__(
        self,
        input_size: int,
        freq_neurons: int = 32,
        num_freqs: int = 4
    ):
        total_units = freq_neurons * num_freqs
        super().__init__(total_units)
        
        # Store configuration
        self.freq_neurons = freq_neurons
        self.num_freqs = num_freqs
        
        # Set output dimension (all frequencies)
        self.set_output_dim(num_freqs)
        
        # Build frequency-specific connectivity
        self._build_freq_connections()
        self._build_harmonic_connections()

def generate_mixed_signals(n_samples=1000, seq_length=100):
    """Generate signals with multiple frequency components."""
    t = np.linspace(0, 10, seq_length)
    freqs = [1, 2, 4, 8]  # Hz
    
    X = np.zeros((n_samples, seq_length, 1))
    y = np.zeros((n_samples, seq_length, len(freqs)))
    
    for i in range(n_samples):
        # Generate components
        components = []
        for j, freq in enumerate(freqs):
            amp = 0.5 + 0.1 * np.random.randn()
            phase = 2 * np.pi * np.random.rand()
            component = amp * np.sin(2 * np.pi * freq * t + phase)
            components.append(component)
            y[i, :, j] = component
        
        # Mix components
        X[i, :, 0] = np.sum(components, axis=0)
    
    return mx.array(X), mx.array(y)

# Create frequency analysis model
freq_wiring = FrequencyWiring(
    input_size=1,
    freq_neurons=32,
    num_freqs=4
)

freq_model = CfC(
    wiring=freq_wiring,
    activation="tanh"
)

# Generate data
X_train, y_train = generate_mixed_signals()
X_test, y_test = generate_mixed_signals(n_samples=100)

# Train model
optimizer = nn.Adam(learning_rate=0.001)
losses = []

for epoch in range(100):
    loss = train_step(freq_model, X_train, y_train)
    losses.append(float(loss))
    
    if epoch % 10 == 0:
        print(f"Epoch {epoch}, Loss: {float(loss):.4f}")

# Plot results
plt.figure(figsize=(15, 10))

# Plot training loss
plt.subplot(221)
plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.grid(True)

# Plot mixed signal
plt.subplot(222)
plt.plot(X_test[0, :, 0])
plt.xlabel('Time Step')
plt.ylabel('Amplitude')
plt.title('Mixed Signal')
plt.grid(True)

# Plot frequency components
predictions = freq_model(X_test)
plt.subplot(223)
for i in range(4):
    plt.plot(y_test[0, :, i], label=f'True {i+1}Hz')
plt.xlabel('Time Step')
plt.ylabel('Amplitude')
plt.title('True Components')
plt.legend()
plt.grid(True)

plt.subplot(224)
for i in range(4):
    plt.plot(predictions[0, :, i], '--', label=f'Pred {i+1}Hz')
plt.xlabel('Time Step')
plt.ylabel('Amplitude')
plt.title('Predicted Components')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()

## Analysis

The signal processing wiring patterns demonstrate several advantages:

1. **Filtering Performance**
   - Effectively removes noise
   - Preserves signal structure
   - Adapts to different noise levels

2. **Frequency Analysis**
   - Separates frequency components
   - Handles overlapping frequencies
   - Maintains phase relationships

3. **Architecture Benefits**
   - Multi-scale processing
   - Frequency-specific neurons
   - Cross-frequency interactions

Key considerations for signal processing:
- Balance between smoothing and detail preservation
- Handle varying signal-to-noise ratios
- Maintain temporal coherence
- Efficient real-time processing