# Ricci Fragile Gas: 3D Visualization and Physics Application

This notebook provides interactive 3D visualizations of:
1. Walkers in flat Euclidean space
2. Walkers on the emergent Riemannian manifold
3. A real physics problem: Lennard-Jones cluster optimization

**Theory**: See `docs/source/12_fractal_gas.md`

**Implementation**: See `src/fragile/ricci_gas.py`

In [1]:
import sys
import numpy as np
import torch
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import holoviews as hv
from holoviews import opts
hv.extension('plotly')

# Set device to CPU by default
# To use CUDA, change this to: device = torch.device("cuda")
device = torch.device("cpu")

# IMPORTANT: Don't use torch.set_default_device() - it's unreliable
# Instead, explicitly pass device to all tensor creations

# Add parent directory to path
sys.path.insert(0, '..')

from src.fragile.ricci_gas import (
    RicciGas,
    RicciGasParams,
    SwarmState,
    create_ricci_gas_variants,
    compute_kde_density,
    compute_kde_hessian,
    compute_ricci_proxy_3d,
)

print("✓ Imports successful")
print(f"✓ Using device: {device}")

✓ Imports successful
✓ Using device: cpu


## 1. Initialize Ricci Gas

In [2]:
# Create Ricci Gas with moderate feedback strength and boundaries
params = RicciGasParams(
    epsilon_R=0.5,           # Feedback strength (try varying: 0.1, 0.5, 1.0, 2.0)
    kde_bandwidth=0.4,       # Smoothing length
    epsilon_Ric=0.01,        # Regularization
    force_mode="pull",       # Gravity: toward high curvature
    reward_mode="inverse",   # Anti-gravity: reward low curvature
    R_crit=15.0,             # Singularity threshold
    gradient_clip=10.0,      # Numerical stability
    epsilon_clone=100000,       # Cloning interaction range
    sigma_clone=0.2,         # Positional jitter for cloning
    x_min=-4.0,              # Lower boundary
    x_max=4.0,               # Upper boundary
)

# IMPORTANT: Pass device to RicciGas
gas = RicciGas(params, device=device)

print(f"Ricci Gas initialized:")
print(f"  Feedback strength α = {params.epsilon_R}")
print(f"  Smoothing length ℓ = {params.kde_bandwidth}")
print(f"  Force mode: {params.force_mode}")
print(f"  Reward mode: {params.reward_mode}")
print(f"  Boundaries: [{params.x_min}, {params.x_max}]")
print(f"  Cloning: epsilon={params.epsilon_clone}, sigma={params.sigma_clone}")
print(f"  Device: {gas.device}")

Ricci Gas initialized:
  Feedback strength α = 0.5
  Smoothing length ℓ = 0.4
  Force mode: pull
  Reward mode: inverse
  Boundaries: [-4.0, 4.0]
  Cloning: epsilon=100000.0, sigma=0.2
  Device: cpu


## 2. Initialize Swarm

Start with walkers in a random configuration.

In [19]:
N = 1000  # Number of walkers
d = 3    # Dimension (always 3 for our implementation)

# Random initialization in [-2, 2]^3
torch.manual_seed(42)
x = torch.rand(N, d, device=device) * 4.0 - 2.0
v = torch.randn(N, d, device=device) * 0.1
s = torch.ones(N, device=device)

state = SwarmState(x=x, v=v, s=s)

print(f"Swarm initialized: {N} walkers in {d}D")
print(f"  Position range: [{x.min():.2f}, {x.max():.2f}]")
print(f"  Velocity std: {v.std():.3f}")

Swarm initialized: 1000 walkers in 3D
  Position range: [-1.99, 2.00]
  Velocity std: 0.100


## 3. Compute Initial Geometry

In [20]:
# Compute Ricci curvature and Hessian
R, H = gas.compute_curvature(state, cache=True)

print("Emergent geometry computed:")
print(f"  Ricci curvature R:")
print(f"    Min:  {R.min():.3f}")
print(f"    Mean: {R.mean():.3f}")
print(f"    Max:  {R.max():.3f}")
print(f"  Hessian eigenvalues (sample walker 0):")
eigenvals = torch.linalg.eigvalsh(H[0])
print(f"    λ = [{eigenvals[0]:.3f}, {eigenvals[1]:.3f}, {eigenvals[2]:.3f}]")

Emergent geometry computed:
  Ricci curvature R:
    Min:  -0.075
    Mean: -0.010
    Max:  0.047
  Hessian eigenvalues (sample walker 0):
    λ = [-0.046, -0.028, 0.008]


## 4. Visualization 1: Walkers in Flat Space

Visualize walkers in Euclidean 3D, colored by Ricci curvature.

