In [1]:
import torch
from torch import nn

In [2]:
class D_block(nn.Module):
        def __init__(self, channels, nc=3, kernel_size=4, strides=2,
                 padding=1, alpha=0.2, **kwargs):
            super(D_block, self).__init__(**kwargs)
            self.conv2d = nn.Conv2d(
                nc, channels, kernel_size, strides, padding, bias=False)
            self.batch_norm = nn.BatchNorm2d(channels)
            self.activation = nn.LeakyReLU(alpha, inplace=True)

        def forward(self, X):
            return self.activation(self.batch_norm(self.conv2d(X)))

In [3]:
x = torch.zeros((2, 3, 16, 16))
d_blk = D_block(20)
d_blk(x).shape

torch.Size([2, 20, 8, 8])

In [4]:
def Conv2D(channels, kernel_size, use_bias, nc=3):
    return nn.Conv2d(nc, channels, kernel_size=kernel_size, bias=use_bias)

In [5]:
n_D = 64
net_D = nn.Sequential(
    D_block(n_D),    # Output: (64, 32, 32)
    D_block(n_D*2),  # Output: (64 * 2, 16, 16)
    D_block(n_D*4),  # Output: (64 * 4, 8, 8)
    D_block(n_D*8),  # Output: (64 * 8, 4, 4)
    Conv2D(1, kernel_size=4, use_bias=False)  # Output: (1, 1, 1)
)

In [6]:
x = torch.zeros((1, 3, 64, 64))
net_D(x).shape

RuntimeError: Given groups=1, weight of size [128, 3, 4, 4], expected input[1, 64, 32, 32] to have 3 channels, but got 64 channels instead

https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html

padding_mode (string, optional) – 'zeros', 'reflect', 'replicate' or 'circular'. Default: 'zeros'

dilation (int or tuple, optional) – Spacing between kernel elements. Default: 1

groups (int, optional) – Number of blocked connections from input channels to output channels. Default: 1

bias (bool, optional) – If True, adds a learnable bias to the output. Default: True

In [7]:
n_D = 64
net_D1 = nn.Sequential(
    D_block(n_D))

In [8]:
x = torch.zeros((1, 3, 64, 64))
net_D1(x).shape

torch.Size([1, 64, 32, 32])

In [9]:
n_D = 64
net_D2 = nn.Sequential(
    D_block(n_D),    # Output: (64, 32, 32)
    D_block(n_D*2),  # Output: (64 * 2, 16, 16)
)

In [10]:
x = torch.zeros((1, 3, 64, 64))
net_D2(x).shape

RuntimeError: Given groups=1, weight of size [128, 3, 4, 4], expected input[1, 64, 32, 32] to have 3 channels, but got 64 channels instead