In [141]:
import sys
import os
sys.path.append(os.path.abspath('..'))

import data_gen as dg
import ising as isg

import jax
import jax.numpy as jnp
import optax

In [142]:
def sm_loss(params, samples):

    mu = params["mu"]
    L = params["L"]
    Lambda = L @ L.T

    centered = samples - mu
    Lambda2 = Lambda @ Lambda
    quad_terms = jnp.sum((centered @ Lambda2) * centered, axis=1)

    loss = 0.5 * jnp.mean(quad_terms) - jnp.trace(Lambda)
    
    return loss

In [143]:
def optimize_score_matching(samples, n_steps=1000, lr=1e-2, seed=0):

    d = samples.shape[1]
    key = jax.random.PRNGKey(seed)
    key_mu, key_L = jax.random.split(key)

    mu_init = jax.random.normal(key_mu, shape=(d,))
    L_init = jnp.eye(d) + 0.01 * jax.random.normal(key_L, shape=(d, d))

    L_init = jnp.tril(L_init)

    params = {"mu": mu_init, "L": L_init}
    optimizer = optax.adam(lr)
    opt_state = optimizer.init(params)

    loss_grad_fn = jax.value_and_grad(sm_loss)

    for step in range(n_steps):
        loss_val, grads = loss_grad_fn(params, samples)
        updates, opt_state = optimizer.update(grads, opt_state)
        params = optax.apply_updates(params, updates)

        params["L"] = jnp.tril(params["L"])

        if step % 20 == 0:
            print(f"Step {step:4d} | Loss: {loss_val:.6f}")

    return params

In [144]:
mu, cov = dg.generate_gaussian_params(d=5, sigma_mu=0.1, sigma_cov=0.2, seed=0)
samples = dg.generate_gaussian_data(mu, cov, n_samples=500, seed=0)

params_hat = optimize_score_matching(samples, n_steps=300, lr=1e-2)

mu_hat = params_hat["mu"]
precision_hat = params_hat["L"] @ params_hat["L"].T
cov_hat = jnp.linalg.inv(precision_hat)

Step    0 | Loss: -2.458913
Step   20 | Loss: -4.712379
Step   40 | Loss: -6.239665
Step   60 | Loss: -7.747056
Step   80 | Loss: -9.349491
Step  100 | Loss: -11.127122
Step  120 | Loss: -13.563088
Step  140 | Loss: -17.102310
Step  160 | Loss: -20.661938
Step  180 | Loss: -23.261869
Step  200 | Loss: -24.852919
Step  220 | Loss: -25.711205
Step  240 | Loss: -26.141741
Step  260 | Loss: -26.374632
Step  280 | Loss: -26.482742


In [145]:
jnp.set_printoptions(precision=4, suppress=True)

print(jnp.round(mu, 4), "\n\n", jnp.round(mu_hat, 4), jnp.linalg.norm(mu - mu_hat), "\n\n\n")
print(jnp.round(cov, 4), "\n\n", jnp.round(cov_hat, 4), jnp.linalg.norm(cov - cov_hat), "\n\n\n")

[ 0.1004 -0.0906 -0.0748 -0.1171 -0.0871] 

 [ 0.0993 -0.105  -0.0623 -0.1255 -0.096 ] 0.022669934 



[[ 0.4443  0.1647 -0.0487  0.1263  0.169 ]
 [ 0.1647  0.372  -0.0433 -0.0838 -0.0586]
 [-0.0487 -0.0433  0.1443  0.0122 -0.0803]
 [ 0.1263 -0.0838  0.0122  0.2647  0.0518]
 [ 0.169  -0.0586 -0.0803  0.0518  0.214 ]] 

 [[ 0.361   0.1421 -0.0485  0.0959  0.1368]
 [ 0.1421  0.3628 -0.0479 -0.0788 -0.0702]
 [-0.0485 -0.0479  0.1501  0.015  -0.075 ]
 [ 0.0959 -0.0788  0.015   0.2575  0.0315]
 [ 0.1368 -0.0702 -0.075   0.0315  0.2057]] 0.1157184 



