In [8]:
import sys
sys.path.insert(0, '../../Sinkhorn/src')

import jax
import jax.numpy as jnp
from ifst import Grid, LinearProblem, Sinkhorn
import functools

# Set random seed
key = jax.random.PRNGKey(42)

In [9]:
# Create two 1024x1024 images
size = 1024

# Create coordinate grid
x = jnp.linspace(0, 1, size)
y = jnp.linspace(0, 1, size)
X, Y = jnp.meshgrid(x, y)

# Image 1: Gaussian centered at (0.3, 0.3)
image1 = jnp.exp(-100 * ((X - 0.3)**2 + (Y - 0.3)**2))

# Image 2: Gaussian centered at (0.7, 0.7)
image2 = jnp.exp(-100 * ((X - 0.7)**2 + (Y - 0.7)**2))

# Normalize to probability distributions
image1 = image1 / jnp.sum(image1)
image2 = image2 / jnp.sum(image2)

# Flatten the images
a = image1.ravel()
b = image2.ravel()

# Create Grid geometry
grid_geom = Grid(
    grid_size=(size, size),
    epsilon=0.01,  # Regularization parameter
)

# Create the linear OT problem
ot_problem = LinearProblem(
    geom=grid_geom,
    a=a,
    b=b
)

# Create and run the Sinkhorn solver
solver = Sinkhorn(
    threshold=1e-3,
    max_iterations=1000,
    lse_mode=True  # Use log-sum-exp mode for stability
)

# Solve the problem
output = solver(ot_problem)

# Extract the Brenier potential (dual potential f)
brenier_potential = output.f

print(f"Brenier potential computed!")
print(f"Shape: {brenier_potential.shape}")
print(f"Converged: {output.converged}")
print(f"Sinkhorn distance: {output.reg_ot_cost:.6f}")

Brenier potential computed!
Shape: (1048576,)
Converged: True
Sinkhorn distance: 0.332439


In [10]:
# --- Step 1: Define Your Operator (Modified) ---
# CHANGED: Now accepts side_length as an argument.
def T_star_operator(potential: jax.Array, side_length: int) -> jax.Array:
    """
    Applies the T*_{F,p} operator to a potential field.
    """
    # CHANGED: The calculation is removed from here.
    potential_2d = potential.reshape((side_length, side_length))
    
    shifted_potential = jnp.roll(potential_2d, shift=1, axis=0)
    operated_potential_2d = (potential_2d + shifted_potential) * 0.45
    
    return operated_potential_2d.flatten()


# --- Step 2: Define the JIT-Compiled Computation Function (Modified) ---
# CHANGED: Added side_length to the decorator and function signature.
@functools.partial(jax.jit, static_argnames=['num_iterations', 'side_length'])
def compute_auxiliary_potential(phi_F: jax.Array, num_iterations: int, side_length: int) -> jax.Array:
    """
    Computes the auxiliary potential by summing the truncated power series.
    """
    psi_F = jnp.zeros_like(phi_F)
    current_term = phi_F
    
    for _ in range(num_iterations):
        psi_F += current_term
        # CHANGED: Pass side_length to the operator.
        current_term = T_star_operator(current_term, side_length=side_length)
        
    return psi_F


@functools.partial(jax.jit, static_argnames=['side_length'])
def compute_gradient_field(potential_flat: jax.Array, side_length: int) -> jax.Array:
    """
    Computes the gradient of a scalar field.

    Args:
        potential_flat: A flattened 1D array representing the scalar field.
        side_length: The size of one side of the 2D grid.

    Returns:
        A JAX array of shape (2, side_length, side_length) representing the
        gradient vector field, where [0,:,:] is the Y-derivative and
        [1,:,:] is the X-derivative.
    """
    # Step 1: Reshape the potential from flat to a 2D grid
    potential_2d = potential_flat.reshape((side_length, side_length))

    # Step 2: Compute the gradient using the built-in JAX function
    # This returns a list of 2 arrays: [gradient_along_axis_0, gradient_along_axis_1]
    # For a 2D image/grid, axis 0 is the Y-direction and axis 1 is the X-direction.
    grad_y, grad_x = jnp.gradient(potential_2d)

    # Step 3: Stack the two components into a single vector field array
    # The shape will be (2, 1024, 1024)
    gradient_field = jnp.stack([grad_y, grad_x], axis=0)
    
    return gradient_field

