# Flow Matching Deep Dive

This notebook explores the Flow Matching component of FlowShield-UDRL in detail:

1. Theory: Continuous Normalizing Flows and Flow Matching
2. Vector field learning
3. ODE integration for sampling and density estimation
4. Visualization of learned flows

## Setup

In [None]:
import sys
sys.path.insert(0, '..')

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import FancyArrowPatch
from mpl_toolkits.mplot3d import Axes3D

from src.models.safety.flow_matching import VectorFieldNetwork, FlowMatchingModel
from src.utils.seed import set_global_seed

set_global_seed(42)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {device}')

## 1. Theory: Flow Matching

Flow Matching learns a **time-dependent vector field** $v_\theta(x, t)$ that transforms
a simple base distribution (e.g., Gaussian) into a complex target distribution.

### Key Concepts:

1. **Probability path**: $p_t(x)$ interpolates between base $p_0(x)$ and target $p_1(x)$

2. **Vector field**: $v_t(x)$ describes how samples flow from $t=0$ to $t=1$

3. **ODE**: Samples evolve according to $\frac{dx}{dt} = v_t(x)$

4. **Optimal Transport path** (Lipman et al., 2022):
   $$x_t = (1-t) x_0 + t x_1$$
   $$v_t(x_t | x_1) = x_1 - x_0$$

5. **Flow Matching loss**:
   $$\mathcal{L} = \mathbb{E}_{t, x_0, x_1} \|v_\theta(x_t, t) - (x_1 - x_0)\|^2$$

## 2. Create a Simple 2D Example

Let's start with a simple 2D example to understand the mechanics.

In [None]:
# Create a simple 2D target distribution (mixture of Gaussians)
def sample_target(n_samples):
    """Sample from a mixture of Gaussians."""
    centers = torch.tensor([
        [-2.0, 2.0],
        [2.0, 2.0],
        [0.0, -2.0],
    ])
    
    # Random center selection
    idx = torch.randint(0, len(centers), (n_samples,))
    samples = centers[idx] + 0.5 * torch.randn(n_samples, 2)
    return samples

# Visualize target distribution
target_samples = sample_target(1000)

plt.figure(figsize=(6, 6))
plt.scatter(target_samples[:, 0], target_samples[:, 1], alpha=0.5, s=10)
plt.xlim(-5, 5)
plt.ylim(-5, 5)
plt.title('Target Distribution (Mixture of Gaussians)')
plt.grid(True, alpha=0.3)
plt.show()

In [None]:
# Simple vector field network (unconditional)
class SimpleVectorField(nn.Module):
    def __init__(self, data_dim, hidden_dim=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(data_dim + 1, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, data_dim),
        )
    
    def forward(self, x, t):
        if t.dim() == 0:
            t = t.expand(x.shape[0], 1)
        elif t.dim() == 1:
            t = t.unsqueeze(-1)
        return self.net(torch.cat([x, t], dim=-1))

# Create and train
model = SimpleVectorField(data_dim=2, hidden_dim=128).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [None]:
# Training loop with Flow Matching
n_epochs = 500
batch_size = 256
losses = []

for epoch in range(n_epochs):
    # Sample target (x_1)
    x1 = sample_target(batch_size).to(device)
    
    # Sample base (x_0 ~ N(0, I))
    x0 = torch.randn_like(x1)
    
    # Sample time
    t = torch.rand(batch_size, device=device)
    
    # Interpolate (optimal transport path)
    xt = (1 - t.unsqueeze(-1)) * x0 + t.unsqueeze(-1) * x1
    
    # Target vector field
    target_v = x1 - x0
    
    # Predict vector field
    pred_v = model(xt, t)
    
    # Flow Matching loss
    loss = ((pred_v - target_v) ** 2).mean()
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    losses.append(loss.item())
    
    if (epoch + 1) % 100 == 0:
        print(f'Epoch {epoch+1}: Loss = {loss.item():.6f}')

