<a href="https://colab.research.google.com/github/Siarzis/custom-ai/blob/main/custom_diffusion_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Import Libraries

In [2]:
import torch

Create the class of the denoising diffusion model

In [None]:
# inherit from the Pytorch "nn.Module"
class custom_DDPM(nn.Module):
    # inputs of the custom_DDPM ->
    # 1. network: the utilized neural network for the diffusion process,
    # 2. n_steps: number of steps in the diffusion process,
    # 3.,4. min_beta, max_beta: the minimum and maximum values for the diffusion parameter
    # 5. device: the device on which the model runs
    # 6. image_shape: shape of the input image.

    # the input parameters with provided values have default values
    def __init__(self, network, min_beta=10 ** -4, max_beta=0.02, n_steps=200, device=None, image_shape=(1, 28, 28)):
        # pass the provided input to the instance of the class
        self.n_steps = n_steps
        self.device = device
        self.image_shape = image_shape
        self.network = network.to(device)

        # create a tensor of all the diffusion parameters, equally distributed [beta_min, beta_1, beta_2, ... , beta_max]
        # in the original work by Ho et. al., betas are put in a linear space from 0.0001 to 0.02
        self.betas = torch.linspace(min_beta, max_beta, n_steps).to(device)

        # calculate alpha parameters for the diffusion process
        # alpha is a complementary parameter to beta and together they control the diffusion process
        # beta represents the portion of noise to add at each diffusion step and
        # alpha represents the complementary probability of not adding that noise.
        self.alphas = 1 - self.betas

        # calculates the cumulative product of alpha values up to each step in the diffusion process
        self.alpha_bars = torch.tensor([torch.prod(self.alphas[:i+1]) for i in range(len(self.alphas))])

    # explanation of inputs
    # x0: input image at the initial time step of the diffusion process
    # t: current time step of the diffusion process
    # eta: noise tensor (optional).; if not provided, random noise is injected.
    def forward(self, x0, t, eta=None):

        # the dimensions of the input image x0 are extracted
        batch_size, channels, height, width = x0.shape
        a_bar = self.alpha_bars[t]

        if eta is None:
            eta = torch.randn(batch_size, channels, height, width).to(self.device)

        noisy = a_bar.sqrt().reshape(batch_size, 1, 1, 1) * x0 + (1 - a_bar).sqrt().reshape(batch_size, 1, 1, 1) * eta
        return noisy

    def backward(self, x, t):
        return self.network(x, t)