In [11]:
# --- Step 3: Example Usage (Modified) ---

print(f"JAX is running on: {jax.devices()[0].device_kind}")

grid_size = 1024
key = jax.random.PRNGKey(0)
phi_F = jax.random.normal(key, (grid_size * grid_size,))

K = 20 

# CHANGED: Calculate side_length ONCE, here in plain Python.
# int() is fine here because we are not inside a JIT context.
side_length = int(phi_F.shape[0]**0.5)
print(f"Calculated side_length: {side_length}")

print("\nCompiling and running the function...")
# CHANGED: Pass the pre-calculated side_length to the function.
auxiliary_potential = compute_auxiliary_potential(phi_F, num_iterations=K, side_length=side_length)

auxiliary_potential.block_until_ready()
print("Computation complete.")
print(f"Shape of output potential: {auxiliary_potential.shape}")

JAX is running on: NVIDIA GeForce RTX 2070 SUPER
Calculated side_length: 1024

Compiling and running the function...
Computation complete.
Shape of output potential: (1048576,)


In [12]:
# --- Example Usage ---

# Compute the gradient vector field
print("Compiling and running the gradient computation...")
grad_vector_field = compute_gradient_field(auxiliary_potential, side_length=grid_size)

# Block until the computation is finished on the GPU to see the result
grad_vector_field.block_until_ready()

print("Computation complete.")
print(f"Shape of the input potential (flat): {auxiliary_potential.shape}")
print(f"Shape of the output gradient field: {grad_vector_field.shape}")
print("\nThe output shape (2, 1024, 1024) represents:")
print(f"  - Gradient Y-component shape: {grad_vector_field[0].shape}")
print(f"  - Gradient X-component shape: {grad_vector_field[1].shape}")

Compiling and running the gradient computation...
Computation complete.
Shape of the input potential (flat): (1048576,)
Shape of the output gradient field: (2, 1024, 1024)

The output shape (2, 1024, 1024) represents:
  - Gradient Y-component shape: (1024, 1024)
  - Gradient X-component shape: (1024, 1024)


In [13]:
@functools.partial(jax.jit, static_argnames=['d'])
def _compute_transformed_coords(f: jnp.ndarray, base_grid: jnp.ndarray, d: int) -> jnp.ndarray:
    """
    Compute transformed coordinates for pulling back a vector field through f.
    
    Args:
        f: Transformation matrix (3, 3)
        base_grid: Base grid in homogeneous coordinates (d*d, 3)
        d: Grid dimension
        
    Returns:
        Pixel coordinates (2, d, d) for map_coordinates
    """
    # Apply transformation f to grid points
    transformed = f @ base_grid.T  # (3, d*d)
    
    # De-homogenize
    x_norm = transformed[0] / transformed[2]  # (d*d,)
    y_norm = transformed[1] / transformed[2]  # (d*d,)
    
    # Convert to pixel indices [0, d-1]
    x_pixel = x_norm * (d - 1.0)
    y_pixel = y_norm * (d - 1.0)
    
    # Stack as (y, x) for map_coordinates
    coords = jnp.stack([
        y_pixel.reshape(d, d),
        x_pixel.reshape(d, d)
    ], axis=0)
    
    return coords


@functools.partial(jax.jit, static_argnames=['d'])
def _pullback_vector_field(T: jnp.ndarray, f: jnp.ndarray, base_grid: jnp.ndarray, d: int) -> jnp.ndarray:
    """
    Pull back vector field T by transformation f.
    
    Computes T ∘ f, i.e., T(f(x)) at each grid point x.
    
    Args:
        T: Vector field (2, d, d) where T[0] is y-component, T[1] is x-component
        f: Transformation matrix (3, 3)
        base_grid: Base grid in homogeneous coordinates (d*d, 3)
        d: Grid dimension
        
    Returns:
        Pulled-back vector field (2, d, d)
    """
    # Get transformed coordinates
    coords = _compute_transformed_coords(f, base_grid, d)
    
    # Pull back each component of the vector field
    T_y_pulled = jax.scipy.ndimage.map_coordinates(
        T[0], coords, order=1, mode='constant', cval=0.0
    )
    T_x_pulled = jax.scipy.ndimage.map_coordinates(
        T[1], coords, order=1, mode='constant', cval=0.0
    )
    
    # Stack components
    T_pulled = jnp.stack([T_y_pulled, T_x_pulled], axis=0)
    
    return T_pulled


