In [None]:
# Setup: Pull latest code and install package with CUDA support
# Run this cell first. If imports fail, restart kernel and run again.

import subprocess
import sys
import os

REPO_URL = "https://github.com/JackHopkins/FormationHNCA.git"

# Detect environment and set repo path
if os.path.exists("/content"):  # Google Colab
    REPO_DIR = "/content/FormationHNCA"
elif os.path.exists("/workspace"):  # Lambda Labs / similar
    REPO_DIR = "/workspace/FormationHNCA"
else:
    REPO_DIR = os.path.expanduser("~/FormationHNCA")

# Clone or pull latest
if os.path.exists(REPO_DIR):
    print(f"Pulling latest changes in {REPO_DIR}...")
    result = subprocess.run(["git", "-C", REPO_DIR, "pull"], capture_output=True, text=True)
    print(result.stdout or "Already up to date.")
else:
    print(f"Cloning repository to {REPO_DIR}...")
    subprocess.run(["git", "clone", REPO_URL, REPO_DIR], check=True)

os.chdir(REPO_DIR)

# Install JAX with CUDA support (for H100/GPU)
print("Installing JAX with CUDA support...")
subprocess.run([
    sys.executable, "-m", "pip", "install", "-q",
    "jax[cuda12]"
], check=True)

# Install package
print("Installing battle-nca package...")
subprocess.run([sys.executable, "-m", "pip", "install", "-q", "-e", "."], check=True)

# Add src to path as fallback
src_path = os.path.join(REPO_DIR, "src")
if src_path not in sys.path:
    sys.path.insert(0, src_path)

print(f"\nWorking directory: {os.getcwd()}")
print("Setup complete!")

# Battle NCA Evaluation

This notebook provides comprehensive evaluation of trained Battle NCA models, including:
- Formation quality metrics (MSE, IoU, coverage)
- Stability analysis (temporal variance, convergence)
- Combat effectiveness metrics
- Regeneration capability testing

In [None]:
# Core imports
import sys
sys.path.insert(0, 'src')

import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
import pickle
from pathlib import Path

# Battle NCA imports
from battle_nca.hierarchy import ChildNCA
from battle_nca.hierarchy.child_nca import create_army_seed, CHILD_CHANNELS
from battle_nca.combat import FormationTargets
from battle_nca.combat.formations import FormationTypes, measure_formation_quality
from battle_nca.utils.metrics import (
    compute_army_statistics,
    compute_formation_metrics,
    compute_trajectory_metrics,
    compute_stability_metrics,
    print_battle_summary
)
from battle_nca.utils.visualization import (
    render_state,
    visualize_channels,
    plot_training_curves
)

print(f"JAX devices: {jax.devices()}")

## Load Trained Model

In [None]:
# Load checkpoint
checkpoint_path = Path('checkpoints/battle_nca_trained.pkl')

if checkpoint_path.exists():
    with open(checkpoint_path, 'rb') as f:
        checkpoint = pickle.load(f)
    
    params = checkpoint['params']
    model_config = checkpoint['config']
    metrics = checkpoint['metrics']
    
    print("Loaded trained model")
    print(f"  Grid size: {model_config['grid_size']}")
    print(f"  Channels: {model_config['num_channels']}")
else:
    print("No checkpoint found. Please run 01_training.ipynb first.")
    print("Creating untrained model for demonstration...")
    
    model_config = {
        'grid_size': 64,
        'num_channels': 24,
        'hidden_dim': 128
    }
    
    # Initialize fresh model
    child_nca = ChildNCA(
        num_channels=model_config['num_channels'],
        hidden_dim=model_config['hidden_dim']
    )
    
    seed = create_army_seed(
        model_config['grid_size'],
        model_config['grid_size']
    )
    
    variables = child_nca.init(
        jax.random.PRNGKey(0),
        seed,
        jax.random.PRNGKey(1)
    )
    params = variables['params']
    metrics = None

In [None]:
# Create model instance
GRID_SIZE = model_config['grid_size']
NUM_CHANNELS = model_config['num_channels']

child_nca = ChildNCA(
    num_channels=NUM_CHANNELS,
    hidden_dim=model_config['hidden_dim'],
    fire_rate=0.5
)

