In [None]:
from IPython.display import HTML

HTML('''<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.7/MathJax.js?config=TeX-AMS-MML_SVG"></script>''')

# Lady Windermere's Fan

In [None]:
import numpy as np
import plotly.graph_objects as go

def f(x):
    return -np.exp(-((x - 0.37)**2 / 2.5)) + 1.5

def f_prime(x):
    return (2 / 2.5) * (x - 0.37) * np.exp(-((x-0.37)**2/ 2.5))

t_values = np.array([0.0, 0.126, 0.239, 0.352, 0.504, 0.695, 1.0])
c = -3.0

def euler_method(t_values, x_start, c):
    t_points = t_values[t_values >= x_start]
    if len(t_points) == 0:
        return np.array([]), np.array([])
    
    y0 = f(x_start)
    y_euler = [y0]
    for i in range(len(t_points)-1):
        dt = t_points[i+1] - t_points[i]
        penalty = -c*(y_euler[-1] - f(t_points[i]))
        dydt = f_prime(t_points[i]) + penalty
        y_next = y_euler[-1] + dydt * dt
        y_euler.append(y_next)
    return t_points, np.array(y_euler)

fig = go.Figure()

x_exact = np.linspace(0, 1, 200)
y_exact = f(x_exact)
start_points = [0.695, 0.504, 0.352, 0.239, 0.126, 0.0]

all_y = [y_exact]
for sp in start_points:
    _, ty = euler_method(t_values, sp, c)
    all_y.append(ty)

all_y_vals = np.concatenate(all_y)
y_min = np.min(all_y_vals)
y_max = np.max(all_y_vals)

all_lines_at_t = {}
for t_val in t_values:
    exact_val = f(t_val)
    y_vals_at_t = [exact_val]
    for sp in start_points:
        tx, ty = euler_method(t_values, sp, c)
        idx = np.where(tx == t_val)
        if len(idx[0]) > 0:
            y_vals_at_t.append(ty[idx[0][0]])
    all_lines_at_t[t_val] = max(y_vals_at_t)

# Exact solution
fig.add_trace(go.Scatter(
    x=x_exact, 
    y=y_exact, 
    mode='lines', 
    line=dict(color='red', width=2),
    name='Exact ODE solution',
    showlegend=False,
    hovertemplate='(%{x:.3f}, %{y:.3f})<extra>%{fullData.name}</extra>'
))

y_starts = [f(sp) for sp in start_points]
fig.add_trace(go.Scatter(
    x=start_points,
    y=y_starts,
    mode='markers',
    name='Start point',
    marker=dict(color='red', size=6),
    showlegend=False,
    hovertemplate='(%{x:.3f}, %{y:.3f})<extra>%{fullData.name}</extra>'
))

# Euler trajectory
for i, sp in enumerate(start_points):
    tx, ty = euler_method(t_values, sp, c)
    fig.add_trace(go.Scatter(
        x=tx, 
        y=ty, 
        mode='lines',
        line=dict(color='black', width=1.5),
        name=f'{i+1}',
        showlegend=False,
        hovertemplate='(%{x:.3f}, %{y:.3f})<extra>Euler method traj %{fullData.name}</extra>'
    ))
    for j in range(len(tx)-1):
        fig.add_annotation(
            x=tx[j+1], 
            y=ty[j+1],
            ax=tx[j],
            ay=ty[j],
            xref="x", yref="y", axref="x", ayref="y",
            arrowhead=3,
            arrowsize=1.8,
            arrowwidth=1,
            arrowcolor="black",
            showarrow=True,
            text=""
        )

# t_i marker
for i, t_val in enumerate(t_values):
    max_y = all_lines_at_t[t_val]
    fig.add_shape(
        type="line",
        x0=t_val,
        x1=t_val,
        y0=y_min - 0.005,
        y1=max_y,
        line=dict(dash="dot", color="black", width=1),
        layer="below"
    )
    fig.add_annotation(
        x=t_val,
        y=y_min-0.007,
        text=f"$t_{i}$",
        showarrow=False,
        font=dict(size=20),
        yshift=-10
    )

