In [76]:
import sys
import os
sys.path.append(os.path.abspath('..'))

import data_gen as dg
import ising as isg

import jax
import jax.numpy as jnp
import optax

In [77]:
def evolve_langevin(x0, key, mu, L, eps=1e-2, n_steps=10):
    """
    Evolve initial points x0 under Langevin dynamics defined by Gaussian potential.

    Parameters:
        x0: [N, d] array of particles
        key: PRNGKey
        mu: [d,] mean of target Gaussian
        L: [d, d] lower-triangular matrix such that L @ L.T = precision
        eps: step size
        n_steps: number of Langevin steps

    Returns:
        x: evolved particles [N, d]
    """
    def grad_V(x):
        diff = x - mu
        precision = L @ L.T
        return diff @ precision.T

    def step(x, key):
        noise = jax.random.normal(key, shape=x.shape)
        return x - eps * grad_V(x) + jnp.sqrt(2 * eps) * noise

    def scan_fn(carry, _):
        x, key = carry
        key, subkey = jax.random.split(key)
        x_new = step(x, subkey)
        return (x_new, key), x_new

    (x_final, _), _ = jax.lax.scan(scan_fn, (x0, key), None, length=n_steps)
    return x_final


In [78]:
def sm_loss(params, samples):

    mu = params["mu"]
    L = params["L"]
    Lambda = L @ L.T

    centered = samples - mu
    Lambda2 = Lambda @ Lambda
    quad_terms = jnp.sum((centered @ Lambda2) * centered, axis=1)

    loss = 0.5 * jnp.mean(quad_terms) - jnp.trace(Lambda)
    
    return loss

In [79]:
def optimize_score_matching(samples, n_steps=1000, lr=1e-2, seed=0, n_enhance=200, eps=1e-2):

    d = samples.shape[1]
    key = jax.random.PRNGKey(seed)
    key_mu, key_L, key_langevin = jax.random.split(key, 3)

    mu_init = jax.random.normal(key_mu, shape=(d,))
    L_init = jnp.eye(d) + 0.01 * jax.random.normal(key_L, shape=(d, d))
    L_init = jnp.tril(L_init)

    params = {"mu": mu_init, "L": L_init}
    optimizer = optax.adam(lr)
    opt_state = optimizer.init(params)

    loss_grad_fn = jax.value_and_grad(sm_loss)

    for step in range(n_steps):
        loss_val, grads = loss_grad_fn(params, samples)
        updates, opt_state = optimizer.update(grads, opt_state)
        params = optax.apply_updates(params, updates)
        params["L"] = jnp.tril(params["L"])

        if step % 20 == 0:
            print(f"Step {step:4d} | Loss: {loss_val:.6f}")

        # Enhance samples every n_enhance steps
        if (step + 1) % n_enhance == 0:
            key_langevin, subkey = jax.random.split(key_langevin)
            new_samples = evolve_langevin(samples, subkey, params["mu"], params["L"], eps=eps, n_steps=10)
            samples = jnp.concatenate([samples, new_samples], axis=0)

    return params


In [None]:
def optimize_score_matching_2(samples, n_steps=1000, lr=1e-2, seed=0,
                            delta=1e-3, n_enhance=100, eps=1e-2):
    d = samples.shape[1]
    key = jax.random.PRNGKey(seed)
    key_mu, key_L, key_langevin = jax.random.split(key, 3)

    mu_init = jax.random.normal(key_mu, shape=(d,))
    L_init = jnp.eye(d) + 0.01 * jax.random.normal(key_L, shape=(d, d))
    L_init = jnp.tril(L_init)

    params = {"mu": mu_init, "L": L_init}
    optimizer = optax.adam(lr)
    opt_state = optimizer.init(params)
    loss_grad_fn = jax.value_and_grad(sm_loss)

    prev_params = params
    last_enhance_step = 0  # passo in cui è stato fatto l'ultimo enhance

    for step in range(n_steps):
        loss_val, grads = loss_grad_fn(params, samples)
        updates, opt_state = optimizer.update(grads, opt_state)
        params = optax.apply_updates(params, updates)
        params["L"] = jnp.tril(params["L"])

        if step % 20 == 0:
            print(f"Step {step:4d} | Loss: {loss_val:.6f}")

        # Calcolo variazione dei parametri
        delta_mu = jnp.linalg.norm(params["mu"] - prev_params["mu"])
        delta_L = jnp.linalg.norm(params["L"] - prev_params["L"])
        delta_theta = jnp.sqrt(delta_mu**2 + delta_L**2)

        # Condizione per fare enhancement
        if delta_theta < delta and (step - last_enhance_step) >= n_enhance:
            key_langevin, subkey = jax.random.split(key_langevin)
            new_samples = evolve_langevin(samples, subkey, params["mu"], params["L"], eps=eps, n_steps=10)
            samples = jnp.concatenate([samples, new_samples], axis=0)
            last_enhance_step = step
            print(f" ↳ Enhanced at step {step}, Δθ = {delta_theta:.2e}, total samples = {samples.shape[0]}")

        prev_params = params

    return params

