In [31]:
import sys
import os
sys.path.append(os.path.abspath('..'))

import data_gen as dg
import ising as isg

import jax
import jax.numpy as jnp
import optax

In [32]:
def mpf_loss_per_sample(sigma, h, J, beta):

    delta_e = isg.flips_energy_diff(sigma, h, J, beta)
    
    return jnp.sum(jnp.exp(-0.5 * delta_e))

In [33]:
def mpf_loss(samples, h, J, beta):

    return jnp.mean(jax.vmap(mpf_loss_per_sample, in_axes=(0, None, None, None))(samples, h, J, beta))

In [34]:
def mpf_gradients(samples, h, J, beta):

    grad_loss = jax.grad(mpf_loss, argnums=(1, 2))
    
    return grad_loss(samples, h, J, beta)

In [35]:
def mpf_step(params, opt_state, samples, optimizer, beta):
    
    grad_h, grad_J = mpf_gradients(samples, params["h"], params["J"], beta)

    grads = {
        "h": grad_h,
        "J": grad_J
    }

    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)

    J = params["J"]
    J = 0.5 * (J + J.T)
    J = J - jnp.diag(J)
    params["J"] = J

    return params, opt_state

In [36]:
def optimize_mpf(samples, h_init, J_init, n_steps=1000, lr=1e-2, beta=1.0):
    
    params = {
        "h": h_init,
        "J": J_init
    }

    optimizer = optax.adam(lr)
    opt_state = optimizer.init(params)

    for t in range(n_steps):
        params, opt_state = mpf_step(params, opt_state, samples, optimizer, beta)

        if t % 100 == 0:
            loss_val = mpf_loss(samples, params["h"], params["J"], beta)
            print(f"step {t} | loss = {loss_val:.6f}")

    return params["h"], params["J"]

In [37]:
d = 5

n_init = 100

n_replicas = 500

h, J = dg.generate_ising_params(d, sigma_h=1, sigma_J=0.5, seed=0)

samples = dg.generate_ising_data(n_init, n_replicas, h=h, J=J, n_steps_equil=10000, n_steps_final=1000, n_prints = 5000, beta=1, seed=12)

step 0:[-0.10000000149011612, 0.020000001415610313, -0.24000000953674316, -0.2600000202655792, 0.020000001415610313]
[0.7350000143051147, -0.503000020980835, -0.5509999990463257, -0.8060000538825989, -0.6200000047683716]
0.2369999885559082

step 5000:[0.9800000190734863, -0.9800000190734863, -0.9800000190734863, 0.7200000286102295, -0.9800000190734863]
[0.9980000257492065, -1.0, -0.9720000624656677, 0.7860000133514404, -0.9980000257492065]
0.010799991898238659

step 10000:[0.9600000381469727, -0.9600000381469727, -0.8800000548362732, 0.7000000476837158, -0.9600000381469727]
[0.9980000257492065, -1.0, -0.9700000286102295, 0.7540000081062317, -0.9980000257492065]
0.0151999955996871



In [38]:
print("true h:", h)
print("true J:", J)

true h: [ 1.0040143 -0.9063372 -0.7481722 -1.1713669 -0.8712328]
true J: [[ 0.         -1.6071162   0.00798594  0.04462025 -0.93398416]
 [-1.6071162   0.          0.63058895 -1.2504177   0.90409863]
 [ 0.00798594  0.63058895  0.         -0.40897474  0.47866577]
 [ 0.04462025 -1.2504177  -0.40897474  0.         -0.5732635 ]
 [-0.93398416  0.90409863  0.47866577 -0.5732635   0.        ]]


In [59]:
key = jax.random.PRNGKey(0)

h0 = 0.1 * jax.random.normal(key, shape=(d,))
J0 = 0.1 * jax.random.normal(key, shape=(d, d))
J0 = 0.5 * (J0 + J0.T)
J0 = J0 - jnp.diag(jnp.diag(J0))

h_est, J_est = optimize_mpf(samples[850:900], h0, J0, n_steps=1000, lr=0.1, beta=1.0)

step 0 | loss = 3.263413
step 100 | loss = 0.550793
step 200 | loss = 0.547042
step 300 | loss = 0.545379
step 400 | loss = 0.544514
step 500 | loss = 0.544006
step 600 | loss = 0.543680
step 700 | loss = 0.543457
step 800 | loss = 0.543298
step 900 | loss = 0.543179


In [60]:
print("estimated h:", jnp.round(h_est, 3))
print("estimated J:", jnp.round(J_est, 3))
print("\n\n")
print("loss h:", jnp.linalg.norm(h - h_est), "loss J:", jnp.linalg.norm(J - J_est))

estimated h: [ 2.371      -1.9200001  -2.203      -0.53400004 -2.197     ]
estimated J: [[ 0.        -2.102     -2.094      0.393     -2.296    ]
 [-2.102      0.         2.065     -0.56       2.1950002]
 [-2.094      2.065      0.        -0.446      2.249    ]
 [ 0.393     -0.56      -0.446      0.        -0.356    ]
 [-2.296      2.1950002  2.249     -0.356      0.       ]]



loss h: 2.6790671 loss J: 5.2959337
