In [52]:
import sys
import math
from pathlib import Path
from typing import Optional, Tuple

import jax
import jax.numpy as jnp
import numpy as np
import pandas as pd
from typing import Literal
from jaxopt import LBFGS

In [53]:
def _read_samples(path: str | Path) -> Tuple[jnp.ndarray, jnp.ndarray]:

    """
    returns
    -------
    freq : jnp.ndarray, shape (num_conf,) - sampling frequencies
    configs : jnp.ndarray, shape (num_conf, num_spins) - spin configurations encoded as ±1
    """

    df = pd.read_csv(path, header=None)
    arr = df.values
    freq = arr[:, 0].astype(np.float32)
    configs = arr[:, 1:].astype(np.float32)
    
    return jnp.asarray(freq), jnp.asarray(configs)

In [54]:
def _read_adjacency(path: str | Path, n: int) -> jnp.ndarray:

    """
    returns
    -------
    adj : jnp.ndarray, shape (n, n) - binary mask for allowed couplings
    raises
    ------
    ValueError if the loaded matrix is not n x n
    """

    df = pd.read_csv(path, header=None)
    adj = df.values.astype(np.float32)
    if adj.shape != (n, n):
        raise ValueError(f"adjacency must be {n} x {n}, got {adj.shape}")
    
    return jnp.asarray(adj)

In [55]:
def _compute_lambda(alpha: float, n_spins: int, n_samples: int) -> float:

    """
    compute the regularization strength λ from
    the user-supplied coefficient alpha

    parameters
    ----------
    alpha : float
        user-chosen coefficient (0 < alpha ≤ 1 typically)
    n_spins : int
        number of spins in the system
    n_samples : int
        total number of samples in the histogram

    returns
    -------
    float
        lambda value used in the l1 penalty term
    """

    return alpha * math.sqrt(math.log((n_spins ** 2) / 0.05) / n_samples)

In [56]:
def _rise_loss(h: jnp.ndarray) -> jnp.ndarray:
    """
    rise loss: exp(-h)
    """
    return jnp.exp(-h)


def _logrise_loss(h: jnp.ndarray) -> jnp.ndarray:
    """
    log-rise loss: same exponential inside, the log is taken
    in the caller for numerical stability
    """
    return jnp.exp(-h)


def _rple_loss(h: jnp.ndarray) -> jnp.ndarray:
    """
    rple loss: log(1 + exp(-2h))
    """
    return jnp.log1p(jnp.exp(-2.0 * h))


def _mpf1_loss(h: jnp.ndarray) -> jnp.ndarray:
    """
    l_mpf ∝ exp(-ΔE / 2) with ΔE = 2h in the node-centric notation,
    hence exp(-h)
    """
    return jnp.exp(-h)  # equivalent to rise when written per‑node

def _mpf2_loss(
    h: jnp.ndarray,           # (n_samples,)
    configs: jnp.ndarray,     # (n_samples, num_spins)
    w_full: jnp.ndarray,      # (num_spins,)
    s: int,                   # indice dello spin centrale
) -> jnp.ndarray:

    central = configs[:, s:s+1]                         # shape (n_samples, 1)
    contrib = jnp.exp(-h[:, None]                       # -h
                      + w_full[None, :] *               # + w_sk ⋅
                        central * configs)              #   s_s s_k
    contrib = contrib.at[:, s].set(0.0)                 # esclude k == s
    return contrib.sum(axis=1)                          # somma su k


def _cms1_loss(h: jnp.ndarray, tau: float = 1.0) -> jnp.ndarray:
    
    return 0

In [57]:
from jaxopt import LBFGS
import jax.numpy as jnp
from typing import Optional

