In [22]:
import sys
import os
sys.path.append(os.path.abspath('..'))

import data_gen as dg
import ising as isg

import jax
import jax.numpy as jnp
import optax

In [None]:
def score_matching_loss_boltzmann(params, samples):

    h = params["h"]
    J = params["J"]

    linear_terms = samples @ J.T + h
    squared_norms = jnp.sum(linear_terms ** 2, axis=1)

    loss = 0.5 * jnp.mean(squared_norms)
    # loss += jnp.trace(J)

    return loss

In [24]:
def compute_loss_and_grad(params, samples):

    loss_val, grads = jax.value_and_grad(score_matching_loss_boltzmann)(params, samples)
    
    return loss_val, grads

In [30]:
def opt_step(params, opt_state, samples, optimizer):

    loss_val, grads = compute_loss_and_grad(params, samples)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    params["J"] = (params["J"].T + params["J"])/2
    params["J"] = params["J"] - jnp.diag(params["J"])

    return params, opt_state, loss_val

In [31]:
def optimize_score_matching(samples, h_init, J_init, lr=1e-2, n_steps=1000):

    params = {"h": h_init, "J": J_init}
    optimizer = optax.adam(lr)
    opt_state = optimizer.init(params)

    for step in range(n_steps):
        params, opt_state, loss_val = opt_step(params, opt_state, samples, optimizer)
        if step % 100 == 0:
            print(f"step {step} | loss: {loss_val:.6f}")

    return params["h"], params["J"]

In [32]:
d = 5

n_init = 100

n_replicas = 500

h, J = dg.generate_ising_params(d, sigma_h=1, sigma_J=0.5, seed=0)

samples = dg.generate_ising_data(n_init, n_replicas, h=h, J=J, n_steps_equil=10000, n_steps_final=5000, n_prints = 5000, beta=1, seed=12)

step 0:[-0.10000000149011612, 0.020000001415610313, -0.24000000953674316, -0.2600000202655792, 0.020000001415610313]
[0.7350000143051147, -0.503000020980835, -0.5509999990463257, -0.8060000538825989, -0.6200000047683716]
0.2369999885559082

step 5000:[0.9800000190734863, -0.9800000190734863, -0.9800000190734863, 0.7200000286102295, -0.9800000190734863]
[0.9980000257492065, -1.0, -0.9720000624656677, 0.7860000133514404, -0.9980000257492065]
0.010799991898238659

step 10000:[0.9600000381469727, -0.9600000381469727, -0.8800000548362732, 0.7000000476837158, -0.9600000381469727]
[0.9980000257492065, -1.0, -0.9700000286102295, 0.7540000081062317, -0.9980000257492065]
0.0151999955996871



In [28]:
key = jax.random.PRNGKey(0)
h_init = jax.random.normal(key, shape=(d,))
J_init = jax.random.normal(key, shape=(d, d))
J_init = 0.5 * (J_init + J_init.T)

h_opt, J_opt = optimize_score_matching(samples, h_init, J_init, 1e-2, 3000)

Step 0 | Loss: 17.960045
Step 100 | Loss: -0.058899
Step 200 | Loss: -3.561385
Step 300 | Loss: -6.304195
Step 400 | Loss: -9.011396
Step 500 | Loss: -11.723924
Step 600 | Loss: -14.452819
Step 700 | Loss: -17.190744
Step 800 | Loss: -19.926098
Step 900 | Loss: -22.649315
Step 1000 | Loss: -25.354252
Step 1100 | Loss: -28.037573
Step 1200 | Loss: -30.697935
Step 1300 | Loss: -33.335300
Step 1400 | Loss: -35.950424
Step 1500 | Loss: -38.544601
Step 1600 | Loss: -41.119400
Step 1700 | Loss: -43.676487
Step 1800 | Loss: -46.217571
Step 1900 | Loss: -48.744244
Step 2000 | Loss: -51.257908
Step 2100 | Loss: -53.759758
Step 2200 | Loss: -56.250732
Step 2300 | Loss: -58.731461
Step 2400 | Loss: -61.202347
Step 2500 | Loss: -63.663517
Step 2600 | Loss: -66.114868
Step 2700 | Loss: -68.556381
Step 2800 | Loss: -70.987350
Step 2900 | Loss: -73.407257


In [29]:
print(h_opt, "\n\n", h, "\n\n\n")
print(J_opt, "\n\n", J, "\n\n\n")

[ 2.4387586 -1.9842377 -3.249127  -0.4370167 -2.3038075] 

 [ 1.0040143 -0.9063372 -0.7481722 -1.1713669 -0.8712328] 



[[-2.2221886e+01 -8.1820107e+00 -2.6123929e+00  1.5795812e-02
  -8.9816332e+00]
 [-8.1820107e+00 -2.6416790e+01  4.8081450e+00 -1.1404092e+00
   1.0640941e+01]
 [-2.6123929e+00  4.8081450e+00 -1.4792762e+01 -7.0579946e-01
   3.1199098e+00]
 [ 1.5795812e-02 -1.1404092e+00 -7.0579946e-01 -2.7112818e+00
  -7.3596281e-01]
 [-8.9816332e+00  1.0640941e+01  3.1199098e+00 -7.3596281e-01
  -2.5624599e+01]] 

 [[ 0.         -1.6071162   0.00798594  0.04462025 -0.93398416]
 [-1.6071162   0.          0.63058895 -1.2504177   0.90409863]
 [ 0.00798594  0.63058895  0.         -0.40897474  0.47866577]
 [ 0.04462025 -1.2504177  -0.40897474  0.         -0.5732635 ]
 [-0.93398416  0.90409863  0.47866577 -0.5732635   0.        ]] 



