In [1]:
import sys
import os
sys.path.append(os.path.abspath('..'))

import jax
import jax.numpy as jnp
import data_gen as dg
import optax

In [2]:
def rbf_kernel(x, y, gamma=1.0):

    print("w", "\n")
    print(jnp.sum(x**2, axis = 1), "\n")
    x_norm = jnp.sum(x**2, axis=1).reshape(-1, 1)
    y_norm = jnp.sum(y**2, axis=1).reshape(1, -1)
    sq_dists = x_norm + y_norm - 2 * x @ y.T

    return jnp.exp(-gamma * sq_dists)

In [3]:
def mcmc_loss(samples, h, J, beta, n_steps=10, seed=0, gamma=1.0):

    n_samples, d = samples.shape
    key = jax.random.PRNGKey(seed)
    keys_all = dg.generate_all_keys(seed + 1, n_steps, n_samples)

    samples_evolved = samples.copy()
    for t in range(n_steps):
        samples_evolved = dg.apply_glauber_to_all(keys_all[t], samples_evolved, h, J, beta)

    k_xx = rbf_kernel(samples, samples, gamma)
    k_yy = rbf_kernel(samples_evolved, samples_evolved, gamma)
    k_xy = rbf_kernel(samples, samples_evolved, gamma)

    mmd = jnp.mean(k_xx) + jnp.mean(k_yy) - 2 * jnp.mean(k_xy)
    
    return mmd

In [4]:
def mcmc_gradients(samples, h, J, beta):

    grad_loss = jax.grad(mcmc_loss, argnums=(1, 2))
    
    return grad_loss(samples, h, J, beta)

In [5]:
def mcmc_step(params, opt_state, samples, optimizer, beta):
    
    grad_h, grad_J = mcmc_gradients(samples, params["h"], params["J"], beta)

    grads = {
        "h": grad_h,
        "J": grad_J
    }

    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)

    J = params["J"]
    J = 0.5 * (J + J.T)
    J = J - jnp.diag(jnp.diag(J))
    params["J"] = J

    return params, opt_state

In [6]:
def optimize_mcmc(samples, h_init, J_init, n_steps=1000, lr=1e-2, beta=1.0):
    
    params = {
        "h": h_init,
        "J": J_init
    }

    optimizer = optax.adam(lr)
    opt_state = optimizer.init(params)

    for t in range(n_steps):
        params, opt_state = mcmc_step(params, opt_state, samples, optimizer, beta)

        if t % 100 == 0:
            loss_val = mcmc_loss(samples, params["h"], params["J"], beta)
            print(f"step {t} | loss = {loss_val:.6f}")

    return params["h"], params["J"]

In [7]:
d = 3

n_samples = 100000

h, J = dg.generate_ising_params(d, sigma_h=1, sigma_J=0.5, seed=0)

samples = dg.generate_ising_data(n_samples, h=h, J=J, n_steps=1000, beta=1, seed=42)

step 0:[0.13600000739097595, -0.12300000339746475, -0.17000000178813934]
[0.843000054359436, -0.7590000629425049, -0.64000004529953]
0.13300001621246338

step 500:[0.9440000653266907, -0.9040000438690186, -0.6480000019073486]
[0.9830000400543213, -0.9410000443458557, -0.6610000133514404]
0.0036666791420429945



In [None]:
print("True h:", h)
print("True J:", J)

True h: [ 1.0040143 -0.9063372 -0.7481722]
True J: [[ 0.         -1.1946154  -0.47133768]
 [-1.1946154   0.         -0.44069302]
 [-0.47133768 -0.44069302  0.        ]]


: 

In [None]:
key = jax.random.PRNGKey(0)

h0 = 0.1 * jax.random.normal(key, shape=(d,))
J0 = 0.1 * jax.random.normal(key, shape=(d, d))
J0 = 0.5 * (J0 + J0.T)
J0 = J0 - jnp.diag(jnp.diag(J0))

h_est, J_est = optimize_mcmc(samples, h0, J0, n_steps=1000, lr=0.1, beta=1.0)

print("estimated h:", jnp.round(h_est, 3))
print("estimated J:", jnp.round(J_est, 3))

w 

[3 3 3 ... 3 3 3] 

w 

