In [132]:
import torch 
from torch import nn
from torchvision import models

In [133]:
class VGG_Net(nn.Module):

    def __init__(self, in_channels=3, num_classes=1000, architecture=None):
        super().__init__()

        self.in_channels = in_channels
        self.num_classes = num_classes
        self.architecture = architecture

        self.conv_layers = self.create_conv_layers()
        
        self.fcs = nn.Sequential(
            nn.Linear(512*7*7, 4096),
            nn.Dropout(p=0.05),
            nn.ReLU(),

            nn.Linear(4096, 4096),
            nn.Dropout(p=0.05),
            nn.ReLU(),

            nn.Linear(4096, self.num_classes)
        )


    def forward(self, x):
        x = self.conv_layers(x)
        x = x.reshape(x.shape[0], -1)
        x = self.fcs(x)

        return x
        

    def create_arch(self, arch):
        in_channels = self.in_channels 
        layers = []

        for x in arch:
        
            if type(x) == int:
                layers += [
                    nn.Conv2d(in_channels=in_channels, out_channels=x, kernel_size=3, stride=1, padding=1),
                    # nn.BatchNorm2d(x),
                    nn.ReLU()
                ]
                in_channels = x

            else:
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
                
        return nn.Sequential(*layers)


    def create_conv_layers(self):
        

        if self.architecture == 'VGG_11':        
            arch = [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512,'M']
            layers = self.create_arch(arch)

                
        elif self.architecture == 'VGG_13':
            arch = [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M']
            layers = self.create_arch(arch)

        elif self.architecture == 'VGG_16':
            arch = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512,'M']
            layers = self.create_arch(arch)
            
        else:
            arch = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M']
            layers = self.create_arch(arch)
            
        return layers


In [134]:
model = VGG_Net(3, 1000, 'VGG_19')
model

VGG_Net(
  (conv_layers): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU()
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU()
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU()
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU()
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU()
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU()
    (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (17): ReLU()
    (18): MaxPool2d(kernel_size=2, stride=2, pad

In [135]:
image = torch.randn(1, 3, 224, 224)
print(image.shape)
x = model(image)
print(x.shape)

torch.Size([1, 3, 224, 224])
torch.Size([1, 1000])


In [136]:
torch_vgg =  models.vgg19()
torch_vgg

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padd