# üî• Diffusion Models

In this notebook, we'll walk through the steps required to train your own diffusion model on the Oxford flowers dataset

The code is adapted from the excellent ['Denoising Diffusion Implicit Models' tutorial](https://keras.io/examples/generative/ddim/) created by Andr√°s B√©res available on the Keras website.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as T
import torchvision.datasets as datasets
import numpy as np
import matplotlib.pyplot as plt
import math
from tqdm import tqdm

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

## 0. Parameters <a name="parameters"></a>

In [None]:
IMAGE_SIZE = 64
BATCH_SIZE = 64
DATASET_REPETITIONS = 5
LOAD_MODEL = False

NOISE_EMBEDDING_SIZE = 32
PLOT_DIFFUSION_STEPS = 20

# optimization
EMA = 0.999
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 1e-4
EPOCHS = 50

## 1. Prepare the Data

In [None]:
# ‰ΩøÁî® ImageFolder Âä†ËΩΩËä±ÂçâÊï∞ÊçÆÈõÜÔºàËá™Ë°åÊõøÊç¢Ë∑ØÂæÑÔºâ
transform = T.Compose([
    T.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    T.ToTensor(),
])

train_dataset = datasets.ImageFolder(
    root="/app/data/pytorch-challange-flower-dataset/dataset",
    transform=transform
)

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    drop_last=True
)


### 1.1 Diffusion schedules <a name="diffusion_schedules"></a>

In [None]:
def linear_diffusion_schedule(diffusion_times):
    min_rate = 0.0001
    max_rate = 0.02
    betas = min_rate + diffusion_times * (max_rate - min_rate)
    alphas = 1 - betas
    alpha_bars = torch.cumprod(alphas, dim=0)
    signal_rates = torch.sqrt(alpha_bars)
    noise_rates = torch.sqrt(1 - alpha_bars)
    return noise_rates, signal_rates

def cosine_diffusion_schedule(diffusion_times):
    signal_rates = torch.cos(diffusion_times * math.pi / 2)
    noise_rates = torch.sin(diffusion_times * math.pi / 2)
    return noise_rates, signal_rates

def offset_cosine_diffusion_schedule(diffusion_times):
    min_signal_rate = 0.02
    max_signal_rate = 0.95
    start_angle = torch.acos(torch.tensor(max_signal_rate))
    end_angle = torch.acos(torch.tensor(min_signal_rate))
    diffusion_angles = start_angle + diffusion_times * (end_angle - start_angle)
    signal_rates = torch.cos(diffusion_angles)
    noise_rates = torch.sin(diffusion_angles)
    return noise_rates, signal_rates

In [None]:
T_steps = 1000
diffusion_times = torch.linspace(0, 1, T_steps)
linear_noise_rates, linear_signal_rates = linear_diffusion_schedule(diffusion_times)
cosine_noise_rates, cosine_signal_rates = cosine_diffusion_schedule(diffusion_times)
offset_cosine_noise_rates, offset_cosine_signal_rates = offset_cosine_diffusion_schedule(diffusion_times)

In [None]:
plt.plot(diffusion_times, linear_signal_rates**2, label="linear")
plt.plot(diffusion_times, cosine_signal_rates**2, label="cosine")
plt.plot(diffusion_times, offset_cosine_signal_rates**2, label="offset_cosine")
plt.xlabel("t/T")
plt.ylabel(r"$\bar{\alpha_t}$ (signal)")
plt.legend()
plt.show()



In [None]:

plt.plot(diffusion_times, linear_noise_rates**2, label="linear")
plt.plot(diffusion_times, cosine_noise_rates**2, label="cosine")
plt.plot(diffusion_times, offset_cosine_noise_rates**2, label="offset_cosine")
plt.xlabel("t/T")
plt.ylabel(r"$1-\bar{\alpha_t}$ (noise)")
plt.legend()
plt.show()

# ## 2. Noise embedding (sinusoidal)<a name="build"></a>

In [None]:
# %%
def sinusoidal_embedding(x, dim=NOISE_EMBEDDING_SIZE):
    device = x.device if isinstance(x, torch.Tensor) else torch.device("cpu")
    half_dim = dim // 2
    frequencies = torch.exp(torch.linspace(math.log(1.0), math.log(1000.0), half_dim, device=device))
    angles = 2 * math.pi * frequencies * x
    embeddings = torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1)
    return embeddings

## 3. Build the model <a name="build"></a>

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.norm = nn.BatchNorm2d(channels, affine=False)
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.activation = nn.SiLU()
    
    def forward(self, x):
        residual = x
        x = self.activation(self.conv1(self.norm(x)))
        x = self.conv2(x)
        return x + residual

class DownBlock(nn.Module):
    def __init__(self, in_ch, out_ch, depth):
        super().__init__()
        layers_ = [ResidualBlock(in_ch) for _ in range(depth)]
        self.layers = nn.Sequential(*layers_)
        self.pool = nn.AvgPool2d(2)
    def forward(self, x):
        x = self.layers(x)
        x_down = self.pool(x)
        return x_down, x

