# Phase 2: Formation Transitions

**Prerequisite**: Complete `phase1_line_formation.ipynb` first.

**Goal**: Train the NCA to transition from one formation to another based on a goal signal.

**Key Difference from Phase 1**:
- Phase 1: Seed → Formation (growing from nothing)
- Phase 2: Formation A → Formation B (morphing between established shapes)

**Formations**:
- Line (horizontal band)
- Phalanx (deep rectangular block)
- Square (hollow defensive square)
- Wedge (triangle/arrow)
- Column (vertical band)

**Training Approach**:
1. Start from a random established formation (not seed)
2. Provide a NEW goal signal for a different formation
3. Train NCA to morph into the target formation
4. Mass must be conserved during transition

**Success Criteria**:
- Can transition between any pair of formations
- Responds correctly to rotation signals
- Mass conservation maintained during transition

In [None]:
# Setup
import subprocess
import sys
import os

REPO_URL = "https://github.com/JackHopkins/FormationHNCA.git"

if os.path.exists("/content"):
    REPO_DIR = "/content/FormationHNCA"
elif os.path.exists("/workspace"):
    REPO_DIR = "/workspace/FormationHNCA"
else:
    REPO_DIR = os.path.expanduser("~/FormationHNCA")

if os.path.exists(REPO_DIR):
    print(f"Pulling latest changes in {REPO_DIR}...")
    result = subprocess.run(["git", "-C", REPO_DIR, "pull"], capture_output=True, text=True)
    print(result.stdout or "Already up to date.")
else:
    print(f"Cloning repository to {REPO_DIR}...")
    subprocess.run(["git", "clone", REPO_URL, REPO_DIR], check=True)

os.chdir(REPO_DIR)

print("Installing JAX with CUDA support...")
subprocess.run([sys.executable, "-m", "pip", "install", "-q", "jax[cuda12]"], check=True)

print("Installing battle-nca package...")
subprocess.run([sys.executable, "-m", "pip", "install", "-q", "-e", "."], check=True)

src_path = os.path.join(REPO_DIR, "src")
if src_path not in sys.path:
    sys.path.insert(0, src_path)

print(f"\nWorking directory: {os.getcwd()}")
print("Setup complete!")

In [None]:
import jax
import gc

gc.collect()
jax.clear_caches()

print(f"JAX version: {jax.__version__}")
print(f"JAX devices: {jax.devices()}")

if jax.devices()[0].platform == 'gpu':
    print("GPU acceleration enabled!")
else:
    print("WARNING: Running on CPU.")

In [None]:
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state
import optax
import matplotlib.pyplot as plt
import numpy as np
from functools import partial
import time
import pickle
from pathlib import Path

from battle_nca.core import NCA, perceive
from battle_nca.core.nca import create_seed
from battle_nca.hierarchy import ChildNCA, ParentNCA, HierarchicalNCA
from battle_nca.hierarchy.child_nca import create_army_seed, CHILD_CHANNELS
from battle_nca.combat import FormationTargets, create_formation_target, rotate_formation
from battle_nca.combat.formations import FormationTypes
from battle_nca.training import NCAPool, Trainer, TrainingConfig
from battle_nca.training.optimizers import create_optimizer, normalize_gradients

print("All imports successful!")

## Configuration

In [None]:
GRID_SIZE = 64
NUM_CHANNELS = 24

# RESET: Set True to start Phase 2 from Phase 1 checkpoint
# Set False to resume Phase 2 training
RESET = True

config = TrainingConfig(
    batch_size=32,
    pool_size=1024,
    min_steps=128,
    max_steps=192,
    learning_rate=2e-3,
    gradient_clip=1.0,
    damage_samples=3,
    damage_start_epoch=1000,  # Later damage since transitions are harder
    log_interval=100,
    checkpoint_interval=500
)

PHASE2_EPOCHS = 5000
SEED = 42

# Rotation augmentation
USE_ROTATION_AUGMENTATION = True
ROTATION_CONTINUOUS = True

# Checkpoints
PHASE1_CHECKPOINT = Path('checkpoints/phase1_line.pkl')
PHASE2_CHECKPOINT = Path('checkpoints/phase2_transitions.pkl')
PHASE2_CHECKPOINT.parent.mkdir(exist_ok=True)

