In [None]:
from abc import ABC, abstractmethod
import numpy as np
import sklearn.datasets as skd
from numpy.random import PCG64
from torch import nn, Tensor
import torch as torch
from tqdm import notebook
import unittest
%reload_ext ipython_unittest

# Flow Matching
Let $q$ be a complex target distribution over $\mathbb{R}^d$. Let $p$ be a simple distribution over the same space that can be easily sampled (for example $p \sim \mathcal{N}(0, I)$)

### Goal
We want to construct a probability path $p_t$ with $t \in [0,1]$ s.t. $p_0 = p$ and $p_1 = q$. Every $p_t$ is a valid probability density function (PDF).

Given a sample $X_0 \sim p$, the goal is to estimate the corresponding $X_1 \sim q$ using an Ordinary Differential Equation (ODE).
Let $u : [0, 1] \times \mathbb{R}^d \rightarrow \mathbb{R}^d$ be a *velocity* field $u$ and $\psi  : [0, 1] \times \mathbb{R}^d \rightarrow \mathbb{R}^d$ its corresponding *flow*, given by the ODE:
$$
\frac{d}{dt}\psi_t(x) = u_t(\psi_t(x))
$$
where $\psi_t = \psi(t, x)$ and $\psi_0(x) = x$


In [None]:
def euler(x_t, h, direction):
    return x_t + h * direction


class InterpolationMethod(ABC):
    @classmethod
    @abstractmethod
    def apply(cls, t: Tensor, x_0: Tensor, x_1: Tensor):
        """Returns the random variable X_t."""
        pass


class OptimalTransport(InterpolationMethod):
    """Uses the L2-norm optimal transport."""

    def __init__(self):
        super().__init__()

    def apply(self, t, x_0, x_1):
        return (1 - t) * x_0 + t * x_1


class FlowMatching(nn.Module):
    """A simple MLP, representing a parametrized velocity field."""

    def __init__(self, input_dim=3, hidden_dim=256, output_dim=2, n_layers=4, *args, **kwargs):
        super().__init__(*args, **kwargs)

        layers = [nn.Linear(input_dim, hidden_dim), nn.GELU()]
        for _ in range(n_layers - 1):
            layers.extend([nn.Linear(hidden_dim, hidden_dim), nn.GELU()])
        layers.append(nn.Linear(hidden_dim, output_dim))

        self.network = nn.Sequential(*layers)

    def forward(self, t, x):
        tx_input = torch.cat((t, x), -1)
        return self.network(tx_input)


def sample_from_p(n):
    return torch.tensor(RNG.multivariate_normal(mean=np.zeros(2), cov=np.eye(2), size=n)).float()


def sample_from_q(n):
    sample = torch.from_numpy(skd.make_moons(n_samples=n, noise=0.005)[0]).float()
    angle_rad = 45 * np.pi / 180  # Convert angle to radians
    c, s = np.cos(angle_rad), np.sin(angle_rad)

    rotation_matrix = torch.tensor([[2 * c, -s],
                                    [s, 2 * c]], dtype=torch.float32)
    rotated_sample = sample @ rotation_matrix
    return rotated_sample + torch.tensor(SHIFT_TARGET_BY)

In [None]:
%%unittest_main

# Test for the optimal transport.


class TestOptimalTransport(unittest.TestCase):
    def test_simple(self):
        x_0 = Tensor((0, 0, 0, 0, 0))
        x_1 = Tensor((1, 1, 1, 1, 1))
        t = Tensor((0, 0.25, 0.5, 0.75, 1))
        usedMethod = OptimalTransport()
        actual = usedMethod.apply(t, x_0, x_1)
        self.assertTrue(torch.equal(actual, Tensor([0, 0.25, 0.5, 0.75, 1])))

    def test_2d(self):
        x_0 = Tensor([[0, 0, 0, 0, 0],
                      [0, 0, 0, 0, 0]])
        x_1 = Tensor([[1, 1, 1, 1, 1],
                      [2, 2, 2, 2, 2]])
        t = Tensor([0, 0.25, 0.5, 0.75, 1])
        usedMethod = OptimalTransport()
        actual = usedMethod.apply(t, x_0, x_1)
        expected = Tensor([[0.0, 0.25, 0.5, 0.75, 1.0],
                           [0.0, 0.5, 1.0, 1.5, 2.0]])
        self.assertTrue(torch.equal(actual, expected))

