In [56]:
import sys
import os
sys.path.append(os.path.abspath('..'))

import jax
import jax.numpy as jnp
import data_gen as dg
import optax

In [57]:
def conditional_prob(sigma, i, h, J, beta=1.0):

    local_field = h[i] + jnp.dot(J[i], sigma) - J[i, i] * sigma[i]
    exponent = -2 * beta * sigma[i] * local_field

    return 1.0 / (1.0 + jnp.exp(exponent))

In [58]:
def log_conditional_prob(sigma, i, h, J, beta=1.0):

    return jnp.log(conditional_prob(sigma, i, h, J, beta) + 1e-12)

In [59]:
def compute_logp_given_mu(mu, samples, i, h, J, beta=1.0):

    sigma = samples[mu]
    
    return log_conditional_prob(sigma, i, h, J, beta)

In [60]:
def compute_logp_all_mu(mus, samples, i, h, J, beta=1.0):
    
    return jax.vmap(compute_logp_given_mu, in_axes=(0, None, None, None, None, None))(mus, samples, i, h, J, beta)

In [61]:
def log_pseudolikelihood_site(samples, i, h, J, beta=1.0):

    M = samples.shape[0]
    mus = jnp.arange(M)
    logps = compute_logp_all_mu(mus, samples, i, h, J, beta)
    
    return jnp.mean(logps)

In [62]:
def total_log_pseudolikelihood(samples, h, J, beta=1.0):
    
    d = samples.shape[1]
    sites = jnp.arange(d)
    site_logps = jax.vmap(log_pseudolikelihood_site, in_axes=(None, 0, None, None, None))(samples, sites, h, J, beta)

    return jnp.sum(site_logps)

In [63]:
def pl_loss(samples, h, J, beta=1.0):
    
    return -total_log_pseudolikelihood(samples, h, J, beta)

In [64]:
def pl_gradients(samples, h, J, beta):

    grad_loss = jax.grad(pl_loss, argnums=(1, 2))

    return grad_loss(samples, h, J, beta)

In [65]:
def pl_step(params, opt_state, samples, optimizer, beta):
    
    grad_h, grad_J = pl_gradients(samples, params["h"], params["J"], beta)

    grads = {
        "h": grad_h,
        "J": grad_J
    }

    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)

    J = params["J"]
    J = 0.5 * (J + J.T)
    J = J - jnp.diag(jnp.diag(J))
    params["J"] = J

    return params, opt_state

In [66]:
def optimize_pl(samples, h_init, J_init, n_steps=1000, lr=1e-2, beta=1.0):
    
    params = {
        "h": h_init,
        "J": J_init
    }

    optimizer = optax.adam(lr)
    opt_state = optimizer.init(params)

    for t in range(n_steps):
        params, opt_state = pl_step(params, opt_state, samples, optimizer, beta)

        if t % 100 == 0:
            loss_val = pl_loss(samples, params["h"], params["J"], beta)
            print(f"step {t} | loss = {loss_val:.6f}")

    return params["h"], params["J"]

In [67]:
d = 3

n_samples = 100000

h, J = dg.generate_ising_params(d, sigma_h=1, sigma_J=0.5, seed=0)

samples = dg.generate_ising_data(n_samples, h=h, J=J, n_steps=1000, beta=1, seed=42)

step 0:[0.13600000739097595, -0.12300000339746475, -0.17000000178813934]
[0.843000054359436, -0.7590000629425049, -0.64000004529953]
0.13300001621246338

step 500:[0.9440000653266907, -0.9040000438690186, -0.6480000019073486]
[0.9830000400543213, -0.9410000443458557, -0.6610000133514404]
0.0036666791420429945



In [68]:
print("True h:", h)
print("True J:", J)

True h: [ 1.0040143 -0.9063372 -0.7481722]
True J: [[ 0.         -1.1946154  -0.47133768]
 [-1.1946154   0.         -0.44069302]
 [-0.47133768 -0.44069302  0.        ]]


In [69]:
key = jax.random.PRNGKey(0)

h0 = 0.1 * jax.random.normal(key, shape=(d,))
J0 = 0.1 * jax.random.normal(key, shape=(d, d))
J0 = 0.5 * (J0 + J0.T)
J0 = J0 - jnp.diag(jnp.diag(J0))

h_est, J_est = optimize_pl(samples, h0, J0, n_steps=1000, lr=0.1, beta=1.0)

print("estimated h:", jnp.round(h_est, 3))
print("estimated J:", jnp.round(J_est, 3))

step 0 | loss = 1.688235
step 100 | loss = 0.679922
step 200 | loss = 0.679900
step 300 | loss = 0.679900
step 400 | loss = 0.679900
step 500 | loss = 0.679900
step 600 | loss = 0.679900
step 700 | loss = 0.679900
step 800 | loss = 0.679900
step 900 | loss = 0.679900
estimated h: [ 1.0040001 -0.896     -0.748    ]
estimated J: [[ 0.         -1.1910001  -0.47400004]
 [-1.1910001   0.         -0.44000003]
 [-0.47400004 -0.44000003  0.        ]]
