In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import math

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ===========================
# Diffusion schedule
# ===========================
T = 1000
beta_start = 1e-4
beta_end   = 2e-2
betas = torch.linspace(beta_start, beta_end, T, device=device)
alphas = 1.0 - betas
alpha_bar = torch.cumprod(alphas, dim=0)
sigmas = torch.sqrt(1.0 - alpha_bar)

# ===========================
# Data: mixture Gaussian -> point cloud
# ===========================
mu1 = torch.tensor([+2.0, 0.0, 0.0], device=device)
mu2 = torch.tensor([-2.0, 0.0, 0.0], device=device)

N_POINTS = 64

def sample_x0_cloud(batch_size: int, n_points: int = N_POINTS):
    B = batch_size
    mix = torch.bernoulli(0.5*torch.ones(B, n_points, device=device))  # (B,N)
    mu = torch.where(mix.unsqueeze(-1)==1, mu1, mu2)  # (B,N,3)
    eps = torch.randn(B, n_points, 3, device=device)
    return mu + eps

# ===========================
# Fourier time embedding
# ===========================
class FourierTimeEmbedding(nn.Module):
    def __init__(self, dim=64, max_freq=1000.0):
        super().__init__()
        self.freqs = torch.exp(
            torch.linspace(0, math.log(max_freq), dim//2)
        )

    def forward(self, t):
        t_norm = t.float() / float(T)
        args = t_norm.unsqueeze(-1) * self.freqs.to(t.device)
        return torch.cat([torch.sin(args), torch.cos(args)], dim=-1)

# ===========================
# EGNN Layer
# ===========================
class EGNNLayer(nn.Module):
    def __init__(self, time_dim=64, hidden_dim=64):
        super().__init__()
        self.edge_mlp = nn.Sequential(
            nn.Linear(1 + time_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, 1),
            nn.Tanh()           # keep weights stable
        )

    def forward(self, x, t_emb):
        B, N, _ = x.shape

        diff = x.unsqueeze(2) - x.unsqueeze(1)           # (B,N,N,3)
        r2 = (diff**2).sum(dim=-1, keepdim=True)         # (B,N,N,1)

        te = t_emb.unsqueeze(1).unsqueeze(2).expand(B,N,N,-1)
        h = torch.cat([r2, te], dim=-1)

        w = self.edge_mlp(h)                             # (B,N,N,1)

        dx = (w * diff).sum(dim=2)                       # (B,N,3)

        return x + dx

# ===========================
# EGNN model: predicts x0 (not score)
# ===========================
class EGNN_x0(nn.Module):
    def __init__(self, n_layers=3, time_dim=64, hidden_dim=64):
        super().__init__()
        self.t_emb = FourierTimeEmbedding(time_dim)
        self.layers = nn.ModuleList(
            [EGNNLayer(time_dim=time_dim, hidden_dim=hidden_dim)
             for _ in range(n_layers)]
        )
        self.out_mlp = nn.Sequential(
            nn.Linear(3, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, 3)
        )

    def forward(self, x_t, t):
        B, N, _ = x_t.shape
        t_emb = self.t_emb(t)

        x = x_t
        for layer in self.layers:
            x = layer(x, t_emb)

        x_flat = x.reshape(B*N, 3)
        x0_flat = self.out_mlp(x_flat)
        return x0_flat.reshape(B, N, 3)

model = EGNN_x0().to(device)

# ===========================
# Loss: predict x0
# ===========================
def loss_x0(model, batch_size):
    x0 = sample_x0_cloud(batch_size)

    t = torch.randint(0, T, (batch_size,), device=device)
    alpha_bar_t = alpha_bar[t].view(-1,1,1)
    sigma_t     = sigmas[t].view(-1,1,1)

    eps = torch.randn_like(x0)
    x_t = torch.sqrt(alpha_bar_t)*x0 + sigma_t*eps

    x0_hat = model(x_t, t)

    return ((x0_hat - x0)**2).mean()

# ===========================
# Training loop
# ===========================
batch_size = 8
lr = 2e-4
num_steps = 10000

optimizer = optim.Adam(model.parameters(), lr=lr)
loss_history = []

model.train()
for step in range(1, num_steps+1):
    loss = loss_x0(model, batch_size)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    loss_history.append(loss.item())

    if step % 200 == 0:
        print(f"step {step:5d} | loss={loss.item():.4f}")


step   200 | loss=22976.5449
step   400 | loss=29453.3379
step   600 | loss=6475.1079
step   800 | loss=2678.5947
step  1000 | loss=42146.5430
step  1200 | loss=2042.0707
step  1400 | loss=78.6997
step  1600 | loss=133.2355
step  1800 | loss=381.1270
step  2000 | loss=30.1496
step  2200 | loss=122.8817
step  2400 | loss=1051.9303
step  2600 | loss=304.2815
step  2800 | loss=58.0293
step  3000 | loss=165.8822
step  3200 | loss=58.8817
step  3400 | loss=71.2147
step  3600 | loss=7.5602
step  3800 | loss=11.5947
step  4000 | loss=10.4281
step  4200 | loss=20.7799
step  4400 | loss=232.3970
step  4600 | loss=91.5541
step  4800 | loss=159.9243
step  5000 | loss=101.6416
step  5200 | loss=34.0025
step  5400 | loss=16.4752
step  5600 | loss=19.3085
step  5800 | loss=11.2452
step  6000 | loss=10.1382
step  6200 | loss=16.9152
step  6400 | loss=6.0926
step  6600 | loss=1156.8262
step  6800 | loss=3.6356
step  7000 | loss=3.6834
step  7200 | loss=3.5911
step  7400 | loss=2.9371
step  7600 | loss

## 試著做做看 validation loss

In [2]:
@torch.no_grad()
def val_loss_x0(model, batch_size):
    x0 = sample_x0_cloud(batch_size)
    t = torch.randint(int(0.1*T), T, (batch_size,), device=device)

    alpha_bar_t = alpha_bar[t].view(-1,1,1)
    sigma_t     = sigmas[t].view(-1,1,1)

    eps = torch.randn_like(x0)
    x_t = torch.sqrt(alpha_bar_t)*x0 + sigma_t*eps

    x0_hat = model(x_t, t)

    return ((x0_hat - x0)**2).mean().item()

if step % 500 == 0:
    vloss = val_loss_x0(model, batch_size)
    print(f"[VAL] step {step} | val_loss={vloss:.4f}")


[VAL] step 10000 | val_loss=2.7006
