In [None]:
import sys
import os
sys.path.append(os.path.abspath('..'))

import data_gen as dg
import ising as isg

import jax
import jax.numpy as jnp
import optax

import matplotlib.pyplot as plt

In [None]:
def potential_grad(samples, mu, cov):

    diff = samples - mu
    grad = jnp.linalg.solve(cov, diff.T).T

    return grad

In [3]:
def evolve_langevin(samples, mu, cov, eps=1e-2, n_evolution=1, seed=0):
    
    key = jax.random.PRNGKey(seed)
    evolved_samples = samples

    for _ in range(n_evolution):
        key, subkey = jax.random.split(key)
        grad = potential_grad(evolved_samples, mu, cov)
        noise = jax.random.normal(subkey, shape=evolved_samples.shape)
        evolved_samples = evolved_samples - eps * grad + jnp.sqrt(2 * eps) * noise

    return evolved_samples

In [None]:
from ott.tools import sinkhorn_divergence

def lm_loss(samples, mu, cov):

In [None]:
def plot_distributions(samples, evolved, epoch):
    d = samples.shape[1] if samples.ndim > 1 else 1

    plt.figure(figsize=(6, 5))

    if d == 1:
        samples_np = np.array(samples).flatten()
        evolved_np = np.array(evolved).flatten()
        plt.hist(samples_np, bins=50, alpha=0.5, label='Original', density=True)
        plt.hist(evolved_np, bins=50, alpha=0.5, label='Evolved', density=True)
        plt.xlabel("x")

    elif d == 2:
        samples_np = np.array(samples)
        evolved_np = np.array(evolved)
        plt.scatter(samples_np[:, 0], samples_np[:, 1], alpha=0.3, label='Original', s=10)
        plt.scatter(evolved_np[:, 0], evolved_np[:, 1], alpha=0.3, label='Evolved', s=10)
        plt.xlabel("x1")
        plt.ylabel("x2")

    else:
        print(f"Plot not implemented for dimension d = {d}")
        return

    plt.legend()
    plt.title(f"epoch {epoch}")
    plt.grid(True)
    plt.tight_layout()
    plt.show()


In [None]:
def lm_optimize(samples, n_epochs=1000, lr=1e-2, seed=10, eps=1e-2, n_evolution=10, plot_every=500):
    n, d = samples.shape

    key = jax.random.PRNGKey(seed)
    key_mu, key_A = jax.random.split(key)

    mu = jax.random.normal(key_mu, shape=(d,))
    A = jax.random.normal(key_A, shape=(d, d))

    params = {
        "mu": mu,
        "A": A
    }

    optimizer = optax.adam(lr)
    opt_state = optimizer.init(params)

    def loss_fn(params, samples):
        mu = params["mu"]
        A = params["A"]
        cov = A.T @ A  # ensure PSD
        return lm_loss(samples, mu, cov)

    for epoch in range(n_epochs):
        loss_val, grads = jax.value_and_grad(loss_fn)(params, samples)
        updates, opt_state = optimizer.update(grads, opt_state)
        params = optax.apply_updates(params, updates)

        if epoch % plot_every == 0 or epoch == n_epochs - 1:
            mu_val = np.array(params["mu"])
            A_val = np.array(params["A"])
            cov_val = A_val.T @ A_val
            print(f"epoch {epoch} | loss = {loss_val:.6f} | mean = {mu_val} | cov =\n{cov_val}")

            evolved = evolve_langevin(samples, params["mu"], cov_val, eps=eps, n_evolution=n_evolution, seed=seed+1)
            plot_distributions(samples, evolved, epoch)

    final_mu = jnp.array(params["mu"])
    final_cov = params["A"].T @ params["A"]
    return final_mu, final_cov

In [None]:
d = 2
mu_true, cov_true = dg.generate_gaussian_params(d, sigma_mu=10, sigma_cov=2, seed=0)
samples = dg.generate_gaussian_data(mu_true, cov_true, n_samples=500, seed=1)
print(mu_true, "\n\n", cov_true, "\n\n\n\n")

mu_hat, cov_hat = lm_optimize(samples, n_epochs=15000, lr=1e-2)

print("\n--------- results: ---------")
print("mu true:     ", mu_true)
print("mu hat:  ", mu_hat)
print("cov true:\n", cov_true)
print("cov hat:\n", cov_hat)