# Manifold Visualizer Demo

## Why Geometry Matters for Machine Learning

When we optimize neural networks or analyze data, we implicitly work in **parameter spaces** that have intrinsic geometry. Standard gradient descent assumes Euclidean geometry (flat space), but this is often wrong:

- **Covariance matrices** live on a curved manifold (SPD cone)
- **Probability distributions** form a statistical manifold with Fisher Information as the metric
- **Directed relationships** require asymmetric (Finsler) metrics

This notebook visualizes these geometric structures to build intuition for:
1. Why Riemannian methods outperform Euclidean ones
2. How asymmetry in data manifests geometrically  
3. What "natural gradient" means visually

## Setup

```bash
pip install plotly ipywidgets
```


In [1]:
import numpy as np
import sys
sys.path.insert(0, '..')

from diffgeo.viz import ManifoldVisualizer, SPDViz, FinslerViz

# Create main visualizer
viz = ManifoldVisualizer()


## 1. SPD Manifold Visualization

### The Geometry of Covariance Matrices

A **Symmetric Positive Definite (SPD)** matrix $A$ satisfies:
- $A = A^T$ (symmetric)
- $x^T A x > 0$ for all $x \neq 0$ (positive definite)

**Key insight**: Each SPD matrix defines an ellipsoid — the set $\{x : x^T A^{-1} x = 1\}$.

The eigenvalues $\lambda_i$ determine the ellipsoid's shape:
- **Principal axes** align with eigenvectors
- **Axis lengths** are $\sqrt{\lambda_i}$

### Why This Matters

SPD matrices appear everywhere:
- **EEG/BCI**: Spatial covariance of electrode signals → brain state classification
- **DTI**: Diffusion tensors in brain imaging → nerve fiber tracking  
- **Finance**: Covariance of asset returns → portfolio optimization

The **Riemannian distance** on SPD matrices is:
$$d_R(A, B) = \left\| \log(A^{-1/2} B A^{-1/2}) \right\|_F = \sqrt{\sum_i \log^2 \lambda_i(A, B)}$$

where $\lambda_i(A,B)$ are generalized eigenvalues from $\det(B - \lambda A) = 0$.


In [2]:
# Generate sample covariance matrices
np.random.seed(42)

matrices = []
for i in range(4):
    # Random SPD matrix: L @ L.T is always positive semi-definite
    L = np.random.randn(3, 3)
    mat = L @ L.T + 0.5 * np.eye(3)
    matrices.append(mat)

print("Generated 4 random SPD matrices")
print(f"Sample eigenvalues: {np.linalg.eigvalsh(matrices[0])}")


Generated 4 random SPD matrices
Sample eigenvalues: [0.78320384 1.28854214 5.84613967]


In [3]:
# Visualize as ellipsoids
fig = viz.explore_spd(matrices, mode='ellipsoids')
fig.show()


**What you're seeing above**: Each ellipsoid represents a covariance matrix. The arrows show principal axes (eigenvectors) with lengths proportional to $\sqrt{\lambda_i}$. Rotate the view to see the 3D structure!

- Elongated ellipsoids = high anisotropy (strong directional correlation)
- Spherical ellipsoids = isotropic (equal variance in all directions)


### Euclidean vs Riemannian Mean: The Swelling Problem

