# Diffusion models

why is the mean of the distribution multiplied by sqrt(1-beta) at each noise step?

so that the values don't become arbitrarily large but the mean stays close to zero

What is the meaning of beta?

The variance of the noise added at each step, i.e. how much noise is added

In [64]:
import os
import time
from abc import ABC, abstractmethod
from dataclasses import dataclass
from functools import reduce
from operator import mul
from typing import Any, Optional, Union
import matplotlib.pyplot as plt
import plotly.express as px
import torch as t
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
import wandb

MAIN = __name__ == "__main__"

keyfile = "keystore.yaml"
if not wandb_key and os.path.exists(keyfile):
    import yaml
    keys = yaml.safe_load(open(keyfile,"r"))
    wandb_key = keys["wandb"]
os.environ["WANDB_API_KEY"] = wandb_key

NameError: name 'wandb_key' is not defined

In [65]:
def gradient_images(n_images: int, img_size: tuple[int, int, int]) -> t.Tensor:
    '''Generate n_images of img_size, each a color gradient
    '''
    (C, H, W) = img_size
    corners = t.randint(0, 255, (2, n_images, C))
    xs = t.linspace(0, W / (W + H), W)
    ys = t.linspace(0, H / (W + H), H)
    (x, y) = t.meshgrid(xs, ys, indexing="xy")
    grid = x + y
    grid = grid / grid[-1, -1]
    grid = repeat(grid, "h w -> b c h w", b=n_images, c=C)
    base = repeat(corners[0], "n c -> n c h w", h=H, w=W)
    ranges = repeat(corners[1] - corners[0], "n c -> n c h w", h=H, w=W)
    gradients = base + grid * ranges
    assert gradients.shape == (n_images, C, H, W)
    return gradients / 255

def plot_img(img: t.Tensor, title: Optional[str] = None) -> None:
    '''Plots a single image, with optional title.
    '''
    img = rearrange(img, "c h w -> h w c").clip(0, 1)
    img = (255 * img).to(t.uint8)
    fig = px.imshow(img, title=title)
    fig.update_layout(margin=dict(t=70 if title else 40, l=40, r=40, b=40))
    fig.show()

def plot_img_grid(imgs: t.Tensor, title: Optional[str] = None, cols: Optional[int] = None) -> None:
    '''Plots a grid of images, with optional title. Splits according to cols.
    '''
    b = imgs.shape[0]
    imgs = rearrange(imgs, "b c h w -> b h w c")
    imgs = (255 * imgs).to(t.uint8)
    if cols is None:
        cols = int(b**0.5) + 1
    fig = px.imshow(imgs, facet_col=0, facet_col_wrap=cols, title=title)
    for annotation in fig.layout.annotations:
        annotation["text"] = ""
    fig.show()

def plot_img_slideshow(imgs: t.Tensor, title: Optional[str] = None) -> None:
    '''Plots slideshow of images (useful for visualising denoising).
    '''
    imgs = rearrange(imgs, "b c h w -> b h w c")
    imgs = (255 * imgs).to(t.uint8)
    fig = px.imshow(imgs, animation_frame=0, title=title)
    fig.show()

if MAIN:
    print("A few samples from the input distribution: ")
    img_shape = (3, 16, 16)
    n_images = 5
    imgs = gradient_images(n_images, img_shape)
    for i in range(n_images):
        plot_img(imgs[i])

A few samples from the input distribution: 


In [66]:
def normalize_img(img: t.Tensor) -> t.Tensor:
    return img * 2 - 1

def denormalize_img(img: t.Tensor) -> t.Tensor:
    return ((img + 1) / 2).clamp(0, 1)

if MAIN:
    plot_img(imgs[0], "Original")
    plot_img(normalize_img(imgs[0]), "Normalized")
    plot_img(denormalize_img(normalize_img(imgs[0])), "Denormalized")

