In [None]:
import torch
import torch.nn as nn
from torch.hub import load_state_dict_from_url
from typing import Union, List, Dict, Any, cast

In [None]:
model_urls = {
    'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
    'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
    'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
    'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
    'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth',
    'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth',
    'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth',
    'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',
}

In [None]:
cfgs: Dict[str, List[Union[str, int]]] = {
    'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}

In [None]:
def stack_layers(cfg: List[Union[str, int]], batch_norm: bool = False) -> nn.Sequential:
  layers: List[nn.Module] = list()
  in_channels = 3

  for v in cfg:
    if v == 'M':
      max_pool = nn.MaxPool2d(kernel_size=2, stride=2)
      layers.append(max_pool)
    else:
      v = cast(int, v)
      conv2d = nn.Conv2d(in_channels=in_channels, out_channels=v, kernel_size=3, padding=1)
      layers.append(conv2d)
      if batch_norm:
        layers.append(nn.BatchNorm2d(v))
      relu = nn.ReLU(inplace=True)
      layeres.append(relu)
      in_channels = v
  return nn.Sequential(*layers)

In [None]:
class VGG(nn.Module):
  def __init__(self,
               features: nn.Module,
               num_classes: int = 1000,
               init_weights: bool = True):
    super(VGG, self).__init__()
    self.features = features
    self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
    self.classifier = nn.Sequential(
        nn.Linear(512 * 7 * 7, 4096),
        nn.ReLU(True),
        nn.Dropout(),

        nn.Linear(4096, 4096),
        nn.ReLU(True),
        nn.Dropout(),

        nn.Linear(4096, num_classes)
    )
    
    if init_weights:
      self._initialize_weights()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
      x = self.features(x)
      # average pooling output : 7 by 7
      x = self.avgpool(x)
      # batch size * channel * height * weidth
      # 128, (512 * 7 * 7)
      x = torch.flatten(x, 1)
      x = self.classifer(x)
      return x

    def _initialize_weights(self):
      for m in self.modules():
        if isinstance(m, nn.Conv2d):
          nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
          if m.bias is not None:
              nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.Linear):
            nn.init.normal_(m.weight, 0, 0.01)
            nn.init.constant_(m.bias, 0)

In [None]:
def _vgg(arch: str, cfg: str, batch_norm: bool, pretrained: bool, progress: bool, **kwargs: Any) -> VGG:
  if pretrained:
    kwargs['init_weights'] = False
  
  model = VGG(stack_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
  if pretrained:
    state_dict = load_state_dict_from_url(model_urls[arch],
                                          progress=progress)
    model.load_state_dict(state_dict)
  
  return model