In [None]:
from IPython.display import HTML

HTML('''<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.7/MathJax.js?config=TeX-AMS-MML_SVG"></script>''')

In [12]:
import torch
import os
import sys
import numpy as np
import matplotlib.pyplot as plt
import warnings
import copy
import plotly.graph_objects as go

import torch.distributions as dist

from rectified_flow.utils import set_seed
from rectified_flow.utils import match_dim_with_data
from rectified_flow.datasets.toy_gmm import TwoPointGMM

from rectified_flow.rectified_flow import RectifiedFlow
from rectified_flow.models.toy_mlp import MLPVelocityConditioned, MLPVelocity

from rectified_flow.samplers import EulerSampler
from rectified_flow.flow_components.interpolation_convertor import AffineInterpConverter
from rectified_flow.flow_components.interpolation_solver import AffineInterp

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
from rectified_flow.datasets.toy_gmm import TwoPointGMM

set_seed(0)
n_samples = 50000
pi_0 = dist.MultivariateNormal(torch.tensor([-5., 0.], device=device), torch.eye(2, device=device)*0.09)
pi_1 = TwoPointGMM(x=3.37, y=3.69, std=0.3)
D0 = pi_0.sample([n_samples])
D1, labels = pi_1.sample_with_labels([n_samples])
labels.tolist()

from rectified_flow.flow_components.interpolation_solver import AffineInterp
from rectified_flow.utils import visualize_2d_trajectories_plotly

straight_interp = AffineInterp("straight")
spherical_interp = AffineInterp("spherical")

idx = torch.randperm(n_samples)[:1000]
x_0 = D0[idx]
x_1 = D1[idx]

print(x_0.shape)

straight_interp_list = []
spherical_interp_list = []

for t in np.linspace(0, 1, 50):
	x_t_straight, dot_x_t_straight = straight_interp.forward(x_0, x_1, t)
	x_t_spherical, dot_x_t_spherical = spherical_interp.forward(x_0, x_1, t)
	straight_interp_list.append(x_t_straight)
	spherical_interp_list.append(x_t_spherical)

visualize_2d_trajectories_plotly(
	trajectories_dict={"Straight interp": straight_interp_list, "Spherical interp": spherical_interp_list},
	D1_gt_samples=D1[:5000],
	num_trajectories=50,
	title="Interpolated Trajectories Visualization",
)

In [None]:
import torch
import plotly.graph_objects as go

t = torch.linspace(0, 1, 100)

title = "Straight to Spherical Transform"
matched_t, scaling_factor = AffineInterpConverter.match_time_and_scale(
    AffineInterp("straight"), 
    AffineInterp("spherical"), 
    t
)

# title = "DDIM to Spherical Transform"
# matched_t, scaling_factor = AffineInterpConverter.match_time_and_scale(
#     AffineInterp("ddim"), 
#     AffineInterp("spherical"), 
#     t
# )

fig = go.Figure()

fig.add_trace(
    go.Scatter(
        x=t,
        y=matched_t,
        mode='lines',
        name=r'$\tau_t$',
        hovertemplate=r't=%{x:.3f}, tau_t=%{y:.3f}<extra></extra>'
    )
)
fig.add_trace(
    go.Scatter(
        x=t,
        y=scaling_factor,
        mode='lines',
        name=r'$\omega_t$',
        hovertemplate=r't=%{x:.3f}, omega_t=%{y:.3f}<extra></extra>'
    )
)

fig.update_xaxes(
    range=[0, 1],
    title_text='t',
    tickmode='array',
    tickvals=[0., 0.2, 0.4, 0.6, 0.8, 1.0],
    ticktext=['0.0', '0.2', '0.4', '0.6', '0.8', '1.0'],
    showline=False,
    zeroline=False
)
fig.update_yaxes(
    range=[0, 1.01],
    title_text='Value',
    tickmode='array',
    tickvals=[0.2, 0.4, 0.6, 0.8, 1.0],
    ticktext=['0.2', '0.4', '0.6', '0.8', '1.0'],
    showline=False,
    zeroline=False,
    mirror=False
)

