# Notebook 5: Two-Swarm Convergence with HoloViews

**Visualization Stack**: HoloViews + Panel with Bokeh backend (NO matplotlib!)

**Goal**: Visualize two independent swarms converging to the same QSD with:
- Framework-correct Lyapunov functions
- **Proper boundary handling** (walkers die and revive)
- Interactive visualizations using HoloViz stack

**Key Features**:
1. Bounded domain with walker death/resurrection
2. Track alive/dead walker counts
3. Visualize cloning events
4. Interactive plots with Bokeh backend

## Setup and Imports

In [None]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.insert(0, '../src')

import torch
import numpy as np
import pandas as pd

# HoloViews stack (NO matplotlib!)
import holoviews as hv
from holoviews import opts
import panel as pn

# Enable Bokeh backend
hv.extension('bokeh')
pn.extension()

from tqdm.notebook import tqdm

# Import experiment code
from fragile.experiments import create_multimodal_potential
from fragile.geometric_gas import (
    GeometricGas,
    GeometricGasParams,
    LocalizationKernelParams,
    AdaptiveParams,
)
from fragile.euclidean_gas import LangevinParams
from fragile.bounds import TorchBounds

# Import framework-correct Lyapunov functions
from fragile.lyapunov import (
    compute_internal_variance_position,
    compute_internal_variance_velocity,
    compute_total_lyapunov,
)

# Set random seeds
torch.manual_seed(42)
np.random.seed(42)

print("✓ Imports successful")
print(f"✓ HoloViews version: {hv.__version__}")
print(f"✓ Panel version: {pn.__version__}")

## 1. Create Target Potential and Define Bounds

In [None]:
# Create multimodal potential
potential, target_mixture = create_multimodal_potential(
    dims=2,
    n_gaussians=3,
    bounds_range=(-8.0, 8.0),
    seed=42
)

# Extract parameters
centers = target_mixture.centers
stds = target_mixture.stds
weights = target_mixture.weights
dims = target_mixture.dims

# **IMPORTANT**: Define bounds to prevent walkers from escaping
bounds = TorchBounds(
    low=torch.tensor([-8.0, -8.0]),
    high=torch.tensor([8.0, 8.0])
)

print(f"✓ Created multimodal potential")
print(f"  Centers: {centers.tolist()}")
print(f"  Weights: {weights.tolist()}")
print(f"\n✓ Bounds defined: [{bounds.low.tolist()}, {bounds.high.tolist()}]")
print(f"  Walkers leaving these bounds will DIE and be RESURRECTED!")

### Visualize Target QSD with Bounds

In [None]:
# Create grid for QSD
x_range = np.linspace(-8, 8, 200)
y_range = np.linspace(-8, 8, 200)
X, Y = np.meshgrid(x_range, y_range)
grid_points = torch.tensor(np.stack([X.ravel(), Y.ravel()], axis=1), dtype=torch.float32)

# Evaluate potential
Z_potential = potential.evaluate(grid_points).detach().numpy().reshape(X.shape)

# Compute target QSD
beta = 1.0
Z_qsd = np.exp(-beta * Z_potential)
Z_qsd = Z_qsd / Z_qsd.sum()

# Create HoloViews plot (Bokeh backend)
qsd_image = hv.Image(
    (x_range, y_range, Z_qsd),
    kdims=['x1', 'x2'],
    vdims='density'
).opts(
    cmap='plasma',
    colorbar=True,
    width=600,
    height=600,
    title='Target QSD with Boundaries',
    xlabel='x₁',
    ylabel='x₂',
    tools=['hover'],
)

# Add mode markers
mode_points = hv.Points(
    [(centers[i, 0].item(), centers[i, 1].item()) for i in range(len(centers))],
    label='Modes'
).opts(
    size=weights.numpy() * 30,
    color='red',
    marker='star',
    line_color='white',
    line_width=2,
)

# Add boundary box
boundary_box = hv.Rectangles(
    [(-8, -8, 8, 8)],
    label='Valid Domain'
).opts(
    color=None,
    line_color='cyan',
    line_width=3,
    line_dash='dashed',
)

# Combine
target_plot = qsd_image * mode_points * boundary_box

target_plot

## 2. Initialize Two Swarms with Bounds

In [None]:
# Parameters
N = 100
n_steps = 3000

