In [1]:
from matplotlib import pyplot as plt
import torch
import torch.nn as nn

import numpy as np

In [2]:
def generate_star(n_spikes=5, inner_radius=0.4, outer_radius=1.0, n_samples=1000, center=(0, 0)):
    points = []
    angle_step = np.pi / n_spikes

    # Generate the star's vertices
    vertices = []
    for i in range(2 * n_spikes):
        angle = i * angle_step
        radius = outer_radius if i % 2 == 0 else inner_radius

        x = radius * np.cos(angle) + center[0]
        y = radius * np.sin(angle) + center[1]
        vertices.append([x, y])
    vertices.append(vertices[0])

    # Sample points along the star's edges
    vertices = np.array(vertices)
    sampled_points = []

    for i in range(len(vertices) - 1):
        start_point = vertices[i]
        end_point = vertices[i + 1]

        # Interpolate points along the edge
        for t in np.linspace(0, 1, n_samples // (len(vertices) - 1)):
            point = (1 - t) * start_point + t * end_point
            sampled_points.append(point)

    return np.array(sampled_points)

In [3]:
class Flow(nn.Module):
  def __init__(self, dim: int = 2, hid: int = 64):
    super().__init__()
    self.net = nn.Sequential(
        nn.Linear(dim+1, hid), nn.ELU(),
        nn.Linear(hid, hid), nn.ELU(),
        nn.Linear(hid, hid), nn.ELU(),
        nn.Linear(hid, dim)
    )

  def forward(self, t: torch.Tensor, x_t: torch.Tensor) -> torch.Tensor:
    return self.net(torch.cat([x_t, t], -1))

  def step(self, x_t: torch.Tensor, t_start: torch.Tensor, t_end: torch.Tensor):
    t_start = t_start.view(-1, 1).expand(x_t.shape[0], 1)
    delta_t = (t_end - t_start)
    t_mid = t_start + delta_t / 2

    v = self(t_mid, x_t)

    return x_t + delta_t * v

In [4]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [6]:
flow = Flow(2, 64).to(device)

optimizer = torch.optim.Adam(flow.parameters(), 1e-2)
loss_fn = nn.MSELoss()

for _ in range(2500):
  x_1 = torch.Tensor(generate_star(n_samples=1024)).to(device)
  x_0 = torch.randn_like(x_1)

  t = torch.rand(len(x_1), 1, device=device)
  x_t = (1 - t) * x_0 + t * x_1
  dx_t = x_1 - x_0

  pred = flow(x_t=x_t, t=t)
  optimizer.zero_grad()
  loss = loss_fn(pred, dx_t)
  loss.backward()
  optimizer.step()

  if _ % 100 == 0:
    print(f"Epoch: {_} Train Loss: {loss.item() / 256}")

Epoch: 0 Train Loss: 0.004799309652298689
Epoch: 100 Train Loss: 0.0032291733659803867
Epoch: 200 Train Loss: 0.0028470417018979788
Epoch: 300 Train Loss: 0.003056406741961837
Epoch: 400 Train Loss: 0.002728142077103257
Epoch: 500 Train Loss: 0.002993255387991667
Epoch: 600 Train Loss: 0.0029228758066892624
Epoch: 700 Train Loss: 0.0029074493795633316
Epoch: 800 Train Loss: 0.002721037482842803
Epoch: 900 Train Loss: 0.0028466179501265287
Epoch: 1000 Train Loss: 0.0029581687413156033
Epoch: 1100 Train Loss: 0.0031000052113085985
Epoch: 1200 Train Loss: 0.002721753902733326
Epoch: 1300 Train Loss: 0.0027628412935882807
Epoch: 1400 Train Loss: 0.0028022341430187225
Epoch: 1500 Train Loss: 0.002782669384032488
Epoch: 1600 Train Loss: 0.0027853669598698616
Epoch: 1700 Train Loss: 0.0027391063049435616
Epoch: 1800 Train Loss: 0.0027404131833463907
Epoch: 1900 Train Loss: 0.0026587527245283127
Epoch: 2000 Train Loss: 0.0027634501457214355
Epoch: 2100 Train Loss: 0.0028665941208601
Epoch: 220

In [7]:
from matplotlib import pyplot as plt

In [8]:
x = torch.randn(1500, 2, device=device)
n_steps = 10
time_steps = torch.linspace(0, 1., n_steps + 1, device=device)

ans = []
for i in range(n_steps):
  x = flow.step(x_t=x, t_start=time_steps[i], t_end=time_steps[i+1])
  ans.append(x)

In [21]:
import matplotlib.pyplot as plt
import matplotlib.animation as animation

def animate_2d_samples(sampled_steps, interval=200, save_path="flow_matching.gif"):
    fig, ax = plt.subplots()
    scat = ax.scatter([], [], s=10)

    def init():
        ax.set_xlim(-5, 5)
        ax.set_ylim(-5, 5)
        return scat,

    def update(frame):
        data = sampled_steps[frame].detach().cpu()  # [N, 2]
        scat.set_offsets(data.cpu().numpy())
        timesteps = frame * 0.1
        ax.set_title(f"Timestep {timesteps:.1f}")
        return scat,

    ani = animation.FuncAnimation(fig, update, frames=len(sampled_steps), init_func=init, blit=True, interval=interval)
    ani.save(save_path, writer='pillow')
    plt.close()

In [22]:
animate_2d_samples(ans, interval=500)