In [None]:
# =========================================================
# ESM vs DSM on 1D Gaussian: x ~ N(3, sigma_data^2)
# 目標：在「白板例子」上驗證 ESM 與 DSM 的對齊關係
# =========================================================
import math, random, numpy as np
import torch, torch.nn as nn, torch.nn.functional as F
import matplotlib.pyplot as plt

torch.manual_seed(0); np.random.seed(0); random.seed(0)
device = torch.device("cpu")  # 如有 GPU: torch.device("cuda")

# ----------------------------
# Data: 1D Gaussian  N(3, sigma_data^2)
# ----------------------------
sigma_data = 1.0   # 白板常用 1（你要改也行）
mu_data    = 3.0

def sample_data(n):
    return mu_data + sigma_data * torch.randn(n, 1, device=device)

def true_score(x):
    # s(x) = d/dx log N(3, sigma_data^2) = -(x-3)/sigma_data^2
    return -(x - mu_data) / (sigma_data**2)

# ----------------------------
# Model: tiny MLP score(x, [log σ])
# ----------------------------
class Score1D(nn.Module):
    def __init__(self, hidden=32, use_sigma=False):
        super().__init__()
        self.use_sigma = use_sigma
        in_dim = 1 + (1 if use_sigma else 0)
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden), nn.SiLU(),
            nn.Linear(hidden, hidden), nn.SiLU(),
            nn.Linear(hidden, 1),
        )
        for m in self.net:
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight); nn.init.zeros_(m.bias)

    def forward(self, x, sigma=None):
        if self.use_sigma and (sigma is not None):
            if not torch.is_tensor(sigma):
                sigma = torch.tensor(float(sigma), dtype=x.dtype, device=x.device)
            h = torch.cat([x, torch.log(sigma).expand_as(x)], dim=1)
        else:
            h = x
        return self.net(h)

# ----------------------------
# Losses
# ----------------------------
def esm_loss(model, n_batch=1024):
    """
    Exact（有真分數可用）：
    L_ESM = E_x [ || s_theta(x) - s_true(x) ||^2 ]
    這裡不需要 Hutchinson，因為白板案例已知真 score。
    """
    x = sample_data(n_batch)
    s_pred = model(x, sigma=None)
    s_true = true_score(x)
    return ((s_pred - s_true)**2).mean()

def dsm_loss_fixed_sigma(model, sigma_noise=0.2, n_batch=1024, lambda_by_sigma=True, clip=50.0):
    """
    y = x0 + sigma_noise*z,  z~N(0,1)
    target = (x0 - y)/sigma^2
    """
    x0 = sample_data(n_batch)
    z  = torch.randn_like(x0)
    y  = x0 + sigma_noise * z
    target = (x0 - y) / (sigma_noise**2)
    if clip is not None:
        target = torch.clamp(target, -clip, clip)
    s_pred = model(y, sigma=sigma_noise)
    per = (s_pred - target)**2
    if lambda_by_sigma:
        per = (sigma_noise**2) * per  # 常見 reweight
    return per.mean()

def dsm_loss_multisigma(model, n_batch=1024, low=5e-3, high=0.2, lambda_rule="sigma2", clip=50.0):
    """
    多層級噪聲（log-uniform）：σ ~ LogU(low, high)
    """
    x0 = sample_data(n_batch)
    u = torch.rand(n_batch, 1, device=device)
    sigmas = torch.exp(u * (math.log(high) - math.log(low)) + math.log(low))  # (B,1)
    z  = torch.randn_like(x0)
    y  = x0 + sigmas * z
    target = (x0 - y) / (sigmas**2)
    if clip is not None:
        target = torch.clamp(target, -clip, clip)
    s_pred = model(y, sigma=sigmas)
    per = (s_pred - target)**2
    if lambda_rule == "sigma2":
        per = (sigmas**2) * per
    return per.mean()

# ----------------------------
# Train loops（會印每步 loss）
# ----------------------------
def train_loop(model, loss_fn, steps=200, batch=1024, lr=1e-3, tag="ESM"):
    model.train()
    opt = torch.optim.Adam(model.parameters(), lr=lr)
    hist = []
    for t in range(1, steps+1):
        loss = loss_fn(model, n_batch=batch)
        opt.zero_grad(set_to_none=True)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        opt.step()
        hist.append(loss.item())
        if t % 20 == 0:
            print(f"[{tag} {t:04d}] loss={loss.item():.6f}, lr={opt.param_groups[0]['lr']:.2e}")
    return np.array(hist, dtype=float)

