# Flow Matching

In [1]:
import time
import torch

import matplotlib.pyplot as plt

from Engine import *
from pathlib import Path
from torchdyn.core import NeuralODE
from torchdyn.datasets import generate_moons

### Conditional Flow Matching

In [None]:
def trajectories(model, x_0, steps):
    x_t = x_0
    delta_t = 1 / steps
    trajectory = [x_t.cpu().numpy()]
    for k in range(steps):
        t = k / steps * torch.ones(x_t.shape[0], 1)
        v_t = model(torch.cat([x_t, t], dim=-1))
        x_t = x_t + v_t * delta_t
        trajectory.append(x_t.cpu().numpy())

    trajectory = np.array(trajectory)
    return torch.tensor(trajectory)

In [2]:
%%time

savedir = os.path.join(os.getcwd(), "Results/CFM")
Path(savedir).mkdir(parents=True, exist_ok=True)

sigma = 0.1
dim = 2
batch_size = 256
model = MLP(dim=dim, time_varying=True)
optimizer = torch.optim.Adam(model.parameters())
FM = CFM(sigma=sigma)
criterion = torch.nn.MSELoss()

start = time.time()
for k in tqdm(range(20000)):
    optimizer.zero_grad()

    x0 = sample_8gaussians(batch_size)
    x1 = sample_moons(batch_size)

    t, xt, ut = FM.sample_location_and_conditional_flow(x0, x1)

    vt = model(torch.cat([xt, t[:, None]], dim=-1))
    loss = criterion(vt, ut)

    loss.backward()
    optimizer.step()

    if (k + 1) % 5000 == 0:
        end = time.time()
        print(f"{k+1}: loss {loss.item():0.3f} time {(end - start):0.2f}")
        start = end
        
        with torch.no_grad():
            traj = trajectories(model, sample_8gaussians(1024), steps=100)
            plot_trajectories(traj=traj)
            evaluate(traj[-1], sample_moons(1024))

torch.save(model, f"{savedir}/CFM.pt")

5000: loss 10.423 time 8.56


NameError: name 'torch_wrapper' is not defined

### Optimal Transport Conditional Flow Matching

In [3]:
%%time

savedir = os.path.join(os.getcwd(), "Results/OT-CFM")
Path(savedir).mkdir(parents=True, exist_ok=True)

sigma = 0.1
dim = 2
batch_size = 256
model = MLP(dim=dim, time_varying=True)
optimizer = torch.optim.Adam(model.parameters())
FM = OT_CFM(sigma=sigma)
criterion = torch.nn.MSELoss()

start = time.time()
for k in tqdm(range(20000)):
    optimizer.zero_grad()

    x0 = sample_8gaussians(batch_size)
    x1 = sample_moons(batch_size)

    t, xt, ut = FM.sample_location_and_conditional_flow(x0, x1)

    vt = model(torch.cat([xt, t[:, None]], dim=-1))
    loss = criterion(vt, ut)

    loss.backward()
    optimizer.step()

    if (k + 1) % 5000 == 0:
        end = time.time()
        print(f"{k+1}: loss {loss.item():0.3f} time {(end - start):0.2f}")
        start = end

        with torch.no_grad():
            traj = trajectories(model, sample_8gaussians(1024), steps=100)
            plot_trajectories(traj=traj.cpu().numpy())
            evaluate(traj[-1].cpu(), sample_moons(1024))

torch.save(model, f"{savedir}/OT-CFM.pt")

KeyboardInterrupt: 

### Variational Flow Matching

In [None]:
def trajectories(model, x_0, steps):
    xt = x_0
    delta_t = 1 / steps
    trajectory = [xt.cpu().numpy()]
    for k in range(steps):
        t = k / steps * torch.ones(xt.shape[0], 1)
        x1 = model(torch.cat([xt, t], dim=-1))
        v_t = (x1 - xt) / (1 - t)
        xt = xt + v_t * delta_t
        trajectory.append(xt.cpu().numpy())

    trajectory = np.array(trajectory)
    return torch.tensor(trajectory)

In [5]:
%%time

savedir = os.path.join(os.getcwd(), "Results/VFM")
Path(savedir).mkdir(parents=True, exist_ok=True)

sigma = 0.1
dim = 2
batch_size = 256
model = MLP(dim=dim, time_varying=True)
optimizer = torch.optim.Adam(model.parameters())
FM = CFM(sigma=sigma)
# criterion = torch.nn.MSELoss()
criterion = torch.nn.GaussianNLLLoss()

start = time.time()
for k in tqdm(range(20000)):
    optimizer.zero_grad()

    x0 = sample_8gaussians(batch_size)
    x1 = sample_moons(batch_size)

    t, xt, _ = FM.sample_location_and_conditional_flow(x0, x1)

    var = torch.ones(batch_size, dim, requires_grad=False) * sigma**2
    var.requires_grad_(False)

    vt = model(torch.cat([xt, t[:, None]], dim=-1))
    # loss = criterion(vt, x1)
    loss = criterion(vt, x1, var)

    loss.backward()
    optimizer.step()

    if (k + 1) % 5000 == 0:
        end = time.time()
        print(f"{k+1}: loss {loss.item():0.3f} time {(end - start):0.2f}")
        start = end
        
        with torch.no_grad():
            traj = trajectories(model, sample_8gaussians(1024), steps=100)
            plot_trajectories(traj=traj)
            evaluate(traj[-1], sample_moons(1024))

torch.save(model, f"{savedir}/VFM.pt")

KeyboardInterrupt: 

