# Setup

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import tqdm.auto as tqdm
%matplotlib widget

In [None]:
def grab(x):
    return x.detach().cpu().numpy()

In [None]:
def compute_ess(w):
    return (w.mean()**2) / (w**2).mean()

# Phi4 theory

Copied from Lecture 1:

In [None]:
class Phi4Action:
    def __init__(self, m2, lam):
        self.m2 = m2
        self.lam = lam
        self.grad = torch.func.grad(self.value)
    def value(self, phi):
        # phi.shape = (Lx, Ly, ...)
        Nd = len(phi.shape)
        S = ((Nd + self.m2/2) * phi**2 + (self.lam/24) * phi**4).sum()
        for mu in range(Nd):
            phi_fwd = torch.roll(phi, -1, dims=mu)
            S -= (phi * phi_fwd).sum()
        return S

In [None]:
def leapfrog_update(phi, pi, action, *, dt, n_leap):
    phi += (dt/2)*pi
    for _ in range(n_leap-1):
        pi -= dt*action.grad(phi)
        phi += dt*pi
    pi -= dt*action.grad(phi)
    phi += (dt/2)*pi

Modified to keep L as a parameter, and return samples $\phi$ and associated action values:

In [None]:
def run_hmc(action, *, L, n_therm, n_iter, n_meas, dt=0.10, n_leap=10):
    torch.manual_seed(1234)
    phi = 0.1*torch.randn((L, L)) # warm start
    S = action.value(phi)
    acc = 0
    tot = 0
    meas = []
    phis = []
    actions = []
    for i in tqdm.tqdm(range(-n_therm, n_iter)):
        new_phi = phi.clone()
        pi = torch.randn_like(phi)
        K = (pi**2/2).sum()
        leapfrog_update(new_phi, pi, action, dt=dt, n_leap=n_leap)
        Sp = action.value(new_phi)
        Kp = (pi**2/2).sum()
        dH = grab(Sp + Kp - S - K)
        tot += 1
        if np.random.random() < np.exp(-dH): # accept
            phi = new_phi
            S = Sp
            acc += 1
        if i >= 0 and (i+1)%n_meas == 0:
            meas.append(grab(phi.mean()))
            phis.append(grab(phi))
            actions.append(S)
            # print(f'Acc {100.0*acc/tot:.2f}')
    return dict(meas=np.stack(meas), phis=np.stack(phis), actions=np.stack(actions))

# Flow

In [None]:
class Velocity(torch.nn.Module):
    def __init__(self):
        super().__init__()
        conv_kwargs = dict(kernel_size=3, padding=1, padding_mode='circular')
        self.net = torch.nn.Sequential(
            torch.nn.Conv2d(2, 8, **conv_kwargs),
            torch.nn.SiLU(),
            torch.nn.Conv2d(8, 8, **conv_kwargs),
            torch.nn.SiLU(),
            torch.nn.Conv2d(8, 1, **conv_kwargs),
        )
    def value(self, x, t):
        inp = torch.stack([x, torch.ones_like(x)*t])
        return self.net(inp)[0]
    def div(self, x, t):
        shape = x.shape
        x_flat = x.flatten()
        def eval_flat(y):
            x = y.reshape(shape)
            inp = torch.stack([x, torch.ones_like(x)*t])
            return self.net(inp)[0].flatten()
        J = torch.func.jacfwd(eval_flat)(x_flat)
        return torch.trace(J)

In [None]:
velocity = Velocity()
phi = torch.randn((4, 4))
print(velocity.value(phi, 1.0))
print(velocity.div(phi, 1.0))

Copied from Lecture 2:

In [None]:
def flow(x, velocity, *, n_step, tf=1.0, inverse=False):
    dt = tf/n_step
    ts = dt*torch.arange(n_step)
    logJ = torch.tensor(0.0)
    sign = 1
    if inverse:
        sign = -1
        ts = reversed(ts)
    for t in ts:
        # transport samples
        x = x + sign * dt * velocity.value(x, t)
        # estimate change of measure
        logJ = logJ + dt * velocity.div(x, t)
    return x, logJ

In [None]:
# chunk size makes sure we do not try to evaluate all samples at
# the same time, exhausting our memory
flow_batch = torch.func.vmap(flow, in_dims=(0, None), chunk_size=512)
flow_batch(torch.randn((1, 4, 4)), velocity, n_step=10)

In [None]:
# our target will be in the broken phase
target = Phi4Action(-0.5, 1.5)
res_hmc = run_hmc(target, L=4, n_therm=100, n_iter=1000, n_meas=1, dt=0.05, n_leap=20)
fig, ax = plt.subplots(1,1, figsize=(4,2.5))
ax.plot(res_hmc['meas'])
plt.show()

We will approach the target by learning 5 different flows trained to transport between the $m^2 = 0, \lambda = 1.5$ theory (symmetric phase) and the $m^2 = -0.5, \lambda = 1.5$ theory (broken phase) in steps of $m^2$:
$$
(m^2, \lambda) = (0, 1.5) \; \longrightarrow \; (-0.1, 1.5) \; \longrightarrow \; (-0.2, 1.5) \; \longrightarrow \; (-0.3, 1.5) \; \longrightarrow \; (-0.4, 1.5) \; \longrightarrow \; (-0.5, 1.5)
$$