fig.update_layout(
    title={
        'text': title,     # 标题内容
        'x': 0.5,          # 标题在横轴上居中 (0 ~ 1)，0.5 即表示在绘图区中点
        'xanchor': 'center'
    },
    # width=400,
    # height=400,
    autosize=True,
    legend=dict(
        x=0.80,
        y=0.20,
        xanchor='left',
        yanchor='top',
        bgcolor='rgba(255, 255, 255, 0.8)',
        bordercolor='gray',
        borderwidth=1.1,
    ),
    margin=dict(l=5, r=5, t=35, b=5)
)

fig.show()

fig.write_html(
	'tau_straight_spherical.html', 
	full_html=True,
	include_plotlyjs="cdn",
	include_mathjax='cdn', 
)

In [None]:
def rf_trainer(rectified_flow, label = "loss", batch_size = 1024):
    model = rectified_flow.velocity_field
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-2)

    losses = []
    for step in range(5000):
        optimizer.zero_grad()
        x_0 = pi_0.sample([batch_size]).to(device)
        x_1 = pi_1.sample([batch_size]).to(device)

        loss = rectified_flow.get_loss(x_0, x_1)
        loss.backward()
        optimizer.step()
        losses.append(loss.item())

        if step % 1000 == 0:
            print(f"Epoch {step}, Loss: {loss.item()}")

    plt.plot(losses, label=label)
    plt.legend()

from rectified_flow.models.toy_mlp import MLPVelocity

set_seed(0)
straight_rf = RectifiedFlow(
    data_shape=(2,),
    velocity_field=MLPVelocity(2, hidden_sizes = [64, 128, 128, 128, 64]).to(device),
    interp=straight_interp,
    source_distribution=pi_0,
    device=device,
)

set_seed(0)
spherical_rf = RectifiedFlow(
    data_shape=(2,),
	velocity_field=MLPVelocity(2, hidden_sizes = [64, 128, 128, 128, 64]).to(device),
	interp=spherical_interp,
	source_distribution=pi_0,
	device=device,
)

In [None]:
set_seed(0)
rf_trainer(rectified_flow=straight_rf, label="straight interp")

set_seed(0)
rf_trainer(rectified_flow=spherical_rf, label="spherical interp")

num_samples = 250
num_steps = 100

euler_sampler_straight = EulerSampler(straight_rf, num_steps=num_steps, num_samples=num_samples)
euler_sampler_straight.sample_loop(seed=0)

euler_sampler_spherical = EulerSampler(spherical_rf, num_steps=num_steps, num_samples=num_samples)
euler_sampler_spherical.sample_loop(seed=0)

# Straight Converted to Spherical   v.s.   Original Straight RF
# zoom in to see they are really close
visualize_2d_trajectories_plotly(
	trajectories_dict={
        "Straight": euler_sampler_straight.trajectories, 
        "Spherical": euler_sampler_spherical.trajectories
    },
	D1_gt_samples=D1[:1000],
	num_trajectories=50,
	title="Euler Sampler",
)

In [63]:
target_interp = AffineInterp("spherical")
convert_spherical_rf = AffineInterpConverter(straight_rf, target_interp).transform_rectified_flow()

In [None]:
# Try different num_steps, e.g. [5, 10, 50, 100, 500]
num_samples = 250
num_steps = 100

euler_sampler_straight = EulerSampler(straight_rf, num_steps=num_steps)
euler_sampler_straight.sample_loop(seed=0, num_samples=num_samples)

euler_sampler_converted_spherical = EulerSampler(convert_spherical_rf, num_steps=num_steps, num_samples=num_samples)
euler_sampler_converted_spherical.sample_loop(seed=0)

mse = torch.mean((euler_sampler_straight.trajectories[-1] - euler_sampler_converted_spherical.trajectories[-1])**2)
print(mse)

# Straight Converted to Spherical   v.s.   Original Straight RF
# zoom in to see they are really close
visualize_2d_trajectories_plotly(
	trajectories_dict={
        "Straight": euler_sampler_straight.trajectories, 
        "Spherical": euler_sampler_converted_spherical.trajectories
    },
	D1_gt_samples=D1[:1000],
	num_trajectories=50,
	title="Euler Sampler:    Straight Converted to Spherical    v.s.   Straight RF",
)

