In [30]:
import torch
from torch import nn
from typing import Tuple


def vgg_block(n_convs: int, in_ch: int, out_ch: int) -> nn.Sequential:
    layers = list()
    for _ in range(n_convs):
        layers.append(nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1))
        layers.append(nn.ReLU())
        in_ch = out_ch
    layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
    return nn.Sequential(*layers)


def vgg(
    conv_arch: Tuple[Tuple[int, int], ...],
    in_channels: int = 1,
    in_dims: Tuple[int, int] = (244, 244),
    n_classes: int = 10,
) -> nn.Sequential:
    conv_blks = []
    # The convolutional part
    for (num_convs, out_channels) in conv_arch:
        conv_blks.append(vgg_block(num_convs, in_channels, out_channels))
        in_channels = out_channels
        in_dims = (in_dims[0] // 2, in_dims[1] // 2)

    return nn.Sequential(
        *conv_blks,
        nn.Flatten(),
        # The fully-connected part
        nn.Linear(out_channels * in_dims[0] * in_dims[1], 4096),
        nn.ReLU(),
        nn.Dropout(0.5),
        nn.Linear(4096, 4096),
        nn.ReLU(),
        nn.Dropout(0.5),
        nn.Linear(4096, n_classes)
    )

(122, 122)
(61, 61)
(30, 30)
(15, 15)
(7, 7)


Sequential(
  (0): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (1): Sequential(
    (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (2): Sequential(
    (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU()
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (3): Sequential(
    (0): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU()
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (4):

In [36]:
dims = (500, 500)
X = torch.randn(size=(1, 3, dims[0], dims[1]))
conv_arch = ((1, 64), (1, 128), (2, 256), (2, 512), (2, 512))
net = vgg(conv_arch, in_channels=3, in_dims=dims)
for blk in net:
    X = blk(X)
    print(blk.__class__.__name__,'output shape:\t',X.shape)

(250, 250)
(125, 125)
(62, 62)
(31, 31)
(15, 15)
Sequential output shape:	 torch.Size([1, 64, 250, 250])
Sequential output shape:	 torch.Size([1, 128, 125, 125])
Sequential output shape:	 torch.Size([1, 256, 62, 62])
Sequential output shape:	 torch.Size([1, 512, 31, 31])
Sequential output shape:	 torch.Size([1, 512, 15, 15])
Flatten output shape:	 torch.Size([1, 115200])
Linear output shape:	 torch.Size([1, 4096])
ReLU output shape:	 torch.Size([1, 4096])
Dropout output shape:	 torch.Size([1, 4096])
Linear output shape:	 torch.Size([1, 4096])
ReLU output shape:	 torch.Size([1, 4096])
Dropout output shape:	 torch.Size([1, 4096])
Linear output shape:	 torch.Size([1, 10])
