# Flow-Based Swap Proposals for Faster Molecular Sampling  
*An intuition-first, CPU-friendly tutorial*

**Goals**

1. *Motivation* — why sampling Boltzmann distributions of molecules is hard.  
2. *Parallel Tempering (PT)* quick refresher.  
3. *Normalizing Flows (RealNVP)* from first principles, with visualisations.  
4. *Toy experiment* on a 2-D Gaussian Mixture Model (GMM):  
   * Train a flow between two temperatures.  
   * Compare swap acceptance rates: naïve PT vs flow-guided PT.  
5. *Bridge* to real peptides (Alanine Dipeptide) and the T-GePT thesis.

> **Runtime:** < 10 min on CPU.

In [None]:
# Cell purpose: import minimal dependencies and set global settings
import torch
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib.gridspec import GridSpec

# Use CPU throughout for portability
device = torch.device("cpu")

# Helper for nice, reproducible plots
plt.rcParams.update({
    "figure.facecolor": "white",
    "axes.grid": True,
    "axes.spines.right": False,
    "axes.spines.top": False,
})
torch.manual_seed(0)

## 1  Why is molecular sampling hard?

* The Boltzmann density  
  $$
    \pi_\beta(\mathbf{x}) \propto e^{-\beta U(\mathbf{x})}
  $$
  has **many deep wells** (metastable states).

* Local proposals (e.g. Langevin) **mix slowly** between wells.

* **Parallel Tempering (PT)** runs $K$ replicas at temperatures  
  $\beta_1 < \beta_2 < \dots < \beta_K$ and occasionally swaps them.

* Unfortunately, a *blind swap* \\((x^{(k)},x^{(k+1)}) \mapsto (x^{(k+1)},x^{(k)})\\)  
  is only accepted if the two distributions **overlap strongly**.
  On tough systems the acceptance may fall below 1 %.

### Idea: learn a transport map that *morphs* a hot sample into a cold-like one
A *normalizing flow* does exactly that while still allowing us to write down an
exact proposal density, so we can keep the Metropolis–Hastings correction and
**preserve detailed balance**.

## 2  A 2-D GMM toy landscape

We first build a simple 5-mode Gaussian mixture.  
At low temperature ($T=1$) each mode is narrow; at high temperature ($T=50$) the
modes blur and overlap.

In [None]:
# Cell purpose: define a simple GMM class + helper to plot contours / samples
class GMM2D:
    def __init__(self):
        # 5 centres manually chosen for visual clarity
        self.loc = torch.tensor([
            [-2.5,  0.8],
            [-0.7,  1.5],
            [ 1.2,  0.3],
            [ 0.1, -1.7],
            [ 3.0, -1.3]
        ], dtype=torch.float32)
        # Slightly eccentric covariances
        self.scale_tril = torch.tensor([
            [[0.40,  0.05], [0.00, 0.45]],
            [[0.10,  0.00], [0.00, 0.15]],
            [[0.35,  0.00], [0.00, 0.35]],
            [[0.25, -0.07], [0.00, 0.18]],
            [[0.12,  0.00], [0.00, 0.25]],
        ])
        self.cat = torch.distributions.Categorical(torch.ones(5) / 5.)
        self.comp = torch.distributions.MultivariateNormal(
            self.loc, scale_tril=self.scale_tril, validate_args=False
        )
        self.mixture = torch.distributions.MixtureSameFamily(self.cat, self.comp)

    def sample(self, n: int, T: float = 1.0):
        if T == 1.0:
            return self.mixture.sample((n,))
        # high-T = inflate covariance by sqrt(T)
        scaled = torch.distributions.MultivariateNormal(
            self.loc, scale_tril=self.scale_tril * np.sqrt(T), validate_args=False
        )
        mixture_hi = torch.distributions.MixtureSameFamily(self.cat, scaled)
        return mixture_hi.sample((n,))

    def log_prob(self, x: torch.Tensor, T: float = 1.0):
        if T == 1.0:
            return self.mixture.log_prob(x)
        scaled = torch.distributions.MultivariateNormal(
            self.loc, scale_tril=self.scale_tril * np.sqrt(T), validate_args=False
        )
        mixture_hi = torch.distributions.MixtureSameFamily(self.cat, scaled)
        return mixture_hi.log_prob(x)