fig.update_xaxes(
    showline=False,
    showticklabels=False,
    zeroline=False
)
fig.update_yaxes(
    showline=False,
    showticklabels=False,
    zeroline=False
)


# Horizontal line at y_min
fig.add_shape(
    type="line",
    x0=-0.05,
    x1=1.05,
    y0=y_min - 0.005,
    y1=y_min - 0.005,
    line=dict(color="black", width=1.5)
)

# Vertical line at x=0
fig.add_shape(
    type="line",
    x0=0,
    x1=0,
    y0=y_min - 0.01,
    y1=y_max + 0.01,
    line=dict(color="black", width=1.5)
)

y0_val = f(0.0)
fig.add_annotation(
    xref='x', yref='y',
    x=0, y=y0_val,
    text="$y_0$",
    showarrow=False,
    font=dict(size=20),
    xshift=-15
)

fig.update_layout(
    plot_bgcolor='white',
    # width=500,
    # height=300,
    autosize=True, # Turn this on when saving to HTML
    margin=dict(l=5, r=5, t=5, b=5),
    showlegend=False
)

fig.show()

fig.write_html(
    'intro_euler_method.html', 
    include_mathjax='cdn', 
    full_html=True,
    include_plotlyjs="cdn",
)

# Interpolation, Rectified Flow, Reflow

In [None]:
import torch
import numpy as np
import os
import sys
import matplotlib.pyplot as plt

import torch.distributions as dist

from rectified_flow.utils import set_seed
from rectified_flow.utils import visualize_2d_trajectories_plotly

from rectified_flow.rectified_flow import RectifiedFlow

set_seed(233)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

from rectified_flow.datasets.toy_gmm import TwoPointGMM

n_samples = 50000
pi_0 = TwoPointGMM(x=0.0, y=7.5, std=0.5, device=device)
pi_1 = TwoPointGMM(x=15.0, y=7.5, std=0.5, device=device)
D0 = pi_0.sample([n_samples])
D1, labels = pi_1.sample_with_labels([n_samples])
labels.tolist()

plt.figure(figsize=(5, 5))
plt.title(r'Samples from $\pi_0$ and $\pi_1$')
plt.scatter(D0[:, 0].cpu(), D0[:, 1].cpu(), alpha=0.5, label=r'$\pi_0$')
plt.scatter(D1[:, 0].cpu(), D1[:, 1].cpu(), alpha=0.5, label=r'$\pi_1$')
plt.legend()

In [None]:
set_seed(233)

x_0 = pi_0.sample([300])
x_0_upper = x_0.clone()
x_0_upper[:, 1] = torch.abs(x_0_upper[:, 1])
x_0_lower = x_0.clone()
x_0_lower[:, 1] = -torch.abs(x_0_lower[:, 1])

In [None]:
x_1_upper = pi_1.sample([300])
x_1_lower = pi_1.sample([300])

interp_upper = []
interp_lower = []

for t in np.linspace(0, 1, 51):
    x_t_uppper = (1 - t) * x_0_upper + t * x_1_upper
    x_t_lower = (1 - t) * x_0_lower + t * x_1_lower
    interp_upper.append(x_t_uppper)
    interp_lower.append(x_t_lower)
    
visualize_2d_trajectories_plotly(
    trajectories_dict={
        "upper interpolation": interp_upper,
		"lower interpolation": interp_lower
    },
    D1_gt_samples=torch.cat([x_1_upper, x_1_lower], dim=0),
    num_trajectories=120,
	title="Straight Interpolation",
)

In [3]:
# from rectified_flow.models.toy_mlp import MLPVelocity

# model = MLPVelocity(2, hidden_sizes = [64, 128, 128, 128, 64]).to(device)

# rectified_flow = RectifiedFlow(
#     data_shape=(2,),
#     velocity_field=model,
#     interp="straight",
#     source_distribution=pi_0,
#     train_time_distribution="lognormal",
#     device=device,
# )

