# Denoising Diffusion Probabilistic Models

In this notebook, we'll implement the Denoising Diffusion Probabilistic Model (DDPM) proposed in the paper [Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239) from scratch using PyTorch. We'll implement the U-Net model, the forward and reverse diffusion processes, the training loop and the sampling process.

In [163]:
import torch 
import torch.nn as nn
import matplotlib.pyplot as plt
import math

## UNet

The U-net is implemented as in the original paper. It's been divided into blocks to make it more modular and easier to work with. The "downblock" consists of a series of convolutional layers, each followed by a maxpooling layer. The "upblock" consists of a convolutional layer followed by a transpose convolutional layer. The bottom-most conv block is implemented separately and is different from the rest of the blocks.

In [164]:
class ConvBlock(nn.Module):
    """
    Convolutional block with two convolutional layers.
    The first conv layer changes the number of channels
    The second maintains the number of channels
    """
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1):
        super(ConvBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, stride)
    
    def forward(self, x):
        x = self.relu(self.conv1(x))
        return self.relu(self.conv2(x))
    
class DownBlock(nn.Module):
    """
    Downampling (left) side of the UNet.
    Excludes the bottom-most conv block.
    """
    def __init__(self, filters, in_channels):
        super(DownBlock, self).__init__()
        conv_blocks = [ConvBlock(in_channels, filters[0])]
        for i in range(1, len(filters)):
            conv_blocks.append(ConvBlock(filters[i-1], filters[i]))

        self.conv_blocks = nn.Sequential(*conv_blocks)
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        residual_outputs = []
        for conv_block in self.conv_blocks:
            x = conv_block(x)
            residual_outputs.append(x)
            x = self.maxpool(x)

        return residual_outputs, x

class UpBlock(nn.Module):
    """
    Upsampling (right) side of the UNet.
    """
    def __init__(self, filters):
        super(UpBlock, self).__init__()
        layers = []
        for i in range(len(filters) - 2):
            layers.append(
                nn.Sequential(
                    ConvBlock(filters[i], filters[i+1]), 
                    nn.ConvTranspose2d(filters[i+1], filters[i+1]//2, 2, stride=2)
                )
            )
        
        layers.append(ConvBlock(filters[-2], filters[-1]))
        self.layers = nn.Sequential(*layers)
    
    def forward(self, x, residual_outputs):
        for i in range(len(self.layers)):
            print(f"i: {i}")
            residual = residual_outputs[-(i+1)]
            _, _, h, w = x.shape
            residual = residual[:, :, :h, :w]
            print(f"x: {x.shape}, residual: {residual.shape}")
            x = torch.cat([x, residual], dim=1)
            x = self.layers[i](x)
            print(f"x: {x.shape}")
        
        return x


In [165]:
db = DownBlock([32, 64, 128], 3)

In [166]:
db

DownBlock(
  (conv_blocks): Sequential(
    (0): ConvBlock(
      (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1))
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
    )
    (1): ConvBlock(
      (conv1): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
    )
    (2): ConvBlock(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
    )
  )
  (maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)

In [167]:
a = db(torch.randn(1, 3, 128, 128))
a[1].shape

torch.Size([1, 128, 12, 12])

In [168]:
for residual_output in a[0]:
    print(residual_output.shape)
    print("-"*10)

print(a[1].shape)

torch.Size([1, 32, 124, 124])
----------
torch.Size([1, 64, 58, 58])
----------
torch.Size([1, 128, 25, 25])
----------
torch.Size([1, 128, 12, 12])


In [169]:
bottom_conv = nn.Sequential(ConvBlock(128, 256), nn.ConvTranspose2d(256, 128, 2, stride=2))
bottom_conv(a[1]).shape

torch.Size([1, 128, 16, 16])

In [170]:
up = UpBlock([256, 128, 64, 32])



In [171]:
for layer in up.layers:
    print(layer)
    print("-"*10)

Sequential(
  (0): ConvBlock(
    (conv1): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1))
    (relu): ReLU(inplace=True)
    (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
  )
  (1): ConvTranspose2d(128, 64, kernel_size=(2, 2), stride=(2, 2))
)
----------
Sequential(
  (0): ConvBlock(
    (conv1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1))
    (relu): ReLU(inplace=True)
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
  )
  (1): ConvTranspose2d(64, 32, kernel_size=(2, 2), stride=(2, 2))
)
----------
ConvBlock(
  (conv1): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1))
  (relu): ReLU(inplace=True)
  (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
)
----------