In [None]:
SHIFT_TARGET_BY = [4, 4]
SOLVER = euler

PATH_DESIGN = OptimalTransport()
RNG = np.random.Generator(PCG64(seed=44))
N_TEST_SAMPLES = 100


In [None]:
EPOCHS = 10000
BATCH_SIZE = 2000
LR = 1e-3

model = FlowMatching()  # Simple MLP.
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
loss = nn.MSELoss()

for epoch in notebook.tqdm(range(EPOCHS)):
    x_0 = sample_from_p(BATCH_SIZE)
    x_1 = sample_from_q(BATCH_SIZE)
    t = torch.rand(BATCH_SIZE, 1)  # Sample t randomly from [0, 1].

    x_t = (1 - t) * x_0 + t * x_1

    # Get velocity prediction
    v_predicted = model(t, x_t)
    difference = x_1 - x_0

    optimizer.zero_grad()
    loss(v_predicted, difference).backward()
    optimizer.step()

In [None]:
trajectories = []
INFERENCE_STEPS = 50
test_data = sample_from_p(1)
x_t = test_data
delta_t = 1 / INFERENCE_STEPS

# Forward trajectory for an unseen noisy test sample.
with torch.no_grad():
    for current_step in range(INFERENCE_STEPS):
        t = delta_t * current_step
        t_tensor = torch.full((x_t.shape[0], 1), t)
        v_pred = model(t_tensor.float(), x_t.float())
        x_t = euler(x_t, delta_t, v_pred)
        trajectories.append(x_t)

## The following cells for creating visualizations have mostly been created using Generative AI.

In [None]:
fm_model = model

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch

# Generate samples for visualization
N_SAMPLES = 2000
initial_noise_samples = sample_from_p(N_SAMPLES).numpy()
target_data_samples = sample_from_q(N_SAMPLES).numpy()

# --- MODIFIED: Switched to the object-oriented API for more control ---
fig, ax = plt.subplots(figsize=(8, 8))

ax.set_title("Noise and Target Distributions", fontsize=16)
ax.set_xlabel("x", fontsize=12)
ax.set_ylabel("y", fontsize=12)

# Plot the initial noise distribution
ax.scatter(initial_noise_samples[:, 0], initial_noise_samples[:, 1],
           color='red', alpha=0.3, s=20, label=r'p ~ $\mathcal{N}\:(0, \text{I})$')

# Plot the target data distribution
ax.scatter(target_data_samples[:, 0], target_data_samples[:, 1],
           color='blue', alpha=0.3, s=20, label=r'$q$')

# Set consistent plot limits
ax.set_xlim([-4, 8])
ax.set_ylim([-4, 8])

# Explicitly set the ticks to ensure the full range is shown
ticks = np.arange(-4, 9, 2)
ax.set_xticks(ticks)
ax.set_yticks(ticks)

ax.grid(True, which='both', linestyle='--', linewidth=0.5)

# --- THE FIX: Use ax.set_aspect() instead of plt.axis('equal') ---
# This enforces the aspect ratio without changing the data limits.
ax.set_aspect('equal', adjustable='box')

ax.legend(loc='upper left', fontsize=12)
plt.savefig("noise_and_target.png", dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# This cell contains all the necessary components for the diffusion model.
import torch.nn as nn
import numpy as np
from tqdm.notebook import tqdm


class DiffusionScheduler:
    def __init__(self, num_timesteps=1000, beta_start=0.0001, beta_end=0.02):
        self.num_timesteps = num_timesteps
        self.betas = torch.linspace(beta_start, beta_end, num_timesteps)
        self.alphas = 1. - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, axis=0)
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod)

    def add_noise(self, x_start, t, noise):
        sqrt_alpha_t = self.sqrt_alphas_cumprod[t].view(-1, 1)
        sqrt_one_minus_alpha_t = self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1)
        return sqrt_alpha_t * x_start + sqrt_one_minus_alpha_t * noise


