In [23]:
import sys
import os
sys.path.append(os.path.abspath('..'))

import data_gen as dg
import ising as isg

import jax
import jax.numpy as jnp
import optax

In [24]:
def log_pseudolikelihood_site(samples, i, h, J, beta=1.0):

    M = samples.shape[0]
    mus = jnp.arange(M)
    logps = isg.compute_logp_all_mu(mus, samples, i, h, J, beta)
    
    return jnp.mean(logps)

In [25]:
def total_log_pseudolikelihood(samples, h, J, beta=1.0):
    
    d = samples.shape[1]
    sites = jnp.arange(d)
    site_logps = jax.vmap(log_pseudolikelihood_site, in_axes=(None, 0, None, None, None))(samples, sites, h, J, beta)

    return jnp.sum(site_logps)

In [26]:
def pl_loss(samples, h, J, beta=1.0):
    
    return -total_log_pseudolikelihood(samples, h, J, beta)

In [74]:
def pl_loss_reg(samples, h, J, beta=1.0, lmbda_h = 0.01, lmbda_J = 0.01):

    nll = -total_log_pseudolikelihood(samples, h, J, beta)

    penalty_h = lmbda_h * jnp.linalg.norm(h)
    penalty_J = lmbda_J * jnp.linalg.norm(J)

    return nll + penalty_h + penalty_J

In [45]:
def pl_gradients(samples, h, J, beta):

    grad_loss = jax.grad(pl_loss, argnums=(1, 2))

    return grad_loss(samples, h, J, beta)

In [29]:
def pl_gradients_reg(samples, h, J, beta):

    grad_loss = jax.grad(pl_loss_reg, argnums=(1, 2))

    return grad_loss(samples, h, J, beta)

In [30]:
def pl_step(params, opt_state, samples, optimizer, beta):
    
    grad_h, grad_J = pl_gradients(samples, params["h"], params["J"], beta)

    grads = {
        "h": grad_h,
        "J": grad_J
    }

    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)

    J = params["J"]
    J = 0.5 * (J + J.T)
    J = J - jnp.diag(J)
    params["J"] = J

    return params, opt_state

In [31]:
def pl_step_reg(params, opt_state, samples, optimizer, beta):
    
    grad_h, grad_J = pl_gradients_reg(samples, params["h"], params["J"], beta)

    grads = {
        "h": grad_h,
        "J": grad_J
    }

    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)

    J = params["J"]
    J = 0.5 * (J + J.T)
    J = J - jnp.diag(J)
    params["J"] = J

    return params, opt_state

In [63]:
def optimize_pl(samples, h_init, J_init, n_steps=1000, lr=1e-2, beta=1.0):
    
    params = {
        "h": h_init,
        "J": J_init
    }

    optimizer = optax.sgd(lr)
    opt_state = optimizer.init(params)

    for t in range(n_steps):
        params, opt_state = pl_step(params, opt_state, samples, optimizer, beta)

        if t % 100 == 0:
            loss_val = pl_loss(samples, params["h"], params["J"], beta)
            print(f"step {t} | loss = {loss_val:.6f}")

    return params["h"], params["J"]

In [52]:
def optimize_pl_reg(samples, h_init, J_init, n_steps=1000, lr=1e-2, beta=1.0):
    
    params = {
        "h": h_init,
        "J": J_init
    }

    optimizer = optax.sgd(lr)
    opt_state = optimizer.init(params)

    for t in range(n_steps):
        params, opt_state = pl_step_reg(params, opt_state, samples, optimizer, beta)

        if t % 100 == 0:
            loss_val = pl_loss_reg(samples, params["h"], params["J"], beta)
            print(f"step {t} | loss = {loss_val:.6f}")

    return params["h"], params["J"]

In [65]:
d = 5

n_init = 100

n_replicas = 500

# h, J = dg.generate_ising_params(d, sigma_h=1, sigma_J=0.5, seed=0)
h, J = dg.generate_ising_params(d, sigma_h=1, sigma_J=1, seed=0)

