In [None]:
import torch

class LinearNoiseScheduler():
    def __init__(self, num_steps, beta_start, beta_end):
        self.num_steps = num_steps
        self.beta_start = beta_start
        self.beta_end = beta_end
        self.step = 0

        # pre-compute alphas and betas
        self.betas = torch.linspace(beta_start, beta_end, num_steps)
        self.alphas = 1 - self.betas
        # \bar{\alpha}_t}
        self.alpha_cum_prod = torch.cumprod(self.alphas, 0)
        # \sqrt{\bar{\alpha}_t}}
        self.sqrt_alpha_cum_prod = torch.sqrt(self.alpha_cum_prod)
        # \sqrt{1-\bar{\alpha}_t}}
        self.sqrt_one_minus_alpha_cum_prod = torch.sqrt(1 - self.alpha_cum_prod)

    # forward process
    def add_noise(self, original, noise, t):
        original_shape = original.shape
        batch_size = original_shape[0]

        sqrt_alph_cum_prod = self.sqrt_alpha_cum_prod[t].repeat(batch_size, 1)
        sqrt_one_minus_alpha_cum_prod = self.sqrt_one_minus_alpha_cum_prod[t].repeat(batch_size, 1)

        for _ in range(original.dim() - 1):
            sqrt_alph_cum_prod = sqrt_alph_cum_prod.unsqueeze(-1)
            sqrt_one_minus_alpha_cum_prod = sqrt_one_minus_alpha_cum_prod.unsqueeze(-1)

        # \sqrt{\bar{\alpha}_t}} * x_0 + (1-\sqrt{\bar{\alpha}_t}) * \epsilon_t
        return sqrt_alph_cum_prod * original + sqrt_one_minus_alpha_cum_prod * noise
    
    def sample_prev_timestep(self, xt, noise_pred, t):
        # x0 = (xt - \sqrt{1-\bar{\alpha}_t}} * \epsilon_t) / \sqrt{\bar{\alpha}_t}}
        x0 = (
            xt - self.sqrt_one_minus_alpha_cum_prod[t] * noise_pred
        ) / self.sqrt_alpha_cum_prod[t]

        x0 = torch.clamp(x0, -1, 1)

        mean = xt - (self.betas[t] * noise_pred) / self.sqrt_one_minus_alpha_cum_prod[t]
        mean = mean / torch.sqrt(self.alphas[t])

        if t == 0:
            return mean, x0

        variance = (1 - self.alpha_cum_prod[t-1]) / (1 - self.alpha_cum_prod[t])
        variance *= self.betas[t]
        sigma = torch.sqrt(variance)
        # sample from Gaussian distribution
        z = torch.randn(xt.shape).to(xt.device)
        return mean + sigma * z, x0