In [None]:
from IPython.display import HTML

HTML('''<script src="https://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.7/MathJax.js?config=TeX-AMS-MML_SVG"></script>''')

In [2]:
import torch
import os
import sys
import matplotlib.pyplot as plt
import numpy as np
import torch.distributions as dist

from matplotlib import rcParams
rcParams['font.family'] = 'serif'

from rectified_flow.utils import set_seed
from rectified_flow.datasets.toy_gmm import TwoPointGMM, CircularGMM
from rectified_flow.rectified_flow import RectifiedFlow
from rectified_flow.models.toy_mlp import MLPVelocityConditioned, MLPVelocity
from rectified_flow.samplers import SDESampler, EulerSampler, CurvedEulerSampler, rf_samplers_dict

# Set seed for reproducibility
set_seed(0)

# Set device to GPU if available, else CPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [294]:
import plotly.graph_objects as go

def visualize_2d_trajectories_plotly(
    trajectories_list: list[torch.Tensor],
    D1_gt_samples: torch.Tensor = None,
    num_trajectories: int = 50,
    dimensions: list[int] = [0, 1],
    alpha_trajectories: float = 0.7,
    alpha_generated_points: float = 1.0,
    alpha_gt_points: float = 1.0,
    color_generated_points: str = 'red',
    color_gt_points: str = 'thistle',
    color_trajectories: str = 'lightsalmon',
    color_noise_points: str = 'royalblue',
    title: str = '2D Trajectories',
):
    """
    Visualize the evolution of samples (trajectories) from the noise distribution (D0)
    to the target distribution (D1) using Plotly.

    Parameters:
        trajectories_list (list): A list of torch.Tensors, each representing samples
                                  at a given timestep along the path.
        D1_gt_samples (torch.Tensor, optional): Samples from the target distribution (D1).
        num_trajectories (int): How many individual trajectories (paths) to plot.
        dimensions (list): Which 2D dimensions to visualize.
        alpha_trajectories (float): Transparency of the trajectory lines.
        alpha_generated_points (float): Transparency of the final generated points.
        alpha_gt_points (float): Transparency of the ground truth points.
    """
    dim0, dim1 = dimensions
    fig = go.Figure()

    # Convert ground truth samples to NumPy if provided
    if D1_gt_samples is not None:
        D1_gt_samples = D1_gt_samples.detach().cpu().numpy()

    # Convert trajectories to a NumPy array for plotting
    traj_list_flat = [traj.detach().cpu() for traj in trajectories_list]
    xtraj = torch.stack(traj_list_flat).numpy()
    print(xtraj.shape)

    all_line_x = []
    all_line_y = []
    
    # Plot trajectories for a subset of points to illustrate their paths
    for i in range(min(num_trajectories, xtraj.shape[1])):
        line_x = xtraj[:, i, dim0]
        line_y = xtraj[:, i, dim1]
        all_line_x.extend(line_x.tolist() + [np.nan])
        all_line_y.extend(line_y.tolist() + [np.nan])
    
    fig.add_trace(
        go.Scatter(
            x=all_line_x,
            y=all_line_y,
            mode="lines",
            name=f"Trajectory",
            line=dict(dash="solid", color=color_trajectories, width=0.6),
            opacity=alpha_trajectories,
            showlegend=True,
            hoverinfo="skip",
        )
    )

    # Plot ground truth distribution (D1)
    if D1_gt_samples is not None:
        fig.add_trace(
            go.Scatter(
                x=D1_gt_samples[:, dim0],
                y=D1_gt_samples[:, dim1],
                mode='markers',
                marker=dict(color=color_gt_points, size=6, opacity=alpha_gt_points),
                name=r'$\pi_1$ (target)',
                hovertemplate='(%{x:.3f}, %{y:.3f})<extra>target</extra>',
            )
        )

    # Plot initial distribution (D0)
    fig.add_trace(
        go.Scatter(
            x=xtraj[0][:, dim0],
            y=xtraj[0][:, dim1],
            mode='markers',
            marker=dict(color=color_noise_points, size=6, opacity=alpha_gt_points),
            name=r'$\pi_0$ (initial noise)',
            hovertemplate='(%{x:.3f}, %{y:.3f})<extra>initial noise</extra>',
        )
    )

    # Plot final generated points
    fig.add_trace(
        go.Scatter(
            x=xtraj[-1][:, dim0],
            y=xtraj[-1][:, dim1],
            mode='markers',
            marker=dict(color=color_generated_points, size=6, opacity=alpha_generated_points),
            name='Generated',
            hovertemplate='(%{x:.3f}, %{y:.3f})<extra>generated</extra>',
        )
    )

    fig.update_xaxes(
        range=[-3.8, 17.4],
        showgrid=True,      
        griddash="solid",          
        showticklabels=False,      
        showline=True,            
        zeroline=False,           
        mirror=True,
        dtick=4.24,
        tick0=-3.8,
    )

    fig.update_yaxes(
        range=[-9.5, 9.5],
        showgrid=True,
        griddash="solid",
        showticklabels=False,
        showline=True,
        zeroline=False,
        mirror=True,
        dtick=3.8,
        tick0=-9.5,
    )
    
    # Update layout for better visualization
    fig.update_layout(
        template="plotly_white",
        margin=dict(l=10, r=10, t=25, b=10),
        title={
            "text": title,
            "font": {"size": 16},  # 标题字体大小
			"x": 0.5,  # 标题居中 (x=0 左对齐, x=1 右对齐)
			"xanchor": "center",  # 确保标题基于中心对齐
            "y": 0.99,
			"yanchor": "top",      # 确保标题位于顶部
		},
    )

    # fig.show()

    return fig

