# Phase 2: Formation Transitions with Advection

This notebook uses **advection-based mass transport** instead of direct alpha updates.

## Key Differences from Standard NCA

| Standard NCA | Advection NCA |
|--------------|---------------|
| NCA outputs alpha updates | NCA outputs velocity |
| Alpha can appear anywhere | Mass is transported |
| Conservation is soft loss | Conservation by construction |
| Movement is emergent | Movement is explicit |

## Physics

```
mass_new = mass - dt * outflow + dt * inflow

where:
  outflow = mass * |velocity|
  inflow = neighbor_mass * neighbor_velocity_toward_me
```

In [3]:
# Setup
import subprocess
import sys
import os

REPO_URL = "https://github.com/JackHopkins/FormationHNCA.git"

if os.path.exists("/content"):
    REPO_DIR = "/content/FormationHNCA"
elif os.path.exists("/workspace"):
    REPO_DIR = "/workspace/FormationHNCA"
else:
    REPO_DIR = os.path.expanduser("~/FormationHNCA")

if os.path.exists(REPO_DIR):
    print(f"Pulling latest changes in {REPO_DIR}...")
    result = subprocess.run(["git", "-C", REPO_DIR, "pull"], capture_output=True, text=True)
    print(result.stdout or "Already up to date.")
else:
    print(f"Cloning repository to {REPO_DIR}...")
    subprocess.run(["git", "clone", REPO_URL, REPO_DIR], check=True)

os.chdir(REPO_DIR)

print("Installing JAX with CUDA support...")
subprocess.run([sys.executable, "-m", "pip", "install", "-q", "jax[cuda12]"], check=True)

print("Installing battle-nca package...")
subprocess.run([sys.executable, "-m", "pip", "install", "-q", "-e", "."], check=True)

src_path = os.path.join(REPO_DIR, "src")
if src_path not in sys.path:
    sys.path.insert(0, src_path)

print(f"\nWorking directory: {os.getcwd()}")
print("Setup complete!")

Pulling latest changes in /workspace/FormationHNCA...
Already up to date.

Installing JAX with CUDA support...
Installing battle-nca package...

Working directory: /workspace/FormationHNCA
Setup complete!


In [4]:
# Setup and imports
import sys
sys.path.insert(0, '/Users/jackhopkins/PycharmProjects/HNCA/src')

import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state
import optax
import matplotlib.pyplot as plt
import numpy as np
from functools import partial

print(f"JAX devices: {jax.devices()}")

JAX devices: [CudaDevice(id=0)]


In [5]:
# Import advection NCA components
from battle_nca.core.advection import advect_mass, check_mass_conservation
from battle_nca.hierarchy.advection_nca import (
    AdvectionNCA, 
    ADVECTION_CHANNELS,
    create_advection_seed,
    create_formation_from_alpha
)
from battle_nca.combat import create_formation_target, rotate_formation
from battle_nca.combat.formations import FormationTypes

print(f"Advection channels: {ADVECTION_CHANNELS.TOTAL}")
print(f"  Mass: {ADVECTION_CHANNELS.MASS}")
print(f"  Velocity: {ADVECTION_CHANNELS.VELOCITY_X}, {ADVECTION_CHANNELS.VELOCITY_Y}")
print(f"  Hidden: {ADVECTION_CHANNELS.HIDDEN_START}-{ADVECTION_CHANNELS.HIDDEN_END}")

ModuleNotFoundError: No module named 'battle_nca.core.advection'

In [None]:
# Configuration
GRID_SIZE = 64
NUM_CHANNELS = ADVECTION_CHANNELS.TOTAL  # 16 channels

print(f"Grid: {GRID_SIZE}x{GRID_SIZE}")
print(f"Channels: {NUM_CHANNELS}")

In [None]:
# Test advection physics
print("Testing advection conservation...")

