# PyTorch Tutorial: Neural Cellular Automata

Neural Cellular Automata (NCA) combine the local update rules of cellular automata with learnable neural networks. They can learn to grow patterns, regenerate from damage, and create complex textures from simple rules.

## Learning Objectives
- Understand classical cellular automata (Conway's Game of Life)
- Make cellular automata differentiable for training
- Implement perception using Sobel filters
- Build and train a Neural CA to grow patterns
- Understand regeneration and self-organization

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

torch.manual_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 1. Classical Cellular Automata

A cellular automaton is a grid of cells, where each cell's next state depends only on its current state and neighbors.

**Conway's Game of Life rules**:
- Live cell with 2-3 neighbors survives
- Dead cell with exactly 3 neighbors becomes alive
- All other cells die

In [None]:
def game_of_life_step(grid: torch.Tensor) -> torch.Tensor:
    """One step of Conway's Game of Life."""
    kernel = torch.tensor([[1, 1, 1],
                          [1, 0, 1],
                          [1, 1, 1]], dtype=torch.float32)
    kernel = kernel.view(1, 1, 3, 3)
    
    grid_4d = grid.float().view(1, 1, *grid.shape)
    neighbors = F.conv2d(grid_4d, kernel, padding=1).squeeze()
    
    birth = (grid == 0) & (neighbors == 3)
    survive = (grid == 1) & ((neighbors == 2) | (neighbors == 3))
    
    return (birth | survive).float()

# Create initial state with a glider
grid_size = 32
initial = torch.zeros(grid_size, grid_size)
glider = torch.tensor([[0, 1, 0],
                       [0, 0, 1],
                       [1, 1, 1]], dtype=torch.float32)
initial[5:8, 5:8] = glider

# Run simulation
history = [initial.clone()]
state = initial.clone()
for _ in range(50):
    state = game_of_life_step(state)
    history.append(state.clone())

# Visualize
fig, axes = plt.subplots(1, 5, figsize=(15, 3))
for i, step in enumerate([0, 10, 20, 30, 50]):
    axes[i].imshow(history[step].numpy(), cmap='binary')
    axes[i].set_title(f'Step {step}')
    axes[i].axis('off')
plt.suptitle("Conway's Game of Life - Glider")
plt.tight_layout()
plt.show()

## 2. Making CA Differentiable

To learn CA rules, we need:
1. **Continuous state**: Not just 0/1, but real values
2. **Smooth update function**: Learnable neural network
3. **Differentiable loss**: Compare to target pattern

In [None]:
NUM_CHANNELS = 16  # 4 visible (RGBA) + 12 hidden

def create_seed(size: int, num_channels: int = NUM_CHANNELS) -> torch.Tensor:
    """Create a seed state with a single active cell in the center."""
    state = torch.zeros(1, num_channels, size, size)
    center = size // 2
    state[0, 3, center, center] = 1.0  # Alpha channel
    return state

def get_visible_state(state: torch.Tensor) -> torch.Tensor:
    """Extract RGBA channels from state."""
    rgba = state[:, :4, :, :]
    return torch.clamp(rgba, 0, 1)

# Visualize seed
seed = create_seed(64)
visible = get_visible_state(seed)

plt.figure(figsize=(6, 6))
plt.imshow(visible[0, 3].numpy(), cmap='gray')
plt.title('Seed State (Alpha Channel)')
plt.colorbar()
plt.show()

print(f"State shape: {seed.shape}")

## 3. Perception: Sobel Filters

Each cell needs to perceive its neighborhood. We use **Sobel filters** to detect gradients.

In [None]:
class PerceptionModule(nn.Module):
    """Perceive local neighborhood using Sobel filters."""
    
    def __init__(self, num_channels: int = NUM_CHANNELS):
        super().__init__()
        
        sobel_x = torch.tensor([[-1, 0, 1],
                                [-2, 0, 2],
                                [-1, 0, 1]], dtype=torch.float32) / 8.0
        
        sobel_y = torch.tensor([[-1, -2, -1],
                                [0, 0, 0],
                                [1, 2, 1]], dtype=torch.float32) / 8.0
        
        identity = torch.tensor([[0, 0, 0],
                                 [0, 1, 0],
                                 [0, 0, 0]], dtype=torch.float32)
        
        filters = torch.stack([identity, sobel_x, sobel_y])
        self.filters = nn.Parameter(
            filters.repeat(num_channels, 1, 1).view(-1, 1, 3, 3),
            requires_grad=False
        )
        self.num_channels = num_channels
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        perceived = F.conv2d(x, self.filters, padding=1, groups=self.num_channels)
        return perceived

# Demonstrate perception
perception = PerceptionModule()
test_state = torch.zeros(1, NUM_CHANNELS, 32, 32)
test_state[0, 3, 10:22, 10:22] = 1.0

perceived = perception(test_state)
print(f"Input shape: {test_state.shape}")
print(f"Perceived shape: {perceived.shape}")

## 4. Neural CA Update Rule

In [None]:
class NeuralCA(nn.Module):
    """Neural Cellular Automaton."""
    
    def __init__(self, num_channels: int = NUM_CHANNELS, hidden_dim: int = 128):
        super().__init__()
        
        self.num_channels = num_channels
        self.perception = PerceptionModule(num_channels)
        
        perception_dim = num_channels * 3
        
        self.update_net = nn.Sequential(
            nn.Conv2d(perception_dim, hidden_dim, 1),
            nn.ReLU(),
            nn.Conv2d(hidden_dim, num_channels, 1),
        )
        
        nn.init.zeros_(self.update_net[-1].weight)
        nn.init.zeros_(self.update_net[-1].bias)
    
    def get_alive_mask(self, state: torch.Tensor, threshold: float = 0.1) -> torch.Tensor:
        alpha = state[:, 3:4, :, :]
        alive = F.max_pool2d(alpha, 3, stride=1, padding=1) > threshold
        return alive.float()
    
    def forward(self, state: torch.Tensor, update_rate: float = 0.5) -> torch.Tensor:
        pre_alive = self.get_alive_mask(state)
        perceived = self.perception(state)
        delta = self.update_net(perceived)
        
        if self.training:
            update_mask = (torch.rand_like(state[:, :1, :, :]) < update_rate).float()
        else:
            update_mask = 1.0
        
        state = state + delta * update_mask
        post_alive = self.get_alive_mask(state)
        alive_mask = pre_alive * post_alive
        state = state * alive_mask
        
        return state
    
    def run_steps(self, state: torch.Tensor, steps: int) -> torch.Tensor:
        for _ in range(steps):
            state = self(state)
        return state

nca = NeuralCA().to(device)
print(f"Model parameters: {sum(p.numel() for p in nca.parameters()):,}")

## 5. Training to Grow a Pattern

In [None]:
def create_target_circle(size: int = 64) -> torch.Tensor:
    """Create a simple circle target pattern."""
    target = torch.zeros(1, 4, size, size)
    center = size // 2
    
    y, x = torch.meshgrid(torch.arange(size), torch.arange(size), indexing='ij')
    y, x = y.float(), x.float()
    
    dist_from_center = torch.sqrt((x - center)**2 + (y - center)**2)
    circle_radius = size * 0.35
    circle_mask = dist_from_center < circle_radius
    
    target[0, 0][circle_mask] = 1.0  # R
    target[0, 1][circle_mask] = 0.5  # G
    target[0, 2][circle_mask] = 0.0  # B
    target[0, 3][circle_mask] = 1.0  # A
    
    return target

target = create_target_circle(64).to(device)

plt.figure(figsize=(4, 4))
img = target[0].permute(1, 2, 0).cpu().numpy()
plt.imshow(img)
plt.title('Target Pattern')
plt.axis('off')
plt.show()

In [None]:
def train_nca(model, target, num_epochs=500, pool_size=256, batch_size=4, lr=2e-3):
    """Train NCA using a pool of states."""
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    size = target.shape[-1]
    pool = create_seed(size).repeat(pool_size, 1, 1, 1).to(device)
    losses = []
    
    for epoch in range(num_epochs):
        batch_indices = torch.randint(0, pool_size, (batch_size,))
        batch = pool[batch_indices].clone()
        num_steps = torch.randint(64, 96, (1,)).item()
        
        for _ in range(num_steps):
            batch = model(batch)
        
        visible = get_visible_state(batch)
        loss = F.mse_loss(visible, target.expand(batch_size, -1, -1, -1))
        
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        
        with torch.no_grad():
            pool[batch_indices] = batch.detach()
            batch_losses = F.mse_loss(visible, target.expand(batch_size, -1, -1, -1), reduction='none')
            batch_losses = batch_losses.mean(dim=(1, 2, 3))
            worst_idx = batch_indices[batch_losses.argmax()]
            pool[worst_idx] = create_seed(size).to(device).squeeze(0)
        
        losses.append(loss.item())
        
        if epoch % 100 == 0:
            print(f"Epoch {epoch}, Loss: {loss.item():.6f}")
    
    return losses

nca = NeuralCA().to(device)
print("Training Neural CA...")
losses = train_nca(nca, target, num_epochs=500)

## 6. Visualize Growth

In [None]:
def visualize_growth(model, size=64, steps=100):
    """Visualize the NCA growing from seed."""
    model.eval()
    state = create_seed(size).to(device)
    history = [get_visible_state(state).cpu()]
    
    with torch.no_grad():
        for _ in range(steps):
            state = model(state)
            history.append(get_visible_state(state).cpu())
    
    fig, axes = plt.subplots(1, 6, figsize=(15, 3))
    for i, step in enumerate([0, 20, 40, 60, 80, 100]):
        img = history[step][0].permute(1, 2, 0).numpy()
        axes[i].imshow(img)
        axes[i].set_title(f'Step {step}')
        axes[i].axis('off')
    
    plt.suptitle('NCA Growth from Seed')
    plt.tight_layout()
    plt.show()

visualize_growth(nca)

## 7. Regeneration: Self-Repair from Damage

In [None]:
def test_regeneration(model, size=64, grow_steps=100, regen_steps=100):
    """Test NCA regeneration after damage."""
    model.eval()
    
    state = create_seed(size).to(device)
    with torch.no_grad():
        for _ in range(grow_steps):
            state = model(state)
    
    grown = state.clone()
    damaged = state.clone()
    damaged[:, :, 20:44, 20:44] = 0
    
    regenerating = damaged.clone()
    regen_history = [get_visible_state(regenerating).cpu()]
    
    with torch.no_grad():
        for _ in range(regen_steps):
            regenerating = model(regenerating)
            regen_history.append(get_visible_state(regenerating).cpu())
    
    fig, axes = plt.subplots(2, 5, figsize=(15, 6))
    
    axes[0, 0].imshow(get_visible_state(grown)[0].permute(1, 2, 0).cpu().numpy())
    axes[0, 0].set_title('Fully Grown')
    axes[0, 1].imshow(get_visible_state(damaged)[0].permute(1, 2, 0).cpu().numpy())
    axes[0, 1].set_title('After Damage')
    for i in range(3):
        axes[0, i+2].axis('off')
    
    for i, step in enumerate([0, 25, 50, 75, 100]):
        axes[1, i].imshow(regen_history[step][0].permute(1, 2, 0).numpy())
        axes[1, i].set_title(f'Regen Step {step}')
    
    for ax in axes.flat:
        ax.axis('off')
    
    plt.suptitle('NCA Regeneration After Damage')
    plt.tight_layout()
    plt.show()

test_regeneration(nca)

## 8. FAANG Interview Questions

### Q1: What is a Neural Cellular Automaton and how does it differ from classical CA?

**Answer**:

**Classical CA** (e.g., Game of Life):
- Fixed, hand-designed rules
- Discrete states (0 or 1)
- Non-differentiable

**Neural CA**:
- Learnable rules (neural network)
- Continuous states (real values)
- Fully differentiable -> can train with gradient descent

---

### Q2: Why use Sobel filters for perception?

**Answer**:

Sobel filters detect gradients in the neighborhood:
- Identity: Current cell value
- Sobel-X: Horizontal gradient
- Sobel-Y: Vertical gradient

Useful because gradients indicate direction of change.

---

### Q3: How does pool-based training work?

**Answer**:

1. Maintain pool of N states
2. Sample batch from pool
3. Run for random number of steps
4. Compute loss and update
5. Put updated states back in pool
6. Replace worst sample with fresh seed

This tests persistence and robustness.

---

### Q4: Why do NCAs regenerate?

**Answer**:

- Local rules: cells only see neighbors
- Same update everywhere
- Pool training exposes model to partial states
- Target pattern becomes a stable attractor

---

### Q5: Applications of Neural CA?

**Answer**:

- Texture synthesis
- Pattern growth
- Self-repair systems
- Morphogenesis modeling
- Generative art
- Decentralized computing

## 9. Key Takeaways

1. **Neural CA** combine cellular automata with learnable neural networks
2. **Perception** uses Sobel filters to sense local gradients
3. **Update rule** is a small network applied to each cell
4. **Residual updates** with stochastic masks enable stable growth
5. **Pool-based training** tests persistence and robustness
6. **Regeneration** emerges naturally from local rules
7. **Self-organization** produces complex patterns from simple seeds