In [55]:
import jax
import jax.numpy as jnp
import optax
import time


import matplotlib.pyplot as plt
import numpy as np

In [56]:
def all_ising_states(d):


    grid01 = jnp.indices((2,) * d)
    grid01 = grid01.reshape(d, -1).T 
    states = 2 * grid01 - 1


    return states.astype(jnp.int8)

In [57]:
def hamiltonian(spins, h, J):


    spins = jnp.asarray(spins)
    h = jnp.asarray(h)
    J = jnp.asarray(J)

    
    return - jnp.dot(h, spins) - 0.5 * spins @ (J @ spins)

In [58]:
def sample_ising_bruteforce(h, J, beta, n_samples, key):
    d = h.shape[0]
    states_int = all_ising_states(d)
    states = states_int.astype(jnp.float32)

    J = J - jnp.diag(jnp.diag(J))

    field_term = - (states @ h)
    sJ = states @ J
    pair_term = -0.5 * jnp.sum(states * sJ, axis=1)
    energies = field_term + pair_term

    weights = jnp.exp(-beta * energies)
    probs = weights / jnp.sum(weights)

    idx = jax.random.choice(key, states.shape[0], shape=(n_samples,), p=probs)
    return states_int[idx]

In [59]:
def local_field(spins, h, J):


    return jnp.asarray(h) + J @ jnp.asarray(spins)

In [60]:
def glauber_step(spins, key, h, J, beta):


    d = spins.shape[0]

    key_idx, key_u, key_new = jax.random.split(key, 3)

    site = jax.random.randint(key_idx, (), 0, d)

    local_fields = local_field(spins, h, J)

    p_plus = jax.nn.sigmoid(2.0 * beta * local_fields[site])

    u = jax.random.uniform(key_u)
    new_spin_value = jnp.where(u < p_plus, 1, -1)

    spins_new = spins.at[site].set(new_spin_value)


    return spins_new, key_new

In [61]:
def glauber_steps(spins, key, h, J, beta, n_steps):


    for i in range(int(n_steps)):
        spins, key = glauber_step(spins, key, h, J, beta)


    return spins

In [62]:
def evolve_mcmc(samples, h, J, beta = 1.0, n_steps = 10_000):


    seed = int(time.time_ns() & 0xFFFFFFFF)
    key = jax.random.PRNGKey(seed)


    spins = samples
    B = spins.shape[0]

    for _ in range(int(n_steps)):
        key, sub = jax.random.split(key)
        subkeys = jax.random.split(sub, B)
        spins, _ = jax.vmap(lambda s, k: glauber_step(s, k, h, J, beta))(spins, subkeys)
        

    return spins

In [63]:
def pairwise_squared_distances(x, y):


    x_norm = jnp.sum(x**2, axis=1).reshape(-1, 1)
    y_norm = jnp.sum(y**2, axis=1).reshape(1, -1)


    return x_norm + y_norm - 2 * jnp.dot(x, y.T)


def sinkhorn(a, b, C, epsilon=0.1, n_iters=100):


    K = jnp.exp(-C / epsilon)
    u = jnp.ones_like(a)
    v = jnp.ones_like(b)


    for _ in range(n_iters):
        u = a / (K @ v + 1e-9)
        v = b / (K.T @ u + 1e-9)


    transport_plan = jnp.outer(u, v) * K
    return jnp.sum(transport_plan * C)


def compute_sinkhorn(samples, evolved_samples, epsilon=0.1, n_iters=12):


    n = samples.shape[0]
    m = evolved_samples.shape[0]

    
    a = jnp.ones(n) / n
    b = jnp.ones(m) / m


    C_xy = pairwise_squared_distances(samples, evolved_samples)
    C_xx = pairwise_squared_distances(samples, samples)
    C_yy = pairwise_squared_distances(evolved_samples, evolved_samples)


    sink_xy = sinkhorn(a, b, C_xy, epsilon, n_iters)
    sink_xx = sinkhorn(a, a, C_xx, epsilon, n_iters)
    sink_yy = sinkhorn(b, b, C_yy, epsilon, n_iters)


    return sink_xy - 0.5 * (sink_xx + sink_yy)

In [64]:
def lm_loss(samples, h, J):


    evolved_samples = evolve_mcmc(samples, h, J)


    return compute_sinkhorn(samples, evolved_samples)

