# Diffusion Models
## Train image generators in Pytorch

In [None]:
_ = !pip install torch
_ = !pip install torchvision
_ = !pip install matplotlib
_ = !pip install pillow

In [None]:
import os
import random
from typing import Tuple
import numpy as np
import matplotlib.pyplot as plt
from PIL import ImageColor, Image, ImageDraw
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import Adam
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder

In [None]:
N_TRAIN = 300
TRAIN_IMAGE_SIZE = 64
DATA_DIR = "./data"
TRAIN_DATA_DIR = os.path.join(DATA_DIR, "train")
os.makedirs(TRAIN_DATA_DIR, exist_ok=True)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## Training Data
The below functions create images of circles with varying radii and locations and store them in a designated folder. These images will be used as training data for our image generator.

In [None]:
def draw_random_circle_image() -> Image:
    """
    Draw a 64 x 64 image of a circle of random radius.
    """
    image_size = (TRAIN_IMAGE_SIZE, TRAIN_IMAGE_SIZE)
    image = Image.new("RGB", image_size, "white")
    draw = ImageDraw.Draw(image)
    radius = random.randint(10, 25)
    x_shift = random.randint(-10, 10)
    y_shift = random.randint(-10, 10)
    center_x = x_shift + image_size[0] // 2 
    center_y = y_shift + image_size[1] // 2
    color_names = list(ImageColor.colormap.keys())
    color_names.remove("white")
    color = random.choice(color_names)
    draw.ellipse((center_x - radius, center_y - radius, center_x + radius, center_y + radius), fill=color)
    return image


def create_image_file_name(index: int) -> str:
    max_length_name_index = len(str(N_TRAIN))
    name = f"{index}"
    while len(name) < max_length_name_index:
        name = "0" + name
    name += ".png"
    return name
    

for i in range(N_TRAIN):
    name = create_image_file_name(i)
    save_path = os.path.join(TRAIN_DATA_DIR, name)
    image = draw_random_circle_image()
    image.save(save_path)

In [None]:
def plot_random_train_images():

    n_images = 10
    n_cols = 5
    n_rows = 2

    fig, axes = plt.subplots(nrows=2, ncols=5, figsize=(15, 5))
    axes = axes.flatten()
    
    for i, ax in enumerate(axes):
        index = random.randint(0, N_TRAIN - 1)
        image_file_name = create_image_file_name(index)
        image_path = os.path.join(TRAIN_DATA_DIR, image_file_name)
        image = Image.open(image_path)
        ax.imshow(image)

plot_random_train_images()

## Dataset for Training
The images can be organized into a data set which makes it more convenient to work through during training. In addition it offers a way to define a sequence of transformations to be applied to the images.

In our case, we need to scale the pixel values to a range of -1 to 1. As of now, their values range from 0-255 for each pixel color (RGB). The Pytorch toTensor() operation scales them automatically to a range of 0-1. Therefore, we apply one more function to shift the values into the desired range of -1 to 1.

In [None]:
pix = np.array(image)
plt.figure(figsize=(3, 3))
_ = plt.hist(pix.flatten())

In [None]:
def load_and_transform_dataset(data_dir: str = DATA_DIR) -> ImageFolder:
    data_transforms = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda t: (t * 2) - 1)
    ])
    data_set = ImageFolder(data_dir, transform=data_transforms)
    return data_set

train_dataset = load_and_transform_dataset()
train_dataset

In [None]:
x_0, label_0 = train_dataset[0]
x_0 = x_0.to(device)
print(type(x_0))
print(x_0.shape)
_ = plt.hist(x_0.cpu().numpy().flatten())

In [None]:
def tensor_to_image(xt: torch.Tensor) -> Image:
    """
    Revert the transformations on the tensor and return corresponding Image
    """

    xt = xt.cpu()

    if len(xt.shape) == 4 and xt.shape[0] == 1:
        xt = xt.squeeze()
    
    reverse_transforms = transforms.Compose([
        transforms.Lambda(lambda t: (t + 1) / 2),
        transforms.Lambda(lambda t: t.permute(1, 2, 0)),
        transforms.Lambda(lambda t: t * 255.),
        transforms.Lambda(lambda t: t.numpy().astype(np.uint8)),
        transforms.ToPILImage(),
    ])
    return reverse_transforms(xt)

image = tensor_to_image(x_0)
plt.imshow(image)

## Apply Noise to Images
In this forward process we apply noise to an image over the course of numerous steps. The added noise increases by step (so-called noise-schedule). In the end of the step process, the image will be complete uniformly distributed noise.

![noise equation](./noise_equation.png "Noise Equation")