In [66]:
J = J.at[1, 0].set(0.0)
J = J.at[0, 1].set(0.0)
J = J.at[3, 4].set(0.0)
J = J.at[4, 3].set(0.0)

In [67]:
print("True h:", h)
print("True J:", J)

True h: [ 1.0040143 -0.9063372 -0.7481722 -1.1713669 -0.8712328]
True J: [[ 0.          0.          0.01597188  0.08924049 -1.8679683 ]
 [ 0.          0.          1.2611779  -2.5008354   1.8081973 ]
 [ 0.01597188  1.2611779   0.         -0.8179495   0.95733154]
 [ 0.08924049 -2.5008354  -0.8179495   0.          0.        ]
 [-1.8679683   1.8081973   0.95733154  0.          0.        ]]


In [68]:
samples = dg.generate_ising_data(n_init, n_replicas, h=h, J=J, n_steps_equil=10000, n_steps_final=1000, n_prints = 5000, beta=1, seed=12)

step 0:[-0.14000000059604645, 0.04000000283122063, -0.18000000715255737, -0.1600000113248825, 0.04000000283122063]
[0.7220000624656677, -0.5790000557899475, -0.48600003123283386, -0.8130000233650208, -0.6100000143051147]
0.27320003509521484

step 5000:[0.940000057220459, -0.940000057220459, -0.940000057220459, 0.9000000357627869, -0.940000057220459]
[0.9930000305175781, -1.0, -0.9980000257492065, 0.9660000205039978, -1.0]
0.011799979023635387

step 10000:[0.9800000190734863, -0.9800000190734863, -0.9600000381469727, 0.9600000381469727, -0.9800000190734863]
[0.9940000176429749, -1.0, -0.999000072479248, 0.9730000495910645, -1.0]
0.010399997234344482



In [75]:
key = jax.random.PRNGKey(0)

h0 = 0.1 * jax.random.normal(key, shape=(d,))
J0 = 0.1 * jax.random.normal(key, shape=(d, d))
J0 = 0.5 * (J0 + J0.T)
J0 = J0 - jnp.diag(jnp.diag(J0))

h_est, J_est = optimize_pl_reg(samples, h0, J0, n_steps=300, lr=0.1, beta=1.0)

print("estimated h:", jnp.round(h_est, 3))
print("estimated J:", jnp.round(J_est, 3))

step 0 | loss = 1.704485
step 100 | loss = 0.159508
step 200 | loss = 0.152929
estimated h: [ 0.628      -0.44000003 -0.545       0.223      -0.51000005]
estimated J: [[ 0.         -0.523      -0.45400003  0.40500003 -0.64400005]
 [-0.523       0.          0.703      -0.583       0.67800003]
 [-0.45400003  0.703       0.         -0.532       0.66      ]
 [ 0.40500003 -0.583      -0.532       0.         -0.47100002]
 [-0.64400005  0.67800003  0.66       -0.47100002  0.        ]]


In [76]:
key = jax.random.PRNGKey(0)

h0 = 0.1 * jax.random.normal(key, shape=(d,))
J0 = 0.1 * jax.random.normal(key, shape=(d, d))
J0 = 0.5 * (J0 + J0.T)
J0 = J0 - jnp.diag(jnp.diag(J0))

h_est, J_est = optimize_pl(samples[0:50], h0, J0, n_steps=300, lr=0.1, beta=1.0)

print("estimated h:", jnp.round(h_est, 3))
print("estimated J:", jnp.round(J_est, 3))

step 0 | loss = 1.709411
step 100 | loss = 0.123135
step 200 | loss = 0.111475
estimated h: [ 0.68500006 -0.551      -0.66300005  0.14400001 -0.51900005]
estimated J: [[ 0.         -0.69600004 -0.61700004  0.49800003 -0.716     ]
 [-0.69600004  0.          0.72300005 -0.43800002  0.762     ]
 [-0.61700004  0.72300005  0.         -0.441       0.74500006]
 [ 0.49800003 -0.43800002 -0.441       0.         -0.46100003]
 [-0.716       0.762       0.74500006 -0.46100003  0.        ]]
