In [18]:
import sys
import os
sys.path.append(os.path.abspath('..'))

import data_gen as dg
import ising as isg

import jax
import jax.numpy as jnp
import optax

In [19]:
def score_matching_loss(theta, samples):
    
    residuals = samples[:, 0] - theta[0]
    loss = 0.5 * jnp.mean(residuals**2) - 1.0

    return loss

In [20]:
def score_matching_step(params, opt_state, samples, optimizer):

    loss_val, grads = jax.value_and_grad(score_matching_loss)(params, samples)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    
    return params, opt_state, loss_val

In [21]:
def optimize_score_matching(samples, theta_init, n_steps=1000, lr=1e-2):

    params = jnp.array([theta_init])
    optimizer = optax.adam(lr)
    opt_state = optimizer.init(params)

    for t in range(n_steps):
        params, opt_state, loss_val = score_matching_step(params, opt_state, samples, optimizer)

        if t % 100 == 0:
            print(f"step {t} | loss = {loss_val:.6f} | theta = {params[0]:.4f}")

    return params[0]

In [23]:
d = 1
n_samples = 1000
mu_true, _ = dg.generate_gaussian_params(d=1, sigma_mu=5.0, sigma_cov=0.0, seed=0)
samples = dg.generate_gaussian_data(mu_true, jnp.eye(1), n_samples=n_samples, seed=1)

theta_init = -5.0
theta_hat = optimize_score_matching(samples, theta_init, n_steps=1000, lr=0.1)

print(f"estimated theta = {theta_hat:.4f}, true theta = {mu_true[0]:.4f}")


step 0 | loss = 49.431808 | theta = -4.9000
step 100 | loss = 2.018369 | theta = 2.7982
step 200 | loss = -0.475407 | theta = 4.8425
step 300 | loss = -0.487354 | theta = 4.9902
step 400 | loss = -0.487356 | theta = 4.9919
step 500 | loss = -0.487356 | theta = 4.9919
step 600 | loss = -0.487356 | theta = 4.9919
step 700 | loss = -0.487356 | theta = 4.9919
step 800 | loss = -0.487356 | theta = 4.9919
step 900 | loss = -0.487356 | theta = 4.9919
estimated theta = 4.9919, true theta = 5.0201