# Create a blob of mass
test_mass = jnp.zeros((GRID_SIZE, GRID_SIZE))
test_mass = test_mass.at[28:36, 28:36].set(1.0)

# Velocity pointing right
test_vx = jnp.ones((GRID_SIZE, GRID_SIZE)) * 0.5
test_vy = jnp.zeros((GRID_SIZE, GRID_SIZE))

# Run advection
mass_before = test_mass
mass_after = advect_mass(mass_before, test_vx, test_vy, dt=0.5)

# Check conservation
conserved, error = check_mass_conservation(mass_before, mass_after)
print(f"Mass before: {float(jnp.sum(mass_before)):.4f}")
print(f"Mass after: {float(jnp.sum(mass_after)):.4f}")
print(f"Conserved: {conserved}, Error: {error:.6f}")

# Visualize
fig, axes = plt.subplots(1, 3, figsize=(12, 4))
axes[0].imshow(mass_before, cmap='hot')
axes[0].set_title('Before')
axes[1].imshow(mass_after, cmap='hot')
axes[1].set_title('After 1 step')

# Run more steps
mass_10 = mass_before
for _ in range(10):
    mass_10 = advect_mass(mass_10, test_vx, test_vy, dt=0.5)
axes[2].imshow(mass_10, cmap='hot')
axes[2].set_title('After 10 steps')

conserved_10, error_10 = check_mass_conservation(mass_before, mass_10)
print(f"After 10 steps - Conserved: {conserved_10}, Error: {error_10:.6f}")

plt.tight_layout()
plt.show()

