Modified from `jko_test.ipynb`

In [3]:
from jko_lab import *
import jax
from jax import numpy as jnp
from jax import lax
from sinkhorn import sinkhorn_flow
from animate import animate_hist_flow
import numpy as np
jax.config.update("jax_enable_x64", True)

In [None]:
n = 100
x = jnp.linspace(0.0, 1.0, n)
l = x[1] - x[0] 


X, Y = jnp.meshgrid(x, x, indexing="ij")
C = (X - Y) ** 2

# Define potential function V(x)
def V(x):
    # Zero potential (pure heat equation)
    return jnp.zeros_like(x)
    
    # Double-well potential
    # return 100 * (x - 0.3)**2 * (x - 0.7)**2
    
    # harmonic potential
    # return 50 * (x - 0.5)**2

v_vals = V(x)  

def entropy(r):
    r = jnp.clip(r, 2.2e-16, None)
    return -jnp.sum(r * jnp.log(r))


r0 = 0.7 * jax.scipy.stats.norm.pdf(x, 0.25, 0.03) + 0.3 * jax.scipy.stats.norm.pdf(x, 0.75, 0.04)
r0 = jnp.clip(r0, 1e-12, None)
r0 = r0 / (r0.sum() * l)  # normalize

tau = 1e-1
epsilon = 1e-2

p = X.shape[0]
b0 = jnp.ones(p)
rs = sinkhorn_flow(r0, C, b0, v_vals, l, tau, reg=epsilon, steps=25, iters=200)

b = jnp.ones_like(x) / (n * l) 
dists = jnp.linalg.norm(rs - b[None, :], axis=1)
Hs = jnp.array([entropy(r) for r in rs])

print("[Flow-Entropy] Entropy should increase, distance to uniform should decrease:")
for k in [0, 1, 5, 10, 20, 24]:
    if k < len(Hs):
        print(f" k={k:2d}: H={float(Hs[k]):.6f}, ||rho-b||={float(dists[k]):.6f}")

[Flow-Entropy] Entropy should increase, distance to uniform should decrease:
 k= 0: H=-137.654594, ||rho-b||=20.517648
 k= 1: H=0.955045, ||rho-b||=0.359932
 k= 5: H=1.018629, ||rho-b||=0.068550
 k=10: H=1.018629, ||rho-b||=0.068550
 k=20: H=1.018629, ||rho-b||=0.068550
 k=24: H=1.018629, ||rho-b||=0.068550


In [5]:
html = animate_hist_flow(
    mu_list=rs, x=np.asarray(x),
    target=b,                # or None
    interval=150,
    title="JKO Flow (entropy)"
)

html