In [None]:
# Number of samples to generate
n_samples = 50000

# Define the source (D0) as a standard Gaussian in 2D
pi0 = dist.MultivariateNormal(torch.zeros(2, device=device), torch.eye(2, device=device))

# Define the target (D1) as a two-point Gaussian mixture
pi1 = TwoPointGMM(x=15.0, y=7.5, std=0.5)

# Sample from these distributions
D0 = pi0.sample([n_samples])
D1 = pi1.sample([n_samples])

# Plot the initial (D0) and target (D1) distributions
plt.scatter(D0[:, 0].cpu(), D0[:, 1].cpu(), alpha=0.5, label=r'Initial distribution $\pi_0$', color='royalblue')
plt.scatter(D1[:, 0].cpu(), D1[:, 1].cpu(), alpha=0.5, label=r'Target distribution $\pi_1$', color='thistle')
plt.legend()
plt.title("Initial and Target Distributions")
plt.show() 

In [None]:
# Initialize the velocity model (MLP)
model = MLPVelocity(
    dim=2,               # Input dimension is 2D
    hidden_sizes=[128, 128, 128]
).to(device)

# Define the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# Batch size for training
batch_size = 1024

# Create the RectifiedFlow object
rectflow = RectifiedFlow(
    data_shape=pi1.sample([1]).squeeze(0),
    velocity_field=model,
    interp="straight",
    source_distribution=pi0,
    device=device,
)

# List to store the training loss at each step
losses = []

# Train for a fixed number of steps
num_steps = 5000
for step in range(num_steps):
    optimizer.zero_grad()

    # Sample a batch of points from D0 and D1
    X_0 = pi0.sample([batch_size]).to(device)
    X_1 = pi1.sample([batch_size]).to(device)

    # Sample random times t
    t = rectflow.sample_train_time(batch_size)

    # Interpolate between X_0 and X_1 at times t
    Xt, dot_Xt = rectflow.interp(X_0, X_1, t)

    # Predict velocity from the model
    v_t = rectflow.get_velocity(Xt, t.squeeze())

    # Compute mean squared error loss
    loss = torch.mean((dot_Xt - v_t)**2)
    loss.backward()
    optimizer.step()

    losses.append(loss.item())

    # Print progress every 100 steps
    if step % 1000 == 0:
        print(f"Step {step}, Loss: {loss.item():.4f}")
        