def IFSgradient(F, p, T, rho_F):
    """
    Compute gradients for IFS optimization.
    
    For each transformation f_i in F, computes:
        grad_i = p_i * rho_F(x) * T(f_i(x))
    
    where T(f_i(x)) is the pull-back of the vector field T by f_i.
    
    Args:
        F: List of n transformation matrices (each 3x3 JAX array)
        p: Probability vector (n,) JAX array
        T: Gradient vector field (2, d, d) where T[0] is y-component, T[1] is x-component
        rho_F: Fixed measure (d, d) JAX array
        
    Returns:
        Fgrads: List of n gradient vector fields, each (2, d, d)
        pgrad: Gradient w.r.t. probabilities (n,) - to be implemented
    """
    n = len(F)
    d = rho_F.shape[0]
    
    # Create base grid (same as in FixedMeasureSolver)
    y_coords = jnp.linspace(0.0, 1.0, d, dtype=jnp.float32)
    x_coords = jnp.linspace(0.0, 1.0, d, dtype=jnp.float32)
    grid_y, grid_x = jnp.meshgrid(y_coords, x_coords, indexing='ij')
    
    base_grid = jnp.stack([
        grid_x.ravel(),
        grid_y.ravel(),
        jnp.ones(d * d, dtype=jnp.float32)
    ], axis=1)  # (d*d, 3)
    
    # Compute gradients for each transformation
    Fgrads = []
    for i in range(n):
        f_i = F[i]
        p_i = p[i]
        
        # Pull back T by f_i: T(f_i(x))
        T_pulled = _pullback_vector_field(T, f_i, base_grid, d)
        
        # Multiply by p_i and rho_F(x): p_i * rho_F(x) * T(f_i(x))
        # Broadcasting: (2, d, d) * scalar * (d, d) -> (2, d, d)
        grad_i = p_i * rho_F[None, :, :] * T_pulled
        
        Fgrads.append(grad_i)
    
    # TODO: Compute gradient w.r.t. p (probability vector)
    # For now, placeholder
    pgrad = jnp.zeros_like(p)
    
    return Fgrads, pgrad

## IFS Gradient Computation

This cell implements the gradient computation for IFS optimization using surrogate gradients.

### Mathematical Background

For an IFS defined by transformations $f_1, \ldots, f_n$ with probabilities $p_1, \ldots, p_n$ and fixed measure $\rho_F$, the gradient with respect to each transformation $f_i$ is:

$$\text{grad}_i = p_i \cdot \rho_F(x) \cdot T(f_i(x))$$

where:
- $T$ is a vector field (the gradient of the Brenier potential)
- $T(f_i(x))$ is the pull-back of $T$ by the transformation $f_i$
- $\rho_F(x)$ is the fixed measure
- $p_i$ is the probability weight

### Implementation Details

The key computational challenge is computing $T(f_i(x))$ efficiently:

1. **Grid transformation**: Since $f_i$ is affine, we can compute where each grid point maps to
2. **Interpolation**: Use `map_coordinates` to pull back the vector field $T$
3. **Pointwise multiplication**: Multiply by $p_i$ and $\rho_F(x)$

The implementation reuses the grid transformation logic from the `ifs_solver` package for efficiency and consistency.

### Performance Notes

- The pull-back operation uses bilinear interpolation (order=1)
- Each component of the vector field is pulled back separately
- The base grid is created once and reused for all transformations
- The functions are JIT-compiled for optimal performance

In [15]:
# Test IFSgradient function
print("Testing IFSgradient function...")
print("=" * 70)

# Create test IFS (Sierpinski triangle)
F_test = [
    jnp.array([
        [0.5, 0.0, 0.0],
        [0.0, 0.5, 0.0],
        [0.0, 0.0, 1.0]
    ], dtype=jnp.float32),
    jnp.array([
        [0.5, 0.0, 0.5],
        [0.0, 0.5, 0.0],
        [0.0, 0.0, 1.0]
    ], dtype=jnp.float32),
    jnp.array([
        [0.5, 0.0, 0.0],
        [0.0, 0.5, 0.5],
        [0.0, 0.0, 1.0]
    ], dtype=jnp.float32)
]

p_test = jnp.array([1/3, 1/3, 1/3], dtype=jnp.float32)

# Create test fixed measure (uniform for simplicity)
d_test = 128
rho_F_test = jnp.ones((d_test, d_test), dtype=jnp.float32) / (d_test * d_test)

