In [1]:
from torch_lib.Dataset import *
from torch_lib.Model import Model, OrientationLoss


import torch
import torch.nn as nn
from torch.autograd import Variable
from torchvision.models import vgg
from torch.utils import data

In [2]:
import os

In [3]:
train_path = os.path.join('dataset','20220528_qd3dt_Center_2k/KITTI/detection/training/')

In [4]:
dataset = Dataset(train_path)

In [6]:
dataset.labels['000001']['0']

{'Class': 'Car',
 'Box_2D': [(322, 422), (434, 512)],
 'Dimensions': array([ 0.1492714,  0.0942542, -0.1817304]),
 'Alpha': 0.33,
 'Orientation': array([[ 0.        ,  0.        ],
        [ 0.32404303, -0.94604234]]),
 'Confidence': array([0., 1.]),
 'OrientationX': array([[-0.87461971, -0.48480962],
        [ 0.87461971,  0.48480962]]),
 'ConfidenceX': array([1., 1.]),
 'OrientationY': array([[0.        , 0.        ],
        [0.87461971, 0.48480962]]),
 'ConfidenceY': array([0., 1.]),
 'OrientationZ': array([[-0.75470958, -0.65605903],
        [ 0.75470958,  0.65605903]]),
 'ConfidenceZ': array([0., 1.]),
 'rot': [-1.064650843716541, 5.218534463463046, -0.855211333477221]}

In [7]:
attr = vars(dataset)

In [8]:
print(attr.keys())

dict_keys(['top_label_path', 'top_img_path', 'top_calib_path', 'proj_matrix', 'ids', 'num_images', 'bins', 'angle_bins', 'interval', 'overlap', 'bin_ranges', 'averages', 'object_list', 'labels', 'curr_id', 'curr_img'])


In [9]:
attr['averages'].get_item('Car')

array([0.36644973, 0.62636991, 0.27536507])

In [8]:
attr['proj_matrix']

array([[208.,   0., 512.,   0.],
       [  0., 208., 512.,   0.],
       [  0.,   0.,   1.,   0.]], dtype=float32)

In [9]:
attr['proj_matrix']

array([[208.,   0., 512.,   0.],
       [  0., 208., 512.,   0.],
       [  0.,   0.,   1.,   0.]], dtype=float32)

In [7]:
epochs = 100
batch_size = 64
alpha = 0.6
w = 0.4

In [8]:
params = {'batch_size': batch_size,
          'shuffle': True,
          'num_workers': 6}

In [9]:
generator = data.DataLoader(dataset, **params)

In [25]:
next(iter(generator))[1]['Orientation']