In [21]:
def plot_walkers_3d(state, title="Walkers in Flat Space"):
    """Plot walkers in 3D, colored by Ricci curvature."""
    x_np = state.x.detach().cpu().numpy()
    R_np = state.R.detach().cpu().numpy()
    alive = state.s.bool().cpu().numpy()
    
    # Alive walkers
    fig = go.Figure(data=[go.Scatter3d(
        x=x_np[alive, 0],
        y=x_np[alive, 1],
        z=x_np[alive, 2],
        mode='markers',
        marker=dict(
            size=5,
            color=R_np[alive],
            colorscale='RdBu_r',
            colorbar=dict(title="Ricci R"),
            line=dict(width=0.5, color='black'),
        ),
        text=[f"R={R_np[i]:.3f}" for i in np.where(alive)[0]],
        hovertemplate="<b>Walker %{text}</b><br>" +
                      "x: %{x:.2f}<br>" +
                      "y: %{y:.2f}<br>" +
                      "z: %{z:.2f}<extra></extra>",
        name="Alive",
    )])
    
    # Dead walkers (if any)
    if (~alive).any():
        fig.add_trace(go.Scatter3d(
            x=x_np[~alive, 0],
            y=x_np[~alive, 1],
            z=x_np[~alive, 2],
            mode='markers',
            marker=dict(size=3, color='gray', opacity=0.3),
            name="Dead",
        ))
    
    fig.update_layout(
        title=title,
        scene=dict(
            xaxis_title="x",
            yaxis_title="y",
            zaxis_title="z",
            aspectmode='cube',
        ),
        width=800,
        height=700,
    )
    
    return fig

fig = plot_walkers_3d(state, title="Initial Configuration: Walkers Colored by Ricci Curvature")
fig.show()

## 5. Run Dynamics

Evolve the swarm under Ricci-driven forces.

In [22]:
# Run dynamics and track statistics WITH CLONING
history = []
T = 1000  # Reduced from 3000 for faster execution and better animation
dt = 0.1

print("Running dynamics with cloning, boundaries, and tracking full history...")
print(f"  Using gas.step() method")
print(f"  Boundaries: [{params.x_min}, {params.x_max}]")
print(f"  Tracking positions for animation")

for t in range(T):
    # Use the step method (includes cloning, curvature, force, dynamics, boundaries)
    state = gas.step(state, dt=dt, gamma=0.9, noise_std=0.05, do_clone=True)
    
    # Track statistics
    alive = state.s.bool()
    alive_count = alive.sum().item()
    
    if alive_count > 0:
        variance = state.x[alive].var(dim=0).sum().item()
        R_mean = state.R[alive].mean().item()
        R_max = state.R[alive].max().item()
    else:
        variance = 0.0
        R_mean = 0.0
        R_max = 0.0
    
    history.append({
        't': t,
        'variance': variance,
        'R_mean': R_mean,
        'R_max': R_max,
        'alive_fraction': alive.float().mean().item(),
        'alive_count': alive_count,
        'x': state.x.clone().cpu(),  # Store positions for animation
        'R': state.R.clone().cpu(),  # Store curvature for coloring
        's': state.s.clone().cpu(),  # Store alive status
    })
    
    if t % 50 == 0:
        print(f"  t={t:3d}: alive={int(alive_count):3d}/{N}, var={variance:.3f}, R_mean={R_mean:.3f}, R_max={R_max:.3f}")

print("\nDynamics complete!")
print(f"  Final alive: {int(history[-1]['alive_count'])}/{N}")

Running dynamics with cloning, boundaries, and tracking full history...
  Using gas.step() method
  Boundaries: [-4.0, 4.0]
  Tracking positions for animation
  t=  0: alive=1000/1000, var=4.079, R_mean=-0.010, R_max=0.051
  t= 50: alive=994/1000, var=8.221, R_mean=-0.013, R_max=0.039
  t=100: alive=995/1000, var=9.154, R_mean=-0.012, R_max=0.027
  t=150: alive=994/1000, var=8.383, R_mean=-0.010, R_max=0.033
  t=200: alive=991/1000, var=8.946, R_mean=-0.013, R_max=0.030
  t=250: alive=982/1000, var=10.382, R_mean=-0.015, R_max=0.026
  t=300: alive=992/1000, var=10.129, R_mean=-0.014, R_max=0.028
  t=350: alive=993/1000, var=8.776, R_mean=-0.014, R_max=0.047
  t=400: alive=993/1000, var=8.947, R_mean=-0.012, R_max=0.039
  t=450: alive=991/1000, var=9.649, R_mean=-0.014, R_max=0.026
  t=500: alive=989/1000, var=9.096, R_mean=-0.014, R_max=0.038
  t=550: alive=993/1000, var=9.292, R_mean=-0.013, R_max=0.032
  t=600: alive=986/1000, var=9.940, R_mean=-0.015, R_max=0.032
  t=650: alive=989/

