In [3]:
import numpy as np
import os
import sys
import torch
import torch.nn as nn
import torch.utils.data as data
import torch.functional as F
import cv2
import PIL.Image as Image
import kitti_object_tracking
import kitti_util_tracking
from collections import namedtuple
from image import *
import math
from torch.utils.data import DataLoader

In [4]:
# 文件路径
train_root = 'G:\\KITTI\\tracking'

In [5]:
# 这里得到的kitti_object包含了整个训练集
kitti_object = kitti_object_tracking.kitti_object_tracking(train_root, split='training')

思考一下dataset中需要准备什么数据：

- stereo图像：image_02与image_03 (check✔)
- 图像的calibration (check✔)
- 图像的labels (check✔)
- 图像尺寸处理 (check✔)
- 图像的heatmap（需要再生成✔）

getitem只能接受index，那么如何实现跨sequence读取图片呢？(✔)
如何将不同duration的图片划分为train set以及validation set呢？(✔)

- Loss function如何组织？
- 于此同时，dataset应该如何改进？
- depth是否能够进一步改进？

In [6]:
class tracking_dataset(data.Dataset):
    # 只对三类目标进行追踪
    num_categories = 3
    # 输入图像的分辨率——为了适应网络结构
    default_resolution = [384, 1280]
    output_resolution = [96, 320]
    class_name = ['Pedestrian', 'Car', 'Cyclist']
    # ['Pedestrian', 'Car', 'Cyclist', 'Van', 'Truck',  'Person_sitting', 'Tram', 'Misc', 'DontCare']
    # 这里的负值按照绝对值处理
    # 7 8 9 忽略
    cat_ids = {1: 1, 2: 2, 3: 3, 4: -2, 5: -2, 6: -1, 7: -9999, 8: -9999, 9: 0}
    # 确定一帧图像的最大主体数量
    max_objs = 50
    def __init__(self, kitti_object, root_dir, ki, K, typ='train'):
        """
            kiiti_object:kitti tracking 数据集对象
            root_dir: 数据目录
            ki: 第k折
            K: 总折数
            typ: 数据集类型，分为 train 以及 val
        """
        self.duration_frames = [154, 447, 233, 144, 314, 297, 270, 800, 390, 803, 294, 373, 78, 340, 106, 376, 209, 145, 339, 1059, 837]
        self.kitti_object = kitti_object
        self.root_dir = root_dir
        self.ki = ki
        self.K = K
        self.typ = typ
        # 用于val的duration
        self.val_durations = list(range(3*(ki-1), 3*ki))
        # 用于train的duration
        self.tra_durations = list(filter(lambda x: x not in self.val_durations, list(range(21))))
        # 确定数据集长度
        if typ=='train':
            self.len = sum([self.duration_frames[x] for x in self.tra_durations])
        else:
            self.len = sum([self.duration_frames[x] for x in self.val_durations])
        self.calibrations = {}
        self.labelObjects = {}
        
    def __getitem__(self, index):
        sequence, index = self.get_sqAndIdx(index)
        image2, image3, calib, labels = self.get_inp(sequence, index)
            
        # -------------------------------------------------------------------------
        height, width = image2.size[1], image2.size[0]  # 获得图像的尺寸
        c = np.array([image2.size[1] / 2., image2.size[0] / 2.], dtype=np.float32)  # 中心点                
        scale_width = self.default_resolution[1] / width # 计算从原始图像到网络输入图像的放缩因子
        scale_height = self.default_resolution[0] / height 
        # img由Image读取，已经转换为RGB格式
