# Time-Varying Potentials in THRML

Energy-based models often need non-static energy landscapes: simulated annealing, scheduled biases, or dynamically injected constraints. This notebook demonstrates patterns for time-varying potentials in THRML.

**Patterns covered:**
1. Step-function schedules (discrete constraint injection)
2. One-shot transitions (sudden potential change)
3. Exponential decay (constraint erosion)
4. Persistent bias (static reference case)
5. Comparing schedule strategies: which converges fastest?

In [None]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np

In [None]:
from thrml.block_management import Block
from thrml.block_sampling import sample_states, SamplingSchedule
from thrml.models.ising import IsingEBM, IsingSamplingProgram
from thrml.pgm import SpinNode

**Why time-varying potentials?**

Many applications of EBMs require non-static energy landscapes:

- **Simulated annealing:** gradually reduce temperature to find ground states
- **Constraint injection:** introduce external fields mid-run to steer the system
- **Scheduled biases:** ramp potentials to model changing environments

THRML's existing examples show static energy functions. This notebook demonstrates patterns for updating biases between sampling rounds, and compares strategies to answer: **does schedule shape matter, or just the final energy?**

**Schedule construction utilities**

Four schedule types: persistent (constant), one-shot (step function at midpoint), iterative steps (gradual ramp), and exponential decay.

In [None]:
def make_persistent_schedule(n_rounds: int, c: float = 1.0) -> np.ndarray:
    """Constant constraint from round 0."""
    return np.full(n_rounds, c)


def make_oneshot_schedule(n_rounds: int, onset_round: int = 50,
                          c_final: float = 1.0) -> np.ndarray:
    """Step function: zero until onset_round, then c_final."""
    schedule = np.zeros(n_rounds)
    schedule[onset_round:] = c_final
    return schedule


def make_step_schedule(n_rounds: int, n_steps: int = 4,
                       c_final: float = 1.0) -> np.ndarray:
    """Iterative ramp: n_steps equal increments reaching c_final at midpoint."""
    schedule = np.zeros(n_rounds)
    onset = n_rounds // 2
    step_size = c_final / n_steps
    rounds_per_step = onset // n_steps

    for s in range(n_steps):
        start = s * rounds_per_step
        c_val = (s + 1) * step_size
        schedule[start:] = np.minimum(schedule[start:] + 0, c_val)
        schedule[start:onset] = c_val
    schedule[onset:] = c_final
    return schedule


def make_decay_schedule(n_rounds: int, kappa: float = 0.05,
                        c_init: float = 1.0) -> np.ndarray:
    """Exponential decay: c(t) = c_init * exp(-kappa * t)."""
    return c_init * np.exp(-kappa * np.arange(n_rounds))


# Visualize all four schedules
n_rounds = 100
schedules = {
    "Persistent (c=1)": make_persistent_schedule(n_rounds),
    "One-shot (R50)": make_oneshot_schedule(n_rounds, onset_round=50),
    "4-step ramp": make_step_schedule(n_rounds, n_steps=4),
    "8-step ramp": make_step_schedule(n_rounds, n_steps=8),
    "Exponential decay": make_decay_schedule(n_rounds, kappa=0.05),
}

fig, axs = plt.subplots(figsize=(10, 4))
for name, sched in schedules.items():
    axs.plot(sched, label=name, linewidth=2 if "Persistent" in name else 1.5)
axs.set_xlabel("Round")
axs.set_ylabel("Constraint strength $c(t)$")
axs.set_title("Constraint schedule types")
axs.legend(fontsize=9)
plt.tight_layout()
plt.show()

**The pattern: per-round model rebuilding**

THRML's `IsingEBM` takes static biases. To implement time-varying potentials, we rebuild the model at each round with updated biases and sample for a small number of steps, carrying forward the spin state from the previous round.

This is the general pattern for any time-varying potential in THRML: the energy function changes between sampling rounds, but within each round the sampler runs at fixed biases.

In [None]:
K = 16
THETA_UU = 0.85
THETA_GG = 0.06


