# Computer Vision with Neural Circuit Policies

This notebook demonstrates how to use wiring patterns for computer vision tasks:
- Feature extraction
- Image classification
- Object detection
- Visual attention

In [None]:
import mlx.core as mx
import mlx.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from ncps.mlx import CfC, LTC
from ncps.mlx.wirings import Wiring

## 1. Visual Processing Wiring

Create a wiring pattern for visual processing with local receptive fields:

In [None]:
class VisionWiring(Wiring):
    """Wiring pattern for visual processing.
    
    Architecture:
    - Local receptive fields
    - Feature hierarchies
    - Skip connections
    """
    
    def __init__(
        self,
        input_height: int,
        input_width: int,
        channels: List[int],
        kernel_size: int = 3,
        stride: int = 2
    ):
        # Calculate total units needed
        feature_maps = self._get_feature_maps(
            input_height,
            input_width,
            channels,
            stride
        )
        total_units = sum(h * w * c for h, w, c in feature_maps)
        
        super().__init__(total_units)
        
        # Store configuration
        self.input_height = input_height
        self.input_width = input_width
        self.channels = channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.feature_maps = feature_maps
        
        # Set output dimension (last feature map)
        self.set_output_dim(channels[-1])
        
        # Build connectivity
        self._build_local_connections()
        self._build_skip_connections()
    
    def _get_feature_maps(self, h, w, channels, stride):
        """Calculate feature map sizes."""
        maps = []
        for c in channels:
            maps.append((h, w, c))
            h = (h - 1) // stride + 1
            w = (w - 1) // stride + 1
        return maps
    
    def _get_receptive_field(self, h, w, layer):
        """Get neurons in local receptive field."""
        k = self.kernel_size
        h_start = max(0, h - k//2)
        h_end = min(self.feature_maps[layer][0], h + k//2 + 1)
        w_start = max(0, w - k//2)
        w_end = min(self.feature_maps[layer][1], w + k//2 + 1)
        
        field = []
        for i in range(h_start, h_end):
            for j in range(w_start, w_end):
                field.append((i, j))
        return field
    
    def _build_local_connections(self):
        """Build connections with local receptive fields."""
        offset = 0
        for layer in range(len(self.channels) - 1):
            h, w, c = self.feature_maps[layer]
            next_h, next_w, next_c = self.feature_maps[layer + 1]
            
            # Connect each neuron to its local receptive field
            for i in range(0, h, self.stride):
                for j in range(0, w, self.stride):
                    src_idx = offset + (i * w + j) * c
                    
                    # Get receptive field in next layer
                    field = self._get_receptive_field(i//self.stride, j//self.stride, layer + 1)
                    
                    for ni, nj in field:
                        dest_idx = (offset + h * w * c) + (ni * next_w + nj) * next_c
                        self.add_synapse(src_idx, dest_idx, 1)
            
            offset += h * w * c
    
    def _build_skip_connections(self):
        """Build skip connections between layers."""
        offset = 0
        for layer in range(len(self.channels) - 2):
            h, w, c = self.feature_maps[layer]
            
            # Connect to layer + 2 (skip one layer)
            skip_h, skip_w, skip_c = self.feature_maps[layer + 2]
            skip_offset = offset + h * w * c + self.feature_maps[layer + 1][0] * self.feature_maps[layer + 1][1] * self.feature_maps[layer + 1][2]
            
            # Sparse skip connections
            for i in range(0, h, self.stride * 2):
                for j in range(0, w, self.stride * 2):
                    src_idx = offset + (i * w + j) * c
                    dest_idx = skip_offset + ((i//(self.stride * 2)) * skip_w + j//(self.stride * 2)) * skip_c
                    self.add_synapse(src_idx, dest_idx, 1)
            
            offset += h * w * c

# Create vision model
wiring = VisionWiring(
    input_height=32,
    input_width=32,
    channels=[64, 128, 256, 512],
    kernel_size=3,
    stride=2
)

model = CfC(
    wiring=wiring,
    activation="relu"
)

## 2. Image Classification

Train the model for image classification:

In [None]:
def generate_image_data(n_samples=1000, size=32, n_classes=10):
    """Generate synthetic image data.
    
    Returns:
    - Images (with simple patterns)
    - Class labels
    """
    X = np.zeros((n_samples, size, size, 3))
    y = np.zeros((n_samples, n_classes))
    
    for i in range(n_samples):
        # Generate random class
        class_id = np.random.randint(n_classes)
        y[i, class_id] = 1
        
        # Generate pattern based on class
        if class_id < n_classes // 2:
            # Geometric patterns
            center = size // 2
            radius = np.random.randint(5, 15)
            color = np.random.rand(3)
            
            for h in range(size):
                for w in range(size):
                    dist = np.sqrt((h - center)**2 + (w - center)**2)
                    if dist < radius:
                        X[i, h, w] = color
        else:
            # Textural patterns
            freq = np.random.rand() * 10
            phase = np.random.rand() * 2 * np.pi
            for h in range(size):
                for w in range(size):
                    val = np.sin(freq * h/size + phase) * np.cos(freq * w/size)
                    X[i, h, w] = (val + 1) / 2
    
    # Reshape for model input (batch, time, features)
    X = X.reshape(n_samples, size * size, 3)
    
    return mx.array(X), mx.array(y)

# Generate data
X_train, y_train = generate_image_data()
X_test, y_test = generate_image_data(n_samples=100)

# Train model
optimizer = nn.Adam(learning_rate=0.001)

def train_step(model, x, y):
    """Single training step."""
    def loss_fn(model, x, y):
        # Get final output
        pred = model(x)[:, -1]
        return mx.mean((pred - y) ** 2)
    
    loss, grads = nn.value_and_grad(model, loss_fn)(model, x, y)
    optimizer.update(model, grads)
    return loss

# Training loop
losses = []
for epoch in range(100):
    loss = train_step(model, X_train, y_train)
    losses.append(float(loss))
    
    if epoch % 10 == 0:
        print(f"Epoch {epoch}, Loss: {float(loss):.4f}")

# Plot results
plt.figure(figsize=(15, 5))

# Plot training loss
plt.subplot(121)
plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.grid(True)

# Plot example predictions
plt.subplot(122)
predictions = model(X_test)[:, -1]
plt.scatter(
    mx.argmax(y_test, axis=1),
    mx.argmax(predictions, axis=1),
    alpha=0.5
)
plt.plot([0, 9], [0, 9], 'r--')
plt.xlabel('True Class')
plt.ylabel('Predicted Class')
plt.title('Classification Results')
plt.grid(True)

plt.tight_layout()
plt.show()

## 3. Feature Visualization

Visualize learned features and receptive fields:

In [None]:
def visualize_receptive_fields(model, layer_idx=0):
    """Visualize receptive fields of neurons."""
    # Get layer information
    h, w, c = model.cell.wiring.feature_maps[layer_idx]
    
    # Create input with single active pixel
    response_maps = []
    for i in range(h):
        for j in range(w):
            x = mx.zeros((1, h * w, 3))
            x = x.at[0, i * w + j].set([1, 1, 1])
            
            # Get response
            response = model(x)
            response_maps.append(response[0, -1].reshape(h, w, -1))
    
    # Plot receptive fields
    plt.figure(figsize=(15, 15))
    for idx, response in enumerate(response_maps[:16]):  # Show first 16
        plt.subplot(4, 4, idx + 1)
        plt.imshow(response[:, :, 0])
        plt.axis('off')
        plt.title(f'Neuron {idx}')
    
    plt.suptitle('Receptive Fields')
    plt.tight_layout()
    plt.show()

# Visualize receptive fields
visualize_receptive_fields(model)

## Analysis

The vision wiring pattern demonstrates several advantages:

1. **Feature Learning**
   - Hierarchical feature extraction
   - Local receptive fields
   - Multi-scale processing

2. **Classification Performance**
   - Good pattern recognition
   - Scale invariance
   - Translation invariance

3. **Architecture Benefits**
   - Efficient parameter sharing
   - Skip connections
   - Natural feature hierarchy

Key considerations for vision tasks:
- Balance receptive field size
- Manage feature map sizes
- Handle spatial relationships
- Efficient computation