In [1]:
from pickletools import optimize
import torch, torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import os
from skimage import io, transform
from tqdm import tqdm

from model import *
from utils import predict_transform, bbox_iou, prep_image
from dataset import filter_labels, bbox_anchorbox_iou, draw_bbox, centre_dims_to_corners, DetectionDataset, Pad, ToTensor, Normalise
from loss import Yolo_Loss

plt.rcParams['figure.figsize'] = [15,15]

In [8]:
mean = [92.11938007161459, 102.83839236762152, 104.90335580512152]
std = [66.09941202519124, 70.6808655565459, 75.05305001603533]
bs = 2

## load custom dataset + transforms
transformed_train_data = DetectionDataset(
    label_dict="det_train_shortened.json",
    root_dir='images/',
    classes_file="data/bdd100k.names",
    grid_sizes=[13, 26, 52],
    anchors = np.array([
            [[116,90], [156,198], [373,326]],
            [[30, 61], [62, 45], [59,119]],
            [[10, 13], [16, 30], [33, 23]],
        ]),
    transform=transforms.Compose([
        Normalise(
            mean=mean,
            std=std
        ),
        Pad(416),
        ToTensor()
    ])
)

# separate into batches
train_loader = DataLoader(
    transformed_train_data,
    batch_size=bs,
    shuffle=True,
    num_workers=0
)

net = Net(cfgfile="cfg/model.cfg")
#print(list(net.parameters())[0])

criterion = Yolo_Loss()
optimizer = optim.SGD(net.parameters(), lr=0.00001, momentum=0.9)

CUDA = torch.cuda.is_available()

for epoch in range(1): # each image gets 3 detections, this happens n_epoch times

    running_loss = 0.0
    for i, data in enumerate(train_loader):
        input_img, labels = data.values()

        #print(input_img.shape)
        optimizer.zero_grad()

        # forward pass
        outputs = net(input_img, CUDA)
        # compute loss
        loss = criterion(outputs, labels).float()
        #loss.requires_grad = True
        # back prop
        
        loss.backward()
        optimizer.step()

        # print stats
        running_loss +=loss.item()
        if i % bs == bs-1: # print every 2000 mini-batches
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss:.3f}')
            running_loss = 0.0

print("done")

Parameter containing:
tensor([[[[ 0.0303, -0.0097, -0.1086],
          [-0.0507,  0.0268,  0.0096],
          [-0.1256, -0.0698, -0.1709]],

         [[ 0.0198,  0.1382, -0.1768],
          [-0.0073,  0.1307, -0.0617],
          [-0.0413, -0.1404, -0.1739]],

         [[ 0.0270, -0.1620, -0.1199],
          [-0.1229, -0.0822,  0.1449],
          [ 0.0349,  0.1468,  0.1018]]],


        [[[ 0.0223,  0.1880,  0.0817],
          [-0.0908, -0.0621, -0.0895],
          [ 0.0217, -0.0406, -0.1096]],

         [[ 0.1629, -0.0730, -0.0651],
          [ 0.0962, -0.0679,  0.1563],
          [-0.1386,  0.1367, -0.1333]],

         [[-0.1129, -0.1369, -0.0020],
          [-0.0421, -0.1138,  0.0462],
          [ 0.0742, -0.1354, -0.0649]]],


        [[[ 0.0263,  0.1837,  0.0090],
          [ 0.1243, -0.1007, -0.1905],
          [-0.1278,  0.1818, -0.1705]],

         [[-0.0901,  0.0411,  0.0536],
          [ 0.0482, -0.1497, -0.1275],
          [-0.0984,  0.1146,  0.0966]],

         [[-0.0307, -0

In [11]:
print(list(net.parameters())[0])

Parameter containing:
tensor([[[[ 1.0764e+00,  1.3696e+00,  7.4620e-01],
          [ 7.3623e-01,  5.7949e-01, -1.5573e-01],
          [-3.3049e-01,  1.1616e-01, -2.9785e-01]],

         [[ 4.2345e-01,  9.0908e-01, -2.5961e-02],
          [ 9.3436e-02,  5.6848e-03, -1.0097e+00],
          [-8.6361e-01, -6.2246e-01, -1.0940e+00]],

         [[ 3.2954e-01,  4.2341e-01,  6.6229e-02],
          [-1.1828e-01, -1.9132e-01, -5.6229e-01],
          [-7.4026e-01, -9.1591e-02, -4.6410e-01]]],


        [[[-1.9397e+00, -2.2089e+00, -2.4874e+00],
          [-1.8464e+00, -1.3654e+00, -1.3870e+00],
          [-2.2590e+00, -1.1503e+00, -9.0204e-01]],

         [[-1.2778e+00, -1.9352e+00, -2.1472e+00],
          [-1.1608e+00, -9.9436e-01, -9.0639e-01],
          [-1.7500e+00, -3.7051e-01, -4.7854e-01]],

         [[-9.5297e-01, -1.4002e+00, -1.5021e+00],
          [-8.4639e-01, -6.1926e-01, -6.2548e-01],
          [-6.7509e-01, -4.6164e-02,  1.0181e-01]]],


        [[[ 2.6078e-01,  1.3147e-01,  1.3857