In [2]:
import sys
import os
sys.path.append(os.path.abspath('..'))

import data_gen as dg
import ising as isg

import jax
import jax.numpy as jnp
import optax

In [3]:
def mpf_loss_per_sample(sigma, h, J, beta):

    delta_e = isg.flips_energy_diff(sigma, h, J, beta)
    
    return jnp.sum(jnp.exp(-0.5 * delta_e))

In [4]:
def mpf_loss(samples, h, J, beta):

    return jnp.mean(jax.vmap(mpf_loss_per_sample, in_axes=(0, None, None, None))(samples, h, J, beta))

In [5]:
def mpf_gradients(samples, h, J, beta):

    grad_loss = jax.grad(mpf_loss, argnums=(1, 2))
    
    return grad_loss(samples, h, J, beta)

In [6]:
def mpf_step(params, opt_state, samples, optimizer, beta):
    
    grad_h, grad_J = mpf_gradients(samples, params["h"], params["J"], beta)

    grads = {
        "h": grad_h,
        "J": grad_J
    }

    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)

    J = params["J"]
    J = 0.5 * (J + J.T)
    J = J - jnp.diag(jnp.diag(J))
    params["J"] = J

    return params, opt_state

In [7]:
def optimize_mpf(samples, h_init, J_init, n_steps=1000, lr=1e-2, beta=1.0):
    
    params = {
        "h": h_init,
        "J": J_init
    }

    optimizer = optax.adam(lr)
    opt_state = optimizer.init(params)

    for t in range(n_steps):
        params, opt_state = mpf_step(params, opt_state, samples, optimizer, beta)

        if t % 100 == 0:
            loss_val = mpf_loss(samples, params["h"], params["J"], beta)
            print(f"step {t} | loss = {loss_val:.6f}")

    return params["h"], params["J"]

In [None]:
d = 3

n_samples = 100000

h, J = dg.generate_ising_params(d, sigma_h=1, sigma_J=0.5, seed=0)

samples = dg.generate_ising_data(n_samples, h=h, J=J, n_steps=1000, beta=1, seed=42)

step 0:[0.13600000739097595, -0.12300000339746475, -0.17000000178813934]
[0.843000054359436, -0.7590000629425049, -0.64000004529953]
0.13300001621246338

step 500:[0.9440000653266907, -0.9040000438690186, -0.6480000019073486]
[0.9830000400543213, -0.9410000443458557, -0.6610000133514404]
0.0036666791420429945



In [9]:
print("True h:", h)
print("True J:", J)

True h: [ 1.0040143 -0.9063372 -0.7481722]
True J: [[ 0.         -1.1946154  -0.47133768]
 [-1.1946154   0.         -0.44069302]
 [-0.47133768 -0.44069302  0.        ]]


In [10]:
key = jax.random.PRNGKey(0)

h0 = 0.1 * jax.random.normal(key, shape=(d,))
J0 = 0.1 * jax.random.normal(key, shape=(d, d))
J0 = 0.5 * (J0 + J0.T)
J0 = J0 - jnp.diag(jnp.diag(J0))

h_est, J_est = optimize_mpf(samples, h0, J0, n_steps=1000, lr=0.1, beta=1.0)

print("estimated h:", jnp.round(h_est, 3))
print("estimated J:", jnp.round(J_est, 3))

step 0 | loss = 2.599569
step 100 | loss = 1.311869
step 200 | loss = 1.311852
step 300 | loss = 1.311852
step 400 | loss = 1.311852
step 500 | loss = 1.311852
step 600 | loss = 1.311852
step 700 | loss = 1.311852
step 800 | loss = 1.311852
step 900 | loss = 1.311852
estimated h: [ 1.005      -0.892      -0.74300003]
estimated J: [[ 0.         -1.1910001  -0.47400004]
 [-1.1910001   0.         -0.43600002]
 [-0.47400004 -0.43600002  0.        ]]
