# Denoising Diffusion Implicit Models

In DDPM, we construct the forward (noising) process as a Markov chain in which we add an amount of noise determined by our beta schedule at every timestep. 

Since each step in the denoising process is tied to the previous state, we are subject to a large number of denoising steps to go from xT to x0. 

In DDIM, they introduce a non-markovian forward process that leads to the same surrogate objective function as DDPM to ease this 

We can take a noise predictor pretained for DDPM and plug it into our DDIM denoising (sampling process)

In [2]:
import torch
from torch import nn

In [3]:
from src.ddpm import get_alpha_bar

## NOTE: alpha_bar from DDPM is referred to as alpha in DDIM

def denoise(
        x: torch.Tensor, 
        model_pred: torch.Tensor,
        alpha_t: torch.Tensor, 
        alpha_t_minus_1: torch.Tensor, 
        sigma_t: torch.Tensor, 
        t: int, 
        z: torch.Tensor
    ) -> torch.Tensor:
    term1 = (alpha_t_minus_1.sqrt() / alpha_t.sqrt()) * (x - (1-alpha_t).sqrt() * model_pred)
    term2 = ((1 - alpha_t_minus_1 - sigma_t.square()).sqrt()) * model_pred
    
    deterministic_term = term1 + term2
    
    if t > 0:
        x_t_minus_1 = deterministic_term + sigma_t * z
    else:
        x_t_minus_1 = deterministic_term
    
    return x_t_minus_1


def sample(
    model: nn.Module, 
    beta_schedule: torch.Tensor,
    step_size: int,
    T: int, 
    device: torch.device, 
    image_shape: tuple[int, int, int],
    num_samples: int = 16, 
    using_diffusers: bool = False
    ) -> torch.Tensor:
    """
    Sample images from the model.

    Args:
        model: the model to sample from
        beta_schedule: the beta values for each timestep
        T: the number of diffusion steps
        device: the device to sample on
        num_samples: the number of samples to generate
        image_shape: the shape of the input image
    """
    x = torch.randn(num_samples, *image_shape).to(device)
    alpha_bar = get_alpha_bar(beta_schedule)
    tau = reversed(range(0, T, step_size))

    model.eval()

    with torch.no_grad():
        for t in tau:
            time_input = torch.full((num_samples,), t, device=device)
            beta_t = beta_schedule[t]
            sigma_t = beta_t.sqrt()

            alpha_t = alpha_bar[t]
            alpha_t_minus_1 = alpha_bar[t - step_size]
            
            z = torch.randn_like(x).to(device) if t > 0 else torch.zeros_like(x).to(device)
            
            if using_diffusers:
                model_pred = model(x, time_input).sample
            else:
                model_pred = model(x, time_input)
            
            x = denoise(x, model_pred, alpha_t, alpha_t_minus_1, sigma_t, t, z)

    return x