#         image2 = cv2.cvtColor(np.array(image2), cv2.COLOR_RGB2BGR)
#         image3 = cv2.cvtColor(np.array(image3), cv2.COLOR_RGB2BGR)
        image2 = np.array(image2)
        image3 = np.array(image3)
        image2 = cv2.resize(image2, (self.default_resolution[1], self.default_resolution[0]), interpolation=cv2.INTER_CUBIC)
        image2 = image2.transpose(2, 0, 1)
        image3 = cv2.resize(image3, (self.default_resolution[1], self.default_resolution[0]), interpolation=cv2.INTER_CUBIC)
        image3 = image3.transpose(2, 0, 1)
        ret = {'image2':image2, 'image3':image3}      
        self._init_ret(ret)
        num_objs = min(len(labels), self.max_objs)
        # 计算由原始图像到最终输出map的放缩因子
        scale_out = np.array((self.output_resolution[1] / width, self.output_resolution[0] / height), dtype = np.float32)
        for i in range(num_objs):
            cat_id = labels[i].type
            if(cat_id > self.num_categories or cat_id < -999):
                continue
            cat_id = abs(cat_id)
            ret['cat'][i] = cat_id
            ret['mask'][i] = 1  # mask的作用是判断该位置是否是有效的
            box_centerPoint = kitti_util_tracking.project_to_image(np.array(labels[i].t).reshape(1,3), calib.P)
            box_centerPoint = np.array([box_centerPoint[0][0]*scale_out[0], box_centerPoint[0][1] * scale_out[1]], dtype=np.int64)
            ret['ind'][i] = box_centerPoint[1] * self.output_resolution[1] + box_centerPoint[0]  
            ret['dim'][i] = np.array([labels[i].h, labels[i].w, labels[i].l], dtype=np.float32)  # 将三维长宽高组织为一个array
            ret['dim'][i] = 1
            ret['dep'][i] = labels[i].t[2]
            ret['dep'][i] = 1
            
            # 生成heatmap
            box_2d = labels[i].box2d
            h, w = box_2d[3]-box_2d[1], box_2d[2]-box_2d[0]  # 放缩前的h, w
            h, w = h*scale_out[1], w*scale_out[0]  # 放缩后的h, w
            radius = gaussian_radius((math.ceil(h), math.ceil(w)))
            radius = max(0, int(radius))
            # ct = np.array([(box_2d[2]-box_2d[0])/2, (box_2d[3]-box_2d[1])/2], dtype=np.float32)  # 原来的中心点
            # ct = np.array([ct[0] * scale_out[0], ct[1] * scale_out[1]], dtype=np.float32)  # 放缩后的中心点
            # ct_int = ct.astype(np.int32)
            draw_umich_gaussian(ret['hm'][cat_id - 1], box_centerPoint, radius)
            
            # 填充方向rot, 参见centerNet, bin based
            ret['rot_mask'][i] = 1
            ry = labels[i].ry
            if ry < np.pi / 6. or ry > 5 * np.pi / 6.:
                ret['rotbin'][i, 0] = 1
                ret['rotres'][i, 0] = ry - (-0.5 * np.pi)
            if ry > -np.pi / 6. or ry < -5 * np.pi / 6.:
                ret['rotbin'][i, 1] = 1
                ret['rotres'][i, 1] = ry - (0.5 * np.pi)
            
        return ret
        
    def __len__(self):
        return self.len
    
    def _init_ret(self, ret):
        # 该方法提前为输入图像的各项输入提前确定数据结构
        # 同时为GT生成空列表以待后面生成
        # hm, reg, wh, tracking, dep, rot, dim, amodel_offset
        ret['hm'] = np.zeros((self.num_categories, self.output_resolution[0], self.output_resolution[1]), np.float32)
        ret['ind'] = np.zeros(self.max_objs, dtype=np.int64)
        ret['cat'] = np.zeros(self.max_objs, dtype=np.int64)
        ret['mask'] = np.zeros(self.max_objs, dtype=np.float32)
        ret['dim'] = np.zeros((self.max_objs, 3), dtype=np.float32)
        ret['dim_mask'] = np.zeros((self.max_objs, 3), dtype=np.float32)
        ret['dep'] = np.zeros(self.max_objs, dtype=np.float32)
        ret['dep_mask'] = np.zeros(self.max_objs, dtype=np.float32)
        ret['rotbin'] = np.zeros((self.max_objs, 2), dtype=np.int64)
        ret['rotres'] = np.zeros((self.max_objs, 2), dtype=np.float32)
        ret['rot_mask'] = np.zeros((self.max_objs), dtype=np.float32)
        
    def get_sqAndIdx(self, index):
        """
        计算当前图像的sequence以及index
        """
        if(self.typ=='train'):
            for x in self.tra_durations:
                index -= self.duration_frames[x]
                if(index<=0):
                    index += self.duration_frames[x]
                    sequence = x
                    break
        else:
            for x in self.val_durations:
                index -= self.duration_frames[x]
                if(index<=0):
                    index += self.duration_frames[x]
                    sequence = x
                    break
        return sequence, index
    
    def get_inp(self, sequence, index):
        # 获得左右图像
        image2 = self.kitti_object.get_image2(sequence, index)
        image3 = self.kitti_object.get_image3(sequence, index)
        
        # 获得calibration
        if(sequence in self.calibrations.keys()):
            calib = self.calibrations[sequence]
        else:
            calib = self.kitti_object.get_calibration(sequence)
            self.calibrations[sequence] = calib
        
        # 获得labels
        if(sequence in self.labelObjects.keys()):
            labels = [x for x in self.labelObjects[sequence] if x.frame_idx==index]
        else:
            self.labelObjects[sequence] = self.kitti_object.get_label_objects(sequence)
            labels = [x for x in self.labelObjects[sequence] if x.frame_idx==index]
        
        return image2, image3, calib, labels

In [7]:
# 训练数据集
train_dataset = tracking_dataset(kitti_object, root_dir=train_root, ki=1, K=7, typ='train')

In [8]:
train_loader = DataLoader(train_dataset, batch_size=32)