In [None]:
# Create formations
formations = {}
for ftype in [FormationTypes.LINE, FormationTypes.SQUARE, FormationTypes.WEDGE]:
    target = create_formation_target(
        GRID_SIZE, GRID_SIZE, ftype,
        center=(GRID_SIZE//2, GRID_SIZE//2),
        scale=0.6
    )
    formations[ftype.name] = target[..., 3]  # Just alpha channel
    
# Visualize formations
fig, axes = plt.subplots(1, len(formations), figsize=(12, 4))
for ax, (name, alpha) in zip(axes, formations.items()):
    ax.imshow(alpha, cmap='hot')
    ax.set_title(f"{name}\nmass={float(jnp.sum(alpha)):.1f}")
plt.tight_layout()
plt.show()

In [None]:
# Initialize AdvectionNCA
advection_nca = AdvectionNCA(
    num_channels=NUM_CHANNELS,
    hidden_dim=64,
    fire_rate=0.5,
    advection_dt=0.25,
    advection_steps=2
)

# Initialize parameters
key = jax.random.PRNGKey(42)
init_key, train_key = jax.random.split(key)

# Create seed state
seed = create_advection_seed(
    GRID_SIZE, GRID_SIZE,
    spawn_region=(GRID_SIZE//2-4, GRID_SIZE//2+4, GRID_SIZE//2-4, GRID_SIZE//2+4)
)

# Dummy signal (target velocity)
dummy_signal = jnp.zeros((GRID_SIZE, GRID_SIZE, 4))

# Initialize
params = advection_nca.init(init_key, seed, jax.random.PRNGKey(0), parent_signal=dummy_signal)['params']

# Count parameters
param_count = sum(p.size for p in jax.tree_util.tree_leaves(params))
print(f"AdvectionNCA parameters: {param_count:,}")
print(f"Seed shape: {seed.shape}")
print(f"Seed mass: {float(jnp.sum(seed[..., ADVECTION_CHANNELS.MASS])):.2f}")

In [None]:
# Create training state
from battle_nca.training import create_optimizer

optimizer = create_optimizer(
    learning_rate=2e-3,
    gradient_clip=1.0,
    use_schedule=True,
    warmup_steps=200,
    decay_steps=4000
)

state = train_state.TrainState.create(
    apply_fn=advection_nca.apply,
    params=params,
    tx=optimizer
)

print(f"Optimizer created with warmup schedule")

In [None]:
# Create goal signal from target formation
def create_goal_signal(target_alpha, current_mass):
    """Create signal that guides NCA toward target.
    
    Signal contains:
    - Channel 0: Target mass (where to be)
    - Channel 1-2: Target velocity (direction toward target)
    - Channel 3: Distance to target (proximity)
    """
    # Blur target to create potential field
    blurred = target_alpha
    for _ in range(4):
        padded = jnp.pad(blurred, ((1, 1), (1, 1)), mode='edge')
        blurred = (
            padded[:-2, :-2] + padded[:-2, 1:-1] + padded[:-2, 2:] +
            padded[1:-1, :-2] + padded[1:-1, 1:-1] + padded[1:-1, 2:] +
            padded[2:, :-2] + padded[2:, 1:-1] + padded[2:, 2:]
        ) / 9.0
    
    # Gradient = target velocity
    padded = jnp.pad(blurred, ((1, 1), (1, 1)), mode='edge')
    target_vx = (padded[1:-1, 2:] - padded[1:-1, :-2]) / 2.0
    target_vy = (padded[2:, 1:-1] - padded[:-2, 1:-1]) / 2.0
    
    # Normalize
    mag = jnp.sqrt(target_vx**2 + target_vy**2 + 1e-8)
    target_vx = target_vx / mag
    target_vy = target_vy / mag
    
    # Proximity (how close to target)
    proximity = blurred / (jnp.max(blurred) + 1e-8)
    
    # Stack into signal
    signal = jnp.stack([target_alpha, target_vx, target_vy, proximity], axis=-1)
    return signal

# Test signal creation
target_alpha = formations['LINE']
test_signal = create_goal_signal(target_alpha, seed[..., ADVECTION_CHANNELS.MASS])
print(f"Signal shape: {test_signal.shape}")

fig, axes = plt.subplots(1, 4, figsize=(16, 4))
axes[0].imshow(test_signal[..., 0], cmap='hot')
axes[0].set_title('Target mass')
axes[1].imshow(test_signal[..., 1], cmap='RdBu', vmin=-1, vmax=1)
axes[1].set_title('Target Vx')
axes[2].imshow(test_signal[..., 2], cmap='RdBu', vmin=-1, vmax=1)
axes[2].set_title('Target Vy')
axes[3].imshow(test_signal[..., 3], cmap='viridis')
axes[3].set_title('Proximity')
plt.tight_layout()
plt.show()

In [None]:
# Training configuration
NUM_STEPS = 64  # Steps per trajectory
BATCH_SIZE = 8
NUM_EPOCHS = 2000
VIZ_INTERVAL = 100

print(f"Training config:")
print(f"  Steps per trajectory: {NUM_STEPS}")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Epochs: {NUM_EPOCHS}")

In [None]:
# Training step
@jax.jit
def train_step(state, batch, target_alpha, key):
    """Single training step.
    
    Loss:
    1. Mass distribution should match target
    2. Velocity should point toward target (auxiliary)
    """
    def loss_fn(params):
        keys = jax.random.split(key, NUM_STEPS)
        
        # Create goal signal
        signal = create_goal_signal(target_alpha, batch[..., ADVECTION_CHANNELS.MASS])
        
        # Run trajectory
        def step(carry, step_key):
            new_state = advection_nca.apply(
                {'params': params}, carry, step_key, parent_signal=signal
            )
            return new_state, new_state
        
        final_state, trajectory = jax.lax.scan(step, batch, keys)
        
        # Loss 1: Final mass should match target
        final_mass = final_state[..., ADVECTION_CHANNELS.MASS]
        mass_loss = jnp.mean((final_mass - target_alpha) ** 2)
        
        # Loss 2: Velocity should align with target direction (where mass exists)
        final_vx = final_state[..., ADVECTION_CHANNELS.VELOCITY_X]
        final_vy = final_state[..., ADVECTION_CHANNELS.VELOCITY_Y]
        target_vx = signal[..., 1]
        target_vy = signal[..., 2]
        
        # Only penalize velocity where there's mass to move
        mass_weight = final_mass + 0.1  # Small weight everywhere for exploration
        velocity_loss = jnp.mean(mass_weight * ((final_vx - target_vx)**2 + (final_vy - target_vy)**2))
        
        # Total loss
        total_loss = mass_loss + 0.1 * velocity_loss
        
        return total_loss, (final_state, mass_loss, velocity_loss)
    
    (loss, (final_state, mass_loss, vel_loss)), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)
    state = state.apply_gradients(grads=grads)
    
    return state, loss, mass_loss, vel_loss, final_state

print("Training step compiled")

In [None]:
# Pool-based training
formation_names = list(formations.keys())

# Initialize pool with seeds
pool_size = 256
pool = jnp.stack([seed] * pool_size)

# Add some formations to pool for variety
for i, (name, alpha) in enumerate(formations.items()):
    formation_state = create_formation_from_alpha(alpha)
    start_idx = (i + 1) * (pool_size // 4)
    end_idx = start_idx + pool_size // 8
    pool = pool.at[start_idx:end_idx].set(formation_state)

print(f"Pool size: {pool_size}")
print(f"Pool shape: {pool.shape}")

In [None]:
# Training loop
import time

losses = []
mass_losses = []
vel_losses = []

print("Starting training...")
print("="*50)

for epoch in range(NUM_EPOCHS):
    epoch_key = jax.random.PRNGKey(epoch)
    keys = jax.random.split(epoch_key, 3)
    
    # Select random batch from pool
    batch_idx = jax.random.choice(keys[0], pool_size, (BATCH_SIZE,), replace=False)
    batch = pool[batch_idx]
    
    # Select random target formation
    target_idx = int(jax.random.randint(keys[1], (), 0, len(formation_names)))
    target_name = formation_names[target_idx]
    target_alpha = formations[target_name]
    
    # Optional: random rotation
    angle = float(jax.random.uniform(keys[2], (), minval=-30, maxval=30))
    target_alpha = rotate_formation(target_alpha, angle)
    
    # Train step
    state, loss, m_loss, v_loss, final_states = train_step(
        state, batch, target_alpha, keys[2]
    )
    
    losses.append(float(loss))
    mass_losses.append(float(m_loss))
    vel_losses.append(float(v_loss))
    
    # Update pool with final states
    pool = pool.at[batch_idx].set(final_states)
    
    # Logging
    if epoch % 50 == 0:
        # Check mass conservation
        init_mass = float(jnp.sum(batch[0, ..., ADVECTION_CHANNELS.MASS]))
        final_mass = float(jnp.sum(final_states[0, ..., ADVECTION_CHANNELS.MASS]))
        conservation = final_mass / (init_mass + 1e-8)
        
        print(f"Epoch {epoch:4d} | Loss: {loss:.4f} | Mass: {m_loss:.4f} | Vel: {v_loss:.4f} | Conservation: {conservation:.3f}")
    
    # Visualization
    if epoch % VIZ_INTERVAL == 0 and epoch > 0:
        fig, axes = plt.subplots(2, 4, figsize=(16, 8))
        
        # Top row: batch sample evolution
        axes[0, 0].imshow(batch[0, ..., ADVECTION_CHANNELS.MASS], cmap='hot')
        axes[0, 0].set_title('Initial')
        axes[0, 1].imshow(target_alpha, cmap='hot')
        axes[0, 1].set_title(f'Target ({target_name})')
        axes[0, 2].imshow(final_states[0, ..., ADVECTION_CHANNELS.MASS], cmap='hot')
        axes[0, 2].set_title('Final')
        axes[0, 3].imshow(jnp.abs(final_states[0, ..., ADVECTION_CHANNELS.MASS] - target_alpha), cmap='hot')
        axes[0, 3].set_title('Error')
        
        # Bottom row: velocity and losses
        axes[1, 0].imshow(final_states[0, ..., ADVECTION_CHANNELS.VELOCITY_X], cmap='RdBu', vmin=-1, vmax=1)
        axes[1, 0].set_title('Velocity X')
        axes[1, 1].imshow(final_states[0, ..., ADVECTION_CHANNELS.VELOCITY_Y], cmap='RdBu', vmin=-1, vmax=1)
        axes[1, 1].set_title('Velocity Y')
        axes[1, 2].plot(losses[-VIZ_INTERVAL:])
        axes[1, 2].set_title('Total Loss')
        axes[1, 3].plot(mass_losses[-VIZ_INTERVAL:], label='Mass')
        axes[1, 3].plot(vel_losses[-VIZ_INTERVAL:], label='Velocity')
        axes[1, 3].legend()
        axes[1, 3].set_title('Loss Components')
        
        plt.tight_layout()
        plt.show()

print("="*50)
print("Training complete!")

In [None]:
# Final evaluation
print("\nFinal Evaluation")
print("="*50)

# Test each formation transition
for source_name in formation_names:
    for target_name in formation_names:
        if source_name == target_name:
            continue
            
        # Create source state
        source_state = create_formation_from_alpha(formations[source_name])
        source_state = source_state[None, ...]  # Add batch dim
        
        target_alpha = formations[target_name]
        
        # Run transition
        key = jax.random.PRNGKey(0)
        signal = create_goal_signal(target_alpha, source_state[0, ..., ADVECTION_CHANNELS.MASS])
        
        current = source_state
        for i in range(NUM_STEPS):
            key, subkey = jax.random.split(key)
            current = advection_nca.apply(
                {'params': state.params}, current, subkey, parent_signal=signal
            )
        
        final_mass = current[0, ..., ADVECTION_CHANNELS.MASS]
        mse = float(jnp.mean((final_mass - target_alpha)**2))
        
        # Conservation check
        init_total = float(jnp.sum(source_state[0, ..., ADVECTION_CHANNELS.MASS]))
        final_total = float(jnp.sum(final_mass))
        conservation = final_total / (init_total + 1e-8)
        
        print(f"{source_name:8s} -> {target_name:8s} | MSE: {mse:.4f} | Conservation: {conservation:.3f}")

In [None]:
# Visualize a transition
source_name = 'LINE'
target_name = 'SQUARE'

source_state = create_formation_from_alpha(formations[source_name])
target_alpha = formations[target_name]
signal = create_goal_signal(target_alpha, source_state[..., ADVECTION_CHANNELS.MASS])

# Collect trajectory
trajectory = [source_state[..., ADVECTION_CHANNELS.MASS]]
current = source_state[None, ...]
key = jax.random.PRNGKey(0)

for i in range(NUM_STEPS):
    key, subkey = jax.random.split(key)
    current = advection_nca.apply(
        {'params': state.params}, current, subkey, parent_signal=signal
    )
    if i % (NUM_STEPS // 8) == 0:
        trajectory.append(current[0, ..., ADVECTION_CHANNELS.MASS])

trajectory.append(current[0, ..., ADVECTION_CHANNELS.MASS])

# Plot trajectory
fig, axes = plt.subplots(2, 5, figsize=(20, 8))
for i, mass in enumerate(trajectory[:10]):
    ax = axes[i // 5, i % 5]
    ax.imshow(mass, cmap='hot', vmin=0, vmax=1)
    total_mass = float(jnp.sum(mass))
    ax.set_title(f'Step {i * (NUM_STEPS // 8)}\nmass={total_mass:.1f}')
    ax.axis('off')

plt.suptitle(f'{source_name} -> {target_name} Transition', fontsize=14)
plt.tight_layout()
plt.show()