In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import math

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ===========================
# Diffusion schedule
# ===========================
T = 1000
beta_start = 1e-4
beta_end   = 2e-2
betas = torch.linspace(beta_start, beta_end, T, device=device)
alphas = 1.0 - betas
alpha_bar = torch.cumprod(alphas, dim=0)
sigmas = torch.sqrt(1.0 - alpha_bar)

# ===========================
# Data: mixture Gaussian -> point cloud
# ===========================
mu1 = torch.tensor([+2.0, 0.0, 0.0], device=device)
mu2 = torch.tensor([-2.0, 0.0, 0.0], device=device)

N_POINTS = 64

def sample_x0_cloud(batch_size: int, n_points: int = N_POINTS):
    B = batch_size
    mix = torch.bernoulli(0.5*torch.ones(B, n_points, device=device))  # (B,N)
    mu = torch.where(mix.unsqueeze(-1)==1, mu1, mu2)  # (B,N,3)
    eps = torch.randn(B, n_points, 3, device=device)
    return mu + eps


In [3]:
# ===========================
# Fourier time embedding
# ===========================
class FourierTimeEmbedding(nn.Module):
    def __init__(self, dim=64, max_freq=1000.0):
        super().__init__()
        self.freqs = torch.exp(
            torch.linspace(0, math.log(max_freq), dim//2)
        )

    def forward(self, t):
        t_norm = t.float() / float(T)
        args = t_norm.unsqueeze(-1) * self.freqs.to(t.device)
        return torch.cat([torch.sin(args), torch.cos(args)], dim=-1)

# ===========================
# EGNN Layer
# ===========================
class EGNNLayer(nn.Module):
    def __init__(self, time_dim=64, hidden_dim=64):
        super().__init__()
        self.edge_mlp = nn.Sequential(
            nn.Linear(1 + time_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, 1),
            nn.Tanh()           # keep weights stable
        )

    def forward(self, x, t_emb):
        B, N, _ = x.shape

        diff = x.unsqueeze(2) - x.unsqueeze(1)           # (B,N,N,3)
        r2 = (diff**2).sum(dim=-1, keepdim=True)         # (B,N,N,1)

        te = t_emb.unsqueeze(1).unsqueeze(2).expand(B,N,N,-1)
        h = torch.cat([r2, te], dim=-1)

        w = self.edge_mlp(h)                             # (B,N,N,1)

        dx = (w * diff).sum(dim=2)                       # (B,N,3)

        return x + dx

# ===========================
# EGNN model: predicts x0 (not score)
# ===========================
class EGNN_x0(nn.Module):
    def __init__(self, n_layers=3, time_dim=64, hidden_dim=64):
        super().__init__()
        self.t_emb = FourierTimeEmbedding(time_dim)
        self.layers = nn.ModuleList(
            [EGNNLayer(time_dim=time_dim, hidden_dim=hidden_dim)
             for _ in range(n_layers)]
        )
        self.out_mlp = nn.Sequential(
            nn.Linear(3, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, 3)
        )

    def forward(self, x_t, t):
        B, N, _ = x_t.shape
        t_emb = self.t_emb(t)

        x = x_t
        for layer in self.layers:
            x = layer(x, t_emb)

        x_flat = x.reshape(B*N, 3)
        x0_flat = self.out_mlp(x_flat)
        return x0_flat.reshape(B, N, 3)

model = EGNN_x0().to(device)


In [5]:
# ===========================
# Loss: predict x0
# ===========================
def loss_x0(model, batch_size):
    x0 = sample_x0_cloud(batch_size)

    t = torch.randint(0, T, (batch_size,), device=device)
    alpha_bar_t = alpha_bar[t].view(-1,1,1)
    sigma_t     = sigmas[t].view(-1,1,1)

    eps = torch.randn_like(x0)
    x_t = torch.sqrt(alpha_bar_t)*x0 + sigma_t*eps

    x0_hat = model(x_t, t)

    return ((x0_hat - x0)**2).mean()

# ===========================
# Training loop
# ===========================
batch_size = 8
lr = 2e-4
num_steps = 10000

optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=0.5,
    patience=200,
    cooldown=100,
    min_lr=1e-8
)
loss_history = []

model.train()
for step in range(1, num_steps+1):
    loss = loss_x0(model, batch_size)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    loss_history.append(loss.item())

    if step % 200 == 0:
        print(f"step {step:5d} | loss={loss.item():.4f}")


step   200 | loss=82060.0859
step   400 | loss=13802.3799
step   600 | loss=2576.5325
step   800 | loss=4201.8862
step  1000 | loss=543.0429
step  1200 | loss=1893.8910
step  1400 | loss=2017.7432
step  1600 | loss=1149.2738
step  1800 | loss=19.5828
step  2000 | loss=17.8711
step  2200 | loss=98.6815
step  2400 | loss=10.8795
step  2600 | loss=188.3660
step  2800 | loss=1393.7628
step  3000 | loss=3816.8313
step  3200 | loss=617.4590
step  3400 | loss=6.9975
step  3600 | loss=9576.6611
step  3800 | loss=1398.7972
step  4000 | loss=12.6269
step  4200 | loss=1784.3511
step  4400 | loss=251.9515
step  4600 | loss=1251.1548
step  4800 | loss=3.9039
step  5000 | loss=3676.2754
step  5200 | loss=5.2457
step  5400 | loss=43.9872
step  5600 | loss=2.9570
step  5800 | loss=3.1693
step  6000 | loss=11.9074
step  6200 | loss=18.7120
step  6400 | loss=26.8876
step  6600 | loss=4.4952
step  6800 | loss=3.5921
step  7000 | loss=18.0147
step  7200 | loss=2.7119
step  7400 | loss=421.3731
step  7600 