# This cell defines an improved DenoiseModel with a larger hidden dimension.

class DenoiseModel(nn.Module):
    # Increased hidden_dim from 128 to 256 for more capacity
    def __init__(self, input_dim=2, time_emb_dim=32, hidden_dim=256):
        super().__init__()
        self.time_mlp = nn.Sequential(nn.Linear(time_emb_dim, hidden_dim), nn.GELU(), nn.Linear(hidden_dim, hidden_dim))
        self.network = nn.Sequential(
            nn.Linear(input_dim + hidden_dim, hidden_dim), nn.GELU(),
            nn.Linear(hidden_dim, hidden_dim), nn.GELU(),
            nn.Linear(hidden_dim, input_dim)
        )
        self.time_emb_dim = time_emb_dim

    def forward(self, x, t):
        half_dim = self.time_emb_dim // 2
        emb = np.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim) * -emb)
        emb = t[:, None] * emb[None, :]
        time_emb = self.time_mlp(torch.cat((emb.sin(), emb.cos()), dim=1))
        xt_input = torch.cat([x, time_emb], dim=-1)
        return self.network(xt_input)


# The training function remains the same but will now use the updated model class.
def train_diffusion_model(epochs=10000, batch_size=2048):
    print("--- Training Diffusion Model ---")
    model = DenoiseModel()  # Instantiates the new, larger model
    scheduler = DiffusionScheduler()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    loss_fn = nn.MSELoss()

    for epoch in notebook.tqdm(range(epochs), desc="Training DM"):
        x_start = sample_from_q(batch_size)
        t = torch.randint(0, scheduler.num_timesteps, (batch_size,))
        noise = torch.randn_like(x_start)
        x_noisy = scheduler.add_noise(x_start, t, noise)
        noise_pred = model(x_noisy, t.float())
        loss = loss_fn(noise_pred, noise)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print("Diffusion training complete!\n")
    return model, scheduler

In [None]:
# Re-train the improved diffusion model with more epochs for better convergence.
DIFFUSION_EPOCHS = 30000

# This assumes 'fm_model' is already trained and available.
diff_model, diff_scheduler = train_diffusion_model(epochs=DIFFUSION_EPOCHS)

In [None]:
# This function generates the high-quality "noise-to-data" (denoising) trajectory for the diffusion model.

@torch.no_grad()
def generate_diffusion_generative_trajectory(model, scheduler, initial_noise, num_frames):
    """
    Generates a high-quality trajectory by running the FULL denoising process
    and then selecting a subset of frames for the animation.
    """
    x_t = initial_noise.clone()

    # This list will store the entire 1000-step history
    full_trajectory = [x_t.clone()]

    # Run the complete denoising loop from T-1 down to 0
    for t_step in range(scheduler.num_timesteps - 1, -1, -1):
        t = torch.full((initial_noise.shape[0],), t_step, dtype=torch.long)
        pred_noise = model(x_t, t.float())

        alpha_t = scheduler.alphas[t].view(-1, 1)
        alpha_t_cumprod = scheduler.alphas_cumprod[t].view(-1, 1)
        beta_t = scheduler.betas[t].view(-1, 1)

        mean = (1 / torch.sqrt(alpha_t)) * (x_t - (beta_t / torch.sqrt(1 - alpha_t_cumprod)) * pred_noise)

        if t_step > 0:
            x_t = mean + torch.sqrt(beta_t) * torch.randn_like(x_t)
        else:
            x_t = mean

        full_trajectory.append(x_t.clone())

    # Now, select 'num_frames' evenly spaced samples from the full history
    indices_to_sample = np.linspace(0, len(full_trajectory) - 1, num_frames, dtype=int)

    # Create the final animation trajectory
    animation_trajectory = torch.stack([full_trajectory[i] for i in indices_to_sample])

    return animation_trajectory

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import os

# --- 1. Setup the Simulation ---
N_STEPS = 5
DT = 1.0 / N_STEPS

