# Denoising Diffusion Probabilistic Models

In [126]:
import torch 
import torch.nn as nn

In [127]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1):
        super(ConvBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, stride)
    
    def forward(self, x):
        x = self.relu(self.conv1(x))
        return self.relu(self.conv2(x))
    
class DownBlock(nn.Module):
    def __init__(self, filters, in_channels):
        super(DownBlock, self).__init__()
        conv_blocks = [ConvBlock(in_channels, filters[0])]
        for i in range(1, len(filters)):
            conv_blocks.append(ConvBlock(filters[i-1], filters[i]))

        self.conv_blocks = nn.Sequential(*conv_blocks)
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        residual_outputs = []
        for conv_block in self.conv_blocks:
            x = conv_block(x)
            residual_outputs.append(x)
            x = self.maxpool(x)

        return residual_outputs, x

class UpBlock(nn.Module):
    def __init__(self, filters):
        super(UpBlock, self).__init__()
        layers = []
        for i in range(len(filters) - 2):
            layers.append(
                nn.Sequential(
                    ConvBlock(filters[i], filters[i+1]), 
                    nn.ConvTranspose2d(filters[i+1], filters[i+1]//2, 2, stride=2)
                )
            )
        
        layers.append(ConvBlock(filters[-2], filters[-1]))
        self.layers = nn.Sequential(*layers)
    
    def forward(self, x, residual_outputs):
        for i in range(len(self.layers)):
            print(f"i: {i}")
            residual = residual_outputs[-(i+1)]
            _, _, h, w = x.shape
            residual = residual[:, :, :h, :w]
            print(f"x: {x.shape}, residual: {residual.shape}")
            x = torch.cat([x, residual], dim=1)
            x = self.layers[i](x)
            print(f"x: {x.shape}")
        
        return x


In [128]:
db = DownBlock([32, 64, 128], 3)

In [129]:
db

DownBlock(
  (conv_blocks): Sequential(
    (0): ConvBlock(
      (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1))
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
    )
    (1): ConvBlock(
      (conv1): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
    )
    (2): ConvBlock(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
    )
  )
  (maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)

In [130]:
a = db(torch.randn(1, 3, 128, 128))
a[1].shape

torch.Size([1, 128, 12, 12])

In [131]:
for residual_output in a[0]:
    print(residual_output.shape)
    print("-"*10)

print(a[1].shape)

torch.Size([1, 32, 124, 124])
----------
torch.Size([1, 64, 58, 58])
----------
torch.Size([1, 128, 25, 25])
----------
torch.Size([1, 128, 12, 12])


In [132]:
bottom_conv = nn.Sequential(ConvBlock(128, 256), nn.ConvTranspose2d(256, 128, 2, stride=2))
bottom_conv(a[1]).shape

torch.Size([1, 128, 16, 16])

In [133]:
up = UpBlock([256, 128, 64, 32])



In [134]:
for layer in up.layers:
    print(layer)
    print("-"*10)

Sequential(
  (0): ConvBlock(
    (conv1): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1))
    (relu): ReLU(inplace=True)
    (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
  )
  (1): ConvTranspose2d(128, 64, kernel_size=(2, 2), stride=(2, 2))
)
----------
Sequential(
  (0): ConvBlock(
    (conv1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1))
    (relu): ReLU(inplace=True)
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
  )
  (1): ConvTranspose2d(64, 32, kernel_size=(2, 2), stride=(2, 2))
)
----------
ConvBlock(
  (conv1): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1))
  (relu): ReLU(inplace=True)
  (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
)
----------


In [135]:
bottom_out = bottom_conv(a[1])
bottom_out.shape


torch.Size([1, 128, 16, 16])

In [136]:
up_out = up(bottom_out, a[0])

i: 0
x: torch.Size([1, 128, 16, 16]), residual: torch.Size([1, 128, 16, 16])
x: torch.Size([1, 64, 24, 24])
i: 1
x: torch.Size([1, 64, 24, 24]), residual: torch.Size([1, 64, 24, 24])
x: torch.Size([1, 32, 40, 40])
i: 2
x: torch.Size([1, 32, 40, 40]), residual: torch.Size([1, 32, 40, 40])
x: torch.Size([1, 32, 36, 36])