print(f"Phase 2: Formation Transitions")
print(f"Task: Formation A -> Formation B (not seed -> formation)")
print(f"Steps per transition: {config.min_steps}-{config.max_steps}")

## Load Phase 1 Model

In [None]:
child_nca = ChildNCA(
    num_channels=NUM_CHANNELS,
    hidden_dim=128,
    fire_rate=0.5,
    use_circular_padding=True
)

# Seed for initialization only
seed = create_army_seed(
    height=GRID_SIZE,
    width=GRID_SIZE,
    team_color=(1.0, 0.0, 0.0),
    unit_type=0,
    formation_id=0,
    spawn_region=(GRID_SIZE//2 - 2, GRID_SIZE//2 + 2, GRID_SIZE//2 - 2, GRID_SIZE//2 + 2)
)

print(f"Seed shape: {seed.shape}")

In [None]:
key = jax.random.PRNGKey(SEED)
key, init_key = jax.random.split(key)

def count_params(params):
    return sum(p.size for p in jax.tree_util.tree_leaves(params))

if RESET:
    if PHASE1_CHECKPOINT.exists():
        print(f"Loading Phase 1 checkpoint from {PHASE1_CHECKPOINT}...")
        with open(PHASE1_CHECKPOINT, 'rb') as f:
            loaded_checkpoint = pickle.load(f)
        params = loaded_checkpoint['params']
        print(f"  Loaded Phase 1 model with {count_params(params):,} parameters")
        print(f"  Phase 1 best loss: {loaded_checkpoint['metrics'].get('best_loss', 'N/A')}")
    else:
        print("ERROR: Phase 1 checkpoint not found!")
        print("Please run phase1_line_formation.ipynb first.")
        raise FileNotFoundError(f"Missing: {PHASE1_CHECKPOINT}")
else:
    if PHASE2_CHECKPOINT.exists():
        print(f"Loading Phase 2 checkpoint from {PHASE2_CHECKPOINT}...")
        with open(PHASE2_CHECKPOINT, 'rb') as f:
            loaded_checkpoint = pickle.load(f)
        params = loaded_checkpoint['params']
        epochs_done = len(loaded_checkpoint['metrics'].get('losses', []))
        print(f"  Loaded Phase 2 model, {epochs_done} epochs done")
    else:
        print("No Phase 2 checkpoint. Loading Phase 1...")
        with open(PHASE1_CHECKPOINT, 'rb') as f:
            loaded_checkpoint = pickle.load(f)
        params = loaded_checkpoint['params']

print(f"Total parameters: {count_params(params):,}")

## Create Formation Targets

In [None]:
targets = {
    'line': FormationTargets.line(GRID_SIZE, GRID_SIZE),
    'phalanx': FormationTargets.phalanx(GRID_SIZE, GRID_SIZE, depth=8),
    'square': FormationTargets.square(GRID_SIZE, GRID_SIZE, thickness=3),
    'wedge': FormationTargets.wedge(GRID_SIZE, GRID_SIZE),
    'column': FormationTargets.column(GRID_SIZE, GRID_SIZE, col_width=4),
}

formation_names = list(targets.keys())
num_formations = len(formation_names)

fig, axes = plt.subplots(1, num_formations, figsize=(3*num_formations, 3))
for ax, (name, target) in zip(axes, targets.items()):
    ax.imshow(target[..., 3], cmap='gray', vmin=0, vmax=1)
    ax.set_title(f'{name}')
    ax.axis('off')
plt.suptitle('Formation Targets', fontsize=14)
plt.tight_layout()
plt.show()

## Create Starting Formation Pool

Instead of starting from seeds, we create a pool of established formations. Each training iteration will:
1. Sample a starting formation (already in shape)
2. Pick a different target formation
3. Train the NCA to transition between them

In [None]:
def create_formation_state(target, key):
    """Create an NCA state from a formation target.
    
    Converts the target pattern into a full 24-channel state with:
    - RGB from target
    - Alpha from target
    - Other channels initialized appropriately
    """
    state = jnp.zeros((GRID_SIZE, GRID_SIZE, NUM_CHANNELS))
    
    # Copy RGBA from target
    state = state.at[..., :4].set(target[..., :4])
    
    # Initialize health where alpha > 0
    alive = target[..., 3] > 0.1
    state = state.at[..., 4].set(jnp.where(alive, 1.0, 0.0))  # Health
    state = state.at[..., 5].set(jnp.where(alive, 1.0, 0.0))  # Morale
    
    # Small random noise in hidden channels for diversity
    noise = jax.random.uniform(key, (GRID_SIZE, GRID_SIZE, 9), minval=-0.1, maxval=0.1)
    state = state.at[..., 15:24].set(noise)
    
    return state


def create_transition_pool(pool_size, key):
    """Create a pool of established formations for transition training.
    
    Each sample in the pool is an already-formed formation, not a seed.
    """
    pool_states = []
    pool_formation_ids = []  # Track which formation each state is
    
    samples_per_formation = pool_size // num_formations
    
    for formation_idx, name in enumerate(formation_names):
        target = targets[name]
        
        for i in range(samples_per_formation):
            key, subkey1, subkey2 = jax.random.split(key, 3)
            
            # Random rotation for variety
            if USE_ROTATION_AUGMENTATION:
                angle = float(jax.random.uniform(subkey1, (), minval=0, maxval=2*jnp.pi))
                rotated_target = rotate_formation(target, angle)
            else:
                rotated_target = target
            
            state = create_formation_state(rotated_target, subkey2)
            pool_states.append(state)
            pool_formation_ids.append(formation_idx)
    
    # Fill remaining slots
    remaining = pool_size - len(pool_states)
    for i in range(remaining):
        key, subkey1, subkey2 = jax.random.split(key, 3)
        formation_idx = i % num_formations
        target = targets[formation_names[formation_idx]]
        
        if USE_ROTATION_AUGMENTATION:
            angle = float(jax.random.uniform(subkey1, (), minval=0, maxval=2*jnp.pi))
            rotated_target = rotate_formation(target, angle)
        else:
            rotated_target = target
        
        state = create_formation_state(rotated_target, subkey2)
        pool_states.append(state)
        pool_formation_ids.append(formation_idx)
    
    return jnp.stack(pool_states), jnp.array(pool_formation_ids)


key, pool_key = jax.random.split(key)
transition_pool, pool_formation_ids = create_transition_pool(config.pool_size, pool_key)

print(f"Created transition pool: {transition_pool.shape}")
print(f"Formation distribution in pool:")
for i, name in enumerate(formation_names):
    count = int(jnp.sum(pool_formation_ids == i))
    print(f"  {name}: {count}")

In [None]:
# Visualize samples from the transition pool
fig, axes = plt.subplots(2, 5, figsize=(15, 6))

key, viz_key = jax.random.split(key)
sample_indices = jax.random.permutation(viz_key, config.pool_size)[:10]

for i, idx in enumerate(sample_indices):
    row, col = i // 5, i % 5
    state = transition_pool[idx]
    formation_id = pool_formation_ids[idx]
    
    axes[row, col].imshow(state[..., 3], cmap='gray', vmin=0, vmax=1)
    axes[row, col].set_title(f'{formation_names[formation_id]}')
    axes[row, col].axis('off')

plt.suptitle('Transition Pool Samples (starting formations)', fontsize=14)
plt.tight_layout()
plt.show()

## Training Setup

In [None]:
optimizer = create_optimizer(
    learning_rate=config.learning_rate,
    gradient_clip=config.gradient_clip
)

state = train_state.TrainState.create(
    apply_fn=child_nca.apply,
    params=params,
    tx=optimizer
)

print(f"Optimizer ready")

In [None]:
# Training step - OPTIMIZED
# 1. Fixed num_steps to avoid recompilation
# 2. Vectorized checkpoint computation

NUM_CHECKPOINTS = 4
MASS_CONSERVATION_WEIGHT = 0.5

# FIX: Use fixed step count to avoid JIT recompilation
FIXED_NUM_STEPS = 160

@jax.jit
def train_step(state, batch, target, key, formation_signal, initial_mass):
    """Training step for formation transitions.
    
    OPTIMIZED:
    - Fixed num_steps to avoid recompilation
    - Vectorized checkpoint computation
    - initial_mass passed in to avoid recomputation
    """
    num_steps = FIXED_NUM_STEPS
    
    def loss_fn(params):
        keys = jax.random.split(key, num_steps)
        
        checkpoint_interval = num_steps // NUM_CHECKPOINTS
        
        def step(carry, inputs):
            step_idx, subkey = inputs
            state_in = carry.at[..., CHILD_CHANNELS.PARENT_SIGNAL_START:CHILD_CHANNELS.PARENT_SIGNAL_END].set(
                formation_signal[..., :2]
            )
            new_state = child_nca.apply(
                {'params': params}, state_in, subkey, parent_signal=formation_signal
            )
            return new_state, new_state
        
        step_inputs = (jnp.arange(num_steps), keys)
        final, all_states = jax.lax.scan(step, batch, step_inputs)
        
        # === VECTORIZED checkpoint extraction ===
        checkpoint_indices = jnp.array([
            checkpoint_interval - 1,
            2 * checkpoint_interval - 1,
            3 * checkpoint_interval - 1,
            4 * checkpoint_interval - 1
        ])
        checkpoint_indices = jnp.minimum(checkpoint_indices, num_steps - 1)
        checkpoint_states = all_states[checkpoint_indices]
        
        # === VECTORIZED formation loss ===
        checkpoint_errors = (checkpoint_states[..., :4] - target) ** 2
        checkpoint_mses = jnp.mean(checkpoint_errors, axis=(1, 2, 3, 4))
        checkpoint_weights = jnp.array([0.1, 0.2, 0.3, 0.4])
        formation_loss = jnp.sum(checkpoint_weights * checkpoint_mses)
        
        # === VECTORIZED mass conservation loss ===
        checkpoint_masses = jnp.sum(checkpoint_states[..., 3], axis=(2, 3))
        relative_errors = jnp.abs(checkpoint_masses - initial_mass[None, :]) / (initial_mass[None, :] + 1e-6)
        mass_conservation_loss = jnp.mean(relative_errors)
        
        total_loss = formation_loss + MASS_CONSERVATION_WEIGHT * mass_conservation_loss
        return total_loss, (final, formation_loss, mass_conservation_loss)
    
    (loss, (outputs, form_loss, mass_loss)), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)
    grads = normalize_gradients(grads)
    state = state.apply_gradients(grads=grads)
    return state, loss, outputs, form_loss, mass_loss