# Set a fixed starting point
x_0_single = torch.tensor([[2.123, -1.9]])

# --- Create a directory to save the plots ---
output_folder = 'trajectory_steps'
os.makedirs(output_folder, exist_ok=True)

# --- 2. Run the 5-Step Generation Process ---
x_current = x_0_single.clone()
x_history = [x_current.numpy()]
v_history = []
displacement_history = []

for i in range(N_STEPS):
    t_tensor = torch.tensor([[i * DT]])
    with torch.no_grad():
        v_current = fm_model(t_tensor, x_current)

    displacement = v_current * DT
    x_next = x_current + displacement

    v_history.append(v_current.numpy())
    displacement_history.append(displacement.numpy())
    x_history.append(x_next.numpy())

    x_current = x_next

# --- 3. Generate and Save Each Plot Individually ---
q_samples = sample_from_q(500).numpy()
full_path = np.concatenate(x_history)

# Loop 6 times for 6 plots (5 steps + 1 final state)
for i in range(N_STEPS + 1):
    # Create a new figure and axes for each individual plot
    fig, ax = plt.subplots(figsize=(8, 8))

    # Common plot configuration
    ax.set_xlim([-4, 8])
    ax.set_ylim([-4, 8])
    ax.set_aspect('equal', adjustable='box')
    ax.grid(True, linestyle='--', alpha=0.6)
    ax.scatter(q_samples[:, 0], q_samples[:, 1], alpha=0.1, color='blue')

    # Logic for intermediate steps vs. the final plot
    if i < N_STEPS:
        # This is one of the first 5 steps
        current_time = i * DT
        ax.set_title(f"At t = {current_time:.2f} (Planning Step {i + 1})", fontsize=14)

        # Plot trajectory so far, current position, and next step vector
        path_so_far = full_path[:i + 1]
        traj_label = 'Trajectory So Far' if i > 0 else None
        ax.plot(path_so_far[:, 0], path_so_far[:, 1], 'k--', alpha=0.6, linewidth=1.5, label=traj_label)

        x_i = x_history[i].flatten()
        disp_i = displacement_history[i].flatten()
        ax.scatter(x_i[0], x_i[1], c='black', s=80, zorder=10, label=r'Current Position ($X_t$)')
        ax.quiver(x_i[0], x_i[1], disp_i[0], disp_i[1], color='red', scale_units='xy', scale=1, width=0.005,
                  label=r'Next Step ($v^{\theta}_t \cdot \Delta{t}$)')
    else:
        # This is the final plot (t=1.0)
        ax.set_title("Final Position (t = 1.00)", fontsize=14)

        # Plot full trajectory and final position
        ax.plot(full_path[:, 0], full_path[:, 1], 'k--', alpha=0.6, linewidth=1.5, label='Full Trajectory')
        x_final = x_history[-1].flatten()
        ax.scatter(x_final[0], x_final[1], c='blue', s=100, zorder=10, edgecolors='black',
                   label=r'Final Position ($X_1$)')

    ax.legend(loc='upper left')

    # Save the figure to the specified folder
    filename = os.path.join(output_folder, f'step_{i + 1}.png')
    plt.savefig(filename, dpi=150, bbox_inches='tight')
    plt.close(fig)  # Close the figure to free up memory and prevent it from displaying in the notebook

print(f"✅ Saved {N_STEPS + 1} plots to the '{output_folder}' directory.")

In [None]:
# SINGLE FLOW MATCHING ANIMATION

import matplotlib.animation as animation
import matplotlib.pyplot as plt
import numpy as np
import torch
from IPython.display import display, Markdown

# --- 1. Global Configuration ---
N_ANIMATION_SAMPLES = 500
FRAMES = 120
FILENAME_SINGLE = 'flow_matching_scaled_field_alt.mp4'
PLOT_MIN, PLOT_MAX = -4, 8
GRID_DENSITY = 18