def measurement_fn(x):
    return -potential.evaluate(x)

# **IMPORTANT**: Pass bounds to prevent escaping
params = GeometricGasParams(
    N=N,
    d=dims,
    potential=potential,
    langevin=LangevinParams(
        gamma=1.0,
        beta=1.0,
        delta_t=0.05
    ),
    localization=LocalizationKernelParams(
        rho=2.0,
        kernel_type="gaussian"
    ),
    adaptive=AdaptiveParams(
        epsilon_F=0.05,
        nu=0.02,
        epsilon_Sigma=0.01,
        rescale_amplitude=1.0,
        sigma_var_min=0.1,
        viscous_length_scale=2.0
    ),
    bounds=bounds,  # <- Add bounds!
    device="cpu",
    dtype="float32"
)

# Create two independent Gas instances
gas1 = GeometricGas(params, measurement_fn=measurement_fn)
gas2 = GeometricGas(params, measurement_fn=measurement_fn)

# Swarm 1: Upper right (but within bounds)
x1_init = torch.rand(N, dims) * 2.0 + 4.0  # [4, 6] x [4, 6]
v1_init = torch.randn(N, dims) * 0.1
state1 = gas1.initialize_state(x1_init, v1_init)

# Swarm 2: Lower left (but within bounds)
x2_init = torch.rand(N, dims) * 2.0 - 6.0  # [-6, -4] x [-6, -4]
v2_init = torch.randn(N, dims) * 0.1
state2 = gas2.initialize_state(x2_init, v2_init)

# Check initial alive counts
alive1_init = bounds.contains(state1.x).sum().item()
alive2_init = bounds.contains(state2.x).sum().item()

print("✓ Initialized two swarms with bounds")
print(f"  Swarm 1: {alive1_init}/{N} walkers alive initially")
print(f"  Swarm 2: {alive2_init}/{N} walkers alive initially")

### Visualize Initial Configuration

In [None]:
# Get alive masks
alive_mask1 = bounds.contains(state1.x)
alive_mask2 = bounds.contains(state2.x)

# Swarm 1 points (alive=blue, dead=gray)
swarm1_alive = hv.Points(
    state1.x[alive_mask1].detach().numpy(),
    label='Swarm 1 (alive)'
).opts(
    size=8,
    color='blue',
    alpha=0.6,
    line_color='black',
    line_width=0.5,
)

swarm1_dead = hv.Points(
    state1.x[~alive_mask1].detach().numpy() if (~alive_mask1).any() else np.empty((0, 2)),
    label='Swarm 1 (dead)'
).opts(
    size=8,
    color='lightgray',
    alpha=0.3,
    marker='x',
)

# Swarm 2 points (alive=red, dead=gray)
swarm2_alive = hv.Points(
    state2.x[alive_mask2].detach().numpy(),
    label='Swarm 2 (alive)'
).opts(
    size=8,
    color='red',
    alpha=0.6,
    line_color='black',
    line_width=0.5,
)

swarm2_dead = hv.Points(
    state2.x[~alive_mask2].detach().numpy() if (~alive_mask2).any() else np.empty((0, 2)),
    label='Swarm 2 (dead)'
).opts(
    size=8,
    color='lightgray',
    alpha=0.3,
    marker='x',
)

# Combine with background
init_plot = (
    qsd_image * boundary_box * mode_points *
    swarm1_alive * swarm1_dead * swarm2_alive * swarm2_dead
).opts(
    title='Initial Configuration: Two Swarms with Boundaries',
    width=700,
    height=700,
)

init_plot

## 3. Run Simulation with Boundary Tracking

In [None]:
# Storage for metrics
metrics = {
    'time': [],
    # Swarm 1
    'V_total_1': [],
    'V_var_x_1': [],
    'V_var_v_1': [],
    'n_alive_1': [],  # Track alive walker count!
    # Swarm 2
    'V_total_2': [],
    'V_var_x_2': [],
    'V_var_v_2': [],
    'n_alive_2': [],  # Track alive walker count!
    # Inter-swarm
    'com_distance': [],
}

snapshot_times = [0, 100, 500, 1000, 2000, n_steps]
snapshots = {t: {'state1': None, 'state2': None, 'alive1': None, 'alive2': None} 
             for t in snapshot_times}