# Create test gradient field (using the one we computed earlier, downsampled)
# For testing, we'll create a simple gradient field
T_test = grad_vector_field[:, ::8, ::8]  # Downsample from 1024 to 128
print(f"T_test shape: {T_test.shape}")
print(f"rho_F_test shape: {rho_F_test.shape}")

# Compute gradients
print("\nComputing IFS gradients...")
Fgrads, pgrad = IFSgradient(F_test, p_test, T_test, rho_F_test)

print(f"\nResults:")
print(f"Number of F gradients: {len(Fgrads)}")
print(f"Shape of each F gradient: {Fgrads[0].shape}")
print(f"pgrad shape: {pgrad.shape}")

# Check properties
print(f"\nGradient statistics:")
for i, grad in enumerate(Fgrads):
    print(f"  F[{i}] gradient: mean={jnp.mean(grad):.6e}, max={jnp.max(jnp.abs(grad)):.6e}")

print("\n✓ IFSgradient function works!")

Testing IFSgradient function...
T_test shape: (2, 128, 128)
rho_F_test shape: (128, 128)

Computing IFS gradients...

Results:
Number of F gradients: 3
Shape of each F gradient: (2, 128, 128)
pgrad shape: (3,)

Gradient statistics:
  F[0] gradient: mean=-6.546317e-07, max=1.710885e-04
  F[1] gradient: mean=-5.915949e-07, max=1.542660e-04
  F[2] gradient: mean=-4.123000e-07, max=2.041800e-04

✓ IFSgradient function works!


In [16]:
# Optimized version using vmap for parallelization
@functools.partial(jax.jit, static_argnames=['d'])
def IFSgradient_optimized(F, p, T, rho_F, d):
    """
    Optimized IFS gradient computation using vmap.
    
    Args:
        F: Stacked transformation matrices (n, 3, 3)
        p: Probability vector (n,)
        T: Gradient vector field (2, d, d)
        rho_F: Fixed measure (d, d)
        d: Grid dimension (static argument for JIT)
        
    Returns:
        Fgrads: Gradient vector fields (n, 2, d, d)
        pgrad: Gradient w.r.t. probabilities (n,)
    """
    # Create base grid once
    y_coords = jnp.linspace(0.0, 1.0, d, dtype=jnp.float32)
    x_coords = jnp.linspace(0.0, 1.0, d, dtype=jnp.float32)
    grid_y, grid_x = jnp.meshgrid(y_coords, x_coords, indexing='ij')
    
    base_grid = jnp.stack([
        grid_x.ravel(),
        grid_y.ravel(),
        jnp.ones(d * d, dtype=jnp.float32)
    ], axis=1)  # (d*d, 3)
    
    # Vectorized pull-back across all transformations
    # vmap over axis 0 of F
    T_pulled_all = jax.vmap(
        lambda f: _pullback_vector_field(T, f, base_grid, d)
    )(F)  # (n, 2, d, d)
    
    # Compute gradients: p_i * rho_F * T_pulled_i
    # Broadcasting: (n, 1, 1, 1) * (1, 1, d, d) * (n, 2, d, d)
    Fgrads = p[:, None, None, None] * rho_F[None, None, :, :] * T_pulled_all
    
    # TODO: Gradient w.r.t. p
    pgrad = jnp.zeros_like(p)
    
    return Fgrads, pgrad


# Test the optimized version
print("\nTesting optimized IFSgradient...")
print("=" * 70)

# Stack F for the optimized version
F_test_stacked = jnp.stack(F_test, axis=0)  # (3, 3, 3)

print(f"F_test_stacked shape: {F_test_stacked.shape}")

# Compute gradients (optimized)
import time
start = time.perf_counter()
Fgrads_opt, pgrad_opt = IFSgradient_optimized(F_test_stacked, p_test, T_test, rho_F_test, d=d_test)
Fgrads_opt[0].block_until_ready()
time_opt = time.perf_counter() - start

print(f"\nOptimized results:")
print(f"Fgrads shape: {Fgrads_opt.shape}")
print(f"pgrad shape: {pgrad_opt.shape}")
print(f"Time: {time_opt:.4f}s")

# Compare with original version
print(f"\nComparing with original implementation...")
start = time.perf_counter()
Fgrads_orig, pgrad_orig = IFSgradient(F_test, p_test, T_test, rho_F_test)
Fgrads_orig[0].block_until_ready()
time_orig = time.perf_counter() - start
print(f"Original time: {time_orig:.4f}s")