## 6. Visualization 2: Evolution Metrics

In [7]:
# Extract time series
t_vals = [h['t'] for h in history]
variance = [h['variance'] for h in history]
R_mean = [h['R_mean'] for h in history]
R_max = [h['R_max'] for h in history]
alive_frac = [h['alive_fraction'] for h in history]
alive_count = [h['alive_count'] for h in history]

# Create subplots - 2x2 grid
fig = make_subplots(
    rows=2, cols=2,
    subplot_titles=('Spatial Variance', 'Mean Ricci Curvature', 'Max Ricci Curvature', 'Alive Walkers'),
)

fig.add_trace(go.Scatter(x=t_vals, y=variance, mode='lines', name='Variance'), row=1, col=1)
fig.add_trace(go.Scatter(x=t_vals, y=R_mean, mode='lines', name='R_mean', line=dict(color='orange')), row=1, col=2)
fig.add_trace(go.Scatter(x=t_vals, y=R_max, mode='lines', name='R_max', line=dict(color='red')), row=2, col=1)
fig.add_trace(go.Scatter(x=t_vals, y=alive_count, mode='lines', name='Alive', line=dict(color='green')), row=2, col=2)

# Add horizontal line for initial count
fig.add_hline(y=N, line_dash="dash", line_color="gray", row=2, col=2)

fig.update_xaxes(title_text="Time", row=1, col=1)
fig.update_xaxes(title_text="Time", row=1, col=2)
fig.update_xaxes(title_text="Time", row=2, col=1)
fig.update_xaxes(title_text="Time", row=2, col=2)

fig.update_yaxes(title_text="Variance", row=1, col=1)
fig.update_yaxes(title_text="R_mean", row=1, col=2)
fig.update_yaxes(title_text="R_max", row=2, col=1)
fig.update_yaxes(title_text="Count", row=2, col=2)

fig.update_layout(height=700, showlegend=False, title_text="Evolution of Swarm Statistics (with Boundaries & Cloning)")
fig.show()

# Interpretation
final_var = variance[-1]
mean_alive = sum(alive_count) / len(alive_count)
min_alive = min(alive_count)

print(f"\n📊 Population Dynamics:")
print(f"  Initial: {N}")
print(f"  Final: {int(alive_count[-1])}")
print(f"  Mean: {mean_alive:.1f}")
print(f"  Min: {int(min_alive)}")

if final_var < 0.5 * variance[0]:
    print("\n📉 SUPERCRITICAL REGIME: Variance collapsed (possible phase transition)")
elif final_var > 0.8 * variance[0]:
    print("\n🌊 SUBCRITICAL REGIME: Variance stable (diffuse gas phase)")
else:
    print("\n⚖️  NEAR-CRITICAL: Intermediate behavior")


📊 Population Dynamics:
  Initial: 500
  Final: 493
  Mean: 496.0
  Min: 484

🌊 SUBCRITICAL REGIME: Variance stable (diffuse gas phase)


In [8]:
# Create 3D animation of walker evolution
print("Creating 3D time evolution animation...")
print(f"  Total frames: {len(history)}")
print(f"  Sampling every 10th frame for performance")

# Sample frames for animation (every 10th frame)
frame_indices = list(range(0, len(history), 10))

# Find global min/max for consistent color scale
all_R = [h['R'][h['s'].bool()] for h in history if h['s'].sum() > 0]
R_min_global = min([r.min().item() for r in all_R])
R_max_global = max([r.max().item() for r in all_R])

# Create frames
frames = []
for idx in frame_indices:
    h = history[idx]
    alive_mask = h['s'].bool()
    x_np = h['x'][alive_mask].numpy()
    R_np = h['R'][alive_mask].numpy()
    
    frame = go.Frame(
        data=[go.Scatter3d(
            x=x_np[:, 0],
            y=x_np[:, 1],
            z=x_np[:, 2],
            mode='markers',
            marker=dict(
                size=5,
                color=R_np,
                colorscale='RdBu_r',
                cmin=R_min_global,
                cmax=R_max_global,
                colorbar=dict(title="Ricci R"),
                line=dict(width=0.5, color='black'),
            ),
            text=[f"R={R_np[i]:.3f}" for i in range(len(R_np))],
            hovertemplate="<b>Walker</b><br>" +
                          "x: %{x:.2f}<br>" +
                          "y: %{y:.2f}<br>" +
                          "z: %{z:.2f}<br>" +
                          "%{text}<extra></extra>",
        )],
        name=str(idx),
        layout=go.Layout(
            title=f"3D Evolution (t={h['t']}, alive={int(h['alive_count'])}/{N})"
        )
    )
    frames.append(frame)