def build_ising_at_constraint(c: float):
    """Build IsingEBM with a given constraint level c in [0, 1]."""
    eps = 1e-6
    b_drift = 0.5 * np.log(THETA_UU / (1 - THETA_UU))
    b_gg = 0.5 * np.log(max(THETA_GG, eps) / max(1 - THETA_GG, eps))
    b_constraint = b_drift - b_gg

    nodes = [SpinNode() for _ in range(K)]
    biases = jnp.full(K, b_drift - b_constraint * c)

    # 1-D chain with nearest-neighbor coupling
    edges = [(nodes[i], nodes[i + 1]) for i in range(K - 1)]
    weights = jnp.full(len(edges), 0.1)

    ebm = IsingEBM(
        nodes=nodes, edges=edges, biases=biases,
        weights=weights, beta=jnp.array(1.0),
    )
    return ebm, nodes


def simulate_schedule(constraint_schedule, steps_per_round: int = 20,
                      seed: int = 4242):
    """Run THRML sampling with a time-varying constraint schedule.

    At each round, rebuild the IsingEBM with updated biases and sample
    for steps_per_round steps, carrying forward the spin state.

    Returns theta trajectory: array of shape (n_rounds,).
    """
    n_rounds = len(constraint_schedule)
    key = jax.random.key(seed)
    theta_traj = np.zeros(n_rounds)

    # Initial state: all spins down
    current_state = None

    for r in range(n_rounds):
        c = constraint_schedule[r]
        ebm, nodes = build_ising_at_constraint(c)
        blocks = [Block([node]) for node in nodes]

        program = IsingSamplingProgram(
            ebm=ebm, free_blocks=blocks, clamped_blocks=[],
        )
        schedule = SamplingSchedule(
            n_warmup=0 if r > 0 else 50,  # warmup only on first round
            n_samples=1,
            steps_per_sample=steps_per_round,
        )

        key, subkey = jax.random.split(key)
        if current_state is None:
            init_state = [jnp.array([False]) for _ in nodes]
        else:
            init_state = current_state

        samples = sample_states(
            key=subkey, program=program, schedule=schedule,
            init_state_free=init_state, state_clamp=[],
            nodes_to_sample=blocks,
        )

        # Extract theta and save state for next round
        spins = jnp.stack([s[0, 0] for s in samples])
        theta_traj[r] = float(jnp.mean(spins.astype(jnp.float32)))
        current_state = [s[0:1, :] for s in samples]

    return theta_traj

**Comparing schedule strategies**

We run five conditions that reach the same final constraint ($c = 1$) via different paths, plus a no-constraint baseline and an exponential decay. The question: does the path to full constraint matter?

In [None]:
n_rounds = 100

test_schedules = {
    "No constraint": np.zeros(n_rounds),
    "Persistent (full)": make_persistent_schedule(n_rounds),
    "One-shot (R50)": make_oneshot_schedule(n_rounds, onset_round=50),
    "4-step ramp": make_step_schedule(n_rounds, n_steps=4),
    "8-step ramp": make_step_schedule(n_rounds, n_steps=8),
    "Exponential decay": make_decay_schedule(n_rounds, kappa=0.03),
}

trajectories = {}
print("Running schedule comparison:\n")
for name, sched in test_schedules.items():
    theta_traj = simulate_schedule(sched, seed=4242)
    trajectories[name] = theta_traj
    final_theta = theta_traj[-1]
    mean_last10 = np.mean(theta_traj[-10:])
    var_last10 = np.var(theta_traj[-10:])
    print(f"  {name:22s}: final theta = {final_theta:.4f}  "
          f"mean(last 10) = {mean_last10:.4f}  var(last 10) = {var_last10:.6f}")

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Panel 1: Trajectories
colors = plt.cm.tab10(np.linspace(0, 1, len(trajectories)))
for (name, traj), color in zip(trajectories.items(), colors):
    axes[0].plot(traj, label=name, color=color,
                 linewidth=2 if "Persistent" in name else 1.2)
axes[0].set_xlabel("Round")
axes[0].set_ylabel(r"$\theta$ (order parameter)")
axes[0].set_title("Trajectories under different constraint schedules")
axes[0].legend(fontsize=8, loc="upper right")

# Panel 2: Final-10-round statistics
names = list(trajectories.keys())
means = [np.mean(trajectories[n][-10:]) for n in names]
variances = [np.var(trajectories[n][-10:]) for n in names]