plt.figure(figsize=(8, 4))
plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('Flow Matching Loss')
plt.title('Training Progress')
plt.yscale('log')
plt.grid(True, alpha=0.3)
plt.show()

## 3. Visualize the Learned Vector Field

In [None]:
def plot_vector_field(model, t, ax, n_grid=20, scale=0.3):
    """Plot vector field at time t."""
    x = torch.linspace(-5, 5, n_grid)
    y = torch.linspace(-5, 5, n_grid)
    X, Y = torch.meshgrid(x, y, indexing='xy')
    
    grid = torch.stack([X.flatten(), Y.flatten()], dim=1).to(device)
    t_tensor = torch.full((grid.shape[0],), t, device=device)
    
    model.eval()
    with torch.no_grad():
        v = model(grid, t_tensor)
    
    # Reshape and plot
    vx = v[:, 0].reshape(n_grid, n_grid).cpu().numpy()
    vy = v[:, 1].reshape(n_grid, n_grid).cpu().numpy()
    
    ax.quiver(X.numpy(), Y.numpy(), vx, vy, alpha=0.7, scale=1/scale)
    ax.set_xlim(-5, 5)
    ax.set_ylim(-5, 5)
    ax.set_title(f't = {t:.2f}')
    ax.grid(True, alpha=0.3)

# Plot vector field at different times
fig, axes = plt.subplots(1, 4, figsize=(16, 4))

for ax, t in zip(axes, [0.0, 0.25, 0.5, 0.75]):
    plot_vector_field(model, t, ax)

plt.suptitle('Learned Vector Field at Different Times', fontsize=14)
plt.tight_layout()
plt.show()

## 4. Generate Samples via ODE Integration

In [None]:
def integrate_ode(model, x0, n_steps=100):
    """Integrate ODE from t=0 to t=1 using Euler method."""
    dt = 1.0 / n_steps
    x = x0.clone()
    trajectory = [x.clone()]
    
    model.eval()
    with torch.no_grad():
        for i in range(n_steps):
            t = i / n_steps
            t_tensor = torch.full((x.shape[0],), t, device=device)
            v = model(x, t_tensor)
            x = x + dt * v
            trajectory.append(x.clone())
    
    return torch.stack(trajectory, dim=0)

# Sample from base distribution
n_samples = 500
x0 = torch.randn(n_samples, 2).to(device)

# Integrate
trajectory = integrate_ode(model, x0)

print(f'Trajectory shape: {trajectory.shape}')  # [n_steps+1, n_samples, 2]

In [None]:
# Visualize flow
fig, axes = plt.subplots(1, 5, figsize=(20, 4))

time_indices = [0, 25, 50, 75, 100]  # t = 0, 0.25, 0.5, 0.75, 1.0

for ax, t_idx in zip(axes, time_indices):
    samples = trajectory[t_idx].cpu().numpy()
    ax.scatter(samples[:, 0], samples[:, 1], alpha=0.5, s=10, c='blue')
    ax.set_xlim(-5, 5)
    ax.set_ylim(-5, 5)
    ax.set_title(f't = {t_idx/100:.2f}')
    ax.grid(True, alpha=0.3)
    ax.set_aspect('equal')

# Add target distribution for comparison
target_samples = sample_target(500)
axes[-1].scatter(target_samples[:, 0], target_samples[:, 1], alpha=0.3, s=10, c='red', label='Target')
axes[-1].legend()

plt.suptitle('Flow Evolution: Base → Target', fontsize=14)
plt.tight_layout()
plt.show()

In [None]:
# Plot sample trajectories
n_traj = 50
x0 = torch.randn(n_traj, 2).to(device)
trajectory = integrate_ode(model, x0, n_steps=100)

plt.figure(figsize=(8, 8))

# Plot trajectories
traj_np = trajectory.cpu().numpy()
for i in range(n_traj):
    plt.plot(traj_np[:, i, 0], traj_np[:, i, 1], alpha=0.3, linewidth=0.5)

