In [48]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
import matplotlib.pyplot as plt


In [49]:
# only the good ilplelentations
configs = {
  'VGG11': [64, 'm', 128, 'm', 256, 256, 'm', 512, 512, 'm', 512, 512, 'm'],
  'VGG13': [64, 64,  'm', 128, 128, 'm', 256, 256, 'm', 512, 512, 'm', 512, 512, 'm'],
  'VGG16': [64, 64,  'm', 128, 128, 'm', 256, 256, 256, 'm', 512, 512, 512, 'm', 512, 512, 512, 'm'],
  'VGG19': [64, 64,  'm', 128, 128, 'm', 256, 256, 256, 256, 'm', 512, 512, 512, 512, 'm', 512, 512, 512, 512, 'm'],
}


In [66]:
class VGG(nn.Module):
  def __init__(self, feats, classes=1000):
    super().__init__()
    self.net = self.make_layers(feats)
    self.classifier = nn.Sequential(
      nn.AdaptiveAvgPool2d((7,7)),
      nn.Flatten(),
      nn.Linear(512*7*7, 4096),
      nn.ReLU(),
      nn.Dropout(),
      nn.Linear(4096, 4096),
      nn.ReLU(),
      nn.Dropout(),
      nn.Linear(4096, classes),
      nn.ReLU(),
    )
    self.init_weights()
  
  def forward(self, x):
    x = self.net(x)
    return self.classifier(x)
  
  def init_weights(self):
    for l in self.modules():
      if isinstance(l, nn.Conv2d) or isinstance(l, nn.Linear):
        nn.init.normal_(l.weight, 0, 0.01)
        nn.init.constant_(l.bias, 0)

  def make_layers(self, layer_info, in_ch=3):
    layers = []
    for i in layer_info:
      if i == 'm':
        layers.append(nn.MaxPool2d(2, 2))
      else:
        layers.append(nn.Conv2d(in_ch, i, 3, 1, 1))
        layers.append(nn.ReLU())
        in_ch = i
    return nn.Sequential(*layers)

In [67]:
# example of creating a model using the provided configs
vgg11 = VGG(configs['VGG11'])

[64, 'm', 128, 'm', 256, 256, 'm', 512, 512, 'm', 512, 512, 'm']


In [68]:
print(vgg11)

VGG(
  (net): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU()
    (8): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU()
    (10): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (11): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (12): ReLU()
    (13): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (14): ReLU()
    (15): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (16): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))


In [70]:
vgg11(torch.rand(8, 3, 224, 224)).shape

torch.Size([8, 1000])