In [4]:
import torch
import torch.nn as nn
from torchinfo import summary
import torch.nn.functional as F

In [41]:
class VGG16(nn.Module):
    def __init__(self):
        super().__init__()
        # 1
        self.conv1 = nn.Conv2d(3, 64, 3, padding=1)
        self.conv2 = nn.Conv2d(64, 64, 3, padding=1)
        self.pool1 = nn.MaxPool2d(2)  # 112 * 112
        # 2
        self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
        self.conv4 = nn.Conv2d(128, 128, 3, padding=1)
        self.pool2 = nn.MaxPool2d(2)  # 56 * 56
        # 3
        self.conv5 = nn.Conv2d(128, 256, 3, padding=1)
        self.conv6 = nn.Conv2d(256, 256, 3, padding=1)
        self.conv7 = nn.Conv2d(256, 256, 3, padding=1)
        self.pool3 = nn.MaxPool2d(2)  # 28 * 28
        # 4
        self.conv8 = nn.Conv2d(256, 512, 3, padding=1)
        self.conv9 = nn.Conv2d(512, 512, 3, padding=1)
        self.conv10 = nn.Conv2d(512, 512, 3, padding=1)
        self.pool4 = nn.MaxPool2d(2)  # 14 * 14
        # 5
        self.conv11 = nn.Conv2d(512, 512, 3, padding=1)
        self.conv12 = nn.Conv2d(512, 512, 3, padding=1)
        self.conv13 = nn.Conv2d(512, 512, 3, padding=1)
        self.pool5 = nn.MaxPool2d(2)  # 7 * 7
        # 6
        self.fc1 = nn.Linear(7*7*512, 4096)
        self.fc2 = nn.Linear(4096, 4096)
        self.fc3 = nn.Linear(4096, 1000)
    
    def forward(self, x):
        # 1
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.pool1(x)
        # 2
        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = self.pool2(x)
        # 3
        x = F.relu(self.conv5(x))
        x = F.relu(self.conv6(x))
        x = F.relu(self.conv7(x))
        x = self.pool3(x)
        # 4
        x = F.relu(self.conv8(x))
        x = F.relu(self.conv9(x))
        x = F.relu(self.conv10(x))
        x = self.pool4(x)
        # 5
        x = F.relu(self.conv11(x))
        x = F.relu(self.conv12(x))
        x = F.relu(self.conv13(x))
        x = self.pool5(x)
        # 6
        x = F.dropout(x.view(-1, 7*7*512), p=0.5)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, p=0.5)
        x = F.relu(self.fc2(x))
        x = F.softmax(self.fc3(x), 1)
        return x

In [42]:
data = torch.ones([10, 3, 224, 224])

In [43]:
net = VGG16()

In [44]:
output = net(data)

In [45]:
output.shape

torch.Size([10, 1000])

In [46]:
summary(net, input_size=[10, 3, 224, 224])

Layer (type:depth-idx)                   Output Shape              Param #
VGG16                                    [10, 1000]                --
├─Conv2d: 1-1                            [10, 64, 224, 224]        1,792
├─Conv2d: 1-2                            [10, 64, 224, 224]        36,928
├─MaxPool2d: 1-3                         [10, 64, 112, 112]        --
├─Conv2d: 1-4                            [10, 128, 112, 112]       73,856
├─Conv2d: 1-5                            [10, 128, 112, 112]       147,584
├─MaxPool2d: 1-6                         [10, 128, 56, 56]         --
├─Conv2d: 1-7                            [10, 256, 56, 56]         295,168
├─Conv2d: 1-8                            [10, 256, 56, 56]         590,080
├─Conv2d: 1-9                            [10, 256, 56, 56]         590,080
├─MaxPool2d: 1-10                        [10, 256, 28, 28]         --
├─Conv2d: 1-11                           [10, 512, 28, 28]         1,180,160
├─Conv2d: 1-12                           [10, 5