In [None]:
n_steps = 200
betas = torch.linspace(start=0.0001, end=0.01, steps=n_steps, device=device)
alphas = 1. - betas
alphas_bar = torch.cumprod(alphas, axis=0)
sqrt_alphas_bar = torch.sqrt(alphas_bar)
sqrt_one_minus_alphas_bar = torch.sqrt(1. - alphas_bar)


def apply_noise(xt: torch.Tensor, t: torch.Tensor, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Apply noise to an image tensor at a giving time step t according to the noise schedule beta
    """
    noise = torch.randn_like(xt, device=device)
    
    sqrt_alphas_bar_t = sqrt_alphas_bar.gather(-1, t).reshape(t.shape[0], 1, 1, 1)
    sqrt_one_minus_alphas_bar_t = sqrt_one_minus_alphas_bar.gather(-1, t).reshape(t.shape[0], 1, 1, 1)

    noisy_image = sqrt_alphas_bar_t.to(device) * xt.to(device) + sqrt_one_minus_alphas_bar_t.to(device) * noise.to(device)

    return noisy_image, noise

In [None]:
x_0_batch = x_0.unsqueeze(0).repeat(5, 1, 1, 1)
t = torch.linspace(start=0, end=n_steps-1, steps=5, device=device).long()
x_noise_t, noise_t = apply_noise(xt=x_0_batch, t=t, device=device)
fig, axes = plt.subplots(nrows=1, ncols=6, figsize=(18, 3))
x_all = torch.cat((x_0.unsqueeze(0), x_noise_t), dim=0)
for i, ax in enumerate(axes):
    image = tensor_to_image(x_all[i])
    ax.imshow(image)

## Representation and Model

### Positional Encoding
The images are well represented by the tensors we have transformed them into. 
For the time "positions" reprenting the step index in the noise sequence we need an encoding which transforms sequential integers into fixed-dimensional vector representations. 
For details on that process, please see my video and notebook on that topic.

In [None]:
class PositionalEncoding(nn.Module):
    """
    Sinusoidal Position Encodings
    """

    def __init__(self, dim_out: int):
        super().__init__()
        self.dim_out = dim_out
        self.n = 10000

    def forward(self, pos: torch.Tensor):
        even = torch.arange(0, 2 * self.dim_out / 2, 2).view(1, -1).repeat(pos.size(0), 1).to(device)
        odd = torch.arange(1, 2 * self.dim_out / 2 + 1, 2).view(1, -1).repeat(pos.size(0), 1).to(device)
        pos = pos.view(-1, 1)
        even = torch.sin(pos / torch.pow(self.n, (even / self.dim_out)))
        odd = torch.cos(pos / torch.pow(self.n, (odd / self.dim_out)))
        out = torch.zeros(pos.shape[0], self.dim_out).to(device)
        out[:, 0::2] = even
        out[:, 1::2] = odd
        return out     

## Unet Model
The Unet model is U-shaped in the sense that it "down-samples" a tensor via convoluations and afterwards "up-samples" via de-convolutions. The ouput of a Unet is therefore of the same dimension as the input. This makes is a suited model for infering a tensor of noise pixels for an input image tensor.
We add the time encoding vector in the process in order to include the information at which step in the noise process we are in.

In [None]:
class AbstractBlock(nn.Module):

    def __init__(self, in_channels: int, out_channels: int, time_encoding_dim: int):
        super().__init__()
        
        self.time_linear =  nn.Linear(in_features=time_encoding_dim, out_features=out_channels)
        
        self.conv1 = None
        self.transform = None

        self.conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
        self.batch_norm1 = nn.BatchNorm2d(num_features=out_channels)
        self.batch_norm2 = nn.BatchNorm2d(num_features=out_channels)
        self.relu  = nn.ReLU()

    def forward(self, x, t):

        h = self.batch_norm1(self.relu(self.conv1(x)))
        time_encoding = self.relu(self.time_linear(t))
        time_encoding = time_encoding.unsqueeze(-1).unsqueeze(-1)
        h = h + time_encoding
        h = self.batch_norm2(self.relu(self.conv2(h)))
        return self.transform(h)


class UpBlock(AbstractBlock):

    def __init__(self, in_channels: int, out_channels: int, time_encoding_dim: int):
        super().__init__(in_channels, out_channels, time_encoding_dim)

        self.conv1 = nn.Conv2d(in_channels=2*in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
        self.transform = nn.ConvTranspose2d(in_channels=out_channels, out_channels=out_channels, kernel_size=4, stride=2, padding=1)


class DownBlock(AbstractBlock):

    def __init__(self, in_channels: int, out_channels: int, time_encoding_dim: int):
        super().__init__(in_channels, out_channels, time_encoding_dim)

        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=1, padding=1)
        self.transform = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=4, stride=2, padding=1)

In [None]:
class Unet(nn.Module):
    """
    Unet architecture

    (Time) position encodings (32 dim)
    + 5 layers of convolutional downsampling
    + 5 layers of convolutional upsampling
    """
    def __init__(self):
        super().__init__()
        image_channels = 3
        down_channels = (64, 128, 256, 512, 1024)
        up_channels = (1024, 512, 256, 128, 64)
        out_dim = 3 
        time_pos_dim = 32

        self.pos_linear = nn.Sequential(
                PositionalEncoding(time_pos_dim),
                nn.Linear(time_pos_dim, time_pos_dim),
                nn.ReLU()
            ).to(device)
        
        self.conv0 = nn.Conv2d(in_channels=image_channels, out_channels=down_channels[0], kernel_size=3, stride=1, padding=1, device=device)

        # Downsample
        self.downs = nn.ModuleList([DownBlock(down_channels[i], down_channels[i+1], \
                                    time_pos_dim) \
                    for i in range(len(down_channels)-1)]).to(device)
        # Upsample
        self.ups = nn.ModuleList([UpBlock(up_channels[i], up_channels[i+1], \
                                        time_pos_dim) \
                    for i in range(len(up_channels)-1)]).to(device)
        
        self.output = nn.Conv2d(up_channels[-1], out_dim, 1).to(device)

    def forward(self, x, timestep):
        t = self.pos_linear(timestep)
        x = self.conv0(x)
        residual_inputs = []
        for down in self.downs:
            x = down(x, t)
            residual_inputs.append(x)
        for up in self.ups:
            residual_x = residual_inputs.pop()
            x = torch.cat((x, residual_x), dim=1)           
            x = up(x, t)
        return self.output(x)

In [None]:
x_0 = x_0.to(device)
x_0_un = x_0.unsqueeze(0)
pos = torch.tensor([5], device=device).unsqueeze(0)
unet = Unet()
unet(x_0_un, pos).shape

## Generate an Image
Given a model which can estimate the noise added at a given timestep, we can generate an image that is representative of the training set. We start at a noisy image, predict its added noise and subtract it. We repeat this process n_steps times and will end with a generated de-noised image.

In [None]:
sqrt_one_over_alphas = torch.sqrt(1.0 / alphas)
alphas_bar_prev = F.pad(alphas_bar[:-1], (1, 0), value=1.0)
posterior_variance = betas * (1. - alphas_bar_prev) / (1. - alphas_bar)

@torch.no_grad()
def denoised_image_at_timestemp(model: nn.Module, x: torch.Tensor, t: int):
    t_tensor = torch.Tensor([t]).to(device)
    noise_pred = model(x, t_tensor)
    denoised_x = sqrt_one_over_alphas[t] * (x - betas[t] * noise_pred / sqrt_one_minus_alphas_bar[t])
    if t == 0:
        return denoised_x
    else:
        noise = torch.randn_like(x)
        return denoised_x + torch.sqrt(posterior_variance[t]) * noise 

In [None]:
@torch.no_grad()
def sequential_denoising(model: nn.Module):
    img = torch.randn((1, 3, TRAIN_IMAGE_SIZE, TRAIN_IMAGE_SIZE), device=device)
    
    n_images = 6
    show_image_at = int(n_steps/n_images)

    for t in range(n_steps-1, -1, -1):
        img = denoised_image_at_timestemp(model, img, t)
        img = torch.clamp(img, -1.0, 1.0)

        if t % show_image_at == 0:
            image = tensor_to_image(img)
            plt.figure()
            plt.imshow(image)

In [None]:
model = Unet().to(device)
sequential_denoising(model)

### Train the model
We apply a standard deep learning flow to train the model. We use L1 loss and and Adam Optimizer to tune the model weights. For details on the training flow, please see my video and notes on that topic.

In [None]:
epochs = 1000
learning_rate = 0.001
batch_size = 64
score_every = 10

model = Unet().to(device)
optimizer = Adam(model.parameters(), lr=learning_rate)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

for epoch in range(epochs):
    for step, batch in enumerate(train_dataloader):

        x = batch[0]
        
        optimizer.zero_grad()
        
        t = torch.randint(0, n_steps, (x.shape[0],), device=device).long()
        x_noise_t, noise_t = apply_noise(xt=x, t=t, device=device)
        noise_pred = model(x_noise_t, t)
        loss = F.l1_loss(noise_t, noise_pred)
        
        loss.backward()
        optimizer.step()

        
        if epoch % score_every == 0 and step == 0:
            print(f"epoch {epoch} --- loss: {loss.item()} ")
            

sequential_denoising(model)