def plot_gmm_samples(gmm, T, n=1000, ax=None, title=""):
    if ax is None:
        _, ax = plt.subplots()
    s = gmm.sample(n, T).numpy()
    ax.scatter(s[:, 0], s[:, 1], s=7, alpha=0.5)
    ax.set_title(f"{title} (T={T})")
    ax.set_xlim(-5, 5); ax.set_ylim(-5, 5)

In [None]:
# Cell purpose: visualise the GMM at T=1 and T=50
gmm = GMM2D()
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 4))
plot_gmm_samples(gmm, T=1.0, n=1500, ax=ax1, title="Low temperature")
plot_gmm_samples(gmm, T=50.0, n=1500, ax=ax2, title="High temperature")
plt.show()

### 3  Naïve PT baseline

We run two replicas:

* Replica 0: low-T (β=1)  
* Replica 1: high-T (β=50)

We alternate **local Langevin moves** and a **blind swap** and record the
acceptance probability.

In [None]:
# Cell purpose: minimal PT with Langevin + blind swap, CPU-only
def mala_step(x, log_prob, step_size=0.1):
    """One Metropolis-adjusted Langevin step (vectorised)."""
    x = x.clone().detach().requires_grad_(True)
    lp = log_prob(x)
    grad = torch.autograd.grad(lp.sum(), x)[0]
    proposal = x + 0.5 * step_size * grad + torch.sqrt(torch.tensor(step_size)) * torch.randn_like(x)
    # MH accept
    def log_q(x_from, x_to):
        # transition density of forward/backward Langevin (ignoring constant)
        diff = x_to - (x_from + 0.5 * step_size * grad)
        return -0.25 / step_size * (diff ** 2).sum(dim=-1)
    lp_prop = log_prob(proposal)
    grad_prop = torch.autograd.grad(lp_prop.sum(), proposal)[0]
    log_alpha = (lp_prop - lp) + log_q(proposal, x) - log_q(x, proposal)
    accept = (torch.log(torch.rand_like(log_alpha)) < log_alpha)
    x_new = torch.where(accept.unsqueeze(-1), proposal.detach(), x.detach())
    return x_new, accept.float().mean().item()

def run_pt_baseline(n_steps=500, swap_interval=10):
    # start each replica from its own distribution
    x_low = gmm.sample(1, T=1.0)
    x_high = gmm.sample(1, T=50.0)
    swap_acc = []

    for step in range(n_steps):
        # local moves
        x_low, _ = mala_step(x_low, lambda y: gmm.log_prob(y, T=1.0), step_size=0.05)
        x_high, _ = mala_step(x_high, lambda y: gmm.log_prob(y, T=50.0), step_size=0.05)

        # swap attempt
        if step % swap_interval == 0:
            log_alpha = (
                gmm.log_prob(x_high, T=1.0) + gmm.log_prob(x_low, T=50.0)
                - gmm.log_prob(x_low, T=1.0) - gmm.log_prob(x_high, T=50.0)
            )
            accept = torch.rand(()) < torch.exp(log_alpha)
            if accept:
                x_low, x_high = x_high.clone(), x_low.clone()
            swap_acc.append(float(accept))
    return np.mean(swap_acc)

naive_rate = run_pt_baseline()
print(f"Naïve PT swap acceptance ≈ {naive_rate:.4f}")

## 4  Normalizing Flows in a Nutshell

* A **flow** is an invertible map \\(f:\\; \mathbf{z}\\to\mathbf{x}\\) with a
  tractable Jacobian determinant.

