In [2]:
import torch
import torch.nn as nn 



In [19]:
class VGG16(nn.Module):
    def __init__(self, in_channels, n_class): 
        super(VGG16, self).__init__()
        self.arch = [64, 64, 'M',
                     128, 128, 'M',
                     256, 256, 256, 256, 'M',
                     512, 512, 512, 512, 'M',
                     512, 512, 512, 512, 'M']
        self.n_class = n_class
        self.in_channels = in_channels

        
        self.darknet = self.make_layers(self.arch)
        self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
        self.classifier = nn.Sequential(
            nn.Linear(512*7*7, 4096), 
            nn.ReLU(inplace=True), 
            nn.Dropout2d(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True), 
            nn.Dropout2d(),
            nn.Linear(4096, self.n_class))
    
    def make_layers(self, arch, batch_norm=False): 
        layer = []
        in_channel = self.in_channels
        for v in arch:
            if v == 'M': 
                layer += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else: 
                conv2d = nn.Conv2d(in_channels=in_channel, out_channels=v, kernel_size=3, padding=1, stride=1)
                if batch_norm:
                    layer += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
                else: 
                    layer += [conv2d, nn.ReLU(inplace=True)]
                in_channel = v 
        return(nn.Sequential(*layer))

    def forward(self, x):
        x = self.darknet(x)
        x = self.avgpool(x)
        x = self.classifier(x)
        return x 
        


    
        
    

In [20]:
model = VGG16(in_channels=3, n_class=10)
print(model)

VGG16(
  (darknet): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), pad