def _reconstruct_single_spin(
    s: int,
    freq: jnp.ndarray,
    configs: jnp.ndarray,
    method: str,
    lam: float,
    adj_row: Optional[jnp.ndarray],
) -> jnp.ndarray:
    """
    reconstruct the coupling vector w_{s,·} for a single spin s
    returns a vector of length num_spins
    """
    num_conf, num_spins = configs.shape
    n_samples = freq.sum()

    # ----- local statistics -----
    y = configs[:, s]
    nodal_stat = y[:, None] * configs
    nodal_stat = nodal_stat.at[:, s].set(y)
    nodal_stat = nodal_stat.astype(jnp.float32)

    # ----- masks -----
    l1_mask = jnp.ones(num_spins, dtype=jnp.float32).at[s].set(0.0)
    zero_mask = (
        (adj_row == 0) & (jnp.arange(num_spins) != s)
        if adj_row is not None
        else jnp.zeros(num_spins, dtype=bool)
    )

    free_idx = jnp.where(~zero_mask)[0]
    l1_mask_free = l1_mask[free_idx]

    # ----- smooth part of the loss -----
    def loss_smooth(w_free):
        w_full = jnp.zeros(num_spins, dtype=jnp.float32).at[free_idx].set(w_free)
        h = nodal_stat @ w_full

        if method == "RISE":
            return (freq / n_samples * _rise_loss(h)).sum()
        elif method == "logRISE":
            return jnp.log((freq / n_samples * _logrise_loss(h)).sum())
        elif method == "RPLE":
            return (freq / n_samples * _rple_loss(h)).sum()
        elif method == "MPF1":
            return (freq / n_samples * _mpf1_loss(h)).sum()
        elif method == "MPF2":
            return (freq / n_samples * _mpf2_loss(h, configs, w_full, s)).sum()
        elif method == "CMS1":
            return (freq / n_samples * _cms1_loss(h)).sum()
        else:
            raise ValueError(f"unknown method: {method}")

    # ----- total objective (smooth + λ‖w‖₁) -----
    def objective(w_free):
        return loss_smooth(w_free) + lam * jnp.sum(l1_mask_free * jnp.abs(w_free))

    # ----- lbfgs optimisation -----
    init_w = jnp.zeros((free_idx.size,), dtype=jnp.float32)
    solver = LBFGS(fun=objective, maxiter=500, tol=1e-6)
    sol = solver.run(init_w)
    w_opt_free = sol.params

    # ----- re‑insert into full vector -----
    w_full = jnp.zeros(num_spins, dtype=jnp.float32).at[free_idx].set(w_opt_free)
    return w_full

In [58]:
def inverse_ising(
    method: str,
    regularizing_value: float,
    symmetrization: str,
    file_samples_histo: str | Path,
    file_reconstruction: str | Path = "reconstruction.csv",
    adjacency_path: Optional[str | Path] = None,
) -> np.ndarray:
    
    method = method.strip()
    symmetrization = symmetrization.strip().upper()

    freq, configs = _read_samples(file_samples_histo)
    num_conf, num_spins = configs.shape
    num_samples = float(freq.sum())

    adj = None
    if adjacency_path is not None:
        adj = _read_adjacency(adjacency_path, num_spins)

    lam = _compute_lambda(regularizing_value, num_spins, num_samples)
    print(f"λ = {lam:.5g}  (reg = {regularizing_value})")

    rows = []
    for s in range(num_spins):
        print(f"[{s+1}/{num_spins}] reconstruction spin {s}")
        adj_row = adj[s] if adj is not None else None
        w_row = _reconstruct_single_spin(
            s, freq, configs, method, lam, adj_row
        )
        rows.append(w_row)

    W = jnp.stack(rows)  # (n, n)

    if symmetrization == "Y":
        W = 0.5 * (W + W.T)

    W_np = np.asarray(W)



    pd.DataFrame(W_np).to_csv(file_reconstruction, header=False, index=False)
    print(f"matrix saved in '{file_reconstruction}'")

    return W_np

In [59]:
import pandas as pd
from pathlib import Path

# ────────────────────────────────────────────────
method              = "MPF2"       # "RISE", "logRISE", "RPLE", or "MPF"
regularizing_value  = 0.2          # coefficient α (0 < α ≤ 1 is typical)
symmetrization      = "Y"          # "Y" = symmetrize; "N" = keep asymmetric
file_samples_histo  = "output_samples.csv"   # CSV with [freq, spin1, spin2, …]
file_reconstruction = "reconstruction.csv"   # output file for the estimated matrix
adjacency_path      = None          # set to a path if you have structural constraints
# ────────────────────────────────────────────────

# quick path checks
file_samples_histo = Path(file_samples_histo)
if not file_samples_histo.exists():
    raise FileNotFoundError(f"{file_samples_histo} not found")

if adjacency_path is not None:
    adjacency_path = Path(adjacency_path)
    if not adjacency_path.exists():
        raise FileNotFoundError(f"{adjacency_path} not found")

print("parameters set manually:")
print(f"  method              = {method}")
print(f"  regularizing_value  = {regularizing_value}")
print(f"  symmetrization      = {symmetrization}")
print(f"  file_samples_histo  = {file_samples_histo}")
print(f"  file_reconstruction = {file_reconstruction}")
print(f"  adjacency_path      = {adjacency_path}")

