# Natural Language Processing with Neural Circuit Policies

This notebook demonstrates how to use wiring patterns for NLP tasks:
- Text classification
- Sequence modeling
- Attention mechanisms
- Language modeling

In [None]:
import mlx.core as mx
import mlx.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from ncps.mlx import CfC, LTC
from ncps.mlx.wirings import Wiring

## 1. Language Model Wiring

Create a wiring pattern for language modeling with attention:

In [None]:
class LanguageWiring(Wiring):
    """Wiring pattern for language processing.
    
    Architecture:
    - Token embeddings
    - Multi-head attention
    - Position-wise processing
    """
    
    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        vocab_size: int,
        max_seq_length: int = 512
    ):
        # Size calculations
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_size = hidden_size // num_heads
        self.vocab_size = vocab_size
        self.max_seq_length = max_seq_length
        
        # Total units needed for Q,K,V projections and output
        total_units = hidden_size * 4
        super().__init__(total_units)
        
        # Define component ranges
        self.query_range = range(0, hidden_size)
        self.key_range = range(hidden_size, hidden_size * 2)
        self.value_range = range(hidden_size * 2, hidden_size * 3)
        self.output_range = range(hidden_size * 3, hidden_size * 4)
        
        # Set output dimension
        self.set_output_dim(vocab_size)
        
        # Build connectivity
        self._build_attention_connections()
        self._build_position_connections()
    
    def _build_attention_connections(self):
        """Build multi-head attention connections."""
        # Connect each query to its corresponding key-value pairs
        for head in range(self.num_heads):
            q_start = head * self.head_size
            q_end = (head + 1) * self.head_size
            
            k_start = self.key_range.start + head * self.head_size
            k_end = self.key_range.start + (head + 1) * self.head_size
            
            v_start = self.value_range.start + head * self.head_size
            v_end = self.value_range.start + (head + 1) * self.head_size
            
            # Query-Key connections
            for q in range(q_start, q_end):
                for k in range(k_start, k_end):
                    self.add_synapse(q, k, 1)
            
            # Key-Value connections
            for k in range(k_start, k_end):
                for v in range(v_start, v_end):
                    self.add_synapse(k, v, 1)
            
            # Value-Output connections
            for v in range(v_start, v_end):
                for o in self.output_range:
                    self.add_synapse(v, o, 1)
    
    def _build_position_connections(self):
        """Build position-wise processing connections."""
        # Add position-wise feed-forward connections
        for i in range(self.hidden_size):
            # Connect to corresponding output neuron
            self.add_synapse(i, self.output_range.start + i, 1)
            
            # Add skip connections
            if i % 2 == 0:  # Every other neuron gets skip connection
                self.add_synapse(i, self.output_range.start + i + 1, 1)

# Create language model
wiring = LanguageWiring(
    hidden_size=256,
    num_heads=8,
    vocab_size=10000
)

model = CfC(
    wiring=wiring,
    activation="gelu"
)

## 2. Text Classification

Train the model for text classification:

In [None]:
def generate_text_data(n_samples=1000, seq_length=50, vocab_size=1000, n_classes=5):
    """Generate synthetic text data.
    
    Returns:
    - Token sequences
    - Class labels
    """
    X = np.random.randint(0, vocab_size, (n_samples, seq_length))
    y = np.zeros((n_samples, n_classes))
    
    # Generate patterns based on token distributions
    for i in range(n_samples):
        # Assign class based on token patterns
        pattern_type = np.random.randint(n_classes)
        y[i, pattern_type] = 1
        
        if pattern_type == 0:
            # Repeated tokens
            token = np.random.randint(0, vocab_size//5)
            X[i, ::2] = token
        elif pattern_type == 1:
            # Increasing sequence
            start = np.random.randint(0, vocab_size//2)
            X[i] = np.minimum(start + np.arange(seq_length), vocab_size-1)
        elif pattern_type == 2:
            # Alternating high/low
            X[i, ::2] = np.random.randint(0, vocab_size//2, seq_length//2 + 1)
            X[i, 1::2] = np.random.randint(vocab_size//2, vocab_size, seq_length//2)
        else:
            # Random with local structure
            for j in range(0, seq_length, 5):
                token = np.random.randint(0, vocab_size)
                X[i, j:j+5] = np.random.normal(token, 2, 5).astype(int) % vocab_size
    
    # Convert to one-hot
    X_onehot = np.zeros((n_samples, seq_length, vocab_size))
    for i in range(n_samples):
        for j in range(seq_length):
            X_onehot[i, j, X[i, j]] = 1
    
    return mx.array(X_onehot), mx.array(y)

# Generate data
X_train, y_train = generate_text_data(vocab_size=wiring.vocab_size)
X_test, y_test = generate_text_data(n_samples=100, vocab_size=wiring.vocab_size)

# Train model
optimizer = nn.Adam(learning_rate=0.001)

def train_step(model, x, y):
    """Single training step."""
    def loss_fn(model, x, y):
        # Get sequence output
        pred = model(x)[:, -1]
        return mx.mean((pred - y) ** 2)
    
    loss, grads = nn.value_and_grad(model, loss_fn)(model, x, y)
    optimizer.update(model, grads)
    return loss

# Training loop
losses = []
for epoch in range(100):
    loss = train_step(model, X_train, y_train)
    losses.append(float(loss))
    
    if epoch % 10 == 0:
        print(f"Epoch {epoch}, Loss: {float(loss):.4f}")

# Plot results
plt.figure(figsize=(15, 5))

# Plot training loss
plt.subplot(121)
plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.grid(True)

# Plot confusion matrix
plt.subplot(122)
predictions = model(X_test)[:, -1]
confusion = np.zeros((5, 5))
for i in range(len(y_test)):
    true_class = mx.argmax(y_test[i])
    pred_class = mx.argmax(predictions[i])
    confusion[true_class, pred_class] += 1

plt.imshow(confusion, cmap='Blues')
plt.colorbar()
plt.xlabel('Predicted Class')
plt.ylabel('True Class')
plt.title('Confusion Matrix')

plt.tight_layout()
plt.show()

## 3. Attention Visualization

Visualize attention patterns:

In [None]:
def visualize_attention(model, text_input):
    """Visualize attention patterns."""
    # Get attention weights
    attention = model(text_input)
    
    # Plot attention patterns for each head
    plt.figure(figsize=(15, 5))
    for head in range(model.cell.wiring.num_heads):
        plt.subplot(2, 4, head + 1)
        
        # Get head-specific attention
        head_attention = attention[0, :, head * model.cell.wiring.head_size:(head + 1) * model.cell.wiring.head_size]
        plt.imshow(head_attention, cmap='viridis')
        plt.title(f'Head {head}')
        plt.axis('off')
    
    plt.suptitle('Attention Patterns')
    plt.tight_layout()
    plt.show()

# Visualize attention for a sample input
sample_input = X_test[0:1]
visualize_attention(model, sample_input)

## Analysis

The language wiring pattern demonstrates several advantages:

1. **Attention Mechanism**
   - Multi-head attention
   - Position-wise processing
   - Long-range dependencies

2. **Classification Performance**
   - Pattern recognition
   - Sequence understanding
   - Context integration

3. **Architecture Benefits**
   - Efficient attention
   - Position encoding
   - Skip connections

Key considerations for NLP tasks:
- Balance attention heads
- Handle sequence lengths
- Manage vocabulary size
- Efficient computation