# Initial frame
h0 = history[0]
alive_mask_0 = h0['s'].bool()
x_np_0 = h0['x'][alive_mask_0].numpy()
R_np_0 = h0['R'][alive_mask_0].numpy()

fig = go.Figure(
    data=[go.Scatter3d(
        x=x_np_0[:, 0],
        y=x_np_0[:, 1],
        z=x_np_0[:, 2],
        mode='markers',
        marker=dict(
            size=5,
            color=R_np_0,
            colorscale='RdBu_r',
            cmin=R_min_global,
            cmax=R_max_global,
            colorbar=dict(title="Ricci R"),
            line=dict(width=0.5, color='black'),
        ),
        text=[f"R={R_np_0[i]:.3f}" for i in range(len(R_np_0))],
        hovertemplate="<b>Walker</b><br>" +
                      "x: %{x:.2f}<br>" +
                      "y: %{y:.2f}<br>" +
                      "z: %{z:.2f}<br>" +
                      "%{text}<extra></extra>",
    )],
    frames=frames,
)

# Add boundary box
# Create wireframe box at [-4, 4]^3
box_x = [-4, 4, 4, -4, -4, -4, 4, 7, -4, -4, -4, -4, 4, 4, 4, 4]
box_y = [-4, -4, 4, 4, -4, -4, -4, -4, -4, 4, 4, 4, 4, 4, 4, 4]
box_z = [-4, -4, -4, -4, -4, 4, 4, 4, 4, 4, 4, -4, -4, 4, 4, -4]

fig.add_trace(go.Scatter3d(
    x=box_x, y=box_y, z=box_z,
    mode='lines',
    line=dict(color='red', width=2, dash='dash'),
    name='Boundaries',
    showlegend=True,
    hoverinfo='skip'
))

# Animation controls
fig.update_layout(
    title=f"3D Walker Evolution with Boundaries (Click Play)",
    scene=dict(
        xaxis=dict(title="x", range=[-5, 5]),
        yaxis=dict(title="y", range=[-5, 5]),
        zaxis=dict(title="z", range=[-5, 5]),
        aspectmode='cube',
    ),
    width=900,
    height=800,
    updatemenus=[
        dict(
            type="buttons",
            showactive=False,
            buttons=[
                dict(
                    label="▶ Play",
                    method="animate",
                    args=[None, {
                        "frame": {"duration": 50, "redraw": True},
                        "fromcurrent": True,
                        "mode": "immediate",
                        "transition": {"duration": 50}
                    }]
                ),
                dict(
                    label="⏸ Pause",
                    method="animate",
                    args=[[None], {
                        "frame": {"duration": 0, "redraw": False},
                        "mode": "immediate",
                        "transition": {"duration": 0}
                    }]
                )
            ],
            x=0.1,
            y=1.15,
        )
    ],
    sliders=[{
        "active": 0,
        "steps": [
            {
                "args": [[f.name], {
                    "frame": {"duration": 0, "redraw": True},
                    "mode": "immediate",
                }],
                "label": str(history[int(f.name)]['t']),
                "method": "animate"
            }
            for f in frames
        ],
        "x": 0.1,
        "len": 0.85,
        "xanchor": "left",
        "y": 0.05,
        "yanchor": "top",
    }]
)

fig.show()

print(f"✓ Animation ready! ({len(frames)} frames)")
print(f"  Red dashed box shows boundaries [-4, 4]³")
print(f"  Walkers colored by Ricci curvature")
print(f"  Use Play button or slider to explore evolution")

Creating 3D time evolution animation...
  Total frames: 500
  Sampling every 10th frame for performance


✓ Animation ready! (50 frames)
  Red dashed box shows boundaries [-4, 4]³
  Walkers colored by Ricci curvature
  Use Play button or slider to explore evolution


## 6b. Time Evolution Animation

Watch the walkers evolve in 3D space, colored by Ricci curvature.

## 7. Visualization 3: Final Configuration

In [9]:
# Recompute final geometry
R_final, H_final = gas.compute_curvature(state, cache=True)

fig = plot_walkers_3d(state, title="Final Configuration: Emergent Structure")
fig.show()

print(f"\nFinal state:")
print(f"  Alive walkers: {state.s.sum():.0f}/{N}")
print(f"  Ricci range: [{R_final[state.s.bool()].min():.3f}, {R_final[state.s.bool()].max():.3f}]")
print(f"  Spatial std: {state.x[state.s.bool()].std(dim=0).mean():.3f}")


Final state:
  Alive walkers: 493/500
  Ricci range: [-0.221, 0.034]
  Spatial std: 1.829


## 8. Visualization 4: Emergent Manifold