def compute_metrics_with_bounds(state1, state2, time):
    """Compute all metrics including alive/dead counts."""
    # Get alive masks
    alive_mask1 = bounds.contains(state1.x)
    alive_mask2 = bounds.contains(state2.x)
    
    n_alive_1 = alive_mask1.sum().item()
    n_alive_2 = alive_mask2.sum().item()
    
    # Lyapunov components (only alive walkers count!)
    V_var_x_1 = compute_internal_variance_position(state1, alive_mask1)
    V_var_v_1 = compute_internal_variance_velocity(state1, alive_mask1)
    V_total_1 = V_var_x_1 + V_var_v_1
    
    V_var_x_2 = compute_internal_variance_position(state2, alive_mask2)
    V_var_v_2 = compute_internal_variance_velocity(state2, alive_mask2)
    V_total_2 = V_var_x_2 + V_var_v_2
    
    # Inter-swarm distance (center of mass of alive walkers)
    if n_alive_1 > 0 and n_alive_2 > 0:
        mu_x_1 = state1.x[alive_mask1].mean(dim=0)
        mu_x_2 = state2.x[alive_mask2].mean(dim=0)
        com_distance = torch.norm(mu_x_1 - mu_x_2).item()
    else:
        com_distance = float('nan')
    
    metrics['time'].append(time)
    metrics['V_total_1'].append(V_total_1.item())
    metrics['V_var_x_1'].append(V_var_x_1.item())
    metrics['V_var_v_1'].append(V_var_v_1.item())
    metrics['n_alive_1'].append(n_alive_1)
    metrics['V_total_2'].append(V_total_2.item())
    metrics['V_var_x_2'].append(V_var_x_2.item())
    metrics['V_var_v_2'].append(V_var_v_2.item())
    metrics['n_alive_2'].append(n_alive_2)
    metrics['com_distance'].append(com_distance)

# Initial metrics
compute_metrics_with_bounds(state1, state2, 0)
snapshots[0] = {
    'state1': state1.x.clone(),
    'state2': state2.x.clone(),
    'alive1': bounds.contains(state1.x).clone(),
    'alive2': bounds.contains(state2.x).clone(),
}

print("Running simulation with boundary tracking...\n")

# Main loop
for step in tqdm(range(n_steps), desc="Simulation"):
    # Step both swarms
    _, state1 = gas1.step(state1)
    _, state2 = gas2.step(state2)
    
    # Metrics every 10 steps
    if (step + 1) % 10 == 0:
        compute_metrics_with_bounds(state1, state2, step + 1)
    
    # Snapshots
    if (step + 1) in snapshot_times:
        snapshots[step + 1] = {
            'state1': state1.x.clone(),
            'state2': state2.x.clone(),
            'alive1': bounds.contains(state1.x).clone(),
            'alive2': bounds.contains(state2.x).clone(),
        }

print("\n✓ Simulation complete!")
print(f"  Final alive counts: Swarm 1 = {metrics['n_alive_1'][-1]}/{N}, "
      f"Swarm 2 = {metrics['n_alive_2'][-1]}/{N}")

## 4. Alive Walker Tracking

Visualize how many walkers are alive over time (shows cloning effectiveness)

In [None]:
# Create DataFrame for easier plotting
df = pd.DataFrame(metrics)

# Plot alive walker counts
alive_curve1 = hv.Curve(
    df,
    kdims='time',
    vdims='n_alive_1',
    label='Swarm 1'
).opts(
    color='blue',
    line_width=2,
)

alive_curve2 = hv.Curve(
    df,
    kdims='time',
    vdims='n_alive_2',
    label='Swarm 2'
).opts(
    color='red',
    line_width=2,
)

# Add N reference line
n_ref = hv.HLine(N, label=f'Total ({N})').opts(
    color='gray',
    line_dash='dashed',
    line_width=1,
)

alive_plot = (alive_curve1 * alive_curve2 * n_ref).opts(
    width=800,
    height=400,
    title='Alive Walker Count Over Time',
    xlabel='Time (steps)',
    ylabel='Number of Alive Walkers',
    legend_position='bottom_right',
    tools=['hover'],
)

alive_plot

## 5. Lyapunov Function Decay (Log Scale)

