In [23]:
import sys
import os
sys.path.append(os.path.abspath('..'))

import data_gen as dg
import ising as isg

import jax
import jax.numpy as jnp
import optax

In [None]:
def score_matching_loss(params, samples):

    mu, var = params
    diffs = samples[:, 0] - mu
    term1 = (diffs**2) / (var**2)
    term2 = -2.0 / var
    loss = 0.5 * jnp.mean(term1 + term2)

    return loss

In [25]:
def score_matching_step(params, opt_state, samples, optimizer):
    
    loss_val, grads = jax.value_and_grad(score_matching_loss)(params, samples)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)

    return params, opt_state, loss_val

In [26]:
def optimize_score_matching(samples, mu_init, var_init, n_steps=1000, lr=1e-2):

    params = jnp.array([mu_init, var_init])
    optimizer = optax.adam(lr)
    opt_state = optimizer.init(params)

    for t in range(n_steps):
        params, opt_state, loss_val = score_matching_step(params, opt_state, samples, optimizer)

        if t % 100 == 0:
            mu, var = params
            print(f"step {t} | loss = {loss_val:.6f} | mu = {mu:.4f} | sigma^2 = {var:.4f}")

    return params

In [40]:
d = 1
n_samples = 1000
mu_true, var_true = dg.generate_gaussian_params(d=1, sigma_mu=5.0, sigma_cov=1.0, seed=0)
samples = dg.generate_gaussian_data(mu_true, var_true, n_samples=n_samples, seed=1)

mu_init = -3.0
var_init = 0.5

params_hat = optimize_score_matching(samples, mu_init, var_init, n_steps=8000, lr=0.1)
mu_hat, var_hat = params_hat

print(f"\nStima finale: mu = {mu_hat:.4f}, sigma^2 = {var_hat:.4f}")
print(f"Valore vero: mu = {mu_true[0]:.4f}, sigma^2 = {var_true[0, 0]:.4f}")


step 0 | loss = 136.697784 | mu = -2.9000 | sigma^2 = 0.6000
step 100 | loss = 2.159341 | mu = 0.4500 | sigma^2 = 2.2630
step 200 | loss = 0.725074 | mu = 2.1308 | sigma^2 = 2.5123
step 300 | loss = 0.260046 | mu = 3.2936 | sigma^2 = 2.6411
step 400 | loss = 0.102932 | mu = 4.0402 | sigma^2 = 2.7220
step 500 | loss = 0.050101 | mu = 4.4826 | sigma^2 = 2.7865
step 600 | loss = 0.030001 | mu = 4.7246 | sigma^2 = 2.8470
step 700 | loss = 0.019179 | mu = 4.8476 | sigma^2 = 2.9074
step 800 | loss = 0.010966 | mu = 4.9062 | sigma^2 = 2.9684
step 900 | loss = 0.003728 | mu = 4.9325 | sigma^2 = 3.0299
step 1000 | loss = -0.002894 | mu = 4.9438 | sigma^2 = 3.0917
step 1100 | loss = -0.008986 | mu = 4.9484 | sigma^2 = 3.1533
step 1200 | loss = -0.014587 | mu = 4.9502 | sigma^2 = 3.2147
step 1300 | loss = -0.019733 | mu = 4.9509 | sigma^2 = 3.2756
step 1400 | loss = -0.024459 | mu = 4.9511 | sigma^2 = 3.3361
step 1500 | loss = -0.028799 | mu = 4.9512 | sigma^2 = 3.3960
step 1600 | loss = -0.03278