In [None]:
import torch 
import torch.nn as nn
import torch.nn.functional as F

from typing import Optional
from helper import accumulate

In [None]:
class DenoisingProcess(nn.Module):
    def __init__(self, alphas:torch.Tensor, betas:torch.Tensor, epsilon:torch.Tensor, T:int, sigma_sqaured:torch.Tensor, device:torch.device):
        super().__init__()
        self.betas = torch.linespace(0.0001, 0.02, T).to(self.device)
        self.alphas = 1-self.betas
        self.epsilon = epsilon
        self.T = T
        self.sigma_sqaured = sigma_sqaured
        self.alphas_bar = torch.cumprod(self.alphas, dim=0)
        self.device = device
    
    def reparameterization(self, x_0:torch.Tensor, t: int) -> [torch.Tensor, torch.Tensor]:
        """
        x_0 : Original image tensor
        t : Adding noise for t steps 
        """
        mean = accumulate(self.alphas_bar**0.5*x_0)
        variance = 1-accumulate(self.alphas_bar, t)

    def sampling(self, x_0:torch.Tensor, t:int, epsilon: Optional[torch.Tensor]=None):
        """
        x_0 : Original image tensor
        t : Adding noise for t steps
        epsilon : Random numbers generated from a standard normal distribution
        """
        if epsilon is None:
            epsilon = torch.randn_like(x_0)
        mean, variance = self.reparameterization(x_0, t)
        return mean + (variance**0.5)*epsilon
    
    def denoising_sample(self, x_t:torch.Tensor, t:int):
        epsilon_theta = self.epsilon(x_t,t)
        alpha_bar = accumulate(self.alphas_bar, t)
        alpha_t = accumulate(self.alphas, t)
        second_part = (1-alpha_t)/(1-alpha_bar)**0.5
        mean = 1 / (alpha_t**0.5)*(x_t - second_part * epsilon_theta)
        variance = accumulate(self.sigma_sqaured, t)
        epsilon = torch.randn_like(x_t,device=self.device)
        return mean + variance**0.5*epsilon
    
    def loss(self, x_0:torch.Tensor, noise:Optional[torch.Tensor]=None):
        batch_size = x_0.shape[0]
        t = torch.randint(0, self.T, (batch_size, ), device=self.device, dtype = torch.long)
        if noise is None:
            noise = torch.randn_like(x_0)
        x_t = self.sampling(x_0, t, noise)
        epsilon_theta = self.epsilon(x_t, t)
        loss = F.mse_loss(epsilon_theta, noise)
        return loss 