def create_formation_signal(batch_size, height, width, formation_idx, angle, num_formations=5):
    """Create goal-conditioning signal for target formation."""
    formation_val = (formation_idx / (num_formations - 1)) * 2 - 1
    angle_val = (angle / jnp.pi) - 1
    signal = jnp.zeros((batch_size, height, width, 2))
    signal = signal.at[..., 0].set(formation_val)
    signal = signal.at[..., 1].set(angle_val)
    return signal


print("Training step ready (OPTIMIZED).")
print(f"  - Fixed steps: {FIXED_NUM_STEPS}")
print(f"  - Vectorized checkpoint computation")
print(f"  - Mass conservation weight: {MASS_CONSERVATION_WEIGHT}")

## Phase 2 Training: Formation Transitions

In [None]:
print("="*60)
print("PHASE 2: Formation Transitions")
print("="*60)
print(f"Task: Start from Formation A, transition to Formation B")
print(f"Formations: {formation_names}")
print(f"Fixed steps per iteration: {FIXED_NUM_STEPS}")
print()

PATIENCE = 2000
MIN_DELTA = 1e-6
VIZ_INTERVAL = 500

losses = []
form_losses = []
mass_losses = []
times = []
transition_history = []
best_loss = float('inf')
best_params = state.params
epochs_without_improvement = 0

