In [1]:
%matplotlib inline

from config import configs

import os
import io
import numpy as np
import matplotlib
import timeit
import random
import cv2
import deepdish as dd

import matplotlib.pyplot as plt

from pprint import pprint
from datetime import datetime
from pixor_targets import PIXORTargets
from skimage.transform import resize
from scipy.interpolate import griddata
from core.kitti import KITTI
from pixor_utils.model_utils import load_model, save_model
from data_utils.training_gen import TrainingGenerator
from data_utils.generator import Generator, KITTIGen
from tt import bev
from data_utils.generator import KITTIGen
from pixor_utils.post_processing import nms_bev
from test_utils.unittest import test_pc_encoder, test_target_encoder
from encoding_utils.pointcloud_encoder import OccupancyCuboidKITTI

In [2]:
DS_DIR = os.path.expanduser(configs['dataset_path'])

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = configs["gpu_id"]
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 

In [3]:
# Point Cloud Encoder
INPUT_SHAPE = configs['input_shape']

# Training
BATCH_SIZE = configs['hyperparams']['batch_size']
LEARNING_RATE = configs['hyperparams']['lr']
EPOCHS = configs['hyperparams']['epochs']
NUM_THREADS = configs['hyperparams']['num_threads']
MAX_Q_SIZE = configs['hyperparams']['max_q_size']

In [4]:
kitti = KITTI(DS_DIR, configs['training_target'])

train_ids = kitti.get_ids('train')
val_ids = kitti.get_ids('val')
micro_ids = kitti.get_ids('micro')

In [5]:
pc_encoder = OccupancyCuboidKITTI(0, 70, -40, 40, -1, 3, [0.1, 0.1, 0.4]) # 0.1, 0.1, 0.4

In [6]:
target_encoder = PIXORTargets(shape=(200, 175), 
                              stats=dd.io.load('kitti_stats/stats.h5'),
                              P_WIDTH=70, P_HEIGHT=80, P_DEPTH=4, 
                              subsampling_factor=(0.8, 1.2))

In [7]:
train_gen = KITTIGen(kitti, train_ids, BATCH_SIZE, pc_encoder=pc_encoder, target_encoder=target_encoder)

In [8]:
for batch in train_gen:
    bev, depth_map, intensity_map, height_map, (obj, geo) = batch
    print('bev.shape          ', bev.shape)
    print('depth_map.shape    ', depth_map.shape)
    print('intensity_map.shape', intensity_map.shape)
    print('height_map.shape   ', height_map.shape)
    print('obj.shape          ', obj.shape)
    print('geo.shape          ', geo.shape)
    print('------------------------')


bev.shape           (2, 800, 700, 10)
depth_map.shape     (2, 375, 1242, 3)
intensity_map.shape (2, 375, 1242, 3)
height_map.shape    (2, 375, 1242, 3)
obj.shape           (2, 200, 175, 2)
geo.shape           (2, 200, 175, 12)
------------------------
bev.shape           (2, 800, 700, 10)
depth_map.shape     (2, 375, 1242, 3)
intensity_map.shape (2, 375, 1242, 3)
height_map.shape    (2, 375, 1242, 3)
obj.shape           (2, 200, 175, 2)
geo.shape           (2, 200, 175, 12)
------------------------
bev.shape           (2, 800, 700, 10)
depth_map.shape     (2, 375, 1242, 3)
intensity_map.shape (2, 375, 1242, 3)
height_map.shape    (2, 375, 1242, 3)
obj.shape           (2, 200, 175, 2)
geo.shape           (2, 200, 175, 12)
------------------------
bev.shape           (2, 800, 700, 10)
depth_map.shape     (2, 375, 1242, 3)
intensity_map.shape (2, 375, 1242, 3)
height_map.shape    (2, 375, 1242, 3)
obj.shape           (2, 200, 175, 2)
geo.shape           (2, 200, 175, 12)
-----------------

bev.shape           (2, 800, 700, 10)
depth_map.shape     (2, 375, 1242, 3)
intensity_map.shape (2, 375, 1242, 3)
height_map.shape    (2, 375, 1242, 3)
obj.shape           (2, 200, 175, 2)
geo.shape           (2, 200, 175, 12)
------------------------
bev.shape           (2, 800, 700, 10)
depth_map.shape     (2, 375, 1242, 3)
intensity_map.shape (2, 375, 1242, 3)
height_map.shape    (2, 375, 1242, 3)
obj.shape           (2, 200, 175, 2)
geo.shape           (2, 200, 175, 12)
------------------------
bev.shape           (2, 800, 700, 10)
depth_map.shape     (2, 375, 1242, 3)
intensity_map.shape (2, 375, 1242, 3)
height_map.shape    (2, 375, 1242, 3)
obj.shape           (2, 200, 175, 2)
geo.shape           (2, 200, 175, 12)
------------------------
bev.shape           (2, 800, 700, 10)
depth_map.shape     (2, 375, 1242, 3)
intensity_map.shape (2, 375, 1242, 3)
height_map.shape    (2, 375, 1242, 3)
obj.shape           (2, 200, 175, 2)
geo.shape           (2, 200, 175, 12)
-----------------

KeyboardInterrupt: 