# Check if results match
for i in range(len(F_test)):
    diff = jnp.max(jnp.abs(Fgrads_opt[i] - Fgrads_orig[i]))
    print(f"Max difference F[{i}]: {diff:.2e}")

if time_orig > time_opt:
    speedup = time_orig / time_opt
    print(f"\n✓ Optimized version is {speedup:.2f}x faster!")
else:
    print(f"\n(First call includes compilation overhead)")
    
print("\n✓ Optimized IFSgradient works!")


Testing optimized IFSgradient...
F_test_stacked shape: (3, 3, 3)

Optimized results:
Fgrads shape: (3, 2, 128, 128)
pgrad shape: (3,)
Time: 0.5309s

Comparing with original implementation...
Original time: 0.0059s
Max difference F[0]: 0.00e+00
Max difference F[1]: 0.00e+00
Max difference F[2]: 0.00e+00

(First call includes compilation overhead)

✓ Optimized IFSgradient works!


In [17]:
@functools.partial(jax.jit, static_argnames=['d'])
def _pullback_scalar_field(psi: jnp.ndarray, f: jnp.ndarray, base_grid: jnp.ndarray, d: int) -> jnp.ndarray:
    """
    Pull back scalar field psi by transformation f.
    
    Computes psi ∘ f, i.e., psi(f(x)) at each grid point x.
    
    Args:
        psi: Scalar field (d, d)
        f: Transformation matrix (3, 3)
        base_grid: Base grid in homogeneous coordinates (d*d, 3)
        d: Grid dimension
        
    Returns:
        Pulled-back scalar field (d, d)
    """
    # Get transformed coordinates
    coords = _compute_transformed_coords(f, base_grid, d)
    
    # Pull back the scalar field using bilinear interpolation
    psi_pulled = jax.scipy.ndimage.map_coordinates(
        psi, coords, order=1, mode='constant', cval=0.0
    )
    
    return psi_pulled


@functools.partial(jax.jit, static_argnames=['d'])
def compute_p_gradient(F, rho_F, psi, d):
    """
    Compute gradient with respect to probability vector p.
    
    For each p_i, the gradient is:
        grad_p[i] = ∫_X rho_F(x) * psi(f_i(x)) dx
    
    where the integral is a sum over the discrete grid X.
    
    Args:
        F: Stacked transformation matrices (n, 3, 3)
        rho_F: Fixed measure (d, d)
        psi: Auxiliary potential (d, d) - scalar field
        d: Grid dimension (static argument)
        
    Returns:
        pgrad: Gradient w.r.t. probabilities (n,)
    """
    # Create base grid once
    y_coords = jnp.linspace(0.0, 1.0, d, dtype=jnp.float32)
    x_coords = jnp.linspace(0.0, 1.0, d, dtype=jnp.float32)
    grid_y, grid_x = jnp.meshgrid(y_coords, x_coords, indexing='ij')
    
    base_grid = jnp.stack([
        grid_x.ravel(),
        grid_y.ravel(),
        jnp.ones(d * d, dtype=jnp.float32)
    ], axis=1)  # (d*d, 3)
    
    # Vectorized pull-back of psi across all transformations
    # vmap over axis 0 of F
    psi_pulled_all = jax.vmap(
        lambda f: _pullback_scalar_field(psi, f, base_grid, d)
    )(F)  # (n, d, d)
    
    # Compute integral: sum over grid of rho_F(x) * psi(f_i(x))
    # Broadcasting: (1, d, d) * (n, d, d) -> (n, d, d)
    integrand = rho_F[None, :, :] * psi_pulled_all  # (n, d, d)
    
    # Sum over spatial dimensions to get gradient for each p_i
    pgrad = jnp.sum(integrand, axis=(1, 2))  # (n,)
    
    return pgrad


