In [None]:
# Setup: Pull latest code and install package with CUDA support
# Run this cell first. If imports fail, restart kernel and run again.

import subprocess
import sys
import os

REPO_URL = "https://github.com/JackHopkins/FormationHNCA.git"

# Detect environment and set repo path
if os.path.exists("/content"):  # Google Colab
    REPO_DIR = "/content/FormationHNCA"
elif os.path.exists("/workspace"):  # Lambda Labs / similar
    REPO_DIR = "/workspace/FormationHNCA"
else:
    REPO_DIR = os.path.expanduser("~/FormationHNCA")

# Clone or pull latest
if os.path.exists(REPO_DIR):
    print(f"Pulling latest changes in {REPO_DIR}...")
    result = subprocess.run(["git", "-C", REPO_DIR, "pull"], capture_output=True, text=True)
    print(result.stdout or "Already up to date.")
else:
    print(f"Cloning repository to {REPO_DIR}...")
    subprocess.run(["git", "clone", REPO_URL, REPO_DIR], check=True)

os.chdir(REPO_DIR)

# Install JAX with CUDA support (for H100/GPU)
print("Installing JAX with CUDA support...")
subprocess.run([
    sys.executable, "-m", "pip", "install", "-q",
    "jax[cuda12]"
], check=True)

# Install package
print("Installing battle-nca package...")
subprocess.run([sys.executable, "-m", "pip", "install", "-q", "-e", "."], check=True)

# Add src to path as fallback
src_path = os.path.join(REPO_DIR, "src")
if src_path not in sys.path:
    sys.path.insert(0, src_path)

print(f"\nWorking directory: {os.getcwd()}")
print("Setup complete!")

# Battle NCA Visualization

This notebook provides comprehensive visualization tools for Battle NCA:
- Formation evolution animations
- Two-army battle simulations
- Channel-wise visualizations
- Interactive exploration

In [None]:
# Core imports
import sys
sys.path.insert(0, 'src')

import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from matplotlib.colors import LinearSegmentedColormap
from IPython.display import HTML, display
import pickle
from pathlib import Path

# Battle NCA imports
from battle_nca.hierarchy import ChildNCA, HierarchicalNCA
from battle_nca.hierarchy.child_nca import create_army_seed, CHILD_CHANNELS
from battle_nca.hierarchy.hnca import create_battle_scenario, BattleSimulator
from battle_nca.combat import FormationTargets
from battle_nca.utils.visualization import (
    render_state,
    render_battle,
    create_animation,
    create_battle_animation,
    visualize_channels,
    plot_training_curves
)

# Enable inline animations
plt.rcParams['animation.html'] = 'jshtml'

print(f"JAX devices: {jax.devices()}")

## Load Model

In [None]:
# Configuration
GRID_SIZE = 64
NUM_CHANNELS = 24

# Load or create model
checkpoint_path = Path('checkpoints/battle_nca_trained.pkl')

if checkpoint_path.exists():
    with open(checkpoint_path, 'rb') as f:
        checkpoint = pickle.load(f)
    params = checkpoint['params']
    print("Loaded trained model")
else:
    print("No checkpoint found, initializing fresh model")
    child_nca = ChildNCA(num_channels=NUM_CHANNELS, hidden_dim=128)
    seed = create_army_seed(GRID_SIZE, GRID_SIZE)
    variables = child_nca.init(jax.random.PRNGKey(0), seed, jax.random.PRNGKey(1))
    params = variables['params']

# Create model
child_nca = ChildNCA(num_channels=NUM_CHANNELS, hidden_dim=128)

