In [5]:
import torch
import torch.nn as nn

from tqdm import tqdm 
import random
import datetime

In [2]:
class VGG16(nn.Module):
    def __init__(self, img_size):
        super().__init__()

        self.img_size = img_size

        self.conv_layers = nn.ModuleList([])

        layers_with_maxpool = [2, 4, 7, 10, 13]
        kernel_count = {1: 64, 2: 64, 3: 128, 4: 128, 5: 256, 6: 256,
                        7: 256, 8: 512, 9: 512, 10: 512, 11: 512, 12: 512, 13: 512}

        for layer in range(1, 14):
            cur_kernel_count = kernel_count[layer]
            self.conv_layers.append(nn.Conv2d(cur_kernel_count, cur_kernel_count, kernel_size=3, padding=1))
            self.conv_layers.append(nn.BatchNorm2d(cur_kernel_count))
            self.conv_layers.append(nn.ReLU(inplace=True))

            if layer in layers_with_maxpool:
                self.conv_layers.append(nn.MaxPool2d(2, 2))

        self.fc_layers = nn.ModuleList([])
        neuron_count = {14: (self.img_size/2**5)*512, 15: 4096}

        for layer in range(14, 16):
            cur_neuron_count = neuron_count[layer]
            self.fc_layers.append(nn.Dropout(0.5))
            self.fc_layers.append(nn.Linear(int(cur_neuron_count), 4096))
            self.fc_layers.append(nn.ReLU(inplace=True))

        self.fc_layers.append(nn.Linear(4096, 5))  # 5 - P, x, y, w, h

    def forward(self, out):

        for cur_layer in self.conv_layers:
            out = cur_layer(out)

        out = torch.flatten(out, 1)

        for cur_layer in self.fc_layers:
            out = cur_layer(out)

        return out

In [3]:
model = VGG16(256)

In [6]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [None]:
def save_img(img, pred, epoch):
    box = torchvision.utils.draw_bounding_boxes(img, [pred[0], pred[1], pred[2], pred[3]], colors = 'red')
    torchvision.transforms.ToPILImage()(box)
    image_path = f"./results/epoch_{epoch}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}"
    pil_image.save(image_path)

In [None]:
epochs = 10 

for epoch in range(epochs):
    epoch_loss = 0.0

    img_ind_tosave = random.randint(0, batch_size)
    img_ind = 0

    for sample in (pbar := tqdm(dataloader)):
        img, box = sample['img'], sample['box']

        optimizer.zero_grad()
        pred = model(img)
        loss = loss_fn(pred, box)

        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

        if img_ind == img_ind_tosave:
            save_img(img, pred, epoch)
        
        img_ind += 1

    print(f"Epoch: {epoch}\tLoss: {epoch_loss / len(dataloader)}")