# PyTorch Tutorial: Neural Cellular Automata

Neural Cellular Automata (NCA) combine the local update rules of cellular automata with learnable neural networks. They can learn to grow patterns, regenerate from damage, and create complex textures from simple rules.

## Learning Objectives
- Understand classical cellular automata (Conway's Game of Life)
- Make cellular automata differentiable for training
- Implement perception using Sobel filters
- Build and train a Neural CA to grow patterns
- Understand regeneration and self-organization

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

torch.manual_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 1. Classical Cellular Automata

A cellular automaton is a grid of cells, where each cell's next state depends only on its current state and neighbors.

**Conway's Game of Life rules**:
- Live cell with 2-3 neighbors survives
- Dead cell with exactly 3 neighbors becomes alive
- All other cells die

### Why Cellular Automata Matter

Despite their simplicity, CA exhibit **emergent complexity** — complex global patterns arise from simple local rules. This is the same principle behind:
- **Biological morphogenesis**: How a single cell (embryo) grows into a complex organism using only local chemical signals
- **Swarm intelligence**: How ant colonies build complex structures with no central coordinator
- **Self-organization**: How crystals, snowflakes, and neural patterns form

The key insight is that **local rules + iteration = global structure**. Neural CA take this further by *learning* the local rules from data instead of hand-designing them.

In [None]:
def game_of_life_step(grid: torch.Tensor) -> torch.Tensor:
    """One step of Conway's Game of Life."""
    kernel = torch.tensor([[1, 1, 1],
                          [1, 0, 1],
                          [1, 1, 1]], dtype=torch.float32)
    kernel = kernel.view(1, 1, 3, 3)
    
    grid_4d = grid.float().view(1, 1, *grid.shape)
    neighbors = F.conv2d(grid_4d, kernel, padding=1).squeeze()
    
    birth = (grid == 0) & (neighbors == 3)
    survive = (grid == 1) & ((neighbors == 2) | (neighbors == 3))
    
    return (birth | survive).float()

# Create initial state with a glider
grid_size = 32
initial = torch.zeros(grid_size, grid_size)
glider = torch.tensor([[0, 1, 0],
                       [0, 0, 1],
                       [1, 1, 1]], dtype=torch.float32)
initial[5:8, 5:8] = glider

# Run simulation
history = [initial.clone()]
state = initial.clone()
for _ in range(50):
    state = game_of_life_step(state)
    history.append(state.clone())

# Visualize
fig, axes = plt.subplots(1, 5, figsize=(15, 3))
for i, step in enumerate([0, 10, 20, 30, 50]):
    axes[i].imshow(history[step].numpy(), cmap='binary')
    axes[i].set_title(f'Step {step}')
    axes[i].axis('off')
plt.suptitle("Conway's Game of Life - Glider")
plt.tight_layout()
plt.show()

## 2. Making CA Differentiable

To learn CA rules, we need:
1. **Continuous state**: Not just 0/1, but real values (so gradients can flow)
2. **Smooth update function**: Learnable neural network (replaces hand-coded rules)
3. **Differentiable loss**: Compare to target pattern (MSE between generated and target image)

### The Multi-Channel State

Each cell isn't just a single number — it has **multiple channels**:
- **Channels 0-3**: RGBA (what we see — red, green, blue, alpha)
- **Channels 4-15**: Hidden channels (internal state the CA uses to coordinate)

The hidden channels are like invisible "chemical signals" that cells pass to neighbors. Biologically, this mirrors how real cells communicate through chemical gradients (morphogens) that aren't directly visible.

### The Alpha Channel as "Alive" Signal

Channel 3 (alpha) serves as a binary alive/dead signal. A cell is "alive" if its alpha value (or any neighbor's alpha) exceeds a threshold. Dead cells get zeroed out. This prevents the pattern from growing infinitely and creates clear boundaries.

In [None]:
NUM_CHANNELS = 16  # 4 visible (RGBA) + 12 hidden

def create_seed(size: int, num_channels: int = NUM_CHANNELS) -> torch.Tensor:
    """Create a seed state with a single active cell in the center."""
    state = torch.zeros(1, num_channels, size, size)
    center = size // 2
    state[0, 3, center, center] = 1.0  # Alpha channel
    return state

def get_visible_state(state: torch.Tensor) -> torch.Tensor:
    """Extract RGBA channels from state."""
    rgba = state[:, :4, :, :]
    return torch.clamp(rgba, 0, 1)

# Visualize seed
seed = create_seed(64)
visible = get_visible_state(seed)

plt.figure(figsize=(6, 6))
plt.imshow(visible[0, 3].numpy(), cmap='gray')
plt.title('Seed State (Alpha Channel)')
plt.colorbar()
plt.show()

print(f"State shape: {seed.shape}")

## 3. Perception: Sobel Filters

Each cell needs to perceive its neighborhood. We use **Sobel filters** to detect gradients.

### Why Sobel Filters Instead of Raw Neighbor Values?

Sobel filters give each cell richer information about its neighborhood:

- **Identity filter**: "What is my own value?" (the cell's current state)
- **Sobel-X filter**: "Are my left neighbors different from my right neighbors?" (horizontal gradient)
- **Sobel-Y filter**: "Are my top neighbors different from my bottom neighbors?" (vertical gradient)

This is the same as asking: "Am I at an edge? Which direction is the edge?" — exactly the information needed to grow patterns with defined boundaries.

**Connection to biology**: Real cells sense concentration gradients of signaling molecules. A cell at the edge of a tissue "feels" a steep gradient (lots of signal on one side, none on the other). Sobel filters are a simple mathematical model of this gradient sensing.

### Why Not Just Use Raw Convolution?

We could let the network learn arbitrary 3x3 filters. But Sobel filters provide a strong inductive bias:
- They decompose neighborhood information into interpretable components (value + gradients)
- This reduces the number of parameters the network needs to learn
- The perception output has `num_channels × 3` dimensions (identity + x-gradient + y-gradient for each channel)

In [None]:
class PerceptionModule(nn.Module):
    """Perceive local neighborhood using Sobel filters."""
    
    def __init__(self, num_channels: int = NUM_CHANNELS):
        super().__init__()
        
        sobel_x = torch.tensor([[-1, 0, 1],
                                [-2, 0, 2],
                                [-1, 0, 1]], dtype=torch.float32) / 8.0
        
        sobel_y = torch.tensor([[-1, -2, -1],
                                [0, 0, 0],
                                [1, 2, 1]], dtype=torch.float32) / 8.0
        
        identity = torch.tensor([[0, 0, 0],
                                 [0, 1, 0],
                                 [0, 0, 0]], dtype=torch.float32)
        
        filters = torch.stack([identity, sobel_x, sobel_y])
        self.filters = nn.Parameter(
            filters.repeat(num_channels, 1, 1).view(-1, 1, 3, 3),
            requires_grad=False
        )
        self.num_channels = num_channels
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        perceived = F.conv2d(x, self.filters, padding=1, groups=self.num_channels)
        return perceived

# Demonstrate perception
perception = PerceptionModule()
test_state = torch.zeros(1, NUM_CHANNELS, 32, 32)
test_state[0, 3, 10:22, 10:22] = 1.0

perceived = perception(test_state)
print(f"Input shape: {test_state.shape}")
print(f"Perceived shape: {perceived.shape}")

## 4. Neural CA Update Rule

### Architecture Design Choices

The update network uses **1x1 convolutions** — these are equivalent to applying the same MLP to every cell independently. This enforces **spatial uniformity**: every cell uses the exact same update rule, just like in classical CA.

**Key design decisions**:
- **Zero-initialized last layer**: The update starts as "do nothing" (delta = 0). This means the model starts stable and gradually learns to make changes — much easier to train than starting with random updates.
- **Stochastic update mask**: During training, only a random subset of cells update each step. This forces cells to be robust to neighbors updating at different times, which is critical for stable long-term growth and regeneration.
- **Residual connection**: The update is additive (`state = state + delta`), not a replacement. This makes training more stable and allows small, incremental changes.
- **Alive masking**: Cells that are "dead" (alpha < threshold) are forced to zero. This prevents patterns from growing infinitely and creates sharp boundaries.

In [None]:
class NeuralCA(nn.Module):
    """Neural Cellular Automaton."""
    
    def __init__(self, num_channels: int = NUM_CHANNELS, hidden_dim: int = 128):
        super().__init__()
        
        self.num_channels = num_channels
        self.perception = PerceptionModule(num_channels)
        
        perception_dim = num_channels * 3
        
        self.update_net = nn.Sequential(
            nn.Conv2d(perception_dim, hidden_dim, 1),
            nn.ReLU(),
            nn.Conv2d(hidden_dim, num_channels, 1),
        )
        
        nn.init.zeros_(self.update_net[-1].weight)
        nn.init.zeros_(self.update_net[-1].bias)
    
    def get_alive_mask(self, state: torch.Tensor, threshold: float = 0.1) -> torch.Tensor:
        alpha = state[:, 3:4, :, :]
        alive = F.max_pool2d(alpha, 3, stride=1, padding=1) > threshold
        return alive.float()
    
    def forward(self, state: torch.Tensor, update_rate: float = 0.5) -> torch.Tensor:
        pre_alive = self.get_alive_mask(state)
        perceived = self.perception(state)
        delta = self.update_net(perceived)
        
        if self.training:
            update_mask = (torch.rand_like(state[:, :1, :, :]) < update_rate).float()
        else:
            update_mask = 1.0
        
        state = state + delta * update_mask
        post_alive = self.get_alive_mask(state)
        alive_mask = pre_alive * post_alive
        state = state * alive_mask
        
        return state
    
    def run_steps(self, state: torch.Tensor, steps: int) -> torch.Tensor:
        for _ in range(steps):
            state = self(state)
        return state

nca = NeuralCA().to(device)
print(f"Model parameters: {sum(p.numel() for p in nca.parameters()):,}")

## 5. Training to Grow a Pattern

### Pool-Based Training (Why It's Necessary)

Naive training (always start from seed, grow for N steps, compute loss) has a problem: the model only learns to grow from a fresh seed. It doesn't learn to **persist** or **recover from damage**.

Pool-based training solves this:
1. Maintain a pool of N states (e.g., 256), initialized as seeds
2. Sample a batch from the pool (these could be partially-grown states from previous iterations)
3. Run for a random number of steps (64-96, not fixed — prevents the model from "memorizing" a specific step count)
4. Compute loss against target and backpropagate
5. Put the updated states back in the pool (so next iteration starts from where we left off)
6. Replace the worst-performing sample with a fresh seed (prevents the pool from degrading)

**Why this works**: The model is trained on states at all stages of growth (fresh seeds, partially grown, fully grown), so it learns a stable attractor that works regardless of starting state.

### Training Stability Tips

- **Gradient clipping** (1.0): NCA can have exploding gradients because errors compound over many steps
- **Random step counts** (64-96): Prevents overfitting to a fixed number of steps
- **Pool replacement**: Injecting fresh seeds prevents the pool from becoming stale
- **Learning rate**: Start with 2e-3 (Adam), which is higher than typical deep learning because the model is small

In [None]:
def create_target_circle(size: int = 64) -> torch.Tensor:
    """Create a simple circle target pattern."""
    target = torch.zeros(1, 4, size, size)
    center = size // 2
    
    y, x = torch.meshgrid(torch.arange(size), torch.arange(size), indexing='ij')
    y, x = y.float(), x.float()
    
    dist_from_center = torch.sqrt((x - center)**2 + (y - center)**2)
    circle_radius = size * 0.35
    circle_mask = dist_from_center < circle_radius
    
    target[0, 0][circle_mask] = 1.0  # R
    target[0, 1][circle_mask] = 0.5  # G
    target[0, 2][circle_mask] = 0.0  # B
    target[0, 3][circle_mask] = 1.0  # A
    
    return target

target = create_target_circle(64).to(device)

plt.figure(figsize=(4, 4))
img = target[0].permute(1, 2, 0).cpu().numpy()
plt.imshow(img)
plt.title('Target Pattern')
plt.axis('off')
plt.show()

In [None]:
def train_nca(model, target, num_epochs=500, pool_size=256, batch_size=4, lr=2e-3):
    """Train NCA using a pool of states."""
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    size = target.shape[-1]
    pool = create_seed(size).repeat(pool_size, 1, 1, 1).to(device)
    losses = []
    
    for epoch in range(num_epochs):
        batch_indices = torch.randint(0, pool_size, (batch_size,))
        batch = pool[batch_indices].clone()
        num_steps = torch.randint(64, 96, (1,)).item()
        
        for _ in range(num_steps):
            batch = model(batch)
        
        visible = get_visible_state(batch)
        loss = F.mse_loss(visible, target.expand(batch_size, -1, -1, -1))
        
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        with torch.no_grad():
            pool[batch_indices] = batch.detach()
            batch_losses = F.mse_loss(visible, target.expand(batch_size, -1, -1, -1), reduction='none')
            batch_losses = batch_losses.mean(dim=(1, 2, 3))
            worst_idx = batch_indices[batch_losses.argmax()]
            pool[worst_idx] = create_seed(size).to(device).squeeze(0)
        
        losses.append(loss.item())
        
        if epoch % 100 == 0:
            print(f"Epoch {epoch}, Loss: {loss.item():.6f}")
    
    return losses

nca = NeuralCA().to(device)
print("Training Neural CA...")
losses = train_nca(nca, target, num_epochs=500)

## 6. Visualize Growth

In [None]:
def visualize_growth(model, size=64, steps=100):
    """Visualize the NCA growing from seed."""
    model.eval()
    state = create_seed(size).to(device)
    history = [get_visible_state(state).cpu()]
    
    with torch.no_grad():
        for _ in range(steps):
            state = model(state)
            history.append(get_visible_state(state).cpu())
    
    fig, axes = plt.subplots(1, 6, figsize=(15, 3))
    for i, step in enumerate([0, 20, 40, 60, 80, 100]):
        img = history[step][0].permute(1, 2, 0).numpy()
        axes[i].imshow(img)
        axes[i].set_title(f'Step {step}')
        axes[i].axis('off')
    
    plt.suptitle('NCA Growth from Seed')
    plt.tight_layout()
    plt.show()

visualize_growth(nca)

## 7. Regeneration: Self-Repair from Damage

### Why Regeneration Emerges Naturally

Regeneration isn't explicitly trained — it emerges as a side effect of the training procedure:

1. **Local rules**: Each cell only sees its 3x3 neighborhood. It doesn't know about the global pattern. It just follows its local update rule.
2. **Same rule everywhere**: Every cell runs the identical neural network. There's no "center cell" with special instructions.
3. **Pool-based training**: The model sees partially-damaged states during training (when pool samples are at different stages of growth). This implicitly teaches recovery.
4. **Stable attractor**: The target pattern is a fixed point of the learned dynamics. Any perturbation (damage) is corrected because the local rules push the state back toward the attractor.

**Analogy**: Imagine a crowd of people, each following the rule "stand 2 feet from your nearest neighbor." If you remove a few people, the remaining crowd naturally fills the gaps — not because anyone was told to, but because the local rule creates a stable global configuration.

### Connection to Biological Regeneration

Real organisms (like salamanders regrowing limbs) use similar principles:
- Cells communicate only with neighbors (chemical signals)
- Every cell runs the same DNA "program"
- The target pattern (body plan) is encoded as a stable attractor of the cell dynamics
- Damage disrupts the local state, and the local rules drive recovery

Neural CA provide a simplified computational model of this biological process.

In [None]:
def test_regeneration(model, size=64, grow_steps=100, regen_steps=100):
    """Test NCA regeneration after damage."""
    model.eval()
    
    state = create_seed(size).to(device)
    with torch.no_grad():
        for _ in range(grow_steps):
            state = model(state)
    
    grown = state.clone()
    damaged = state.clone()
    damaged[:, :, 20:44, 20:44] = 0
    
    regenerating = damaged.clone()
    regen_history = [get_visible_state(regenerating).cpu()]
    
    with torch.no_grad():
        for _ in range(regen_steps):
            regenerating = model(regenerating)
            regen_history.append(get_visible_state(regenerating).cpu())
    
    fig, axes = plt.subplots(2, 5, figsize=(15, 6))
    
    axes[0, 0].imshow(get_visible_state(grown)[0].permute(1, 2, 0).cpu().numpy())
    axes[0, 0].set_title('Fully Grown')
    axes[0, 1].imshow(get_visible_state(damaged)[0].permute(1, 2, 0).cpu().numpy())
    axes[0, 1].set_title('After Damage')
    for i in range(3):
        axes[0, i+2].axis('off')
    
    for i, step in enumerate([0, 25, 50, 75, 100]):
        axes[1, i].imshow(regen_history[step][0].permute(1, 2, 0).numpy())
        axes[1, i].set_title(f'Regen Step {step}')
    
    for ax in axes.flat:
        ax.axis('off')
    
    plt.suptitle('NCA Regeneration After Damage')
    plt.tight_layout()
    plt.show()

test_regeneration(nca)

## 8. Concepts to Know

### Classical CA vs Neural CA

| Aspect | Classical CA (Game of Life) | Neural CA |
|--------|---------------------------|-----------|
| Rules | Hand-designed, fixed | Learned via gradient descent |
| State | Discrete (0/1) | Continuous (real-valued) |
| Trainable | No | Yes (differentiable) |
| Complexity | Limited by rule design | Can learn arbitrary patterns |
| Channels | 1 | 16+ (hidden channels for coordination) |

### Why Sobel Filters for Perception?

Sobel filters detect gradients in the neighborhood. They give each cell three pieces of information per channel:
- **Identity**: "What is my current value?"
- **Sobel-X**: "What is the horizontal gradient?" (left-right difference)
- **Sobel-Y**: "What is the vertical gradient?" (top-bottom difference)

This is computationally cheap and biologically inspired — real cells sense chemical concentration gradients to determine their position and role during development.

### How Pool-Based Training Works

The pool is like a "memory" of training states. Instead of always starting from a fresh seed, the model trains on states at various stages of development. This teaches:
- **Growth**: How to develop from a seed
- **Persistence**: How to maintain a stable pattern
- **Recovery**: How to heal from partial damage (because pool states can be "partially grown")

### Why NCAs Regenerate

The target pattern becomes a **stable attractor** of the learned dynamics:
- Local rules push any perturbed state back toward the target
- No cell "knows" the global pattern — regeneration is purely emergent
- The same principle explains biological regeneration in organisms like planaria and salamanders

### Applications of Neural CA

- **Texture synthesis**: Generate seamless textures from small examples
- **Pattern growth**: Grow complex images from single-pixel seeds
- **Self-repairing systems**: Distributed systems that recover from component failure
- **Morphogenesis modeling**: Understanding how organisms develop from embryos
- **Generative art**: Creating organic, evolving visual artworks
- **Decentralized computing**: Algorithms with no central controller, robust to local failures

## 9. Key Takeaways

1. **Neural CA** combine cellular automata with learnable neural networks
2. **Perception** uses Sobel filters to sense local gradients
3. **Update rule** is a small network applied to each cell
4. **Residual updates** with stochastic masks enable stable growth
5. **Pool-based training** tests persistence and robustness
6. **Regeneration** emerges naturally from local rules
7. **Self-organization** produces complex patterns from simple seeds