Visualize the emergent Riemannian metric via the **metric tensor eigenvalues**.

For each walker, the metric $g_i = H_i + \epsilon_\Sigma I$ defines local distances. We visualize:
- **Eigenvalue magnitudes** (size of ellipsoid axes)
- **Anisotropy** (ratio of max/min eigenvalue)

In [10]:
def plot_emergent_manifold(state, epsilon_Sigma=0.01):
    """Visualize emergent metric via eigenvalue ellipsoids."""
    x_np = state.x.detach().cpu().numpy()
    H_np = state.H.detach().cpu().numpy()
    alive = state.s.bool().cpu().numpy()
    
    # Compute metric eigenvalues for alive walkers
    G = H_np + epsilon_Sigma * np.eye(3)  # g = H + ε I
    eigenvals = np.linalg.eigvalsh(G)  # [N, 3]
    
    # Anisotropy = max / min eigenvalue
    anisotropy = eigenvals[:, 2] / (eigenvals[:, 0] + 1e-8)
    
    fig = go.Figure()
    
    # Plot walkers, sized by mean eigenvalue, colored by anisotropy
    mean_eigval = eigenvals.mean(axis=1)
    
    fig.add_trace(go.Scatter3d(
        x=x_np[alive, 0],
        y=x_np[alive, 1],
        z=x_np[alive, 2],
        mode='markers',
        marker=dict(
            size=5 + 10 * (mean_eigval[alive] - mean_eigval[alive].min()) / (mean_eigval[alive].max() - mean_eigval[alive].min() + 1e-8),
            color=anisotropy[alive],
            colorscale='Viridis',
            colorbar=dict(title="Anisotropy<br>(λ_max/λ_min)"),
            line=dict(width=0.5, color='white'),
        ),
        text=[f"λ=[{eigenvals[i,0]:.2f}, {eigenvals[i,1]:.2f}, {eigenvals[i,2]:.2f}]" 
              for i in np.where(alive)[0]],
        hovertemplate="<b>%{text}</b><br>" +
                      "Anisotropy: %{marker.color:.2f}<extra></extra>",
        name="Metric Tensor",
    ))
    
    fig.update_layout(
        title="Emergent Riemannian Manifold: Metric Tensor Eigenvalues",
        scene=dict(
            xaxis_title="x",
            yaxis_title="y",
            zaxis_title="z",
            aspectmode='cube',
        ),
        width=800,
        height=700,
    )
    
    return fig

fig_manifold = plot_emergent_manifold(state, epsilon_Sigma=params.epsilon_Sigma)
fig_manifold.show()

print("\nInterpretation:")
print("  Marker size: Mean eigenvalue (local 'stiffness' of metric)")
print("  Marker color: Anisotropy (how elongated the metric ellipsoid is)")
print("  High anisotropy → Directional bias in geometry")


Interpretation:
  Marker size: Mean eigenvalue (local 'stiffness' of metric)
  Marker color: Anisotropy (how elongated the metric ellipsoid is)
  High anisotropy → Directional bias in geometry


## 9. Visualization 5: Curvature Field

Create a 3D volume rendering of the Ricci curvature field.

In [18]:
def create_curvature_isosurface(state, gas, grid_res=30, iso_level=None):
    """Create isosurface of Ricci curvature."""
    x_np = state.x.detach().cpu().numpy()
    alive = state.s.bool()
    
    # Define grid
    x_min, x_max = x_np[:, 0].min() - 1, x_np[:, 0].max() + 1
    y_min, y_max = x_np[:, 1].min() - 1, x_np[:, 1].max() + 1
    z_min, z_max = x_np[:, 2].min() - 1, x_np[:, 2].max() + 1
    
    x_grid = torch.linspace(x_min, x_max, grid_res, device=device)
    y_grid = torch.linspace(y_min, y_max, grid_res, device=device)
    z_grid = torch.linspace(z_min, z_max, grid_res, device=device)
    
    xx, yy, zz = torch.meshgrid(x_grid, y_grid, z_grid, indexing='ij')
    
    # Evaluation points
    x_eval = torch.stack([xx.flatten(), yy.flatten(), zz.flatten()], dim=-1)
    
    # Compute Hessian on grid (this is expensive!)
    print(f"Computing curvature on {len(x_eval)} grid points...")
    from src.fragile.ricci_gas import compute_kde_hessian, compute_ricci_proxy_3d
    
    H_grid = compute_kde_hessian(
        state.x,
        x_eval,
        gas.params.kde_bandwidth,
        alive,
    )
    R_grid = compute_ricci_proxy_3d(H_grid)
    R_grid = R_grid.reshape(grid_res, grid_res, grid_res)
    R_np = R_grid.detach().cpu().numpy()
    
    # Auto-select iso level if not provided
    if iso_level is None:
        iso_level = np.percentile(R_np, 75)  # 75th percentile
    
    print(f"Creating isosurface at R = {iso_level:.3f}")
    
    fig = go.Figure(data=go.Isosurface(
        x=xx.flatten().cpu().numpy(),
        y=yy.flatten().cpu().numpy(),
        z=zz.flatten().cpu().numpy(),
        value=R_np.flatten(),
        isomin=iso_level * 0.8,
        isomax=iso_level * 1.2,
        surface_count=3,
        colorscale='RdBu_r',
        colorbar=dict(title="Ricci R"),
        opacity=0.3,
        name="Curvature",
    ))
    
    # Add walkers
    fig.add_trace(go.Scatter3d(
        x=x_np[alive.cpu(), 0],
        y=x_np[alive.cpu(), 1],
        z=x_np[alive.cpu(), 2],
        mode='markers',
        marker=dict(size=3, color='black'),
        name="Walkers",
    ))
    
    fig.update_layout(
        title="Ricci Curvature Isosurface with Walkers",
        scene=dict(
            xaxis_title="x",
            yaxis_title="y",
            zaxis_title="z",
            aspectmode='cube',
        ),
        width=900,
        height=800,
    )
    
    return fig

