# Fitting Energy-Based Models from Experimental Data

Given experimental measurements of system equilibria, this notebook shows how to fit THRML Ising model parameters to reproduce observed behavior.

**Method:** solve analytically using the mean-field approximation $\theta = \sigma(2b)$, then validate against THRML samples.

Relates to [#29](https://github.com/extropic-ai/thrml/issues/29).

In [None]:
import jax
import jax.numpy as jnp
import jax.random
import matplotlib.pyplot as plt
import numpy as np
from scipy.optimize import minimize
from scipy.special import expit  # sigmoid

from thrml.block_management import Block
from thrml.block_sampling import sample_states, SamplingSchedule
from thrml.models.ising import IsingEBM, IsingSamplingProgram
from thrml.pgm import SpinNode

**The fitting problem**

We have a system with $K$ binary spins. The mean spin fraction $\theta = \frac{1}{K}\sum_i \mathbb{1}[s_i = +1]$ is the observable. Two competing forces shape $\theta$:

- **Drive** (bias $b_\alpha > 0$): pushes spins toward $+1$
- **Constraint** (bias $b_\gamma > 0$, scaled by $c \in [0,1]$): pushes spins toward $-1$

Net per-spin bias: $b = b_\alpha - c \cdot b_\gamma$. Mean-field equilibrium: $\theta^* = \sigma(2b) = 1/(1 + e^{-2b})$.

**Experimental targets:**

| Condition | Constraint $c$ | Target $\theta$ |
|-----------|:--------------:|:---------------:|
| A (no constraint) | 0.0 | 0.80 |
| B (partial) | 0.5 | 0.26 |
| C (full constraint) | 1.0 | 0.00 |

In [None]:
K = 16  # spins per agent
seed = 42

# Experimental targets
conditions = [
    {"name": "A (no constraint)", "c": 0.0, "target_theta": 0.80},
    {"name": "B (partial)", "c": 0.5, "target_theta": 0.26},
    {"name": "C (full constraint)", "c": 1.0, "target_theta": 0.00},
]

# Held-out validation targets (not used during fitting)
validation = [
    {"name": "V1 (c=0.25)", "c": 0.25, "target_theta": 0.55},
    {"name": "V2 (c=0.75)", "c": 0.75, "target_theta": 0.10},
]

**Parameterized model**

Ising chain where all spins share the same net bias $b = b_\alpha - c \cdot b_\gamma$, with weak ferromagnetic coupling for coherence.

In [None]:
def build_model(b_alpha, b_gamma, c, K=16):
    """Build an Ising EBM with drive bias b_alpha and constraint bias b_gamma.

    Net per-spin bias: b = b_alpha - c * b_gamma
    Equilibrium: theta = sigma(2*b) = 1/(1+exp(-2*b))
    """
    nodes = [SpinNode() for _ in range(K)]
    edges = [(nodes[i], nodes[i + 1]) for i in range(K - 1)]

    net_bias = b_alpha - c * b_gamma
    biases = jnp.full(K, net_bias)

    # Weak ferromagnetic coupling for spin coherence
    J_within = 0.02 / K
    edge_weights = jnp.full(K - 1, J_within)

    beta = jnp.array(1.0)
    model = IsingEBM(nodes, edges, biases, edge_weights, beta)
    return model, nodes


def sample_theta(b_alpha, b_gamma, c, K=16, n_samples=200, n_warmup=500, rng_key=None):
    """Sample from the model and return mean theta."""
    model, nodes = build_model(b_alpha, b_gamma, c, K)

    # Chromatic blocking (2 colors for 1D chain)
    even = [nodes[i] for i in range(0, K, 2)]
    odd = [nodes[i] for i in range(1, K, 2)]
    free_blocks = [Block(even), Block(odd)]

    program = IsingSamplingProgram(model, free_blocks, [])
    schedule = SamplingSchedule(n_warmup, n_samples, 5)

    init_state = [
        jnp.zeros(len(even), dtype=jnp.bool_),
        jnp.zeros(len(odd), dtype=jnp.bool_),
    ]

    if rng_key is None:
        rng_key = jax.random.key(0)

    samples = sample_states(
        rng_key, program, schedule, init_state, [], [Block(nodes)]
    )
    spins = samples[0]  # shape (n_samples, K), boolean
    theta = jnp.mean(spins.astype(jnp.float32), axis=-1)  # per-sample theta
    return float(jnp.mean(theta)), float(jnp.std(theta))

**Analytic mean-field solution**

For large $K$, the mean-field approximation $\theta^* = \sigma(2b)$ is tight. We fit $b_\alpha$ and $b_\gamma$ by minimizing the squared error on the analytic prediction — no THRML sampling required, just three function evaluations per iteration.

In [None]:
def analytic_theta(b_alpha, b_gamma, c):
    """Mean-field prediction: theta = sigma(2*(b_alpha - c*b_gamma))."""
    return float(expit(2.0 * (b_alpha - c * b_gamma)))


def analytic_loss(params, conditions):
    b_alpha, b_gamma = params
    eps = 1e-3  # floor for target=0.00
    total = 0.0
    thetas = []
    for cond in conditions:
        target = max(cond["target_theta"], eps)
        theta = analytic_theta(b_alpha, b_gamma, cond["c"])
        total += (theta - target) ** 2
        thetas.append(theta)
    # rank-order penalty
    for i in range(len(thetas) - 1):
        if thetas[i] <= thetas[i + 1]:
            total += 1.0
    return total


result = minimize(analytic_loss, [0.5, 2.0], args=(conditions,), method="Nelder-Mead",
                  options={"xatol": 1e-6, "fatol": 1e-8, "maxiter": 10_000})
b_alpha_fit, b_gamma_fit = result.x

print(f"Fitted: b_alpha={b_alpha_fit:.4f}, b_gamma={b_gamma_fit:.4f}")
print(f"Analytic loss: {result.fun:.6f}")
print()
print(f"{'Condition':<20} {'Target':>8} {'Analytic':>9}")
print("-" * 40)
for cond in conditions:
    theta_mf = analytic_theta(b_alpha_fit, b_gamma_fit, cond["c"])
    print(f"{cond['name']:<20} {cond['target_theta']:>8.3f} {theta_mf:>9.3f}")

**THRML validation**

With fitted parameters in hand, we run one THRML sampling pass per condition to confirm the sampler matches the mean-field prediction. This is a single round of sampling — not an optimization loop.

In [None]:
key = jax.random.key(seed)

print(f"{'Condition':<20} {'Target':>8} {'Analytic':>9} {'THRML':>8} {'Δ (MF)':>8}")
print("-" * 56)

thrml_thetas = []
for cond in conditions:
    key, subkey = jax.random.split(key)
    theta_mf = analytic_theta(b_alpha_fit, b_gamma_fit, cond["c"])
    theta_thrml, _ = sample_theta(b_alpha_fit, b_gamma_fit, cond["c"], rng_key=subkey)
    thrml_thetas.append(theta_thrml)
    delta_mf = abs(theta_thrml - theta_mf)
    print(f"{cond['name']:<20} {cond['target_theta']:>8.3f} {theta_mf:>9.3f} {theta_thrml:>8.3f} {delta_mf:>8.3f}")

rank_ok = all(thrml_thetas[i] > thrml_thetas[i + 1] for i in range(len(thrml_thetas) - 1))
print(f"\nRank order preserved: {rank_ok}")

In [None]:
print("Held-out validation:")
print(f"{'Condition':<20} {'Target':>8} {'Analytic':>9} {'THRML':>8} {'Error':>8}")
print("-" * 56)

for cond in validation:
    key, subkey = jax.random.split(key)
    theta_mf = analytic_theta(b_alpha_fit, b_gamma_fit, cond["c"])
    theta_thrml, _ = sample_theta(b_alpha_fit, b_gamma_fit, cond["c"], rng_key=subkey)
    err = abs(theta_thrml - cond["target_theta"])
    print(f"{cond['name']:<20} {cond['target_theta']:>8.3f} {theta_mf:>9.3f} {theta_thrml:>8.3f} {err:>8.3f}")

In [None]:
# Theta vs constraint strength: THRML samples, mean-field, and targets
c_dense = np.linspace(0, 1, 20)
theta_thrml_curve = []

for c_val in c_dense:
    key, subkey = jax.random.split(key)
    theta, _ = sample_theta(b_alpha_fit, b_gamma_fit, c_val, rng_key=subkey)
    theta_thrml_curve.append(theta)

theta_mf_curve = [analytic_theta(b_alpha_fit, b_gamma_fit, c) for c in c_dense]

fig, ax = plt.subplots(figsize=(8, 5))
ax.plot(c_dense, theta_thrml_curve, "b-", linewidth=2, label="THRML samples")
ax.plot(c_dense, theta_mf_curve, "k--", alpha=0.5, label="Analytic (mean-field)")

# Training targets
ax.scatter([cond["c"] for cond in conditions],
           [cond["target_theta"] for cond in conditions],
           c="red", s=80, zorder=5, label="Training targets")
# Validation targets
ax.scatter([cond["c"] for cond in validation],
           [cond["target_theta"] for cond in validation],
           c="green", s=80, marker="^", zorder=5, label="Validation targets")

ax.set_xlabel("Constraint strength c")
ax.set_ylabel(r"Equilibrium $\theta$")
ax.set_title(f"Fitted model  (b_alpha={b_alpha_fit:.3f}, b_gamma={b_gamma_fit:.3f})")
ax.legend()
plt.tight_layout()
plt.show()

**Discussion**

General pattern for fitting THRML Ising parameters from experimental data:

1. **Mean-field pre-fit** — minimize $\sum_i(\sigma(2b_i) - \theta_i^*)^2$ analytically. Fast (no sampling), tight for large $K$.
2. **THRML validation** — one sampling pass per condition confirms the sampler matches the mean-field prediction.
3. **Out-of-sample test** — held-out conditions verify the parameters generalize.

**When this works:** observable $\theta$ is smooth in the parameters, $K$ is large enough that mean-field is tight, and the energy function is well-specified.

**When it doesn't:** high-dimensional parameter spaces (use gradient-based methods), or when mean-field is a poor approximation (small $K$, strong coupling). For neural-network-parameterized EBMs, consider contrastive divergence via `equinox` + `optax`.