* If \\( \mathbf{z}\sim \rho(\mathbf{z}) = \mathcal N(0,I) \\) then  
  \\[
      p(\mathbf{x}) = \rho\!\bigl(f^{-1}(\mathbf{x})\bigr)
      \;\Bigl|\det \partial_{\mathbf{x}} f^{-1}(\mathbf{x}) \Bigr|.
  \\]

### RealNVP coupling layer

Keep half the variables fixed, transform the rest with an **affine** shift & log-scale  
computed from the passive half.

\[
\begin{aligned}
    y_a &= x_a \\
    y_b &= x_b \odot \exp\bigl(s(x_a)\bigr) + t(x_a)
\end{aligned}
\]

The log-determinant is simply \\(\sum s(x_a)\\).

In [None]:
# Cell purpose: define a super-minimal RealNVP for 2-D
class AffineCoupling(torch.nn.Module):
    def __init__(self, mask, hidden_dim=128):
        super().__init__()
        self.register_buffer("mask", mask)
        in_dim = mask.numel()
        self.net = torch.nn.Sequential(
            torch.nn.Linear(in_dim, hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_dim, hidden_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_dim, in_dim * 2),  # outputs [s, t]
        )

    def forward(self, x):
        x_masked = x * self.mask
        s_t = self.net(x_masked)
        s, t = s_t.chunk(2, dim=-1)
        s = torch.tanh(s) * 2.0             # soft clamp
        y = x_masked + (1 - self.mask) * (x * torch.exp(s) + t)
        logdet = ((1 - self.mask) * s).sum(-1)
        return y, logdet

    def inverse(self, y):
        y_masked = y * self.mask
        s_t = self.net(y_masked)
        s, t = s_t.chunk(2, dim=-1)
        s = torch.tanh(s) * 2.0
        x = y_masked + (1 - self.mask) * ((y - t) * torch.exp(-s))
        logdet = -((1 - self.mask) * s).sum(-1)
        return x, logdet

