In [2]:
import torch

In [3]:
net = torch.nn.Conv2d(1, 16, 3, 2, (3-1)//2)

In [4]:
x = torch.randn(1, 1, 28, 28)
print(x.shape)

torch.Size([1, 1, 28, 28])


In [5]:
y = net(x)
print(y.shape)

torch.Size([1, 16, 14, 14])


In [13]:
class ConvNet(torch.nn.Module):
    class Block(torch.nn.Module):
        def __init__(self, in_channels, out_channels, stride):
            super().__init__()
            kernel_size = 3
            padding = (kernel_size-1)//2
            self.c1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
            self.c2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size, 1, padding)
            self.c3 = torch.nn.Conv2d(out_channels, out_channels, kernel_size, 1, padding)
            self.relu = torch.nn.ReLU()
        
        def forward(self, x):
            x = self.relu(self.c1(x))
            x = self.relu(self.c2(x))
            x = self.relu(self.c3(x))
            return x
    
    def __init__(self, channels_10 = 64, n_blocks = 4):
        super().__init__()
        cnn_layers = [
            torch.nn.Conv2d(3, channels_10, kernel_size = 11, stride = 2, padding = 5),
            torch.nn.ReLU(),
        ]
        c1 = channels_10
        for _ in range(n_blocks):
            c2 = c1 * 2
            cnn_layers.append(self.Block(c1, c2, stride = 2))
            c1 = c2
        cnn_layers.append(torch.nn.Conv2d(c1, 1, kernel_size = 1))
        #cnn_layers.append(torch.nn.AdaptiveAveragePool2d(1))
        self.network = torch.nn.Sequential(*cnn_layers)
        
    def forward(self, x):
        return self.network(x)

net = ConvNet(n_blocks=3)
x = torch.randn(1, 3, 64, 64)
#net(x).shape
print(net)

ConvNet(
  (network): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(2, 2), padding=(5, 5))
    (1): ReLU()
    (2): Block(
      (c1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (c2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (c3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (relu): ReLU()
    )
    (3): Block(
      (c1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (c2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (c3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (relu): ReLU()
    )
    (4): Block(
      (c1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (c2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (c3): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (relu): ReLU()
    )
    (5): Conv2d(512, 1, ke