# Update the optimized IFSgradient to not compute pgrad
@functools.partial(jax.jit, static_argnames=['d'])
def IFSgradient_F_only(F, p, T, rho_F, d):
    """
    Compute gradients with respect to F only (for surrogate gradient method).
    
    Note: Gradient w.r.t. p is computed separately using compute_p_gradient()
    since it requires the auxiliary potential psi, not the vector field T.
    
    Args:
        F: Stacked transformation matrices (n, 3, 3)
        p: Probability vector (n,)
        T: Gradient vector field (2, d, d)
        rho_F: Fixed measure (d, d)
        d: Grid dimension (static argument for JIT)
        
    Returns:
        Fgrads: Gradient vector fields (n, 2, d, d)
    """
    # Create base grid once
    y_coords = jnp.linspace(0.0, 1.0, d, dtype=jnp.float32)
    x_coords = jnp.linspace(0.0, 1.0, d, dtype=jnp.float32)
    grid_y, grid_x = jnp.meshgrid(y_coords, x_coords, indexing='ij')
    
    base_grid = jnp.stack([
        grid_x.ravel(),
        grid_y.ravel(),
        jnp.ones(d * d, dtype=jnp.float32)
    ], axis=1)  # (d*d, 3)
    
    # Vectorized pull-back across all transformations
    T_pulled_all = jax.vmap(
        lambda f: _pullback_vector_field(T, f, base_grid, d)
    )(F)  # (n, 2, d, d)
    
    # Compute gradients: p_i * rho_F * T_pulled_i
    Fgrads = p[:, None, None, None] * rho_F[None, None, :, :] * T_pulled_all
    
    return Fgrads


# Test the p gradient computation
print("\nTesting compute_p_gradient...")
print("=" * 70)

# Create a test auxiliary potential (using the one computed earlier, downsampled)
psi_test = auxiliary_potential.reshape(grid_size, grid_size)[::8, ::8]  # Downsample to 128x128
print(f"psi_test shape: {psi_test.shape}")
print(f"rho_F_test shape: {rho_F_test.shape}")
print(f"F_test_stacked shape: {F_test_stacked.shape}")

# Compute p gradient
pgrad_test = compute_p_gradient(F_test_stacked, rho_F_test, psi_test, d=d_test)

print(f"\nResults:")
print(f"pgrad shape: {pgrad_test.shape}")
print(f"pgrad values: {pgrad_test}")

print(f"\nGradient statistics:")
print(f"  Mean: {jnp.mean(pgrad_test):.6e}")
print(f"  Max: {jnp.max(pgrad_test):.6e}")
print(f"  Min: {jnp.min(pgrad_test):.6e}")
print(f"  Std: {jnp.std(pgrad_test):.6e}")

print("\n✓ compute_p_gradient works!")


Testing compute_p_gradient...
psi_test shape: (128, 128)
rho_F_test shape: (128, 128)
F_test_stacked shape: (3, 3, 3)

Results:
pgrad shape: (3,)
pgrad values: [-0.04286745 -0.02920848  0.05633599]

Gradient statistics:
  Mean: -5.246644e-03
  Max: 5.633599e-02
  Min: -4.286745e-02
  Std: 4.390109e-02

✓ compute_p_gradient works!


In [18]:
# ==============================================================================
# COMPLETE WORKFLOW EXAMPLE: Computing all gradients for IFS optimization
# ==============================================================================

print("=" * 80)
print("COMPLETE IFS GRADIENT COMPUTATION WORKFLOW")
print("=" * 80)

# -----------------------------------------------------------------------------
# Step 1: Set up IFS parameters
# -----------------------------------------------------------------------------
print("\n[Step 1] Setting up IFS...")

# Example: Sierpinski triangle IFS
F_example = jnp.stack([
    jnp.array([[0.5, 0.0, 0.0],
               [0.0, 0.5, 0.0],
               [0.0, 0.0, 1.0]], dtype=jnp.float32),
    jnp.array([[0.5, 0.0, 0.5],
               [0.0, 0.5, 0.0],
               [0.0, 0.0, 1.0]], dtype=jnp.float32),
    jnp.array([[0.5, 0.0, 0.0],
               [0.0, 0.5, 0.5],
               [0.0, 0.0, 1.0]], dtype=jnp.float32)
], axis=0)  # Shape: (3, 3, 3)

p_example = jnp.array([1/3, 1/3, 1/3], dtype=jnp.float32)

print(f"  IFS transformations: {F_example.shape[0]} maps")
print(f"  Probabilities: {p_example}")

# -----------------------------------------------------------------------------
# Step 2: Compute fixed measure ρ_F (using FixedMeasureSolver from ifs_solver)
# -----------------------------------------------------------------------------
print("\n[Step 2] Computing fixed measure ρ_F...")