# Plot the training loss
plt.figure(figsize=(4,3))
plt.plot(losses)
plt.title("Training Loss")
plt.xlabel("Training Step")
plt.ylabel("Loss")

In [None]:
# Use an Euler sampler to generate samples and their trajectories
e_sampler = EulerSampler(rectified_flow=rectflow, num_steps=100, num_samples=1000)
e_sampler.sample_loop()
# Visualize the trajectories from D0 to D1
fig = visualize_2d_trajectories_plotly(e_sampler.trajectories, D1[:1000], num_trajectories=200)
fig.update_layout(width=800, height=600)
fig.show()

In [None]:
from rectified_flow.samplers import SDESampler, EulerSampler

seed = 2

# Initialize an Euler sampler (deterministic)
euler_sampler = EulerSampler(
    rectified_flow=rectflow, # rectflow should be defined from previous code cells
    num_steps=100,
    num_samples=1000
)
euler_sampler.sample_loop(seed=seed)

# Initialize an SDE sampler (stochastic)
sde_sampler = SDESampler(
    rectified_flow=rectflow,
    num_steps=100,
    num_samples=1000,
    noise_scale=2,
    noise_decay_rate=0
)
sde_sampler.sample_loop(seed=seed)

# Plot Euler sampler results
fig = visualize_2d_trajectories_plotly(
    euler_sampler.trajectories,
    D1[:1000], # D1 defined previously
    num_trajectories=500,
    title="Deterministic Sampler"
)
fig.update_layout(
    autosize=True,
    showlegend=False
)
fig.show()
fig.write_html(
	'../diffusion_deterministic_single.html', 
	full_html=True,
	include_plotlyjs="cdn",
	include_mathjax='cdn', 
)

# Plot SDE sampler results
fig = visualize_2d_trajectories_plotly(
    sde_sampler.trajectories,
    D1[:1000],
    num_trajectories=500,
    title="Stochastic Sampler",
)
fig.update_layout(
    autosize=True,
    showlegend=False
)
fig.show()
fig.write_html(
	'../diffusion_stochastic_single.html', 
	full_html=True,
	include_plotlyjs="cdn",
	include_mathjax='cdn', 
)

In [None]:
from plotly.subplots import make_subplots

seed=3

fig = make_subplots(
    rows=1,
    cols=4,
    column_widths=[0.235, 0.235, 0.235, 0.235],
    horizontal_spacing=0.02,
)

fig.update_xaxes(
        range=[-3.1, 16.9],
        showgrid=True,      
        griddash="solid",          
        showticklabels=False,      
        showline=True,            
        zeroline=False,           
        mirror=True,
        dtick=4,
        tick0=-3.1,
    )

fig.update_yaxes(
	range=[-9.5, 9.5],
	showgrid=True,
	griddash="solid",
	showticklabels=False,
	showline=True,
	zeroline=False,
	mirror=True,
	dtick=3.8,
    tick0=-9.5,
)

for idx, scale in enumerate([0.5, 3, 5, 100]):
    # Initialize an SDE sampler (stochastic)
    sde_sampler = SDESampler(
        rectified_flow=rectflow,
        num_steps=100,
        num_samples=500,
        noise_scale=scale,
        noise_decay_rate=0
    )
    sde_sampler.sample_loop(seed=seed)

    # Generate trajectories plot for the current scale
    scale_fig = visualize_2d_trajectories_plotly(
        sde_sampler.trajectories,
        D1[:1000],  # D1 defined previously
        num_trajectories=200
    )

    # Add traces from the scale figure to the subplot
    for trace in scale_fig.data:
        fig.add_trace(trace, row=1, col=idx + 1)