# Mark start and end points
plt.scatter(traj_np[0, :, 0], traj_np[0, :, 1], c='blue', s=30, label='Start (t=0)', zorder=5)
plt.scatter(traj_np[-1, :, 0], traj_np[-1, :, 1], c='red', s=30, label='End (t=1)', zorder=5)

# Add target centers
centers = [[-2, 2], [2, 2], [0, -2]]
for c in centers:
    plt.scatter(*c, c='green', s=100, marker='x', linewidths=3, label='Target center' if c == centers[0] else '')

plt.xlim(-5, 5)
plt.ylim(-5, 5)
plt.xlabel('x')
plt.ylabel('y')
plt.title('Sample Trajectories from Base to Target')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

## 5. Density Estimation with Change of Variables

We can compute $\log p(x_1)$ using the change of variables formula:

$$\log p(x_1) = \log p(x_0) - \int_0^1 \text{tr}\left(\frac{\partial v_t}{\partial x}\right) dt$$

The Hutchinson trace estimator makes this tractable:

$$\text{tr}(J) = \mathbb{E}_{\epsilon \sim \mathcal{N}(0,I)} [\epsilon^\top J \epsilon]$$

In [None]:
def compute_log_prob(model, x1, n_steps=100, n_trace_samples=10):
    """Compute log probability of samples using reverse integration."""
    dt = 1.0 / n_steps
    x = x1.clone().requires_grad_(True)
    log_det = torch.zeros(x.shape[0], device=device)
    
    for i in range(n_steps, 0, -1):
        t = i / n_steps
        t_tensor = torch.full((x.shape[0],), t, device=device)
        
        # Compute vector field
        v = model(x, t_tensor)
        
        # Hutchinson trace estimator
        trace = 0.0
        for _ in range(n_trace_samples):
            eps = torch.randn_like(x)
            eps_grad = torch.autograd.grad(
                v, x, eps, create_graph=False, retain_graph=True
            )[0]
            trace += (eps * eps_grad).sum(dim=-1)
        trace = trace / n_trace_samples
        
        # Update
        log_det = log_det + dt * trace
        
        x = x.detach() - dt * v.detach()
        x.requires_grad_(True)
    
    # Log prob under base distribution (standard Gaussian)
    log_p0 = -0.5 * (x.detach() ** 2).sum(dim=-1) - x.shape[-1] * 0.5 * np.log(2 * np.pi)
    
    return log_p0 + log_det.detach()

# Compute log probs for samples
test_samples = sample_target(200).to(device)
log_probs = compute_log_prob(model, test_samples)

print(f'Mean log prob: {log_probs.mean().item():.4f}')
print(f'Std log prob: {log_probs.std().item():.4f}')

In [None]:
# Visualize log probability landscape
n_grid = 50
x = torch.linspace(-5, 5, n_grid)
y = torch.linspace(-5, 5, n_grid)
X, Y = torch.meshgrid(x, y, indexing='xy')
grid = torch.stack([X.flatten(), Y.flatten()], dim=1).to(device)

# Compute log probs (this may take a moment)
print('Computing log probabilities over grid...')
log_probs_grid = compute_log_prob(model, grid, n_steps=50, n_trace_samples=5)
log_probs_grid = log_probs_grid.reshape(n_grid, n_grid).cpu().numpy()

# Plot
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Log probability
im1 = axes[0].imshow(
    log_probs_grid.T, origin='lower', aspect='equal',
    extent=[-5, 5, -5, 5],
    cmap='viridis'
)
plt.colorbar(im1, ax=axes[0], label='log p(x)')
axes[0].set_xlabel('x')
axes[0].set_ylabel('y')
axes[0].set_title('Log Probability Density')