In [98]:
def visualize_2d_trajectories_plotly_static(
    trajectories_dict: dict[str, list[torch.Tensor]],
    D1_gt_samples: torch.Tensor = None,
    num_trajectories: int = 50,
    markersize: int = 3,
    dimensions: list[int] = [0, 1],
    alpha_trajectories: float = 0.5,
    alpha_particles: float = 0.8,
    alpha_gt_points: float = 1.0,
    show_legend: bool = True,
    title: str = "2D Trajectories Visualization",
):
    import torch
    import numpy as np
    import plotly.graph_objects as go

    dim0, dim1 = dimensions

    # Convert ground truth samples to NumPy if provided
    if D1_gt_samples is not None:
        D1_gt_samples = D1_gt_samples.clone().to(torch.float32).cpu().detach().numpy()

    # Prepare color mapping for trajectories
    particle_colors = [
        # "#1E90FF",
        "#FF69B4",
        "#7B68EE",
        "#FF8C00",
        "#32CD32",
        "#4169E1",
        "#FF4500",
        "#9932CC",
        "#ADFF2F",
        "#FFD700",
    ]

    trajectory_colors = [
        # "#8ABDE5",
        "#E09CAF",
        "#B494E1",
        "#E5B680",
        "#82C9A1",
        "#92BCD5",
        "#E68FA2",
        "#A98FC8",
        "#E5C389",
        "#A0C696",
    ]

    marker_list = [
        "circle", "cross", "x", "square", "star", "diamond",
        "triangle-up", "triangle-down", "hexagram"
    ]

    trajectory_names = list(trajectories_dict.keys())
    colors = {}
    for i, name in enumerate(trajectory_names):
        colors[name] = {
            "particle_color": particle_colors[i % len(particle_colors)],
            "trajectory_color": trajectory_colors[i % len(trajectory_colors)],
            "marker": marker_list[i % len(marker_list)],
        }

    # Process trajectories and store data
    trajectory_data = {}
    max_time_steps = 0

    for trajectory_name, traj_list in trajectories_dict.items():
        # traj_list: List of trajectories, each shape [batch_size, dimension]
        xtraj_list = [
            traj.clone().to(torch.float32).detach().cpu().numpy() for traj in traj_list
        ]
        # xtraj shape: [time_steps, total_batch_size, dimension]
        xtraj = np.stack(xtraj_list)
        trajectory_data[trajectory_name] = xtraj
        max_time_steps = max(max_time_steps, xtraj.shape[0])

    # Create figure
    fig = go.Figure()

    # Collect all x and y values for setting axis range
    all_x, all_y = [], []

    # Plot ground truth samples (if any)
    if D1_gt_samples is not None:
        all_x.extend(D1_gt_samples[:, dim0])
        all_y.extend(D1_gt_samples[:, dim1])
        fig.add_trace(
            go.Scatter(
                x=D1_gt_samples[:, dim0],
                y=D1_gt_samples[:, dim1],
                mode="markers",
                name="GT Data",
                marker=dict(size=markersize, opacity=alpha_gt_points, color="red"),
                showlegend=False,
                hovertemplate='(%{x:.3f}, %{y:.3f})<extra>%{fullData.name}</extra>'
            )
        )
        
    fig.add_trace(
            go.Scatter(
                x=xtraj[0, :, dim0],
                y=xtraj[0, :, dim1],
                mode="markers",
                name="Source Dist",
                marker=dict(size=markersize, opacity=alpha_gt_points, color="blue"),
                showlegend=False,
                hovertemplate='(%{x:.3f}, %{y:.3f})<extra>%{fullData.name}</extra>'
            )
        )

    # Plot trajectories (lines) + collect all time-step points in one scatter
    for trajectory_name, xtraj in trajectory_data.items():
        particle_color = colors[trajectory_name]["particle_color"]
        trajectory_color = colors[trajectory_name]["trajectory_color"]
        marker_symbol = colors[trajectory_name]["marker"]

        time_steps = xtraj.shape[0]
        num_points = xtraj.shape[1]
        indices = np.arange(min(num_trajectories, num_points))

        # Accumulate for axis range
        all_x.extend(xtraj[:, :, dim0].ravel())
        all_y.extend(xtraj[:, :, dim1].ravel())

        # 1) Plot lines connecting each particle's trajectory over time
        #    (for the first `num_trajectories` particles)
        all_line_x = []
        all_line_y = []
        for i in indices:
            line_x = xtraj[:, i, dim0]
            line_y = xtraj[:, i, dim1]
            all_line_x.extend(line_x.tolist() + [np.nan])
            all_line_y.extend(line_y.tolist() + [np.nan])

        fig.add_trace(
            go.Scatter(
                x=all_line_x,
                y=all_line_y,
                mode="lines",
                name=f"{trajectory_name}",
                line=dict(dash="solid", color=trajectory_color, width=1.5),
                opacity=alpha_trajectories,
                showlegend=False,  # 这条线的 legend
                hovertemplate='(%{x:.3f}, %{y:.3f})<extra>%{fullData.name} Traj</extra>'
            )
        )

        # 2) 将所有时刻的散点一次性画出
        #    （若只画其中部分粒子，可以再次用 indices；这里示例对所有粒子画点）
        x_all_points = xtraj[1:, :, dim0].ravel()
        y_all_points = xtraj[1:, :, dim1].ravel()
        fig.add_trace(
            go.Scatter(
                x=x_all_points,
                y=y_all_points,
                mode="markers",
                name=f"{trajectory_name}",
                marker=dict(
                    size=markersize,
                    color=particle_color,
                    symbol=marker_symbol
                ),
                opacity=alpha_particles,
                showlegend=False,  # 如果不需要重复出现在 legend，就 False
                hovertemplate='(%{x:.3f}, %{y:.3f})<extra>%{fullData.name}</extra>'
            )
        )

    # 现在已经把所有的轨迹（线）和各时刻的散点都静态画完

    # 计算上下限, 适当扩展
    min_x, max_x = np.min(all_x), np.max(all_x)
    min_y, max_y = np.min(all_y), np.max(all_y)
    delta_x = 0.02 * (max_x - min_x) if (max_x > min_x) else 1.0
    delta_y = 0.02 * (max_y - min_y) if (max_y > min_y) else 1.0

    fig.update_xaxes(
        range=[-5.8, 4.5],
        showgrid=True,
        gridcolor="white",
        gridwidth=1,
        griddash="dot",
        showticklabels=False,
        showline=False,
        zeroline=False,
        mirror=False,
        dtick=2.0,
    )

    fig.update_yaxes(
        range=[-4.8, 4.8],
        showgrid=True,
        gridcolor="white",
        gridwidth=1,
        griddash="dot",
        showticklabels=False,
        showline=False,
        zeroline=False,
        mirror=False,
        dtick=1.0,
    )

    # 去掉所有 slider、frames、按钮，仅保留静态可视化
    fig.update_layout(
        margin=dict(l=10, r=10, t=10, b=10),
        showlegend=show_legend,
        # height=600,
        # width=900,
        autosize=True,
    )

    # 静态展示
    fig.show()

    # 可写出 HTML
    fig.write_html(
        "./convert_10step_spherical.html",
        full_html=True,
        include_plotlyjs="cdn"
    )

