# Denoising Diffusion Probabilistic Models

In this notebook, we'll implement the Denoising Diffusion Probabilistic Model (DDPM) proposed in the paper [Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2006.11239) from scratch using PyTorch. We'll implement the U-Net model, the forward and reverse diffusion processes, the training loop and the sampling process.

In [153]:
import torch 
import torch.nn as nn
import matplotlib.pyplot as plt
import math

In [154]:
def alpha_bar(beta):
    alpha = 1. - beta
    return alpha.cumprod(dim=0)

def prepare_batch(x: torch.Tensor, T: int, alpha_bar: torch.Tensor):
    """
    Prepare a batch for training by generating the noise and the noisy image.
    """
    batch_size = x.shape[0]
    t = torch.randint(0, T, (batch_size,))
    e = torch.randn_like(x)
    alpha_bar_t = alpha_bar(t)
    alpha_bar_t = alpha_bar_t.view(-1, 1, 1, 1)
    noisy_images = alpha_bar_t.sqrt() * x + (1 - alpha_bar_t).sqrt() * e
    
    return (noisy_images, t), e

## UNet

We split the U-net into 3 blocks, Left, Middle and Right.

At the root of it, we have the conv block. Each layer in the left and the right blocks consists of a series of ConvBlocks followed by a maxpool or an upsample. 

The convblocks are resnets with group normalization and are imported from src.resnet.py. We changed the original resnets to 

