# Denoising Diffusion Probabilistic Models

In [2]:
import torch 
import torch.nn as nn

In [14]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1):
        super(ConvBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, stride)
    
    def forward(self, x):
        return self.relu(self.conv2(self.relu(self.conv1(x))))
    
class DownBlock(nn.Module):
    def __init__(self, filters, in_channels):
        super(DownBlock, self).__init__()
        conv_blocks = [ConvBlock(in_channels, filters[0])]
        for i in range(1, len(filters)):
            conv_blocks.append(ConvBlock(filters[i-1], filters[i]))

        self.conv_blocks = nn.Sequential(*conv_blocks)
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        residual_outputs = []
        for conv_block in self.conv_blocks:
            x = conv_block(x)
            residual_outputs.append(x)
            x = self.maxpool(x)

        return residual_outputs, x

class UpBlock(nn.Module):
    def __init__(self, filters, in_channels):
        super(UpBlock, self).__init__()
        self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
        self.conv_blocks = nn.Sequential(
            ConvBlock(in_channels, filters[0]),
            ConvBlock(filters[0], filters[1]),
            ConvBlock(filters[1], filters[2])
        )



In [16]:
db = DownBlock([32, 64, 128], 3)

In [17]:
db

DownBlock(
  (conv_blocks): Sequential(
    (0): ConvBlock(
      (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1))
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
    )
    (1): ConvBlock(
      (conv1): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
    )
    (2): ConvBlock(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
    )
  )
  (maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)

In [18]:
a = db(torch.randn(1, 3, 64, 64))