In [None]:
# Try different num_steps, e.g. [5, 10, 50, 100, 500]
num_samples = 1000
num_steps = 4

euler_sampler_straight = EulerSampler(straight_rf, num_steps=num_steps)
euler_sampler_straight.sample_loop(seed=0, num_samples=num_samples)

euler_sampler_converted_spherical = EulerSampler(convert_spherical_rf, num_steps=num_steps, num_samples=num_samples)
euler_sampler_converted_spherical.sample_loop(seed=0)

mse = torch.mean((euler_sampler_straight.trajectories[-1] - euler_sampler_converted_spherical.trajectories[-1])**2)
print(mse)

# Straight Converted to Spherical   v.s.   Original Straight RF
# zoom in to see they are really close
visualize_2d_trajectories_plotly_static(
	trajectories_dict={
        # "Straight RF": euler_sampler_straight.trajectories, 
        "Spherical RF": euler_sampler_converted_spherical.trajectories
    },
	D1_gt_samples=D1[:1000],
	num_trajectories=100,
)

In [None]:
num_samples = 5000
mse_list = []
steps_list = []

# 第一段：1～50，步长1
for num_steps in range(1, 11):
    euler_sampler_straight.sample_loop(seed=0, num_steps=num_steps, num_samples=num_samples)
    euler_sampler_converted_spherical.sample_loop(seed=0, num_steps=num_steps, num_samples=num_samples)
    mse = torch.mean((euler_sampler_straight.trajectories[-1] - euler_sampler_converted_spherical.trajectories[-1])**2)
    mse_list.append(mse.item())   # 建议转 .item()，得到纯 float
    steps_list.append(num_steps)
    print(num_steps)