# For this example, we'll use a uniform distribution
# In practice, you'd use: solver.solve(F=F_example, p=p_example)
d_example = 128
rho_F_example = jnp.ones((d_example, d_example), dtype=jnp.float32) / (d_example**2)

print(f"  Fixed measure shape: {rho_F_example.shape}")
print(f"  Total mass: {jnp.sum(rho_F_example):.6f}")

# -----------------------------------------------------------------------------
# Step 3: Prepare gradient field T (from Brenier potential)
# -----------------------------------------------------------------------------
print("\n[Step 3] Preparing gradient field T...")

# Downsample the gradient field we computed earlier
T_example = grad_vector_field[:, ::8, ::8]  # From 1024 to 128

print(f"  Gradient field shape: {T_example.shape}")
print(f"  Components: T[0] (y-gradient), T[1] (x-gradient)")

# -----------------------------------------------------------------------------
# Step 4: Prepare auxiliary potential ψ
# -----------------------------------------------------------------------------
print("\n[Step 4] Preparing auxiliary potential ψ...")

# Downsample the auxiliary potential
psi_example = auxiliary_potential.reshape(grid_size, grid_size)[::8, ::8]

print(f"  Auxiliary potential shape: {psi_example.shape}")
print(f"  Value range: [{jnp.min(psi_example):.4f}, {jnp.max(psi_example):.4f}]")

# -----------------------------------------------------------------------------
# Step 5: Compute gradient w.r.t. F
# -----------------------------------------------------------------------------
print("\n[Step 5] Computing gradients w.r.t. F...")

Fgrads_example = IFSgradient_F_only(F_example, p_example, T_example, rho_F_example, d=d_example)

print(f"  F gradients shape: {Fgrads_example.shape}")
print(f"  Interpretation: {Fgrads_example.shape[0]} transformations, each with 2D vector field")

for i in range(Fgrads_example.shape[0]):
    grad_norm = jnp.sqrt(jnp.sum(Fgrads_example[i]**2))
    print(f"    F[{i}] gradient L2 norm: {grad_norm:.4e}")

# -----------------------------------------------------------------------------
# Step 6: Compute gradient w.r.t. p
# -----------------------------------------------------------------------------
print("\n[Step 6] Computing gradients w.r.t. p...")

pgrad_example = compute_p_gradient(F_example, rho_F_example, psi_example, d=d_example)

print(f"  p gradient shape: {pgrad_example.shape}")
print(f"  p gradient values: {pgrad_example}")
print(f"  Gradient magnitude: {jnp.linalg.norm(pgrad_example):.4e}")

# -----------------------------------------------------------------------------
# Summary
# -----------------------------------------------------------------------------
print("\n" + "=" * 80)
print("GRADIENT COMPUTATION SUMMARY")
print("=" * 80)

print(f"""
Inputs:
  - IFS transformations F:     {F_example.shape}
  - Probabilities p:            {p_example.shape}
  - Fixed measure ρ_F:          {rho_F_example.shape}
  - Gradient field T:           {T_example.shape}
  - Auxiliary potential ψ:      {psi_example.shape}

Outputs:
  - Gradients w.r.t. F:         {Fgrads_example.shape}  (n vector fields)
  - Gradient w.r.t. p:          {pgrad_example.shape}   (n scalars)

Total gradients computed:       {Fgrads_example.shape[0]} (for F) + {pgrad_example.shape[0]} (for p)
                              = {Fgrads_example.shape[0] + pgrad_example.shape[0]} total gradient components

✓ All gradients computed successfully!
✓ Ready for optimization loop!
""")

COMPLETE IFS GRADIENT COMPUTATION WORKFLOW

[Step 1] Setting up IFS...
  IFS transformations: 3 maps
  Probabilities: [0.33333334 0.33333334 0.33333334]

[Step 2] Computing fixed measure ρ_F...
  Fixed measure shape: (128, 128)
  Total mass: 1.000000

[Step 3] Preparing gradient field T...
  Gradient field shape: (2, 128, 128)
  Components: T[0] (y-gradient), T[1] (x-gradient)

[Step 4] Preparing auxiliary potential ψ...
  Auxiliary potential shape: (128, 128)
  Value range: [-12.5424, 13.8567]

[Step 5] Computing gradients w.r.t. F...
  F gradients shape: (3, 2, 128, 128)
  Interpretation: 3 transformations, each with 2D vector field
    F[0] gradient L2 norm: 5.1133e-03
    F[1] gradient L2 norm: 5.0517e-03
    F[2] gradient L2 norm: 5.1160e-03