In [None]:
# Lyapunov decay curves (log scale)
lyap_curve1 = hv.Curve(
    df[df['V_total_1'] > 0],  # Filter out zeros for log
    kdims='time',
    vdims='V_total_1',
    label='Swarm 1'
).opts(
    color='blue',
    line_width=2,
    logy=True,
)

lyap_curve2 = hv.Curve(
    df[df['V_total_2'] > 0],
    kdims='time',
    vdims='V_total_2',
    label='Swarm 2'
).opts(
    color='red',
    line_width=2,
    logy=True,
)

lyap_plot = (lyap_curve1 * lyap_curve2).opts(
    width=800,
    height=400,
    title='Framework-Correct Lyapunov Function Decay (Log Scale)',
    xlabel='Time (steps)',
    ylabel='V_total (N-normalized)',
    legend_position='top_right',
    tools=['hover'],
)

lyap_plot

## 6. Variance Components

In [None]:
# Position variance
var_x1 = hv.Curve(
    df[df['V_var_x_1'] > 0],
    kdims='time',
    vdims='V_var_x_1',
    label='V_Var,x (Swarm 1)'
).opts(color='blue', line_width=2, logy=True)

var_x2 = hv.Curve(
    df[df['V_var_x_2'] > 0],
    kdims='time',
    vdims='V_var_x_2',
    label='V_Var,x (Swarm 2)'
).opts(color='red', line_width=2, logy=True)

# Velocity variance
var_v1 = hv.Curve(
    df[df['V_var_v_1'] > 0],
    kdims='time',
    vdims='V_var_v_1',
    label='V_Var,v (Swarm 1)'
).opts(color='blue', line_width=2, line_dash='dashed', logy=True)

var_v2 = hv.Curve(
    df[df['V_var_v_2'] > 0],
    kdims='time',
    vdims='V_var_v_2',
    label='V_Var,v (Swarm 2)'
).opts(color='red', line_width=2, line_dash='dashed', logy=True)

variance_plot = (var_x1 * var_x2 * var_v1 * var_v2).opts(
    width=800,
    height=400,
    title='Variance Components (Position vs Velocity)',
    xlabel='Time (steps)',
    ylabel='Variance (N-normalized, log scale)',
    legend_position='right',
    tools=['hover'],
)

variance_plot

## 7. Inter-Swarm Distance

In [None]:
# Distance between swarms
distance_curve = hv.Curve(
    df.dropna(subset=['com_distance']),
    kdims='time',
    vdims='com_distance',
    label='Center of Mass Distance'
).opts(
    color='purple',
    line_width=2,
    logy=True,
)

distance_plot = distance_curve.opts(
    width=800,
    height=400,
    title='Inter-Swarm Convergence',
    xlabel='Time (steps)',
    ylabel='||μ_x^(1) - μ_x^(2)|| (log scale)',
    tools=['hover'],
)

distance_plot

## 8. Visual Evolution with Alive/Dead Walkers

In [None]:
def create_snapshot_plot(time_idx):
    """Create plot for a single snapshot showing alive/dead walkers."""
    snapshot = snapshots[time_idx]
    
    pos1 = snapshot['state1'].detach().numpy()
    pos2 = snapshot['state2'].detach().numpy()
    alive1 = snapshot['alive1'].numpy()
    alive2 = snapshot['alive2'].numpy()
    
    # Swarm 1 alive
    s1_alive = hv.Points(
        pos1[alive1],
        label='Swarm 1 (alive)'
    ).opts(size=6, color='blue', alpha=0.6, line_color='black', line_width=0.5)
    
    # Swarm 1 dead
    s1_dead = hv.Points(
        pos1[~alive1] if (~alive1).any() else np.empty((0, 2)),
        label='Swarm 1 (dead)'
    ).opts(size=6, color='lightblue', alpha=0.3, marker='x')
    
    # Swarm 2 alive
    s2_alive = hv.Points(
        pos2[alive2],
        label='Swarm 2 (alive)'
    ).opts(size=6, color='red', alpha=0.6, line_color='black', line_width=0.5)
    
    # Swarm 2 dead
    s2_dead = hv.Points(
        pos2[~alive2] if (~alive2).any() else np.empty((0, 2)),
        label='Swarm 2 (dead)'
    ).opts(size=6, color='lightcoral', alpha=0.3, marker='x')
    
    # Combine
    plot = (
        qsd_image * boundary_box * mode_points *
        s1_alive * s1_dead * s2_alive * s2_dead
    ).opts(
        title=f'Time t = {time_idx} (Alive: S1={alive1.sum()}/{len(alive1)}, S2={alive2.sum()}/{len(alive2)})',
        width=550,
        height=550,
        xlim=(-9, 9),
        ylim=(-9, 9),
    )
    
    return plot