**Euclidean mean** (what you'd normally compute):
$$\bar{A}_{Euc} = \frac{1}{n}\sum_{i=1}^n A_i$$

**Problem**: This causes **swelling** — the determinant $\det(\bar{A}_{Euc})$ is larger than the average of individual determinants! In DTI, this destroys information about nerve fiber direction.

**Riemannian (Fréchet) mean** respects the manifold geometry:
$$\bar{A}_{Riem} = \arg\min_{A \in \text{SPD}} \sum_{i=1}^n d_R^2(A, A_i)$$

This is computed iteratively:
1. Project all matrices to tangent space at current estimate
2. Average in tangent space (which IS flat)
3. Map back to manifold via exponential map
4. Repeat until convergence

**Result**: The Riemannian mean preserves determinant and anisotropy — geometrically "correct"!


In [4]:
# Compare means - notice the swelling effect!
fig = viz.explore_spd(matrices, mode='means')
fig.show()


In [5]:
# Geodesic interpolation between two matrices
fig = viz.explore_spd(matrices[:2], mode='geodesic')
fig.show()


**Interpreting the comparison above**:
- **Left panel**: The red ellipsoid (Euclidean mean) is noticeably **larger** than the input ellipsoids — this is the swelling effect!
- **Right panel**: The green ellipsoid (Riemannian mean) stays geometrically consistent with inputs

Check the determinant values below each panel. The Euclidean mean's determinant exceeds the input average, while the Riemannian mean preserves it.


## 2. Finsler (Randers) Metric Visualization

### Beyond Riemannian: When Distances Are Asymmetric

**Riemannian metrics** are symmetric: $d(A, B) = d(B, A)$. But many real systems have **directional bias**:
- Social networks: Following vs being followed
- Thermodynamics: Entropy production (time's arrow)
- Causal graphs: Cause → effect, not effect → cause

**Finsler geometry** generalizes Riemannian by allowing $F(v) \neq F(-v)$.

### The Randers Metric

The **Randers metric** is the simplest non-trivial Finsler structure:
$$F(v) = \sqrt{v^T A v} + b^T v$$

Components:
- $A$: Symmetric positive-definite matrix (the "Riemannian core")
- $b$: Drift vector (the "wind")

**Physical interpretation**: Imagine walking on a windy day:
- $\sqrt{v^T A v}$: Your walking speed (determined by terrain $A$)
- $b^T v$: Wind assistance/resistance

**Constraint**: For strong convexity, $|b|_A = \sqrt{b^T A^{-1} b} < 1$ (wind can't be faster than you walk).

### The Indicatrix

The **indicatrix** is the unit ball $\{v : F(v) = 1\}$:
- For Riemannian ($b=0$): Centered ellipsoid
- For Randers ($b \neq 0$): **Off-center** ellipsoid shifted opposite to drift


In [6]:
# Create a Randers metric
A = np.array([
    [2.0, 0.5, 0.0],
    [0.5, 1.5, 0.3],
    [0.0, 0.3, 1.0]
])

# Drift vector - must satisfy |b|_A < 1 for strong convexity
b = np.array([0.4, 0.2, 0.0])

# Check validity
A_inv = np.linalg.inv(A)
b_norm_A = np.sqrt(b @ A_inv @ b)
print(f"Drift norm |b|_A = {b_norm_A:.3f} (must be < 1)")


Drift norm |b|_A = 0.296 (must be < 1)


In [7]:
# Visualize the indicatrix (unit ball)
# Notice how it's NOT centered at the origin due to asymmetry!
fig = viz.explore_finsler(A, b, mode='indicatrix')
fig.show()


In [8]:
# Asymmetry analysis - compare F(v) vs F(-v)
fig = viz.explore_finsler(A, b, mode='asymmetry')
fig.show()


**Interpreting the indicatrix above**:
- The **colored surface** is the Randers indicatrix — notice it's NOT centered at origin!
- The **gray surface** is the Riemannian indicatrix ($b=0$) for comparison — centered and symmetric
- The **red arrow** shows the drift vector $b$

The indicatrix shifts **opposite** to the drift: moving "with the wind" (in direction $b$) costs less, so you can go further per unit cost.


In [9]:
# Geodesic asymmetry - forward vs backward paths have different costs
fig = viz.explore_finsler(A, b, mode='geodesic')
fig.show()


**Understanding the asymmetry plot above**:
- **Left (3D)**: Points on a sphere colored by asymmetry ratio $F(v)/F(-v)$
  - Blue = cheaper forward (with drift)
  - Red = cheaper backward (against drift)
- **Right (histogram)**: Distribution of asymmetry ratios
  - Red dashed line at 1.0 = symmetric
  - Spread away from 1.0 shows degree of asymmetry

For symmetric metrics (Riemannian), all points would be exactly at ratio = 1.


## 3. Fisher Information Geometry

### The Natural Metric on Parameter Space

For a parametric family $p(x|\theta)$, the **Fisher Information Matrix** is:
$$F_{ij}(\theta) = \mathbb{E}\left[\frac{\partial \log p(x|\theta)}{\partial \theta_i} \frac{\partial \log p(x|\theta)}{\partial \theta_j}\right]$$

This measures how much information observations carry about parameters.

### Why It's a Metric

The Fisher Information is the **unique** Riemannian metric (up to scale) that is:
1. **Invariant under reparameterization**: The geometry doesn't depend on how you label parameters
2. **Intrinsic to the statistical model**: Derived from the likelihood itself

### Stiff vs Sloppy Directions

In high-dimensional models (neural networks, biological systems), the Fisher matrix typically has:

- **Stiff directions** (large eigenvalues): Small parameter changes → big prediction changes. Data strongly constrains these.
- **Sloppy directions** (small eigenvalues): Large parameter changes → negligible prediction changes. "Don't care" directions.

**Eigenvalue spectrum** often spans 6+ orders of magnitude! The model manifold is a "hyper-ribbon" — long and thin.

### Natural Gradient

Standard gradient descent: $\theta \leftarrow \theta - \eta \nabla_\theta L$

**Natural gradient** accounts for the Fisher metric:
$$\theta \leftarrow \theta - \eta F^{-1} \nabla_\theta L$$

This makes equal-length steps in **information space**, not parameter space. Much faster convergence!


In [10]:
# Generate a "sloppy" Fisher matrix (wide eigenvalue spectrum)
dim = 10
condition_number = 1000

# Eigenvalues spanning several orders of magnitude
eigenvalues = 10 ** np.linspace(0, np.log10(condition_number), dim)

# Random orthogonal matrix
Q, _ = np.linalg.qr(np.random.randn(dim, dim))
fisher_matrix = Q @ np.diag(eigenvalues) @ Q.T

print(f"Fisher matrix condition number: {condition_number}")
print(f"Eigenvalue range: {eigenvalues.min():.2f} to {eigenvalues.max():.2f}")


Fisher matrix condition number: 1000
Eigenvalue range: 1.00 to 1000.00


In [11]:
# Visualize stiff vs sloppy directions
fig = viz.explore_fisher(fisher_matrix, threshold=0.01)
fig.show()


## 4. Interactive Exploration (with ipywidgets)

If you have `ipywidgets` installed, you can interactively adjust parameters!


In [12]:
# Interactive Randers metric explorer
# Sliders let you adjust drift strength and direction
viz.interactive_randers(A=np.eye(3), initial_drift=0.3)


VBox(children=(HBox(children=(FloatSlider(value=0.3, continuous_update=False, description='Drift |b|:', max=0.…

Output()

**Interpreting the Fisher geometry above**:
- **Left (3D)**: Principal directions colored by eigenvalue magnitude
  - Longer arrows = stiffer directions (large eigenvalues)
  - These are well-constrained by data
- **Right (bar chart)**: Eigenvalue spectrum on log scale
  - Red bars = stiff (above threshold)
  - Blue bars = sloppy (below threshold)
  
Notice the eigenvalues span **3 orders of magnitude**! This is the "sloppy model" phenomenon common in complex systems. Only a few directions actually matter for predictions.


In [13]:
# Interactive SPD explorer
# Adjust eigenvalue ranges to see how ellipsoids change
viz.interactive_spd(n_matrices=4)


VBox(children=(Dropdown(description='Mode:', index=1, options=('ellipsoids', 'means', 'geodesic', 'tangent'), …

Output()

In [14]:
# Interactive Fisher explorer
# Adjust condition number to see stiff/sloppy transition
viz.interactive_fisher(dim=8)


HBox(children=(FloatLogSlider(value=100.0, continuous_update=False, description='Cond #:', max=3.0), FloatLogS…

Output()

## 5. Using with Your Actual Manifolds

Connect the visualizer to your learned geometries:


In [15]:
# Example: Visualize a Randers metric from diffgeo
try:
    from diffgeo.geometry import RandersMetric, make_randers_spd
    import jax
    
    key = jax.random.PRNGKey(42)
    metric = make_randers_spd(dim=3, key=key, drift_strength=0.4)
    
    A_np = np.array(metric.A)
    b_np = np.array(metric.b)
    
    print(f"Created RandersMetric with drift strength {np.linalg.norm(b_np):.3f}")
    fig = viz.explore_finsler(A_np, b_np)
    fig.show()
except ImportError as e:
    print(f"JAX not available: {e}")


rocm_plugin_extension not found


Created RandersMetric with drift strength 0.501


## Summary: The Geometric Perspective

### Key Takeaways

| Geometry | Structure | Key Insight | Application |
|----------|-----------|-------------|-------------|
| **SPD Manifold** | Covariance matrices | Euclidean mean causes swelling; Riemannian mean preserves geometry | BCI, DTI, portfolio optimization |
| **Finsler/Randers** | Asymmetric distances | Indicatrix shifts opposite to drift; $F(v) \neq F(-v)$ | Directed graphs, causal inference |
| **Fisher Information** | Parameter space metric | Stiff/sloppy decomposition; eigenvalues span orders of magnitude | Natural gradient, model compression |

### The Unifying Theme: Covariance

All three geometries are about **how things co-vary**:
- **SPD**: Statistical covariance of random variables
- **Finsler**: Directional covariance (how cost varies with direction)  
- **Fisher**: How predictions covary with parameters

### Practical Impact

1. **Optimization**: Natural gradient (using Fisher metric) converges faster than vanilla SGD
2. **Classification**: Riemannian methods on SPD matrices achieve state-of-the-art in BCI
3. **Embedding**: Finsler MDS preserves asymmetric relationships that Euclidean MDS destroys

The visualizations in this notebook make these abstract concepts **tangible** — you can see the swelling, the drift, the sloppy directions.


## 6. Tensor Variance: Vectors vs Covectors (Gradients)

### The Fundamental Distinction

In differential geometry, there are two types of "vector-like" objects:

| Type | Symbol | Transform Rule | Example |
|------|--------|----------------|---------|
| **Contravariant** (Vector) | $v^i$ | $v'^i = J^{-1} v^i$ | Velocity, displacement |
| **Covariant** (Covector) | $\alpha_i$ | $\alpha'_i = \alpha_i J$ | Gradient, price |

**Key insight for ML**: Gradients are **covectors**, not vectors! They live in the cotangent space $T^*M$.

To update weights (which ARE vectors in parameter space), we need a **metric** to convert:
$$\Delta w = g^{-1}(\nabla L)$$

This is **index raising** — the musical isomorphism ♯ (sharp).

### Visual Intuition

Imagine a 2D coordinate system that gets stretched:
- **Vectors** (arrows) stretch along with coordinates
- **Covectors** (level sets of a function) stretch inversely — they get closer together when coordinates stretch!


In [16]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np

def visualize_vector_covector_duality():
    """
    Visualize how vectors and covectors transform differently under coordinate change.
    """
    fig = make_subplots(
        rows=1, cols=3,
        subplot_titles=(
            "Original Coordinates",
            "Stretched Coordinates (2x horizontal)",
            "Gradient → Update (with metric)"
        ),
        specs=[[{"type": "scatter"}, {"type": "scatter"}, {"type": "scatter"}]]
    )
    
    # Original vector and covector
    vector = np.array([1, 1])  # Points northeast
    covector_normal = np.array([1, 0.5])  # Gradient direction
    
    # Jacobian for coordinate transformation (stretch x by 2)
    J = np.array([[2, 0], [0, 1]])  # x' = 2x, y' = y
    J_inv = np.linalg.inv(J)
    
    # Transform rules:
    # Vector (contravariant): v' = J^{-1} v  (shrinks in x because coordinates stretched)
    # Covector (covariant): α' = α @ J  (stretches in x to compensate)
    
    vector_transformed = J_inv @ vector
    covector_transformed = covector_normal @ J
    
    # Normalize for display
    covector_norm = covector_normal / np.linalg.norm(covector_normal)
    covector_trans_norm = covector_transformed / np.linalg.norm(covector_transformed)
    
    # --- Panel 1: Original ---
    # Vector (arrow)
    fig.add_trace(go.Scatter(
        x=[0, vector[0]], y=[0, vector[1]],
        mode='lines+markers',
        line=dict(color='blue', width=3),
        marker=dict(size=[5, 12], symbol=['circle', 'triangle-up']),
        name='Vector v'
    ), row=1, col=1)
    
    # Covector (shown as perpendicular lines - level sets)
    for t in np.linspace(-0.5, 1.5, 8):
        # Level set: α · x = t
        x_line = np.linspace(-0.5, 2, 50)
        y_line = (t - covector_normal[0] * x_line) / (covector_normal[1] + 1e-8)
        mask = (y_line > -0.5) & (y_line < 2)
        fig.add_trace(go.Scatter(
            x=x_line[mask], y=y_line[mask],
            mode='lines', line=dict(color='red', width=1, dash='dash'),
            showlegend=False
        ), row=1, col=1)
    
    # Covector direction arrow
    fig.add_trace(go.Scatter(
        x=[0, covector_norm[0]], y=[0, covector_norm[1]],
        mode='lines+markers',
        line=dict(color='red', width=3),
        marker=dict(size=[5, 12], symbol=['circle', 'arrow']),
        name='Covector α (gradient)'
    ), row=1, col=1)
    
    # --- Panel 2: Transformed ---
    fig.add_trace(go.Scatter(
        x=[0, vector_transformed[0]], y=[0, vector_transformed[1]],
        mode='lines+markers',
        line=dict(color='blue', width=3),
        marker=dict(size=[5, 12], symbol=['circle', 'triangle-up']),
        name="v' = J⁻¹v (shrunk)",
        showlegend=True
    ), row=1, col=2)
    
    # Transformed level sets
    for t in np.linspace(-0.5, 1.5, 8):
        x_line = np.linspace(-0.5, 2, 50)
        y_line = (t - covector_transformed[0] * x_line) / (covector_transformed[1] + 1e-8)
        mask = (y_line > -0.5) & (y_line < 2)
        fig.add_trace(go.Scatter(
            x=x_line[mask], y=y_line[mask],
            mode='lines', line=dict(color='orange', width=1, dash='dash'),
            showlegend=False
        ), row=1, col=2)
    
    fig.add_trace(go.Scatter(
        x=[0, covector_trans_norm[0]], y=[0, covector_trans_norm[1]],
        mode='lines+markers',
        line=dict(color='orange', width=3),
        marker=dict(size=[5, 12], symbol=['circle', 'arrow']),
        name="α' = αJ (stretched)"
    ), row=1, col=2)
    
    # --- Panel 3: Metric dualization ---
    # Show gradient → update conversion with a metric
    gradient = np.array([1, 0.3])  # Raw gradient (covector)
    
    # Different metrics give different updates
    metric_euclidean = np.eye(2)
    metric_anisotropic = np.array([[4, 0], [0, 1]])  # Stretch in x
    
    update_euclidean = np.linalg.inv(metric_euclidean) @ gradient
    update_anisotropic = np.linalg.inv(metric_anisotropic) @ gradient
    
    # Gradient
    fig.add_trace(go.Scatter(
        x=[0, gradient[0]], y=[0, gradient[1]],
        mode='lines+markers',
        line=dict(color='red', width=3),
        marker=dict(size=[5, 12]),
        name='∇L (gradient/covector)'
    ), row=1, col=3)
    
    # Euclidean update
    fig.add_trace(go.Scatter(
        x=[0, update_euclidean[0]], y=[0, update_euclidean[1]],
        mode='lines+markers',
        line=dict(color='green', width=3),
        marker=dict(size=[5, 12]),
        name='Δw (Euclidean metric)'
    ), row=1, col=3)
    
    # Anisotropic update
    fig.add_trace(go.Scatter(
        x=[0, update_anisotropic[0]], y=[0, update_anisotropic[1]],
        mode='lines+markers',
        line=dict(color='purple', width=3),
        marker=dict(size=[5, 12]),
        name='Δw (anisotropic metric)'
    ), row=1, col=3)
    
    # Draw the anisotropic metric as an ellipse
    theta = np.linspace(0, 2*np.pi, 100)
    # Ellipse from metric: x^T M x = 1
    eigvals, eigvecs = np.linalg.eigh(metric_anisotropic)
    a, b = 1/np.sqrt(eigvals)
    ellipse_x = a * np.cos(theta) * 0.5
    ellipse_y = b * np.sin(theta) * 0.5
    fig.add_trace(go.Scatter(
        x=ellipse_x, y=ellipse_y,
        mode='lines', line=dict(color='gray', width=1, dash='dot'),
        name='Metric ellipse'
    ), row=1, col=3)
    
    fig.update_xaxes(range=[-0.5, 2], row=1, col=1)
    fig.update_yaxes(range=[-0.5, 2], row=1, col=1)
    fig.update_xaxes(range=[-0.5, 2], row=1, col=2)
    fig.update_yaxes(range=[-0.5, 2], row=1, col=2)
    fig.update_xaxes(range=[-0.5, 1.5], row=1, col=3)
    fig.update_yaxes(range=[-0.5, 1.5], row=1, col=3)
    
    fig.update_layout(
        height=400, width=1200,
        title="Tensor Variance: Vectors (contravariant) vs Covectors (covariant)",
        showlegend=True
    )
    
    return fig

fig = visualize_vector_covector_duality()
fig.show()


**Interpreting the panels above**:

1. **Left (Original)**: Blue vector and red covector (gradient) in standard coordinates
2. **Middle (Transformed)**: After stretching x by 2:
   - Vector **shrinks** in x (contravariant: transforms with $J^{-1}$)
   - Covector level sets become **closer** in x (covariant: transforms with $J$)
3. **Right (Dualization)**: Same gradient, different metrics → different update directions!
   - Green: Euclidean metric ($g = I$) → update = gradient
   - Purple: Anisotropic metric → update rotated toward "easier" direction

This is why **natural gradient** (using Fisher metric) converges faster — it finds the right direction, not just the steepest descent in parameter space!


## 7. Parity and Reflection: Even vs Odd Tensors

### What Is Parity?

**Parity** describes how a quantity transforms under **reflection** (coordinate flip with $\det(R) = -1$):

| Parity | Transform Rule | Example | Physical Meaning |
|--------|----------------|---------|------------------|
| **EVEN (+1)** | $T' = T$ | Position, velocity, electric field | "True" vectors/tensors |
| **ODD (-1)** | $T' = -T$ | Angular momentum, magnetic field | Pseudovectors (axial vectors) |

### The Mirror Test

Hold an object in front of a mirror:
- **Polar vector** (EVEN): Velocity arrow pointing right → mirror shows arrow pointing left. This is the "expected" reflection.
- **Pseudovector** (ODD): Angular momentum of spinning top → mirror shows it spinning the **wrong way**! The reflection flips the direction.

### Why This Matters: Chirality

**Chiral objects** are not superimposable on their mirror images:
- Left hand vs right hand
- L-amino acids vs D-amino acids
- Many drug molecules (one enantiomer heals, the other harms!)

A standard neural network **cannot distinguish** chiral objects — they have the same atoms, bonds, and local structure. But with **odd parity embeddings** (like `TwistedEmbed`), we can!


In [17]:
def visualize_parity_reflection():
    """
    Visualize how EVEN and ODD parity objects transform under reflection.
    """
    fig = make_subplots(
        rows=1, cols=3,
        subplot_titles=(
            "Original",
            "Reflected (EVEN parity)",
            "Reflected (ODD parity)"
        ),
        specs=[[{"type": "scatter3d"}, {"type": "scatter3d"}, {"type": "scatter3d"}]]
    )
    
    # Create a simple 3D object (like a spiral/helix - chiral!)
    t = np.linspace(0, 2*np.pi, 100)
    x = np.cos(t)
    y = np.sin(t)
    z = t / (2*np.pi)  # Rising helix
    
    # Add handedness indicator (arrow showing rotation direction)
    arrow_t = np.pi
    arrow_x = [np.cos(arrow_t), np.cos(arrow_t) - 0.3*np.sin(arrow_t)]
    arrow_y = [np.sin(arrow_t), np.sin(arrow_t) + 0.3*np.cos(arrow_t)]
    arrow_z = [arrow_t/(2*np.pi), arrow_t/(2*np.pi)]
    
    # Original helix
    fig.add_trace(go.Scatter3d(
        x=x, y=y, z=z,
        mode='lines',
        line=dict(color='blue', width=5),
        name='Right-handed helix'
    ), row=1, col=1)
    
    fig.add_trace(go.Scatter3d(
        x=arrow_x, y=arrow_y, z=arrow_z,
        mode='lines+markers',
        line=dict(color='red', width=3),
        marker=dict(size=[3, 8], symbol=['circle', 'diamond']),
        name='Rotation direction'
    ), row=1, col=1)
    
    # Reflection matrix (reflect across yz-plane: x → -x)
    # EVEN parity: just flip x
    x_reflected = -x
    arrow_x_reflected = [-ax for ax in arrow_x]
    
    fig.add_trace(go.Scatter3d(
        x=x_reflected, y=y, z=z,
        mode='lines',
        line=dict(color='green', width=5),
        name='Reflected helix (EVEN)'
    ), row=1, col=2)
    
    fig.add_trace(go.Scatter3d(
        x=arrow_x_reflected, y=arrow_y, z=arrow_z,
        mode='lines+markers',
        line=dict(color='red', width=3),
        marker=dict(size=[3, 8], symbol=['circle', 'diamond']),
        showlegend=False
    ), row=1, col=2)
    
    # Add text annotation
    fig.add_trace(go.Scatter3d(
        x=[0], y=[0], z=[1.3],
        mode='text',
        text=['LEFT-handed now!'],
        textfont=dict(size=12, color='green'),
        showlegend=False
    ), row=1, col=2)
    
    # ODD parity: flip x AND negate (flip sign)
    # This is what a pseudovector does!
    # For a helix, this means: reflect AND reverse direction
    x_odd = -x  # Reflect
    z_odd = -z + 1  # Negate (flip z direction, shift to keep in view)
    
    fig.add_trace(go.Scatter3d(
        x=x_odd, y=y, z=z_odd,
        mode='lines',
        line=dict(color='purple', width=5),
        name='Reflected (ODD) - sign flipped!'
    ), row=1, col=3)
    
    # Arrow for ODD - also negated
    arrow_z_odd = [1 - az for az in arrow_z]
    fig.add_trace(go.Scatter3d(
        x=arrow_x_reflected, y=arrow_y, z=arrow_z_odd,
        mode='lines+markers',
        line=dict(color='orange', width=3),
        marker=dict(size=[3, 8], symbol=['circle', 'diamond']),
        showlegend=False
    ), row=1, col=3)
    
    fig.add_trace(go.Scatter3d(
        x=[0], y=[0], z=[1.3],
        mode='text',
        text=['Still RIGHT-handed!'],
        textfont=dict(size=12, color='purple'),
        showlegend=False
    ), row=1, col=3)
    
    # Update layout
    camera = dict(eye=dict(x=1.5, y=1.5, z=0.8))
    for col in [1, 2, 3]:
        fig.update_scenes(
            camera=camera,
            xaxis=dict(range=[-1.5, 1.5]),
            yaxis=dict(range=[-1.5, 1.5]),
            zaxis=dict(range=[-0.2, 1.5]),
            row=1, col=col
        )
    
    fig.update_layout(
        height=500, width=1200,
        title="Parity: How EVEN vs ODD objects transform under reflection",
        showlegend=True
    )
    
    return fig

fig = visualize_parity_reflection()
fig.show()


**Understanding the visualization above**:

- **Left (Original)**: A right-handed helix (like DNA or a right-hand screw)
- **Middle (EVEN reflection)**: Simple mirror reflection — the helix becomes **left-handed**! This is what happens to standard vectors.
- **Right (ODD reflection)**: Reflection + sign flip — the helix remains **right-handed**! This is what pseudovectors do.

### Parity Composition: The Z₂ Group

When composing layers, parities multiply:

| Outer | Inner | Result | Rule |
|-------|-------|--------|------|
| EVEN (+1) | EVEN (+1) | EVEN (+1) | (+1)(+1) = +1 |
| EVEN (+1) | ODD (-1) | ODD (-1) | (+1)(-1) = -1 |
| ODD (-1) | EVEN (+1) | ODD (-1) | (-1)(+1) = -1 |
| **ODD (-1)** | **ODD (-1)** | **EVEN (+1)** | (-1)(-1) = +1 |

Two reflections = identity! This is tracked automatically in `diffgeo`:

```python
twisted = TwistedEmbed(...)  # ODD
layer = GeometricLinear(..., parity=Parity.ODD)  # ODD
net = layer @ twisted  # ODD × ODD = EVEN
```


## 8. Finsler Dualization: Asymmetric Gradient Updates

### The Problem with Symmetric Metrics

Standard optimization uses symmetric metrics (Euclidean or Riemannian):
$$\Delta w = g^{-1} \nabla L$$

But for **directed data** (causal graphs, time series), the gradient→update conversion should also be **asymmetric**!

### Randers Metric Dualization

For the Randers metric $F(v) = \sqrt{v^T A v} + b^T v$, the dualization is:

$$\Delta w = A^{-1}(\nabla L - b)$$

The drift vector $b$ **biases** the update direction:
- Updates aligned with drift are "cheaper"
- Updates against drift are "more expensive"

This is what `FinslerLinear` does — it learns both the weights AND the drift!


In [18]:
def visualize_finsler_dualization():
    """
    Visualize how Finsler (Randers) metric creates asymmetric gradient→update mapping.
    """
    fig = make_subplots(
        rows=1, cols=2,
        subplot_titles=(
            "Euclidean Dualization (symmetric)",
            "Finsler Dualization (asymmetric drift)"
        ),
        specs=[[{"type": "scatter"}, {"type": "scatter"}]]
    )
    
    # Generate gradient vectors on a grid
    n_arrows = 8
    angles = np.linspace(0, 2*np.pi, n_arrows, endpoint=False)
    
    gradients = np.array([[np.cos(a), np.sin(a)] for a in angles])
    
    # Metrics
    A = np.eye(2)  # Euclidean base
    drift = np.array([0.5, 0.2])  # Randers drift
    
    # Euclidean dualization: Δw = A^{-1} @ grad = grad (since A=I)
    updates_euclidean = gradients.copy()
    
    # Finsler dualization: Δw = A^{-1} @ (grad - drift_strength * drift)
    drift_strength = 0.6
    updates_finsler = gradients - drift_strength * drift
    
    # Normalize for display
    scale = 0.4
    
    # --- Panel 1: Euclidean ---
    # Draw unit circle (isotropic cost)
    theta = np.linspace(0, 2*np.pi, 100)
    fig.add_trace(go.Scatter(
        x=np.cos(theta), y=np.sin(theta),
        mode='lines', line=dict(color='lightblue', width=2),
        name='Unit cost (Euclidean)'
    ), row=1, col=1)
    
    # Draw gradients and updates (same for Euclidean)
    for i, (g, u) in enumerate(zip(gradients, updates_euclidean)):
        # Gradient arrow
        fig.add_trace(go.Scatter(
            x=[0, g[0]*scale], y=[0, g[1]*scale],
            mode='lines', line=dict(color='red', width=2),
            showlegend=(i==0), name='∇L (gradient)'
        ), row=1, col=1)
        
        # Update arrow (same direction for Euclidean)
        fig.add_trace(go.Scatter(
            x=[0, u[0]*scale*0.9], y=[0, u[1]*scale*0.9],
            mode='lines', line=dict(color='green', width=3, dash='dash'),
            showlegend=(i==0), name='Δw (update)'
        ), row=1, col=1)
    
    # --- Panel 2: Finsler ---
    # Draw shifted indicatrix (unit cost surface for Randers)
    # For Randers F(v) = |v| + b·v = 1, the indicatrix is a shifted circle
    # Shift = -drift (opposite to drift direction)
    center_shift = -drift * 0.3  # Approximate shift
    
    fig.add_trace(go.Scatter(
        x=np.cos(theta) + center_shift[0], y=np.sin(theta) + center_shift[1],
        mode='lines', line=dict(color='lightblue', width=2),
        name='Unit cost (Finsler)', showlegend=True
    ), row=1, col=2)
    
    # Draw original Euclidean circle for reference
    fig.add_trace(go.Scatter(
        x=np.cos(theta), y=np.sin(theta),
        mode='lines', line=dict(color='gray', width=1, dash='dot'),
        name='Euclidean reference', showlegend=True
    ), row=1, col=2)
    
    # Draw drift vector
    fig.add_trace(go.Scatter(
        x=[0, drift[0]*0.6], y=[0, drift[1]*0.6],
        mode='lines+markers',
        line=dict(color='purple', width=4),
        marker=dict(size=[5, 10], symbol=['circle', 'arrow-wide']),
        name='Drift b'
    ), row=1, col=2)
    
    # Draw gradients and shifted updates
    for i, (g, u) in enumerate(zip(gradients, updates_finsler)):
        # Gradient arrow
        fig.add_trace(go.Scatter(
            x=[0, g[0]*scale], y=[0, g[1]*scale],
            mode='lines', line=dict(color='red', width=2),
            showlegend=False
        ), row=1, col=2)
        
        # Finsler update (shifted by drift)
        u_norm = u / (np.linalg.norm(u) + 1e-8) * scale
        fig.add_trace(go.Scatter(
            x=[0, u_norm[0]*0.9], y=[0, u_norm[1]*0.9],
            mode='lines', line=dict(color='orange', width=3, dash='dash'),
            showlegend=(i==0), name='Δw (Finsler update)'
        ), row=1, col=2)
    
    fig.update_xaxes(range=[-1.2, 1.2], row=1, col=1)
    fig.update_yaxes(range=[-1.2, 1.2], row=1, col=1)
    fig.update_xaxes(range=[-1.2, 1.2], row=1, col=2)
    fig.update_yaxes(range=[-1.2, 1.2], row=1, col=2)
    
    fig.update_layout(
        height=500, width=1000,
        title="Finsler Dualization: Drift Biases Update Direction",
        showlegend=True
    )
    
    return fig

fig = visualize_finsler_dualization()
fig.show()


**Understanding the visualization**:

- **Left (Euclidean)**: Red gradient arrows and green update arrows point the **same direction**. The metric is symmetric — cost is the same in all directions.
- **Right (Finsler)**: The drift vector (purple) **shifts** all updates. Orange update arrows are systematically biased in the drift direction.

Notice the unit cost circle is **shifted** in the Finsler case — this visualizes the asymmetry. Moving "with the drift" is cheaper (smaller circle in that direction means reaching unit cost is easier).

### When to Use Finsler Metrics

1. **Time series**: Past → future is fundamentally different from future → past
2. **Causal graphs**: Cause → effect, not symmetric
3. **Network flows**: Traffic, social influence, supply chains have directionality
4. **Thermodynamics**: Entropy increase is favored over decrease


## 9. TwistedEmbed in Action: Chirality Detection

Let's see `TwistedEmbed` actually working! We'll compare it against `GeometricEmbed` for distinguishing chiral molecules.

### The Setup

- **Input**: Molecule token indices (same for both enantiomers)
- **Orientation**: +1 (right-handed) or -1 (left-handed)
- **Goal**: Distinguish L vs D forms (e.g., L-alanine vs D-alanine)


In [19]:
# TwistedEmbed vs GeometricEmbed for chirality

import jax
import jax.numpy as jnp
from diffgeo import TwistedEmbed, GeometricEmbed
from diffgeo.core.types import Parity

# Create both types of embeddings
twisted = TwistedEmbed(dEmbed=32, numEmbed=100)
standard = GeometricEmbed(dEmbed=32, numEmbed=100)

# Initialize with same random key for fair comparison
key = jax.random.PRNGKey(42)
twisted_weights = twisted.initialize(key)
standard_weights = standard.initialize(key)

# Molecule "tokens" (same for both enantiomers)
molecule_tokens = jnp.array([5, 12, 7, 3])  # e.g., C, N, H, O atoms

# Right-handed (orientation = +1) vs Left-handed (orientation = -1)
right_embed = twisted.forward(molecule_tokens, twisted_weights, orientation=+1.0)
left_embed = twisted.forward(molecule_tokens, twisted_weights, orientation=-1.0)

# Standard embedding (cannot see orientation!)
standard_right = standard.forward(molecule_tokens, standard_weights)
standard_left = standard.forward(molecule_tokens, standard_weights)  # Same tokens → same output

print("="*60)
print("CHIRALITY DETECTION COMPARISON")
print("="*60)
print(f"\n🔬 TwistedEmbed (parity={twisted.signature.parity.name}):")
print(f"   Right-handed norm: {jnp.linalg.norm(right_embed):.4f}")
print(f"   Left-handed norm:  {jnp.linalg.norm(left_embed):.4f}")
print(f"   ‖R - L‖ difference: {jnp.linalg.norm(right_embed - left_embed):.4f}")
print(f"   Can distinguish? {'✅ YES' if jnp.linalg.norm(right_embed - left_embed) > 0.01 else '❌ NO'}")

print(f"\n📦 GeometricEmbed (parity={standard.signature.parity.name}):")
print(f"   Right-handed norm: {jnp.linalg.norm(standard_right):.4f}")
print(f"   Left-handed norm:  {jnp.linalg.norm(standard_left):.4f}")
print(f"   ‖R - L‖ difference: {jnp.linalg.norm(standard_right - standard_left):.4f}")
print(f"   Can distinguish? {'✅ YES' if jnp.linalg.norm(standard_right - standard_left) > 0.01 else '❌ NO'}")

# Show the actual relationship
print("\n🧬 Mathematical relationship:")
print(f"   left_embed ≈ -right_embed? {jnp.allclose(left_embed, -right_embed)}")
print(f"   This is the ODD parity signature: T' = -T under reflection")


CHIRALITY DETECTION COMPARISON

🔬 TwistedEmbed (parity=ODD):
   Right-handed norm: 11.3137
   Left-handed norm:  11.3137
   ‖R - L‖ difference: 22.6274
   Can distinguish? ✅ YES

📦 GeometricEmbed (parity=EVEN):
   Right-handed norm: 11.3137
   Left-handed norm:  11.3137
   ‖R - L‖ difference: 0.0000
   Can distinguish? ❌ NO

🧬 Mathematical relationship:
   left_embed ≈ -right_embed? True
   This is the ODD parity signature: T' = -T under reflection


In [20]:
# Visualize the embedding space with PCA projection

from sklearn.decomposition import PCA

def visualize_chiral_embeddings():
    """Show how TwistedEmbed separates chiral pairs in embedding space."""
    
    # Generate embeddings for multiple "molecules" with both orientations
    n_molecules = 20
    keys = jax.random.split(jax.random.PRNGKey(123), n_molecules)
    
    right_embeddings = []
    left_embeddings = []
    
    for i in range(n_molecules):
        # Random molecule tokens
        tokens = jax.random.randint(keys[i], (4,), 0, 100)
        
        right = twisted.forward(tokens, twisted_weights, orientation=+1.0)
        left = twisted.forward(tokens, twisted_weights, orientation=-1.0)
        
        right_embeddings.append(np.array(right.flatten()))
        left_embeddings.append(np.array(left.flatten()))
    
    right_embeddings = np.array(right_embeddings)
    left_embeddings = np.array(left_embeddings)
    
    # PCA to 3D for visualization
    all_embeddings = np.vstack([right_embeddings, left_embeddings])
    pca = PCA(n_components=3)
    projected = pca.fit_transform(all_embeddings)
    
    right_proj = projected[:n_molecules]
    left_proj = projected[n_molecules:]
    
    # Create 3D scatter plot
    fig = go.Figure()
    
    # Right-handed molecules
    fig.add_trace(go.Scatter3d(
        x=right_proj[:, 0], y=right_proj[:, 1], z=right_proj[:, 2],
        mode='markers',
        marker=dict(size=8, color='blue', symbol='circle'),
        name='Right-handed (R)'
    ))
    
    # Left-handed molecules
    fig.add_trace(go.Scatter3d(
        x=left_proj[:, 0], y=left_proj[:, 1], z=left_proj[:, 2],
        mode='markers',
        marker=dict(size=8, color='red', symbol='diamond'),
        name='Left-handed (L)'
    ))
    
    # Draw lines connecting chiral pairs
    for i in range(n_molecules):
        fig.add_trace(go.Scatter3d(
            x=[right_proj[i, 0], left_proj[i, 0]],
            y=[right_proj[i, 1], left_proj[i, 1]],
            z=[right_proj[i, 2], left_proj[i, 2]],
            mode='lines',
            line=dict(color='gray', width=1),
            showlegend=False
        ))
    
    fig.update_layout(
        title="TwistedEmbed Embedding Space (PCA projection)<br>"
              "<sub>Lines connect chiral pairs — note they're perfectly separated!</sub>",
        scene=dict(
            xaxis_title="PC1",
            yaxis_title="PC2",
            zaxis_title="PC3"
        ),
        height=600, width=800
    )
    
    return fig

fig = visualize_chiral_embeddings()
fig.show()


**Interpretation**:

The 3D scatter plot shows molecules projected from 32D embedding space to 3D via PCA:
- **Blue circles**: Right-handed enantiomers
- **Red diamonds**: Left-handed enantiomers
- **Gray lines**: Connect chiral pairs (same molecule, opposite handedness)

Notice how each chiral pair is connected by a line passing roughly through the origin — this is the ODD parity in action: $e_L = -e_R$.


## 10. FinslerLinear: Dualization Deep Dive

Let's compare how `modula.atom.Linear` and `diffgeo.FinslerLinear` convert gradients to updates.

### The Key Difference

| Layer | Dualization Method | Result |
|-------|-------------------|--------|
| `Linear` | Newton-Schulz orthogonalization | Symmetric, direction-preserving |
| `FinslerLinear` | Finsler-aware orthogonalization | Asymmetric, drift-biased |

The Finsler version applies:
1. **Drift shift**: Subtract `drift_strength * drift` from gradient
2. **Drift-biased orthogonalization**: Adds small bias term during Newton-Schulz iterations


In [21]:
# Compare Linear vs FinslerLinear dualization

from modula.atom import Linear
from diffgeo import FinslerLinear

# Create layers
base_linear = Linear(16, 8)
finsler_linear = FinslerLinear(16, 8, drift_strength=0.5)

# Initialize weights
key = jax.random.PRNGKey(42)
base_weights = base_linear.initialize(key)
finsler_weights = finsler_linear.initialize(key)

# Create a gradient (simulating backprop)
k1, k2 = jax.random.split(key)
grad_matrix = jax.random.normal(k1, shape=(16, 8))

# Dualize with both methods
base_grads = [grad_matrix]
finsler_grads = [grad_matrix, finsler_weights[1]]  # Include drift

base_update = base_linear.dualize(base_grads, targetNorm=1.0)
finsler_update = finsler_linear.dualize(finsler_grads, targetNorm=1.0)

print("="*60)
print("DUALIZATION COMPARISON: Linear vs FinslerLinear")
print("="*60)

print(f"\n📊 Input gradient shape: {grad_matrix.shape}")
print(f"   Gradient Frobenius norm: {jnp.linalg.norm(grad_matrix):.4f}")

print(f"\n🔲 Linear (base Modula):")
print(f"   Update shape: {base_update[0].shape}")
print(f"   Update Frobenius norm: {jnp.linalg.norm(base_update[0]):.4f}")

# Check orthogonality: Q^T @ Q should be close to identity (scaled)
Q_base = base_update[0] / jnp.sqrt(16/8)
ortho_error_base = jnp.linalg.norm(Q_base.T @ Q_base - jnp.eye(8))
print(f"   Orthogonality error: {ortho_error_base:.6f}")

print(f"\n🔷 FinslerLinear:")
print(f"   Weight update shape: {finsler_update[0].shape}")
print(f"   Weight update norm: {jnp.linalg.norm(finsler_update[0]):.4f}")
print(f"   Drift update norm: {jnp.linalg.norm(finsler_update[1]):.6f}")

Q_finsler = finsler_update[0] / jnp.sqrt(16/8)
ortho_error_finsler = jnp.linalg.norm(Q_finsler.T @ Q_finsler - jnp.eye(8))
print(f"   Orthogonality error: {ortho_error_finsler:.6f}")

# Compare directions
dot_product = jnp.sum(base_update[0] * finsler_update[0])
norm_product = jnp.linalg.norm(base_update[0]) * jnp.linalg.norm(finsler_update[0])
cosine_sim = dot_product / (norm_product + 1e-8)
print(f"\n📐 Cosine similarity (base vs Finsler): {cosine_sim:.4f}")
print(f"   Angular difference: {jnp.arccos(jnp.clip(cosine_sim, -1, 1)) * 180 / jnp.pi:.2f}°")


DUALIZATION COMPARISON: Linear vs FinslerLinear

📊 Input gradient shape: (16, 8)
   Gradient Frobenius norm: 11.1846

🔲 Linear (base Modula):
   Update shape: (16, 8)
   Update Frobenius norm: 3.9846
   Orthogonality error: 0.021841

🔷 FinslerLinear:
   Weight update shape: (16, 8)
   Weight update norm: 3.9847
   Drift update norm: 0.100000
   Orthogonality error: 0.021602

📐 Cosine similarity (base vs Finsler): 0.9995
   Angular difference: 1.84°


In [22]:
# Visualize the dualization matrices

def visualize_dualization_comparison():
    """Heatmap comparison of Linear vs Finsler dualization."""
    
    fig = make_subplots(
        rows=1, cols=4,
        subplot_titles=(
            "Input Gradient",
            "Linear Update",
            "Finsler Update",
            "Difference"
        )
    )
    
    # Convert to numpy for plotting
    grad_np = np.array(grad_matrix)
    base_np = np.array(base_update[0])
    finsler_np = np.array(finsler_update[0])
    diff_np = finsler_np - base_np
    
    # Common colorscale range
    vmax = max(np.abs(grad_np).max(), np.abs(base_np).max(), np.abs(finsler_np).max())
    
    # Gradient
    fig.add_trace(go.Heatmap(
        z=grad_np, colorscale='RdBu', zmin=-vmax, zmax=vmax,
        showscale=False
    ), row=1, col=1)
    
    # Linear update
    fig.add_trace(go.Heatmap(
        z=base_np, colorscale='RdBu', zmin=-vmax, zmax=vmax,
        showscale=False
    ), row=1, col=2)
    
    # Finsler update
    fig.add_trace(go.Heatmap(
        z=finsler_np, colorscale='RdBu', zmin=-vmax, zmax=vmax,
        showscale=False
    ), row=1, col=3)
    
    # Difference (with its own scale)
    fig.add_trace(go.Heatmap(
        z=diff_np, colorscale='Viridis',
        colorbar=dict(title="Diff")
    ), row=1, col=4)
    
    fig.update_layout(
        height=300, width=1100,
        title="Weight Update Comparison: Linear vs FinslerLinear<br>"
              "<sub>The difference (right) shows drift-induced bias in update direction</sub>"
    )
    
    return fig

fig = visualize_dualization_comparison()
fig.show()


## 11. Geometric Composition: Tracking Signatures Through Networks

One of the powerful features of `diffgeo` is **automatic signature tracking** through layer composition.

### What Gets Tracked

| Property | Meaning | Composition Rule |
|----------|---------|------------------|
| **Domain variance** | Input type (vector/covector) | Must match for compatibility |
| **Codomain variance** | Output type | Flows through |
| **Parity** | Even/odd under reflection | Multiply (Z₂ group) |
| **Metric type** | Euclidean/Riemannian/Finsler | Propagates |

### Why This Matters

Without geometric tracking, you might accidentally:
- Feed a **gradient (covector)** into a layer expecting a **vector**
- Compose layers where parity cancels unexpectedly
- Lose track of which metric should be used for dualization

Let's see this in action:


In [23]:
# Geometric signature tracking through composition

from diffgeo import FinslerLinear, TwistedEmbed
from diffgeo.core.types import Parity, TensorVariance

# Create layers with different parities
embed_odd = TwistedEmbed(dEmbed=32, numEmbed=1000)  # ODD parity
finsler_even = FinslerLinear(64, 32, parity=Parity.EVEN)  # EVEN parity
finsler_odd = FinslerLinear(64, 32, parity=Parity.ODD)  # ODD parity

print("="*70)
print("GEOMETRIC SIGNATURE TRACKING")
print("="*70)

print("\n📦 Individual Layer Signatures:")
print(f"\n   TwistedEmbed:")
print(f"      Domain:   {embed_odd.signature.domain.name}")
print(f"      Codomain: {embed_odd.signature.codomain.name}")
print(f"      Parity:   {embed_odd.signature.parity.name} ({embed_odd.signature.parity.value:+d})")

print(f"\n   FinslerLinear (EVEN):")
print(f"      Domain:   {finsler_even.signature.domain.name}")
print(f"      Codomain: {finsler_even.signature.codomain.name}")
print(f"      Parity:   {finsler_even.signature.parity.name} ({finsler_even.signature.parity.value:+d})")

print(f"\n   FinslerLinear (ODD):")
print(f"      Domain:   {finsler_odd.signature.domain.name}")
print(f"      Codomain: {finsler_odd.signature.codomain.name}")
print(f"      Parity:   {finsler_odd.signature.parity.name} ({finsler_odd.signature.parity.value:+d})")

# Compose and see signatures propagate
print("\n" + "─"*70)
print("🔗 COMPOSITION EXAMPLES:")

# EVEN @ ODD = ODD
net1 = finsler_even @ embed_odd
print(f"\n   finsler_even @ embed_odd:")
print(f"      Parity: EVEN({+1}) × ODD({-1}) = {net1.signature.parity.name} ({net1.signature.parity.value:+d})")

# ODD @ ODD = EVEN (chirality cancels!)
net2 = finsler_odd @ embed_odd
print(f"\n   finsler_odd @ embed_odd:")
print(f"      Parity: ODD({-1}) × ODD({-1}) = {net2.signature.parity.name} ({net2.signature.parity.value:+d})")
print(f"      ⚠️  Chirality sensitivity CANCELLED! Two reflections = identity")

# ODD @ EVEN @ ODD = ODD (intermediate doesn't cancel)
finsler_mid = FinslerLinear(32, 32, parity=Parity.EVEN)
net3 = finsler_odd @ finsler_mid @ embed_odd
print(f"\n   finsler_odd @ finsler_even @ embed_odd:")
print(f"      Parity: ODD × EVEN × ODD = {net3.signature.parity.name}")

print("\n" + "─"*70)
print("💡 KEY INSIGHT:")
print("   Track parity to avoid accidentally cancelling chirality sensitivity!")


GEOMETRIC SIGNATURE TRACKING

📦 Individual Layer Signatures:

   TwistedEmbed:
      Domain:   SCALAR
      Codomain: CONTRAVARIANT
      Parity:   ODD (-1)

   FinslerLinear (EVEN):
      Domain:   CONTRAVARIANT
      Codomain: CONTRAVARIANT
      Parity:   EVEN (+1)

   FinslerLinear (ODD):
      Domain:   CONTRAVARIANT
      Codomain: CONTRAVARIANT
      Parity:   ODD (-1)

──────────────────────────────────────────────────────────────────────
🔗 COMPOSITION EXAMPLES:

   finsler_even @ embed_odd:
      Parity: EVEN(1) × ODD(-1) = ODD (-1)

   finsler_odd @ embed_odd:
      Parity: ODD(-1) × ODD(-1) = EVEN (+1)
      ⚠️  Chirality sensitivity CANCELLED! Two reflections = identity

   finsler_odd @ finsler_even @ embed_odd:
      Parity: ODD × EVEN × ODD = EVEN

──────────────────────────────────────────────────────────────────────
💡 KEY INSIGHT:
   Track parity to avoid accidentally cancelling chirality sensitivity!


## 12. Summary: The Full Geometric Picture

This notebook has explored the geometric structures underlying `diffgeo`:

### Manifold Structures

| Structure | Visualization | Use Case |
|-----------|---------------|----------|
| **SPD Manifold** | Ellipsoids, geodesics | Covariance matrices (EEG, DTI, finance) |
| **Randers/Finsler** | Shifted indicatrix | Directed graphs, causal inference, time series |
| **Fisher Information** | Stiff/sloppy directions | Natural gradient, model compression |

### Neural Network Components

| Component | Geometric Property | Benefit |
|-----------|-------------------|---------|
| **FinslerLinear** | Asymmetric dualization | Learns directional bias in data |
| **TwistedEmbed** | ODD parity | Distinguishes chiral/handed objects |
| **GeometricSignature** | Tracks variance, parity | Type-safe geometric composition |
| **MetricTensor** | Riemannian index raising | Natural gradient optimization |

### Key Mathematical Ideas

1. **Tensor Variance**: Gradients are covectors, updates are vectors — the metric converts between them
2. **Parity**: Even vs odd behavior under reflection; composes multiplicatively (Z₂ group)
3. **Finsler Geometry**: Generalizes Riemannian to allow asymmetric distances via drift vector
4. **Dualization**: The central operation converting loss gradients to parameter updates

### Practical Implications

- **Better optimization**: Natural gradient (metric-aware) converges faster than vanilla SGD
- **Richer representations**: Finsler metrics capture directional relationships
- **Chirality detection**: TwistedEmbed enables distinguishing mirror-image molecules
- **Type safety**: Geometric signatures catch composition errors at definition time


In [25]:
# Quick reference: try the CLI demos!
print("""
╔══════════════════════════════════════════════════════════════════════╗
║                        What You've Learned                           ║
╠══════════════════════════════════════════════════════════════════════╣
║                                                                      ║
║  1. SPD Manifolds     → Covariance lives on a curved space           ║
║  2. Finsler Metrics   → Asymmetric costs for directed data           ║
║  3. Fisher Geometry   → Stiff/sloppy directions in parameter space   ║
║  4. Tensor Variance   → Gradients ≠ vectors (need metric to convert) ║
║  5. Parity            → EVEN vs ODD behavior under reflection        ║
║  6. TwistedEmbed      → Chirality detection via ODD parity           ║
║  7. FinslerLinear     → Drift-biased gradient updates                ║
║  8. Composition       → Signatures track through layer stacking      ║
║                                                                      ║
╠══════════════════════════════════════════════════════════════════════╣
║  Try the interactive CLI demos:                                      ║
║                                                                      ║
║    python -m diffgeo.cli demo spd       # SPD manifold operations    ║
║    python -m diffgeo.cli demo finsler   # Asymmetric Finsler metrics ║
║    python -m diffgeo.cli demo chiral    # Chirality with TwistedEmbed║
║    python -m diffgeo.cli benchmark      # Performance comparison     ║
║                                                                      ║
╚══════════════════════════════════════════════════════════════════════╝
""")



╔══════════════════════════════════════════════════════════════════════╗
║                        What You've Learned                           ║
╠══════════════════════════════════════════════════════════════════════╣
║                                                                      ║
║  1. SPD Manifolds     → Covariance lives on a curved space           ║
║  2. Finsler Metrics   → Asymmetric costs for directed data           ║
║  3. Fisher Geometry   → Stiff/sloppy directions in parameter space   ║
║  4. Tensor Variance   → Gradients ≠ vectors (need metric to convert) ║
║  5. Parity            → EVEN vs ODD behavior under reflection        ║
║  6. TwistedEmbed      → Chirality detection via ODD parity           ║
║  7. FinslerLinear     → Drift-biased gradient updates                ║
║  8. Composition       → Signatures track through layer stacking      ║
║                                                                      ║
╠═════════════════════════════════════════════════