Adapted from flow_matching/examples/standalone_flow_matching.ipynb from the Facebook Flow Matching github repository (https://github.com/facebookresearch/flow_matching/blob/main/examples/standalone_flow_matching.ipynb).

Attempts to turn a standard Gaussian distribution into a distribution with two sharp peaks at -1 and +1, so the distribution is not yet dependent on anything (equivalent to heads & tails where the bet is always â‚¬1).



In [None]:
import torch
from torch import nn, Tensor
import numpy as np
import matplotlib.pyplot as plt


# Choose device

In [None]:
if torch.cuda.is_available():
    device = torch.device('cuda')
    print('cuda available')
    # mlflow.log_param('device', torch.cuda.get_device_name(device))
else:
    device = torch.device('cpu')
    print('cuda not available')
    # mlflow.log_param('device', 'cpu')

# Define model

In [None]:
class Flow(nn.Module):
    def __init__(self, dim: int = 1, h: int = 64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim + 1, h), nn.ELU(),
            nn.Linear(h, h), nn.ELU(),
            nn.Linear(h, h), nn.ELU(),
            nn.Linear(h, dim))

    def forward(self, t: Tensor, x_t: Tensor) -> Tensor:
        return self.net(torch.cat((t, x_t), -1))

    def step(self, x_t: Tensor, t_start: Tensor, t_end: Tensor) -> Tensor:
        t_start = t_start.view(1, 1).expand(x_t.shape[0], 1)

        return x_t + (t_end - t_start) * self(t=t_start + (t_end - t_start) / 2, x_t= x_t + self(x_t=x_t, t=t_start) * (t_end - t_start) / 2)

In [None]:
flow = Flow().to(device=device)
print(flow)
n_params = sum(p.numel() for p in flow.parameters())
print('Total nr of parameters:', n_params)

# Training

In [None]:
flow = Flow().to(device=device)

optimizer = torch.optim.Adam(flow.parameters(), 1e-2)
loss_fn = nn.MSELoss()

for _ in range(10000):
    x_1 = torch.randint(0, 2, (256, 1), dtype=torch.float32, device=device)*2-1
    x_0 = torch.randn_like(x_1)
    t = torch.rand(len(x_1), 1, device=device)

    x_t = (1 - t) * x_0 + t * x_1
    dx_t = x_1 - x_0

    optimizer.zero_grad()
    loss_fn(flow(t=t, x_t=x_t), dx_t).backward()
    optimizer.step()

# Sampling

In [None]:
x = torch.randn(10000, 1, device=device)
n_steps = 32
plot_interval = 8
fig, axes = plt.subplots(1, n_steps//plot_interval + 1, figsize=(15, 3), sharex=True, sharey=True)

time_steps = torch.linspace(0, 1.0, n_steps + 1, device=device)

bins = np.linspace(-4, 4, 50)

axes[0].set_title(f't = {time_steps[0]:.2f}')
axes[0].hist(x.cpu().detach()[:, 0], bins=bins)

for ax in axes: ax.set_yscale('log')

fig.suptitle(f'Histogram of the samples, {n_steps} steps')

for i in range(n_steps):
    x = flow.step(x_t=x, t_start=time_steps[i], t_end=time_steps[i + 1])

    if (i+1) % plot_interval == 0:
        # print(i)
        axes[i//plot_interval + 1].set_title(f't = {time_steps[i + 1]:.2f}')
        axes[i//plot_interval + 1].hist(x.cpu().detach()[:, 0], bins=bins)

plt.tight_layout()
plt.show()

In [None]:
# final histogram
plt.figure(figsize=(4, 3), dpi=200)
bins = np.linspace(-2, 2, 50)
plt.hist(x.cpu().detach()[:, 0], bins=bins, density=True)
# plt.yscale('log')
plt.xlabel('x')
plt.ylabel('Density')
plt.tight_layout()
plt.show()



In [None]:
plt.figure(figsize=(8,3))

# contourplot of the flow u(t, x)
x, t = np.meshgrid(np.linspace(-3, 3, 100), np.linspace(0, 1, 100))
x2 = torch.Tensor(x).view(-1, 1)
t2 = torch.Tensor(t).view(-1, 1)
u = flow(t=t2, x_t=x2).detach().numpy().reshape(100, 100)
print(u.min(), u.max())
cp = plt.contourf(t.reshape(100, 100), x.reshape(100, 100), u, cmap='coolwarm', levels=np.linspace(-15, 15, 13),
                   extend='both')
plt.xlabel('$t$')

# add dots for the two final points
plt.scatter([1.0, 1.0], [-1, 1], color='black', s=100)

# add a number of trajectories
n_steps = 8
N = 20
x = torch.randn(N).view(-1, 1)
plt.scatter([0.0]*N, x[:, 0].numpy(), color='black', s=5)
time_steps = torch.linspace(0, 1.0, n_steps + 1)

x_arr = np.empty((n_steps + 1, N))
x_arr[0] = x[:, 0].numpy()
for i in range(n_steps):
    x = flow.step(x_t=x, t_start=time_steps[i], t_end=time_steps[i + 1])
    x_arr[i+1] = x[:, 0].detach().numpy()
plt.plot(time_steps.numpy(), x_arr, color='black', alpha=0.3)

plt.ylabel('$x_t$')
plt.colorbar(cp, label='$u_t(x_t)$')

fig.suptitle(f'Contour plot of the flow $u_t(x)$, {n_steps} steps')

# Wasserstein distance

In [None]:
from scipy.stats import wasserstein_distance

In [None]:
n_steps = 32
print('n_steps:', n_steps)

time_steps = torch.linspace(0, 1.0, n_steps + 1)

n_samples = 10000
x = torch.randn(n_samples, 1)
for i in range(n_steps):
    x = flow.step(x_t=x, t_start=time_steps[i], t_end=time_steps[i + 1])

print(x.shape)

samples = x[:, 0].detach().numpy()
real = np.array([-1, 1])

print('samples.shape:', samples.shape)
print('real.shape:', real.shape)

print('Wasserstein distance:', wasserstein_distance(samples, real))