# Probability
prob_grid = np.exp(np.clip(log_probs_grid, -20, 5))
im2 = axes[1].imshow(
    prob_grid.T, origin='lower', aspect='equal',
    extent=[-5, 5, -5, 5],
    cmap='hot'
)
plt.colorbar(im2, ax=axes[1], label='p(x)')
axes[1].set_xlabel('x')
axes[1].set_ylabel('y')
axes[1].set_title('Probability Density')

plt.tight_layout()
plt.show()

## 6. Conditional Flow Matching for p(g|s)

In FlowShield-UDRL, we condition the flow on the state $s$ to model $p(g|s)$.
This tells us what commands are achievable from a given state.

In [None]:
# Create conditional Flow Matching model using our actual implementation
from src.models.safety.flow_matching import FlowMatchingModel

cond_model = FlowMatchingModel(
    data_dim=2,          # command dimension (horizon, return)
    condition_dim=2,     # state dimension
    hidden_dim=128,
    n_layers=4,
    n_integration_steps=100,
).to(device)

print(f'Parameters: {sum(p.numel() for p in cond_model.parameters()):,}')

In [None]:
# Create synthetic state-conditioned command distribution
def sample_state_conditioned_data(n_samples):
    """Sample (state, command) pairs where commands depend on state."""
    # States in a grid
    states = torch.rand(n_samples, 2) * 4 - 2  # [-2, 2]^2
    
    # Commands depend on state:
    # - Horizon depends on distance to center
    # - Return depends on state position
    distance = torch.norm(states, dim=-1)
    
    horizon = 10 + 5 * distance + torch.randn(n_samples) * 2
    target_return = 20 - 3 * states[:, 0] + 2 * states[:, 1] + torch.randn(n_samples) * 3
    
    commands = torch.stack([horizon, target_return], dim=-1)
    
    return states, commands

# Sample training data
train_states, train_commands = sample_state_conditioned_data(5000)

print(f'States shape: {train_states.shape}')
print(f'Commands shape: {train_commands.shape}')

In [None]:
# Visualize the state-command relationship
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# State distribution
axes[0].scatter(train_states[:, 0], train_states[:, 1], alpha=0.3, s=10)
axes[0].set_xlabel('State x')
axes[0].set_ylabel('State y')
axes[0].set_title('State Distribution')

# Command distribution colored by state x
sc = axes[1].scatter(train_commands[:, 0], train_commands[:, 1], 
                     c=train_states[:, 0], cmap='coolwarm', alpha=0.5, s=10)
plt.colorbar(sc, ax=axes[1], label='State x')
axes[1].set_xlabel('Horizon')
axes[1].set_ylabel('Target Return')
axes[1].set_title('Commands (colored by state x)')

# Horizon vs distance to center
distances = torch.norm(train_states, dim=-1)
axes[2].scatter(distances, train_commands[:, 0], alpha=0.3, s=10)
axes[2].set_xlabel('Distance from center')
axes[2].set_ylabel('Horizon')
axes[2].set_title('Horizon vs State Distance')

plt.tight_layout()
plt.show()

In [None]:
# Train conditional flow matching
optimizer = torch.optim.Adam(cond_model.parameters(), lr=1e-3)
batch_size = 256
n_epochs = 200

train_states_dev = train_states.to(device)
train_commands_dev = train_commands.to(device)

losses = []

for epoch in range(n_epochs):
    # Random batch
    idx = torch.randint(0, len(train_states), (batch_size,))
    states = train_states_dev[idx]
    commands = train_commands_dev[idx]
    
    loss = cond_model.compute_loss(commands, states)
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    losses.append(loss.item())
    
    if (epoch + 1) % 50 == 0:
        print(f'Epoch {epoch+1}: Loss = {loss.item():.6f}')

plt.figure(figsize=(8, 4))
plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('Flow Matching Loss')
plt.title('Conditional Flow Matching Training')
plt.yscale('log')
plt.grid(True, alpha=0.3)
plt.show()