In [38]:
# optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.03)

In [None]:
# batch_size = 128

# losses = []

# for step in range(10000):
# 	optimizer.zero_grad()
# 	idx = torch.randperm(n_samples)[:batch_size]
# 	# x_0 = D0[idx].to(device)
# 	# x_1 = D1[idx].to(device)

# 	x_0 = pi_0.sample([batch_size]).to(device)
# 	x_1 = pi_1.sample([batch_size]).to(device)
	
# 	loss = rectified_flow.get_loss(x_0, x_1)
# 	loss.backward()
# 	optimizer.step()
# 	losses.append(loss.item())

# 	if step % 1000 == 0:
# 		print(f"Epoch {step}, Loss: {loss.item()}")
		
# plt.plot(losses)

In [None]:
from rectified_flow.models.kernel_method import NadarayaWatson

velocity = NadarayaWatson(pi_0_sample=D0, pi_1_sample=D1, sample_size=5000, bandwidth=0.5, use_dot_x_t=True)

print(velocity.interp.name)

rectified_flow = RectifiedFlow(
    data_shape=(2,),
    velocity_field=velocity,
    interp=velocity.interp,
    source_distribution=pi_0,
    device=device,
)


In [None]:
from rectified_flow.samplers import EulerSampler
from rectified_flow.utils import visualize_2d_trajectories_plotly

euler_sampler_1rf_unconditional = EulerSampler(
    rectified_flow=rectified_flow,
    num_steps=1000,
)

print(x_0_upper.shape)

traj_upper = euler_sampler_1rf_unconditional.sample_loop(x_0=x_0_upper).trajectories
traj_lower = euler_sampler_1rf_unconditional.sample_loop(x_0=x_0_lower).trajectories

visualize_2d_trajectories_plotly(
    trajectories_dict={"upper": traj_upper, "lower": traj_lower},
    D1_gt_samples=D1[:1000],
    num_trajectories=150,
	title="Unconditional 1-Rectified Flow",
)

In [None]:
rf_traj_upper = []
rf_traj_lower = []

for idx in range(0, 1001, 20):
    rf_traj_lower.append(traj_lower[idx])
    rf_traj_upper.append(traj_upper[idx])

visualize_2d_trajectories_plotly(
    trajectories_dict={"upper": rf_traj_upper, "lower": rf_traj_lower},
	D1_gt_samples=D1[:1000],
	num_trajectories=120,
	title="Unconditional 1-Rectified Flow",
)

In [38]:
Z_0 = rectified_flow.sample_source_distribution(batch_size=50000)

Z_1 = euler_sampler_1rf_unconditional.sample_loop(x_0=Z_0, num_steps=100).trajectories[-1]

# mask = (Z_0[:, 1] * Z_1[:, 1]) > 0

# Z_0 = Z_0[mask]
# Z_1 = Z_1[mask]

In [39]:
reflow_velocity = NadarayaWatson(pi_0_sample=Z_0, pi_1_sample=Z_1, sample_size=5000, bandwidth=0.5, use_dot_x_t=True)

reflow = RectifiedFlow(
    data_shape=(2,),
	velocity_field=reflow_velocity,
	interp=reflow_velocity.interp,
	source_distribution=pi_0,
	device=device,
)

In [None]:
# import copy

# reflow = copy.deepcopy(rectified_flow)

# optimizer = torch.optim.AdamW(reflow.velocity_field.parameters(), lr=1e-3)
# batch_size = 1024

# losses = []

# for step in range(5000):
# 	optimizer.zero_grad()
# 	idx = torch.randperm(Z_0.shape[0])[:batch_size]
# 	x_0 = Z_0[idx]
# 	x_1 = Z_1[idx]
	
# 	x_0 = x_0.to(device)
# 	x_1 = x_1.to(device)
	
# 	loss = reflow.get_loss(x_0, x_1)
# 	loss.backward()
# 	optimizer.step()
# 	losses.append(loss.item())