# Create all snapshots
snapshot_plots = [create_snapshot_plot(t) for t in sorted(snapshots.keys())]

# Layout in 2x3 grid
evolution_grid = hv.Layout(snapshot_plots).cols(3).opts(
    title='Two-Swarm Evolution with Boundary Handling'
)

evolution_grid

## 9. Summary Dashboard

Combine all key metrics into an interactive Panel dashboard

In [None]:
# Create summary text
summary_text = f"""
## Two-Swarm Convergence Summary

**Framework**: N-normalized Lyapunov functions from 03_cloning.md

**Initial Conditions**:
- Swarm 1: Upper right, {metrics['n_alive_1'][0]}/{N} alive
- Swarm 2: Lower left, {metrics['n_alive_2'][0]}/{N} alive
- Initial distance: {metrics['com_distance'][0]:.4f}

**Final State** (t={metrics['time'][-1]}):
- Swarm 1: {metrics['n_alive_1'][-1]}/{N} alive, V_total={metrics['V_total_1'][-1]:.2f}
- Swarm 2: {metrics['n_alive_2'][-1]}/{N} alive, V_total={metrics['V_total_2'][-1]:.2f}
- Final distance: {metrics['com_distance'][-1]:.4f}
- Distance reduction: {100 * (1 - metrics['com_distance'][-1]/metrics['com_distance'][0]):.1f}%

**Key Observations**:
1. ✓ Boundaries enforced (walkers die outside [-8,8]×[-8,8])
2. ✓ Cloning resurrects dead walkers
3. ✓ Both swarms converge toward same modes
4. ✓ Framework-correct N-normalized Lyapunov functions
"""

# Create dashboard
dashboard = pn.Column(
    pn.pane.Markdown(summary_text),
    pn.Row(alive_plot, lyap_plot),
    pn.Row(variance_plot, distance_plot),
    evolution_grid,
)

dashboard

## 10. Validation Checks

In [None]:
print("=" * 70)
print("VALIDATION CHECKS")
print("=" * 70)

# Check 1: Were there any deaths?
min_alive_1 = min(metrics['n_alive_1'])
min_alive_2 = min(metrics['n_alive_2'])

if min_alive_1 < N or min_alive_2 < N:
    print(f"✓ Boundaries enforced: Some walkers died!")
    print(f"  Swarm 1 min alive: {min_alive_1}/{N}")
    print(f"  Swarm 2 min alive: {min_alive_2}/{N}")
else:
    print(f"⚠️  No deaths observed - walkers may not have reached boundaries")

# Check 2: Did alive counts recover?
final_alive_1 = metrics['n_alive_1'][-1]
final_alive_2 = metrics['n_alive_2'][-1]

if final_alive_1 >= 0.8 * N and final_alive_2 >= 0.8 * N:
    print(f"\n✓ Cloning works: Most walkers alive at end")
    print(f"  Final alive: S1={final_alive_1}/{N}, S2={final_alive_2}/{N}")
else:
    print(f"\n⚠️  Low alive count at end - may need parameter tuning")

# Check 3: Did swarms converge?
init_dist = metrics['com_distance'][0]
final_dist = metrics['com_distance'][-1]
reduction = 100 * (1 - final_dist / init_dist)

if reduction > 50:
    print(f"\n✓ Inter-swarm convergence: {reduction:.1f}% distance reduction")
else:
    print(f"\n⚠️  Weak convergence: {reduction:.1f}% distance reduction")

# Check 4: Framework correctness
print(f"\n✓ Framework-correct Lyapunov functions used (N-normalized)")
print(f"  V_Var,x and V_Var,v computed with alive walkers only")
print(f"  N-normalization ensures N-uniform drift inequalities")

print("\n" + "=" * 70)
print("All validation checks complete!")
print("=" * 70)