### Stochastic Gradients

In [None]:
class MLPWithScore(torch.nn.Module):
    def __init__(self, dim, time_varying):
        super(MLPWithScore, self).__init__()
        self.mu = MLP(dim=dim, time_varying=time_varying)
        self.score = MLP(dim=dim, time_varying=time_varying)

    def forward(self, x):
        mu = self.mu(x)
        score = self.score(x)
        return mu, score


def g_t(t):
    return torch.exp(t)


def trajectories(model, x_0, steps):
    xt = x_0
    delta_t = 1 / steps
    trajectory = [xt.cpu().numpy()]
    for k in range(steps):
        t = k / steps * torch.ones(xt.shape[0], 1)
        mu_theta, score_theta = model(torch.cat([xt, t], dim=-1))
        gt = g_t(t)
        v_tilde = ((mu_theta - xt) / (1 - t)) + ((gt**2 / 2) * score_theta)
        xt = xt + v_tilde * delta_t
        trajectory.append(xt.cpu().numpy())

    trajectory = np.array(trajectory)
    return torch.tensor(trajectory)

In [None]:
%%time

savedir = os.path.join(os.getcwd(), "Results/SG")
Path(savedir).mkdir(parents=True, exist_ok=True)

sigma = 0.1
dim = 2
batch_size = 256
model = MLPWithScore(dim=dim, time_varying=True)
optimizer = torch.optim.Adam(model.parameters())
FM = CFM(sigma=sigma)
criterion_v = torch.nn.GaussianNLLLoss()
criterion_s = torch.nn.MSELoss()

start = time.time()
for k in tqdm(range(20000)):
    optimizer.zero_grad()

    x0 = sample_8gaussians(batch_size)
    x1 = sample_moons(batch_size)

    t, xt, _ = FM.sample_location_and_conditional_flow(x0, x1)

    xt.requires_grad_(True)

    var = torch.ones(batch_size, dim, requires_grad=False) * sigma**2
    var.requires_grad_(False)
    gt = g_t(t)

    mu_theta, score_theta = model(torch.cat([xt, t[:, None]], dim=-1))
    loss_v = criterion_v(mu_theta, x1, var / pad_t_like_x(gt, var)**2)

    score_true = torch.autograd.grad((x1 - xt).sum(), xt, create_graph=True)[0]
    loss_s = criterion_s(score_theta, score_true / pad_t_like_x(gt, score_true))
    
    loss = loss_v + loss_s

    loss.backward()
    optimizer.step()

    if (k + 1) % 5000 == 0:
        end = time.time()
        print(f"{k+1}: loss {loss.item():0.3f} time {(end - start):0.2f}")
        start = end
        
        with torch.no_grad():
            traj = trajectories(model, sample_8gaussians(1024), steps=100)
            plot_trajectories(traj=traj)
            evaluate(traj[-1], sample_moons(1024))

torch.save(model, f"{savedir}/SG.pt")

### VFM with learned sigma

In [None]:
def trajectories(model, x_0, steps):
    xt = x_0
    delta_t = 1 / steps
    trajectory = [xt.cpu().numpy()]
    for k in range(steps):
        t = k / steps * torch.ones(xt.shape[0], 1)
        x1 = model(torch.cat([xt, t], dim=-1))
        v_t = (x1 - xt) / (1 - t)
        xt = xt + v_t * delta_t
        trajectory.append(xt.cpu().numpy())

    trajectory = np.array(trajectory)
    return torch.tensor(trajectory)


class SigmaMLP(MLP):
    def __init__(self, dim, out_dim=None, w=64, time_varying=False):
        super().__init__(dim, out_dim=out_dim, w=w, time_varying=time_varying)
        self.last_filter = torch.nn.Sigmoid()

    def forward(self, x):
        pred = self.net(x)
        pred = self.last_filter(pred)
        return pred

    savedir = os.path.join(os.getcwd(), "Results/SVFM")
    Path(savedir).mkdir(parents=True, exist_ok=True)

In [6]:
%% time

dim = 2
batch_size = 256
noise = 0.2

model = MLP(dim=dim, time_varying=True)
sigma = torch.nn.Parameter(torch.rand(1))
optimizer = torch.optim.Adam([param for param in model.parameters()] + [sigma])
FM = CFM()
criterion = torch.nn.GaussianNLLLoss()


start = time.time()
for k in tqdm(range(20000)):
    optimizer.zero_grad()

    x0 = sample_8gaussians(batch_size)
    x1 = sample_moons(batch_size, noise=noise)

    t, xt, _ = FM.sample_location_and_conditional_flow(x0, x1)

    mu_theta = model(torch.cat([xt, t[:, None]], dim=-1))

    var = torch.ones(batch_size, dim) * (sigma**2)

    loss = criterion(mu_theta, x1, var)

    loss.backward()
    optimizer.step()
    
    if (k + 1) % 5000 == 0:
        end = time.time()
        print(f"{k+1}: loss {loss.item():0.3f} time {(end - start):0.2f}")
        start = end

        with torch.no_grad():
            traj = trajectories(model, sample_8gaussians(1024), steps=100)
            plot_trajectories(traj=traj.cpu().numpy(), output=f"{savedir}/SVFM_{k+1}.png")
            evaluate(traj[-1].cpu(), sample_moons(1024))
            print(sigma)
            
torch.save(model, f"{savedir}/SVFM.pt")
torch.save(sigma, f"{savedir}/sigma.pt")

5000: loss 504.400 time 8.86


NameError: name 'torch_wrapper' is not defined