# Create seed
seed = create_army_seed(
    GRID_SIZE, GRID_SIZE,
    team_color=(1.0, 0.0, 0.0),
    spawn_region=(GRID_SIZE//2-2, GRID_SIZE//2+2, GRID_SIZE//2-2, GRID_SIZE//2+2)
)

## Formation Evolution Animation

In [None]:
# Run model and collect trajectory
key = jax.random.PRNGKey(42)
state = seed
trajectory = [state]

for i in range(100):
    key, subkey = jax.random.split(key)
    state = child_nca.apply({'params': params}, state, subkey)
    trajectory.append(state)

trajectory = jnp.stack(trajectory)
print(f"Trajectory shape: {trajectory.shape}")

In [None]:
# Create RGBA animation
anim = create_animation(
    trajectory,
    mode='rgba',
    fps=15,
    title='Formation Growth'
)

# Display in notebook
HTML(anim.to_jshtml())

In [None]:
# Create alpha channel animation
anim_alpha = create_animation(
    trajectory,
    mode='alpha',
    fps=15,
    title='Alpha Channel Evolution'
)

HTML(anim_alpha.to_jshtml())

## Multi-Channel Visualization

In [None]:
# Create custom multi-channel animation
fig, axes = plt.subplots(2, 3, figsize=(12, 8))

channel_configs = [
    ('RGBA', slice(0, 4), 'rgba'),
    ('Alpha', 3, 'gray'),
    ('Health', 4, 'RdYlGn'),
    ('Morale', 5, 'RdYlGn'),
    ('Velocity X', 7, 'coolwarm'),
    ('Hidden Mean', slice(15, 24), 'viridis'),
]

ims = []
for ax, (name, ch, cmap) in zip(axes.flat, channel_configs):
    ax.set_title(name)
    ax.axis('off')
    
    if name == 'RGBA':
        data = np.clip(trajectory[0, ..., :4], 0, 1)
        rgb = data[..., :3]
        alpha = data[..., 3:4]
        display_data = rgb * alpha + (1 - alpha)
        im = ax.imshow(display_data)
    elif name == 'Hidden Mean':
        im = ax.imshow(trajectory[0, ..., ch].mean(axis=-1), cmap=cmap)
    else:
        im = ax.imshow(trajectory[0, ..., ch], cmap=cmap)
    ims.append(im)

plt.tight_layout()

def update(frame):
    for im, (name, ch, _) in zip(ims, channel_configs):
        if name == 'RGBA':
            data = np.clip(trajectory[frame, ..., :4], 0, 1)
            rgb = data[..., :3]
            alpha = data[..., 3:4]
            im.set_array(rgb * alpha + (1 - alpha))
        elif name == 'Hidden Mean':
            im.set_array(trajectory[frame, ..., ch].mean(axis=-1))
        else:
            im.set_array(trajectory[frame, ..., ch])
    return ims

multi_anim = animation.FuncAnimation(
    fig, update, frames=len(trajectory),
    interval=100, blit=True
)

HTML(multi_anim.to_jshtml())

## Two-Army Battle Visualization

In [None]:
# Create battle scenario
scenario = create_battle_scenario(grid_size=GRID_SIZE, cluster_size=4)

red_child = scenario['red_child']
blue_child = scenario['blue_child']

# Visualize initial positions
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

axes[0].imshow(np.clip(red_child[..., :4], 0, 1))
axes[0].set_title('Red Army Initial')
axes[0].axis('off')

axes[1].imshow(np.clip(blue_child[..., :4], 0, 1))
axes[1].set_title('Blue Army Initial')
axes[1].axis('off')

plt.tight_layout()
plt.show()

In [None]:
# Initialize separate models for each army
red_nca = ChildNCA(num_channels=NUM_CHANNELS, hidden_dim=128)
blue_nca = ChildNCA(num_channels=NUM_CHANNELS, hidden_dim=128)

# Use same params for both (in practice, could be different)
red_params = params
blue_params = params

# Run battle simulation
key = jax.random.PRNGKey(789)
red_state = red_child
blue_state = blue_child

red_trajectory = [red_state]
blue_trajectory = [blue_state]

for step in range(100):
    key, k1, k2 = jax.random.split(key, 3)
    
    # Each army sees the other as enemy
    red_state = red_nca.apply(
        {'params': red_params}, 
        red_state, 
        k1,
        enemy_state=blue_state
    )
    blue_state = blue_nca.apply(
        {'params': blue_params}, 
        blue_state, 
        k2,
        enemy_state=red_state
    )
    
    red_trajectory.append(red_state)
    blue_trajectory.append(blue_state)

red_trajectory = jnp.stack(red_trajectory)
blue_trajectory = jnp.stack(blue_trajectory)

print(f"Battle trajectories: {red_trajectory.shape}")

In [None]:
# Create battle animation
battle_anim = create_battle_animation(
    red_trajectory,
    blue_trajectory,
    fps=15
)

HTML(battle_anim.to_jshtml())

In [None]:
# Combined battlefield view
fig, ax = plt.subplots(figsize=(10, 10))

def combine_armies(red, blue):
    """Combine red and blue army views into single image."""
    red = np.array(red)
    blue = np.array(blue)
    
    # Create combined RGBA
    red_rgba = np.clip(red[..., :4], 0, 1)
    blue_rgba = np.clip(blue[..., :4], 0, 1)
    
    # Blend based on alpha
    red_weight = red_rgba[..., 3:4]
    blue_weight = blue_rgba[..., 3:4]
    total_weight = red_weight + blue_weight + 0.001
    
    combined_rgb = (
        red_rgba[..., :3] * red_weight + 
        blue_rgba[..., :3] * blue_weight
    ) / total_weight
    
    combined_alpha = np.maximum(red_rgba[..., 3], blue_rgba[..., 3])
    
    # On white background
    result = combined_rgb * combined_alpha[..., None] + (1 - combined_alpha[..., None])
    
    return result

im = ax.imshow(combine_armies(red_trajectory[0], blue_trajectory[0]))
ax.axis('off')
ax.set_title('Battlefield')

def update(frame):
    im.set_array(combine_armies(red_trajectory[frame], blue_trajectory[frame]))
    ax.set_title(f'Battlefield - Step {frame}')
    return [im]

combined_anim = animation.FuncAnimation(
    fig, update, frames=len(red_trajectory),
    interval=100, blit=True
)

HTML(combined_anim.to_jshtml())

## Damage and Regeneration Animation

In [None]:
# Grow formation first
key = jax.random.PRNGKey(111)
state = seed

for i in range(80):
    key, subkey = jax.random.split(key)
    state = child_nca.apply({'params': params}, state, subkey)

grown_state = state

# Apply damage
damaged = np.array(grown_state)
cy, cx = GRID_SIZE // 2, GRID_SIZE // 2
radius = 12
y, x = np.ogrid[:GRID_SIZE, :GRID_SIZE]
mask = (x - cx)**2 + (y - cy)**2 <= radius**2
damaged[mask] = 0.0
damaged_state = jnp.array(damaged)

# Collect regeneration trajectory
regen_trajectory = [damaged_state]
state = damaged_state

for i in range(80):
    key, subkey = jax.random.split(key)
    state = child_nca.apply({'params': params}, state, subkey)
    regen_trajectory.append(state)

regen_trajectory = jnp.stack(regen_trajectory)
print(f"Regeneration trajectory: {regen_trajectory.shape}")

In [None]:
# Create regeneration animation with comparison
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# Before damage
axes[0].imshow(np.clip(grown_state[..., :4], 0, 1))
axes[0].set_title('Before Damage')
axes[0].axis('off')

# Regeneration (animated)
im_regen = axes[1].imshow(np.clip(regen_trajectory[0, ..., :4], 0, 1))
axes[1].set_title('Regenerating...')
axes[1].axis('off')

# Target
target = FormationTargets.line(GRID_SIZE, GRID_SIZE)
axes[2].imshow(target[..., 3], cmap='gray', vmin=0, vmax=1)
axes[2].set_title('Target')
axes[2].axis('off')

def update(frame):
    img_data = np.clip(regen_trajectory[frame, ..., :4], 0, 1)
    rgb = img_data[..., :3]
    alpha = img_data[..., 3:4]
    display = rgb * alpha + (1 - alpha)
    im_regen.set_array(display)
    axes[1].set_title(f'Regenerating... (Step {frame})')
    return [im_regen]

regen_anim = animation.FuncAnimation(
    fig, update, frames=len(regen_trajectory),
    interval=100, blit=True
)

HTML(regen_anim.to_jshtml())

## Formation Morphing

In [None]:
# Show different formations the model can produce
# (Note: this requires training on multiple formations)

fig, axes = plt.subplots(2, 5, figsize=(15, 6))

key = jax.random.PRNGKey(222)

for run in range(2):
    state = seed
    steps_to_show = [0, 20, 40, 60, 80]
    
    key, subkey = jax.random.split(key)
    
    for i, step in enumerate(steps_to_show):
        # Run to this step
        while state.shape == seed.shape:  # Always true, just for counting
            if i == 0:
                break
            current_step = steps_to_show[i-1] if i > 0 else 0
            for _ in range(step - current_step):
                key, subkey = jax.random.split(key)
                state = child_nca.apply({'params': params}, state, subkey)
            break
        
        axes[run, i].imshow(np.clip(state[..., :4], 0, 1))
        axes[run, i].set_title(f'Step {step}')
        axes[run, i].axis('off')

plt.suptitle('Multiple Runs Showing Formation Evolution')
plt.tight_layout()
plt.show()

## Save Animations

In [None]:
# Save animations to files (uncomment to use)
output_dir = Path('outputs')
output_dir.mkdir(exist_ok=True)

# # Save as GIF
# anim.save(output_dir / 'formation_evolution.gif', writer='pillow', fps=15)
# print(f"Saved to {output_dir / 'formation_evolution.gif'}")

# # Save as MP4 (requires ffmpeg)
# battle_anim.save(output_dir / 'battle_simulation.mp4', writer='ffmpeg', fps=15)
# print(f"Saved to {output_dir / 'battle_simulation.mp4'}")

print("Uncomment the save commands to export animations")

## Summary

This notebook demonstrated:
1. **Formation animations** - Watching NCA patterns grow and stabilize
2. **Multi-channel views** - Understanding internal state representations
3. **Battle simulations** - Two armies interacting
4. **Regeneration** - Recovery after damage

**Tips for better visualizations:**
- Increase `GRID_SIZE` to 200 for larger battles
- Adjust `fps` parameter for animation speed
- Use `create_battle_scenario()` for custom starting positions