In [67]:
def linear_schedule(max_steps: int, min_noise: float = 0.0001, max_noise: float = 0.02) -> t.Tensor:
    '''
    Return the forward process variances as in the paper.

    max_steps: total number of steps of noise addition
    out: shape (step=max_steps, ) the amount of noise at each step
    '''
    return t.linspace(min_noise, max_noise, max_steps)

if MAIN:
    betas = linear_schedule(max_steps=200)

In [68]:
def q_forward_slow(x: t.Tensor, num_steps: int, betas: t.Tensor) -> t.Tensor:
    '''Return the input image with num_steps iterations of noise added according to schedule.
    x: shape (channels, height, width)
    betas: shape (T, ) with T >= num_steps

    out: shape (channels, height, width)
    '''
    for i in range(num_steps):
        # apply sample x_t ~ N(sqrt(1-beta_t) x_{t-1}, beta_t I))
        x = t.sqrt(1-betas[i]) * x + t.randn_like(x) * t.sqrt(betas[i])
    return x

if MAIN:
    x = normalize_img(gradient_images(1, (3, 16, 16))[0])
    for n in [1, 10, 50, 200]:
        xt = q_forward_slow(x, n, betas)
        plot_img(denormalize_img(xt), f"Equation 2 after {n} step(s)")
    plot_img(denormalize_img(t.randn_like(xt)), "Random Gaussian noise")

In [69]:


def q_forward_fast(x: t.Tensor, num_steps: int, betas: t.Tensor) -> t.Tensor:
    '''Equivalent to Equation 2 but without a for loop.'''
    # sample x_t from N(sqrt(alpha_dash_t x_0, (1-alpha_dash_t) I))
    alpha_t = 1-betas[:num_steps]
    alpha_dash_t = alpha_t.prod(dim=0)
    x = t.sqrt(alpha_dash_t) * x + t.randn_like(x) * t.sqrt(1-alpha_dash_t)
    return x

if MAIN:
    for n in [1, 10, 50, 200]:
        xt = q_forward_fast(x, n, betas)
        plot_img(denormalize_img(xt), f"Equation 4 after {n} steps")


Compare That the slow and the fast function look similar

## Noise Scheduler

In [70]:
class NoiseSchedule(nn.Module):
    betas: t.Tensor
    alphas: t.Tensor
    alpha_bars: t.Tensor

    def __init__(self, max_steps: int, device: Union[t.device, str]) -> None:
        super().__init__()
        self.max_steps = max_steps
        self.device = device
        # register buffers for betas, alphas, alpha_bars
        betas = t.zeros(max_steps)
        betas = linear_schedule(max_steps)
        self.register_buffer("betas", betas)
        alphas = 1 - betas
        self.register_buffer("alphas", alphas)
        alpha_bars = alphas.cumprod(dim=0)
        self.register_buffer("alpha_bars", alpha_bars)
        # move to device
        self.to(device)


    @t.inference_mode()
    def beta(self, num_steps: Union[int, t.Tensor]) -> t.Tensor:
        '''
        Returns the beta(s) corresponding to a given number of noise steps
        num_steps: int or int tensor of shape (batch_size,)
        Returns a tensor of shape (batch_size,), where batch_size is one if num_steps is an int
        '''
        return self.betas[num_steps]

    @t.inference_mode()
    def alpha(self, num_steps: Union[int, t.Tensor]) -> t.Tensor:
        '''
        Returns the alphas(s) corresponding to a given number of noise steps
        num_steps: int or int tensor of shape (batch_size,)
        Returns a tensor of shape (batch_size,), where batch_size is one if num_steps is an int
        '''
        # self.alphas = 1-self.betas
        return self.alphas[num_steps]

    @t.inference_mode()
    def alpha_bar(self, num_steps: Union[int, t.Tensor]) -> t.Tensor:
        '''
        Returns the alpha_bar(s) corresponding to a given number of noise steps
        num_steps: int or int tensor of shape (batch_size,)
        Returns a tensor of shape (batch_size,), where batch_size is one if num_steps is an int
        '''
        # self.alpha_bars = self.alphas.prod(dim=0)
        return self.alpha_bars[num_steps]

    def __len__(self) -> int:
        return self.max_steps