In [None]:
# Sample commands for different states
test_states = torch.tensor([
    [-1.5, -1.5],  # bottom-left
    [0.0, 0.0],    # center
    [1.5, 1.5],    # top-right
]).to(device)

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

for i, (state, ax) in enumerate(zip(test_states, axes)):
    # Sample commands for this state
    state_batch = state.unsqueeze(0).expand(200, -1)
    with torch.no_grad():
        sampled_commands = cond_model.sample(state_batch)
    
    # Get true commands for similar states
    distances = torch.norm(train_states - state.cpu(), dim=-1)
    mask = distances < 0.5
    true_commands = train_commands[mask]
    
    # Plot
    ax.scatter(sampled_commands[:, 0].cpu(), sampled_commands[:, 1].cpu(),
               alpha=0.5, s=20, c='blue', label='Sampled')
    ax.scatter(true_commands[:, 0], true_commands[:, 1],
               alpha=0.5, s=20, c='red', label='True')
    ax.set_xlabel('Horizon')
    ax.set_ylabel('Target Return')
    ax.set_title(f'State = ({state[0]:.1f}, {state[1]:.1f})')
    ax.legend()

plt.suptitle('Conditional Sampling: True vs Generated Commands', fontsize=14)
plt.tight_layout()
plt.show()

## 7. OOD Detection with Flow Matching

Commands with low probability under $p(g|s)$ are out-of-distribution (OOD)
and potentially dangerous.

In [None]:
# Compute log probability for a grid of commands conditioned on a state
test_state = torch.tensor([[0.0, 0.0]]).to(device)  # Center state

# Create command grid
h_range = torch.linspace(5, 25, 40)
r_range = torch.linspace(10, 30, 40)
H, R = torch.meshgrid(h_range, r_range, indexing='xy')
command_grid = torch.stack([H.flatten(), R.flatten()], dim=-1).to(device)

# Expand state to match
state_expanded = test_state.expand(len(command_grid), -1)

# Compute log probs
print('Computing log probabilities...')
with torch.no_grad():
    log_probs = cond_model.log_prob(command_grid, state_expanded)

log_probs_grid = log_probs.reshape(40, 40).cpu().numpy()

# Plot
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Log probability heatmap
im1 = axes[0].imshow(
    log_probs_grid.T, origin='lower', aspect='auto',
    extent=[5, 25, 10, 30],
    cmap='viridis'
)
plt.colorbar(im1, ax=axes[0], label='log p(g|s)')
axes[0].set_xlabel('Horizon')
axes[0].set_ylabel('Target Return')
axes[0].set_title('Log Probability of Commands')

# OOD detection (threshold-based)
threshold = np.percentile(log_probs_grid, 10)  # 10th percentile as threshold
ood_grid = log_probs_grid < threshold

im2 = axes[1].imshow(
    (~ood_grid).astype(float).T, origin='lower', aspect='auto',
    extent=[5, 25, 10, 30],
    cmap='RdYlGn', vmin=0, vmax=1
)
axes[1].set_xlabel('Horizon')
axes[1].set_ylabel('Target Return')
axes[1].set_title('Safe (green) vs OOD (red) Commands')

# Add training data points
nearby_mask = torch.norm(train_states, dim=-1) < 0.5
nearby_commands = train_commands[nearby_mask]
axes[0].scatter(nearby_commands[:, 0], nearby_commands[:, 1], c='red', s=10, alpha=0.3)

plt.suptitle('OOD Detection for State (0, 0)', fontsize=14)
plt.tight_layout()
plt.show()

## Summary

Key takeaways about Flow Matching for FlowShield:

1. **Learns smooth probability flows**: Transforms Gaussian → complex distribution
2. **Efficient training**: Simple MSE loss on vector field predictions
3. **Exact density estimation**: Uses change of variables formula
4. **Conditional modeling**: $p(g|s)$ tells us achievable commands per state
5. **OOD detection**: Low probability ⟹ command is out-of-distribution

For more details, see our full FlowShield-UDRL implementation!