# Flow Matching Experiments
This notebook explores the synthetic dataset and compares conditional mean models (MLP/Transformer) with a conditional flow matching model.

## 1. Setup


In [None]:
import math
from pathlib import Path

import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

import cfm
from cfm import SyntheticDatasetConfig, SyntheticControlDataset, train_val_test_split
from cfm.evaluation import make_dataloader
from cfm.training import MeanModelTrainer, MeanModelConfig
from cfm.flow_matching import FlowMatchingTrainer, FlowMatchingConfig
from cfm.utils import set_seed


Set a seed for reproducibility.


In [None]:
set_seed(42)


## 2. Generate a synthetic dataset


In [None]:
config = SyntheticDatasetConfig(
    n_samples=512,
    seq_len=32,
    control_dim=3,
    static_dim=2,
    noise_std=0.2,
    nonlinear_strength=0.4,
    regime_change_prob=0.12,
    regime_scale=0.7,
    seed=123,
)
dataset = SyntheticControlDataset(config)
train_ds, val_ds, test_ds = train_val_test_split(dataset, seed=1234)
len(train_ds), len(val_ds), len(test_ds)


Visualise a random trajectory and the associated controls.


In [None]:
sample = dataset[0]
time = np.arange(config.seq_len)
fig, axes = plt.subplots(2, 1, figsize=(10, 6), sharex=True)
axes[0].plot(time, sample['targets'].numpy(), label='target')
axes[0].set_ylabel('Target')
axes[0].legend()
for i in range(config.control_dim):
    axes[1].plot(time, sample['dynamic_controls'][:, i].numpy(), label=f'ctrl {i}')
axes[1].set_xlabel('Time step')
axes[1].set_ylabel('Control value')
axes[1].legend(ncol=config.control_dim)
plt.tight_layout()
plt.show()


## 3. Data efficiency of conditional mean models


In [None]:
subset_sizes = [64, 128, 256, 384]
batch_size = 64
device = 'cuda' if torch.cuda.is_available() else 'cpu'

def trainer_factory(model_type):
    return MeanModelTrainer(MeanModelConfig(
        seq_len=config.seq_len,
        control_dim=config.control_dim,
        static_dim=config.static_dim,
        model_type=model_type,
        lr=3e-3,
        n_epochs=80,
        batch_size=batch_size,
        device=device,
    ))

def evaluate_subset(trainer, subset_size):
    subset = torch.utils.data.Subset(train_ds, list(range(subset_size)))
    train_loader = make_dataloader(subset, batch_size=batch_size, shuffle=True)
    val_loader = make_dataloader(val_ds, batch_size=batch_size, shuffle=False)
    history = trainer.fit(train_loader, val_loader)
    test_loader = make_dataloader(test_ds, batch_size=batch_size, shuffle=False)
    preds = trainer.predict(test_loader)
    targets = torch.cat([test_ds[i]['targets'] for i in range(len(test_ds))], dim=0)
    mse = torch.mean((preds - targets) ** 2).item()
    return history['train'][-1], mse

mlp_rows = []
for size in subset_sizes:
    trainer = trainer_factory('mlp')
    train_loss, mse = evaluate_subset(trainer, size)
    mlp_rows.append({'subset_size': size, 'train_loss': train_loss, 'test_mse': mse})
pd.DataFrame(mlp_rows)


Repeat with the Transformer baseline.


In [None]:
transformer_rows = []
for size in subset_sizes:
    trainer = trainer_factory('transformer')
    train_loss, mse = evaluate_subset(trainer, size)
    transformer_rows.append({'subset_size': size, 'train_loss': train_loss, 'test_mse': mse})
pd.DataFrame(transformer_rows)


## 4. Conditional flow matching


In [None]:
flow_config = FlowMatchingConfig(
    seq_len=config.seq_len,
    control_dim=config.control_dim,
    static_dim=config.static_dim,
    lr=2e-3,
    n_epochs=100,
    batch_size=64,
    device=device,
)
flow_trainer = FlowMatchingTrainer(flow_config)
train_loader = make_dataloader(train_ds, batch_size=flow_config.batch_size, shuffle=True)
val_loader = make_dataloader(val_ds, batch_size=flow_config.batch_size, shuffle=False)
history = flow_trainer.fit(train_loader, val_loader)
history['train'][-5:]


Generate samples conditioned on the controls and compare to the ground truth.


In [None]:
test_loader = make_dataloader(test_ds, batch_size=8, shuffle=False)
batch = next(iter(test_loader))
generated = flow_trainer.sample(batch['dynamic_controls'], batch['static_controls'], n_steps=60)
fig, axes = plt.subplots(4, 1, figsize=(10, 12), sharex=True)
for i in range(4):
    axes[i].plot(generated[i].cpu().numpy(), label='generated')
    axes[i].plot(batch['targets'][i].cpu().numpy(), label='ground truth', linestyle='--')
    axes[i].legend()
axes[-1].set_xlabel('Time step')
plt.tight_layout()
plt.show()


Evaluate reconstruction error of flow matching samples.


In [None]:
metrics = []
for batch in test_loader:
    generated = flow_trainer.sample(batch['dynamic_controls'], batch['static_controls'], n_steps=60)
    mse = torch.mean((generated - batch['targets']) ** 2).item()
    metrics.append(mse)
np.mean(metrics)