# Mutable pool
current_pool = transition_pool.copy()
current_pool_ids = pool_formation_ids.copy()

print("Compiling JIT (first epoch will be slow)...")

for epoch in range(PHASE2_EPOCHS):
    start_time = time.time()
    key, subkey1, subkey2, subkey3, subkey4, subkey5 = jax.random.split(key, 6)
    
    # Sample batch from pool
    batch_indices = jax.random.permutation(subkey1, config.pool_size)[:config.batch_size]
    batch = current_pool[batch_indices]
    batch_source_formations = current_pool_ids[batch_indices]
    
    # Compute initial mass for mass conservation
    initial_mass = jnp.sum(batch[..., 3], axis=(1, 2))  # (B,)
    
    # Pick a TARGET formation
    target_idx = int(jax.random.randint(subkey2, (), 0, num_formations))
    target = targets[formation_names[target_idx]]
    
    # Rotation augmentation for target
    if USE_ROTATION_AUGMENTATION:
        if ROTATION_CONTINUOUS:
            angle = float(jax.random.uniform(subkey5, (), minval=0, maxval=2 * jnp.pi))
        else:
            angle = float(jax.random.randint(subkey5, (), 0, 4)) * (jnp.pi / 2)
        target = rotate_formation(target, angle)
    else:
        angle = 0.0
    
    # Goal signal for target formation
    formation_signal = create_formation_signal(
        config.batch_size, GRID_SIZE, GRID_SIZE, target_idx, angle, num_formations
    )
    
    batch_mass_before = float(jnp.mean(initial_mass))
    
    # Visualization
    if epoch % VIZ_INTERVAL == 0 and epoch > 0:
        fig, axes = plt.subplots(1, 4, figsize=(16, 4))
        
        start_alpha = np.array(batch[0, ..., 3])
        source_name = formation_names[int(batch_source_formations[0])]
        axes[0].imshow(start_alpha, cmap='gray', vmin=0, vmax=1)
        axes[0].set_title(f'Starting: {source_name}\n(mass: {np.sum(start_alpha):.0f})')
        axes[0].axis('off')
        
        axes[1].imshow(np.array(target[..., 3]), cmap='gray', vmin=0, vmax=1)
        axes[1].set_title(f'Target: {formation_names[target_idx]}\n(angle: {np.degrees(angle):.0f})')
        axes[1].axis('off')
        
        axes[2].text(0.5, 0.5, f"{source_name}\n->\n{formation_names[target_idx]}", 
                     ha='center', va='center', fontsize=16)
        axes[2].set_title('Transition')
        axes[2].axis('off')
        
        pool_masses = [float(jnp.sum(batch[i, ..., 3])) for i in range(min(32, config.batch_size))]
        axes[3].hist(pool_masses, bins=20, color='steelblue', edgecolor='black')
        axes[3].set_xlabel('Mass')
        axes[3].set_title('Pool Health')
        
        plt.suptitle(f'Epoch {epoch} | Best: {best_loss:.6f}')
        plt.tight_layout()
        plt.show()
    
    # Training step - pass initial_mass for conservation loss
    state, loss, outputs, form_loss, mass_loss = train_step(
        state, batch, target, subkey4, formation_signal, initial_mass
    )
    
    loss_val = float(loss)
    form_loss_val = float(form_loss)
    mass_loss_val = float(mass_loss)
    
    batch_mass_after = float(jnp.mean(jnp.sum(outputs[..., 3], axis=(1, 2))))
    mass_retention = batch_mass_after / (batch_mass_before + 1e-6)
    
    if np.isnan(loss_val) or np.isinf(loss_val):
        print(f"\nEarly stopping at epoch {epoch}: loss is {loss_val}")
        state = state.replace(params=best_params)
        break
    
    if loss_val < best_loss - MIN_DELTA:
        best_loss = loss_val
        best_params = state.params
        epochs_without_improvement = 0
    else:
        epochs_without_improvement += 1
    
    if epochs_without_improvement >= PATIENCE:
        print(f"\nEarly stopping at epoch {epoch}: no improvement for {PATIENCE} epochs")
        state = state.replace(params=best_params)
        break
    
    # Update pool
    current_pool = current_pool.at[batch_indices].set(outputs)
    current_pool_ids = current_pool_ids.at[batch_indices].set(target_idx)
    
    elapsed = time.time() - start_time
    losses.append(loss_val)
    form_losses.append(form_loss_val)
    mass_losses.append(mass_loss_val)
    times.append(elapsed)
    
    source_formation = int(batch_source_formations[0])
    transition_history.append((source_formation, target_idx))
    
    if epoch % config.log_interval == 0:
        src = formation_names[source_formation]
        tgt = formation_names[target_idx]
        print(f"Epoch {epoch:4d}: loss={loss_val:.6f} (form:{form_loss_val:.6f}, mass:{mass_loss_val:.4f}), " +
              f"retention={mass_retention:.1%}, {src}->{tgt}, {elapsed:.2f}s")

