# Gas Algorithms Comparison: Ricci vs Differentiable

This notebook compares two approaches to Fragile Gas optimization:

1. **Ricci Gas**: Geometry-driven exploration using KDE Hessian curvature
2. **Differentiable Gas**: End-to-end differentiable with Gumbel-softmax cloning

## Key Differences

| Feature | Ricci Gas | Differentiable Gas |
|---------|-----------|--------------------|
| Companion Selection | Deterministic (geometric) | Gumbel-softmax (stochastic) |
| Differentiability | Partial (KDE + Hessian) | Full (end-to-end) |
| Cloning | Hard (discrete) | Soft (weighted average) |
| Meta-optimization | Manual tuning | Gradient-based |
| Computational Cost | O(N²) for Hessian | O(N²) for logits |

## Test Problems

We compare on:
1. **Sphere**: Simple convex function
2. **Rastrigin**: Many local minima
3. **Lennard-Jones**: Physics-based clustering

In [None]:
import sys
import numpy as np
import torch
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Set device
device = torch.device("cpu")  # Change to "cuda" if available

sys.path.insert(0, '..')

from src.fragile.ricci_gas import RicciGas, RicciGasParams, SwarmState as RicciState
from src.fragile.differentiable_gas import (
    DifferentiableGas,
    DifferentiableGasParams,
    SwarmState as DiffState,
)

print(f"✓ Using device: {device}")

## Define Test Problems

In [None]:
def sphere_potential(x: torch.Tensor) -> torch.Tensor:
    """Sphere function: f(x) = ||x||^2. Global minimum at origin."""
    return (x**2).sum(dim=-1)

def rastrigin_potential(x: torch.Tensor) -> torch.Tensor:
    """Rastrigin function: many local minima."""
    A = 10
    d = x.shape[-1]
    pi = torch.tensor(3.14159265, device=x.device, dtype=x.dtype)
    return A * d + (x**2 - A * torch.cos(2 * pi * x)).sum(dim=-1)

def lennard_jones_potential(x: torch.Tensor) -> torch.Tensor:
    """Lennard-Jones cluster energy."""
    N = len(x)
    diff = x.unsqueeze(0) - x.unsqueeze(1)  # [N, N, d]
    r = diff.norm(dim=-1)  # [N, N]
    r = r + torch.eye(N, device=x.device) * 1e10  # Avoid self-interaction
    
    sigma = 1.0
    epsilon = 1.0
    r6 = (sigma / r) ** 6
    r12 = r6 ** 2
    V_pair = 4 * epsilon * (r12 - r6)
    
    mask = torch.triu(torch.ones(N, N, device=x.device), diagonal=1).bool()
    return V_pair[mask].sum()

problems = {
    "sphere": {"fn": sphere_potential, "bounds": (-5, 5), "d": 3, "N": 50},
    "rastrigin": {"fn": rastrigin_potential, "bounds": (-5, 5), "d": 2, "N": 60},
}

print("✓ Test problems defined")

## Setup Algorithms

In [None]:
def create_ricci_gas(problem_name, device):
    """Create Ricci Gas for problem."""
    params = RicciGasParams(
        epsilon_R=0.3,
        kde_bandwidth=0.5,
        epsilon_Ric=0.01,
        force_mode="pull",
        reward_mode="inverse",
        gradient_clip=5.0,
    )
    return RicciGas(params, device=device)

def create_diff_gas(problem_name, device):
    """Create Differentiable Gas for problem."""
    params = DifferentiableGasParams(
        selection_mode="combined",
        tau_init=2.0,
        tau_min=0.1,
        tau_anneal=0.995,
        clone_rate=0.15,
        elite_fraction=0.2,
        gamma=0.9,
        sigma_v=0.1,
    )
    return DifferentiableGas(params, device=device)

print("✓ Algorithm constructors ready")

## Run Comparison

In [None]:
def run_optimization(gas, problem, T=100, dt=0.1):
    """Run optimization and track performance."""
    history = []
    
    if isinstance(gas, RicciGas):
        # Ricci Gas
        N = problem["N"]
        d = problem["d"]
        bounds = problem["bounds"]
        
        torch.manual_seed(42)
        x = torch.rand(N, d, device=device) * (bounds[1] - bounds[0]) + bounds[0]
        v = torch.randn(N, d, device=device) * 0.1
        s = torch.ones(N, device=device)
        state = RicciState(x=x, v=v, s=s)
        
        for t in range(T):
            # Compute curvature
            R, H = gas.compute_curvature(state, cache=True)
            
            # Compute force
            force = gas.compute_force(state)
            
            # Langevin update
            state.v = 0.9 * state.v + 0.1 * force + torch.randn_like(state.v, device=device) * 0.05
            state.x = state.x + state.v * dt
            
            # Compute potential
            V = problem["fn"](state.x)
            best_V = V.min().item()
            mean_V = V.mean().item()
            
            history.append({"t": t, "best": best_V, "mean": mean_V, "R_mean": R.mean().item()})
            
    else:
        # Differentiable Gas
        gas.potential = problem["fn"]
        state = gas.init_state(N=problem["N"], d=problem["d"], bounds=problem["bounds"])
        
        for t in range(T):
            state = gas.step(state, dt=dt)
            
            # Reward = -potential
            best_V = -state.reward.max().item()
            mean_V = -state.reward.mean().item()
            
            history.append({"t": t, "best": best_V, "mean": mean_V, "tau": gas.tau})
    
    return history