# --- 2. Trajectory Generation ---
@torch.no_grad()
def generate_flow_trajectory(model, initial_noise, num_frames):
    """Simulates particle flow from t=0 to t=1 and returns the full trajectory."""
    trajectory = torch.zeros((num_frames, initial_noise.shape[0], 2))
    x_t = initial_noise.clone()
    trajectory[0] = x_t
    dt = 1.0 / (num_frames -1)

    for i in range(num_frames - 1):
        t = torch.full((initial_noise.shape[0], 1), i * dt)
        velocity = model(t.float(), x_t.float())
        x_t += velocity * dt  # Simple Euler integration
        trajectory[i + 1] = x_t
    return trajectory.numpy()


# --- 3. Animation Helper Functions ---
def setup_animation_axis(ax, target_samples, particle_color, particle_label):
    """Configures a subplot for animation and returns handles to dynamic artists."""
    ax.set_xlim(PLOT_MIN, PLOT_MAX)
    ax.set_ylim(PLOT_MIN, PLOT_MAX)
    ax.set_aspect('equal', adjustable='box')
    ax.grid(True, linestyle='--', alpha=0.6)
    ax.set_xlabel("x coordinate", fontsize=12)
    ax.set_ylabel("y coordinate", fontsize=12)
    ax.scatter(target_samples[:, 0], target_samples[:, 1], alpha=0.15, s=15, color='blue', label='Target')
    scatter = ax.scatter([], [], alpha=0.7, s=15, color=particle_color, label=particle_label)
    title = ax.set_title('', fontsize=14)
    ax.legend(loc='upper left')
    return scatter, title


def setup_vector_field(ax, model, dt):
    """Creates, initializes, and returns the quiver plot for the vector field."""
    x_grid = torch.linspace(PLOT_MIN, PLOT_MAX, GRID_DENSITY)
    y_grid = torch.linspace(PLOT_MIN, PLOT_MAX, GRID_DENSITY)
    X_grid, Y_grid = torch.meshgrid(x_grid, y_grid, indexing='xy')
    grid_points = torch.stack([X_grid.reshape(-1), Y_grid.reshape(-1)], dim=-1).float()
    t0 = torch.full((grid_points.shape[0], 1), 0.0)
    v_field_0 = model(t0, grid_points).detach() * dt
    quiver = ax.quiver(
        grid_points[:, 0], grid_points[:, 1], v_field_0[:, 0], v_field_0[:, 1],
        color='green', alpha=0.2, scale_units='xy', scale=1,
        width=0.003, headwidth=5, headlength=5
    )
    return quiver, grid_points


def update_vector_field(quiver, model, grid_points, t, dt):
    """Updates the vector field's arrows for the current time t."""
    t_tensor = torch.full((grid_points.shape[0], 1), t)
    v_field = model(t_tensor.float(), grid_points.float()).detach() * dt
    quiver.set_UVC(v_field[:, 0], v_field[:, 1])


def save_animation_to_file(fig, update_func, num_frames, filename):
    """Generates and saves the final animation video file."""
    print(f"Creating animation for {filename}... This may take a moment. ⏳")
    ani = animation.FuncAnimation(fig, update_func, frames=num_frames, blit=True, interval=50)
    try:
        ani.save(filename, writer='ffmpeg', fps=30, dpi=120)
        print(f"Animation saved successfully as '{filename}' ✅")
    except Exception as e:
        print(f"Error saving animation: {e}\nPlease ensure ffmpeg is installed.")
    plt.close(fig)


# --- 4. Main Function for Single Animation ---
def create_single_animation(model, trajectory, target_samples, filename):
    """Creates a single animation of particle flow with its vector field."""
    fig, ax = plt.subplots(figsize=(8, 8))
    dt = 1.0 / (len(trajectory) - 1)
    scatter, title = setup_animation_axis(ax, target_samples, 'red', r'$p_t(x)$ Samples')
    quiver, grid_points = setup_vector_field(ax, model, dt)

    def update(frame):
        current_t = frame * dt
        scatter.set_offsets(trajectory[frame])
        title.set_text(f'Flow Matching Transformation\nTime t = {current_t:.2f}')
        update_vector_field(quiver, model, grid_points, current_t, dt)
        return scatter, title, quiver

    save_animation_to_file(fig, update, num_frames=len(trajectory), filename=filename)