Changes that need to be made to a regular U-Net for DDPM:
1. Swap batch norm with group norm 
2. Introduce an attention mechanism at each conv block in the down and up blocks
3. Create an embedding for the timestep
4. Develop a 2D attention mechanism based on the one used in the text transformer model in [aryamanpandya99/transformers](https://github.com/aryamanpandya99/transformers)


In [155]:
from src.resnet import ResBlock


class ConvBlock(nn.Module):
    def __init__(
            self, 
            in_channels: int, 
            out_channels: int, 
            num_layers: int, 
            num_groups: int = 1, 
            dropout: float = 0.2, 
            activation: nn.Module = nn.ReLU,
            timestep_emb_dim: int = None
            ):
        super(ConvBlock, self).__init__()
        convs = []
        convs.append(
            ResBlock(
                in_channels, 
                out_channels, 
                num_groups=num_groups, 
                dropout=dropout, 
                activation=activation,
                timestep_emb_dim=timestep_emb_dim
            )
        )
        
        for _ in range(num_layers-1):
            convs.append(
                ResBlock(
                    out_channels,
                    out_channels, 
                    num_groups=num_groups, 
                    dropout=dropout, 
                    activation=activation,
                    timestep_emb_dim=timestep_emb_dim
                )
            )

        self.convs = nn.ModuleList(convs)

    def forward(self, x, timestep_emb=None):
        for res_block in self.convs:
            x = res_block(x, timestep_emb)
        
        return x

In [156]:
def scaled_dot_product_attention(q, k, d_k, mask):
    scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is True:
        mask = torch.tril(torch.ones(scores.shape)).to(q.device)
        scores = scores.masked_fill(mask == 0, float('-inf'))
    return nn.Softmax(-1)(scores)

class Attention(nn.Module):
    """
    Multihead attention.

    Differences from multihead attention for text:
    
    - we no longer need a d_model, the internal hidden size 
    is determined by the number of channels which is determined 
    by the convolutional layers leading up to the attention layer.

    - swap batch norm with group norm

    - we resize the image to one of shape: (batch_size, num_channels, height * width)
    so that we can perform multihead attention across the image. This closely 
    mirrors (batch_size, embed_dim, seq_len
    """
    def __init__(self,
                 d_k: int, 
                 dropout: float, 
                 num_heads: int, 
                 num_channels: int,
                 num_groups: int = 8,
                 mask: bool = False
                 ):
        super(Attention, self).__init__()
        self.d_k, self.num_heads = d_k, num_heads
        self.query_projection, self.key_projection, self.value_projection = (
            nn.Linear(num_channels, num_heads* d_k),
            nn.Linear(num_channels, num_heads* d_k), 
            nn.Linear(num_channels, num_heads*d_k)
        )
        self.layer_norm = nn.LayerNorm(num_channels)
        self.output_layer = nn.Linear(num_heads*d_k, num_channels)
        self.dropout = nn.Dropout(dropout)
        self.mask = mask
        self.num_channels = num_channels

    def forward(self, x, y = None):
        
        batch_size, n_channels, height, width = x.shape
        x = x.view(batch_size, n_channels, height * width).permute(0, 2, 1)
        residual = x
        x = self.layer_norm(x)
        
        if y is not None:
            k, q, v = y, x, y
        else:
            k, q, v = x, x, x
        
        k_len, q_len, v_len, batch_size = k.size(1), q.size(1), v.size(1),  q.size(0)
        
        k = self.key_projection(k).view(batch_size, k_len,  self.num_heads, self.d_k)
        q = self.query_projection(q).view(batch_size, q_len,  self.num_heads, self.d_k)
        v = self.value_projection(v).view(batch_size, v_len,  self.num_heads, self.d_k)
        
        attention = scaled_dot_product_attention(
            q.transpose(1, 2), 
            k.transpose(1, 2), 
            self.d_k, 
            self.mask
        )
        output = torch.matmul(attention, v.transpose(1, 2))
        output = self.output_layer(output.transpose(1, 2).contiguous().view(batch_size, q_len, -1))

        h = self.dropout(output) + residual

        h = h.permute(0, 2, 1).view(batch_size, n_channels, height, width)
        
        return h


In [157]:
example_image = torch.randn(2, 3, 128, 128)
conv_block = ConvBlock(3, 32, 2)
h_example_image = conv_block(example_image)

In [158]:
attention_block = Attention(d_k=64, dropout=0.1, num_heads=3, num_channels=32, num_groups=8)
attn_output = attention_block(h_example_image)
attn_output.shape

torch.Size([2, 32, 128, 128])

In [159]:
def timestep_encoding(curr_t: torch.Tensor, T: torch.Tensor, embedding_dim: int, n=10000, device: torch.device = "cpu"):
    """
    Naive sin/cosin positional embedding adapted for timestep embedding in DDPM
    """
    curr_t = curr_t / T # normalize the timestep to be between 0 and 1
    p = torch.zeros((curr_t.shape[-1], embedding_dim)).to(device) # initialize the positional embedding tensor

    m = torch.arange(int(embedding_dim/2)).to(device) # this is divided by two because we alternate between sin and cos
    denominators = torch.pow(n, (2*m/embedding_dim))  # compute the denominators for the sin and cos functions
    
    p[:, 0::2] = torch.sin(curr_t.unsqueeze(1) / denominators.unsqueeze(0))
    p[:, 1::2] = torch.cos(curr_t.unsqueeze(1) / denominators.unsqueeze(0))
    return p



class TimestepEmbedding(nn.Module):
    """
    Embeds the timestep into a higher dimensional space using a 2 layer MLP.
    """
    def __init__(self,
                in_channels: int, 
                embedding_dim: int, 
                activation: nn.Module = nn.ReLU
                ):
        """
        Args:
            in_channels: number of input channels
            embedding_dim: dimension of the embedding space
            activation: activation function
        """
        super(TimestepEmbedding, self).__init__()
        self.linear1 = nn.Linear(in_channels, embedding_dim)
        self.linear2 = nn.Linear(embedding_dim, embedding_dim)
        self.activation = activation()

    def forward(self, curr_t: torch.Tensor, T: torch.Tensor):
        print(curr_t.shape)
        x = self.linear1(curr_t)
        x = self.activation(x)
        x = self.linear2(x)
        x = self.activation(x)
        
        return x

In [160]:
class LeftBlock(nn.Module):
    """
    Downampling (left) side of the UNet.
    Excludes the bottom-most conv block.
    """
    def __init__(
            self, 
            in_channels: int, 
            filters: list[int], 
            num_layers: int, 
            has_attention: bool = False, 
            num_heads: int = 8, 
            dropout: float = 0.2, 
            timestep_emb_dim: int = None
            ):
        super(LeftBlock, self).__init__()
        
        self.has_attention = has_attention
        conv_blocks = [ConvBlock(in_channels, filters[0], num_layers, timestep_emb_dim=timestep_emb_dim)]
        attention_blocks = [
            Attention(
                d_k=64, 
                dropout=0.1, 
                num_heads=num_heads, 
                num_channels=filters[0], 
            )
        ] if has_attention else []
        
        for i in range(1, len(filters)):
            conv_blocks.append(ConvBlock(filters[i-1], filters[i], num_layers, timestep_emb_dim=timestep_emb_dim))
            if has_attention:
                attention_blocks.append(
                    Attention(
                        d_k=64, 
                        dropout=0.1, 
                        num_heads=num_heads, 
                        num_channels=filters[i], 
                    )
                )
        
        self.conv_blocks = nn.ModuleList(conv_blocks)
        self.attention_blocks = nn.ModuleList(attention_blocks)
        
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x, timestep_emb=None):
        residual_outputs = []
        for i, conv_block in enumerate(self.conv_blocks):
            x = conv_block(x, timestep_emb)
            if self.has_attention:
                x = self.attention_blocks[i](x)
            
            residual_outputs.append(x)
            x = self.maxpool(x)

        return residual_outputs, x