[Step 6] Computing gradients w.r.t. p...
  p gradient shape: (3,)
  p gradient values: [-0.04286745 -0.02920848  0.05633599]
  Gradient magnitude: 7.6580e-02

GRADIENT COMPUTATION SUMMARY

Inputs:
  - IFS transformations F:     (3, 3, 3)
  - 

## API Summary

### Functions for IFS Gradient Computation

#### 1. **IFSgradient_F_only** - Compute gradients w.r.t. transformations F

```python
Fgrads = IFSgradient_F_only(F, p, T, rho_F, d)
```

**Inputs:**
- `F`: Stacked transformation matrices `(n, 3, 3)`
- `p`: Probability vector `(n,)`
- `T`: Gradient vector field `(2, d, d)` from Brenier potential
- `rho_F`: Fixed measure `(d, d)`
- `d`: Grid dimension (static, for JIT)

**Output:**
- `Fgrads`: Gradient vector fields `(n, 2, d, d)`

**Formula:** For each i: `Fgrads[i] = p[i] * rho_F(x) * T(f_i(x))`

---

#### 2. **compute_p_gradient** - Compute gradient w.r.t. probability vector p

```python
pgrad = compute_p_gradient(F, rho_F, psi, d)
```

**Inputs:**
- `F`: Stacked transformation matrices `(n, 3, 3)`
- `rho_F`: Fixed measure `(d, d)`
- `psi`: Auxiliary potential `(d, d)` (scalar field)
- `d`: Grid dimension (static, for JIT)

**Output:**
- `pgrad`: Gradient scalars `(n,)`

**Formula:** For each i: `pgrad[i] = ∫_X rho_F(x) * psi(f_i(x)) dx`

---

### Helper Functions (Internal)

- **`_compute_transformed_coords(f, base_grid, d)`** - Transform grid coordinates by affine map
- **`_pullback_vector_field(T, f, base_grid, d)`** - Pull back vector field through transformation
- **`_pullback_scalar_field(psi, f, base_grid, d)`** - Pull back scalar field through transformation

All functions are JIT-compiled for performance.

---

### Usage in Optimization Loop

```python
# Setup (once)
d = 128
F = jnp.stack([f1, f2, f3], axis=0)  # (3, 3, 3)
p = jnp.array([1/3, 1/3, 1/3])

# In each iteration:
# 1. Compute fixed measure
rho_F = solver.solve(F=F, p=p)

# 2. Compute OT and get potentials
T = compute_gradient_field(brenier_potential, d)
psi = auxiliary_potential.reshape(d, d)

# 3. Compute gradients
Fgrads = IFSgradient_F_only(F, p, T, rho_F, d)
pgrad = compute_p_gradient(F, rho_F, psi, d)

# 4. Update parameters (your optimization step)
F_new = F - learning_rate * process_F_gradients(Fgrads)
p_new = p - learning_rate * pgrad
```

## Probability Gradient Computation

This section implements the gradient computation for the probability vector **p**.

### Mathematical Background

For the probability $p_i$ of transformation $f_i$, the gradient is:

$$\nabla_{p_i} = \int_X \rho_F(x) \cdot \psi(f_i(x)) \, dx$$

where:
- $\rho_F(x)$ is the fixed measure
- $\psi$ is the auxiliary potential (scalar field)
- $\psi(f_i(x))$ is the pull-back of $\psi$ by transformation $f_i$
- The integral is a discrete sum over the grid

### Implementation Details

**Key differences from F gradient:**
- Uses the **scalar field** ψ (auxiliary potential), not the vector field T
- Requires pull-back of a scalar field instead of a vector field
- Result is a scalar for each $p_i$, not a vector field

**Computation steps:**
1. Pull back ψ through each transformation: $\psi(f_i(x))$
2. Multiply by fixed measure: $\rho_F(x) \cdot \psi(f_i(x))$
3. Integrate (sum) over the grid: $\sum_x \rho_F(x) \cdot \psi(f_i(x))$

### Separation from F Gradient

The gradients w.r.t. F and p are computed **separately** because:
- F gradient needs the **vector field** T (gradient of Brenier potential)
- p gradient needs the **scalar field** ψ (auxiliary potential)
- They will be handled differently in the optimization loop

This separation provides flexibility for different optimization strategies.