In [172]:
bottom_out = bottom_conv(a[1])
bottom_out.shape


torch.Size([1, 128, 16, 16])

In [173]:
up_out = up(bottom_out, a[0])

i: 0
x: torch.Size([1, 128, 16, 16]), residual: torch.Size([1, 128, 16, 16])
x: torch.Size([1, 64, 24, 24])
i: 1
x: torch.Size([1, 64, 24, 24]), residual: torch.Size([1, 64, 24, 24])
x: torch.Size([1, 32, 40, 40])
i: 2
x: torch.Size([1, 32, 40, 40]), residual: torch.Size([1, 32, 40, 40])
x: torch.Size([1, 32, 36, 36])


In [174]:
class UNet(nn.Module):
    def __init__(self, down_filters, in_channels):
        super(UNet, self).__init__()
        self.down_filters = down_filters
        self.down_block = DownBlock(down_filters, in_channels)
        
        # the bottom-most conv block is different from the rest of the blocks
        # in that it doesn't contain a maxpool and upsamples without a residual connection
        self.bottom_conv = nn.Sequential(
            ConvBlock(down_filters[-1], down_filters[-1]*2), 
            nn.ConvTranspose2d(down_filters[-1]*2, down_filters[-1], 2, stride=2)
        )
        
        self.up_filters = [down_filters[-1]*2]
        self.up_filters.extend(reversed(down_filters))
        self.up_block = UpBlock(self.up_filters)

    def forward(self, x):
        residual_outputs, down_output = self.down_block(x)
        bottom_output = self.bottom_conv(down_output)
        return self.up_block(bottom_output, residual_outputs)

In [175]:
u_net = UNet([32, 64, 128], 3)
a = u_net(torch.randn(1, 3, 128, 128))
a.shape

i: 0
x: torch.Size([1, 128, 16, 16]), residual: torch.Size([1, 128, 16, 16])
x: torch.Size([1, 64, 24, 24])
i: 1
x: torch.Size([1, 64, 24, 24]), residual: torch.Size([1, 64, 24, 24])
x: torch.Size([1, 32, 40, 40])
i: 2
x: torch.Size([1, 32, 40, 40]), residual: torch.Size([1, 32, 40, 40])
x: torch.Size([1, 32, 36, 36])


torch.Size([1, 32, 36, 36])

In [180]:
def alpha_bar(beta):
    alpha = 1. - beta
    return alpha.cumprod(dim=0)

def prepare_batch(x: torch.Tensor, T: int, alpha_bar: torch.Tensor):
    """
    Prepare a batch for training by generating the noise and the noisy image.
    """
    batch_size = x.shape[0]
    t = torch.randint(0, T, (batch_size,))
    e = torch.randn_like(x)
    alpha_bar_t = alpha_bar(t)
    alpha_bar_t = alpha_bar_t.view(-1, 1, 1, 1)
    noisy_images = alpha_bar_t.sqrt() * x + (1 - alpha_bar_t).sqrt() * e
    
    return (noisy_images, t), e


In [185]:
x = torch.randn(15, 3, 128, 128)
T = 1000
batch_size = x.shape[0]
t = torch.randint(0, T, (batch_size,))
e = torch.randn_like(x)
alpha_bar_t = alpha_bar(t)
print(alpha_bar_t.shape)
alpha_bar_t = alpha_bar_t.view(-1, 1, 1, 1)