print(f"\nPhase 2 complete. Best loss: {best_loss:.6f}")
if len(times) > 1:
    print(f"Average time per epoch (after JIT): {np.mean(times[1:]):.3f}s")

## Training Curves

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(12, 8))

window = 100

# Total loss
axes[0, 0].plot(losses, alpha=0.3)
if len(losses) > window:
    smoothed = np.convolve(losses, np.ones(window)/window, mode='valid')
    axes[0, 0].plot(range(window-1, len(losses)), smoothed, 'r-', linewidth=2)
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Total Loss')
axes[0, 0].set_title('Total Loss')
axes[0, 0].set_yscale('log')

# Formation loss
axes[0, 1].plot(form_losses, alpha=0.3, color='blue')
if len(form_losses) > window:
    smoothed = np.convolve(form_losses, np.ones(window)/window, mode='valid')
    axes[0, 1].plot(range(window-1, len(form_losses)), smoothed, 'b-', linewidth=2)
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Formation Loss')
axes[0, 1].set_title('Formation Loss')
axes[0, 1].set_yscale('log')

# Mass conservation
axes[1, 0].plot(mass_losses, alpha=0.3, color='green')
if len(mass_losses) > window:
    smoothed = np.convolve(mass_losses, np.ones(window)/window, mode='valid')
    axes[1, 0].plot(range(window-1, len(mass_losses)), smoothed, 'g-', linewidth=2)
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Mass Conservation Error')
axes[1, 0].set_title('Mass Conservation Loss')

