In [2]:
import os
import sys

SCRIPT_DIR = "/home/abhutani/simple-flow-matching"
sys.path.append(SCRIPT_DIR)

import plotly.graph_objects as go
import plotly.express as px
import numpy as np
import pandas as pd

from sklearn.datasets import make_moons

import torch
from torch import Tensor

from tqdm import tqdm
from typing import List
from zuko.utils import odeint
from flow_matching import CNF, FlowMatchingLoss

## Training Data

In [4]:
data, _ = make_moons(4096, noise=0.05)
data = torch.from_numpy(data).float()
fig = go.Figure()
fig.add_trace(go.Scatter(x =data[:, 0], y= data[:, 1] , mode="markers"))
fig.update_layout(template="simple_white", title="Training Data", height=400, width=400)
fig.update_xaxes(title="x")
fig.update_yaxes(title="y")

## Model Definition

In [5]:
# Architecture hyperparameters
hidden_units = 256
n_layers = 3
flow = CNF(data_dim=2, hidden_features=[hidden_units]*n_layers)

# Training
loss = FlowMatchingLoss(flow)
optimizer = torch.optim.AdamW(flow.parameters(), lr=1e-3)


## Training

In [6]:
batch_size = 256
max_epochs = 4096

for epoch in tqdm(range(max_epochs), ncols=88):

    # sample batch idx
    subset = torch.randint(0, len(data), (batch_size,))
    # get batch 
    x = data[subset]

    optimizer.zero_grad()
    loss(x).backward()
    optimizer.step()

100%|██████████████████████████████████████████████| 4096/4096 [00:14<00:00, 282.43it/s]


## Sampling and Vizualizing Probability Paths

In [None]:
# Generate trajectories

start_pts = torch.randn(3000, 2)
pts = start_pts

time_step = 0.1
df_list = []
for start in np.arange(0.0, 1.0, time_step):
    pts = odeint(f=flow, x=pts, t0=start, t1=start+time_step, phi=flow.parameters())
    df = pd.DataFrame(pts.detach().numpy())
    df['Time'] = np.around(start,2)
    df['Idx'] = df.index.values()
    df_list.append(df)
full_df = pd.concat(df_list)
full_df.columns = ["x", "y", "Time"]

: 

: 

: 

In [35]:
fig = px.density_heatmap(
    full_df, 
    x="x",
    y="y", 
    range_x=[start_pts.numpy()[:, 0].min(), start_pts.numpy()[:, 0].max()],
    range_y=[start_pts.numpy()[:, 1].min(), start_pts.numpy()[:, 1].max()],
    nbinsx=20,
    nbinsy=20,
    histnorm='probability',
    range_color=(0, 0.03),
    animation_frame="Time")
fig.add_trace(
    go.Scatter(x=pts.detach().numpy()[:800, 0], y=pts.detach().numpy()[:800, 1],
    mode="markers", marker_color="white",
    marker_size=1,
    ),
    
    )
fig.update_layout(
    title="Probability Density Path",
    plot_bgcolor=px.colors.sequential.Plasma[0],
    template="simple_white", width=400, height=400,)

In [36]:
fig = go.Figure()
for i in range(100):
    start_x = start_pts.numpy()[i, 0]
    start_y = start_pts.numpy()[i, 1]
    end_x = pts.detach().numpy()[i, 0]
    end_y = pts.detach().numpy()[i, 1]
    fig.add_trace(go.Scatter(x = [start_x, end_x], y = [start_y, end_y]))

In [37]:
px.scatter()