# Conditional Flow Matching with Optimal Transport

In [1]:
import os

import plotly.graph_objects as go
import plotly.express as px
import numpy as np
import pandas as pd

from sklearn.datasets import make_moons

import torch
from torch import Tensor

from tqdm import tqdm
from typing import List
from zuko.utils import odeint
from simple_flow_matching.flow_matching import CNF, FlowMatchingLoss

## Training Data

In [3]:
data, _ = make_moons(n_samples=4096, noise=0.05)
data = torch.from_numpy(data).float()
fig = go.Figure()
fig.add_trace(go.Scatter(x =data[:, 0], y= data[:, 1] , mode="markers"))
fig.update_layout(template="simple_white", title="Training Data", height=400, width=400)
fig.update_xaxes(title="x")
fig.update_yaxes(title="y")

## Model Definition

In [4]:
# Vector field architecture hyperparameters
hidden_units = 256
n_layers = 3
flow = CNF(data_dim=2, hidden_features=[hidden_units]*n_layers)

# Training
loss = FlowMatchingLoss(flow)
optimizer = torch.optim.AdamW(flow.parameters(), lr=1e-3)


## Training

In [35]:
batch_size = 256
max_steps = 2*4096

for step in tqdm(range(max_steps), ncols=88):
    
    # sample batch idx
    subset = torch.randint(0, len(data), (batch_size,))
    # get batch 
    x = data[subset]

    optimizer.zero_grad()
    loss(x).backward()
    if step%500 == 0:
        print(float(loss(x)))
    optimizer.step()

  1%|▏                                               | 41/8192 [00:00<00:41, 197.97it/s]

0.9007209539413452


  6%|███                                            | 523/8192 [00:02<00:38, 199.83it/s]

0.8817497491836548


 13%|█████▊                                        | 1026/8192 [00:05<00:35, 200.15it/s]

0.9140656590461731


 19%|████████▌                                     | 1530/8192 [00:07<00:33, 200.11it/s]

0.9733107686042786


 25%|███████████▍                                  | 2034/8192 [00:10<00:30, 200.27it/s]

0.966995358467102


 31%|██████████████▏                               | 2537/8192 [00:12<00:28, 200.08it/s]

0.9570615291595459


 37%|█████████████████                             | 3041/8192 [00:15<00:25, 199.70it/s]

0.9377673864364624


 43%|███████████████████▊                          | 3523/8192 [00:17<00:23, 200.05it/s]

0.9452264308929443


 49%|██████████████████████▌                       | 4027/8192 [00:20<00:20, 200.12it/s]

1.0258429050445557


 55%|█████████████████████████▍                    | 4531/8192 [00:22<00:18, 199.96it/s]

0.9078284502029419


 61%|████████████████████████████▎                 | 5034/8192 [00:25<00:15, 200.08it/s]

1.0838289260864258


 68%|███████████████████████████████               | 5538/8192 [00:27<00:13, 199.95it/s]

0.9300416111946106


 74%|█████████████████████████████████▉            | 6041/8192 [00:30<00:10, 200.20it/s]

0.8661118149757385


 80%|████████████████████████████████████▋         | 6524/8192 [00:32<00:08, 200.21it/s]

0.9417881369590759


 86%|███████████████████████████████████████▍      | 7028/8192 [00:35<00:05, 200.02it/s]

0.9195681214332581


 92%|██████████████████████████████████████████▎   | 7531/8192 [00:37<00:03, 199.31it/s]

0.9049481153488159


 98%|█████████████████████████████████████████████ | 8034/8192 [00:40<00:00, 200.01it/s]

0.9707792401313782


100%|██████████████████████████████████████████████| 8192/8192 [00:40<00:00, 200.35it/s]


## Sampling and visualizing Probability Paths

In [36]:
# Generate trajectories

start_pts = torch.randn(3000, 2)
pts = start_pts

time_step = 0.1
df_list = []
for start in np.arange(0.0, 1.0, time_step):

    pts = odeint(f=flow, x=pts, t0=start, t1=start+time_step, phi=flow.parameters())

    df = pd.DataFrame(pts.detach().numpy())
    df['Time'] = np.around(start,2)
    df['Idx'] = df.index.values
    df_list.append(df)
    
full_df = pd.concat(df_list)
full_df.columns = ["x", "y", "Time", "Idx"]

In [37]:
# Plot density evolution with time
fig = px.density_heatmap(
    full_df, 
    x="x",
    y="y", 
    range_x=[start_pts.numpy()[:, 0].min(), start_pts.numpy()[:, 0].max()],
    range_y=[start_pts.numpy()[:, 1].min(), start_pts.numpy()[:, 1].max()],
    nbinsx=20,
    nbinsy=20,
    histnorm='probability',
    range_color=(0, 0.03),
    animation_frame="Time")

# Plot data scatter 
fig.add_trace(
    go.Scatter(x=pts.detach().numpy()[:800, 0], y=pts.detach().numpy()[:800, 1],
    mode="markers", marker_color="white",
    marker_size=1,
    ),
)

fig.update_layout(
    title="Probability Density Path",
    plot_bgcolor=px.colors.sequential.Plasma[0],
    template="simple_white", width=400, height=400,)

In [38]:
fig = go.Figure()

for idx in range(0, full_df.Idx.max(), 100):
    df= full_df[full_df.Idx==idx]
    fig.add_trace(go.Scatter(x = df["x"], y = df["y"], showlegend=False,
                             line_color="grey",
                             mode="lines+markers", marker_color=df["Time"]))
    
fig.add_trace(
    go.Scatter(
        x=pts.detach().numpy()[:800, 0], y=pts.detach().numpy()[:800, 1], showlegend=False,
        mode="markers", marker_color=px.colors.sequential.Plasma[-1], marker_size=3
    )
)

fig.add_trace(
    go.Scatter(
        x=start_pts.numpy()[:800, 0], y=start_pts.numpy()[:800, 1], showlegend=False,
        mode="markers", marker_color=px.colors.sequential.Plasma[0], marker_size=3
    )
)

fig.update_layout(
    title="Point Trajectories",
    template="simple_white", width=600, height=600,
)