# Transition matrix
transition_counts = np.zeros((num_formations, num_formations))
for src, tgt in transition_history:
    transition_counts[src, tgt] += 1

im = axes[1, 1].imshow(transition_counts, cmap='Blues')
axes[1, 1].set_xticks(range(num_formations))
axes[1, 1].set_yticks(range(num_formations))
axes[1, 1].set_xticklabels(formation_names, rotation=45, ha='right')
axes[1, 1].set_yticklabels(formation_names)
axes[1, 1].set_xlabel('Target')
axes[1, 1].set_ylabel('Source')
axes[1, 1].set_title('Transition Counts')
plt.colorbar(im, ax=axes[1, 1])

plt.tight_layout()
plt.show()

## Save Checkpoint

In [None]:
checkpoint = {
    'params': state.params,
    'config': {
        'grid_size': GRID_SIZE,
        'num_channels': NUM_CHANNELS,
        'hidden_dim': 128,
        'goal_conditioned': True,
        'phase': 2,
        'task': 'transitions',
        'num_formations': num_formations,
        'formation_names': formation_names,
    },
    'metrics': {
        'losses': losses,
        'form_losses': form_losses,
        'mass_losses': mass_losses,
        'best_loss': best_loss,
        'transition_history': transition_history,
    }
}

with open(PHASE2_CHECKPOINT, 'wb') as f:
    pickle.dump(checkpoint, f)

print(f"Checkpoint saved to {PHASE2_CHECKPOINT}")
print(f"Best loss: {best_loss:.6f}")

---

# Evaluation: Formation Transitions

Test that the model can transition between any pair of formations.

In [None]:
def run_transition(start_state, params, key, target_formation_idx, target_angle, num_steps=150):
    """Run a formation transition."""
    signal = create_formation_signal(1, GRID_SIZE, GRID_SIZE, target_formation_idx, target_angle, num_formations)
    signal_single = signal[0]
    
    trajectory = [start_state]
    state_curr = start_state
    
    for i in range(num_steps):
        key, subkey = jax.random.split(key)
        state_curr = state_curr.at[..., CHILD_CHANNELS.PARENT_SIGNAL_START:CHILD_CHANNELS.PARENT_SIGNAL_END].set(
            signal_single[..., :2]
        )
        state_curr = child_nca.apply(
            {'params': params}, state_curr, subkey, parent_signal=signal_single
        )
        trajectory.append(state_curr)
    
    return trajectory