# --- 5. Execution for Single Animation ---
# Note: Assumes `sample_from_p`, `sample_from_q`, and `model` are pre-defined.
x_0_samples = sample_from_p(N_ANIMATION_SAMPLES)
x_1_samples = sample_from_q(N_ANIMATION_SAMPLES)
fm_trajectory = generate_flow_trajectory(model, x_0_samples, num_frames=FRAMES)
create_single_animation(model, fm_trajectory, x_1_samples, FILENAME_SINGLE)
display(Markdown(f"### [Download Animation]({FILENAME_SINGLE})"))

In [None]:
# --- 1. Reusable Helper from Cell 1 ---
def add_pause_frames(trajectory, num_pause_frames):
    """Appends the last frame multiple times to a trajectory to create a pause."""
    last_frame = trajectory[-1:, :, :]  # Keep dimension
    pause_frames = np.repeat(last_frame, num_pause_frames, axis=0)
    return np.concatenate([trajectory, pause_frames], axis=0)


# --- 2. Main Comparison Animation Function ---
def create_comparison_animation(fm_model, fm_traj, dm_traj, target_samples, diff_scheduler, filename,
                                num_original_frames):
    """Creates a side-by-side animation comparing Flow Matching and Diffusion."""
    fig, (ax_fm, ax_dm) = plt.subplots(1, 2, figsize=(16, 8))
    fig.subplots_adjust(wspace=0.25)

    dt_fm = 1.0 / (num_original_frames - 1)

    # Setup Flow Matching (left) plot using helpers
    scatter_fm, title_fm = setup_animation_axis(ax_fm, target_samples, 'red', 'FM Samples')
    quiver_fm, grid_points_fm = setup_vector_field(ax_fm, fm_model, dt_fm)

    # Setup Diffusion (right) plot using helpers
    scatter_dm, title_dm = setup_animation_axis(ax_dm, target_samples, 'purple', 'DM Samples')

    def update(frame):
        # Determine the effective frame to handle the end-of-animation pause
        effective_frame = min(frame, num_original_frames - 1)

        # --- Update Flow Matching Plot (Left) ---
        current_t_fm = effective_frame * dt_fm
        scatter_fm.set_offsets(fm_traj[frame])
        title_fm.set_text(f'Flow Matching\nTime t = {current_t_fm:.2f}')
        update_vector_field(quiver_fm, fm_model, grid_points_fm, current_t_fm, dt_fm)

        # --- Update Diffusion Plot (Right) ---
        timesteps = np.linspace(diff_scheduler.num_timesteps - 1, 0, num_original_frames)
        current_t_dm = int(timesteps[effective_frame])
        scatter_dm.set_offsets(dm_traj[frame])
        title_dm.set_text(f'Diffusion\nStep t = {current_t_dm}')

        return scatter_fm, title_fm, quiver_fm, scatter_dm, title_dm

    # Generate and save using the helper
    save_animation_to_file(fig, update, num_frames=len(fm_traj), filename=filename)


# --- 3. Execution for Comparison Animation ---
# Note: Assumes models and schedulers like `fm_model`, `diff_model`, etc. are pre-defined.
FILENAME_COMPARISON = 'fm_vs_diffusion_generative.mp4'

# Generate a shared set of initial noise for a fair comparison
initial_noise = sample_from_p(N_ANIMATION_SAMPLES)
target_samples_viz = sample_from_q(N_ANIMATION_SAMPLES)

# Generate trajectories for both models
print("Generating trajectories for comparison animation...")
fm_trajectory = generate_flow_trajectory(fm_model, initial_noise, FRAMES)
dm_trajectory = generate_diffusion_generative_trajectory(diff_model, diff_scheduler, initial_noise, FRAMES).numpy()

# Create the final animation
create_comparison_animation(
    fm_model, fm_trajectory, dm_trajectory, target_samples_viz,
    diff_scheduler, FILENAME_COMPARISON, num_original_frames=FRAMES
)
display(Markdown(f"### [Download Animation]({FILENAME_COMPARISON})"))

