Here I will start to rebuild the DDPM paper from scratch, the whole procedure of DDPM procedure can be thought as a process of demolishing a big building into bricks and rebuild the building by using those bricks, we first add noise into the model, then the model will learn how to rebuild the image back into the original one, the DDPM is actually not a typical "diffusion model" but a VAE model.

In [2]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn import init
import math
from einops import rearrange, repeat
from functools import partial # fancy way to wrap functions

The DDPM paper used a very similar architecture used in UNet paper, but with several modification,  the authors replaced the original double convolutions in each encoding steps with the 'Residual Blocks' used in the Resnet paper

There are five components in DDPM model, they are:

1. Encoder blocks

2. Bottlneck blocks

3. Decoder blocks

4. Self Attention modules

5. Sinusodial time embeddings

Here are the details of the model architecture: 
1. Same as original UNet model,  there are four levels in the encoder and decoder parts of DDPM model, as well as bottleneck part 

Define the SinusodialPositionEncoding 

In [None]:
class SinusodialPositionEncoding(nn.Module):
    def __init__(self, 
                 dim, 
                 theta = 10000
        ):
        """
        dim: dimension of the input
        theta: mentioned in the paper 'attention is all you need'   
        """
        super(SinusodialPositionEncoding, self).__init__()
        assert dim % 2 == 0, 'Dimension of input recommended to be an even number'
        self.dim = dim
    
    def positional_encoding(self, x, position):
        half_dim = self.dim//2
        median = math.log(self.theta) / half_dim
        position = x[:, None] # The x shape before adding None is (batch_size, seq_len)
        emb = torch.exp(torch.arange(0, half_dim,device=x.device)* -median)

        positional_encoding = torch.zeros(len(x), self.dim, device=x.device)
        positional_encoding[:, 0::2] = torch.sin(position * emb)
        positional_encoding[:, 1::2] = torch.cos(position * emb)

        return positional_encoding

Define the Resnet Block

In [4]:
class ResnetBlock(nn.Module):
    def __init__(self, 
                 in_channels, 
                 out_channels, 
                 time_channels,
                 kernel_size=3, 
                 padding_size=1, 
                 n_group=32, 
                 dropout=True, 
                 time_embedding=True
        ):
        super(ResnetBlock, self).__init__()
        self.in_dim = in_channels
        self.out_dim = out_channels
        self.time_channels = time_channels
        self.dropout = dropout
        self.time_embedding = time_embedding

        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding_size)
        self.acvtivation1 = nn.SiLU()
        self.norm1 = nn.GroupNorm(num_groups=n_group, num_channels=out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, padding=padding_size)
        self.activation2 = nn.SiLU()
        self.norm2 = nn.GroupNorm(num_groups=n_group, num_channels=out_channels)
        self.res_connection = nn.Conv2d(in_channels, out_channels,1,0) if in_channels != out_channels else nn.Identity()
        self.time_emb = nn.Conv2d(time_channels, out_channels)

    def forward(self, x):
        x = self.conv1(x)
        x = self.acvtivation1(x)
        x = self.norm1(x)
        if self.time_embedding:
            time_emb = self.time_emb()
            time_emb = rearrange(time_emb, 'b c->b c 1 1')
            x = x + time_emb
        x = self.conv2(x)
        x = self.activation2(x)
        x = self.norm2(x)
        if self.dropout:
            dropout = nn.Dropout(0.1)
            x = dropout(x)
        x = x + self.res_connection(x)
        return x

Define the Attention Block with Flatten Attention

