In [7]:
from torch import nn
import torch


class VGG(nn.Module):
    def __init__(self, architecture, num_classes):
        super(VGG, self).__init__()
        vgg_blocks = []
        for num_conv, (in_channels, out_channels) in architecture:
            vgg_blocks.append(vgg_block(num_conv, in_channels, out_channels))
        self.net = nn.Sequential(
            *vgg_blocks,
            nn.Flatten(),
            nn.LazyLinear(4096), nn.ReLU(),
            nn.Dropout(0.5),
            nn.LazyLinear(4096), nn.ReLU(),
            nn.Dropout(0.5),
            nn.LazyLinear(num_classes))

    def forward(self, X):
        X = self.net(X)
        return X

    def layer_summary(self, X_shape):
        X = torch.rand(*X_shape)
        for layer in self.net:
            X = layer(X)
            print(layer.__class__.__name__, "output shape: ", X.shape)

    @staticmethod
    def xavier_uniform(layer):
        if type(layer) in [nn.Conv2d, nn.Linear]:
            torch.nn.init.xavier_uniform_(layer.weight)
            layer.bias.data.fill_(0.00001)


def vgg_block(num_conv, num_InChannels, num_OutChannels):
    layers = []
    for _ in range(num_conv):
        layers.append(nn.Conv2d(num_InChannels, num_OutChannels, kernel_size=3, padding=1))
        layers.append(nn.ReLU())
    layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
    return nn.Sequential(*layers)

In [23]:
arch=[[1, (1, 64)], [1, (64, 128)],[ 1, (128, 256)],[ 2, (256, 256)],[ 1, (256, 512)], [2, (512, 512)], [2, (512, 512)]]

In [24]:
model = VGG(arch, 10)

In [25]:
model.layer_summary((1,1,224,224))

Sequential output shape:  torch.Size([1, 64, 112, 112])
Sequential output shape:  torch.Size([1, 128, 56, 56])
Sequential output shape:  torch.Size([1, 256, 28, 28])
Sequential output shape:  torch.Size([1, 256, 14, 14])
Sequential output shape:  torch.Size([1, 512, 7, 7])
Sequential output shape:  torch.Size([1, 512, 3, 3])
Sequential output shape:  torch.Size([1, 512, 1, 1])
Flatten output shape:  torch.Size([1, 512])
Linear output shape:  torch.Size([1, 4096])
ReLU output shape:  torch.Size([1, 4096])
Dropout output shape:  torch.Size([1, 4096])
Linear output shape:  torch.Size([1, 4096])
ReLU output shape:  torch.Size([1, 4096])
Dropout output shape:  torch.Size([1, 4096])
Linear output shape:  torch.Size([1, 10])
