# Conditional Flow Matching with Optimal Transport

In [1]:
%load_ext autoreload
import os

import plotly.graph_objects as go
import plotly.express as px
import numpy as np
import pandas as pd

from sklearn.datasets import make_moons

import torch
from tqdm import tqdm
from typing import List
from zuko.utils import odeint
from simple_flow_matching.flow_matching import CNF, FlowMatchingLoss

## Training Data

In [2]:
data, _ = make_moons(n_samples=4096, noise=0.05)
data = torch.from_numpy(data).float()
fig = go.Figure()
fig.add_trace(go.Scatter(x =data[:, 0], y= data[:, 1] , mode="markers"))
fig.update_layout(template="simple_white", title="Training Data", height=400, width=400)
fig.update_xaxes(title="x")
fig.update_yaxes(title="y")

## Model Definition

In [23]:
%autoreload 2
# Vector field architecture hyperparameters
hidden_units = 256
n_layers = 3
flow = CNF(data_dim=2, hidden_features=[hidden_units]*n_layers)

# Training
loss = FlowMatchingLoss(flow, sigma_min=1e-2)
optimizer = torch.optim.AdamW(flow.parameters(), lr=1e-4)


## Training

In [35]:
batch_size = 256
max_steps = 4096

for step in tqdm(range(max_steps), ncols=88):
    
    # sample batch idx
    subset = torch.randint(0, len(data), (batch_size,))
    # get batch 
    x = data[subset]

    optimizer.zero_grad()
    loss(x).backward()
    if step%500==0:
        print(float(loss(x)))
    optimizer.step()

  2%|▋                                               | 62/4096 [00:00<00:13, 307.43it/s]

0.8789365291595459


 14%|██████▍                                        | 558/4096 [00:01<00:11, 308.77it/s]

0.874943733215332


 26%|███████████▊                                  | 1054/4096 [00:03<00:09, 308.51it/s]

0.9205570220947266


 38%|█████████████████▍                            | 1550/4096 [00:05<00:08, 308.55it/s]

1.0945568084716797


 50%|██████████████████████▉                       | 2047/4096 [00:06<00:06, 308.35it/s]

1.0253463983535767


 62%|████████████████████████████▌                 | 2544/4096 [00:08<00:05, 299.69it/s]

0.9859099984169006


 74%|██████████████████████████████████▏           | 3041/4096 [00:09<00:03, 308.62it/s]

0.9831485748291016


 86%|███████████████████████████████████████▋      | 3538/4096 [00:11<00:01, 308.60it/s]

0.9511550664901733


 99%|█████████████████████████████████████████████▎| 4035/4096 [00:13<00:00, 308.71it/s]

0.9964752197265625


100%|██████████████████████████████████████████████| 4096/4096 [00:13<00:00, 308.44it/s]


## Sampling and visualizing Probability Paths

In [36]:
# Generate trajectories

start_pts = torch.randn(3000, 2)
pts = start_pts

time_step = 0.1
df_list = []
for start in np.arange(0.0, 1.0, time_step):

    pts = odeint(f=flow, x=pts, t0=start, t1=start+time_step, phi=flow.parameters())

    df = pd.DataFrame(pts.detach().numpy())
    df['Time'] = np.around(start,2)
    df['Idx'] = df.index.values
    df_list.append(df)
    
full_df = pd.concat(df_list)
full_df.columns = ["x", "y", "Time", "Idx"]

In [37]:
# Plot density evolution with time
fig = px.density_heatmap(
    full_df, 
    x="x",
    y="y", 
    range_x=[start_pts.numpy()[:, 0].min(), start_pts.numpy()[:, 0].max()],
    range_y=[start_pts.numpy()[:, 1].min(), start_pts.numpy()[:, 1].max()],
    nbinsx=20,
    nbinsy=20,
    histnorm='probability',
    range_color=(0, 0.03),
    animation_frame="Time")

# Plot data scatter 
fig.add_trace(
    go.Scatter(x=pts.detach().numpy()[:800, 0], y=pts.detach().numpy()[:800, 1],
    mode="markers", marker_color="white",
    marker_size=1,
    ),
)

fig.update_layout(
    title="Probability Density Path",
    plot_bgcolor=px.colors.sequential.Plasma[0],
    template="simple_white", width=400, height=400,)

In [48]:
fig = go.Figure()

for idx in range(0, full_df.Idx.max(), 100):
    df= full_df[full_df.Idx==idx]
    fig.add_trace(go.Scatter(x = [df["x"].values[0], df["x"].values[-1]], 
                             y = [df["y"].values[0], df["y"].values[-1]], 
                             showlegend=False,
                             line_color="grey",
                             mode="lines+markers", marker_color=df["Time"]))
    
fig.add_trace(
    go.Scatter(
        x=pts.detach().numpy()[:800, 0], y=pts.detach().numpy()[:800, 1], showlegend=False,
        mode="markers", marker_color=px.colors.sequential.Plasma[-1], marker_size=3
    )
)

fig.add_trace(
    go.Scatter(
        x=start_pts.numpy()[:800, 0], y=start_pts.numpy()[:800, 1], showlegend=False,
        mode="markers", marker_color=px.colors.sequential.Plasma[0], marker_size=3
    )
)

fig.update_layout(
    title="Point Trajectories",
    template="simple_white", width=600, height=600,
)