class UpBlock(nn.Module):
    def __init__(self, in_ch, out_ch, depth):
        super().__init__()
        layers_ = [ResidualBlock(in_ch) for _ in range(depth)]
        self.layers = nn.Sequential(*layers_)
    def forward(self, x, skip):
        x = nn.functional.interpolate(x, scale_factor=2, mode="bilinear", align_corners=False)
        x = torch.cat([x, skip], dim=1)
        x = self.layers(x)
        return x

class UNet(nn.Module):
    def __init__(self, in_ch=3, base_ch=32):
        super().__init__()
        self.init_conv = nn.Conv2d(in_ch, base_ch, kernel_size=1)
        self.down1 = DownBlock(base_ch, base_ch, 2)
        self.down2 = DownBlock(base_ch, base_ch*2, 2)
        self.down3 = DownBlock(base_ch*2, base_ch*3, 2)
        self.mid1 = ResidualBlock(base_ch*3)
        self.mid2 = ResidualBlock(base_ch*3)
        self.up3 = UpBlock(base_ch*3*2, base_ch*3, 2)
        self.up2 = UpBlock(base_ch*3*2, base_ch*2, 2)
        self.up1 = UpBlock(base_ch*2*2, base_ch, 2)
        self.final_conv = nn.Conv2d(base_ch, 3, kernel_size=1)
    
    def forward(self, x, noise_emb):
        x = self.init_conv(x)
        # Concatenate noise embedding
        noise_emb_up = noise_emb.expand(-1, -1, x.shape[2], x.shape[3])
        x = torch.cat([x, noise_emb_up], dim=1)
        # Down
        x, skip1 = self.down1(x)
        x, skip2 = self.down2(x)
        x, skip3 = self.down3(x)
        x = self.mid1(x)
        x = self.mid2(x)
        # Up
        x = self.up3(x, skip3)
        x = self.up2(x, skip2)
        x = self.up1(x, skip1)
        x = self.final_conv(x)
        return x

In [None]:
class DiffusionModel(nn.Module):
    def __init__(self, unet, device=DEVICE):
        super().__init__()
        self.device = device
        self.unet = unet.to(device)
        self.ema_unet = UNet()
        self.ema_unet.load_state_dict(unet.state_dict())
        self.ema_unet.to(device)
        self.diffusion_schedule = offset_cosine_diffusion_schedule
    
    def denoise(self, x, noise_rates, signal_rates):
        pred_noises = self.unet(x, noise_rates**2)
        pred_images = (x - noise_rates * pred_noises) / signal_rates
        return pred_noises, pred_images
    
    @torch.no_grad()
    def reverse_diffusion(self, initial_noise, diffusion_steps):
        current_images = initial_noise
        step_size = 1.0 / diffusion_steps
        for step in range(diffusion_steps):
            diffusion_times = torch.ones((current_images.size(0),1,1,1), device=self.device) - step * step_size
            noise_rates, signal_rates = self.diffusion_schedule(diffusion_times)
            pred_noises, pred_images = self.denoise(current_images, noise_rates, signal_rates)
            next_diffusion_times = diffusion_times - step_size
            next_noise_rates, next_signal_rates = self.diffusion_schedule(next_diffusion_times)
            current_images = next_signal_rates * pred_images + next_noise_rates * pred_noises
        return current_images
    
    def generate(self, num_images, diffusion_steps, initial_noise=None):
        if initial_noise is None:
            initial_noise = torch.randn((num_images, 3, IMAGE_SIZE, IMAGE_SIZE), device=self.device)
        gen_images = self.reverse_diffusion(initial_noise, diffusion_steps)
        gen_images = torch.clamp(gen_images, 0.0, 1.0)
        return gen_images

In [None]:
unet = UNet()
ddm = DiffusionModel(unet)
optimizer = optim.AdamW(unet.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
loss_fn = nn.L1Loss()

In [None]:
for epoch in range(EPOCHS):
    for batch, _ in tqdm(train_loader):
        batch = batch.to(DEVICE)
        noises = torch.randn_like(batch)
        diffusion_times = torch.rand((batch.size(0),1,1,1), device=DEVICE)
        noise_rates, signal_rates = offset_cosine_diffusion_schedule(diffusion_times)
        noisy_images = signal_rates * batch + noise_rates * noises
        pred_noises, _ = ddm.denoise(noisy_images, noise_rates, signal_rates)
        loss = loss_fn(pred_noises, noises)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # EMA update
        with torch.no_grad():
            for p, ema_p in zip(unet.parameters(), ddm.ema_unet.parameters()):
                ema_p.data.mul_(EMA).add_(p.data*(1-EMA))
    print(f"Epoch {epoch} loss: {loss.item()}")

# ## 6. Sampling<a name="build"></a>

In [None]:
ddm.generate(num_images=4, diffusion_steps=20)