In [17]:
import sys
import os
sys.path.append(os.path.abspath('..'))

import data_gen as dg
import ising as isg

import jax
import jax.numpy as jnp
import optax

import matplotlib.pyplot as plt

In [18]:
def potential_grad(samples, mu, cov):

    diff = samples - mu
    grad = jnp.linalg.solve(cov, diff.T).T

    return grad

In [19]:
def evolve_langevin(samples, mu, cov, eps=1e-2, n_evolution=1, seed=0):
    
    key = jax.random.PRNGKey(seed)
    evolved_samples = samples

    for _ in range(n_evolution):
        key, subkey = jax.random.split(key)
        grad = potential_grad(evolved_samples, mu, cov)
        noise = jax.random.normal(subkey, shape=evolved_samples.shape)
        evolved_samples = evolved_samples - eps * grad + jnp.sqrt(2 * eps) * noise

    return evolved_samples

In [20]:
def lm_loss(params, samples):

    mu = params["mu"]
    L = params["L"]
    precision = L @ L.T
    cov = jnp.linalg.inv(precision)

    evolved_samples = evolve_langevin(samples, mu, cov)

    centered = evolved_samples - mu
    precision2 = precision @ precision
    quad_terms = jnp.sum((centered @ precision2) * centered, axis=1)

    loss = 0.5 * jnp.mean(quad_terms) - jnp.trace(precision)
    
    return loss

In [21]:
def optimize_score_matching(samples, n_steps=1000, lr=1e-2, seed=0):

    d = samples.shape[1]
    key = jax.random.PRNGKey(seed)
    key_mu, key_L = jax.random.split(key)

    mu_init = jax.random.normal(key_mu, shape=(d,))
    L_init = jnp.eye(d) + 0.01 * jax.random.normal(key_L, shape=(d, d))

    L_init = jnp.tril(L_init)

    params = {"mu": mu_init, "L": L_init}
    optimizer = optax.adam(lr)
    opt_state = optimizer.init(params)

    loss_grad_fn = jax.value_and_grad(lm_loss)

    for step in range(n_steps):
        loss_val, grads = loss_grad_fn(params, samples)
        updates, opt_state = optimizer.update(grads, opt_state)
        params = optax.apply_updates(params, updates)

        params["L"] = jnp.tril(params["L"])

        if step % 20 == 0:
            print(f"Step {step:4d} | Loss: {loss_val:.6f}")

    return params

In [22]:
mu, cov = dg.generate_gaussian_params(d=1, sigma_mu=0.1, sigma_cov=0.2, seed=0)
samples = dg.generate_gaussian_data(mu, cov, n_samples=500, seed=0)

params_hat = optimize_score_matching(samples, n_steps=300, lr=1e-2)

mu_hat = params_hat["mu"]
precision_hat = params_hat["L"] @ params_hat["L"].T
cov_hat = jnp.linalg.inv(precision_hat)

Step    0 | Loss: -0.464309
Step   20 | Loss: -0.671532
Step   40 | Loss: -1.050031
Step   60 | Loss: -1.660648
Step   80 | Loss: -2.131709
Step  100 | Loss: -2.161872
Step  120 | Loss: -2.167278
Step  140 | Loss: -2.168020
Step  160 | Loss: -2.168081
Step  180 | Loss: -2.168081
Step  200 | Loss: -2.168082
Step  220 | Loss: -2.168082
Step  240 | Loss: -2.168082
Step  260 | Loss: -2.168082
Step  280 | Loss: -2.168082


In [23]:
jnp.set_printoptions(precision=4, suppress=True)

print(jnp.round(mu, 4), "\n\n", jnp.round(mu_hat, 4), jnp.linalg.norm(mu - mu_hat), "\n\n\n")
print(jnp.round(cov, 4), "\n\n", jnp.round(cov_hat, 4), jnp.linalg.norm(cov - cov_hat), "\n\n\n")

[0.1004] 

 [0.0837] 0.016668439 



[[0.2486]] 

 [[0.22]] 0.02860327 