In [7]:
import matplotlib.pyplot as plt

def plot_loss_curve(loss_list):
    """
    loss_list: Python list or Tensor containing loss values for each step
    """
    plt.figure(figsize=(8,4))
    plt.plot(loss_list, linewidth=1.5)
    plt.xlabel("Training step")
    plt.ylabel("Loss (DSM)")
    plt.title("DSM Training Loss Curve")
    plt.grid(True)
    plt.tight_layout()
    plt.show()

loss_history.pop([1:10])
plot_loss_curve(loss_history)

SyntaxError: invalid syntax (2337073829.py, line 16)

## 試著做做看 validation loss

In [8]:
@torch.no_grad()
def val_loss_x0(model, batch_size):
    x0 = sample_x0_cloud(batch_size)
    t = torch.randint(int(0.1*T), T, (batch_size,), device=device)

    alpha_bar_t = alpha_bar[t].view(-1,1,1)
    sigma_t     = sigmas[t].view(-1,1,1)

    eps = torch.randn_like(x0)
    x_t = torch.sqrt(alpha_bar_t)*x0 + sigma_t*eps

    x0_hat = model(x_t, t)

    return ((x0_hat - x0)**2).mean().item()

if step % 500 == 0:
    vloss = val_loss_x0(model, batch_size)
    print(f"[VAL] step {step} | val_loss={vloss:.4f}")


[VAL] step 10000 | val_loss=39.5570


## Reverse Diffusion

In [9]:
@torch.no_grad()
def reverse_sample(model, num_samples=200, n_points=N_POINTS):
    """
    反向 diffusion → 生成點雲
    產生 shape = (num_samples, n_points, 3)
    """
    model.eval()

    # 先從純 Gaussian 噪聲開始
    x_t = torch.randn(num_samples, n_points, 3, device=device)

    for t in range(T-1, -1, -1):
        t_batch = torch.full((num_samples,), t, device=device, dtype=torch.long)

        # predict x0
        x0_hat = model(x_t, t_batch)   # (B,N,3)

        alpha_bar_t   = alpha_bar[t]
        if t > 0:
            alpha_bar_prev = alpha_bar[t-1]
        else:
            alpha_bar_prev = torch.tensor(1.0, device=device)

        # reconstruction step
        mean = torch.sqrt(alpha_bar_prev) * x0_hat
        var  = 1 - alpha_bar_prev

        if t > 0:
            noise = torch.randn_like(x_t)
            x_t = mean + torch.sqrt(var) * noise
        else:
            x_t = mean

    return x_t.cpu()


## Visualization

In [10]:
import matplotlib.pyplot as plt

def plot_scatter_2d(samples):
    # samples shape = (B, N, 3)
    pts = samples.reshape(-1, 3)
    x = pts[:,0]
    y = pts[:,1]

    plt.figure(figsize=(5,5))
    plt.scatter(x, y, s=5, alpha=0.5)
    plt.title("Reverse Diffusion: 2D Scatter (x-y plane)")
    plt.xlabel("x")
    plt.ylabel("y")
    plt.grid(True)
    plt.show()


In [11]:
def plot_histogram_x(samples):
    pts = samples.reshape(-1, 3)
    x = pts[:,0]

    plt.figure(figsize=(6,4))
    plt.hist(x, bins=50, density=True, alpha=0.7)
    plt.title("1D Histogram of x (should show bimodal mixture)")
    plt.xlabel("x")
    plt.ylabel("density")
    plt.grid(True)
    plt.show()


In [12]:
import seaborn as sns
import numpy as np

def plot_kde_heatmap(samples):
    pts = samples.reshape(-1, 3)
    x = pts[:,0]
    y = pts[:,1]

    plt.figure(figsize=(6,5))
    sns.kdeplot(x=x, y=y, fill=True, cmap="viridis", thresh=0.05, levels=30)
    plt.title("2D KDE Heatmap (Density of Generated Samples)")
    plt.xlabel("x")
    plt.ylabel("y")
    plt.show()


In [13]:
from mpl_toolkits.mplot3d import Axes3D

def plot_scatter_3d(samples):
    pts = samples.reshape(-1, 3)
    x = pts[:,0]
    y = pts[:,1]
    z = pts[:,2]

    fig = plt.figure(figsize=(6,6))
    ax = fig.add_subplot(111, projection='3d')
    ax.scatter(x, y, z, s=5, alpha=0.4)

    ax.set_title("3D Scatter of Generated Samples")
    ax.set_xlabel("x")
    ax.set_ylabel("y")
    ax.set_zlabel("z")

    plt.show()


In [16]:
samples = reverse_sample(model, num_samples=200, n_points=32)

plot_scatter_2d(samples)
#plot_histogram_x(samples)
#plot_kde_heatmap(samples)
#plot_scatter_3d(samples)


KeyboardInterrupt: 