In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from functools import partial


In [2]:
class GaussianFourierProjection(nn.Module):
    '''
    time step 인코딩 가우시안 랜덤 피쳐
    '''
    def __init__(self, embed_dim, scale=30.):
        super().__init__()

        # initialization할때 랜덤으로 샘플링된 weights. (Fixed, not learnable)
        self.W = nn.Parameter(torch.randn(embed_dim // 2) * scale, requires_grad=False)

    def forward(self, x):
        # TODO: x.shape: ?

        x_proj = x[:, None] * self.W[None, :] * 2 * np.pi
        return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
    
class Dense(nn.Module):
    '''
    Fully Connected Layer
    '''
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.dense = nn.Linear(input_dim, output_dim)
    
    def forward(self, x):
        # 왜 마지막에 [..., 1, 1] 뒤에 2개 차원을 더 붙이지,,?
        return self.dense(x)[..., None, None]

class ScoreNet(nn.Module):
    '''
    Time-Dependent score-based model
    U-Net 기반.
    '''
    def __init__(self, marginal_prob_std, channels=[32, 64, 128, 256], embed_dim=256):
        ''' Init
        
        Args:
            - marginal_prob_std: t를 입력으로 받아 perturbation kernel 
                p_{0t}(x(t) | x(0))의 표준편차 반환하는 함수
            - channels: 피쳐맵 각 resolution 채널 수
            - embed_dim: 가우시안 랜덤 피쳐 임베딩 차원 수
        '''
        super().__init__()

        # 시간 임베딩 가우시안 랜덤 피쳐
        self.embed = nn.Sequential(GaussianFourierProjection(embed_dim=embed_dim),
                                   nn.Linear(embed_dim, embed_dim))
        
        # Encoding: resolution decreases
        self.conv1 = nn.Conv2d(1, channels[0], 3, stride=1, bias=False)
        self.dense1 = Dense(embed_dim, channels[0])
        self.gnorm1 = nn.GroupNorm(4, num_channels=channels[0])

        self.conv2 = nn.Conv2d(channels[0], channels[1], 3, stride=2, bias=False)
        self.dense2 = Dense(embed_dim, channels[1])
        self.gnorm2 = nn.GroupNorm(32, num_channels=channels[1])

        self.conv3 = nn.Conv2d(channels[1], channels[2], 3, stride=2, bias=False)
        self.dense3 = Dense(embed_dim, channels[2])
        self.gnorm3 = nn.GroupNorm(32, num_channels=channels[2])

        self.conv4 = nn.Conv2d(channels[2], channels[3], 3, stride=2, bias=False)
        self.dense4 = Dense(embed_dim, channels[3])
        self.gnorm4 = nn.GroupNorm(32, num_channels=channels[3])

        # Decoding: resolution increases
        self.tconv4 = nn.ConvTranspose2d(channels[3], channels[2], 3, stride=2, bias=False)
        self.dense5 = Dense(embed_dim, channels[2])
        self.tgnorm4 = nn.GroupNorm(32, num_channels=[2])

        self.tconv3 = nn.ConvTranspose2d(channels[2] + channels[2], channels[1], 3, stride=2, bias=False, output_padding=1)
        self.dense6 = Dense(embed_dim, channels[1])
        self.tgnorm3 = nn.GroupNorm(32, num_channels=[1])

        self.tconv2 = nn.ConvTranspose2d(channels[1] + channels[1], channels[0], 3, stride=2, bias=False, output_padding=1)
        self.dense7 = Dense(embed_dim, channels[0])
        self.tgnorm2 = nn.GroupNorm(32, num_channels=[0])

        self.tconv1 = nn.ConvTranspose2d(channels[0] + channels[0], 1, 3, stride=1)

        # The Swish Activation Function (SELU)
        self.act = lambda x: x * torch.sigmoid(x)
        self.marginal_prob_std = marginal_prob_std
    
    def forward(self, x, t):
        # 시간 t를 가우시안 랜덤 피쳐에 임베딩
        embed = self.act(self.embed(t))

        # Encoding
        h1 = self.act(self.gnorm1(self.conv1(x) + self.dense1(embed)))
        h2 = self.act(self.gnorm2(self.conv2(h1) + self.dense2(embed)))
        h3 = self.act(self.gnorm3(self.conv3(h2) + self.dense3(embed)))
        h4 = self.act(self.gnorm4(self.conv4(h3) + self.dense4(embed)))

        # Decoding
        h = self.act(self.tgnorm4(self.tconv4(h4) + self.dense5(embed)))
        h = self.act(self.tgnorm3(self.tconv3(torch.cat([h, h3], dim=1)) + self.dense6(embed)))
        h = self.act(self.tgnorm2(self.tconv3(torch.cat([h, h2], dim=1)) + self.dense7(embed)))
        h = self.tconv1(torch.cat([h, h1], dim=1))

        h = h / self.marginal_prob_std(t)[:, None, None, None]
        return h

\begin{alignat}{1}
dx = \sigma^t dw, ~ t\in [0, 1]\\
p_{0t}(x(t)|x(0)) = N\left( x(t); x(0), \frac 1 {2\log\sigma} (\sigma^{2t} -1 )\mathbf{I}\right)
\end{alignat}


In [3]:
device = 'mps'
sigma = 25. 

def marginal_prob_std(t, sigma):
    '''Compute the mean and standard deviation of $p_{0t}(x(t)|x(0))$

    Args:
        - t: time step vector
        - sigma: SDE에서 sigma
    
    Returns:
        - standard deviation of transition probability distribution p_0t
    
    '''
    t = torch.tensor(t, device=device)
    # cf. Eq(2)
    return torch.sqrt((sigma ** (2 * t) - 1.) / 2. / np.log(sigma))

def diffusion_coeff(t, sigma):
    '''Compute the diffusion coeff g(t): sigma^t

    Args:
        - t: time step vector
        - sigma: SDE sigma

    Returns:
        - g(t)
    '''
    return torch.tenosr(sigma**t, device=device)

In [4]:
marginal_prob_std = partial(marginal_prob_std, sigma=sigma)
diffusion_coeff = partial(diffusion_coeff, sigma=sigma)

### Objectives

$s_\theta$가 학습해야하는 파라미터의 손실(목적)함수는 아래와 같다.
$$
\tag{3} \theta^* = \argmin_\theta \Bbb{E}_{t\sim \mathcal{U}(0,T)}[\lambda(t)\Bbb{E}_{x \sim p_0} \Bbb{E}_{x(t) \sim p_{0t}}[||s_\theta(x(t),t) - \nabla_x \log p_{0t}(x(t) | x(0))||_2^2]] 
$$

이때, 우리가 설정한 SDE 기준으로 transition probability distribution $p_{0t}$는 아래와 같이

정규 분포 꼴로 나타나게 된다.
$$
\tag{4} p_{0t}(x(t)|x(0)) = \mathcal{N}\left(x(t);x(0),\Sigma^2\right)\\
\text{where}~~~ \Sigma^2 = \frac 1 {2\log\sigma}(\sigma^{2t}-1)\mathbf{I}\\
p_{0t}(x(t) | x(0)) = \frac {1} {\sqrt{2\pi\Sigma^2}}\exp\left(-\frac {(x(t) - x(0))^2} {2\Sigma^2}\right)
$$

정규 분포에 로그를 씌우고, 미분하게 되면 아래와 같은 간단한 형태의 수식을 얻을 수 있다.

$$
\log p_{0t}(x(t) | x(0)) = -\frac {(x(t) - x(0))^2} {2\Sigma^2} - \log \sqrt{2\pi\Sigma^2}\\
\tag{5} \nabla_{x(t)} \log p_{0t}(x(t)|x(0)) = - \frac {x(t)-x(0)} {\Sigma^2}
$$

$\Sigma^2$를 대입해준다. (cf. Eq.4)
$$
\tag{6} \nabla_{x(t)} \log p_{0t}(x(t)|x(0)) = - \frac {x(t) - x(0)} {(\sigma^{2t} - 1)}2\log\sigma 
$$

In [5]:
# cf. Eq(3)
def loss_fn(model, x, marginal_prob_std, eps=1e-5):
    '''The loss function for training score-based generative models. (objectives)

    Args:
        - model: Pytorch Model instance (time-dependent, score-based model e.g. U-net)  
        - x: A mini-batch of training data.
        - marginial_prob_std: perturbed distribution의 표준편차를 리턴하는 함수
        - eps: A tolerance value for numerical stability
    '''
    # random t from Uniform distribution
    # eps: t ∈ [0, 1] -> t ∈ [eps, 1]
    random_t = torch.rand(x.shape[0], device=x.device) * (1. - eps) + eps
    
    # random standard noraml distribution(SND) noise
    z = torch.randn_like(x)

    # t에서 p_0t의 std
    std = marginal_prob_std(random_t)
    
    # x(0) -> x(t)
    perturbed_x = x + z * std[:, None, None, None]
    
    # s_theta(x(t), t) ~= ∇logp(x(t)|x(0))
    score = model(perturbed_x, random_t)
    loss = torch.mean(torch.sum((score * std[:, None, None, None] + z) ** 2, dim= (1,2,3)))
    


In [6]:
rt = torch.rand(5)
rt * (1. - 1e-1) + 1e-1, rt

(tensor([0.6054, 0.4844, 0.2440, 0.4463, 0.8068]),
 tensor([0.5615, 0.4271, 0.1600, 0.3848, 0.7853]))