In [None]:
import matplotlib.gridspec as gridspec

# --- 1. New Helper Functions for Path Tracing ---
def setup_trajectory_traces(ax, num_samples, color):
    """Initializes and returns a list of line objects for tracing paths."""
    return [ax.plot([], [], color=color, alpha=0.25, linewidth=0.7)[0] for _ in range(num_samples)]

def update_trajectory_traces(lines, full_trajectory, current_frame):
    """Updates the data for each line to trace the path up to the current frame."""
    for i, line in enumerate(lines):
        path_so_far = full_trajectory[:current_frame + 1, i, :]
        line.set_data(path_so_far[:, 0], path_so_far[:, 1])

# --- 2. Main Animation Function with Tracing ---
def create_comparison_animation_with_traces(fm_model, fm_traj, dm_traj, target_samples, diff_scheduler, filename):
    """Creates a side-by-side animation that traces each particle's path."""
    fig, (ax_fm, ax_dm) = plt.subplots(1, 2, figsize=(16, 8), sharex=True, sharey=True)
    fig.subplots_adjust(wspace=0.25)

    num_frames = fm_traj.shape[0]
    num_samples = fm_traj.shape[1]
    dt_fm = 1.0 / (num_frames - 1)

    # Setup Flow Matching (left) plot using helpers
    scatter_fm, title_fm = setup_animation_axis(ax_fm, target_samples, 'red', 'Samples')
    quiver_fm, grid_points_fm = setup_vector_field(ax_fm, fm_model, dt_fm)
    fm_lines = setup_trajectory_traces(ax_fm, num_samples, 'red')

    # Setup Diffusion (right) plot using helpers
    scatter_dm, title_dm = setup_animation_axis(ax_dm, target_samples, 'purple', 'Samples')
    dm_lines = setup_trajectory_traces(ax_dm, num_samples, 'purple')

    def update(frame):
        # --- Update Titles and Scatter Positions ---
        current_t_fm = frame * dt_fm
        title_fm.set_text(f'Flow Matching\nTime t = {current_t_fm:.2f}')
        scatter_fm.set_offsets(fm_traj[frame])

        timesteps = np.linspace(diff_scheduler.num_timesteps - 1, 0, num_frames)
        current_t_dm = int(timesteps[frame])
        title_dm.set_text(f'Diffusion\nStep t = {current_t_dm}')
        scatter_dm.set_offsets(dm_traj[frame])

        # --- Update Vector Field and Traces ---
        update_vector_field(quiver_fm, fm_model, grid_points_fm, current_t_fm, dt_fm)
        update_trajectory_traces(fm_lines, fm_traj, frame)
        update_trajectory_traces(dm_lines, dm_traj, frame)

        # Return all artists that have been updated for blitting
        return (scatter_fm, title_fm, quiver_fm, scatter_dm, title_dm, *fm_lines, *dm_lines)

    # Generate and save using the main helper
    save_animation_to_file(fig, update, num_frames=num_frames, filename=filename)

# --- 3. Main Execution for Traced Comparison ---
FILENAME_TRACED = 'fm_vs_diffusion_traced.mp4'

# Generate trajectories starting from the same initial noise
print("\nGenerating trajectories for traced animation...")
initial_noise = sample_from_p(N_ANIMATION_SAMPLES)
target_samples_viz = sample_from_q(N_ANIMATION_SAMPLES)

fm_trajectory = generate_flow_trajectory(fm_model, initial_noise, FRAMES)
dm_trajectory = generate_diffusion_generative_trajectory(diff_model, diff_scheduler, initial_noise, FRAMES).numpy()


# Create the final animation with traced paths
create_comparison_animation_with_traces(
    fm_model, fm_trajectory, dm_trajectory, target_samples_viz,
    diff_scheduler, FILENAME_TRACED
)

display(Markdown(f"### [Download Animation]({FILENAME_TRACED})"))

In [None]:
!ffmpeg -y -i fm_vs_diffusion_traced.mp4 -vf tpad=stop_mode=clone:stop_duration=1 long_fm_vs_diffusion_traced.mp4