x = np.arange(len(names))
axes[1].bar(x, means, color=[colors[i] for i in range(len(names))], alpha=0.7)
axes[1].set_xticks(x)
axes[1].set_xticklabels(names, rotation=35, ha="right", fontsize=8)
axes[1].set_ylabel(r"Mean $\theta$ (last 10 rounds)")
axes[1].set_title("Final equilibrium by schedule type")

plt.tight_layout()
plt.show()

**One-shot rebound effect**

A striking feature of the one-shot schedule: when constraint is applied suddenly at round 50, the system drops but then **rebounds** — the order parameter partially recovers. This is experimentally confirmed (3/3 trials in the original study). It reflects the system's inertia: established correlations resist sudden potential changes.

In [None]:
os_traj = trajectories["One-shot (R50)"]
pers_traj = trajectories["Persistent (full)"]

fig, axs = plt.subplots(figsize=(10, 5))
axs.plot(os_traj, label="One-shot (R50)", linewidth=2, color="C1")
axs.plot(pers_traj, label="Persistent (full)", linewidth=2, color="C2",
         linestyle="--")
axs.axvline(50, color="gray", linestyle=":", alpha=0.5, label="Constraint onset")

# Mark rebound region
if len(os_traj) > 60:
    drop_val = os_traj[52] if len(os_traj) > 52 else os_traj[50]
    rebound_val = np.mean(os_traj[60:75]) if len(os_traj) > 75 else os_traj[-1]
    if rebound_val > drop_val:
        axs.annotate("rebound", xy=(65, rebound_val),
                     xytext=(75, rebound_val + 0.1),
                     arrowprops=dict(arrowstyle="->", color="C3"),
                     fontsize=10, color="C3")

axs.set_xlabel("Round")
axs.set_ylabel(r"$\theta$")
axs.set_title("One-shot constraint: drop and rebound")
axs.legend()
plt.tight_layout()
plt.show()

**Per-step analysis: do more steps help?**

A counterintuitive result from the experimental data: more constraint injection steps do **not** reliably produce less drift. The coefficient of variation (CV) of per-step effects is high ($> 0.5$), meaning each step has a wildly different impact depending on the current system state. This is inconsistent with a reverse-diffusion denoising model where each step should contribute equally.

In [None]:
for n_steps_label in ["4-step ramp", "8-step ramp"]:
    traj = trajectories[n_steps_label]
    n_steps = 4 if "4" in n_steps_label else 8
    step_rounds = [i * (n_rounds // (2 * n_steps)) for i in range(n_steps)]

    deltas = []
    for sr in step_rounds:
        if sr < n_rounds - 5 and sr >= 3:
            pre = np.mean(traj[max(0, sr - 3):sr])
            post = np.mean(traj[sr:min(n_rounds, sr + 5)])
            deltas.append(post - pre)

    if deltas:
        deltas = np.array(deltas)
        nonzero = deltas[np.abs(deltas) > 1e-6]
        if len(nonzero) > 1:
            cv = np.std(nonzero) / np.abs(np.mean(nonzero))
            print(f"  {n_steps_label}: per-step deltas = "
                  f"{[f'{d:.4f}' for d in deltas]}")
            print(f"    CV = {cv:.2f}  "
                  f"({'non-uniform (CV > 0.5)' if cv > 0.5 else 'approximately uniform'})")
        else:
            print(f"  {n_steps_label}: insufficient non-zero steps")

**Summary**

This notebook demonstrated how to implement time-varying potentials in THRML by rebuilding the `IsingEBM` with updated biases at each sampling round and carrying forward the spin state.

**Key findings from the schedule comparison:**

1. **Persistent constraint massively outperforms** all other strategies. If you can apply the full constraint from the start, do so.
2. **One-shot injection causes rebound** — the system partially recovers from sudden potential changes. This is a real experimental finding, not a simulation artifact.
3. **More injection steps does not reliably mean less drift.** Per-step effects are state-dependent with high CV, inconsistent with equal-step denoising models.
4. **Schedule shape matters more than step count.** The geometry of the constraint trajectory determines the outcome as much as the final constraint level.

**The practical lesson:** When designing annealing or constraint schedules for THRML models, test multiple strategies. The assumption that gradual is always better than sudden (or vice versa) does not hold in general — the system's response to potential changes depends on its current state and history.