In [71]:
def noise_img(
    img: t.Tensor, noise_schedule: NoiseSchedule, max_steps: Optional[int] = None
) -> tuple[t.Tensor, t.Tensor, t.Tensor]:
    '''
    Adds a uniform random number of steps of noise to each image in img.

    img: An image tensor of shape (B, C, H, W)
    noise_schedule: The NoiseSchedule to follow
    max_steps: if provided, only perform the first max_steps of the schedule

    Returns a tuple composed of:
    num_steps: an int tensor of shape (B,) of the number of steps of noise added to each image
    noise: the unscaled, standard Gaussian noise added to each image, a tensor of shape (B, C, H, W)
    noised: the final noised image, a tensor of shape (B, C, H, W)
    '''
    num_images = img.shape[0]
    num_steps = t.randint(0, max_steps, (num_images,), device=img.device)
    # get alpha_bars
    alpha_bar = noise_schedule.alpha_bar(num_steps)
    # use the fast noising equaiton on the images
    # t.sqrt(alpha_dash_t) * x + t.randn_like(x) * t.sqrt(1-alpha_dash_t)
    noise = t.randn_like(img)
    noised_img = t.sqrt(alpha_bar) * img + noise * t.sqrt(1-alpha_bar)
    return num_steps, noise, noised_img

if MAIN:
    noise_schedule = NoiseSchedule(max_steps=200, device="cpu")
    img = gradient_images(1, (3, 16, 16))
    (num_steps, noise, noised) = noise_img(normalize_img(img), noise_schedule, max_steps=10)
    plot_img(img[0], "Gradient")
    plot_img(noise[0], "Applied Unscaled Noise")
    plot_img(denormalize_img(noised[0]), "Gradient with Noise Applied")

In [72]:
def reconstruct(noisy_img: t.Tensor, noise: t.Tensor, num_steps: t.Tensor, noise_schedule: NoiseSchedule) -> t.Tensor:
    '''
    Subtract the scaled noise from noisy_img to recover the original image. We'll later use this with the model's output to log reconstructions during training. We'll use a different method to sample images once the model is trained.

    Returns img, a tensor with shape (B, C, H, W)
    '''
    # x_0 = (x_t- sqrt(1-alpha_bar)*noise)/sqrt(alpha_bar)
    alpha_bar = noise_schedule.alpha_bar(num_steps)
    img = noisy_img/t.sqrt(alpha_bar) - noise * t.sqrt((1-alpha_bar)/alpha_bar)
    return img

if MAIN:
    reconstructed = reconstruct(noised, noise, num_steps, noise_schedule)
    denorm = denormalize_img(reconstructed)
    plot_img(img[0], "Original Gradient")
    plot_img(denorm[0], "Reconstruction")
    plot_img(denormalize_img(noised[0]), "Noised image")
    t.testing.assert_close(denorm, img)

In [111]:
from typing import Tuple

@dataclass
class DiffusionArgs():
    lr: float = 0.001
    image_shape: Tuple = (3, 4, 5)
    epochs: int = 10
    max_steps: int = 100
    batch_size: int = 128
    seconds_between_image_logs: int = 10
    n_images_per_log: int = 3
    n_images: int = 50000
    n_eval_images: int = 1000
    cuda: bool = True
    track: bool = True

class DiffusionModel(nn.Module, ABC):
    image_shape: Tuple[int, ...]
    noise_schedule: Optional[NoiseSchedule]

    @abstractmethod
    def forward(self, images: t.Tensor, num_steps: t.Tensor) -> t.Tensor:
        ...

@dataclass(frozen=True)
class TinyDiffuserConfig:
    image_shape: Tuple[int, ...] = (3, 4, 5)
    hidden_size: int = 128
    max_steps: int = 100