In [13]:
for idx, val in enumerate(train_loader):
    image2 = val['image2']
    image3 = val['image3']
    stereo_img = torch.cat((image2, image3), dim=1)
    print(stereo_img.shape)
    break

torch.Size([32, 6, 384, 1280])


In [10]:
frame1 = train_dataset.__getitem__(15)

In [11]:
frame1.keys()

dict_keys(['image2', 'image3', 'hm', 'ind', 'cat', 'mask', 'dim', 'dim_mask', 'dep', 'dep_mask', 'rotbin', 'rotres', 'rot_mask'])

In [158]:
image2 = frame1['image2']

In [159]:
image2.shape
# 放缩后的图像尺寸

(3, 384, 1280)

In [160]:
img = Image.fromarray(image2.transpose(1,2,0))

In [161]:
img.show()

In [139]:
hm = frame1['hm']  # heatmap

In [140]:
hm.shape

(3, 96, 320)

In [141]:
heatmap = Image.fromarray(hm[1]*255)

In [142]:
heatmap.show()

In [143]:
# ind
ind = frame1['ind']

In [144]:
ind  # 只有三个目标

array([24223, 17327, 16698,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0,     0,     0,     0,     0,
           0,     0,     0,     0,     0], dtype=int64)

In [146]:
# cat
cat = frame1['cat']

In [148]:
cat  # 实体类别

array([2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0], dtype=int64)

In [149]:
# dim, 长宽高
dim = frame1['dim']

In [151]:
# dep
dep = frame1['dep']

In [152]:
dep  # 某种程度上来说这个并不是真正的depth，但由于只需要进行链接，等效即可

array([ 9.269818, 34.40579 , 38.607296,  0.      ,  0.      ,  0.      ,
        0.      ,  0.      ,  0.      ,  0.      ,  0.      ,  0.      ,
        0.      ,  0.      ,  0.      ,  0.      ,  0.      ,  0.      ,
        0.      ,  0.      ,  0.      ,  0.      ,  0.      ,  0.      ,
        0.      ,  0.      ,  0.      ,  0.      ,  0.      ,  0.      ,
        0.      ,  0.      ,  0.      ,  0.      ,  0.      ,  0.      ,
        0.      ,  0.      ,  0.      ,  0.      ,  0.      ,  0.      ,
        0.      ,  0.      ,  0.      ,  0.      ,  0.      ,  0.      ,
        0.      ,  0.      ], dtype=float32)

In [14]:
# rot
# rotbin 判断该车的方向属于哪一个bin
rotbin = frame1['rotbin']

In [20]:
rotbin

array([[1, 0],
       [0, 1],
       [0, 1],
       [0, 0],
       [0, 0],
       [0, 0],
       [0, 0],
       [0, 0],
       [0, 0],
       [0, 0],
       [0, 0],
       [0, 0],
       [0, 0],
       [0, 0],
       [0, 0],
       [0, 0],
       [0, 0],
       [0, 0],
       [0, 0],
       [0, 0],
       [0, 0],
       [0, 0],
       [0, 0],
       [0, 0],
       [0, 0],
       [0, 0],
       [0, 0],
       [0, 0],
       [0, 0],
       [0, 0],
       [0, 0],
       [0, 0],
       [0, 0],
       [0, 0],
       [0, 0],
       [0, 0],
       [0, 0],
       [0, 0],
       [0, 0],
       [0, 0],
       [0, 0],
       [0, 0],
       [0, 0],
       [0, 0],
       [0, 0],
       [0, 0],
       [0, 0],
       [0, 0],
       [0, 0],
       [0, 0]], dtype=int64)

In [18]:
# rotres
# rotres记录的是角度与该bin的中心的弧度差
rotres = frame1['rotres']

In [21]:
rotres

array([[-0.03889867,  0.        ],
       [ 0.        , -0.02543833],
       [ 0.        , -0.01674733],
       [ 0.        ,  0.        ],
       [ 0.        ,  0.        ],
       [ 0.        ,  0.        ],
       [ 0.        ,  0.        ],
       [ 0.        ,  0.        ],
       [ 0.        ,  0.        ],
       [ 0.        ,  0.        ],
       [ 0.        ,  0.        ],
       [ 0.        ,  0.        ],
       [ 0.        ,  0.        ],
       [ 0.        ,  0.        ],
       [ 0.        ,  0.        ],
       [ 0.        ,  0.        ],
       [ 0.        ,  0.        ],
       [ 0.        ,  0.        ],
       [ 0.        ,  0.        ],
       [ 0.        ,  0.        ],
       [ 0.        ,  0.        ],
       [ 0.        ,  0.        ],
       [ 0.        ,  0.        ],
       [ 0.        ,  0.        ],
       [ 0.        ,  0.        ],
       [ 0.        ,  0.        ],
       [ 0.        ,  0.        ],
       [ 0.        ,  0.        ],
       [ 0.        ,