In [12]:
import sys
import os
sys.path.append(os.path.abspath('..'))

import data_gen as dg
import ising as isg

import jax
import jax.numpy as jnp
import optax

In [13]:
def score_matching_loss(params, samples):
    mu, log_var = params  # log_var = log(sigma^2)
    var = jnp.exp(log_var)
    diffs = samples[:, 0] - mu

    term1 = (diffs**2) / (var**2)
    term2 = -2.0 / var

    loss = 0.5 * jnp.mean(term1 + term2)
    return loss

In [14]:
def score_matching_step(params, opt_state, samples, optimizer):
    
    loss_val, grads = jax.value_and_grad(score_matching_loss)(params, samples)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)

    return params, opt_state, loss_val

In [15]:
def optimize_score_matching(samples, mu_init, log_var_init, n_steps=1000, lr=1e-2):

    params = jnp.array([mu_init, log_var_init])
    optimizer = optax.adam(lr)
    opt_state = optimizer.init(params)

    for t in range(n_steps):
        params, opt_state, loss_val = score_matching_step(params, opt_state, samples, optimizer)

        if t % 100 == 0:
            mu, log_var = params
            print(f"step {t} | loss = {loss_val:.6f} | mu = {mu:.4f} | sigma^2 = {jnp.exp(log_var):.4f}")

    return params

In [17]:
# dati
d = 1
n_samples = 1000
mu_true, _ = dg.generate_gaussian_params(d=1, sigma_mu=5.0, sigma_cov=0.0, seed=0)
samples = dg.generate_gaussian_data(mu_true, jnp.eye(1), n_samples=n_samples, seed=1)

# inizializzazione
mu_init = -3.0
log_var_init = jnp.log(0.5)

# ottimizzazione
params_hat = optimize_score_matching(samples, mu_init, log_var_init, n_steps=5000, lr=0.1)
mu_hat, log_var_hat = params_hat

print(f"\nStima finale: mu = {mu_hat:.4f}, sigma^2 = {jnp.exp(log_var_hat):.4f}")
print(f"Valore vero: mu = {mu_true[0]:.4f}, sigma^2 = 1.0000")


step 0 | loss = 127.791931 | mu = -2.9000 | sigma^2 = 0.5526
step 100 | loss = 0.575678 | mu = -0.5892 | sigma^2 = 4.5099
step 200 | loss = 0.214989 | mu = -0.1169 | sigma^2 = 5.9693
step 300 | loss = 0.091996 | mu = 0.2438 | sigma^2 = 7.1376
step 400 | loss = 0.036084 | mu = 0.5438 | sigma^2 = 8.0773
step 500 | loss = 0.005830 | mu = 0.8074 | sigma^2 = 8.8311
step 600 | loss = -0.012640 | mu = 1.0481 | sigma^2 = 9.4257
step 700 | loss = -0.025046 | mu = 1.2740 | sigma^2 = 9.8782
step 800 | loss = -0.034121 | mu = 1.4908 | sigma^2 = 10.1991
step 900 | loss = -0.041325 | mu = 1.7026 | sigma^2 = 10.3951
step 1000 | loss = -0.047528 | mu = 1.9127 | sigma^2 = 10.4700
step 1100 | loss = -0.053312 | mu = 2.1241 | sigma^2 = 10.4249
step 1200 | loss = -0.059124 | mu = 2.3396 | sigma^2 = 10.2594
step 1300 | loss = -0.065375 | mu = 2.5620 | sigma^2 = 9.9711
step 1400 | loss = -0.072516 | mu = 2.7946 | sigma^2 = 9.5563
step 1500 | loss = -0.081128 | mu = 3.0407 | sigma^2 = 9.0098
step 1600 | loss