class TinyDiffuser(DiffusionModel):
    def __init__(self, config: TinyDiffuserConfig):
        '''
        A toy diffusion model composed of an MLP (Linear, ReLU, Linear)
        '''
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.image_shape = config.image_shape
        self.noise_schedule = None
        self.max_steps = config.max_steps
        self.device = "cpu"

        input_size = t.prod(t.tensor(self.image_shape)) + 1 # plus 1 for the timestep
        output_size = t.prod(t.tensor(self.image_shape))
        self.mlp = nn.Sequential(
            nn.Linear(input_size, self.hidden_size),
            nn.ReLU(),
            nn.Linear(self.hidden_size, output_size),
        )

    def forward(self, images: t.Tensor, num_steps: t.Tensor) -> t.Tensor:
        '''
        Given a batch of images and noise steps applied, attempt to predict the noise that was applied.
        images: tensor of shape (B, C, H, W)
        num_steps: tensor of shape (B,)

        Returns
        noise_pred: tensor of shape (B, C, H, W)
        '''
        # flatten the images
        images_flat = images.view(images.shape[0], -1)
        # add the timestep
        images_flat = t.cat([images_flat, num_steps.view(-1, 1)], dim=1)
        # predict the noise
        noise_pred_flat = self.mlp(images_flat)
        # reshape the noise
        noise_pred = noise_pred_flat.view(images.shape)
        return noise_pred

if MAIN:
    image_shape = (3, 4, 5)
    n_images = 5
    imgs = gradient_images(n_images, image_shape)
    n_steps = t.zeros(imgs.size(0))
    model_config = TinyDiffuserConfig(image_shape, hidden_size=16)
    model = TinyDiffuser(model_config)
    out = model(imgs, n_steps)
    plot_img(out[0].detach(), "Noise prediction of untrained model")

AttributeError: 'TinyDiffuserConfig' object has no attribute 'device'

## Training Loop

In [109]:
def log_images(
    img: t.Tensor, noised: t.Tensor, noise: t.Tensor, noise_pred: t.Tensor, reconstructed: t.Tensor, num_images: int = 3
) -> list[wandb.Image]:
    '''
    Convert tensors to a format suitable for logging to Weights and Biases. Returns an image with the ground truth in the upper row, and model reconstruction on the bottom row. Left is the noised image, middle is noise, and reconstructed image is in the rightmost column.
    '''
    actual = t.cat((noised, noise, img), dim=-1)
    pred = t.cat((noised, noise_pred, reconstructed), dim=-1)
    log_img = t.cat((actual, pred), dim=-2)
    images = [wandb.Image(i) for i in log_img[:num_images]]
    return images


def train(
    model: DiffusionModel, config_dict: dict[str, Any], trainset: TensorDataset, testset: Optional[TensorDataset] = None
) -> DiffusionModel:
    wandb.init(project="diffusion_models", config=config_dict, mode="disabled")
    config = wandb.config
    print(f"Training with config: {config}")
    model.train()
    optimizer = t.optim.Adam(model.parameters(), lr=config.lr)
    noise_schedule = NoiseSchedule(config.max_steps, device=config.device)
    model.noise_schedule = noise_schedule
    for epoch in range(config.epochs):
        for i, (batch) in enumerate(trainset):
            optimizer.zero_grad()

            batch = t.stack(batch, dim=0)
            
            num_steps, noise, noised_img = noise_img(batch, noise_schedule, config.max_steps)
            # get the predicted noise and then calculate loss
            noise_pred = model(noised_img, num_steps)
            # print(noise_pred.shape, noise.shape, noised_img.shape)
            # noise - noise_pred * (sqrt(alpha_bar))
            loss = t.mean((noise - noise_pred * noised_img) ** 2)
            # calculate gradients
            loss.backward()
            # update weights
            optimizer.step()
            
            wandb.log({"loss": loss.item()})

    
    if testset is not None:
        model.eval()
        with t.no_grad():
            for i, (batch) in enumerate(testset):
                batch = t.stack(batch, dim=0)
                num_steps, noise, noised_img = noise_img(batch, noise_schedule, config.max_steps)
                noise_pred = model(noised_img, num_steps)
                reconstructed = noised_img - noise_pred
                log_images(batch, noised_img, noise, noise_pred, reconstructed, num_images=config.n_images_to_log)
                wandb.log({"images": log_images(batch, noised_img, noise, noise_pred, reconstructed, num_images=config.n_images_to_log)})
    return model