# This is expensive - use low resolution for demo
print("⚠️  Warning: Isosurface computation is expensive. Using low resolution (20³ grid).")
print("   Increase grid_res for higher quality (but slower).\n")

fig_iso = create_curvature_isosurface(state, gas, grid_res=20)
fig_iso.show()

   Increase grid_res for higher quality (but slower).

Computing curvature on 8000 grid points...
Creating isosurface at R = 0.004


## 10. Real Physics Problem: Lennard-Jones Cluster Optimization

**Problem**: Find the minimum energy configuration of N particles interacting via Lennard-Jones potential.

**Lennard-Jones Potential**:
$$
V_{LJ}(r) = 4\epsilon \left[ \left(\frac{\sigma}{r}\right)^{12} - \left(\frac{\sigma}{r}\right)^{6} \right]
$$

**Total energy**:
$$
E = \sum_{i<j} V_{LJ}(\|x_i - x_j\|)
$$

**Challenge**: This has many local minima. Known global minima exist for small N (e.g., N=13 → icosahedron).

**Hypothesis**: The Ricci Gas can discover low-energy configurations by exploring negative curvature regions (saddle points connecting basins).

In [12]:
def lennard_jones_energy(x, epsilon=1.0, sigma=1.0):
    """
    Compute Lennard-Jones energy for a set of particle positions.
    
    Args:
        x: [N, 3] particle positions
        epsilon: Energy scale
        sigma: Length scale
    
    Returns:
        E: Total energy (scalar)
        E_per_pair: [N, N] pairwise energies
    """
    N = len(x)
    
    # Pairwise distances [N, N]
    diff = x.unsqueeze(0) - x.unsqueeze(1)  # [N, N, 3]
    r = diff.norm(dim=-1)  # [N, N]
    
    # Avoid self-interaction
    r = r + torch.eye(N, device=x.device) * 1e10
    
    # Lennard-Jones potential
    r6 = (sigma / r) ** 6
    r12 = r6 ** 2
    
    V_pair = 4 * epsilon * (r12 - r6)
    
    # Total energy (sum over upper triangle to avoid double counting)
    mask = torch.triu(torch.ones(N, N, device=x.device), diagonal=1).bool()
    E = V_pair[mask].sum()
    
    return E, V_pair

def lennard_jones_force(x, epsilon=1.0, sigma=1.0):
    """
    Compute LJ force on each particle.
    
    Returns:
        F: [N, 3] forces
    """
    x = x.requires_grad_(True)
    E, _ = lennard_jones_energy(x, epsilon, sigma)
    
    F = -torch.autograd.grad(E, x)[0]
    
    return F

# Test
x_test = torch.randn(5, 3, device=device)
E_test, _ = lennard_jones_energy(x_test)
F_test = lennard_jones_force(x_test)

print(f"Lennard-Jones test:")
print(f"  5 particles: E = {E_test:.3f}")
print(f"  Force on particle 0: F = [{F_test[0,0]:.2f}, {F_test[0,1]:.2f}, {F_test[0,2]:.2f}]")

Lennard-Jones test:
  5 particles: E = 387.201
  Force on particle 0: F = [-3352.02, 5593.88, -2232.69]


### Run Ricci Gas on Lennard-Jones Optimization

In [13]:
# Initialize cluster
N_atoms = 13  # Classic LJ13 problem (known global min: icosahedron)