# 第二段：60～100，步长10
for num_steps in range(20, 101, 10):
    euler_sampler_straight.sample_loop(seed=0, num_steps=num_steps, num_samples=num_samples)
    euler_sampler_converted_spherical.sample_loop(seed=0, num_steps=num_steps, num_samples=num_samples)
    mse = torch.mean((euler_sampler_straight.trajectories[-1] - euler_sampler_converted_spherical.trajectories[-1])**2)
    mse_list.append(mse.item())
    steps_list.append(num_steps)
    print(num_steps)

# 第三段：120～500，步长20
for num_steps in range(120, 501, 20):
    euler_sampler_straight.sample_loop(seed=0, num_steps=num_steps, num_samples=num_samples)
    euler_sampler_converted_spherical.sample_loop(seed=0, num_steps=num_steps, num_samples=num_samples)
    mse = torch.mean((euler_sampler_straight.trajectories[-1] - euler_sampler_converted_spherical.trajectories[-1])**2)
    mse_list.append(mse.item())
    steps_list.append(num_steps)
    print(num_steps)

# -----------------------------
# 绘图部分

In [None]:
title = "MSE vs. Number of Steps"

fig = go.Figure()

fig.add_trace(
    go.Scatter(
        x=steps_list,
        y=mse_list,
        mode='lines+markers',
        name='MSE',
        marker=dict(size=5),
        hovertemplate='( %{x} Steps, MSE=%{y:.5f} )<extra></extra>'
    )
)

# X 轴范围至少包含 [0, 500]，可以再多一点余量
fig.update_xaxes(
    # type='log',
    title_text='num_steps',
    range=[-10, 520],  # 让坐标范围到 510，使 500 那个点不贴边
    tickmode='array',
    tickvals=[100, 200, 300, 400, 500],
    ticktext=['100', '200', '300', '400', '500'],
    showline=False,
    zeroline=False
)

# Y 轴用对数刻度，标题改为对数形式
fig.update_yaxes(
    type='log',
    title_text='MSE (log scale)',
    tickmode='array',
    tickvals=[1e-1, 1e-2, 1e-3, 1e-4, 1e-5],
    ticktext=['1e-1', '1e-2', '1e-3', '1e-4', '1e-5'],
    showline=False,
    zeroline=False,
    mirror=False
)

fig.update_layout(
    title={
        'text': title,
        'x': 0.5,       # 标题居中
        'xanchor': 'center'
    },
    # width=500,
    # height=400,
    autosize=True,
    margin=dict(l=0, r=0, t=30, b=0)
)

fig.show()

fig.write_html(
	'interp_mse_step.html', 
	full_html=True,
	include_plotlyjs="cdn",
	include_mathjax='cdn', 
)

In [None]:
# Try different num_steps, e.g. [5, 10, 50, 100, 500]
num_samples = 250
num_steps = 100
euler_sampler_converted_spherical = EulerSampler(convert_spherical_rf, num_steps=num_steps, num_samples=num_samples)
euler_sampler_converted_spherical.sample_loop(seed=0)

euler_sampler_spherical = EulerSampler(spherical_rf, num_steps=num_steps)
euler_sampler_spherical.sample_loop(seed=0, num_samples=num_samples)

mse = torch.mean((euler_sampler_spherical.trajectories[-1] - euler_sampler_converted_spherical.trajectories[-1])**2)
print(mse)

# zoom in to see they are really close
visualize_2d_trajectories_plotly(
	trajectories_dict={
        "RF trained from spherical": euler_sampler_spherical.trajectories, 
        "RF trained from straight": euler_sampler_converted_spherical.trajectories
    },
	D1_gt_samples=D1[:5000],
	num_trajectories=50,
	title=f"Convert Straight to Spherical RF v.s. Spherical RF",
)