In [55]:
import sys
import os
sys.path.append(os.path.abspath('..'))

import jax
import jax.numpy as jnp
import data_gen as dg
import optax

In [56]:
def flip_spin(sigma, i):
    return sigma.at[i].set(-sigma[i])

In [57]:
def local_energy_diff(sigma, i, h, j):
    """
    Compute energy difference E(new_sigma) - E(sigma) when flipping spin i
    """
    delta = 2 * sigma[i] * (h[i] + jnp.dot(j[i], sigma) - j[i, i] * sigma[i])
    return delta

In [58]:
def mpf_loss_single_sample(sigma, h, j, beta=1.0):
    """
    Compute the MPF loss for a single sample sigma using an explicit loop
    """
    d = sigma.shape[0]
    loss = 0.0

    for i in range(d):
        delta_e = local_energy_diff(sigma, i, h, j)
        loss += jnp.exp(-0.5 * beta * delta_e)

    return loss

In [59]:
def compute_energy_differences_all_sites(sigma, h, j):
    """
    Return vector of energy differences for all single-spin flips of one configuration
    """
    d = sigma.shape[0]
    diffs = []

    for i in range(d):
        delta_e = local_energy_diff(sigma, i, h, j)
        diffs.append(delta_e)

    return jnp.array(diffs)

In [60]:
def mpf_loss_per_sample(sigma, h, j, beta):
    """
    MPF loss for one sample (no inner function, no lambda)
    """
    delta_e = compute_energy_differences_all_sites(sigma, h, j)
    return jnp.sum(jnp.exp(-0.5 * beta * delta_e))

In [61]:
def mpf_loss(samples, h, j, beta=1.0):
    """
    Average MPF loss over all samples
    """
    return jnp.mean(jax.vmap(mpf_loss_per_sample, in_axes=(0, None, None, None))(samples, h, j, beta))

In [62]:
def mpf_gradients(samples, h, j, beta=1.0):
    """
    Compute gradients of MPF loss w.r.t. h and j
    """
    grad_loss = jax.grad(mpf_loss, argnums=(1, 2))
    return grad_loss(samples, h, j, beta)

In [63]:
def mpf_step(params, opt_state, samples, optimizer, beta):
    grad_h, grad_j = mpf_gradients(samples, params["h"], params["j"], beta)

    grads = {
        "h": grad_h,
        "j": grad_j
    }

    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)

    j = params["j"]
    j = 0.5 * (j + j.T)
    j = j - jnp.diag(jnp.diag(j))
    params["j"] = j

    return params, opt_state

In [64]:
def optimize_mpf(samples, h_init, j_init, n_steps=1000, lr=1e-2, beta=1.0):
    params = {
        "h": h_init,
        "j": j_init
    }

    optimizer = optax.adam(lr)
    opt_state = optimizer.init(params)

    for t in range(n_steps):
        params, opt_state = mpf_step(params, opt_state, samples, optimizer, beta)

        if t % 100 == 0:
            loss_val = mpf_loss(samples, params["h"], params["j"], beta)
            print(f"step {t} | loss = {loss_val:.6f}")

    return params["h"], params["j"]

In [68]:
d = 3
n_samples = 100000

h, j = dg.generate_log_concave_ising_params(d, sigma_h=1, sigma_j=0.5, seed=0)

samples = dg.generate_ising_data(n_samples, h=h, j=j, n_steps=1000, beta=1, seed=42)

step 0:[0.13100001215934753, -0.11100000888109207, -0.19500000774860382]
[0.8360000252723694, -0.7870000600814819, -0.6600000262260437]
0.14533334970474243

step 500:[0.9600000381469727, -0.9460000395774841, -0.7820000648498535]
[0.9850000739097595, -0.968000054359436, -0.7880000472068787]
0.0009999871253967285



In [69]:
print("True h:", h)
print("True J:", j)

True h: [ 1.0040143 -0.9063372 -0.7481722]
True J: [[ 0.         -1.225061   -0.35485688]
 [-1.225061    0.         -0.02253926]
 [-0.35485688 -0.02253926  0.        ]]


In [70]:
key = jax.random.PRNGKey(0)
h0 = 0.1 * jax.random.normal(key, shape=(d,))
j0 = 0.1 * jax.random.normal(key, shape=(d, d))
j0 = 0.5 * (j0 + j0.T)
j0 = j0 - jnp.diag(jnp.diag(j0))

h_est, j_est = optimize_mpf(samples, h0, j0, n_steps=1000, lr=0.1, beta=1.0)

print("Estimated h:", jnp.round(h_est, 3))
print("Estimated J:", jnp.round(j_est, 3))


step 0 | loss = 2.517921
step 100 | loss = 1.061966
step 200 | loss = 1.061922
step 300 | loss = 1.061922
step 400 | loss = 1.061922
step 500 | loss = 1.061922
step 600 | loss = 1.061922
step 700 | loss = 1.061923
step 800 | loss = 1.061922
step 900 | loss = 1.061923
Estimated h: [ 1.021      -0.901      -0.76600003]
Estimated J: [[ 0.         -1.22       -0.33800003]
 [-1.22        0.         -0.028     ]
 [-0.33800003 -0.028       0.        ]]
