In [None]:
import torch
import torch.nn as nn

In [None]:
class VGG16(nn.Module):
    def __init__(self, img_size):
        super().__init__()

        self.img_size = img_size

        self.conv_layers = nn.ModuleList([
            # layer 1
            nn.Conv2d(3, 64, kernel_size = 3, padding = 1), 
            nn.BatchNorm2d(64),
            nn.ReLU(inplace = True), 

            # layer 2
            nn.Conv2d(64, 64, kernel_size = 3, padding = 1), 
            nn.BatchNorm2d(64),
            nn.ReLU(inplace = True),
            nn.MaxPool2d(2, 2),

            # layer 3
            nn.Conv2d(64, 128, kernel_size = 3, padding = 1), 
            nn.BatchNorm2d(128),
            nn.ReLU(inplace = True),

            # layer 4
            nn.Conv2d(128, 128, kernel_size = 3, padding = 1), 
            nn.BatchNorm2d(128),
            nn.ReLU(inplace = True), 
            nn.MaxPool2d(2, 2),

            # layer 5
            nn.Conv2d(128, 256, kernel_size = 3, padding = 1), 
            nn.BatchNorm2d(256),
            nn.ReLU(inplace = True),

            # layer 6
            nn.Conv2d(256, 256, kernel_size = 3, padding = 1), 
            nn.BatchNorm2d(256),
            nn.ReLU(inplace = True), 
        
            # layer 7
            nn.Conv2d(256, 256, kernel_size = 3, padding = 1), 
            nn.BatchNorm2d(256),
            nn.ReLU(inplace = True),
            nn.MaxPool2d(2, 2),
    
            # layer 8
            nn.Conv2d(256, 256, kernel_size = 3, padding = 1), 
            nn.BatchNorm2d(256),
            nn.ReLU(inplace = True), 

            # layer 9
            nn.Conv2d(512, 512, kernel_size = 3, padding = 1), 
            nn.BatchNorm2d(512),
            nn.ReLU(inplace = True),

            # layer 10
            nn.Conv2d(512, 512, kernel_size = 3, padding = 1), 
            nn.BatchNorm2d(512),
            nn.ReLU(inplace = True), 
            nn.MaxPool2d(2, 2),
        
            # layer 11
            nn.Conv2d(512, 512, kernel_size = 3, padding = 1), 
            nn.BatchNorm2d(512),
            nn.ReLU(inplace = True), 
            
            # layer 12
            nn.Conv2d(512, 512, kernel_size = 3, padding = 1), 
            nn.BatchNorm2d(512),
            nn.ReLU(inplace = True), 

            # layer 13
            nn.Conv2d(512, 512, kernel_size = 3, padding = 1), 
            nn.BatchNorm2d(512),
            nn.ReLU(inplace = True), 
            nn.MaxPool2d(2, 2)
        ])

        self.fc_layers = nn.ModuleList([
            # layer 14
            nn.Dropout(0.5), 
            nn.Linear((self.img_size/2**5)*512, 4096), 
            nn.ReLU(inplace=True), 

            # layer 15
            nn.Dropout(0.5), 
            nn.Linear(4096, 4096), 
            nn.ReLU(inplace=True), 

            # layer 16
            nn.Linear(4096, 5) # 5 - P, x, y, w, h
        ])

    def forward(self, out):

        for cur_layer in self.conv_layers:
            out = cur_layer(out)

        out = torch.flatten(out, 1) 

        for cur_layer in self.fc_layers:
            out = cur_layer(out)
        
        return out
        

