In [1]:
import sys
import os
sys.path.append(os.path.abspath('..'))

import jax
import jax.numpy as jnp
import data_gen as dg
import optax

In [2]:
def flip_spin(sigma, i):
    
    return sigma.at[i].set(-sigma[i])

In [3]:
def hamiltonian(sigma, h, J, beta):

    interaction_term = jnp.dot(sigma, J @ sigma)
    field_term = jnp.dot(h, sigma)
    energy = - field_term - 0.5 * interaction_term
    
    return beta * energy

In [4]:
def flip_energy_diff(sigma, i, h, J, beta):

    sigma_flip = flip_spin(sigma, i)
    delta = hamiltonian(sigma_flip, h, J, beta) - hamiltonian(sigma, h, J, beta)
    
    return delta

In [5]:
def compute_energy_differences_all_sites(sigma, h, J, beta):
    
    d = sigma.shape[0]
    diffs = []

    for i in range(d):
        delta_e = flip_energy_diff(sigma, i, h, J, beta)
        diffs.append(delta_e)

    return jnp.array(diffs)

In [6]:
def csm_loss_per_sample(sigma, h, J, beta):

    delta_e = compute_energy_differences_all_sites(sigma, h, J, beta)

    return jnp.sum(jnp.exp(-2 * delta_e) - jnp.exp(delta_e) + 1) 

In [7]:
def csm_loss(samples, h, J, beta):

    return jnp.mean(jax.vmap(csm_loss_per_sample, in_axes=(0, None, None, None))(samples, h, J, beta))

In [None]:
def csm_gradients(samples, h, J, beta):

    grad_loss = jax.grad(csm_loss, argnums=(1, 2))
    
    return grad_loss(samples, h, J, beta)

In [9]:
def csm_step(params, opt_state, samples, optimizer, beta):
    
    grad_h, grad_J = csm_gradients(samples, params["h"], params["J"], beta)

    grads = {
        "h": grad_h,
        "J": grad_J
    }

    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)

    J = params["J"]
    J = 0.5 * (J + J.T)
    J = J - jnp.diag(jnp.diag(J))
    params["J"] = J

    return params, opt_state

In [None]:
def optimize_csm(samples, h_init, J_init, n_steps=1000, lr=1e-2, beta=1.0):
    
    params = {
        "h": h_init,
        "J": J_init
    }

    optimizer = optax.adam(lr)
    opt_state = optimizer.init(params)

    for t in range(n_steps):
        params, opt_state = csm_step(params, opt_state, samples, optimizer, beta)

        if t % 100 == 0:
            loss_val = csm_loss(samples, params["h"], params["J"], beta)
            print(f"step {t} | loss = {loss_val:.6f}")

    return params["h"], params["J"]

In [11]:
d = 3

n_samples = 100000

h, J = dg.generate_ising_params(d, sigma_h=1, sigma_J=0.5, seed=0)

samples = dg.generate_ising_data(n_samples, h=h, J=J, n_steps=1000, beta=1, seed=42)

step 0:[0.13600000739097595, -0.12300000339746475, -0.17000000178813934]
[0.843000054359436, -0.7590000629425049, -0.64000004529953]
0.13300001621246338

step 500:[0.9440000653266907, -0.9040000438690186, -0.6480000019073486]
[0.9830000400543213, -0.9410000443458557, -0.6610000133514404]
0.0036666791420429945



In [12]:
print("True h:", h)
print("True J:", J)

True h: [ 1.0040143 -0.9063372 -0.7481722]
True J: [[ 0.         -1.1946154  -0.47133768]
 [-1.1946154   0.         -0.44069302]
 [-0.47133768 -0.44069302  0.        ]]


In [13]:
key = jax.random.PRNGKey(0)

h0 = 0.1 * jax.random.normal(key, shape=(d,))
J0 = 0.1 * jax.random.normal(key, shape=(d, d))
J0 = 0.5 * (J0 + J0.T)
J0 = J0 - jnp.diag(jnp.diag(J0))

h_est, J_est = optimize_csm(samples, h0, J0, n_steps=1000, lr=0.1, beta=1.0)

print("Estimated h:", jnp.round(h_est, 3))
print("Estimated J:", jnp.round(J_est, 3))

step 0 | loss = 0.805071
step 100 | loss = -51.617310
step 200 | loss = -51.623222
step 300 | loss = -51.623295
step 400 | loss = -51.623318
step 500 | loss = -51.623287
step 600 | loss = -51.623310
step 700 | loss = -51.623295
step 800 | loss = -51.623325
step 900 | loss = -51.623299
Estimated h: [ 0.855 -0.734 -0.432]
Estimated J: [[ 0.         -0.9990001  -0.47300002]
 [-0.9990001   0.         -0.42800003]
 [-0.47300002 -0.42800003  0.        ]]
