In [2]:
import torch 
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchvision.datasets as datasets
import torchvision.transforms as transforms


In [20]:
VGG_types = {'VGG16' : [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
             'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256,256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
             'VGG11' : [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
             'VGG13' : [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512,  'M'],
             }

In [25]:
class VGG_net(nn.Module):
    def __init__(self, in_channels = 3 ,num_classes = 1000):
        super(VGG_net, self).__init__()
        
        self.in_channels = in_channels
        self.num_classes = num_classes
        self.conv_layers = self.create_conv_layers(VGG_types['VGG19'])
        self.fc = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096),
            nn.ReLU(),
            nn.Dropout(p =0.5),
            nn.Linear(4096, 4096),
            nn.ReLU(),
            nn.Dropout(p=0),
            nn.Linear(4096, num_classes)
        )
        
    def forward(self,x):
        x = self.conv_layers(x)
        x = x.reshape(x.shape[0], -1)
        x = self.fc(x)
        return x
    
    def create_conv_layers(self, architecture):
        layers = []
        in_channels = self.in_channels
        
        for x in architecture:
            if type(x) == int:
                out_channels = x
                
                layers += [nn.Conv2d(in_channels = in_channels, out_channels = out_channels,
                                     kernel_size = (3,3), stride=(1,1), padding=(1,1)),
                           nn.BatchNorm2d(x), ## Not in original VGG Net
                           nn.ReLU()]
                in_channels = x
            elif x == 'M':
                layers = layers + [nn.MaxPool2d(kernel_size = (2,2), stride=(2,2))]
                
        return nn.Sequential(*layers)
                

In [26]:
model = VGG_net(in_channels=3, num_classes = 1000)
x = torch.randn(1, 3, 224, 224)

In [27]:
print(model(x).shape)

torch.Size([1, 1000])