torch.manual_seed(123)
x_lj = torch.randn(N_atoms, 3, device=device) * 2.0  # Random initial configuration
v_lj = torch.zeros(N_atoms, 3, device=device)
s_lj = torch.ones(N_atoms, device=device)

state_lj = SwarmState(x=x_lj, v=v_lj, s=s_lj)

# Ricci Gas parameters for LJ optimization
params_lj = RicciGasParams(
    epsilon_R=0.3,           # Moderate curvature force
    kde_bandwidth=0.5,       # Smooth over ~2-3 particle spacings
    force_mode="pull",       # Aggregate toward high curvature
    reward_mode="inverse",   # Reward low curvature (exploration)
    R_crit=None,             # No singularity killing for LJ
)

# IMPORTANT: Pass device to RicciGas
gas_lj = RicciGas(params_lj, device=device)

print(f"Lennard-Jones Cluster Optimization: N = {N_atoms}")
print(f"  Initial energy: {lennard_jones_energy(x_lj)[0]:.3f}")
print(f"  Known global minimum (LJ13): E ≈ -44.327")
print(f"\nRunning Ricci Gas + LJ dynamics...\n")

Lennard-Jones Cluster Optimization: N = 13
  Initial energy: -1.268
  Known global minimum (LJ13): E ≈ -44.327

Running Ricci Gas + LJ dynamics...



In [14]:
# Run optimization
history_lj = []
T_lj = 500
dt_lj = 0.05
gamma_lj = 0.8

best_E = float('inf')
best_x = None

for t in range(T_lj):
    # Compute Ricci geometry
    R_lj, H_lj = gas_lj.compute_curvature(state_lj, cache=True)
    F_ricci = gas_lj.compute_force(state_lj)
    
    # Compute LJ forces
    F_lj = lennard_jones_force(state_lj.x)
    
    # Combined dynamics: LJ force + Ricci curvature force
    F_total = F_lj + F_ricci
    
    # Langevin update
    state_lj.v = gamma_lj * state_lj.v + (1 - gamma_lj) * F_total + torch.randn_like(state_lj.v, device=device) * 0.1
    state_lj.x = state_lj.x + state_lj.v * dt_lj
    
    # Compute energy
    E_current, _ = lennard_jones_energy(state_lj.x)
    
    if E_current < best_E:
        best_E = E_current.item()
        best_x = state_lj.x.clone()
    
    # Track statistics
    history_lj.append({
        't': t,
        'E': E_current.item(),
        'E_best': best_E,
        'R_mean': R_lj.mean().item(),
        'R_max': R_lj.max().item(),
    })
    
    if t % 100 == 0:
        print(f"  t={t:3d}: E={E_current:.3f}, E_best={best_E:.3f}, R_mean={R_lj.mean():.2f}")

print(f"\nOptimization complete!")
print(f"  Best energy found: {best_E:.4f}")
print(f"  Known global min:  -44.327")
print(f"  Gap: {best_E - (-44.327):.4f}")

  t=  0: E=-1.373, E_best=-1.373, R_mean=-0.31
  t=100: E=-9.443, E_best=-9.443, R_mean=-0.26
  t=200: E=-6.016, E_best=-11.072, R_mean=-0.29
  t=300: E=-7.021, E_best=-11.072, R_mean=-0.28
  t=400: E=-7.906, E_best=-11.072, R_mean=-0.28

Optimization complete!
  Best energy found: -11.0720
  Known global min:  -44.327
  Gap: 33.2550


### Visualize LJ Optimization Results

In [15]:
# Plot energy evolution
t_lj = [h['t'] for h in history_lj]
E_lj = [h['E'] for h in history_lj]
E_best_lj = [h['E_best'] for h in history_lj]

fig = go.Figure()
fig.add_trace(go.Scatter(x=t_lj, y=E_lj, mode='lines', name='Current E', line=dict(color='blue', width=1)))
fig.add_trace(go.Scatter(x=t_lj, y=E_best_lj, mode='lines', name='Best E', line=dict(color='red', width=2)))
fig.add_hline(y=-44.327, line_dash="dash", annotation_text="Global minimum", line_color="green")

fig.update_layout(
    title="Lennard-Jones Cluster Optimization (LJ13)",
    xaxis_title="Iteration",
    yaxis_title="Energy",
    width=900,
    height=500,
)
fig.show()

In [16]:
# Visualize best configuration
x_best_np = best_x.detach().cpu().numpy()

# Compute pairwise distances
E_best, V_best = lennard_jones_energy(best_x)

fig = go.Figure()

# Draw atoms
fig.add_trace(go.Scatter3d(
    x=x_best_np[:, 0],
    y=x_best_np[:, 1],
    z=x_best_np[:, 2],
    mode='markers',
    marker=dict(size=15, color='blue', opacity=0.8, line=dict(width=2, color='darkblue')),
    name="Atoms",
))