if MAIN:
    config: dict[str, Any] = dict(
        lr=0.001,
        image_shape=(3, 4, 5),
        hidden_size=128,
        epochs=2,
        max_steps=100,
        batch_size=128,
        img_log_interval=200,
        n_images_to_log=3,
        n_images=50000,
        n_eval_images=1000,
        device=t.device("cuda") if t.cuda.is_available() else t.device("cpu"),
    )
    images = normalize_img(gradient_images(config["n_images"], config["image_shape"]))
    train_set = TensorDataset(images)
    images = normalize_img(gradient_images(config["n_eval_images"], config["image_shape"]))
    test_set = TensorDataset(images)
    model_config = TinyDiffuserConfig(config["image_shape"], config["hidden_size"], config["max_steps"])
    model = TinyDiffuser(model_config).to(config["device"])
    model = train(model, config, train_set, test_set)

Training with config: {'lr': 0.001, 'image_shape': [3, 4, 5], 'hidden_size': 128, 'epochs': 2, 'max_steps': 100, 'batch_size': 128, 'img_log_interval': 200, 'n_images_to_log': 3, 'n_images': 50000, 'n_eval_images': 1000, 'device': 'cpu'}


## Sampling from the diffusion model

In [130]:
def sample(model: DiffusionModel, n_samples: int, return_all_steps: bool = False) -> Union[t.Tensor, list[t.Tensor]]:
    '''
    Sample, following Algorithm 2 in the DDPM paper

    model: The trained noise-predictor
    n_samples: The number of samples to generate
    return_all_steps: if true, return a list of the reconstructed tensors generated at each step, rather than just the final reconstructed image tensor.

    out: shape (B, C, H, W), the denoised images
            or (T, B, C, H, W), if return_all_steps=True (where ith element is batched result of (i+1) steps of sampling)
    '''
    schedule = model.noise_schedule
    assert schedule is not None
    num_steps = schedule.max_steps
    # sample noise
    image = t.randn((n_samples, *model.image_shape), device=model.device)
    if return_all_steps:
        image_list = [image]
    # for each timestep from the schedule, get the noise from the model and reconstruct the image
    # reverse the for loop so that the first step is the last step in the schedule
    for i in reversed(range(num_steps)):
        # sample z if i is larger 0
        z = t.randn_like(image) if i > 0 else t.zeros_like(image)
        beta = schedule.beta(i)
        alpha = schedule.alpha(i)
        alpha_bar = schedule.alpha_bar(i)
        i = t.tensor([i]*n_samples, device=model.device)
        # print(i.shape, i, image.shape)
        noise = model(image, i) # epsilon_theta


        image = 1/t.sqrt(alpha) * (image - (1-alpha)/t.sqrt(1-alpha_bar)*noise) + t.sqrt(beta)*z
        if return_all_steps:
            image_list.append(image)
    if return_all_steps: # not working atm
        return image_list
    else:
        return image


if MAIN:
    print("Generating multiple images")
    #assert isinstance(model, DiffusionModel)
    with t.inference_mode():
        samples = sample(model, 5)
    for s in samples:
        plot_img(denormalize_img(s).cpu())
if MAIN:
    print("Printing sequential denoising")
    #assert isinstance(model, DiffusionModel)
    with t.inference_mode():
        samples = sample(model, 1, return_all_steps=True)
    for (i, s) in enumerate(samples):
        if i % (len(samples) // 20) == 0:
            plot_img(denormalize_img(s[0]), f"Step {i}")

Generating multiple images


Printing sequential denoising


In [115]:
isinstance(model, DiffusionModel)

False

In [116]:
type(model)

__main__.TinyDiffuser