# 比較 DSM loss 跟 ESM loss

* ***DSM Loss***: 先加噪 $y=x+\sigma z,\,\, z\sim N(0,I)$, 再最小化 $$L_{DSM}(\theta)=\mathbb{E}[\,\lambda(\sigma)\,\, \big||s_{\theta}(y,\sigma)-\frac{x-y}{\sigma}\big||^2\,],$$ 其中這邊用$\lambda(\sigma)=\sigma^2$
* ***ESM Loss***: $$L_{ESM}(\theta)=\mathbb{E}_{x\sim p}[\frac{1}{2}||s_{\theta}(x)||^2+\text{div}_{x}s_{\theta}(x)]$$ 這邊的散度定義div為 $\text{tr}(J)=\mathbb{E}_v[v^{T}Jv]$, 又稱 **Hutchinson’s trick**


## DATASET
* 產生2D的 normal distribution 點
  * 中心 $(0,0)$; variance matrix $\begin{bmatrix}1 & 0.3\\ 0.3 & 1.2\end{bmatrix}$
  * 中心 $(3,3)$; variance matrix $\begin{bmatrix}0.6 & -0.2\\ -0.2 & 0.8\end{bmatrix}$
* 建立最小的MLP score net
* 計算ESM 跟DSM 的一次性數值(不訓練)

## Simulation

In [1]:
import torch, math, random, numpy as np
import torch.nn as nn
import torch.nn.functional as F

In [2]:
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)
device = torch.device("cpu")

In [7]:
def sample_gmm(n=8192):
    n1 = n//2
    n2 = n-n1
    mean1 = torch.tensor([0.0,0.0])
    mean2 = torch.tensor([3.3,3.3])
    cov1 = torch.tensor([[1.0, 0.3],
                         [0.3, 1.2]], dtype=torch.float32)
    cov2 = torch.tensor([[0.6, -0.2],
                         [-0.2, 0.8]], dtype=torch.float32)
    L1 = torch.linalg.cholesky(cov1)
    L2 = torch.linalg.cholesky(cov2)
    z1 = torch.randn(n1,2) @ L1.T +mean1
    z2 = torch.randn(n2,2) @ L2.T +mean2
    
    x = torch.cat([z1, z2], dim=0)
    return x.to(device)

X=sample_gmm(4096). requires_grad_(True)

In [9]:
class ScoreMLP(nn.Module):
    def __init__(self, in_dim, hidden, use_sigma=True):
        super().__init__()
        self.use_sigma = use_sigma 
        eff_in = in_dim + (1 if use_sigma else 0)
        self.fc1 = nn.Linear(eff_in, hidden)
        self.fc2 = nn.Linear(hidden, hidden)
        self.fc3 = nn.Linear(hidden, in_dim)

        for m in [self.fc1, self.fc2, self.fc3]:
            nn.init.xavier_uniform_(m.weight, gain=1.0)
            nn.init.zeros_(m.bias)
    
    def forward(self, x, sigma=None):
        if self.use_sigma and (sigma is not None):
            if not torch.is_tensor(sigma):
                sigma = torch.tensor(sigma, dtype=x.dtype, device=x.device)
            log_sigma = torch.log(sigma).expand(x.shape[0], 1)
            h=torch.cat([x, log_sigma], dim=1)
        else:
            h=x
        h=F.silu(self.fc1(h))
        h=F.silu(self.fc2(h))
        out =self.fc3(h)
        
        return out
    
model_ESM = ScoreMLP(in_dim=2, hidden=64, use_sigma = False).to(device)
model_DSM = ScoreMLP(in_dim=2, hidden=64, use_sigma = True).to(device)


In [11]:
def ESM_loss(model, x, num_trace_samples=1):
    x = x.requires_grad_(True)
    s = model(x, sigma=None)
    sq = 0.5*(s*s).sum(dim=1)
    div_est_total = torch.zeros_like(sq)

    for _ in range(num_trace_samples):
        v = torch.empty_like(x).bernoulli_(0.5).mul_(2).sub(1) # Redemacher??
        sv = (s*v).sum()
        (grad_x,) = torch.autograd.gard(sv, x, vreate_graph = True)
        trJ_est = (grad_x * v).sum(dim=1)
        div_est_total = div_est_total+trJ_est
    div_est = div_est_total / float(num_trace_samples)
    
    return (sq+div_est).mean()

def DSM_loss(model, x, sigma=0.2, lambda_by_sigma=True):
    sigma = torch.tensor(float(sigma), device = x.device)
    z= torch.randn_like(x)
    y=x+sigma *z
    target =(x-y)/(sigma*sigma)
    pred = model(y.detach(), sigma=sigma)
    per = (pred-target)
    per = (per*per).sum(dim=1)
    if lambda_by_sigma:
        per = (sigma*sigma)*per
    
    return per.mean()