In [14]:
import sys
import os
sys.path.append(os.path.abspath('..'))

import data_gen as dg
import ising as isg

import jax
import jax.numpy as jnp
import optax

In [15]:
def log_pseudolikelihood_site(samples, i, h, J, beta=1.0):

    M = samples.shape[0]
    mus = jnp.arange(M)
    logps = isg.compute_logp_all_mu(mus, samples, i, h, J, beta)
    
    return jnp.mean(logps)

In [16]:
def total_log_pseudolikelihood(samples, h, J, beta=1.0):
    
    d = samples.shape[1]
    sites = jnp.arange(d)
    site_logps = jax.vmap(log_pseudolikelihood_site, in_axes=(None, 0, None, None, None))(samples, sites, h, J, beta)

    return jnp.sum(site_logps)

In [17]:
def pl_loss(samples, h, J, beta=1.0):
    
    return -total_log_pseudolikelihood(samples, h, J, beta)

In [18]:
def pl_gradients(samples, h, J, beta):

    grad_loss = jax.grad(pl_loss, argnums=(1, 2))

    return grad_loss(samples, h, J, beta)

In [19]:
def pl_step(params, opt_state, samples, optimizer, beta):
    
    grad_h, grad_J = pl_gradients(samples, params["h"], params["J"], beta)

    grads = {
        "h": grad_h,
        "J": grad_J
    }

    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)

    J = params["J"]
    J = 0.5 * (J + J.T)
    J = J - jnp.diag(J)
    params["J"] = J

    return params, opt_state

In [68]:
def optimize_pl_enhanced(samples, h_init, J_init, n_steps=1000, lr=1e-2,
                          beta=1.0, delta=1e-2, n_enhance=100, seed=0):
    key = jax.random.PRNGKey(seed)
    key_enhance = key

    params = {
        "h": h_init,
        "J": J_init
    }

    optimizer = optax.adam(lr)
    opt_state = optimizer.init(params)

    prev_params = params
    last_enhance_step = 0
    counter = 0

    for step in range(n_steps):
        params, opt_state = pl_step(params, opt_state, samples, optimizer, beta)

        if step % 100 == 0:
            loss_val = pl_loss(samples, params["h"], params["J"], beta)
            print(f"Step {step:4d} | Loss: {loss_val:.6f}")

        delta_h = jnp.linalg.norm(params["h"] - prev_params["h"])
        delta_J = jnp.linalg.norm(params["J"] - prev_params["J"])
        delta_theta = jnp.sqrt(delta_h**2 + delta_J**2)

        if delta_theta < delta and (step - last_enhance_step) >= n_enhance:
            key_enhance, subkey = jax.random.split(key_enhance)
            keys = jax.random.split(subkey, samples.shape[0])
            new_samples = dg.apply_glauber_to_all(keys, samples, params["h"], params["J"], beta)
            samples = jnp.concatenate([samples, new_samples], axis=0)
            last_enhance_step = step
            print(f" ↳ Enhanced at step {step}, Δθ = {delta_theta:.2e}, total samples = {samples.shape[0]}")

        prev_params = params

    return params["h"], params["J"]


In [21]:
d = 5

n_init = 100

n_replicas = 500

h, J = dg.generate_ising_params(d, sigma_h=1, sigma_J=0.5, seed=0)

samples = dg.generate_ising_data(n_init, n_replicas, h=h, J=J, n_steps_equil=10000, n_steps_final=1000, n_prints = 5000, beta=1, seed=12)

step 0:[-0.10000000149011612, 0.020000001415610313, -0.24000000953674316, -0.2600000202655792, 0.020000001415610313]
[0.7350000143051147, -0.503000020980835, -0.5509999990463257, -0.8060000538825989, -0.6200000047683716]
0.2369999885559082

step 5000:[0.9800000190734863, -0.9800000190734863, -0.9800000190734863, 0.7200000286102295, -0.9800000190734863]
[0.9980000257492065, -1.0, -0.9720000624656677, 0.7860000133514404, -0.9980000257492065]
0.010799991898238659

step 10000:[0.9600000381469727, -0.9600000381469727, -0.8800000548362732, 0.7000000476837158, -0.9600000381469727]
[0.9980000257492065, -1.0, -0.9700000286102295, 0.7540000081062317, -0.9980000257492065]
0.0151999955996871



In [22]:
print("True h:", h)
print("True J:", J)

True h: [ 1.0040143 -0.9063372 -0.7481722 -1.1713669 -0.8712328]
True J: [[ 0.         -1.6071162   0.00798594  0.04462025 -0.93398416]
 [-1.6071162   0.          0.63058895 -1.2504177   0.90409863]
 [ 0.00798594  0.63058895  0.         -0.40897474  0.47866577]
 [ 0.04462025 -1.2504177  -0.40897474  0.         -0.5732635 ]
 [-0.93398416  0.90409863  0.47866577 -0.5732635   0.        ]]


In [69]:
key = jax.random.PRNGKey(0)

h0 = 0.1 * jax.random.normal(key, shape=(d,))
J0 = 0.1 * jax.random.normal(key, shape=(d, d))
J0 = 0.5 * (J0 + J0.T)
J0 = J0 - jnp.diag(jnp.diag(J0))

h_est, J_est = optimize_pl_enhanced(samples, h0, J0, n_steps=1000, lr=0.1, beta=1.0)

Step    0 | Loss: 1.983862
Step  100 | Loss: 0.435885
Step  200 | Loss: 0.431507
Step  300 | Loss: 0.428691
Step  400 | Loss: 0.427673
Step  500 | Loss: 0.429248
Step  600 | Loss: 0.429540
Step  700 | Loss: 0.427864
Step  800 | Loss: 0.428851
Step  900 | Loss: 0.427160


In [70]:
print("estimated h:", jnp.round(h_est, 3))
print("estimated J:", jnp.round(J_est, 3))
print("\n\n")
print("loss h:", jnp.linalg.norm(h - h_est), "loss J:", jnp.linalg.norm(J - J_est))

estimated h: [ 0.83900005 -1.0710001  -0.90500003 -1.4670001  -1.136     ]
estimated J: [[ 0.         -1.9660001  -0.675      -0.94000006 -1.09      ]
 [-1.9660001   0.          0.23       -2.0640001   0.83500004]
 [-0.679       0.22600001  0.         -0.54200006  0.052     ]
 [-0.89800006 -2.023      -0.49600002  0.         -0.96000004]
 [-1.093       0.832       0.053      -1.005       0.        ]]



loss h: 0.48682067 loss J: 2.3281848
