In [1]:
import torch
import torch.nn as nn

In [2]:
class VGG16(nn.Module):
    def __init__(self, img_size):
        super().__init__()

        self.img_size = img_size

        self.conv_layers = nn.ModuleList([])

        layers_with_maxpool = [2, 4, 7, 10, 13]
        kernel_count = {1: 64, 2: 64, 3: 128, 4: 128, 5: 256, 6: 256,
                        7: 256, 8: 512, 9: 512, 10: 512, 11: 512, 12: 512, 13: 512}

        for layer in range(1, 14):
            cur_kernel_count = kernel_count[layer]
            self.conv_layers.append(nn.Conv2d(cur_kernel_count, cur_kernel_count, kernel_size=3, padding=1))
            self.conv_layers.append(nn.BatchNorm2d(cur_kernel_count))
            self.conv_layers.append(nn.ReLU(inplace=True))

            if layer in layers_with_maxpool:
                self.conv_layers.append(nn.MaxPool2d(2, 2))

        self.fc_layers = nn.ModuleList([])
        neuron_count = {14: (self.img_size/2**5)*512, 15: 4096}

        for layer in range(14, 16):
            cur_neuron_count = neuron_count[layer]
            self.fc_layers.append(nn.Dropout(0.5))
            self.fc_layers.append(nn.Linear(cur_neuron_count, 4096))
            self.fc_layers.append(nn.ReLU(inplace=True))

        self.fc_layers.append(nn.Linear(4096, 5))  # 5 - P, x, y, w, h

    def forward(self, out):

        for cur_layer in self.conv_layers:
            out = cur_layer(out)

        out = torch.flatten(out, 1)

        for cur_layer in self.fc_layers:
            out = cur_layer(out)

        return out