class RealNVP(torch.nn.Module):
    def __init__(self, n_couplings=8, hidden=128):
        super().__init__()
        masks = [torch.tensor([1.,0.]), torch.tensor([0.,1.])] * (n_couplings//2)
        self.blocks = torch.nn.ModuleList([AffineCoupling(m) for m in masks])

    def forward(self, x):
        logdet = torch.zeros(x.size(0))
        z = x
        for c in self.blocks:
            z, ld = c.forward(z)
            logdet += ld
            z = z.flip(-1)  # simple permutation
        return z, logdet

    def inverse(self, z):
        logdet = torch.zeros(z.size(0))
        x = z
        for c in reversed(self.blocks):
            x = x.flip(-1)
            x, ld = c.inverse(x)
            logdet += ld
        return x, logdet

In [None]:
# Cell purpose: tiny training loop (≤500 steps) to learn 50→1 transport
flow = RealNVP(n_couplings=8, hidden=128)
optimiser = torch.optim.Adam(flow.parameters(), lr=1e-3)

n_steps = 500      # keep tiny for CPU demo
batch = 128
T_hi = 50.0

for step in range(1, n_steps + 1):
    # --- Sample high / low batches
    x_hi = gmm.sample(batch, T=T_hi)
    x_lo = gmm.sample(batch, T=1.0)

    # Map high → low direction
    y_lo, ld_inv = flow.inverse(x_hi)
    loss_hi = -(gmm.log_prob(y_lo, T=1.0) + ld_inv).mean()

    # Map low → high direction (symmetry helps training)
    y_hi, ld_fwd = flow.forward(x_lo)
    loss_lo = -(gmm.log_prob(y_hi, T=T_hi) + ld_fwd).mean()

    loss = loss_hi + loss_lo
    optimiser.zero_grad()
    loss.backward()
    optimiser.step()

    if step % 100 == 0:
        print(f"[{step}/{n_steps}] loss = {loss.item():.4f}")

In [None]:
# Cell purpose: plot mapped points vs true low-T points
with torch.no_grad():
    hi_samples = gmm.sample(2000, T=T_hi)
    mapped, _ = flow.inverse(hi_samples)

fig, ax = plt.subplots(1, 2, figsize=(8,4))
plot_gmm_samples(gmm, T=1.0, n=0, ax=ax[0], title="True low-T modes")
ax[0].scatter(mapped[:,0], mapped[:,1], s=6, alpha=0.5, label="mapped")
ax[0].legend()

plot_gmm_samples(gmm, T=T_hi, n=2000, ax=ax[1], title="Original high-T samples")
plt.tight_layout()
plt.show()

In [None]:
# Cell purpose: rerun PT, but with flow-based extreme swap proposal
def run_pt_flow(n_steps=500, swap_interval=10):
    x_low = gmm.sample(1, T=1.0)
    x_high = gmm.sample(1, T=T_hi)
    swap_acc = []

    for step in range(n_steps):
        # local Langevin as before
        x_low, _ = mala_step(x_low, lambda y: gmm.log_prob(y, T=1.0), step_size=0.05)
        x_high, _ = mala_step(x_high, lambda y: gmm.log_prob(y, T=T_hi), step_size=0.05)

        # flow-based swap
        if step % swap_interval == 0:
            with torch.no_grad():
                z = torch.randn_like(x_low)
                y_high, ld_fwd = flow.forward(x_low)       # propose new high-T
                y_low,  ld_inv = flow.inverse(x_high)      # propose new low-T
                log_q_fwd = ld_fwd + ld_inv

                # reverse proposal density
                _, ld_inv_r = flow.inverse(y_high)
                _, ld_fwd_r = flow.forward(y_low)
                log_q_rev = ld_fwd_r + ld_inv_r

            log_alpha = (
                gmm.log_prob(y_low,  T=1.0)  + gmm.log_prob(y_high, T=T_hi)
                - gmm.log_prob(x_low, T=1.0) - gmm.log_prob(x_high, T=T_hi)
                + log_q_rev - log_q_fwd
            )
            accept = torch.rand(()) < torch.exp(log_alpha)
            if accept:
                x_low, x_high = y_low.clone(), y_high.clone()
            swap_acc.append(float(accept))
    return np.mean(swap_acc)

flow_rate = run_pt_flow()
print(f"Flow-guided PT swap acceptance ≈ {flow_rate:.4f}")
print(f"Speed-up ×{flow_rate / max(1e-9, naive_rate):.1f} over naïve.")

## 5  What have we learned?

* On the toy GMM, a lightweight 2-D RealNVP **raises swap acceptance**
  by an order of magnitude without changing local dynamics.

* The same principle scales to high-D molecules **if** we:
  * design expressive flows (many couplings, conditioning on atom type);
  * keep the Metropolis correction to guarantee correctness.

### ✈️ Towards Alanine Dipeptide

* Replace `GMM2D` by `AldpBoltzmann` (60-D internal coordinates).  
* Train a flow on **pairs of temperatures that were bottlenecks** in naïve PT.  
* Integrate learned swaps into `ParallelTempering` for genuine sampling gains.

See the follow-up notebooks in this series for the molecular case!

## 6  Further Reading

* **Normalizing Flows**  
  * Dinh et al., *Real NVP* (2016)  
  * Papamakarios et al., *Normalizing flows for probabilistic modeling* (2019)

* **Replica Exchange / PT**  
  * Hukushima & Nemoto, *Exchange Monte Carlo Method* (1996)  
  * Earl & Deem, *Parallel tempering: theory, applications, etc.* (2005)

* **Flow-based molecular sampling**  
  * Noé et al., *Boltzmann Generators* (2019)  
  * Doerr et al., *Timewarp* (NeurIPS 2023)

Happy sampling! 🎉