# Create seed
seed = create_army_seed(
    GRID_SIZE, GRID_SIZE,
    team_color=(1.0, 0.0, 0.0),
    spawn_region=(GRID_SIZE//2-2, GRID_SIZE//2+2, GRID_SIZE//2-2, GRID_SIZE//2+2)
)

# Create targets
targets = {
    'line': FormationTargets.line(GRID_SIZE, GRID_SIZE),
    'phalanx': FormationTargets.phalanx(GRID_SIZE, GRID_SIZE, depth=8),
    'square': FormationTargets.square(GRID_SIZE, GRID_SIZE),
}

## Formation Quality Evaluation

In [None]:
# Run model forward
key = jax.random.PRNGKey(42)

def run_model(state, params, key, num_steps=100):
    """Run model forward for num_steps."""
    trajectory = [state]
    
    for i in range(num_steps):
        key, subkey = jax.random.split(key)
        state = child_nca.apply({'params': params}, state, subkey)
        trajectory.append(state)
    
    return jnp.stack(trajectory)

# Run forward
trajectory = run_model(seed, params, key, num_steps=100)
print(f"Trajectory shape: {trajectory.shape}")

In [None]:
# Compute metrics over trajectory for line formation
target = targets['line']
traj_metrics = compute_trajectory_metrics(trajectory, target)

# Plot metrics over time
fig, axes = plt.subplots(2, 3, figsize=(15, 8))

axes[0, 0].plot(traj_metrics['mse'])
axes[0, 0].set_title('MSE')
axes[0, 0].set_xlabel('Step')

axes[0, 1].plot(traj_metrics['iou'])
axes[0, 1].set_title('IoU')
axes[0, 1].set_xlabel('Step')

axes[0, 2].plot(traj_metrics['alive_cells'])
axes[0, 2].set_title('Alive Cells')
axes[0, 2].set_xlabel('Step')

axes[1, 0].plot(traj_metrics['mean_health'])
axes[1, 0].set_title('Mean Health')
axes[1, 0].set_xlabel('Step')

axes[1, 1].plot(traj_metrics['mean_morale'])
axes[1, 1].set_title('Mean Morale')
axes[1, 1].set_xlabel('Step')

# Final state comparison
axes[1, 2].imshow(trajectory[-1, ..., 3], cmap='gray')
axes[1, 2].set_title('Final Alpha')
axes[1, 2].axis('off')

plt.tight_layout()
plt.suptitle('Formation Metrics Over Time', y=1.02)
plt.show()

In [None]:
# Compute final formation metrics for all targets
final_state = trajectory[-1]

print("Formation Quality Metrics (Final State)")
print("=" * 50)

for name, target in targets.items():
    metrics = compute_formation_metrics(final_state, target)
    print(f"\n{name.upper()}:")
    print(f"  MSE: {metrics['mse']:.4f}")
    print(f"  IoU: {metrics['iou']:.4f}")
    print(f"  Coverage: {metrics['coverage']:.4f}")
    print(f"  Precision: {metrics['precision']:.4f}")
    print(f"  F1: {metrics['f1']:.4f}")

## Stability Analysis

In [None]:
# Run for longer to test stability
key = jax.random.PRNGKey(123)
long_trajectory = run_model(seed, params, key, num_steps=200)

# Compute stability metrics
stability = compute_stability_metrics(long_trajectory, targets['line'], window=30)

print("Stability Metrics")
print("=" * 50)
print(f"Temporal variance: {stability['temporal_variance']:.6f}")
print(f"Mean final loss: {stability['mean_final_loss']:.6f}")
print(f"Loss variance: {stability['loss_variance']:.6f}")
print(f"Convergence ratio: {stability['convergence_ratio']:.4f}")
print(f"  (< 1 = converging, > 1 = diverging)")

In [None]:
# Visualize stability: show variance in final 30 steps
final_30 = long_trajectory[-30:]
alpha_mean = np.mean(final_30[..., 3], axis=0)
alpha_std = np.std(final_30[..., 3], axis=0)

fig, axes = plt.subplots(1, 3, figsize=(12, 4))

axes[0].imshow(alpha_mean, cmap='gray', vmin=0, vmax=1)
axes[0].set_title('Alpha Mean (last 30 steps)')
axes[0].axis('off')

axes[1].imshow(alpha_std, cmap='hot')
axes[1].set_title('Alpha Std (temporal variance)')
axes[1].axis('off')
plt.colorbar(axes[1].images[0], ax=axes[1])

axes[2].imshow(targets['line'][..., 3], cmap='gray', vmin=0, vmax=1)
axes[2].set_title('Target')
axes[2].axis('off')

plt.tight_layout()
plt.show()

## Regeneration Testing

Test the model's ability to recover from damage.

In [None]:
# First, grow to formation
key = jax.random.PRNGKey(456)
grown_state = run_model(seed, params, key, num_steps=80)[-1]

# Apply damage
damaged_state = np.array(grown_state.copy())

# Create circular damage
cy, cx = GRID_SIZE // 2, GRID_SIZE // 2
radius = 10
y, x = np.ogrid[:GRID_SIZE, :GRID_SIZE]
mask = (x - cx)**2 + (y - cy)**2 <= radius**2

damaged_state[mask] = 0.0
damaged_state = jnp.array(damaged_state)

# Visualize damage
fig, axes = plt.subplots(1, 3, figsize=(12, 4))

axes[0].imshow(grown_state[..., :4])
axes[0].set_title('Before Damage')
axes[0].axis('off')

axes[1].imshow(damaged_state[..., :4])
axes[1].set_title('After Damage')
axes[1].axis('off')

axes[2].imshow(mask, cmap='Reds', alpha=0.5)
axes[2].set_title('Damage Mask')
axes[2].axis('off')

plt.tight_layout()
plt.show()

In [None]:
# Run recovery
key, subkey = jax.random.split(key)
recovery_trajectory = run_model(damaged_state, params, subkey, num_steps=100)

# Compute metrics during recovery
recovery_metrics = compute_trajectory_metrics(recovery_trajectory, targets['line'])

# Plot recovery
fig, axes = plt.subplots(2, 5, figsize=(15, 6))

steps_to_show = [0, 10, 25, 50, 100]

for i, step in enumerate(steps_to_show):
    # RGBA
    axes[0, i].imshow(recovery_trajectory[step, ..., :4])
    axes[0, i].set_title(f'Step {step}')
    axes[0, i].axis('off')
    
    # Alpha
    axes[1, i].imshow(recovery_trajectory[step, ..., 3], cmap='gray', vmin=0, vmax=1)
    axes[1, i].axis('off')

axes[0, 0].set_ylabel('RGBA')
axes[1, 0].set_ylabel('Alpha')

plt.suptitle('Regeneration After Damage')
plt.tight_layout()
plt.show()

In [None]:
# Plot recovery metrics
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

axes[0].plot(recovery_metrics['mse'])
axes[0].set_title('MSE During Recovery')
axes[0].set_xlabel('Step')
axes[0].set_ylabel('MSE')
axes[0].axhline(y=recovery_metrics['mse'][0], color='r', linestyle='--', 
                label='Initial (damaged)')
axes[0].legend()

axes[1].plot(recovery_metrics['alive_cells'])
axes[1].set_title('Alive Cells During Recovery')
axes[1].set_xlabel('Step')
axes[1].set_ylabel('Alive Cells')
axes[1].axhline(y=recovery_metrics['alive_cells'][0], color='r', linestyle='--',
                label='Initial (damaged)')
axes[1].legend()

plt.tight_layout()
plt.show()

# Recovery statistics
initial_mse = recovery_metrics['mse'][0]
final_mse = recovery_metrics['mse'][-1]
recovery_rate = (initial_mse - final_mse) / initial_mse * 100

print(f"\nRecovery Statistics:")
print(f"  Initial MSE: {initial_mse:.4f}")
print(f"  Final MSE: {final_mse:.4f}")
print(f"  Recovery: {recovery_rate:.1f}%")

## Channel Analysis

In [None]:
# Visualize all channels of final state
visualize_channels(
    long_trajectory[-1],
    channels=list(range(15)),  # First 15 channels
    show=True
)

In [None]:
# Army statistics
stats = compute_army_statistics(long_trajectory[-1])

print("Army Statistics (Final State)")
print("=" * 40)
print(f"Total units: {stats.total_units}")
print(f"Alive units: {stats.alive_units}")
print(f"Average health: {stats.average_health:.2%}")
print(f"Average morale: {stats.average_morale:+.2f}")
print(f"Average fatigue: {stats.average_fatigue:.2%}")
print(f"Routing units: {stats.routing_units}")
print(f"Unit density: {stats.unit_density:.4f}")

## Summary

This notebook evaluated:
1. **Formation quality** - How well the NCA matches target formations
2. **Stability** - How consistent the pattern remains over time
3. **Regeneration** - How well the NCA recovers from damage
4. **Channel analysis** - Understanding internal state representations

**Key metrics to monitor:**
- IoU > 0.5 indicates good formation matching
- Temporal variance < 0.01 indicates stable attractors
- Convergence ratio < 0.5 indicates fast convergence

**Next steps:**
- See `03_visualization.ipynb` for animations and visual analysis