In [1]:
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

In [7]:
VGG_types = {
    'VGG11' : [64,'M' , 128 , 'M' , 256 ,256, 'M', 512 ,512 ,'M' ,512 ,512 ,'M'] ,
    'VGG13' : [64,64,'M' ,128 ,128 ,'M' ,256 ,256, 'M', 512 ,512 ,'M' ,512 ,512 ,'M'],
    'VGG16' : [64 , 64, 'M' ,128 , 128, 'M' ,256,256,256,'M' ,512  , 512,512,'M' ,512,512,512,'M'] , 
    'VGG19' : [64 , 64, 'M' ,128 , 128, 'M' ,256,256,256,256,'M' ,512, 512,512,512,'M' ,512,512,512,512,'M'] 
}

#Flatten -> 4096 X 4096 X 1000 Fully Connected Layers

In [5]:
class VGG_Net(nn.Module):
    def __init__(self , in_channels , num_classes , architecture):
        super(VGG_Net , self).__init__()
        self.in_channels = in_channels
        self.conv_layers = self.create_conv_layers(architecture)
        self.fc1 = nn.Linear(512*7*7 , 4096)
        self.fc2 = nn.Linear(4096 , 4096)
        self.fc3 = nn.Linear(4096 , num_classes)
    
    def create_conv_layers(self, architecture):
        layers = []
        in_channels = self.in_channels
        for i in architecture:
            if type(i) == int:
                layers.append(nn.Conv2d(in_channels, out_channels=i, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)))
                in_channels = i
            else:
                layers.append(nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2)))
        return nn.Sequential(*layers)

    
    def forward(self , x):
        
        for layer in self.conv_layers:
            x = layer(x)
            x = F.relu(x)
        
        x = x.reshape(x.shape[0] , -1)    
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        
        return x

        
        
        

In [6]:
model = VGG_Net(in_channels = 3 , num_classes = 1000 ,architecture = VGG_types['VGG16'])
x = torch.randn(1, 3 , 224 ,224)
print(model(x).shape)

torch.Size([1, 1000])