fig.update_layout(
    # width=1000,  # Adjust width for a better layout
    # height=300,  # Adjust height for a better layout
    autosize=True,
    showlegend=False,  
    template="plotly_white",
    margin=dict(l=10, r=10, t=35, b=10),
    annotations=[
        dict(
            x=0.92,
			y=1.07,
			xref="paper",
			yref="paper",
			axref="pixel",
			ayref="pixel",
			ax=-620.,
			ay=1.07,
			arrowhead=3,
			arrowwidth=2,
			arrowcolor="steelblue",
			text="", 
			showarrow=True,
        ),
        dict(
            x=0.5,
            y=1.22,   # 比 1.05 再高一点，让文字在线上方
            xref="paper",
            yref="paper",
            text="Noise Scale Low to High",
            showarrow=False,
            font=dict(size=16),
        ),
    ]
)


# Show the combined figure
fig.show()

fig.write_html(
	'../diffusion_noise_scales_4pics.html', 
	full_html=True,
	include_plotlyjs="cdn",
	include_mathjax='cdn', 
)

In [None]:
sde_sampler = SDESampler(rectified_flow=rectflow,num_steps=100,num_samples=500, noise_scale=noise_scale, noise_decay_rate=noise_decay_rate)
sde_sampler.sample_loop(seed=1)

print(len(sde_sampler.trajectories), len(sde_sampler.time_points))

print(sde_sampler.time_points)

In [None]:
import plotly.graph_objects as go
import torch
import numpy as np

