# Flow Matching

Notebook to implement a number of different simulation-free methods for learning flow models.

In this notebook we implement 5 models that can map from a source distribution $q_0$ to a target distribution $q_1$:
* Conditional Flow Matching (CFM)
    * "Flow Straight and Fast: Learning to Generate and Transfer Data with Rectified Flow" [(Liu et al. 2023)](https://openreview.net/forum?id=XVjTT1nw5z)
    * "Stochastic Interpolants" [(Albergo et al. 2023)](https://openreview.net/forum?id=li7qeBbCR1t) with a non-variance preserving interpolant.
    * "Flow Matching" [(Lipman et al. 2023)](https://openreview.net/forum?id=PqvMRDCJT9t) but conditions on both source and target.
* Optimal Transport CFM (OT-CFM), which directly optimizes for dynamic optimal transport (WIP)
* Variational Flow Matching (VFM)

In [2]:
import time
import torch

import matplotlib.pyplot as plt

from Engine import *
from pathlib import Path
from torchdyn.core import NeuralODE
from torchdyn.datasets import generate_moons

savedir = "Results/Moons"
Path(savedir).mkdir(parents=True, exist_ok=True)

### Conditional Flow Matching

First we implement the basic conditional flow matching. As in the paper, we have
$$
\begin{align}
z &= (x_0, x_1) \\
q(z) &= q(x_0)q(x_1) \\
p_t(x | z) &= \mathcal{N}(x | t * x_1 + (1 - t) * x_0, \sigma^2) \\
u_t(x | z) &= x_1 - x_0
\end{align}
$$
When $\sigma = 0$ this is equivalent to zero-steps of rectified flow. We find that small $\sigma$ helps to regularize the problem ymmv.

In [None]:
%%time
sigma = 0.1
dim = 2
batch_size = 256
model = MLP(dim=dim, time_varying=True)
optimizer = torch.optim.Adam(model.parameters())
FM = CFM(sigma=sigma)

start = time.time()
for k in range(20000):
    optimizer.zero_grad()

    x0 = sample_8gaussians(batch_size)
    x1 = sample_moons(batch_size)

    t, xt, ut = FM.sample_location_and_conditional_flow(x0, x1)

    vt = model(torch.cat([xt, t[:, None]], dim=-1))
    loss = torch.mean((vt - ut) ** 2)

    loss.backward()
    optimizer.step()

    if (k + 1) % 5000 == 0:
        end = time.time()
        print(f"{k+1}: loss {loss.item():0.3f} time {(end - start):0.2f}")
        start = end

        node = NeuralODE(torch_wrapper(model), solver="euler")
        with torch.no_grad():
            traj = node.trajectory(
                sample_8gaussians(1024),
                t_span=torch.linspace(0, 1, 100),
            )
            plot_trajectories(traj=traj.cpu().numpy(), output=f"{savedir}/CFM_{k+1}.png")
        
        evaluate(traj[-1].cpu(), sample_moons(1024))
            
torch.save(model, f"{savedir}/CFM.pt")

### Optimal Transport Conditional Flow Matching

Next we implement optimal transport conditional flow matching. As in the paper, here we have
$$
\begin{align}
z &= (x_0, x_1) \\
q(z) &= \pi(x_0, x_1) \\
p_t(x | z) &= \mathcal{N}(x | t * x_1 + (1 - t) * x_0, \sigma^2) \\
u_t(x | z) &= x_1 - x_0
\end{align}
$$
where $\pi$ is the joint of an exact optimal transport matrix. We first sample random $x_0, x_1$, then resample according to the optimal transport matrix as computed with the python optimal transport package. We use the 2-Wasserstein distance with an $L^2$ ground distance for equivalence with dynamic optimal transport.

In [None]:
%%time
sigma = 0.1
dim = 2
batch_size = 256
model = MLP(dim=dim, time_varying=True)
optimizer = torch.optim.Adam(model.parameters())
FM = OT_CFM(sigma=sigma)

start = time.time()
for k in range(20000):
    optimizer.zero_grad()

    x0 = sample_8gaussians(batch_size)
    x1 = sample_moons(batch_size)

    t, xt, ut = FM.sample_location_and_conditional_flow(x0, x1)

    vt = model(torch.cat([xt, t[:, None]], dim=-1))
    loss = torch.mean((vt - ut) ** 2)

    loss.backward()
    optimizer.step()

    if (k + 1) % 5000 == 0:
        end = time.time()
        print(f"{k+1}: loss {loss.item():0.3f} time {(end - start):0.2f}")
        start = end
        node = NeuralODE(torch_wrapper(model), solver="euler")
        with torch.no_grad():
            traj = node.trajectory(
                sample_8gaussians(1024),
                t_span=torch.linspace(0, 1, 100),
            )
            plot_trajectories(traj=traj.cpu().numpy(), output=f"{savedir}/OT-CFM_{k+1}.png")
        
        evaluate(traj[-1].cpu(), sample_moons(1024))
            
torch.save(model, f"{savedir}/OT-CFM.pt")

### Lipman's Optimal Transport Conditional Flow Matching

In [None]:
%%time
sigma = 0.1
dim = 2
batch_size = 256 * 4
model = MLP(dim=dim, time_varying=True)
optimizer = torch.optim.Adam(model.parameters())
FM = LOT_CFM(sigma=sigma)

start = time.time()
for k in range(20000):
    optimizer.zero_grad()

    x0 = sample_8gaussians(batch_size)
    x1 = sample_moons(batch_size)

    t, xt, ut = FM.sample_location_and_conditional_flow(x0, x1)

    psi_t = FM.psi_t(x1, x0, t)
    vt = model(torch.cat([psi_t, t[:, None]], dim=-1))
    loss = torch.mean((vt - ut) ** 2)

    loss.backward()
    optimizer.step()

    if (k + 1) % 5000 == 0:
        end = time.time()
        print(f"{k+1}: loss {loss.item():0.3f} time {(end - start):0.2f}")
        start = end
        node = NeuralODE(torch_wrapper(model), solver="euler")
        with torch.no_grad():
            traj = node.trajectory(
                sample_8gaussians(1024),
                t_span=torch.linspace(0, 1, 100),
            )
            plot_trajectories(traj=traj.cpu().numpy(), output=f"{savedir}/LOT-CFM_{k+1}.png")
        
        evaluate(traj[-1].cpu(), sample_moons(1024))
            
torch.save(model, f"{savedir}/LOT-CFM.pt")

### Variational Flow Matching

Next we implement variational flow matching, which corresponds to

In [8]:
%%time
sigma = 0.1
dim = 2
batch_size = 256
model = MLP(dim=dim, time_varying=True)
optimizer = optim.AdamW(model.parameters(), lr=2e-4, eps=1e-12)
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=50, eta_min=0)
FM = VFM(sigma=sigma)

start = time.time()
for k in range(200000):
    optimizer.zero_grad()

    x0 = sample_8gaussians(batch_size)
    x1 = sample_moons(batch_size)

    t, xt, ut = FM.sample_location_and_conditional_flow(x0, x1)

    mean = model(torch.cat([xt, t[:, None]], dim=-1))

    # var = torch.tensor(sigma)**2

    loss = torch.mean((mean - ut) ** 2)

    # loss = -0.5 * torch.sum(-torch.log(2 * torch.pi * var) - ((x1 - mean)**2 / var))

    loss.backward()
    optimizer.step()
    scheduler.step()

    if (k + 1) % 5000 == 0:
        end = time.time()
        print(f"{k+1}: loss {loss.item():0.3f} time {(end - start):0.2f}")
        start = end
        node = NeuralODE(torch_wrapper(model), solver="euler")
        with torch.no_grad():
            traj = node.trajectory(
                sample_8gaussians(1024),
                t_span=torch.linspace(0, 1, 100),
            )
            plot_trajectories(traj=traj.cpu().numpy(), output=f"{savedir}/VFM_{k+1}.png")
        
        evaluate(traj[-1].cpu(), sample_moons(1024))
            
torch.save(model, f"{savedir}/VFM.pt")

5000: loss 39.635 time 36.31
Fréchet Distance: 0.7450. Hausdorff Distance: 2.4823. Energy Distance: 0.1982.

10000: loss 17.081 time 38.51
Fréchet Distance: 0.0527. Hausdorff Distance: 2.5969. Energy Distance: 0.1030.

15000: loss 16.758 time 35.85
Fréchet Distance: 0.1534. Hausdorff Distance: 3.5174. Energy Distance: 0.0945.

20000: loss 19.866 time 40.05
Fréchet Distance: 0.0544. Hausdorff Distance: 2.4354. Energy Distance: 0.1098.

25000: loss 98.194 time 43.29
Fréchet Distance: 0.0101. Hausdorff Distance: 2.3636. Energy Distance: 0.0781.

30000: loss 612.881 time 41.46
Fréchet Distance: 0.4497. Hausdorff Distance: 5.5201. Energy Distance: 0.2798.

35000: loss 33.825 time 41.73
Fréchet Distance: 1.3140. Hausdorff Distance: 2.2207. Energy Distance: 0.3483.

40000: loss 37.171 time 43.05
Fréchet Distance: 0.1697. Hausdorff Distance: 2.5171. Energy Distance: 0.1142.

45000: loss 115.347 time 43.80
Fréchet Distance: 0.1534. Hausdorff Distance: 2.2687. Energy Distance: 0.0850.

50000: lo

KeyboardInterrupt: 