class RightBlock(nn.Module):
    """
    Upsampling (right) side of the UNet.
    """
    def __init__(
            self, 
            filters: list[int], 
            num_layers: int, 
            has_attention: bool = False, 
            num_heads: int = 8, 
            dropout: float = 0.2,
            timestep_emb_dim: int = None,
            ):
        super(RightBlock, self).__init__()
        self.has_attention = has_attention

        conv_layers = []
        upsample_layers = []
        attention_layers = []
        
        for i in range(len(filters) - 2):
            conv_layers.append(
                ConvBlock(filters[i], filters[i+1], num_layers, timestep_emb_dim=timestep_emb_dim)
            )
            upsample_layers.append(
                nn.ConvTranspose2d(filters[i+1], filters[i+1]//2, 2, stride=2)
            )

            if has_attention:
                attention_layers.append(
                    Attention(d_k=64, dropout=0.1, num_heads=num_heads, num_channels=filters[i+1]//2)
                )
        conv_layers.append(
            ConvBlock(filters[-2], filters[-1], num_layers, timestep_emb_dim=timestep_emb_dim)
        )

        self.conv_layers = nn.ModuleList(conv_layers)
        self.attention_layers = nn.ModuleList(attention_layers)
        self.upsample_layers = nn.ModuleList(upsample_layers)
    
    def forward(self, x, residual_outputs, timestep_emb=None):
        for i in range(len(self.conv_layers)):
            residual = residual_outputs[-(i+1)]
            _, _, h, w = x.shape
            residual = residual[:, :, :h, :w]

            x = torch.cat([x, residual], dim=1)
            x = self.conv_layers[i](x, timestep_emb)

            if i < len(self.upsample_layers):
                x = self.upsample_layers[i](x)

                if self.has_attention:
                    x = self.attention_layers[i](x)
        
        return x


In [161]:
db = LeftBlock(in_channels=3, filters=[32, 64, 128], num_layers=2, has_attention=True, num_heads=3, timestep_emb_dim=128)

In [162]:
timesteps = torch.randint(0, 1000, (2,))
t_encoded = timestep_encoding(timesteps, 1000, 32)
t_encoded.shape # N x 32

time_embedding_layer = TimestepEmbedding(in_channels=32, embedding_dim=128, activation=nn.ReLU) # N x 32 -> N x 128 (embedding_dim)
t_embedded = time_embedding_layer(curr_t=t_encoded, T=1000)

torch.Size([2, 32])


In [163]:
a = db(torch.randn(2, 3, 128, 128), timestep_emb=t_embedded)
a[1].shape

torch.Size([2, 128, 16, 16])

In [164]:
for residual_output in a[0]:
    print(residual_output.shape)
    print("-"*10)

print(a[1].shape)

torch.Size([2, 32, 128, 128])
----------
torch.Size([2, 64, 64, 64])
----------
torch.Size([2, 128, 32, 32])
----------
torch.Size([2, 128, 16, 16])


In [165]:
bottom_conv = nn.Sequential(ResBlock(128, 256), nn.ConvTranspose2d(256, 128, 2, stride=2))
bottom_conv(a[1]).shape

torch.Size([2, 128, 32, 32])

In [166]:
up = RightBlock(filters=[256, 128, 64, 32], num_layers=2, has_attention=True, num_heads=3, timestep_emb_dim=128)

In [167]:
bottom_out = bottom_conv(a[1])
bottom_out.shape


torch.Size([2, 128, 32, 32])

In [168]:
up_out = up(bottom_out, a[0], t_embedded)
up_out.shape

torch.Size([2, 32, 128, 128])

In [169]:
class UNet(nn.Module):
    """
    UNet model for DDPM.
    """
    def __init__(self,
                 down_filters: list[int], 
                 in_channels: int, 
                 num_layers: int, 
                 has_attention: bool = False, 
                 num_heads: int = 8,
                 diffusion_steps: int = None,
                 num_groups: int = 8,
                 activation: nn.Module = nn.ReLU,
                 timestep_emb_dim: int = None
                ):
        super(UNet, self).__init__()
        self.T = diffusion_steps
        self.down_filters = down_filters

        self.time_embed_dim = down_filters[0] * 4 
        
        if self.T is not None:
            self.timestep_embedding = TimestepEmbedding(
                in_channels=self.down_filters[0], 
                embedding_dim=self.time_embed_dim, 
                activation=activation, 
            )
        
        
        self.left_block = LeftBlock(
            filters=down_filters, 
            num_layers=num_layers, 
            in_channels=in_channels, 
            has_attention=has_attention, 
            num_heads=num_heads
        )
        
        # the bottom-most (middle) conv block 
        if has_attention:
            self.middle_conv = ConvBlock(
                down_filters[-1], 
                down_filters[-1]*2, 
                num_layers, 
                timestep_emb_dim=timestep_emb_dim
            )
            self.middle_attention = Attention(d_k=64, dropout=0.1, num_heads=num_heads, num_channels=down_filters[-1]*2)
            self.middle_upsample = nn.ConvTranspose2d(down_filters[-1]*2, down_filters[-1], 2, stride=2)
        else:
            self.middle_conv = ConvBlock(
                down_filters[-1], 
                down_filters[-1]*2, 
                num_layers, 
                timestep_emb_dim=timestep_emb_dim
            )
            self.middle_attention = None
            self.middle_upsample = nn.ConvTranspose2d(down_filters[-1]*2, down_filters[-1], 2, stride=2)
        
        self.up_filters = [down_filters[-1]*2]
        self.up_filters.extend(reversed(down_filters))
        self.right_block = RightBlock(
            filters=self.up_filters, 
            num_layers=num_layers, 
            has_attention=has_attention, 
            num_heads=num_heads,
            timestep_emb_dim=timestep_emb_dim
        )

        self.group_norm = nn.GroupNorm(num_groups=in_channels, num_channels=in_channels)
        self.conv_out = nn.Conv2d(down_filters[0], in_channels, 3, padding=1)
    
    def forward(self, x, t):
        if self.T is not None:
            print(t.shape)
            t_encoded = timestep_encoding(t, self.T, self.down_filters[0], n = 4000, device=x.device)
            print(t_encoded.shape)
            t_emb = self.timestep_embedding(curr_t=t_encoded, T=self.T)
            t_emb = t_emb.view(-1, self.time_embed_dim, 1, 1)
        else: 
            t_emb = None

        x = self.group_norm(x)
        print("group norm complete")
        residual_outputs, down_output = self.left_block(x, t_emb)
        print("left block complete")
        bottom_output = self.middle_conv(down_output, t_emb)
        if self.middle_attention is not None:
            bottom_output = self.middle_attention(bottom_output)
        print("middle attention complete")
        bottom_output = self.middle_upsample(bottom_output)

        right_out = self.right_block(bottom_output, residual_outputs, t_emb)
        print("right block complete")
        output = self.conv_out(right_out)
        print("conv out complete")
        return output


In [170]:
u_net = UNet(down_filters=[32, 64, 128], in_channels=3, num_layers=2, has_attention=True, num_heads=3, diffusion_steps=1000)

In [171]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

num_params = count_parameters(u_net)
print(f"The UNet model has {num_params:,} trainable parameters.")


The UNet model has 4,442,319 trainable parameters.


In [172]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

sample_inputs = torch.randn(16, 3, 32, 32).to(device)
u_net.to(device)
sample_outputs = u_net(sample_inputs, torch.randint(0, 1000, (15,)).to(device))
sample_outputs.shape

torch.Size([15])
torch.Size([15, 32])
torch.Size([15, 32])
group norm complete
left block complete
middle attention complete
right block complete
conv out complete


torch.Size([16, 3, 32, 32])

In [173]:
x = torch.randn(15, 3, 128, 128)
T = 1000
batch_size = x.shape[0]
t = torch.randint(0, T, (batch_size,))
e = torch.randn_like(x)
alpha_bar_t = alpha_bar(t)
print(alpha_bar_t.shape)
alpha_bar_t = alpha_bar_t.view(-1, 1, 1, 1)

torch.Size([15])


In [174]:
alpha_bar_t.shape

torch.Size([15, 1, 1, 1])

In [175]:
def validate_model(model, valid_loader, loss_fn, all_valid_loss):
    model.eval()
    valid_loss = []
    
    with torch.no_grad():
        for batch in valid_loader:
            x, y = batch
            y_pred = model(x)
            loss = loss_fn(y_pred, y)
            valid_loss.append(loss.item())
    
    all_valid_loss.append(sum(valid_loss) / len(valid_loss))

def plot_loss(all_train_loss, all_valid_loss):
    plt.figure(figsize=(10, 5))
    plt.plot(all_train_loss, label='Training Loss')
    plt.plot(all_valid_loss, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()


def train_model(model: nn.Module,
                optim: torch.optim.Optimizer,
                loss_fn: nn.Module,
                train_loader: torch.utils.data.DataLoader,
                valid_loader: torch.utils.data.DataLoader,
                scheduler: torch.optim.lr_scheduler._LRScheduler,
                epochs: int = 10,
                valid_every: int = 1
                ):

    all_train_loss = []
    all_valid_loss = []
    
    for epoch in range(epochs):
        model.train()
        train_loss = []
        
        for batch in train_loader:
            x, _ = batch
            (noisy_images, t), e = prepare_batch(x, T, alpha_bar_t)
            optim.zero_grad()
            y_pred = model(noisy_images)
            loss = loss_fn(y_pred, e)
            loss.backward()
            optim.step()
            train_loss.append(loss.item())
        
        all_train_loss.append(sum(train_loss) / len(train_loss))
        scheduler.step()
        
        if epoch % valid_every == 0:
            validate_model(model, valid_loader, loss_fn, all_valid_loss)
            print(
                f"Epoch {epoch}, Train Loss: {sum(train_loss) / len(train_loss)}, "
                f"Valid Loss: {all_valid_loss[-1]}"
            )
    
    plot_loss(all_train_loss, all_valid_loss)
