# Diffusion from scratch

Diffusion corrupts data and learns to reconstruct it. This allows it to learn distributions in the data so that with new randomness, it can enforce the distributions it learned to generate new data. 

## Forward Process

The forward process is a Markov chain where each step adds noise to the data.

For a given data point $x_0$ from the data distribution $p$, at each step $t$ from $0$ to $T$, the noise added is defined by:

$x_t = \sqrt{1 - \beta_t} \cdot x_{t-1} + \sqrt{\beta_t} \cdot \epsilon$

where:
- $\beta_t$ is a predefined noise schedule.
- $\epsilon$ is Gaussian noise.

### Simplified Form:

$x_t = \alpha_t \cdot x_0 + \sigma_t \cdot \epsilon$

where $\alpha_t$ and $\sigma_t$ are computed directly from the noise schedule.

---

## Reparametrization

We use a reparametrization technique to rewrite the equation w.r.t $x_0$. This allows us to calculate the noise at timestep $t$ relative to timestep $0$, rather than relying on $t-1$.

This is useful because, in the reverse process, the neural network learns to predict the noise to reconstruct the original data **without sequential dependency** on each prior timestep $t-1$. Reduces complexity by not requiring the computation of intermediate steps.

### Reparametrized Equation:

$x_t = \alpha_t \cdot x_0 + \sigma_t \cdot \epsilon$

where $\epsilon$ is the noise the model learns to predict.


In [None]:
import os
import torch
import torch.nn as nn
from tqdm import tqdm
from torch import optim
from matplotlib import pyplot as plt
import logging
import numpy as np

logging.basicConfig(format="%(asctime)s - %(levelname)s: %(message)s", level=logging.INFO, datefmt="%I:%M:%S")

# the following values are taken from the original diffusion paper: noise_steps=1000, beta_start=1e-4, beta_end=0.02 
class Diffusion:
    def __init__(self, total_timesteps=1000, beta_start=1e-4, beta_end=0.02, img_size=64, device="cpu"):
        self.total_timesteps = total_timesteps
        self.beta_start = beta_start
        self.beta_end = beta_end
        self.img_size = img_size
        self.device = device

        self.beta = self.prepare_noise_schedule().to(device)
        self.alpha = 1. - self.beta
        self.alpha_hat = torch.cumprod(self.alpha, dim=0)

    def prepare_noise_schedule(self):
        # linear scheduler (note: cosine scheduler is recommended due to smoother noise injection)
        return torch.linspace(self.beta_start, self.beta_end, self.total_timesteps)     # create a 1d tensor of evenly spaced values between two end points
        
    #  x_t = √(α̅ₜ) * x₀ + √(1 - α̅ₜ) * ε
    def noise_images(self, x, t):
        sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t])[:, None, None, None]                     # alpha hat is the culumative product of alphas up to timestep t
        sqrt_one_minus_alpha_hat = torch.sqrt(1. - self.alpha_hat[t])[:, None, None, None]      # [:, None, None, None] reshapes the tensor to match the image dimensions
        epsilon = torch.randn_like(x)
        return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * epsilon, epsilon

    def sample_timesteps(self, n):
        return torch.randint(low=1, high=self.noise_steps, size=(n,))
    
    def sample(self, model, n, labels, cfg_scale=3):
        logging.info(f"Sampling {n} new images....")
        model.eval()
        with torch.no_grad():
            x = torch.randn((n, 3, self.img_size, self.img_size)).to(self.device)       # initialize x as pure gaussian noise (x_t at most noisy state)
            for i in tqdm(reversed(range(1, self.noise_steps)), position=0):            # reverse diffusion process (iterate backwards through timsteps t-1 to 1)
                t = (torch.ones(n) * i).long().to(self.device)      
                predicted_noise = model(x, t, labels)
                if cfg_scale > 0:
                    uncond_predicted_noise = model(x, t, None)
                    predicted_noise = torch.lerp(uncond_predicted_noise, predicted_noise, cfg_scale)
                alpha = self.alpha[t][:, None, None, None]
                alpha_hat = self.alpha_hat[t][:, None, None, None]
                beta = self.beta[t][:, None, None, None]
                if i > 1:
                    noise = torch.randn_like(x)
                else:
                    noise = torch.zeros_like(x)
                x = 1 / torch.sqrt(alpha) * (x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + torch.sqrt(beta) * noise
        model.train()
        x = (x.clamp(-1, 1) + 1) / 2
        x = (x * 255).type(torch.uint8)
        return x

## Recommended Resources

- [Diffusion Models Explained](https://www.youtube.com/watch?v=fbJac4qQy04&ab_channel=ComputerVisionwithH%C3%BCseyin%C3%96zdemir)
- [Diffusion Models Implementation](https://www.youtube.com/watch?v=TBCRlnwJtZU&ab_channel=Outlier)