In [36]:
import sys
import os
sys.path.append(os.path.abspath('..'))

import data_gen as dg
import ising as isg

import jax
import jax.numpy as jnp
import optax

In [37]:
def sm_loss(params, samples):

    mu = params["mu"]
    L = params["L"]
    Lambda = L @ L.T

    centered = samples - mu
    Lambda2 = Lambda @ Lambda
    quad_terms = jnp.sum((centered @ Lambda2) * centered, axis=1)

    loss = 0.5 * jnp.mean(quad_terms) - jnp.trace(Lambda)
    
    return loss

In [38]:
def optimize_score_matching(samples, n_steps=1000, lr=1e-2, seed=0):

    d = samples.shape[1]
    key = jax.random.PRNGKey(seed)
    key_mu, key_L = jax.random.split(key)

    mu_init = jax.random.normal(key_mu, shape=(d,))
    L_init = jnp.eye(d) + 0.01 * jax.random.normal(key_L, shape=(d, d))

    L_init = jnp.tril(L_init)

    params = {"mu": mu_init, "L": L_init}
    optimizer = optax.adam(lr)
    opt_state = optimizer.init(params)

    loss_grad_fn = jax.value_and_grad(sm_loss)

    for step in range(n_steps):
        loss_val, grads = loss_grad_fn(params, samples)
        updates, opt_state = optimizer.update(grads, opt_state)
        params = optax.apply_updates(params, updates)

        params["L"] = jnp.tril(params["L"])

        if step % 20 == 0:
            print(f"Step {step:4d} | Loss: {loss_val:.6f}")

    return params

In [39]:
mu, cov = dg.generate_gaussian_params(d=5, sigma_mu=0.1, sigma_cov=0.2, seed=0)
mu = mu*10
cov = cov*10
samples = dg.generate_gaussian_data(mu, cov, n_samples=10, seed=90)

params_hat = optimize_score_matching(samples, n_steps=1000, lr=1e-2)

mu_hat = params_hat["mu"]
precision_hat = params_hat["L"] @ params_hat["L"].T
cov_hat = jnp.linalg.inv(precision_hat)

Step    0 | Loss: 0.040803
Step   20 | Loss: -2.703758
Step   40 | Loss: -3.844007
Step   60 | Loss: -5.027158
Step   80 | Loss: -6.266546
Step  100 | Loss: -7.700078
Step  120 | Loss: -9.402672
Step  140 | Loss: -11.275608
Step  160 | Loss: -13.190903
Step  180 | Loss: -15.032191
Step  200 | Loss: -16.705515
Step  220 | Loss: -18.142570
Step  240 | Loss: -19.304667
Step  260 | Loss: -20.190701
Step  280 | Loss: -20.818161
Step  300 | Loss: -21.236916
Step  320 | Loss: -21.493399
Step  340 | Loss: -21.641241
Step  360 | Loss: -21.720255
Step  380 | Loss: -21.760279
Step  400 | Loss: -21.779266
Step  420 | Loss: -21.787777
Step  440 | Loss: -21.791601
Step  460 | Loss: -21.788071
Step  480 | Loss: -21.793234
Step  500 | Loss: -21.794003
Step  520 | Loss: -21.794117
Step  540 | Loss: -21.794119
Step  560 | Loss: -21.794109
Step  580 | Loss: -21.794090
Step  600 | Loss: -21.794107
Step  620 | Loss: -21.791725
Step  640 | Loss: -21.793653
Step  660 | Loss: -21.794104
Step  680 | Loss: -21.

In [40]:
jnp.set_printoptions(precision=4, suppress=True)

print(jnp.round(mu, 4), "\n\n", jnp.round(mu_hat, 4), jnp.linalg.norm(mu - mu_hat), "\n\n\n")
print(jnp.round(cov, 4), "\n\n", jnp.round(cov_hat, 4), jnp.linalg.norm(cov - cov_hat), "\n\n\n")

[ 1.004  -0.9063 -0.7482 -1.1714 -0.8712] 

 [ 0.6351 -0.1987 -0.5707 -1.265  -1.4019] 0.9790969 



[[ 4.443   1.6469 -0.4872  1.2632  1.6904]
 [ 1.6469  3.7204 -0.4334 -0.8381 -0.5858]
 [-0.4872 -0.4334  1.4431  0.122  -0.803 ]
 [ 1.2632 -0.8381  0.122   2.6474  0.5179]
 [ 1.6904 -0.5858 -0.803   0.5179  2.1397]] 

 [[ 3.7927  1.3043  0.1804  2.3854  1.4806]
 [ 1.3043  1.2011  0.2075  0.3584  0.2086]
 [ 0.1804  0.2075  0.5738  0.1043 -0.3786]
 [ 2.3854  0.3584  0.1043  2.908   0.8019]
 [ 1.4806  0.2086 -0.3786  0.8019  1.0992]] 4.228209 



