In [1]:
import sys
sys.path.insert(0, '../../Sinkhorn/src')

import jax
import jax.numpy as jnp
from ifst import Grid, LinearProblem, Sinkhorn
import functools

# Set random seed
key = jax.random.PRNGKey(42)

In [2]:
# Create two 1024x1024 images
size = 1024

# Create coordinate grid
x = jnp.linspace(0, 1, size)
y = jnp.linspace(0, 1, size)
X, Y = jnp.meshgrid(x, y)

# Image 1: Gaussian centered at (0.3, 0.3)
image1 = jnp.exp(-100 * ((X - 0.3)**2 + (Y - 0.3)**2))

# Image 2: Gaussian centered at (0.7, 0.7)
image2 = jnp.exp(-100 * ((X - 0.7)**2 + (Y - 0.7)**2))

# Normalize to probability distributions
image1 = image1 / jnp.sum(image1)
image2 = image2 / jnp.sum(image2)

# Flatten the images
a = image1.ravel()
b = image2.ravel()

# Create Grid geometry
grid_geom = Grid(
    grid_size=(size, size),
    epsilon=0.01,  # Regularization parameter
)

# Create the linear OT problem
ot_problem = LinearProblem(
    geom=grid_geom,
    a=a,
    b=b
)

# Create and run the Sinkhorn solver
solver = Sinkhorn(
    threshold=1e-3,
    max_iterations=1000,
    lse_mode=True  # Use log-sum-exp mode for stability
)

# Solve the problem
output = solver(ot_problem)

# Extract the Brenier potential (dual potential f)
brenier_potential = output.f

print(f"Brenier potential computed!")
print(f"Shape: {brenier_potential.shape}")
print(f"Converged: {output.converged}")
print(f"Sinkhorn distance: {output.reg_ot_cost:.6f}")

Brenier potential computed!
Shape: (1048576,)
Converged: True
Sinkhorn distance: 0.332439


In [7]:
# --- Step 1: Define Your Operator (Modified) ---
# CHANGED: Now accepts side_length as an argument.
def T_star_operator(potential: jax.Array, side_length: int) -> jax.Array:
    """
    Applies the T*_{F,p} operator to a potential field.
    """
    # CHANGED: The calculation is removed from here.
    potential_2d = potential.reshape((side_length, side_length))
    
    shifted_potential = jnp.roll(potential_2d, shift=1, axis=0)
    operated_potential_2d = (potential_2d + shifted_potential) * 0.45
    
    return operated_potential_2d.flatten()


# --- Step 2: Define the JIT-Compiled Computation Function (Modified) ---
# CHANGED: Added side_length to the decorator and function signature.
@functools.partial(jax.jit, static_argnames=['num_iterations', 'side_length'])
def compute_auxiliary_potential(phi_F: jax.Array, num_iterations: int, side_length: int) -> jax.Array:
    """
    Computes the auxiliary potential by summing the truncated power series.
    """
    psi_F = jnp.zeros_like(phi_F)
    current_term = phi_F
    
    for _ in range(num_iterations):
        psi_F += current_term
        # CHANGED: Pass side_length to the operator.
        current_term = T_star_operator(current_term, side_length=side_length)
        
    return psi_F


@functools.partial(jax.jit, static_argnames=['side_length'])
def compute_gradient_field(potential_flat: jax.Array, side_length: int) -> jax.Array:
    """
    Computes the gradient of a scalar field.

    Args:
        potential_flat: A flattened 1D array representing the scalar field.
        side_length: The size of one side of the 2D grid.

    Returns:
        A JAX array of shape (2, side_length, side_length) representing the
        gradient vector field, where [0,:,:] is the Y-derivative and
        [1,:,:] is the X-derivative.
    """
    # Step 1: Reshape the potential from flat to a 2D grid
    potential_2d = potential_flat.reshape((side_length, side_length))

    # Step 2: Compute the gradient using the built-in JAX function
    # This returns a list of 2 arrays: [gradient_along_axis_0, gradient_along_axis_1]
    # For a 2D image/grid, axis 0 is the Y-direction and axis 1 is the X-direction.
    grad_y, grad_x = jnp.gradient(potential_2d)

    # Step 3: Stack the two components into a single vector field array
    # The shape will be (2, 1024, 1024)
    gradient_field = jnp.stack([grad_y, grad_x], axis=0)
    
    return gradient_field

In [8]:
# --- Step 3: Example Usage (Modified) ---

print(f"JAX is running on: {jax.devices()[0].device_kind}")

grid_size = 1024
key = jax.random.PRNGKey(0)
phi_F = jax.random.normal(key, (grid_size * grid_size,))

K = 20 

# CHANGED: Calculate side_length ONCE, here in plain Python.
# int() is fine here because we are not inside a JIT context.
side_length = int(phi_F.shape[0]**0.5)
print(f"Calculated side_length: {side_length}")

print("\nCompiling and running the function...")
# CHANGED: Pass the pre-calculated side_length to the function.
auxiliary_potential = compute_auxiliary_potential(phi_F, num_iterations=K, side_length=side_length)

auxiliary_potential.block_until_ready()
print("Computation complete.")
print(f"Shape of output potential: {auxiliary_potential.shape}")

JAX is running on: NVIDIA GeForce RTX 2070 SUPER
Calculated side_length: 1024

Compiling and running the function...
Computation complete.
Shape of output potential: (1048576,)


In [9]:
# --- Example Usage ---

# Compute the gradient vector field
print("Compiling and running the gradient computation...")
grad_vector_field = compute_gradient_field(auxiliary_potential, side_length=grid_size)

# Block until the computation is finished on the GPU to see the result
grad_vector_field.block_until_ready()

print("Computation complete.")
print(f"Shape of the input potential (flat): {auxiliary_potential.shape}")
print(f"Shape of the output gradient field: {grad_vector_field.shape}")
print("\nThe output shape (2, 1024, 1024) represents:")
print(f"  - Gradient Y-component shape: {grad_vector_field[0].shape}")
print(f"  - Gradient X-component shape: {grad_vector_field[1].shape}")

Compiling and running the gradient computation...
Computation complete.
Shape of the input potential (flat): (1048576,)
Shape of the output gradient field: (2, 1024, 1024)

The output shape (2, 1024, 1024) represents:
  - Gradient Y-component shape: (1024, 1024)
  - Gradient X-component shape: (1024, 1024)


In [None]:
def IFSGrad(p, F, mu_F, psi):
    