In [6]:
import jax
import jax.numpy as jnp
import numpy as np
import pandas as pd
import sys
from typing import Literal
import optax
from scipy.optimize import minimize

In [8]:
def reconstruct_all_params(samples_hist, method: Literal['RISE', 'logRISE', 'RPLE'] = 'RISE',
                           lambda_val=0.1, symmetrize=True):

    num_conf, num_col = samples_hist.shape
    num_spins = num_col - 1
    num_samples = jnp.sum(samples_hist[:, 0])
    reconstruction = jnp.zeros((num_spins, num_spins))

    def RISEobjective(h): return jnp.exp(-h)
    def logRISEobjective(h): return jnp.log(jnp.sum(jnp.exp(-h)))
    def RPLEobjective(h): return jnp.log(1 + jnp.exp(-2 * h))

    if method == 'RISE':
        obj_fun = RISEobjective
    elif method == 'logRISE':
        obj_fun = RISEobjective  # log will be applied outside
    elif method == 'RPLE':
        obj_fun = RPLEobjective
    else:
        raise ValueError("Invalid method")

    for i in range(num_spins):
        y = samples_hist[:, 1 + i]
        X = samples_hist[:, 1:]
        X = X * y[:, None]
        weights = samples_hist[:, 0] / num_samples

        def loss_fn(w):
            wx = X @ w
            L = jnp.sum(weights * obj_fun(wx))
            if method == "logRISE":
                L = jnp.log(L)
            l1 = lambda_val * jnp.sum(jnp.abs(jnp.delete(w, i)))  # exclude diagonal from penalty
            return L + l1

        loss_and_grad = jax.value_and_grad(loss_fn)
        x0 = np.zeros(num_spins)
        res = minimize(lambda w: loss_and_grad(jnp.array(w)), x0, method='L-BFGS-B', jac=True)
        reconstruction = reconstruction.at[i, :].set(res.x)

    if symmetrize:
        reconstruction = 0.5 * (reconstruction + reconstruction.T)

    return reconstruction


In [9]:
d = 4
J = jnp.array([[ 0.0,  0.8,  0.0,  0.0],
               [ 0.8,  0.0, -0.8,  0.0],
               [ 0.0, -0.8,  0.0,  0.8],
               [ 0.0,  0.0,  0.8,  0.0]])
h = jnp.array([0.2, 0.0, -0.1, 0.0])
adj = J + jnp.diag(h)

In [None]:
df = pd.read_csv("output_samples.csv", header=None)

samples_hist = jnp.array(df.values)

Array([[ 472,   -1,   -1,   -1,   -1],
       [ 101,   -1,   -1,   -1,    1],
       [ 393,   -1,   -1,    1,   -1],
       [2094,   -1,   -1,    1,    1],
       [ 495,   -1,    1,   -1,   -1],
       [  92,   -1,    1,   -1,    1],
       [  19,   -1,    1,    1,   -1],
       [  76,   -1,    1,    1,    1],
       [ 149,    1,   -1,   -1,   -1],
       [  36,    1,   -1,   -1,    1],
       [ 119,    1,   -1,    1,   -1],
       [ 617,    1,   -1,    1,    1],
       [3804,    1,    1,   -1,   -1],
       [ 752,    1,    1,   -1,    1],
       [ 140,    1,    1,    1,   -1],
       [ 641,    1,    1,    1,    1]], dtype=int32)

In [16]:
recon = reconstruct_all_params(samples_hist, method="RISE", lambda_val=0.2)

In [17]:
recon

Array([[ 1.3086606e+01, -3.6402307e-06, -2.0923433e-06,  1.7275198e-08],
       [-3.6402307e-06,  1.4473771e+01,  7.4842372e-09, -2.8839651e-07],
       [-2.0923433e-06,  7.4842372e-09,  1.7166050e+01,  5.3871624e-08],
       [ 1.7275198e-08, -2.8839651e-07,  5.3871624e-08,  1.7046007e+01]],      dtype=float32)