# ----------------------------
# 實驗：白板版 ESM vs DSM
# ----------------------------
STEPS, BATCH = 200, 1024

# ESM：學 s(x) ≈ -(x-3)/sigma_data^2
esm_net = Score1D(hidden=32, use_sigma=False).to(device)
esm_curve = train_loop(
    esm_net,
    loss_fn=lambda M, n_batch: esm_loss(M, n_batch),
    steps=STEPS, batch=BATCH, lr=1e-3, tag="ESM"
)

# DSM（多 σ，前期不 reweight → 後期用 σ^2）
dsm_net = Score1D(hidden=32, use_sigma=True).to(device)
dsm_curve = []
for t in range(1, STEPS+1):
    rule = "none" if t <= 120 else "sigma2"
    loss = dsm_loss_multisigma(dsm_net, n_batch=BATCH, low=5e-3, high=0.2,
                               lambda_rule=rule, clip=50.0)
    if t == 1:
        opt = torch.optim.Adam(dsm_net.parameters(), lr=3e-3)
    opt.zero_grad(set_to_none=True)
    loss.backward()
    torch.nn.utils.clip_grad_norm_(dsm_net.parameters(), 1.0)
    opt.step()
    dsm_curve.append(loss.item())
    if t % 20 == 0:
        print(f"[DSM  {t:04d}] loss={loss.item():.6f}, rule={rule}, lr={opt.param_groups[0]['lr']:.2e}")
dsm_curve = np.array(dsm_curve, dtype=float)

# ----------------------------
# Cross-evaluation（在白板例子上檢查「接近」）
# ----------------------------
with torch.no_grad():
    X_eval = sample_data(4000)

# 1) ESM native
native_esm = ((esm_net(X_eval) - true_score(X_eval))**2).mean().item()

# 2) DSM native（在 σ→小 的 regime）
def dsm_eval_near_zero(model, x, sigmas=(1e-2, 2e-2, 5e-2)):
    vals = []
    for s in sigmas:
        z = torch.randn_like(x)
        y = x + s * z
        target = (x - y) / (s*s)
        pred = model(y, sigma=s)
        vals.append(((pred - target)**2).mean().item())
    return float(sum(vals)/len(vals))

native_dsm_small = dsm_eval_near_zero(dsm_net, X_eval)

# 3) DSM 模型的 ESM（用很小 σ，期望接近）
def esm_of_dsm(model, x, tiny_sigma=1e-2):
    # 用 DSM 模型在 tiny σ 下的輸出近似 s(x)
    s_pred = model(x, sigma=tiny_sigma)
    return ((s_pred - true_score(x))**2).mean().item()

cross_esm_on_dsm = esm_of_dsm(dsm_net, X_eval, tiny_sigma=1e-2)

# 4) ESM 模型的 DSM（小 σ）
def dsm_of_esm(model, x, sigmas=(1e-2, 2e-2, 5e-2)):
    vals = []
    for s in sigmas:
        z = torch.randn_like(x)
        y = x + s * z
        target = (x - y) / (s*s)
        pred = model(y, sigma=None)  # ESM 模型不吃 σ
        vals.append(((pred - target)**2).mean().item())
    return float(sum(vals)/len(vals))

cross_dsm_on_esm = dsm_of_esm(esm_net, X_eval)

print("\n=== Evaluation on 1D Gaussian ===")
print(f"Native ESM (to true score):        {native_esm:.6f}")
print(f"Native DSM (σ in { (1e-2,2e-2,5e-2) }): {native_dsm_small:.6f}")
print(f"ESM of DSM-model (σ≈1e-2):         {cross_esm_on_dsm:.6f}")
print(f"DSM of ESM-model (avg small σ):    {cross_dsm_on_esm:.6f}")

# ----------------------------
# 圖：訓練曲線
# ----------------------------
plt.figure(figsize=(6,4))
plt.plot(esm_curve, label="ESM train loss")
plt.plot(dsm_curve, label="DSM train loss (multi-σ)")
plt.xlabel("Step"); plt.ylabel("Loss"); plt.title("Training curves on 1D Gaussian")
plt.legend(); plt.tight_layout(); plt.show()