# 	if step % 1000 == 0:
# 		print(f"Epoch {step}, Loss: {loss.item()}")
    
# plt.plot(losses)

In [None]:
euler_sampler_2rf = EulerSampler(
    rectified_flow=reflow,
    num_steps=1000,
)

reflow_upper = euler_sampler_2rf.sample_loop(x_0=x_0_upper).trajectories
reflow_lower = euler_sampler_2rf.sample_loop(x_0=x_0_lower).trajectories

# mask_upper = reflow_upper[-1][:, 1] >= 6
# reflow_upper = [step[mask_upper] for step in reflow_upper]

# mask_lower = reflow_lower[-1][:, 1] <= -6
# reflow_lower = [step[mask_lower] for step in reflow_lower]

visualize_2d_trajectories_plotly(
    trajectories_dict={"upper": reflow_upper, "lower": reflow_lower},
    D1_gt_samples=D1[:1000],
    num_trajectories=100,
	title="Unconditional 1-Rectified Flow",
)

In [None]:
reflow_traj_upper = []
reflow_traj_lower = []

for idx in range(0, 1001, 20):
    reflow_traj_upper.append(reflow_upper[idx])
    reflow_traj_lower.append(reflow_lower[idx])

visualize_2d_trajectories_plotly(
    trajectories_dict={"upper": reflow_traj_upper, "lower": reflow_traj_lower},
	D1_gt_samples=D1[:1000],
	num_trajectories=120,
	title="Unconditional 1-Rectified Flow",
)

