# Import Library

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
import math
import numpy as np


# Diffusion

## Extract

In [2]:
def extract(alphas, timestep, x_shape):
    device = timestep.device
    out = torch.gather(alphas, index=timestep, dim=0).double().to(device)
    return out.view([timestep.shape[0]] + (len(x_shape)-1) * [1])


## Training

    betas   =   1 - alphas

- betas:
$$[\beta_1,\beta_2,...,\beta_T]$$
$$\beta_t = 1 - \alpha_t$$

- alphas:
$$ [\alpha_1, \alpha_2,...,\alpha_T] $$

- alphas_bar:
$$ [\alpha_1, \alpha_1 \alpha_2,..., \alpha_1 \alpha_2... \alpha_T]$$
$$[\bar\alpha_1, \bar\alpha_2,..., \bar\alpha_T]$$

- objective:
$$
   \arg \min_\theta \| \epsilon - \epsilon_\theta \| ^2
$$
- gradient:
$$
    \nabla_\theta \| \epsilon - \epsilon_\theta( \sqrt{\bar \alpha_t} x_0 + \sqrt{1 - \bar \alpha_t} \epsilon , t) \| ^2
$$

The input of the Unet:
$$\sqrt{\bar \alpha_t} x_0 + \sqrt{1 - \bar \alpha_t} \epsilon $$
$$t$$  
$$\epsilon$$    

In [3]:
class Training(nn.Module):
    def __init__(self, model, T, beta_1, beta_T):
        super().__init__()
        self.T = T
        self.model = model

        # 存储超参数到模型中
        self.register_buffer(
            'betas', torch.linspace(beta_1, beta_T, T).double()
        )
        alphas = 1. - self.betas
        alphas_bar = torch.cumprod(alphas,dim=0)

        self.register_buffer(
            'sqrt_alphas_bar', torch.sqrt(alphas_bar)
        )
        self.register_buffer(
            'sqrt_one_minus_alphas_bar', torch.sqrt(1. - alphas_bar)
        )


    def forward(self, x_0):
        t = torch.randint(self.T, size=(x_0.shape[0],), device=x_0.device)
        noise = torch.randn_like(x_0)
        x_t = (
              extract(self.sqrt_alphas_bar, t, x_0.shape) * x_0 + 
              extract(self.sqrt_one_minus_alphas_bar, t, x_0.shape) * noise)
        
        # reduction = 'none' 逐元素求mse
        loss = F.mse_loss(
            self.model(x_t, t), noise, reduction='none'
        )
        return loss

## Sampling

- alphas_bar_prev:
$$\alpha_{t-1}$$
$$ [1, \alpha_1, \alpha_1 \alpha_2,..., \alpha_1 \alpha_2... \alpha_{T-1}]$$
$$[1, \bar\alpha_1, \bar\alpha_2,..., \bar\alpha_{T-1}]$$

objective:
$$x_0$$

- loop:
$$
x_{t-1} = \frac {x_t}{\sqrt \alpha_t} - \frac {1-\alpha_t}{\sqrt{\alpha_t} \sqrt{1- \bar \alpha_t} } \epsilon_\theta(x_t, t) + \sigma_t z
$$

where:
$$
    \sigma_t ^2 = \frac {1-\bar \alpha_{t-1}}{1-\bar \alpha} \beta_t
$$

In [4]:
class Sampling(nn.Module):
   def __init__(self, model, beta_1, beta_T, T):
      super().__init__()

      self.model = model
      self.T = T

      self.register_buffer('betas', torch.linspace(beta_1, beta_T, T).double())
      alphas = 1. - self.betas
      alphas_bar = torch.cumprod(alphas, dim=0)
      # pad 参数是一个长度可变的列表或元组，遵循从最后一个维度向前的逆序填充规则
      # [1, 0] 表示只在最后一个维度左侧填充 1 个元素，右侧不填充
      alphas_bar_prev = F.pad(alphas_bar, pad=[1,0], value=1)[:T]

      self.register_buffer('coeff1', 1./ torch.sqrt(alphas_bar))
      self.register_buffer('coeff2', self.coeff1 * self.betas / torch.sqrt(1-alphas_bar))

      self.register_buffer('posterior_var', self.betas * (1 - alphas_bar_prev) / (1 - alphas_bar))
      
   def predict_xt_prev_mean_from_eps(self, x_t, t, eps):
      assert x_t.shape == eps.shape
      mean = extract(self.coeff1, t, x_t.shape) * x_t - \
             extract(self.coeff2, t, x_t.shape) * eps
      return mean
   
   def p_mean_variance(self, x_t, t):
      var = extract(self.posterior_var, t, x_t.shape)
      std = torch.sqrt(var)

      eps = self.model(x_t, t)
      xt_prev_mean = self.predict_xt_prev_mean_from_eps(x_t, t, eps)
      return xt_prev_mean, std
   
   def forward(self, x_T):
      x_t = x_T
      for time_step in reversed(range(self.T)):
         if time_step > 0:
            noise = torch.randn_like(x_t, device=x_t.device)
         else:
            noise = torch.zeros_like(x_t, device=x_t.device)
         t = x_t.new_ones([x_T.shape[0],],dtype=torch.long) * time_step
         mean, std = self.p_mean_variance(x_t, t)
         x_t = mean + noise * std
         assert torch.isnan(x_t).int().sum() == 0, "nan in tensor."
      x_0 = x_t
      return torch.clip(x_0,-1,1)
         