print("Transition function ready.")

In [None]:
# Test all pairwise transitions (5x5 = 25 pairs)
print("Testing all formation transitions...")

transition_mses = np.zeros((num_formations, num_formations))
transition_retentions = np.zeros((num_formations, num_formations))

key, eval_key = jax.random.split(key)

for src_idx, src_name in enumerate(formation_names):
    # Create starting state from source formation
    eval_key, state_key = jax.random.split(eval_key)
    start_state = create_formation_state(targets[src_name], state_key)
    initial_mass = float(jnp.sum(start_state[..., 3]))
    
    for tgt_idx, tgt_name in enumerate(formation_names):
        eval_key, run_key = jax.random.split(eval_key)
        
        # Run transition
        trajectory = run_transition(start_state, state.params, run_key, tgt_idx, 0.0, num_steps=150)
        final = trajectory[-1]
        
        # Compute metrics
        target = targets[tgt_name]
        mse = float(np.mean((np.array(target[..., :4]) - np.array(final[..., :4]))**2))
        final_mass = float(jnp.sum(final[..., 3]))
        retention = final_mass / initial_mass
        
        transition_mses[src_idx, tgt_idx] = mse
        transition_retentions[src_idx, tgt_idx] = retention

print("Done!")

In [None]:
# Visualize transition matrices
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# MSE matrix
im1 = axes[0].imshow(transition_mses, cmap='Reds', vmin=0)
axes[0].set_xticks(range(num_formations))
axes[0].set_yticks(range(num_formations))
axes[0].set_xticklabels(formation_names, rotation=45, ha='right')
axes[0].set_yticklabels(formation_names)
axes[0].set_xlabel('Target Formation')
axes[0].set_ylabel('Source Formation')
axes[0].set_title('Transition MSE (lower is better)')
plt.colorbar(im1, ax=axes[0])

# Add text annotations
for i in range(num_formations):
    for j in range(num_formations):
        axes[0].text(j, i, f'{transition_mses[i,j]:.3f}', ha='center', va='center', fontsize=8)

# Retention matrix
im2 = axes[1].imshow(transition_retentions, cmap='Greens', vmin=0.5, vmax=1.5)
axes[1].set_xticks(range(num_formations))
axes[1].set_yticks(range(num_formations))
axes[1].set_xticklabels(formation_names, rotation=45, ha='right')
axes[1].set_yticklabels(formation_names)
axes[1].set_xlabel('Target Formation')
axes[1].set_ylabel('Source Formation')
axes[1].set_title('Mass Retention (1.0 = perfect)')
plt.colorbar(im2, ax=axes[1])

for i in range(num_formations):
    for j in range(num_formations):
        axes[1].text(j, i, f'{transition_retentions[i,j]:.2f}', ha='center', va='center', fontsize=8)

plt.tight_layout()
plt.show()

In [None]:
# Visualize specific transitions
test_transitions = [
    ('line', 'wedge'),
    ('wedge', 'square'),
    ('square', 'column'),
    ('column', 'phalanx'),
]

fig, axes = plt.subplots(len(test_transitions), 6, figsize=(18, 3*len(test_transitions)))

timesteps = [0, 30, 60, 90, 120, 150]

for row, (src_name, tgt_name) in enumerate(test_transitions):
    src_idx = formation_names.index(src_name)
    tgt_idx = formation_names.index(tgt_name)
    
    eval_key, state_key, run_key = jax.random.split(eval_key, 3)
    start_state = create_formation_state(targets[src_name], state_key)
    trajectory = run_transition(start_state, state.params, run_key, tgt_idx, 0.0, num_steps=150)
    
    for col, t in enumerate(timesteps):
        if t < len(trajectory):
            axes[row, col].imshow(trajectory[t][..., 3], cmap='gray', vmin=0, vmax=1)
        axes[row, col].set_title(f't={t}')
        axes[row, col].axis('off')
    
    axes[row, 0].set_ylabel(f'{src_name} -> {tgt_name}', fontsize=10)