# ---- call inverse_ising ----
W = inverse_ising(
    method=method,
    regularizing_value=regularizing_value,
    symmetrization=symmetrization,
    file_samples_histo=file_samples_histo,
    file_reconstruction=file_reconstruction,
    adjacency_path=adjacency_path,
)

print("reconstruction finished. matrix W:")
W  # Jupyter will render the matrix


parameters set manually:
  method              = MPF2
  regularizing_value  = 0.2
  symmetrization      = Y
  file_samples_histo  = output_samples.csv
  file_reconstruction = reconstruction.csv
  adjacency_path      = None
λ = 0.017193  (reg = 0.2)
[1/9] reconstruction spin 0
[2/9] reconstruction spin 1
[3/9] reconstruction spin 2
[4/9] reconstruction spin 3
[5/9] reconstruction spin 4
[6/9] reconstruction spin 5
[7/9] reconstruction spin 6
[8/9] reconstruction spin 7
[9/9] reconstruction spin 8
matrix saved in 'reconstruction.csv'
reconstruction finished. matrix W:


array([[ 5.8559872e-02, -7.2853327e-02,  4.2882901e-01,  4.5894817e-01,
         9.7453659e-03,  1.0526966e-01,  4.2987093e-01,  3.3604153e-02,
         9.5957167e-02],
       [-7.2853327e-02, -3.6161834e-01,  4.2880201e-01,  6.5232506e-03,
         5.0824475e-01,  5.9816740e-02,  9.7716071e-02,  5.2247119e-01,
         1.1421292e-01],
       [ 4.2882901e-01,  4.2880201e-01, -5.1630497e-02,  9.2231229e-02,
         1.5912305e-01,  4.3174636e-01,  6.3337825e-02,  3.9527769e-04,
         5.2496445e-01],
       [ 4.5894817e-01,  6.5232506e-03,  9.2231229e-02,  3.0246025e-01,
         2.8797179e-01,  5.4722524e-01,  4.3666518e-01,  1.5365683e-01,
         3.7431318e-02],
       [ 9.7453659e-03,  5.0824475e-01,  1.5912305e-01,  2.8797179e-01,
         2.6684701e-02,  4.0678331e-01,  4.8374951e-02,  5.2049232e-01,
         7.8741908e-02],
       [ 1.0526966e-01,  5.9816740e-02,  4.3174636e-01,  5.4722524e-01,
         4.0678331e-01, -5.9307098e-02,  5.8199942e-02,  7.5365342e-02,
         4.

In [60]:
def reconstruction_error(
    reco_csv: str | Path,
    true_csv: str | Path,
    norm: Literal["fro", "l1", "l2", "max"] = "fro",
    ignore_diag: bool = False,
) -> float:
    """
    compare reconstructed couplings with ground-truth adjacency

    parameters
    ----------
    reco_csv : path to reconstruction.csv (estimated matrix)
    true_csv : path to input_adjacency.csv (ground-truth matrix)
    norm      : which matrix norm to compute
                "fro" → frobenius (default)
                "l1"  → entry-wise absolute sum
                "l2"  → spectral norm (largest singular value)
                "max" → max |diff|
    ignore_diag : if true, set the diagonal of both matrices to 0
                  before computing the norm (useful if fields h_i
                  are in the diagonal and you only care about J_ij)

    returns
    -------
    float - requested norm of (reco - true)
    """
    reco = pd.read_csv(reco_csv, header=None).values.astype(float)
    true = pd.read_csv(true_csv, header=None).values.astype(float)

    if reco.shape != true.shape:
        raise ValueError(f"shape mismatch: {reco.shape} vs {true.shape}")

    if ignore_diag:
        np.fill_diagonal(reco, 0.0)
        np.fill_diagonal(true, 0.0)

    diff = reco - true

    if norm == "fro":
        return np.linalg.norm(diff, "fro")
    elif norm == "l1":
        return np.sum(np.abs(diff))
    elif norm == "l2":
        return np.linalg.norm(diff, 2)
    elif norm == "max":
        return np.max(np.abs(diff))
    else:
        raise ValueError(f"unknown norm '{norm}'")

In [61]:
err_fro = reconstruction_error("reconstruction.csv", "input_adjacency.csv", norm="fro")
print("frobenius error =", err_fro)

err_l2  = reconstruction_error("reconstruction.csv", "input_adjacency.csv", norm="l2", ignore_diag=True)
print("l2 error (off-diag) =", err_l2)

frobenius error = 0.7183166174238585
l2 error (off-diag) = 0.3687756773704446
