# Experiment 2 — Flow matching / measure transport (§5.2)

Validates Corollary 3 (L¹ convergence, rate O(√h)) for the continuity
equation driven by a flow-matching velocity field.

In [None]:
import os, sys

for _root in (os.getcwd(), os.path.abspath(os.path.join(os.getcwd(), ".."))):
    if os.path.isdir(os.path.join(_root, "rnode")) and _root not in sys.path:
        sys.path.insert(0, _root)
        break

import torch, numpy as np, torch.nn as nn, random
from rnode.flow import Flow
from rnode.data import sample_initial_density, sample_target_density

## Configuration

In [None]:
QUICK = False  # Set True for a fast test run

CFG = dict(
    seed=42, dim=2, hidden=64,
    T=1.0, dt=0.01,
    n_epochs=10000 if not QUICK else 1500,
    batch_x0=250, batch_x1=100, lr=1e-2,
    extent=[-3, 8, -3, 5], grid_res=300,
    n_mesh=200, n_seeds=30,
    n_batches_default=3,
    n_real_viz=20 if not QUICK else 5,
    dt_conv=0.001,
    n_real_conv=20 if not QUICK else 6,
    hist_res=50,
    n_real_scheme=20 if not QUICK else 6,
    n_bench_trials=10 if not QUICK else 3,
)
CFG["n_steps"] = int(CFG["T"] / CFG["dt"])
CFG["n_steps_conv"] = int(CFG["T"] / CFG["dt_conv"])

OUT = os.path.join(os.getcwd(), "..", "outputs")
os.makedirs(OUT, exist_ok=True)
torch.manual_seed(CFG["seed"]); np.random.seed(CFG["seed"]); random.seed(CFG["seed"])

## Train flow network

In [None]:
flow = Flow(dim=CFG["dim"], hidden=CFG["hidden"])
opt = torch.optim.Adam(flow.parameters(), lr=CFG["lr"])
loss_fn = nn.MSELoss()

for ep in range(CFG["n_epochs"]):
    x0 = torch.tensor(sample_initial_density(CFG["batch_x0"]), dtype=torch.float32)
    x1 = torch.tensor(sample_target_density(CFG["batch_x1"]), dtype=torch.float32)
    idx = np.random.choice(CFG["batch_x0"], CFG["batch_x1"])
    x0m = x0[idx]

    t = torch.rand(CFG["batch_x1"], 1)
    xt = (1 - t) * x0m + t * x1

    opt.zero_grad()
    loss = loss_fn(flow(t, xt), x1 - x0m)
    loss.backward(); opt.step()

    if (ep+1) % 2000 == 0:
        print(f"  Epoch {ep+1}/{CFG['n_epochs']}  loss={loss.item():.4f}")

print(f"Final loss: {loss.item():.4f}")

## Generate figures

In [None]:
from experiments.exp2_plots import *

fig1_data(CFG, OUT)

In [None]:
d_ini, d_fin, lvl_i, lvl_f = fig2_density(flow, CFG, OUT)

In [None]:
fig3_comparison(flow, d_ini, d_fin, lvl_i, lvl_f, CFG, OUT)

In [None]:
conv = fig4_convergence(flow, CFG, OUT)

In [None]:
sch_conv = fig5_scheme_convergence(flow, CFG, OUT)

In [None]:
fig6_variance(flow, CFG, OUT)

In [None]:
bench = fig7_benchmark(flow, CFG, OUT)

## Save model

In [None]:
torch.save(flow.state_dict(), os.path.join(OUT, "flow_model.pth"))
print("Done. All figures saved to", OUT)