This notebook takes processed pickle files and makes a dataset for prediction.
In particular, given a history and prediction horizon, this generates all viable slices of a full trajectory.  This goes through all pickle files and constructs a TFRecord dataset for k-fold cross validation.

**Note**: This is just for reference.  Contact Vijay/Xu (emails in the README.md file) if you need access to the raw pickle data.

In [None]:
import glob
import numpy as np
import scipy.io as sio
import pickle
import matplotlib.pyplot as plt
import pdb
from datetime import datetime
import sklearn.utils as sku
from tqdm import tqdm

import traceback
from collections import defaultdict

import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"]= "0" 

from pkl_reader import *
from tfrecord_utils import write_tfrecord

from utils import get_parking_lot_image_hist

In [None]:
''' CONFIG: CONSTANTS FOR EXECUTION '''
num_folds_cv = 5 # k-fold cross validation, number of splits (-1 = only provide full set of data)
shuffle = True
seed = 0
prune_start=True          # remove stationary portion of ego's trajectory at the start
prune_end=True            # remove stationary portion of ego's trajectory at the end
min_vel_thresh=0.01       # velocity threshold (m/s) above which ego is considered moving
exclude_collisions=True  # return an empty trajectory if there was a collision

Nhist=5          # number of timesteps of motion history to predict with
Npred=20         # number of timesteps of prediction horizon
Nskip=5          # "stride" for sliding window of snippet selection
dt=0.1           # discretization (s) of full ego trajectory corresponding to N* above
ego_trans = True # whether or not to represent trajectory snippets in the ego frame
                 # if False, use the global map frame for all snippets

In [None]:
save_ext = 'pkl'
file_prefix = '/media/data/carla_parking_data_bkcp/bags/'
search_str = file_prefix + '*.' + save_ext 
files_to_process = glob.glob(search_str)
print('Found %d files to read: %s' % (len(files_to_process), files_to_process))

In [None]:
num_tfrecords = 0
skipped_files = 0
skipped_users = defaultdict(lambda:0)

# full dataset for all files_to_process
features_combined = []
features_global_combined = []
labels_combined = []
goal_snpts_combined = []
static_objs_combined = []

parking_lot = None
ego_dims    = None

for file in files_to_process:
    if save_ext == 'pkl':
        res_dict = pickle.load(open(file,'rb'))
    else:
        raise NotImplemented('Invalid Extension')
    
    goals = extract_goals(res_dict)
    parking_lot = res_dict['parking_lot']
    ego_dims = res_dict['ego_dimensions']
    
    try:
        assert goals.shape[0] == 32, "Invalid goal shape."
        assert len(res_dict['vehicle_object_lists'][0]) == 56, "Wrong number of static vehicles."
        
        # parse one demonstration
        ego_trajectory, start_ind, switch_ind, end_ind, goal_ind = \
             extract_full_trajectory(res_dict, goals, prune_start, prune_end, \
                                     min_vel_thresh, exclude_collisions)

        features, features_global, labels, labels_global, goal_snpts = \
            get_ego_trajectory_prediction_snippets(ego_trajectory, start_ind, switch_ind, end_ind, goal_ind, \
                                           goals, Nhist, Npred, Nskip, dt, ego_frame=ego_trans)
        
        features_combined.extend(features)
        features_global_combined.extend(features_global)
        labels_combined.extend(labels)
        goal_snpts_combined.extend(goal_snpts)
        
        static_object_list = res_dict['static_object_list']
        for i in range(len(features)):
            static_objs_combined.append(static_object_list)
    except Exception as e:
        print(file, e)
        skipped_files += 1
        traceback.print_exc()
        skipped_users[ file.split('_')[4] ] += 1

print("Num of skipped files due to exception: ", skipped_files)
print("Num of skipped files by user: ", skipped_users)
        
if shuffle:
    features_combined, features_global_combined, labels_combined, goal_snpts_combined, static_objs_combined = \
        sku.shuffle(features_combined, features_global_combined, labels_combined, goal_snpts_combined, static_objs_combined, random_state=seed)    

In [None]:
N_instances = len(features_combined)
    
if num_folds_cv > 0:
    splits = (N_instances // num_folds_cv) * np.ones(num_folds_cv)
    splits[:N_instances % num_folds_cv] += 1
    
    ind_limits = np.cumsum(splits).astype(np.int)
    
    for i in range(len(ind_limits)):
        if i == 0:
            ind_start = 0
        else:
            ind_start = ind_limits[i-1]
        ind_end = ind_limits[i]
        
        print('Fold', i, ind_start, ind_end)
 
        # write tfrecords in batches from ind_start:ind_end.
        #for j in range(0, ind_end - ind_start, batch_size):
        #    write data from ind_start + j : min(ind_start + j + batch_size, ind_end)
        batch_size = 100 # TODO.
        for batch_ind, j in enumerate(range(ind_start,ind_end,batch_size)):
            j_min = j
            j_max = min(j+batch_size,ind_end)
            
            print('Batch', batch_ind, j_min, j_max)
            
            img_hists_batch = np.array(
                [get_parking_lot_image_hist(parking_lot, 
                                            static_objs_combined[k], 
                                            features_global_combined[k], 
                                            ego_dims, resize_factor=0.5) for k in range(j_min, j_max)])
            
            print('Img Hist Shape', img_hists_batch.shape)
            
            file_location = file_prefix + 'dataset_fold_' + str(i) +'_' + str(batch_ind) + '.tfrecord'
            
            print('Saving to ', file_location)
            
            write_tfrecord(features_combined[j_min:j_max],
                           img_hists_batch,
                           labels_combined[j_min:j_max], 
                           goal_snpts_combined[j_min:j_max],
                           file_location, {})

In [None]:

# For visualization/debugging: see the first snippet of data.
for k in range(5):
    print('Instance', k)
    plt.figure(figsize=(10, 10), dpi=500, facecolor='w', edgecolor='k')
    for i in range(Nhist):
        plt.subplot(1, Nhist, i+1)
        plt.imshow(img_hists_batch[k, i, :, :, :])
    plt.tight_layout()
    plt.show()

'''
# For debugging: see an arbitrary snippet from the generated tfrecord.
image, feature, label, goal, count = read_tfrecord([file_location])

plt.figure(figsize=(10, 10), dpi=160, facecolor='w', edgecolor='k')
for i in range(Nhist):
    plt.subplot(1, Nhist, i+1)
    plt.imshow(image.numpy()[0, i, :, :, :])
plt.tight_layout()
plt.show()
'''

In [None]:
img_hists_batch.shape
# img_test = Image.fromarray(img_hists_batch[0, 0, :, :, :])
# img_test = img_test.resize((100, 325))
# img_test = np.asarray(img_test)

# plt.figure(figsize=(10, 10), dpi=100, facecolor='w', edgecolor='k')
# plt.axis('equal')
# plt.imshow(img_test)
# plt.show()