In [110]:
import sys
import os
sys.path.append(os.path.abspath('..'))

import data_gen as dg
import ising as isg

import jax
import jax.numpy as jnp
import optax

In [111]:
def score_matching_loss_boltzmann(params, samples):

    h = params["h"]
    J = params["J"]

    linear_terms = samples @ J.T + h
    squared_norms = jnp.sum(linear_terms ** 2, axis=1)

    loss = 0.5 * jnp.mean(squared_norms) + jnp.trace(J)

    return loss


In [112]:
def compute_loss_and_grad(params, samples):

    loss_val, grads = jax.value_and_grad(score_matching_loss_boltzmann)(params, samples)
    
    return loss_val, grads

In [113]:
def opt_step(params, opt_state, samples, optimizer):

    loss_val, grads = compute_loss_and_grad(params, samples)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    params["J"] = (params["J"].T + params["J"])/2

    return params, opt_state, loss_val

In [114]:
def optimize_score_matching(samples, h_init, J_init, lr=1e-2, n_steps=1000):

    params = {"h": h_init, "J": J_init}
    optimizer = optax.adam(lr)
    opt_state = optimizer.init(params)

    for step in range(n_steps):
        params, opt_state, loss_val = opt_step(params, opt_state, samples, optimizer)
        if step % 100 == 0:
            print(f"Step {step} | Loss: {loss_val:.6f}")

    return params["h"], params["J"]

In [115]:
d = 3
n_samples = 1000

h_true, J_true = dg.generate_ising_params(d, sigma_h=1, sigma_J=0.5, seed=0)
samples = dg.generate_ising_data(n_samples, h=h_true, J=J_true, n_steps=1000, beta=1, seed=42)

step 0:[0.10600000619888306, -0.13600000739097595, -0.20600001513957977]
[0.8520000576972961, -0.7360000610351562, -0.628000020980835]
0.09199999272823334

step 500:[0.9580000638961792, -0.906000018119812, -0.6320000290870667]
[0.9830000400543213, -0.9440000653266907, -0.6640000343322754]
0.015000025741755962



In [116]:
key = jax.random.PRNGKey(0)
h_init = jax.random.normal(key, shape=(d,))
J_init = jax.random.normal(key, shape=(d, d))
J_init = 0.5 * (J_init + J_init.T)

h_opt, J_opt = optimize_score_matching(samples, h_init, J_init, 1e-2, 1000)

Step 0 | Loss: 11.070559
Step 100 | Loss: 2.576656
Step 200 | Loss: 0.455281
Step 300 | Loss: -1.002035
Step 400 | Loss: -2.298135
Step 500 | Loss: -3.478536
Step 600 | Loss: -4.549845
Step 700 | Loss: -5.512217
Step 800 | Loss: -6.369025
Step 900 | Loss: -7.127464


In [117]:
print(h_opt, "\n\n", h_true, "\n\n\n")
print(J_opt, "\n\n", J_true, "\n\n\n")

[ 2.0818355 -2.0969296 -0.970043 ] 

 [ 1.0040143 -0.9063372 -0.7481722] 



[[-4.523883   -2.138714   -0.25920716]
 [-2.138714   -4.601226   -0.17933486]
 [-0.25920716 -0.17933486 -1.6713101 ]] 

 [[ 0.         -1.1946154  -0.47133768]
 [-1.1946154   0.         -0.44069302]
 [-0.47133768 -0.44069302  0.        ]] 



