<a href="https://colab.research.google.com/github/TimofeyKulakov/NeuralNets/blob/master/VGG.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F

# Model

In [None]:
def test_module(module, size = (1, 3, 224, 224)):
  test = torch.randn(size)
  print(module.forward(test).shape)

In [None]:
class conv_layer(nn.Module):
  def __init__(self, in_channels, out_channels, num_convs, add1dconv = False):
    super(conv_layer, self).__init__()
  
    self.mods = nn.ModuleList(
      [j for i in range(num_convs) for j in [nn.Conv2d(in_channels if i == 0 else out_channels, out_channels, 
                                                       kernel_size = 1 if ((add1dconv==True) & (i==num_convs-1)) else 3,
                                                       stride = 1,  
                                                       padding = 0 if ((add1dconv==True) & (i==num_convs-1)) else 1), nn.BatchNorm2d(out_channels), nn.ReLU()]]
        )
    self.pool = nn.MaxPool2d(2, stride = 2)
  
  def forward(self, x):
    for m in self.mods:
      x = m(x)
    return self.pool(x)

In [None]:
class VGG(nn.Module):
  def __init__(self, n_classes, in_channels = 3, n_convs_list = [1, 1, 2, 2, 2], out_channels = [64, 128, 256, 512, 512], conv1_layers_idxs = []):
    super(VGG, self).__init__()

    self.mods = nn.ModuleList(
    [conv_layer(3 if k == 0 else out_channels[k - 1], i, n, k in conv1_layers_idxs)  for k, (i, n) in enumerate(zip(out_channels, n_convs_list))]
    )
    
    self.ff = nn.Sequential(
       nn.Flatten(),
       nn.Linear((224 // (2**len(n_convs_list))) * (224 // (2**len(n_convs_list))) * out_channels[-1], 4096),
       nn.ReLU(),
       nn.Dropout(0.5),
       nn.Linear(4096, 4096),
       nn.ReLU(),
       nn.Dropout(0.5),
       nn.Linear(4096, n_classes),
    )

  def forward(self, x):
    if not ((x.shape[2], x.shape[3]) == (224, 224)):
      x = F.interpolate(x, size=(224, 224), mode='bilinear')

    for m in self.mods:
      x = m(x)
    
    return self.ff(x)


In [None]:
VGG(6, 3, [2,2,3,3,3], conv1_layers_idxs=[2,3,4])

VGG(
  (mods): ModuleList(
    (0): conv_layer(
      (mods): ModuleList(
        (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
        (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU()
      )
      (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (1): conv_layer(
      (mods): ModuleList(
        (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
        (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU()
    

In [None]:
test_module(VGG(1000, 3, [2,2,3,3,3], conv1_layers_idxs=[2,3,4]))

torch.Size([1, 1000])