In [5]:
class FocusedLinearAttention(nn.Module):
    def __init__(self, 
                 dim, 
                 num_heads=8, 
                 bias=False, 
                 scale=0, 
                 attn_drop=0.2, 
                 proj_drop=0.2,
                 focusing_factor=3, 
                 kernel_size=5
        ):
        super(FocusedLinearAttention, self).__init__()
        assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."

        self.dim = dim
        self.num_heads = num_heads
        self.scale = scale
        self.kernel = nn.ReLU()
        head_dim = dim // num_heads
        
        # requires bias = False when using linear projection
        self.q = nn.Linear(dim, dim, bias=bias)
        self.k = nn.Linear(dim, dim, bias=bias)
        self.v = nn.Linear(dim, dim, bias=bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        # Focusing on higher cosine similarity pairs
        self.focusing_factor = focusing_factor
        # Depthwise Convolution, padding = kernel_size//2 to make sure the image with the same size after convolution
        self.dwc = nn.Conv2d(in_channels=head_dim, out_channels=head_dim, kernel_size=kernel_size,
                             groups=head_dim, padding=kernel_size // 2)
        self.scale = nn.Parameter(torch.zeros(size=(1, 1, 1, dim)))
        

    def forward(self, x):
        b, c, h, w = x.shape
        q = self.q(x)
        k = self.k(x)
        focusing_factor = self.focusing_factor
        scale = nn.Softplus()(self.scale)
        q = self.kernel(q) + 1e-6
        k = self.kernel(k) + 1e-6
        q = q / scale
        k = k / scale
        q_norm = q.norm(dim=-1, keepdim=True)
        k_norm = k.norm(dim=-1, keepdim=True)
        q = q ** focusing_factor
        k = k ** focusing_factor
        q = (q / q.norm(dim=-1, keepdim=True)) * q_norm
        k = (k / k.norm(dim=-1, keepdim=True)) * k_norm

        # Rearrange into multi-head dimension, each head will have C/H dimensions
        q, k, v = (rearrange(x, "b n (h c) -> (b h) n c", h=self.num_heads) for x in [q, k, v])
        i, j, c, d = q.shape[-2], k.shape[-2], k.shape[-1], v.shape[-1]
        # print(i, j, c, d)
        # print(q.shape)
        # print(k.shape)
        # print(v.shape)
        # print(i * j * (c + d))
        # print(c * d * (i + j))
        z = 1 / (torch.einsum("b i c, b c -> b i", q, k.sum(dim=1)) + 1e-6)
        # Using Linear Attention Mechanism here to get O(N) complexity
        if i * j * (c + d) > c * d * (i + j):
            kv = torch.einsum("b j c, b j d -> b c d", k, v)
            x = torch.einsum("b i c, b c d, b i -> b i d", q, kv, z)
        else:
            qk = torch.einsum("b i c, b j c -> b i j", q, k)
            x = torch.einsum("b i j, b j d, b i -> b i d", qk, v, z)

        num = int(v.shape[1] ** 0.5)
        feature_map = rearrange(v, "b (w h) c -> b c w h", w=num, h=num)
        # Expanding the attention matrix rank from d to N to expand the expressing power of the model
        feature_map = rearrange(self.dwc(feature_map), "b c w h -> b (w h) c")
        # Adding the expanding version feature map back into the input
        x = x + feature_map
        x = rearrange(x, "(b h) n c -> b n (h c)", h=self.num_heads)

        x = self.proj(x)
        x = self.proj_drop(x)

        return x

Define the Attention Block with Self-attetion

In [4]:
class Origin_Attention(nn.Module):
    def __init__(self, 
                 dim, 
                 input_channels,
                 heads=8,
                 scale=True,
                 dropout=0.1,
                 n_group=32,
                 bias=False
        ):
        super(Origin_Attention, self).__init__()
        self.dim = dim
        self.input_channels = input_channels
        self.head_dim = dim // heads
        self.scale = scale  

        self.qkv = nn.Linear(input_channels, dim*3, bias=bias)
        self.attn_dropout = nn.Dropout(dropout)
        self.output = nn.Linear(dim, input_channels)

    def forward(self, x):
        b, c, h, w = x.shape
        x = rearrange(x, 'b c x y -> b (x y) c', c=self.input_channels) # reshape from (batch_size, channel, height, width) to (batch_size, height*width, channel)
        qkv = self.qkv(x)
        q, k, v = torch.chunk(qkv, 3, dim=-1)
        q, k, v = map(lambda t: rearrange(t, 'b p (h d) -> b p h d ', h=self.heads), (q , k, v))
        attn = torch.einsum('b i h d, b j h d -> b j i h', q, k)
        if self.scale:
            attn = attn / (self.head_dim) ** 0.5
        attn = attn.softmax(dim=-1)
        attn = self.attn_dropout(attn)
        x = torch.einsum('b i j h, b j h d -> b i h d', attn, v)
        x = rearrange(x, 'b (x y) h d -> b h x y', x=h, y=w)
        x = self.output(x)
        return x

Define the downsample Encoder Blocks

In [3]:
class Encoders(nn.Module):
    def __init__(self,in_channels):
        super(Encoders, self).__init__()
        self.in_channels = in_channels
        self.conv = nn.Conv2d(in_channels, in_channels, 3, stride=2, padding=1)
        
    @staticmethod
    def initialize(x):
        init.xavier_uniform_(x.weight)
        init.zeros_(x.bias)

    def forward(self,x):
        x = self.conv(x)
        return x

Define the upsample Decoder Blocks

In [None]:
class Decoders(nn.Module):
    def __init__(self, in_channels):
        super(Decoders, self).__init__()
        self.in_channels = in_channels
        self.conv = nn.Conv2d(in_channels, in_channels, 3, stride=1, padding=1)

    @staticmethod
    def initialize(x):
        init.xavier_uniform_(x.weight)
        init.zeros_(x.bias)
    
    def forward(self, x):
        x = F.interpolate(x, scale_factor=2, mode='nearest')
        x = self.conv(x)
        return x

Define the UNet used in the model

In [None]:
class UNet(nn.Module):
    def __init__(self, 
                 in_channels,
                 out_channels,
                 init_dim,
                 model_dim,
                 theta=10000,
                 attn_heads=8,
                 resnet_blocks=8,
                 dim_mults=(1, 2, 4, 8)
        ):
        super(UNet, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.init_dim = init_dim
        self.model_dim = model_dim
        
        self.positional_encoding = SinusodialPositionEncoding(dim=model_dim, theta=theta)
        self.attention = FocusedLinearAttention(dim=model_dim, num_heads=attn_heads)
        self.encoders = nn.ModuleList([])
        self.decoders = nn.ModuleList([])
        
        self.init_conv = nn.Conv2d(in_channels, init_dim, 7, padding=3)
        dim = [init_dim, *map(lambda m:init_dim*m, dim_mults)]
        dim_list = zip(dim[:-1], dim[1:])
        
        whole = len(dim_list)
        last = int >= (whole - 1)

        res_block = partial(ResnetBlock, n_group = resnet_blocks)
        time_dim = dim*4 # time embedding metioned in the DDPM paper
        
        self.time_embedding = nn.Sequential(
            self.positional_encoding(),
            nn.Linear(dim, time_dim),
            nn.GELU(),
            nn.Linear(time_dim, time_dim)
        )
        
        # define the encoder part of the model  
        for i, (dim_in, dim_out) in enumerate(dim_list):
            self.encoders.append(nn.ModuleList([
                res_block(dim_in, dim_in, time_dim),
                res_block(dim_in, dim_in, time_dim),
                Decoders(dim_in) if not last else nn.Conv2d(dim_in, dim_out, 3, padding=1)
            ]))

        # define the bottleneck part of the model
        bottleneck_dim = dim[-1]
        self.bottleneck = nn.Sequential(
            res_block(bottleneck_dim, bottleneck_dim, time_dim),
            res_block(bottleneck_dim, bottleneck_dim, time_dim) 
        )

        # define the decoder part of the model
        for i, (dim_in, dim_out) in reversed(list(enumerate(dim_list))):
            self.decoders.append(nn.ModuleList([
                res_block(dim_in + dim_out, dim_out, time_dim),
                res_block(dim_in + dim_out, dim_out, time_dim),
                Decoders(dim_in) if not last else nn.Conv2d(dim_in, dim_out, 3, padding=1)
            ]))