tensor([[[ 0.0000,  0.0000],
         [ 0.2571, -0.9664]],

        [[-0.0216, -0.9998],
         [ 0.0216,  0.9998]],

        [[ 0.0000,  0.0000],
         [ 0.1708,  0.9853]],

        [[ 0.0000,  0.0000],
         [ 0.4350, -0.9004]],

        [[-0.0899,  0.9960],
         [ 0.0899, -0.9960]],

        [[ 0.0000,  0.0000],
         [ 0.2280, -0.9737]],

        [[ 0.0000,  0.0000],
         [ 0.3523, -0.9359]],

        [[-0.0200,  0.9998],
         [ 0.0200, -0.9998]],

        [[ 0.0000,  0.0000],
         [ 0.4364,  0.8998]],

        [[ 0.0000,  0.0000],
         [ 0.5810, -0.8139]],

        [[ 0.0000,  0.0000],
         [ 0.3523, -0.9359]],

        [[-0.0316, -0.9995],
         [ 0.0316,  0.9995]],

        [[ 0.0000,  0.0000],
         [ 0.5985,  0.8011]],

        [[ 0.0000,  0.0000],
         [ 0.2860, -0.9582]],

        [[ 0.0000,  0.0000],
         [ 0.4169, -0.9090]],

        [[ 0.0000,  0.0000],
         [ 0.3444,  0.9388]],

        [[ 0.0000,  0.0000],
         [ 

In [27]:
my_vgg = vgg.vgg19_bn(pretrained=True)
model = Model(features=my_vgg.features).cuda()
opt_SGD = torch.optim.SGD(model.parameters(), lr=0.0001, momentum=0.9)
conf_loss_func = nn.CrossEntropyLoss().cuda()
dim_loss_func = nn.MSELoss().cuda()
orient_loss_func = OrientationLoss

In [28]:
model_path = os.path.join('weights/')
latest_model = None
first_epoch = 0

In [29]:
total_num_batches = int(len(dataset) / batch_size)

In [30]:
total_num_batches

68

In [35]:
for epoch in range(first_epoch+1, epochs+1):
    curr_batch = 0
    passes = 0
    for local_batch, local_labels in generator:

        truth_orient = local_labels['Orientation'].float().cuda()
        truth_conf = local_labels['Confidence'].long().cuda()
        truth_dim = local_labels['Dimensions'].float().cuda()

        local_batch=local_batch.float().cuda()
        [orient, conf, dim] = model(local_batch)

        orient_loss = orient_loss_func(orient, truth_orient, truth_conf)
        dim_loss = dim_loss_func(dim, truth_dim)

        truth_conf = torch.max(truth_conf, dim=1)[1]
        conf_loss = conf_loss_func(conf, truth_conf)

        loss_theta = conf_loss + w * orient_loss
        loss = alpha * dim_loss + loss_theta

        opt_SGD.zero_grad()
        loss.backward()
        opt_SGD.step()


        if passes % 20 == 0:
            print("--- epoch %s | batch %s/%s --- [loss: %s] [loss_dim: %s] [loss_orient: %s]" \
                  %(epoch, curr_batch, total_num_batches, round(loss.item(),4),round(dim_loss.item(),4),round(loss_theta.item(),4)))
            passes = 0

        passes += 1
        curr_batch += 1

    # save after every 10 epochs
    if epoch % 10 == 0:
        name = model_path + 'epoch_%s.pkl' % epoch
        print("====================")
        print ("Done with epoch %s!" % epoch)
        print ("Saving weights as %s ..." % name)
        torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': opt_SGD.state_dict(),
                'loss': loss
                }, name)
        print("====================")

--- epoch 1 | batch 0/68 --- [loss: 0.5713] [loss_dim: 0.0249] [loss_orient: 0.5563]
--- epoch 1 | batch 20/68 --- [loss: 0.5266] [loss_dim: 0.0207] [loss_orient: 0.5142]
--- epoch 1 | batch 40/68 --- [loss: 0.5209] [loss_dim: 0.0222] [loss_orient: 0.5076]
--- epoch 1 | batch 60/68 --- [loss: 0.4525] [loss_dim: 0.0208] [loss_orient: 0.44]
--- epoch 2 | batch 0/68 --- [loss: 0.4939] [loss_dim: 0.0175] [loss_orient: 0.4834]
--- epoch 2 | batch 20/68 --- [loss: 0.4049] [loss_dim: 0.0167] [loss_orient: 0.3949]
--- epoch 2 | batch 40/68 --- [loss: 0.3974] [loss_dim: 0.0169] [loss_orient: 0.3873]
--- epoch 2 | batch 60/68 --- [loss: 0.4153] [loss_dim: 0.0157] [loss_orient: 0.4059]
--- epoch 3 | batch 0/68 --- [loss: 0.3513] [loss_dim: 0.0149] [loss_orient: 0.3423]
--- epoch 3 | batch 20/68 --- [loss: 0.3099] [loss_dim: 0.015] [loss_orient: 0.3008]
--- epoch 3 | batch 40/68 --- [loss: 0.3358] [loss_dim: 0.0125] [loss_orient: 0.3283]
--- epoch 3 | batch 60/68 --- [loss: 0.2474] [loss_dim: 0.01

KeyboardInterrupt: 