In [18]:
import sys
import os
sys.path.append(os.path.abspath('..'))

import jax
import jax.numpy as jnp
import data_gen as dg
import optax

In [19]:
def rbf_kernel(x, y, gamma=1.0):

    sq_dists = jnp.linalg.norm(x - y)**2
    
    return jnp.exp(-gamma * sq_dists)

In [41]:
def mcmc_loss(samples, h, J, beta, n_steps=10, seed=0, gamma=1.0):

    n_samples, d = samples.shape
    key = jax.random.PRNGKey(seed)
    keys_all = dg.generate_all_keys(seed + 1, n_steps, n_samples)

    samples_evolved = samples.copy()
    for t in range(n_steps):
        samples_evolved = dg.apply_glauber_to_all(keys_all[t], samples_evolved, h, J, beta)

    k_xx = 0
    k_yy = 0
    k_xy = 0

    for i in range(n_samples):
        for j in range(n_samples):

            print(samples[i])
            
            k_xx += rbf_kernel(samples[i], samples[j], gamma)
            k_yy += rbf_kernel(samples_evolved[i], samples_evolved[j], gamma)
            k_xy += rbf_kernel(samples[i], samples_evolved[j], gamma)

    mmd = (k_xx) + (k_yy) - 2 * (k_xy)
    
    return mmd / (n_samples**2)

In [42]:
def mcmc_gradients(samples, h, J, beta):

    grad_loss = jax.grad(mcmc_loss, argnums=(1, 2))
    
    return grad_loss(samples, h, J, beta)

In [29]:
def mcmc_step(params, opt_state, samples, optimizer, beta):
    
    grad_h, grad_J = mcmc_gradients(samples, params["h"], params["J"], beta)

    grads = {
        "h": grad_h,
        "J": grad_J
    }

    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)

    J = params["J"]
    J = 0.5 * (J + J.T)
    J = J - jnp.diag(J)
    params["J"] = J

    return params, opt_state

In [30]:
def optimize_mcmc(samples, h_init, J_init, n_steps=1000, lr=1e-2, beta=1.0):
    
    params = {
        "h": h_init,
        "J": J_init
    }

    optimizer = optax.adam(lr)
    opt_state = optimizer.init(params)

    for t in range(n_steps):
        params, opt_state = mcmc_step(params, opt_state, samples, optimizer, beta)

        if t % 100 == 0:
            loss_val = mcmc_loss(samples, params["h"], params["J"], beta)
            print(f"step {t} | loss = {loss_val:.6f}")
            
    return params["h"], params["J"]

In [45]:
d = 5

n_init = 10

n_replicas = 5

h, J = dg.generate_ising_params(d, sigma_h=1, sigma_J=0.5, seed=0)

samples = dg.generate_ising_data(n_init, n_replicas, h=h, J=J, n_steps_equil=10000, n_steps_final=1000, n_prints = 5000, beta=1, seed=12)

step 0:[-0.800000011920929, 0.20000000298023224, -0.6000000238418579, -0.6000000238418579, 0.4000000059604645]
[0.2710000276565552, 0.8050000667572021, -0.18900001049041748, -0.8940000534057617, 0.11300000548362732]
0.3012000024318695

step 5000:[1.0, -1.0, -1.0, 0.800000011920929, -1.0]
[0.9980000257492065, -1.0, -0.9750000238418579, 0.8030000329017639, -0.999000072479248]
0.005399990361183882

step 10000:[1.0, -1.0, -1.0, 0.800000011920929, -1.0]
[0.9980000257492065, -1.0, -0.9750000238418579, 0.8030000329017639, -0.999000072479248]
0.005399990361183882



In [32]:
print("True h:", h)
print("True J:", J)

True h: [ 1.0040143 -0.9063372 -0.7481722 -1.1713669 -0.8712328]
True J: [[ 0.         -1.6071162   0.00798594  0.04462025 -0.93398416]
 [-1.6071162   0.          0.63058895 -1.2504177   0.90409863]
 [ 0.00798594  0.63058895  0.         -0.40897474  0.47866577]
 [ 0.04462025 -1.2504177  -0.40897474  0.         -0.5732635 ]
 [-0.93398416  0.90409863  0.47866577 -0.5732635   0.        ]]


In [43]:
key = jax.random.PRNGKey(0)

h0 = 0.1 * jax.random.normal(key, shape=(d,))
J0 = 0.1 * jax.random.normal(key, shape=(d, d))
J0 = 0.5 * (J0 + J0.T)
J0 = J0 - jnp.diag(jnp.diag(J0))

h_est, J_est = optimize_mcmc(samples, h0, J0, n_steps=1000, lr=0.1, beta=1.0)

print("estimated h:", jnp.round(h_est, 3))
print("estimated J:", jnp.round(J_est, 3))

[ 1 -1 -1  1 -1]
[ 1 -1 -1  1 -1]
[ 1 -1 -1  1 -1]
[ 1 -1 -1  1 -1]
[ 1 -1 -1  1 -1]
[ 1 -1 -1  1 -1]
[ 1 -1 -1  1 -1]
[ 1 -1 -1  1 -1]
[ 1 -1 -1  1 -1]
[ 1 -1 -1  1 -1]
[ 1 -1 -1  1 -1]
[ 1 -1 -1  1 -1]
[ 1 -1 -1  1 -1]
[ 1 -1 -1  1 -1]
[ 1 -1 -1  1 -1]
[ 1 -1 -1  1 -1]
[ 1 -1 -1  1 -1]
[ 1 -1 -1  1 -1]
[ 1 -1 -1  1 -1]
[ 1 -1 -1  1 -1]
[ 1 -1 -1  1 -1]
[ 1 -1 -1  1 -1]
[ 1 -1 -1  1 -1]
[ 1 -1 -1  1 -1]
[ 1 -1 -1  1 -1]
[ 1 -1 -1  1 -1]
[ 1 -1 -1  1 -1]
[ 1 -1 -1  1 -1]
[ 1 -1 -1  1 -1]
[ 1 -1 -1  1 -1]
[ 1 -1 -1  1 -1]
[ 1 -1 -1  1 -1]
[ 1 -1 -1  1 -1]
[ 1 -1 -1  1 -1]
[ 1 -1 -1  1 -1]
[ 1 -1 -1  1 -1]
[ 1 -1 -1  1 -1]
[ 1 -1 -1  1 -1]
[ 1 -1 -1  1 -1]
[ 1 -1 -1  1 -1]
[ 1 -1 -1  1 -1]
[ 1 -1 -1  1 -1]
[ 1 -1 -1  1 -1]
[ 1 -1 -1  1 -1]
[ 1 -1 -1  1 -1]
[ 1 -1 -1  1 -1]
[ 1 -1 -1  1 -1]
[ 1 -1 -1  1 -1]
[ 1 -1 -1  1 -1]
[ 1 -1 -1  1 -1]
[ 1 -1 -1  1 -1]
[ 1 -1 -1  1 -1]
[ 1 -1 -1  1 -1]
[ 1 -1 -1  1 -1]
[ 1 -1 -1  1 -1]
[ 1 -1 -1  1 -1]
[ 1 -1 -1  1 -1]
[ 1 -1 -1  1 -1]
[ 1 -1 -1  1 -

KeyboardInterrupt: 

In [55]:
import numpy as np

unique = []
for i in range(len(samples)):
    if not any(np.array_equal(samples[i], x) for x in unique):
        unique.append(samples[i])


In [56]:
unique

[Array([ 1, -1, -1,  1, -1], dtype=int32),
 Array([ 1, -1, -1, -1, -1], dtype=int32)]