In [68]:
def mcmc_optimize_with_tracking(
    samples,
    h_true,
    J_true,
    n_epochs=1000,
    base_lr=1e-2,
    seed=10,
    plot_every=2,
    use_lr_schedule=True
):
    

    n, d = samples.shape


    key = jax.random.PRNGKey(seed)
    key_h, key_J = jax.random.split(key)


    h = jax.random.normal(key_h, shape=(d,))
    J = jax.random.normal(key_J, shape=(d, d))
    J = J - jnp.diag(jnp.diag(J))


    params = {
        "h": h,
        "J": J
    }


    if use_lr_schedule:
        lr_schedule = optax.linear_schedule(
            init_value=1e-2,
            end_value=base_lr,
            transition_steps=500
        )
        optimizer = optax.adam(learning_rate=lr_schedule)
    else:
        optimizer = optax.adam(learning_rate=base_lr)


    opt_state = optimizer.init(params)


    def loss_fn(params, samples):
        h = params["h"]
        J = params["J"]
        return lm_loss(samples, h, J)


    history = {
        "loss": [],
        "h_l2": [],
        "J_fro": [],
    }


    for epoch in range(n_epochs):
        loss_val, grads = jax.value_and_grad(loss_fn)(params, samples)
        updates, opt_state = optimizer.update(grads, opt_state, params)
        params = optax.apply_updates(params, updates)
        params["J"] = params["J"] - jnp.diag(jnp.diag(params["J"]))


        h_val = params["h"]
        J_val = params["J"]


        history["loss"].append(loss_val)
        history["h_l2"].append(jnp.linalg.norm(h_true - h_val))
        history["J_fro"].append(jnp.linalg.norm(J_true - J_val, ord='fro'))


        if epoch % plot_every == 0 or epoch == n_epochs - 1:
            print(f"epoch {epoch} | loss = {loss_val:.6f}")


    final_h = params["h"]
    final_J = params["J"]


    return final_h, final_J, history

In [69]:
key = jax.random.key(0)
d = 8
h_true = jnp.array([0.30, -0.10, 0.00, 0.20, -0.40, 0.15, 0.00, -0.25], dtype=jnp.float32)

J_true = jnp.array([
    [ 0.00,  0.50, -0.20,  0.00,  0.00,  0.05,  0.00,  0.10],
    [ 0.50,  0.00,  0.30, -0.10,  0.00, -0.05, -0.15,  0.00],
    [-0.20,  0.30,  0.00, -0.25,  0.40,  0.05,  0.00,  0.00],
    [ 0.00, -0.10, -0.25,  0.00,  0.20, -0.35,  0.12,  0.00],
    [ 0.00,  0.00,  0.40,  0.20,  0.00,  0.10,  0.00, -0.30],
    [ 0.05, -0.05,  0.05, -0.35,  0.10,  0.00,  0.25,  0.00],
    [ 0.00, -0.15,  0.00,  0.12,  0.00,  0.25,  0.00, -0.20],
    [ 0.10,  0.00,  0.00,  0.00, -0.30,  0.00, -0.20,  0.00]
], dtype=jnp.float32)

beta = 1.0

samples = sample_ising_bruteforce(h_true, J_true, beta, n_samples=50, key=key)
print(samples)

[[-1 -1 -1 -1  1  1  1 -1]
 [-1 -1 -1 -1 -1  1 -1  1]
 [ 1  1 -1 -1 -1  1  1  1]
 [ 1 -1 -1  1  1 -1 -1  1]
 [ 1 -1 -1 -1  1  1 -1 -1]
 [ 1  1  1 -1 -1  1 -1 -1]
 [ 1  1 -1  1 -1 -1 -1 -1]
 [-1  1 -1  1 -1 -1  1 -1]
 [-1 -1  1  1 -1  1  1  1]
 [ 1  1  1 -1 -1  1 -1 -1]
 [-1 -1 -1 -1 -1 -1  1  1]
 [ 1  1  1  1  1 -1  1 -1]
 [-1  1  1  1 -1 -1  1 -1]
 [ 1 -1 -1  1 -1 -1 -1 -1]
 [-1 -1 -1  1 -1 -1  1 -1]
 [-1 -1 -1  1 -1 -1 -1 -1]
 [-1 -1 -1  1  1  1 -1 -1]
 [-1 -1  1  1  1  1 -1 -1]
 [ 1 -1 -1  1 -1  1 -1 -1]
 [ 1  1  1  1  1 -1  1 -1]
 [ 1  1  1  1  1 -1  1 -1]
 [-1 -1 -1 -1 -1  1  1  1]
 [ 1 -1 -1  1 -1 -1  1  1]
 [-1 -1  1 -1 -1  1  1  1]
 [ 1 -1 -1  1 -1 -1 -1  1]
 [ 1 -1 -1 -1 -1 -1  1 -1]
 [-1 -1 -1  1 -1 -1  1  1]
 [-1 -1  1  1 -1 -1  1  1]
 [ 1  1 -1  1  1 -1  1 -1]
 [ 1  1 -1  1 -1  1  1  1]
 [ 1  1  1  1 -1 -1  1 -1]
 [ 1 -1 -1 -1 -1 -1 -1 -1]
 [ 1 -1 -1  1 -1  1  1 -1]
 [ 1 -1 -1  1 -1 -1 -1  1]
 [ 1 -1  1  1 -1  1  1 -1]
 [ 1 -1 -1  1 -1 -1  1 -1]
 [-1 -1  1 -1 -1  1  1 -1]
 

In [70]:
final_h, final_J, history = mcmc_optimize_with_tracking(samples, h_true, J_true)

epoch 0 | loss = 0.286969
epoch 2 | loss = 0.188276
epoch 4 | loss = 0.071930
epoch 6 | loss = 0.098596
epoch 8 | loss = 0.283458
epoch 10 | loss = 0.247863
epoch 12 | loss = 0.307593
epoch 14 | loss = 0.304495
epoch 16 | loss = 0.338417
epoch 18 | loss = 0.148208
epoch 20 | loss = 0.245422
epoch 22 | loss = 0.157679


KeyboardInterrupt: 