torch.Size([15])


In [184]:
alpha_bar_t.shape

torch.Size([15, 1, 1, 1])

In [176]:
def validate_model(model, valid_loader, loss_fn, all_valid_loss):
    model.eval()
    valid_loss = []
    
    with torch.no_grad():
        for batch in valid_loader:
            x, y = batch
            y_pred = model(x)
            loss = loss_fn(y_pred, y)
            valid_loss.append(loss.item())
    
    all_valid_loss.append(sum(valid_loss) / len(valid_loss))

def plot_loss(all_train_loss, all_valid_loss):
    plt.figure(figsize=(10, 5))
    plt.plot(all_train_loss, label='Training Loss')
    plt.plot(all_valid_loss, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()


def train_model(model: nn.Module,
                optim: torch.optim.Optimizer,
                loss_fn,
                train_loader,
                valid_loader,
                scheduler,
                epochs=10,
                valid_every=1
                ):

    all_train_loss = []
    all_valid_loss = []
    
    for epoch in range(epochs):
        model.train()
        train_loss = []
        
        for batch in train_loader:
            x, _ = batch
            (noisy_images, t), e = prepare_batch(x, T, alpha_bar_t)
            optim.zero_grad()
            y_pred = model(noisy_images)
            loss = loss_fn(y_pred, e)
            loss.backward()
            optim.step()
            train_loss.append(loss.item())
        
        all_train_loss.append(sum(train_loss) / len(train_loss))
        scheduler.step()
        
        if epoch % valid_every == 0:
            validate_model(model, valid_loader, loss_fn, all_valid_loss)
            print(
                f"Epoch {epoch}, Train Loss: {sum(train_loss) / len(train_loss)}, "
                f"Valid Loss: {all_valid_loss[-1]}"
            )
    
    plot_loss(all_train_loss, all_valid_loss)


Changes that need to be made to the U-Net for DDPM:
1. Swap batch norm with group norm 
2. Introduce an attention mechanism at each conv block in the down and up blocks
3. Create an embedding for the timestep

Next, we'll implement these missing features

In [None]:
def scaled_dot_product_attention(q, k, d_k, mask):
    scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is True:
        mask = torch.tril(torch.ones(scores.shape)).to(q.device)
        scores = scores.masked_fill(mask == 0, float('-inf'))
    return nn.Softmax(-1)(scores)

class Attention(nn.Module):
    """
    Multihead attention class implementation. Can act as self-attention (default, y is None)
    or cross attention if y not none
    """
    def __init__(self, d_k, d_model, d_v, dropout, mask) -> None:
        super(Attention, self).__init__()
        self.d_k, self.d_v, self.d_model = d_k, d_v, d_model
        self.query_layer, self.key_layer, self.value_layer = (
            nn.Linear(d_model, d_k), 
            nn.Linear(d_model, d_k), 
            nn.Linear(d_model, d_v)
        )
        self.layer_norm = nn.LayerNorm(d_model)
        self.concat_projection = nn.Linear(d_v, d_model)
        self.dropout = nn.Dropout(dropout)
        self.mask = mask

    def forward(self, x, y = None):
        residual = x
        x = self.layer_norm(x)
        if y is not None:
            k, q, v = y, x, y
        else:
            k, q, v = x, x, x
        
        k_len, q_len, v_len, batch_size = k.size(1), q.size(1), v.size(1),  q.size(0)
        k = self.key_layer(k).view(batch_size, k_len,  self.d_k)
        q = self.query_layer(q).view(batch_size, q_len,  self.d_k)
        v = self.value_layer(v).view(batch_size, v_len,  self.d_v)
        attention = scaled_dot_product_attention(q.transpose(1, 2), k.transpose(1, 2), self.d_k, self.mask)
        output = torch.matmul(attention, v.transpose(1, 2))
        output = self.concat_projection(output.transpose(1, 2).contiguous().view(batch_size, q_len, -1))
        
        return self.dropout(output) + residual