In [74]:
def visualize_3_plots_side_by_side(
    # First plot parameters
    trajectories_dict_1: dict[str, list[torch.Tensor]],
    D1_gt_samples_1: torch.Tensor,
    num_trajectories_1: int,
    # Second plot parameters
    trajectories_dict_2: dict[str, list[torch.Tensor]],
    D1_gt_samples_2: torch.Tensor,
    num_trajectories_2: int,
    # Third plot parameters
    trajectories_dict_3: dict[str, list[torch.Tensor]],
    D1_gt_samples_3: torch.Tensor,
    num_trajectories_3: int,
    # Common parameters
    dimensions: list[int] = [0, 1],
    alpha_trajectories: float = 0.7,
    alpha_particles: float = 0.8,
    alpha_gt_points: float = 0.8,
    markersize: int = 4,
    range_x=[-1.8, 16.8],
    range_y=[-9.3, 9.3],
    caption1=r"$\text{Linear Interpolation } X_t$",
    caption2=r"$\text{Rectified Flow } Z_t$",
    caption3=r"$\text{Straightened Rectified Flow}$",
):
    import plotly.graph_objects as go
    from plotly.subplots import make_subplots
    import numpy as np
    import torch

    # Helper to process a single set of trajectories
    def process_trajectories(trajectories_dict, D1_gt_samples, num_trajectories, dim0, dim1, type):
        if D1_gt_samples is not None:
            D1_gt_samples = D1_gt_samples.clone().to(torch.float32).cpu().detach().numpy()

        # Prepare color mappings
        particle_colors = [
			"RoyalBlue",
			"OrangeRed",
            "#1E90FF",
            "#FF69B4",
            "#7B68EE",
            "#FF8C00",
            "#32CD32",
            "#4169E1",
            "#FF4500",
            "#9932CC",
            "#ADFF2F",
            "#FFD700",
        ]

        trajectory_colors = [
			"DodgerBlue",
			"Tomato",
            "#8ABDE5",
            "#E09CAF",
            "#B494E1",
            "#E5B680",
            "#82C9A1",
            "#92BCD5",
            "#E68FA2",
            "#A98FC8",
            "#E5C389",
            "#A0C696",
        ]

        marker_list = [
            "circle", "circle", "x", "square", "star", "diamond",  
            "triangle-up", "triangle-down", "hexagram"
        ]

        trajectory_names = list(trajectories_dict.keys())
        colors = {}
        for i, name in enumerate(trajectory_names):
            colors[name] = {
                "particle_color": particle_colors[i % len(particle_colors)],
                "trajectory_color": trajectory_colors[i % len(trajectory_colors)],
                "marker": marker_list[i % len(marker_list)],
            }

        # Process trajectories and store data
        trajectory_data = {}
        max_time_steps = 0

        for trajectory_name, traj_list in trajectories_dict.items():
            xtraj_list = [
                traj.clone().to(torch.float32).detach().cpu().numpy() for traj in traj_list
            ]
            xtraj = np.stack(xtraj_list)  # [time_steps, batch_size, dimension]
            trajectory_data[trajectory_name] = xtraj
            max_time_steps = max(max_time_steps, xtraj.shape[0])

        # Build static traces
        static_traces = []
        particle_traces_info = []
        current_trace_index = 0

        for trajectory_name, xtraj in trajectory_data.items():
            particle_color = colors[trajectory_name]["particle_color"]
            trajectory_color = colors[trajectory_name]["trajectory_color"]
            marker_symbol = colors[trajectory_name]["marker"]
            num_points = xtraj.shape[1]
            indices = np.arange(min(num_trajectories, num_points))

            # Plot lines for trajectories
            all_line_x = []
            all_line_y = []

            for i in indices:
                line_x = xtraj[:, i, dim0]
                line_y = xtraj[:, i, dim1]
                all_line_x.extend(line_x.tolist() + [np.nan])
                all_line_y.extend(line_y.tolist() + [np.nan])

            static_traces.append(
                go.Scatter(
                    x=all_line_x,
                    y=all_line_y,
                    mode="lines",
                    line=dict(dash="solid", color=trajectory_color, width=0.6),
                    opacity=alpha_trajectories,
                    showlegend=False,
                    hoverinfo="skip",
                )
            )
            current_trace_index += 1

            # Plot initial points
            static_traces.append(
                go.Scatter(
                    x=xtraj[0, :, dim0],
                    y=xtraj[0, :, dim1],
                    mode="markers",
                    marker=dict(
                        size=markersize, 
                        opacity=alpha_gt_points, 
                        color="blue", 
                        symbol=marker_symbol
                    ),
                    showlegend=False,
                    hovertemplate='(%{x:.3f}, %{y:.3f})<extra>Initial noise</extra>'
                )
            )
            current_trace_index += 1

            # Particle traces info for frames
            particle_traces_info.append(
                {
                    "trajectory_name": trajectory_name,
                    "x": xtraj[0, :, dim0],
                    "y": xtraj[0, :, dim1],
                    "particle_color": particle_color,
                    "marker_symbol": marker_symbol,
                    "trace_index": None,
                }
            )
            
		# Ground truth samples
        if D1_gt_samples is not None:
            static_traces.append(
                go.Scatter(
                    x=D1_gt_samples[:, dim0],
                    y=D1_gt_samples[:, dim1],
                    mode="markers",
                    marker=dict(size=markersize, opacity=alpha_gt_points, color="DarkOrchid"),
                    showlegend=False,
                    hovertemplate='(%{x:.3f}, %{y:.3f})<extra>Target data</extra>'
                )
            )
            current_trace_index += 1

        # Add a trace for each trajectory set for the moving particles
        for info in particle_traces_info:
            static_traces.append(
                go.Scatter(
                    x=info["x"],
                    y=info["y"],
                    mode="markers",
                    marker=dict(size=markersize, color=info["particle_color"], symbol=info["marker_symbol"]),
                    showlegend=False,
                    hovertemplate='(%{x:.3f}, %{y:.3f})'+f'<extra>{type}</extra>'
                )
            )
            info["trace_index"] = current_trace_index
            current_trace_index += 1

        return static_traces, trajectory_data, {info["trajectory_name"]: info["trace_index"] for info in particle_traces_info}, max_time_steps

    # Dimensions
    dim0, dim1 = dimensions

    # Process each plot
    static_traces_1, trajectory_data_1, particle_trace_indices_1, max_steps_1 = process_trajectories(
        trajectories_dict_1, D1_gt_samples_1, num_trajectories_1, dim0, dim1, type="Interpolation"
    )
    static_traces_2, trajectory_data_2, particle_trace_indices_2, max_steps_2 = process_trajectories(
        trajectories_dict_2, D1_gt_samples_2, num_trajectories_2, dim0, dim1, type="Rectified Flow"
    )
    static_traces_3, trajectory_data_3, particle_trace_indices_3, max_steps_3 = process_trajectories(
        trajectories_dict_3, D1_gt_samples_3, num_trajectories_3, dim0, dim1, type="Reflow",
    )

    max_time_steps = min(max_steps_1, max_steps_2, max_steps_3)

    # Create figure with 3 subplots
    fig = make_subplots(rows=1, cols=3, horizontal_spacing=0.01)

    # Add static traces to each subplot
    for trace in static_traces_1:
        fig.add_trace(trace, row=1, col=1)
    for trace in static_traces_2:
        fig.add_trace(trace, row=1, col=2)
    for trace in static_traces_3:
        fig.add_trace(trace, row=1, col=3)

    alpha_particles = 0.8

    def frame_data_for_trajectory(trajectory_data, particle_trace_indices, t):
        frame_data = []
        for trajectory_name, xtraj in trajectory_data.items():
            if t < xtraj.shape[0]:
                x = xtraj[t, :, dim0]
                y = xtraj[t, :, dim1]
                trace_index = particle_trace_indices[trajectory_name]
                frame_data.append((trace_index, go.Scatter(
                    x=x,
                    y=y,
                    mode="markers",
                    marker=dict(size=markersize, opacity=alpha_particles),
                    showlegend=False,
                )))
        return frame_data

    num_traces_subplot_1 = len(static_traces_1)
    num_traces_subplot_2 = len(static_traces_2)
    num_traces_subplot_3 = len(static_traces_3)

    offset_2 = num_traces_subplot_1
    offset_3 = num_traces_subplot_1 + num_traces_subplot_2

    def adjust_trace_indices(frame_data_list, offset):
        adjusted = []
        for (trace_index, scatter_obj) in frame_data_list:
            adjusted.append((trace_index + offset, scatter_obj))
        return adjusted

    frames = []
    for t in range(max_time_steps):
        fd1 = frame_data_for_trajectory(trajectory_data_1, particle_trace_indices_1, t)
        fd2 = frame_data_for_trajectory(trajectory_data_2, particle_trace_indices_2, t)
        fd3 = frame_data_for_trajectory(trajectory_data_3, particle_trace_indices_3, t)

        fd2 = adjust_trace_indices(fd2, offset_2)
        fd3 = adjust_trace_indices(fd3, offset_3)

        combined = fd1 + fd2 + fd3
        combined_sorted = sorted(combined, key=lambda x: x[0])

        frame_trace_indices = [c[0] for c in combined_sorted]
        frame_data_traces = [c[1] for c in combined_sorted]

        frames.append(go.Frame(data=frame_data_traces, name=str(t), traces=frame_trace_indices))

    # Create slider steps
    slider_steps = []
    for t in range(max_time_steps):
        step = dict(
            method="animate",
            args=[
                [str(t)],
                dict(
                    mode="immediate",
                    frame=dict(duration=0, redraw=True),
                    transition=dict(duration=0),
                ),
            ],
            label=str(t),
        )
        slider_steps.append(step)

    # Create sliders and buttons
    sliders = [
        dict(
            active=0,
            currentvalue={"prefix": "Step: "},
            pad={"t": 0, "b": 0, "l": 0, "r": 0},
            steps=slider_steps,
            x=0.5, xanchor="center",
            y=0.23, yanchor="top",
            font=dict(size=12)
        )
    ]

    updatemenus = [
        {
            "type": "buttons",
            "x": 0.5,
            "y": 0.25,
            "xanchor": "center",
            "yanchor": "top",
            "font": dict(size=12),
            "buttons": [
                {
                    "label": "Play",
                    "method": "animate",
                    "args": [
                        None,
                        {
                            "frame": {"duration": 500, "redraw": False},
                            "transition": {"duration": 400, "easing": "quadratic-in-out"},
                            "fromcurrent": True,
                            "mode": "immediate",
                            "loop": True,
                        },
                    ],
                },
            ],
        }
    ]

    # Update layout: no grid line, no x/y axis, no ticks
    for i in range(1, 4):
        fig.update_xaxes(
            visible=False,
            showgrid=False,
            zeroline=False,
            showline=False,
            range=range_x,
            row=1, col=i
        )
        fig.update_yaxes(
            domain=[0.3, 1.0],
            visible=False,
            showgrid=False,
            zeroline=False,
            showline=False,
            range=range_y,
            row=1, col=i
        )

    fig.update_layout(
        sliders=sliders,
        updatemenus=updatemenus,
        showlegend=False,
        # height=450,
        # width=950,
        autosize=True, # Turn this on when saving to HTML
        margin=dict(l=0, r=5, t=0, b=5),
        template="plotly_white",
    )

    # After layout update, we can get domain info
    # xaxis1.domain, xaxis2.domain, xaxis3.domain and yaxis1.domain
    xaxis1_domain = fig.layout.xaxis.domain
    xaxis2_domain = fig.layout.xaxis2.domain
    xaxis3_domain = fig.layout.xaxis3.domain
    yaxis1_domain = fig.layout.yaxis.domain

    x_mid_1 = (xaxis1_domain[0] + xaxis1_domain[1]) / 2
    x_mid_2 = (xaxis2_domain[0] + xaxis2_domain[1]) / 2
    x_mid_3 = (xaxis3_domain[0] + xaxis3_domain[1]) / 2

    # 将 caption 放在子图正下方(子图在 paper 坐标系下的 domain 范围为 [0,1])
    # y 放在稍低于子图底部的位置，如 y = yaxis1_domain[0] - 一定间隔
    # 因为只有一行子图，yaxis1.domain 应该是 [0,1]
    # 我们将 caption 放在 y = -0.25(参考先前手动指定值)
    y_caption = yaxis1_domain[0] + 0.01

    fig.add_annotation(
        x=x_mid_1, y=y_caption,
        xanchor="center", yanchor="top",
        text=caption1,
        showarrow=False,
        font=dict(size=16),
        xref='paper', yref='paper'
    )
    fig.add_annotation(
        x=x_mid_2, y=y_caption,
        xanchor="center", yanchor="top",
        text=caption2,
        showarrow=False,
        font=dict(size=16),
        xref='paper', yref='paper'
    )
    fig.add_annotation(
        x=x_mid_3, y=y_caption,
        xanchor="center", yanchor="top",
        text=caption3,
        showarrow=False,
        font=dict(size=16),
        xref='paper', yref='paper'
    )

    # Add frames
    fig.frames = frames

    fig.show()

    fig.write_html(
        './intro_rf_three_in_one.html', 
        full_html=True,
        include_plotlyjs="cdn",
    	include_mathjax='cdn', 
    )


In [None]:
visualize_3_plots_side_by_side(
    trajectories_dict_1={"upper interpolation": interp_upper, "lower interpolation": interp_lower},
    trajectories_dict_2={"upper rf": rf_traj_upper, "lower rf": rf_traj_lower},
    trajectories_dict_3={"upper reflow": reflow_traj_upper, "lower reflow": reflow_traj_lower},
    D1_gt_samples_1=torch.cat([x_1_upper, x_1_lower], dim=0),
    D1_gt_samples_2=torch.cat([x_1_upper, x_1_lower], dim=0),
    D1_gt_samples_3=torch.cat([x_1_upper, x_1_lower], dim=0),
    num_trajectories_1=150,
    num_trajectories_2=150,
    num_trajectories_3=150,
    dimensions=[0, 1],
    markersize=4,
    alpha_trajectories=0.7,
    alpha_particles=0.7,
    alpha_gt_points=0.7,
)