# Interpolations and Samplers

This section introduces various interpolation methods and demonstrates how `Rectified Flow` provides a user-friendly interface to handle these interpolations automatically. We'll then explore the different samplers included in the codebase.

If you're new to the Rectified Flow framework, we suggest starting with `train_2d_toys.ipynb` to get a better understanding.

**Interpolation**

Recall that given observed samples $X_0 \sim \pi_0$ and $X_1 \sim \pi_1$, the interpolation $X_t$ is defined as:

$$
X_t = \alpha_t \cdot X_0 + \beta_t \cdot X_1,
$$

where $ \alpha_t $ and $ \beta_t $ are time-dependent functions satisfying:
$$
\alpha_0 = \beta_1 = 0 \quad \text{and} \quad \alpha_1 = \beta_0 = 1.
$$


**Velocity Field**

The *rectified flow* is induced by the pair $(X_0, X_1)$ with a velocity field:

$$
v(z, t) = \mathbb{E}[ \dot{X}_t \mid X_t = z] = \arg \min_v \int_0^1 \mathbb{E} \left[\left\| \dot{\alpha}_t X_1 + \dot{\beta}_t X_0 - v(X_t, t) \right\|^2 \right] \, \mathrm{d}t
$$

where $ v(z, t) $ minimizes the expected error between the interpolated velocity and the learned velocity field:

In the 2D toy example, we previously used the `straight` interpolation, where $\alpha_t = 1 - t$ and $\beta_t = t$. However, $\alpha_t$ and $\beta_t$ can actually be **any** time-dependent functions, as long as they satisfy $\alpha_0 = \beta_1 = 0$ and $\alpha_1 = \beta_0 = 1$. Let’s explore some alternative interpolation methods to see how they influence the behavior of the rectified flow.

In [None]:
import torch
import os
import sys
import matplotlib.pyplot as plt

import torch.distributions as dist

from rectified_flow.utils import set_seed
from rectified_flow.datasets.toy_gmm import TwoPointGMM

from rectified_flow.rectified_flow import RectifiedFlow
from rectified_flow.models.toy_mlp import MLPVelocityConditioned, MLPVelocity

set_seed(0)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
n_samples = 50000
pi_0 = dist.MultivariateNormal(torch.zeros(2, device=device), torch.eye(2, device=device))
pi_1 = TwoPointGMM(x=15.0, y=2, std=0.3)
D0 = pi_0.sample([n_samples])
D1, labels = pi_1.sample_with_labels([n_samples])
labels.tolist()

plt.scatter(D0[:, 0].cpu().numpy(), D0[:, 1].cpu().numpy(), alpha=0.5, label='D0')
plt.scatter(D1[:, 0].cpu().numpy(), D1[:, 1].cpu().numpy(), alpha=0.5, label='D1')
plt.legend()
plt.xlim(-5, 18)
plt.ylim(-5, 5)
plt.gca().set_aspect('equal', adjustable='box')
plt.show()

In [None]:
from rectified_flow.models.toy_mlp import MLPVelocity

model = MLPVelocity(2, hidden_sizes = [128, 128, 128]).to(device)

rectified_flow = RectifiedFlow(
    data_shape=(2,),
    velocity_field=model,
    interp="straight",
    source_distribution=pi_0,
    device=device,
)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
batch_size = 1024

losses = []

for step in range(1000):
	optimizer.zero_grad()
	idx = torch.randperm(n_samples)[:batch_size]
	x_0 = D0[idx]
	x_1 = D1[idx]

	x_0 = x_0.to(device)
	x_1 = x_1.to(device)

	loss = rectified_flow.get_loss(x_0, x_1)
	loss.backward()
	optimizer.step()
	losses.append(loss.item())

	if step % 200 == 0:
		print(f"Epoch {step}, Loss: {loss.item()}")

plt.plot(losses)

In [4]:
from rectified_flow.samplers import EulerSampler
from rectified_flow.utils import visualize_2d_trajectories_plotly

euler_sampler_1rf_unconditional = EulerSampler(
    rectified_flow=rectified_flow,
    num_steps=100,
    num_samples=500,
)

euler_sampler_1rf_unconditional.sample_loop(seed=0)

print(len(euler_sampler_1rf_unconditional.trajectories))

visualize_2d_trajectories_plotly(
    trajectories_dict={"1rf": euler_sampler_1rf_unconditional.trajectories[:50], "2rf": euler_sampler_1rf_unconditional.trajectories[50:]},
    D1_gt_samples=D1[:1000],
    num_trajectories=200,
	title="Unconditional 1-Rectified Flow",
)