# Verifying Fluctuation Theorems in THRML

**Extracting Peclet numbers, Crooks ratios, and entropy production**

Any EBM sampled by THRML satisfies detailed balance at equilibrium.
This notebook shows how to extract thermodynamic observables from
THRML sample trajectories and verify fluctuation theorem predictions.

**Observables:**
1. Drift velocity $v$ and diffusion coefficient $D$
2. Peclet number $\text{Pe} = |v| \cdot L / D$ (directed vs diffusive transport)
3. Crooks ratio: $P(+\Delta E) / P(-\Delta E) = \exp(\beta \Delta E)$
4. Entropy production rate $\dot{S}$

**References:** Crooks (1999), Jarzynski (1997), Hack et al. (2022).

In [None]:
import jax
import jax.numpy as jnp
import jax.random
import matplotlib.pyplot as plt
import numpy as np

from thrml.block_management import Block
from thrml.block_sampling import sample_states, SamplingSchedule
from thrml.models.ising import IsingEBM, IsingSamplingProgram
from thrml.pgm import SpinNode

**Background**

For any system satisfying detailed balance, the **Crooks fluctuation theorem** relates
forward and reverse transition probabilities:

$$\frac{P(\sigma)}{P(-\sigma)} = e^{\sigma}$$

where $\sigma$ is the entropy production per transition. For an Ising model with energy
$E(\mathbf{s}) = -\beta\left(\sum_i h_i s_i + \sum_{ij} J_{ij} s_i s_j\right)$,
the entropy production for a transition $\mathbf{s} \to \mathbf{s}'$ is
$\sigma = \beta \cdot [E(\mathbf{s}) - E(\mathbf{s}')]$.

The **Peclet number** $\text{Pe} = |v| / D$ measures the ratio of directed transport
to diffusion. $\text{Pe} > 1$ indicates driven (non-diffusive) dynamics;
$\text{Pe} \approx 0$ indicates thermal equilibrium.

**Build a 1D Ising chain**

We use a simple 1D ferromagnetic chain with an external field $h$.
When $h > 0$, the field biases spins toward $+1$, creating directed transport
during relaxation from a cold (all $-1$) initial state.

In [None]:
N = 32  # number of spins
seed = 42

# Create 1D chain
nodes = [SpinNode() for _ in range(N)]
edges = [(nodes[i], nodes[i + 1]) for i in range(N - 1)]

# External field strength (drives spins toward +1)
h_field = 0.3
biases = jnp.full(N, h_field)

# Nearest-neighbor ferromagnetic coupling
J = 1.0
weights = jnp.full(N - 1, J)

# Inverse temperature
beta = jnp.array(1.0)

model = IsingEBM(nodes, edges, biases, weights, beta)
print(f"Ising chain: {N} spins, h={h_field}, J={J}, beta={float(beta)}")

In [None]:
def ising_energy(spins, biases, weights, N, beta):
    """Compute Ising energy from a spin configuration.

    spins: boolean array of shape (N,). True = +1, False = -1.
    Returns scalar energy.
    """
    s = 2.0 * spins.astype(jnp.float32) - 1.0  # {False,True} -> {-1,+1}
    bias_energy = jnp.sum(biases * s)
    coupling_energy = jnp.sum(weights * s[:-1] * s[1:])  # nearest-neighbor
    return -beta * (bias_energy + coupling_energy)


def magnetization(spins):
    """Mean magnetization m = (1/N) sum(s_i) in [-1, +1]."""
    return jnp.mean(2.0 * spins.astype(jnp.float32) - 1.0)

**Sample trajectories**

We sample two trajectories:
1. **Relaxation** — start from all spins down ($m = -1$), let the system relax toward equilibrium under the field $h > 0$. This shows directed transport ($\text{Pe} > 0$).
2. **Equilibrium** — start from a pre-equilibrated state. Fluctuations are diffusive ($\text{Pe} \approx 0$).

In [None]:
# Chromatic blocking for 1D chain (2 colors: even/odd sites)
even_nodes = [nodes[i] for i in range(0, N, 2)]
odd_nodes = [nodes[i] for i in range(1, N, 2)]
free_blocks = [Block(even_nodes), Block(odd_nodes)]

program = IsingSamplingProgram(model, free_blocks, [])

# --- Relaxation trajectory: cold start (all spins down) ---
n_samples = 500
init_cold = [
    jnp.zeros(len(even_nodes), dtype=jnp.bool_),
    jnp.zeros(len(odd_nodes), dtype=jnp.bool_),
]
schedule_relax = SamplingSchedule(0, n_samples, 1)  # no warmup, record every sweep

key = jax.random.key(seed)
key, subkey = jax.random.split(key)
samples_relax = sample_states(
    subkey, program, schedule_relax, init_cold, [], [Block(nodes)]
)
traj_relax = samples_relax[0]  # shape (n_samples, N)

# --- Equilibrium trajectory: long warmup first ---
schedule_eq = SamplingSchedule(2000, n_samples, 1)  # 2000 warmup sweeps
key, subkey = jax.random.split(key)
samples_eq = sample_states(
    subkey, program, schedule_eq, init_cold, [], [Block(nodes)]
)
traj_eq = samples_eq[0]

print(f"Trajectory shapes: relaxation {traj_relax.shape}, equilibrium {traj_eq.shape}")

**Magnetization dynamics**

In [None]:
# Compute magnetization time series
mag_relax = jax.vmap(magnetization)(traj_relax)
mag_eq = jax.vmap(magnetization)(traj_eq)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

ax1.plot(np.array(mag_relax), linewidth=0.8)
ax1.set_xlabel("Gibbs sweep")
ax1.set_ylabel("Magnetization m")
ax1.set_title("Relaxation (cold start)")
ax1.axhline(y=np.mean(mag_eq), color="r", linestyle="--", alpha=0.5, label="equilibrium")
ax1.legend()

ax2.plot(np.array(mag_eq), linewidth=0.8)
ax2.set_xlabel("Gibbs sweep")
ax2.set_ylabel("Magnetization m")
ax2.set_title("Equilibrium (after warmup)")

plt.tight_layout()
plt.show()

**Peclet number extraction**

The Peclet number is computed from the magnetization trajectory:

$$\text{Pe} = \frac{|\langle \Delta m \rangle|}{\text{Var}(\Delta m) / 2}$$

where $\Delta m = m(t+1) - m(t)$. During relaxation $\text{Pe} > 1$
(directed transport dominates). At equilibrium $\text{Pe} \approx 0$
(pure diffusion).

In [None]:
def compute_peclet(mag_trajectory, window=None):
    """Compute Peclet number from a magnetization time series.

    If window is given, compute Pe in sliding windows of that size.
    Otherwise compute a single Pe for the entire trajectory.
    """
    mag = np.array(mag_trajectory)
    dm = np.diff(mag)

    if window is None:
        v = np.mean(dm)
        D = np.var(dm) / 2.0
        return np.abs(v) / D if D > 1e-12 else 0.0

    # Sliding window
    pe_values = []
    centers = []
    for start in range(0, len(dm) - window, window // 2):
        chunk = dm[start : start + window]
        v = np.mean(chunk)
        D = np.var(chunk) / 2.0
        pe = np.abs(v) / D if D > 1e-12 else 0.0
        pe_values.append(pe)
        centers.append(start + window // 2)
    return np.array(centers), np.array(pe_values)


# Global Pe
pe_relax = compute_peclet(mag_relax)
pe_eq = compute_peclet(mag_eq)
print(f"Peclet number (relaxation): {pe_relax:.3f}")
print(f"Peclet number (equilibrium): {pe_eq:.3f}")

# Sliding-window Pe shows Pe decaying as system equilibrates
centers, pe_sliding = compute_peclet(mag_relax, window=50)

fig, ax = plt.subplots(figsize=(8, 4))
ax.plot(centers, pe_sliding, "o-", markersize=4)
ax.axhline(y=1.0, color="r", linestyle="--", alpha=0.5, label="Pe = 1")
ax.set_xlabel("Gibbs sweep")
ax.set_ylabel("Peclet number")
ax.set_title("Pe during relaxation (sliding window)")
ax.legend()
plt.tight_layout()
plt.show()

**Crooks fluctuation theorem verification**

The Crooks theorem predicts that for energy changes $\Delta E$ along the trajectory:

$$\ln \frac{P(+\Delta E)}{P(-\Delta E)} = \beta \cdot \Delta E$$

We verify this by histogramming $\Delta E$ values and checking the log-ratio
against the theoretical prediction.

In [None]:
# Compute energy at each step of the equilibrium trajectory
energy_fn = lambda s: ising_energy(s, biases, weights, N, beta)
energies_eq = jax.vmap(energy_fn)(traj_eq)
dE = np.array(jnp.diff(energies_eq))

# Histogram of energy changes
n_bins = 40
dE_range = max(abs(dE.min()), abs(dE.max()))
bins = np.linspace(-dE_range, dE_range, n_bins + 1)
bin_centers = 0.5 * (bins[:-1] + bins[1:])

hist_fwd, _ = np.histogram(dE, bins=bins)
hist_rev, _ = np.histogram(-dE, bins=bins)

# Crooks ratio (only where both histograms have counts)
mask = (hist_fwd > 5) & (hist_rev > 5)  # require sufficient statistics
log_ratio = np.log(hist_fwd[mask] / hist_rev[mask])
predicted = float(beta) * bin_centers[mask]

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

# Energy change distribution
ax1.hist(dE, bins=50, alpha=0.7, density=True)
ax1.set_xlabel(r"$\Delta E$")
ax1.set_ylabel("Density")
ax1.set_title("Energy change distribution (equilibrium)")

# Crooks verification
ax2.scatter(bin_centers[mask], log_ratio, s=20, label="Measured")
ax2.plot(bin_centers[mask], predicted, "r-", label=r"$\beta \Delta E$ (theory)")
ax2.set_xlabel(r"$\Delta E$")
ax2.set_ylabel(r"$\ln[P(+\Delta E) / P(-\Delta E)]$")
ax2.set_title("Crooks fluctuation theorem")
ax2.legend()

plt.tight_layout()
plt.show()

# Quantitative check: slope should equal beta
if len(bin_centers[mask]) > 2:
    slope = np.polyfit(bin_centers[mask], log_ratio, 1)[0]
    print(f"Measured slope: {slope:.3f} (expected: {float(beta):.3f})")

**Entropy production rate**

The entropy production per step is $\sigma = \beta \cdot \Delta E$.
The **integral fluctuation theorem** predicts $\langle e^{-\sigma} \rangle = 1$.
The mean entropy production $\langle \sigma \rangle \geq 0$ (second law).

In [None]:
# Entropy production for both trajectories
energies_relax = jax.vmap(energy_fn)(traj_relax)
dE_relax = np.array(jnp.diff(energies_relax))
sigma_relax = float(beta) * dE_relax
sigma_eq = float(beta) * dE

print("Relaxation trajectory:")
print(f"  Mean entropy production: {np.mean(sigma_relax):.4f} (should be < 0, energy decreasing)")
print(f"  <exp(-sigma)>: {np.mean(np.exp(-sigma_relax)):.4f} (Jarzynski: should -> 1)")

print("\nEquilibrium trajectory:")
print(f"  Mean entropy production: {np.mean(sigma_eq):.4f} (should be ~ 0)")
print(f"  <exp(-sigma)>: {np.mean(np.exp(-sigma_eq)):.4f} (Jarzynski: should -> 1)")

# Cumulative entropy production
fig, ax = plt.subplots(figsize=(8, 4))
ax.plot(np.cumsum(sigma_relax), label="Relaxation", linewidth=0.8)
ax.plot(np.cumsum(sigma_eq), label="Equilibrium", linewidth=0.8)
ax.set_xlabel("Gibbs sweep")
ax.set_ylabel(r"Cumulative $\sigma$")
ax.set_title("Cumulative entropy production")
ax.legend()
plt.tight_layout()
plt.show()

**Varying the bias field**

The Peclet number increases with the bias field $h$. At $h = 0$ (no external drive),
the system is symmetric and $\text{Pe} \approx 0$. As $h$ increases, directed
transport dominates.

In [None]:
h_values = [0.0, 0.1, 0.3, 0.5, 1.0, 2.0]
pe_values = []

for h_val in h_values:
    biases_h = jnp.full(N, h_val)
    model_h = IsingEBM(nodes, edges, biases_h, weights, beta)
    program_h = IsingSamplingProgram(model_h, free_blocks, [])

    schedule_h = SamplingSchedule(0, 200, 1)
    key, subkey = jax.random.split(key)
    samples_h = sample_states(subkey, program_h, schedule_h, init_cold, [], [Block(nodes)])
    mag_h = jax.vmap(magnetization)(samples_h[0])
    pe_h = compute_peclet(mag_h)
    pe_values.append(pe_h)
    print(f"  h = {h_val:.1f}: Pe = {pe_h:.3f}")

fig, ax = plt.subplots(figsize=(6, 4))
ax.plot(h_values, pe_values, "o-")
ax.axhline(y=1.0, color="r", linestyle="--", alpha=0.5, label="Pe = 1")
ax.set_xlabel("External field h")
ax.set_ylabel("Peclet number")
ax.set_title("Pe vs bias field strength")
ax.legend()
plt.tight_layout()
plt.show()

**Discussion**

This notebook demonstrates extracting three thermodynamic observables from THRML
sample trajectories:

- **Peclet number** distinguishes directed transport ($\text{Pe} > 1$) from
  diffusion ($\text{Pe} \approx 0$). During relaxation, sampling is driven;
  at equilibrium, it is diffusive.
- **Crooks ratio** confirms that the Gibbs sampler satisfies the fluctuation
  theorem: $\ln[P(+\Delta E)/P(-\Delta E)] = \beta \Delta E$.
- **Entropy production** is positive during relaxation (second law) and
  zero-mean at equilibrium. The integral fluctuation theorem
  $\langle e^{-\sigma} \rangle = 1$ provides an independent consistency check.

These observables are generic — they apply to any EBM sampled by THRML,
not just the 1D Ising chain used here. They are useful for:
- Verifying that a custom EBM satisfies detailed balance
- Diagnosing whether sampling has reached equilibrium ($\text{Pe} \to 0$)
- Measuring irreversibility in biased or driven systems

### References

- Crooks, G. E. (1999). Entropy production fluctuation theorem and the
  nonequilibrium work relation for free energy differences. *Phys. Rev. E* 60(3).
- Jarzynski, C. (1997). Nonequilibrium equality for free energy differences.
  *Phys. Rev. Lett.* 78(14).
- Hack, R. et al. (2022). The Crooks fluctuation theorem for general
  Markov chains. *arXiv:2208.11722*.