In [None]:
targets = [Phi4Action(m2, 1.5) for m2 in [0.0, -0.1, -0.2, -0.3, -0.4, -0.5]]

In [None]:
def train_model(prior, target, *, batch_size=4, n_iter=1000):
    torch.manual_seed(1234)
    model = Velocity()
    L = 4
    opt = torch.optim.Adam(model.parameters(), lr=3e-4)
    loss_hist = []
    ess_hist = []
    flow_batch = torch.func.vmap(flow, in_dims=(0, None))
    target_batch = torch.func.vmap(target.value)
    prior_batch = torch.func.vmap(prior.value)
    res_hmc = run_hmc(prior, L=L, n_therm=100, n_iter=1000, n_meas=1, dt=0.05, n_leap=20)
    prior_samples = res_hmc['phis']
    prior_actions = res_hmc['actions']
    for i in tqdm.tqdm(range(n_iter)):
        opt.zero_grad()
        # sample prior
        inds = np.random.randint(len(prior_samples), size=batch_size)
        xr = torch.tensor(prior_samples[inds])
        logr = -torch.tensor(prior_actions[inds])
        # flow
        x, logJ = flow_batch(xr, model, n_step=10)
        logp = -target_batch(x)
        # ordinary kl logq
        # logq = logr - logJ
        # path grad logq
        model.requires_grad_(False)
        xr2, logJ2 = flow_batch(x, model, n_step=10, inverse=True)
        model.requires_grad_(True)
        logq = -prior_batch(xr2) - logJ2
        # kl div
        loss = (logq - logp).mean()
        loss.backward()
        opt.step()
        loss_hist.append(grab(loss))
        ess = compute_ess((logp - logq).exp())
        ess_hist.append(grab(ess))
        if (i+1) % 25 == 0:
            print(f'Step {i+1}: Loss {grab(loss)} ESS {grab(ess)}')
    fig, axes = plt.subplots(1,2, figsize=(8, 3))
    axes[0].plot(loss_hist)
    axes[1].plot(ess_hist)
    axes[0].set_ylabel('loss')
    axes[1].set_ylabel('ess')
    return dict(model=model, loss=np.stack(loss_hist), ess=np.stack(ess_hist))

In [None]:
results = [
    train_model(targets[i], targets[i+1], batch_size=32, n_iter=10)
    for i in range(len(targets)-1)
]

# Evaluate

## L = 4

In [None]:
prior = Phi4Action(0.0, 1.5)
phi_r = torch.tensor(run_hmc(prior, L=4, n_therm=100, n_iter=5000, n_meas=2, dt=0.05, n_leap=20)['phis'])
logr = -torch.func.vmap(prior.value)(phi_r)

In [None]:
def apply_flows(phi_r):
    samples = [phi_r]
    ws = [torch.ones(phi_r.shape[0])]
    for res, target_p, target in zip(tqdm.tqdm(results), targets[:-1], targets[1:]):
        with torch.no_grad(): # we don't need gradients, so don't waste memory for it
            phi, logJ = flow_batch(samples[-1].clone(), res['model'], n_step=100)
        dlogq = -logJ
        dlogp = -torch.func.vmap(target.value)(phi) + torch.func.vmap(target_p.value)(samples[-1])
        ws.append(ws[-1] * (dlogp - dlogq).exp())
        samples.append(phi)
        print(f'ess: {compute_ess(ws[-1])}')
    return dict(samples=samples, ws=ws)

In [None]:
res_L4 = apply_flows(phi_r)

In [None]:
# compare with 2500 samples from the HMC
target = Phi4Action(-0.5, 1.5)
res_hmc_L4 = []
fig, ax = plt.subplots(1,1, figsize=(6, 3))
for m2 in [0.0, -0.1, -0.2, -0.3, -0.4, -0.5]:
    res_hmc_L4.append(run_hmc(Phi4Action(m2, 1.5), L=4, n_therm=100, n_iter=5000, n_meas=2, dt=0.05, n_leap=20))
    ax.plot(res_hmc_L4[-1]['meas'])
plt.show()

In [None]:
fig, axes = plt.subplots(2, 3)
axes = axes.flatten()
bins = np.linspace(-2, 2, 21)
for ax, phi, w, hmc in zip(axes, res_L4['samples'], res_L4['ws'], res_hmc_L4):
    ax.hist(grab(phi.flatten(1).mean(-1)), color='k', density=True, histtype='step', bins=bins, linestyle='--')
    ax.hist(grab(phi.flatten(1).mean(-1)), weights=grab(w), density=True, bins=bins)
    ax.hist(hmc['meas'], density=True, bins=bins, alpha=0.5, color='xkcd:red')
plt.show()

## L = 8

Because we worked with convolutions, the _same_ flow model can be directly applied to the L=8 theory. This is one of the exercises!