# Draw bonds (for nearest neighbors, roughly r < 1.5σ)
diff = best_x.unsqueeze(0) - best_x.unsqueeze(1)
dist = diff.norm(dim=-1).cpu().detach().numpy()

bond_threshold = 1.5  # Rough cutoff for visualization
for i in range(N_atoms):
    for j in range(i+1, N_atoms):
        if dist[i, j] < bond_threshold:
            fig.add_trace(go.Scatter3d(
                x=[x_best_np[i, 0], x_best_np[j, 0]],
                y=[x_best_np[i, 1], x_best_np[j, 1]],
                z=[x_best_np[i, 2], x_best_np[j, 2]],
                mode='lines',
                line=dict(color='gray', width=2),
                showlegend=False,
            ))

fig.update_layout(
    title=f"Best LJ13 Configuration Found (E = {best_E:.3f})",
    scene=dict(
        xaxis_title="x",
        yaxis_title="y",
        zaxis_title="z",
        aspectmode='cube',
    ),
    width=800,
    height=700,
)
fig.show()

print(f"\nStructure analysis:")
print(f"  Center of mass: [{x_best_np.mean(axis=0)[0]:.3f}, {x_best_np.mean(axis=0)[1]:.3f}, {x_best_np.mean(axis=0)[2]:.3f}]")
print(f"  Radius of gyration: {np.sqrt(((x_best_np - x_best_np.mean(axis=0))**2).sum(axis=1).mean()):.3f}")
print(f"  Min pairwise distance: {dist[dist > 0].min():.3f}")
print(f"  Max pairwise distance: {dist.max():.3f}")


Structure analysis:
  Center of mass: [-0.576, 0.178, -0.482]
  Radius of gyration: 2.683
  Min pairwise distance: 1.076
  Max pairwise distance: 7.261


## 11. Summary and Next Steps

### What We've Demonstrated

1. **Flat Space Visualization**: Walkers colored by Ricci curvature
2. **Emergent Manifold**: Metric tensor eigenvalues showing geometric anisotropy
3. **Curvature Field**: 3D isosurfaces of Ricci scalar
4. **Phase Dynamics**: Evolution of variance, entropy, curvature
5. **Real Physics**: Lennard-Jones cluster optimization guided by curvature

### Key Observations

- **Phase behavior**: Depending on `epsilon_R` (α), the swarm either stays diffuse or collapses
- **Curvature guidance**: High curvature regions attract, low curvature regions disperse
- **LJ optimization**: The Ricci force helps escape local minima by exploring saddle points

### Experimental Directions

1. **Vary α**: Re-run with different `epsilon_R` values to find the phase transition
2. **Compare variants**: Test the 4 ablation study variants (Ricci, Aligned, Force-only, Reward-only)
3. **Larger LJ clusters**: Try N=19, 38, 55 (known difficult cases)
4. **Other physics problems**:
   - Protein folding (coarse-grained)
   - Rigid body packing
   - Molecular docking

### To Run More Experiments

```bash
# Full experimental suite
python experiments/ricci_gas_experiments.py --experiment all

# Or modify this notebook's parameters and re-run!
```

In [17]:
print("\n" + "="*60)
print("  Ricci Fragile Gas: Visualization Complete")
print("="*60)
print(f"\n📊 Generated visualizations:")
print(f"  ✓ Walkers in flat space")
print(f"  ✓ Evolution metrics (variance, curvature, alive fraction)")
print(f"  ✓ Emergent Riemannian manifold (metric eigenvalues)")
print(f"  ✓ Curvature isosurface")
print(f"  ✓ Lennard-Jones cluster optimization")
print(f"\n🔬 Physics problem: LJ{N_atoms} cluster")
print(f"  Best energy: {best_E:.4f}")
print(f"  Gap to global: {best_E - (-44.327):.4f}")
print(f"\n📖 Theory: docs/source/12_fractal_gas.md")
print(f"💻 Code: src/fragile/ricci_gas.py")
print(f"\n🚀 Next: Try varying epsilon_R to explore phase transition!")
print("="*60)


  Ricci Fragile Gas: Visualization Complete

📊 Generated visualizations:
  ✓ Walkers in flat space
  ✓ Evolution metrics (variance, curvature, alive fraction)
  ✓ Emergent Riemannian manifold (metric eigenvalues)
  ✓ Curvature isosurface
  ✓ Lennard-Jones cluster optimization

🔬 Physics problem: LJ13 cluster
  Best energy: -11.0720
  Gap to global: 33.2550

📖 Theory: docs/source/12_fractal_gas.md
💻 Code: src/fragile/ricci_gas.py

🚀 Next: Try varying epsilon_R to explore phase transition!