print("✓ Running comparison...")

results = {}
for prob_name, prob in problems.items():
    print(f"\n  {prob_name}:")
    
    # Ricci Gas
    gas_ricci = create_ricci_gas(prob_name, device)
    results[f"{prob_name}_ricci"] = run_optimization(gas_ricci, prob, T=100)
    print(f"    Ricci: best = {results[f'{prob_name}_ricci'][-1]['best']:.4f}")
    
    # Differentiable Gas
    gas_diff = create_diff_gas(prob_name, device)
    results[f"{prob_name}_diff"] = run_optimization(gas_diff, prob, T=100)
    print(f"    Diff:  best = {results[f'{prob_name}_diff'][-1]['best']:.4f}")

print("\n✓ Comparison complete")

## Visualize Results

In [None]:
for prob_name in problems.keys():
    fig = make_subplots(
        rows=1, cols=2,
        subplot_titles=("Best Energy", "Mean Energy"),
    )
    
    # Ricci Gas
    hist_ricci = results[f"{prob_name}_ricci"]
    t_ricci = [h["t"] for h in hist_ricci]
    best_ricci = [h["best"] for h in hist_ricci]
    mean_ricci = [h["mean"] for h in hist_ricci]
    
    fig.add_trace(go.Scatter(
        x=t_ricci, y=best_ricci, name="Ricci Gas", line=dict(color="blue")
    ), row=1, col=1)
    fig.add_trace(go.Scatter(
        x=t_ricci, y=mean_ricci, name="Ricci Gas", line=dict(color="blue"), showlegend=False
    ), row=1, col=2)
    
    # Differentiable Gas
    hist_diff = results[f"{prob_name}_diff"]
    t_diff = [h["t"] for h in hist_diff]
    best_diff = [h["best"] for h in hist_diff]
    mean_diff = [h["mean"] for h in hist_diff]
    
    fig.add_trace(go.Scatter(
        x=t_diff, y=best_diff, name="Diff Gas", line=dict(color="red")
    ), row=1, col=1)
    fig.add_trace(go.Scatter(
        x=t_diff, y=mean_diff, name="Diff Gas", line=dict(color="red"), showlegend=False
    ), row=1, col=2)
    
    fig.update_xaxes(title_text="Iteration", row=1, col=1)
    fig.update_xaxes(title_text="Iteration", row=1, col=2)
    fig.update_yaxes(title_text="Energy", row=1, col=1)
    fig.update_yaxes(title_text="Energy", row=1, col=2)
    
    fig.update_layout(
        title=f"Comparison on {prob_name.capitalize()}",
        height=400,
        width=900,
    )
    
    fig.show()

## Summary

### Performance

Both algorithms should converge to good solutions, but with different characteristics:

- **Ricci Gas**: Exploits geometric structure via curvature
- **Differentiable Gas**: Balances fitness and diversity via Gumbel-softmax

### Differentiability Test

In [None]:
print("Testing end-to-end differentiability...\n")

# Setup
prob = problems["sphere"]
gas_diff = create_diff_gas("sphere", device)
gas_diff.potential = prob["fn"]

# Initial state with learnable parameter
tau_param = torch.tensor(1.0, device=device, requires_grad=True)
gas_diff.tau = tau_param

state = gas_diff.init_state(N=30, d=3, bounds=(-3, 3))

# Short rollout
for _ in range(5):
    state = gas_diff.step(state, dt=0.1)

# Compute loss
loss = -state.reward.mean()  # Minimize potential = maximize reward

# Backprop
loss.backward()

print(f"Loss: {loss.item():.4f}")
print(f"Temperature gradient: {tau_param.grad.item():.6f}")
print(f"\n✓ Gradient flows through entire dynamics!")
print("  This enables meta-optimization of hyperparameters.")

## Conclusion

### When to use Ricci Gas:
- Geometric structure is important
- Exploring manifolds with varying curvature
- Physics-inspired problems

### When to use Differentiable Gas:
- Need end-to-end differentiability
- Want to meta-optimize hyperparameters
- Prefer soft, continuous operations
- Need gradient-based training of the optimization process itself

### Future Work:
1. Hybrid approach: Ricci curvature as logits for Gumbel-softmax
2. Learned temperature schedules
3. Neural network parameterization of selection logits
4. Multi-objective optimization with Pareto fronts