In [45]:
import sys
import os
sys.path.append(os.path.abspath('..'))

import data_gen as dg
import ising as isg

import jax
import jax.numpy as jnp
import optax

In [46]:
def grad_log_gaussian(x, mu, L):

    cov = L @ L.T
    delta = x - mu
    
    return -jnp.linalg.solve(cov, delta)

In [47]:
def langevin_step(x, key, mu, L, step_size=1e-1):

    noise = jax.random.normal(key, shape=x.shape)
    grad = grad_log_gaussian(x, mu, L)
    
    return x + step_size * grad + jnp.sqrt(2 * step_size) * noise

In [48]:
def evolve_one_sample(x, keys_i, mu, L, step_size, n_steps):

    for t in range(n_steps):
        x = langevin_step(x, keys_i[t], mu, L, step_size)

    return x

In [49]:
def evolve_samples(params, samples, n_evolution=1, step_size=1e-1, seed=0):

    mu = params["mu"]
    L = params["L"]
    n_samples, d = samples.shape

    key = jax.random.PRNGKey(seed)
    keys = jax.random.split(key, n_samples * n_evolution).reshape((n_samples, n_evolution, 2))

    new_samples = []
    for i in range(n_samples):
        x_i = samples[i]
        keys_i = keys[i]
        x_new = evolve_one_sample(x_i, keys_i, mu, L, step_size, n_evolution)
        new_samples.append(x_new)

    return jnp.stack(new_samples)

In [50]:
def rbf_kernel(first_entry, second_entry, sigma=1.0):

    diff = second_entry - first_entry
    
    return jnp.exp(-jnp.linalg.norm(diff)/(sigma**2))

In [51]:
def compute_term(samples_1, samples_2, sigma = 1.0):

    n_samples = len(samples_1)

    s = 0

    for i in range(n_samples):
        for j in range(n_samples):
            s += rbf_kernel(samples_1[i], samples_2[j], sigma)
        
    return s

In [52]:
def mmd_loss(samples, evolved_samples, sigma = 1.0):

    k_xx = compute_term(samples, samples, sigma)
    k_yy = compute_term(evolved_samples, evolved_samples, sigma)
    k_xy = compute_term(samples, evolved_samples, sigma)

    n_data = len(samples)

    mmd = (1/(n_data**2)) * (k_xx + k_yy - 2 * k_xy)
    print(mmd)

    return mmd

In [53]:
def langevin_matching_loss(params, samples):

    evolved_samples = evolve_samples(params, samples)
    loss = mmd_loss(samples, evolved_samples)
    
    return loss

In [None]:
def optimize_langevin_matching(samples, n_steps=1000, lr=1e-2, seed=0):

    d = samples.shape[1]
    key = jax.random.PRNGKey(seed)
    key_mu, key_L = jax.random.split(key)

    mu_init = jax.random.normal(key_mu, shape=(d,))
    L_init = jnp.eye(d) + 0.01 * jax.random.normal(key_L, shape=(d, d))

    L_init = jnp.tril(L_init)

    params = {"mu": mu_init, "L": L_init}
    optimizer = optax.adam(lr)
    opt_state = optimizer.init(params)

    loss_grad_fn = jax.value_and_grad(langevin_matching_loss)

    for step in range(n_steps):
        loss_val, grads = loss_grad_fn(params, samples)
        updates, opt_state = optimizer.update(grads, opt_state)
        params = optax.apply_updates(params, updates)

        params["L"] = jnp.tril(params["L"])

        if step % 2 == 0:
            print(f"Step {step:4d} | Loss: {loss_val:.6f}")

    return params

In [None]:
mu, cov = dg.generate_gaussian_params(d=5, sigma_mu=0.1, sigma_cov=0.2, seed=0)
samples = dg.generate_gaussian_data(mu, cov, n_samples=50, seed=1)

params_hat = optimize_langevin_matching(samples, n_steps=1000, lr=1e-2)

mu_hat = params_hat["mu"]
precision_hat = params_hat["L"] @ params_hat["L"].T
cov_hat = jnp.linalg.inv(precision_hat)

Traced<float32[]>with<JVPTrace> with
  primal = Array(0.03633267, dtype=float32)
  tangent = Traced<float32[]>with<JaxprTrace> with
    pval = (ShapedArray(float32[]), None)
    recipe = JaxprEqnRecipe(eqn_id=60658, in_tracers=(Traced<float32[]:JaxprTrace>, Traced<float32[]:JaxprTrace>), out_tracer_refs=[<weakref at 0x31ef54590; to 'jax._src.interpreters.partial_eval.JaxprTracer' at 0x31ef54550>], out_avals=[ShapedArray(float32[])], primitive=pjit, params={'jaxpr': { [34;1mlambda [39;22m; a[35m:f32[][39m b[35m:f32[][39m. [34;1mlet[39;22m c[35m:f32[][39m = mul b a [34;1min [39;22m(c,) }, 'in_shardings': (UnspecifiedValue, UnspecifiedValue), 'in_layouts': (None, None), 'out_shardings': (UnspecifiedValue,), 'out_layouts': (None,), 'donated_invars': (False, False), 'ctx_mesh': None, 'name': 'multiply', 'keep_unused': False, 'inline': True, 'compiler_options_kvs': ()}, effects=set(), source_info=<jax._src.source_info_util.SourceInfo object at 0x31ef49bd0>, ctx=JaxprEqnContext(co

Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x110a6da90>>
Traceback (most recent call last):
  File "/Users/lucaraffo/CFM/cfm_env/lib/python3.13/site-packages/ipykernel/ipkernel.py", line 775, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(
KeyboardInterrupt: 


Traced<float32[]>with<JVPTrace> with
  primal = Array(nan, dtype=float32)
  tangent = Traced<float32[]>with<JaxprTrace> with
    pval = (ShapedArray(float32[]), None)
    recipe = JaxprEqnRecipe(eqn_id=60658, in_tracers=(Traced<float32[]:JaxprTrace>, Traced<float32[]:JaxprTrace>), out_tracer_refs=[<weakref at 0x31ff5be20; to 'jax._src.interpreters.partial_eval.JaxprTracer' at 0x31ff5bde0>], out_avals=[ShapedArray(float32[])], primitive=pjit, params={'jaxpr': { [34;1mlambda [39;22m; a[35m:f32[][39m b[35m:f32[][39m. [34;1mlet[39;22m c[35m:f32[][39m = mul b a [34;1min [39;22m(c,) }, 'in_shardings': (UnspecifiedValue, UnspecifiedValue), 'in_layouts': (None, None), 'out_shardings': (UnspecifiedValue,), 'out_layouts': (None,), 'donated_invars': (False, False), 'ctx_mesh': None, 'name': 'multiply', 'keep_unused': False, 'inline': True, 'compiler_options_kvs': ()}, effects=set(), source_info=<jax._src.source_info_util.SourceInfo object at 0x31ff522c0>, ctx=JaxprEqnContext(compute_t

In [None]:
print(mu, "\n\n", mu_hat, "\n\n\n")
print(cov, "\n\n", cov_hat, "\n\n\n")