In [2]:
import torch
from torch import nn
from typing import Tuple, Optional
from torch.nn import functional as F

In [13]:
def gather(consts: torch.Tensor, t: torch.Tensor):
    """Gather consts for t and reshape to feature map shape"""
    c = consts.gather(-1, t)
    return c.reshape(-1, 1, 1, 1)

In [None]:
class DenoiseDiffusion:
    
    def __init__(self, eps_model: nn.Module, n_steps: int, device: torch.device):
        """
        :param eps_model:  Unet去噪网络
        :param n_steps: 训练总步数
        :param device: 
        """
        super().__init__()
        self.eps_model = eps_model
        self.beta = torch.linspace(0.0001, 0.02, n_steps).to(device)
        self.alpha = 1. - self.beta
        self.alpha_bar = torch.cumprob(self.alpha, dim=0)
        self.n_steps = n_steps
        self.sigma2 = self.beta
        
    def q_xt_x0(self, x0: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        :param x0:  来自训练数据的干净的图片
        :param t:  时间步
        :return: 
            mean: xt服从的高斯分布均值
            var： xt服从的高斯分布方差
        """
        mean = gather(self.alpha_bar, t) ** 0.5 * x0
        var = 1 - gather(self.alpha_bar, t)
        return mean, var
    
    def q_sample(self, x0: torch.Tensor, t: torch.Tensor, eps: Optional[torch.Tensor] = None):
        if eps is None:
            eps = torch.randn_like(x0)

        mean, var = self.q_xt_x0(x0, t)
        return mean + (var ** 0.5) * eps
    
    def p_sample(self, xt: torch.Tensor, t: torch.Tensor):

        eps_theta = self.eps_model(xt, t)
        alpha_bar = gather(self.alpha_bar, t)
        alpha = gather(self.alpha, t)
        eps_coef = (1 - alpha) / (1 - alpha_bar) ** .5
        mean = 1 / (alpha ** 0.5) * (xt - eps_coef * eps_theta)
        var = gather(self.sigma2, t)

        eps = torch.randn(xt.shape, device=xt.device)
        # Sample
        return mean + (var ** .5) * eps
    
    def loss(self, x0: torch.Tensor, noise: Optional[torch.Tensor] = None):

        # Get batch size
        batch_size = x0.shape[0]
        t = torch.randint(0, self.n_steps, (batch_size,), device=x0.device, dtype=torch.long)
        if noise is None:
            noise = torch.randn_like(x0)

        xt = self.q_sample(x0, t, eps=noise)
        eps_theta = self.eps_model(xt, t)

        # MSE loss
        return F.mse_loss(noise, eps_theta)

In [3]:
t = torch.randint(0, 1000, (16,), device="cpu", dtype=torch.long)

In [6]:
beta = torch.linspace(0.0001, 0.02, 1000)
alpha = 1. - beta
alpha_bar = torch.cumprod(alpha, dim=0)

In [11]:
x0 = torch.randn([16, 3, 256, 256])

In [15]:
mean = gather(alpha_bar, t) ** 0.5 * x0
mean.shape

torch.Size([16, 3, 256, 256])