In [11]:
import sys
import os
sys.path.append(os.path.abspath('..'))

import data_gen as dg
import ising as isg

import jax
import jax.numpy as jnp
import optax

In [12]:
def score_matching_loss(params, samples):

    mu = params["mu"]
    L = params["L"]
    Lambda = L @ L.T

    centered = samples - mu
    Lambda2 = Lambda @ Lambda
    quad_terms = jnp.sum((centered @ Lambda2) * centered, axis=1)

    loss = 0.5 * jnp.mean(quad_terms) - jnp.trace(Lambda)
    
    return loss

In [13]:
def optimize_score_matching(samples, n_steps=1000, lr=1e-2, seed=0):

    d = samples.shape[1]
    key = jax.random.PRNGKey(seed)
    key_mu, key_L = jax.random.split(key)

    mu_init = jax.random.normal(key_mu, shape=(d,))
    L_init = jnp.eye(d) + 0.01 * jax.random.normal(key_L, shape=(d, d))

    L_init = jnp.tril(L_init)

    params = {"mu": mu_init, "L": L_init}
    optimizer = optax.adam(lr)
    opt_state = optimizer.init(params)

    loss_grad_fn = jax.value_and_grad(score_matching_loss)

    for step in range(n_steps):
        loss_val, grads = loss_grad_fn(params, samples)
        updates, opt_state = optimizer.update(grads, opt_state)
        params = optax.apply_updates(params, updates)

        params["L"] = jnp.tril(params["L"])

        if step % 100 == 0:
            print(f"Step {step:4d} | Loss: {loss_val:.6f}")

    return params

In [14]:
mu, cov = dg.generate_gaussian_params(d=5, sigma_mu=0.1, sigma_cov=0.2, seed=0)
samples = dg.generate_gaussian_data(mu, cov, n_samples=5000, seed=1)

params_hat = optimize_score_matching(samples, n_steps=1000, lr=1e-2)

mu_hat = params_hat["mu"]
precision_hat = params_hat["L"] @ params_hat["L"].T
cov_hat = jnp.linalg.inv(precision_hat)

Step    0 | Loss: -2.439819
Step  100 | Loss: -11.456070
Step  200 | Loss: -25.564577
Step  300 | Loss: -27.581137
Step  400 | Loss: -27.638535
Step  500 | Loss: -27.639832
Step  600 | Loss: -27.632603
Step  700 | Loss: -27.639864
Step  800 | Loss: -27.639862
Step  900 | Loss: -27.639833


In [15]:
print(mu, "\n\n", mu_hat, "\n\n\n")
print(cov, "\n\n", cov_hat, "\n\n\n")

[ 0.10040142 -0.09063372 -0.07481723 -0.11713669 -0.08712328] 

 [ 0.09848906 -0.10065973 -0.07100011 -0.12589708 -0.08577091] 



[[ 0.44429743  0.16468923 -0.04871666  0.1263152   0.16904442]
 [ 0.16468923  0.3720387  -0.04333964 -0.08381284 -0.05857849]
 [-0.04871666 -0.04333964  0.14431155  0.01220462 -0.08030083]
 [ 0.1263152  -0.08381284  0.01220462  0.2647391   0.05178968]
 [ 0.1690444  -0.05857849 -0.08030083  0.05178968  0.2139716 ]] 

 [[ 0.44445962  0.15949708 -0.04591568  0.12920672  0.1701761 ]
 [ 0.15949708  0.36336207 -0.04257188 -0.08282109 -0.05459233]
 [-0.04591568 -0.04257189  0.14183275  0.01025948 -0.07798707]
 [ 0.12920673 -0.08282107  0.01025947  0.2637142   0.05422248]
 [ 0.1701761  -0.05459231 -0.07798707  0.05422248  0.20998599]] 