plt.suptitle('Formation Transitions Over Time', fontsize=14)
plt.tight_layout()
plt.show()

In [None]:
# Test rotation during transition
print("Testing line -> wedge with different target rotations...")

angles_test = [0, np.pi/4, np.pi/2, np.pi]
angle_labels = ['0', '45', '90', '180']

fig, axes = plt.subplots(2, len(angles_test), figsize=(3*len(angles_test), 6))

eval_key, state_key = jax.random.split(eval_key)
start_state = create_formation_state(targets['line'], state_key)

for col, (angle, label) in enumerate(zip(angles_test, angle_labels)):
    eval_key, run_key = jax.random.split(eval_key)
    
    # Target wedge at this angle
    rotated_target = rotate_formation(targets['wedge'], float(angle))
    wedge_idx = formation_names.index('wedge')
    
    trajectory = run_transition(start_state, state.params, run_key, wedge_idx, float(angle), num_steps=150)
    final = trajectory[-1]
    
    # Target
    axes[0, col].imshow(rotated_target[..., 3], cmap='gray', vmin=0, vmax=1)
    axes[0, col].set_title(f'Target @ {label}')
    axes[0, col].axis('off')
    
    # Result
    mse = np.mean((np.array(rotated_target[..., 3]) - np.array(final[..., 3]))**2)
    axes[1, col].imshow(final[..., 3], cmap='gray', vmin=0, vmax=1)
    axes[1, col].set_title(f'Result (MSE: {mse:.4f})')
    axes[1, col].axis('off')

plt.suptitle('Line -> Wedge with Rotation', fontsize=14)
plt.tight_layout()
plt.show()

## Success Criteria Check

In [None]:
print("=" * 60)
print("PHASE 2 SUCCESS CRITERIA: Formation Transitions")
print("=" * 60)

# Analyze transition matrix
avg_mse = np.mean(transition_mses)
max_mse = np.max(transition_mses)
avg_retention = np.mean(transition_retentions)
min_retention = np.min(transition_retentions)

# Off-diagonal (actual transitions, not same-to-same)
off_diag_mask = ~np.eye(num_formations, dtype=bool)
avg_transition_mse = np.mean(transition_mses[off_diag_mask])
avg_transition_retention = np.mean(transition_retentions[off_diag_mask])

print("\nOverall Metrics:")
print(f"  Average MSE (all pairs): {avg_mse:.6f}")
print(f"  Average MSE (transitions only): {avg_transition_mse:.6f}")
print(f"  Worst MSE: {max_mse:.6f}")
print(f"  Average retention: {avg_retention:.1%}")
print(f"  Worst retention: {min_retention:.1%}")

print("\nPer-Transition Analysis:")
for src_idx, src_name in enumerate(formation_names):
    for tgt_idx, tgt_name in enumerate(formation_names):
        if src_idx != tgt_idx:
            mse = transition_mses[src_idx, tgt_idx]
            ret = transition_retentions[src_idx, tgt_idx]
            status = "OK" if mse < 0.02 and ret > 0.85 else "NEEDS WORK"
            print(f"  {src_name:8s} -> {tgt_name:8s}: MSE={mse:.4f}, Retention={ret:.1%} [{status}]")

# Success criteria
print("\n" + "=" * 60)
criteria = [
    ("Avg transition MSE < 0.02", avg_transition_mse, 0.02, avg_transition_mse < 0.02),
    ("Worst MSE < 0.05", max_mse, 0.05, max_mse < 0.05),
    ("Avg retention > 85%", avg_transition_retention, 0.85, avg_transition_retention > 0.85),
    ("Worst retention > 70%", min_retention, 0.70, min_retention > 0.70),
]

all_passed = True
for name, value, threshold, passed in criteria:
    status = "PASS" if passed else "FAIL"
    print(f"{name}: {value:.4f} (threshold: {threshold}) [{status}]")
    if not passed:
        all_passed = False

print("=" * 60)
if all_passed:
    print("ALL CRITERIA PASSED - Ready for Phase 3!")
    print("\nNext: Implement combat dynamics with adversarial armies")
else:
    print("SOME CRITERIA FAILED - Consider more training")
    print("\nTry: Set RESET=False and run training again")