def plot_arrows_plotly(sampler, show_margin=False):
    xts = torch.stack(sampler.trajectories) # [num_steps, num_samples, 2]
    vts = []
    sts = []

    for xt, t in zip(sampler.trajectories, sampler.time_grid):
        if t == 1: break
        t = rectflow.match_dim_with_data(t, xt.shape)
        vt = rectflow.get_velocity(xt, t.squeeze())
        st = rectflow.get_score_function_from_velocity(xt, vt, t)

        st = st * (1 - t)  # Scale the score function to avoid explosion as t --> 1
        vts.append(vt)
        sts.append(st)
        # print(t, st)

    vts = torch.stack(vts).detach().cpu() # [num_steps, num_samples, 2]
    sts = torch.stack(sts).detach().cpu()

    if show_margin:
        sortedidx = torch.argsort(xts[-30, :, 1])
        sortedidx1 = torch.argsort((xts[-20, :, 1] + 2) ** 2)
        sortedidx2 = torch.argsort((xts[-20, :, 1] - 2) ** 2)
        idx = torch.cat([
            sortedidx[:1].view(-1),
            sortedidx[-1:].view(-1),
            sortedidx1[2].view(-1),
            sortedidx2[1].view(-1),
        ])
    else:
        # val = xts[-1, :, 1]
        val = torch.sum(xts[:, :, 1], dim=0)
        sortedidx = torch.argsort(val)
        idx1 = sortedidx[-1].view(-1)
        # 3) idx3: 数值最小的索引（sortedidx[0]）
        idx3 = sortedidx[0].view(-1)
        # 2) idx2: > 0 中最小值的索引 —— 从小到大遍历, 找到第一个大于0的
        idx2 = None
        for i in sortedidx:
            if val[i] > 0:
                idx2 = i.view(-1)
                break
        # 4) idx4: < 0 中最大值的索引 —— 从大到小遍历, 找到第一个小于0的
        idx4 = None
        for i in reversed(sortedidx):
            if val[i] < 0:
                idx4 = i.view(-1)
                break
        idx = torch.cat([idx1, idx2, idx3, idx4])

    xs_flat = (xts[10::10, idx, :]).reshape(-1, 2)
    ss_flat = (sts[10::10, idx, :]).reshape(-1, 2)

    x_coords, y_coords = xs_flat[:, 0].numpy(), xs_flat[:, 1].numpy()
    u_velocities, v_velocities = ss_flat[:, 0].numpy(), ss_flat[:, 1].numpy()
    
    fig = visualize_2d_trajectories_plotly(sampler.trajectories,D1[:500],num_trajectories=200,
                              alpha_trajectories=0.5, alpha_generated_points=1, alpha_gt_points=1)

    # Plot the trajectories
    all_line_x = []
    all_line_y = []
    
    for i in range(len(idx)):
        line_x = xts[:, idx[i], 0].detach().cpu().numpy()
        line_y = xts[:, idx[i], 1].detach().cpu().numpy()
        all_line_x.extend(line_x.tolist() + [np.nan])
        all_line_y.extend(line_y.tolist() + [np.nan])
        
    fig.add_trace(
        go.Scatter(
            x=all_line_x,
            y=all_line_y,
            mode="lines",
            name=f"Trajectory",
            line=dict(dash="solid", color='blue', width=1),
            opacity=1,
            showlegend=True,
            hovertemplate='(%{x:.3f}, %{y:.3f})<extra>trajectory</extra>',
        )
    )
    
    # Add arrows using layout.Annotation
    list_of_all_arrows = []
    for x0, y0, u, v in zip(x_coords, y_coords, u_velocities, v_velocities):
        magnitude = np.sqrt(u**2 + v**2)

        if magnitude == 0:
            continue
        
        scale_factor = 1 / magnitude
        u_norm = scale_factor * u
        v_norm = scale_factor * v
            
        arrow = go.layout.Annotation(
            x=x0 + u_norm,
            y=y0 + v_norm,
            xref="x", yref="y",
            text="",  
            showarrow=True,
            axref="x", ayref="y",
            ax=x0,
            ay=y0,
            arrowhead=3,  # Arrowhead style
            arrowwidth=2,
            arrowcolor='rgb(255,51,0)',  # Red arrows
        )
        list_of_all_arrows.append(arrow)

    # Add a manual arrow with text outside the legend
    fig.add_annotation(
        x=1.02,
        y=0.72,
        xref="paper",
        yref="paper",
        text="Score Function",
        showarrow=True,
        arrowhead=3,
        arrowwidth=1.5,
        arrowcolor="red",
        ax=66.0,
        ay=0,
        axref="pixel",
        ayref="pixel"
    )

    # Update the layout with annotations
    fig.update_layout(annotations=list_of_all_arrows)
    
    fig.update_xaxes(
        range=[-3.4, 16.9],
        showgrid=True,      
        griddash="solid",          
        showticklabels=False,      
        showline=True,            
        zeroline=False,           
        mirror=True,
        dtick=4.06,
        tick0=-3.4,
    )

    fig.update_yaxes(
        range=[-9.5, 9.5],
        showgrid=True,
        griddash="solid",
        showticklabels=False,
        showline=True,
        zeroline=False,
        mirror=True,
        dtick=3.8,
        tick0=-9.5,
    )

    # Update layout
    fig.update_layout(
        title={
            "text": r"$\textrm{Score Function } \nabla \log \rho_t(x) \textrm{ in SDE Trajectories}$",
            "font": {"size": 20},  # 标题字体大小
            "x": 0.5,  # 标题居中 (x=0 左对齐, x=1 右对齐)
            "y": 0.96,
            "xanchor": "center",  # 确保标题基于中心对齐
            "yanchor": "top"      # 确保标题位于顶部
        },
        template="plotly_white",
        margin=dict(l=10, r=120, t=30, b=10),
        autosize=True,
    )
    
    fig.show()
    
    fig.write_html(
		'../diffusion_score_function.html', 
		full_html=True,
		include_plotlyjs="cdn",
		include_mathjax='cdn', 
	)
    # Save the figure as a high-resolution PNG

noise_scale = 2; noise_decay_rate = 0
sde_sampler = SDESampler(rectified_flow=rectflow,num_steps=100,num_samples=500, noise_scale=noise_scale, noise_decay_rate=noise_decay_rate)
sde_sampler.sample_loop(seed=13)
plot_arrows_plotly(sde_sampler)