In [None]:
import jax
import jax.numpy as jnp
import optax

import matplotlib.pyplot as plt
import numpy as np

In [21]:
def generate_gaussian_samples(mu, cov, n_samples, seed=0):

    d = mu.shape[0]
    key = jax.random.PRNGKey(seed)
    z = jax.random.normal(key, shape=(n_samples, d))

    L = jnp.linalg.cholesky(cov)
    samples = mu + z @ L.T
    
    return samples

In [22]:
def potential_grad(sample, mu, cov):

    precision = jnp.linalg.inv(cov)

    return (sample - mu) @ precision

In [23]:
def evolve_langevin(samples, mu, cov, eps=1e-2, n_evolution=5, seed=0):
    
    key = jax.random.PRNGKey(seed)
    evolved_samples = samples

    for i in range(n_evolution):
        key, subkey = jax.random.split(key)
        grad = potential_grad(evolved_samples, mu, cov)
        noise = jax.random.normal(subkey, shape=evolved_samples.shape)
        evolved_samples = evolved_samples - eps * grad + jnp.sqrt(2 * eps) * noise

    return evolved_samples

In [24]:
from ott.tools import sinkhorn_divergence

def compute_sinkhorn_divergence(x, y, epsilon=0.1):

    divergence, sink = sinkhorn_divergence.sinkdiv(x, y, epsilon=epsilon)

    return divergence

In [25]:
def lm_loss(samples, mu, cov):

    evolved_samples = evolve_langevin(samples, mu, cov)

    return compute_sinkhorn_divergence(samples, evolved_samples)

In [27]:
def lm_optimize_with_tracking(samples, true_mu, true_cov, n_epochs=1000, lr=1e-2, seed=10, eps=1e-2, plot_every=2):
    n, d = samples.shape

    key = jax.random.PRNGKey(seed)
    key_mu, key_A = jax.random.split(key)

    mu = jax.random.normal(key_mu, shape=(d,))
    A = jax.random.normal(key_A, shape=(d, d))

    params = {
        "mu": mu,
        "A": A
    }

    optimizer = optax.adam(lr)
    opt_state = optimizer.init(params)

    def loss_fn(params, samples):
        mu = params["mu"]
        A = params["A"]
        cov = A.T @ A
        return lm_loss(samples, mu, cov)

    history = {
        "loss": [],
        "mu_l2": [],
        "cov_l2": [],
        "cov_fro": [],
    }

    for epoch in range(n_epochs):
        loss_val, grads = jax.value_and_grad(loss_fn)(params, samples)
        updates, opt_state = optimizer.update(grads, opt_state)
        params = optax.apply_updates(params, updates)

        # Tracking
        mu_val = params["mu"]
        A_val = params["A"]
        cov_val = A_val.T @ A_val

        history["loss"].append(loss_val)
        history["mu_l2"].append(jnp.linalg.norm(true_mu - mu_val))
        history["cov_l2"].append(jnp.linalg.norm(true_cov - cov_val, ord=2))
        history["cov_fro"].append(jnp.linalg.norm(true_cov - cov_val, ord='fro'))

        if epoch % plot_every == 0 or epoch == n_epochs - 1:
            print(f"epoch {epoch} | loss = {loss_val:.6f}")

    final_mu = params["mu"]
    final_cov = params["A"].T @ params["A"]

    return final_mu, final_cov, history


In [28]:
mu = jnp.array([0.0, 1.0, -1.0])

cov = jnp.array([
    [1.0, 0.5, 0.2],
    [0.5, 1.0, 0.3],
    [0.2, 0.3, 1.0]
])

In [None]:
samples = generate_gaussian_samples(mu, cov, n_samples=2000, seed=0)

final_mu, final_cov, history = lm_optimize_with_tracking(samples, mu, cov, n_epochs=800, lr=0.05, seed = 0)

epoch 0 | loss = 0.049816
epoch 2 | loss = 0.040484
epoch 4 | loss = 0.038063
epoch 6 | loss = 0.037248
epoch 8 | loss = 0.036910
epoch 10 | loss = 0.036734
epoch 12 | loss = 0.036623
epoch 14 | loss = 0.036537
epoch 16 | loss = 0.036460
epoch 18 | loss = 0.036385
epoch 20 | loss = 0.036306
epoch 22 | loss = 0.036209
epoch 24 | loss = 0.036089
epoch 26 | loss = 0.035952
epoch 28 | loss = 0.035820
epoch 30 | loss = 0.035714
epoch 32 | loss = 0.035645
epoch 34 | loss = 0.035611
epoch 36 | loss = 0.035599
epoch 38 | loss = 0.035589
epoch 40 | loss = 0.035561
epoch 42 | loss = 0.035513
epoch 44 | loss = 0.035449
epoch 46 | loss = 0.035378
epoch 48 | loss = 0.035302
epoch 50 | loss = 0.035227
epoch 52 | loss = 0.035151
epoch 54 | loss = 0.035074
epoch 56 | loss = 0.034988
epoch 58 | loss = 0.034891
epoch 60 | loss = 0.034780
epoch 62 | loss = 0.034659
epoch 64 | loss = 0.034524
epoch 66 | loss = 0.034377
epoch 68 | loss = 0.034227
epoch 70 | loss = 0.034079
epoch 72 | loss = 0.033939
epoch 

In [None]:
jnp.set_printoptions(precision=4, suppress=True)

print(jnp.round(mu, 4), "\n\n", jnp.round(final_mu, 4), jnp.linalg.norm(mu - final_mu), "\n\n\n")
print(jnp.round(cov, 4), "\n\n", jnp.round(final_cov, 4), jnp.linalg.norm(cov - final_cov), "\n\n\n")

[0. 0.] 

 [0.0171 0.0117] 0.020761693 



[[2.001 2.   ]
 [2.    2.001]] 

 [[3.5716 3.564 ]
 [3.564  3.5769]] 3.1372166 





In [None]:
history_np = {k: np.array(v) for k, v in history.items()}

In [None]:
plt.figure(figsize=(12, 6))

plt.subplot(1, 2, 1)
plt.plot(history_np["loss"], label="langevin matching loss")
plt.xlabel("step")
plt.ylabel("loss")
plt.title("loss evolution")
plt.grid(True)
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(history_np["mu_l2"], label=r"$\| \mu - \hat{\mu} \|_2$")
plt.plot(history_np["cov_l2"], label=r"$\| \Sigma - \hat{\Sigma} \|_2$")
plt.plot(history_np["cov_fro"], label=r"$\| \Sigma - \hat{\Sigma} \|_F$")
plt.yscale("log")
plt.xlabel("step")
plt.title("reconstruction error")
plt.grid(True)
plt.legend()

plt.tight_layout()
plt.show()