In [81]:
mu, cov = dg.generate_gaussian_params(d=5, sigma_mu=0.1, sigma_cov=0.2, seed=0)
mu *= 10
cov *= 10
samples = dg.generate_gaussian_data(mu, cov, n_samples=10, seed=90)

params_hat = optimize_score_matching_2(samples, n_steps=1000, lr=1e-2)

mu_hat = params_hat["mu"]
precision_hat = params_hat["L"] @ params_hat["L"].T
cov_hat = jnp.linalg.inv(precision_hat)

Step    0 | Loss: 0.040803
Step   20 | Loss: -2.703758
Step   40 | Loss: -3.844007
Step   60 | Loss: -5.027158
Step   80 | Loss: -6.266546
Step  100 | Loss: -7.700078
Step  120 | Loss: -9.402672
Step  140 | Loss: -11.275608
Step  160 | Loss: -13.190903
Step  180 | Loss: -15.032191
Step  200 | Loss: -16.705515
Step  220 | Loss: -18.142570
Step  240 | Loss: -19.304667
Step  260 | Loss: -20.190701
Step  280 | Loss: -20.818161
Step  300 | Loss: -21.236916
Step  320 | Loss: -21.493399
Step  340 | Loss: -21.641241
Step  360 | Loss: -21.720255
Step  380 | Loss: -21.760279
Step  400 | Loss: -21.779266
Step  420 | Loss: -21.787777
Step  440 | Loss: -21.791601
Step  460 | Loss: -21.788071
 ↳ Enhanced at step 467, Δθ = 1.00e-03, total samples = 20
Step  480 | Loss: -18.866550
Step  500 | Loss: -19.309046
Step  520 | Loss: -19.415136
Step  540 | Loss: -19.450861
Step  560 | Loss: -19.463461
Step  580 | Loss: -19.468405
Step  600 | Loss: -19.470299
 ↳ Enhanced at step 612, Δθ = 9.85e-04, total samp

In [82]:
jnp.set_printoptions(precision=4, suppress=True)

print(jnp.round(mu, 4), "\n\n", jnp.round(mu_hat, 4), jnp.linalg.norm(mu - mu_hat), "\n\n\n")
print(jnp.round(cov, 4), "\n\n", jnp.round(cov_hat, 4), jnp.linalg.norm(cov - cov_hat), "\n\n\n")

[ 1.004  -0.9063 -0.7482 -1.1714 -0.8712] 

 [ 0.6264 -0.247  -0.6261 -1.1498 -1.4398] 0.95704716 



[[ 4.443   1.6469 -0.4872  1.2632  1.6904]
 [ 1.6469  3.7204 -0.4334 -0.8381 -0.5858]
 [-0.4872 -0.4334  1.4431  0.122  -0.803 ]
 [ 1.2632 -0.8381  0.122   2.6474  0.5179]
 [ 1.6904 -0.5858 -0.803   0.5179  2.1397]] 

 [[ 3.9809  1.3209 -0.0283  2.5885  1.696 ]
 [ 1.3209  1.2494 -0.      0.3454  0.3211]
 [-0.0283 -0.      0.8584  0.07   -0.576 ]
 [ 2.5885  0.3454  0.07    2.9484  0.9746]
 [ 1.696   0.3211 -0.576   0.9746  1.3228]] 4.1126084 



