In [1]:
from pickletools import optimize
import torch, torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import os
from skimage import io, transform

from model import *
from dataset import filter_labels, DetectionDataset, Rescale, ToTensor, Normalise

plt.rcParams['figure.figsize'] = [15,15]

## Load Dataset

In [2]:
class Pad(object):
    """
    Add padding to image

    Args:
        output_size (tuple or int): Desired output size. If tuple, output is
            matched to output_size. If int, smaller of image edges is matched
            to output_size keeping aspect ratio the same.
    """
    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple)) # make sure output size is EITHER int or tuple
        self.output_size = output_size

    def __call__(self, sample):
        image, categories, bboxes = sample["image"], sample["categories"], sample["bboxes"]

        img_w, img_h = image.shape[1], image.shape[0]
        w, h = self.output_size

        # calculate new width and height
        new_w = int(img_w * min(w/img_w, h/img_h))
        new_h = int(img_h * min(w/img_w, h/img_h))
        resized_image = cv2.resize(image, (new_w,new_h), interpolation = cv2.INTER_CUBIC)
        
        canvas = np.full((self.output_size[1], self.output_size[0], 3), 128)

        canvas[(h-new_h)//2:(h-new_h)//2 + new_h,(w-new_w)//2:(w-new_w)//2 + new_w,  :] = resized_image

        return {"image": canvas, "categories": categories, "bboxes": bboxes}

In [29]:
filtered_labels = filter_labels("det_train_shortened.json")

train_data = DetectionDataset(
    label_dict=filtered_labels,
    root_dir='images/',
)
## load custom dataset + transforms
transformed_train_data = DetectionDataset(
    label_dict=filtered_labels,
    root_dir='images/',
    transform=transforms.Compose([
        Rescale(608),
        Normalise(0.5, 0.5),
        Pad((608,608)),
        ToTensor()
    ])
)

## dataloader
train_loader = DataLoader(
    transformed_train_data,
    batch_size=1,
    shuffle=True,
    num_workers=0
)

In [40]:
transformed_train_data[0]["categories"]

['traffic light',
 'traffic light',
 'traffic sign',
 'traffic sign',
 'car',
 'car',
 'traffic sign']

In [41]:
transformed_train_data[1]["categories"]

['traffic sign',
 'traffic sign',
 'traffic sign',
 'pedestrian',
 'pedestrian',
 'pedestrian']

## Define Network

In [19]:
net = Net(cfgfile="cfg/model.cfg")

In [20]:
class TestNet(nn.Module):

    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)     # in 1 , out 6, 5x5 kernel
        self.conv2 = nn.Conv2d(6, 16, 5)    # in 6, out 16, 5x5 kernel

        self.pool = nn.MaxPool2d(2, 2)      # this time define own maxpooling
        
        # an affine op: y = Wx + b
        self.fc1 = nn.Linear(16*5*5, 120)   # 5x5 is image dim
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        # max pooling over a (2,2) window
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dims except batch dim
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x) # no activation on final layer -> output
        return x
testnet = TestNet()

## Define Loss Function and Optimiser

In [21]:
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
criterion = ...

## Train Network

In [30]:
CUDA = torch.cuda.is_available()
for i, data in enumerate(train_loader):
    input_img, cat, bboxes = data.values()

    optimizer.zero_grad()
    #plt.imshow(input_img)
    #input_img = norm_with_padding(input_img, 608)
    outputs = net(Variable(input_img.float()), CUDA)
    

In [36]:
outputs[...,0] # select first bbox attr from all bbox tensors

tensor([[ 20.6303,  20.2820,  21.7008,  ..., 603.5667, 603.3599, 603.8293]])

In [39]:
0.6 == 1

False