# Play with Trajectory_Dataloader class #2

In [36]:
import os
import pickle
import random
import time
import numpy as np
import torch

In [37]:
DATASET_NAME_TO_NUM = {
    'eth': 0,
    'hotel': 1,
    'zara1': 2,
    'zara2': 3,
    'univ': 4
}

In [92]:
class Object(object):
    pass

args = Object()
args.dataset = 'eth5'
args.test_set = 'zara1'
args.save_dir = './output/' + args.test_set + "/"
args.seq_length = 20
args.batch_size = 8
if not os.path.exists(args.save_dir):
    os.makedirs(args.save_dir)

self = Object()
self.args = args


In [93]:
if self.args.dataset == 'eth5':

    self.data_dirs = ['../data/eth/univ', '../data/eth/hotel',
                      '../data/ucy/zara/zara01', '../data/ucy/zara/zara02',
                      '../data/ucy/univ/students001', '../data/ucy/univ/students003',
                      '../data/ucy/univ/uni_examples', '../data/ucy/zara/zara03']

    # Data directory where the pre-processed pickle file resides
    self.data_dir = './data'
    skip = [6, 10, 10, 10, 10, 10, 10, 10]

    train_set = [i for i in range(len(self.data_dirs))]

    assert args.test_set in DATASET_NAME_TO_NUM.keys(), 'Unsupported dataset {}'.format(args.test_set)

    args.test_set = DATASET_NAME_TO_NUM[args.test_set]

    if args.test_set == 4 or args.test_set == 5:
        self.test_set = [4, 5]
    else:
        self.test_set = [self.args.test_set]

    for x in self.test_set:
        train_set.remove(x)

    self.train_dir = [self.data_dirs[x] for x in train_set]
    self.test_dir = [self.data_dirs[x] for x in self.test_set]
    self.trainskip = [skip[x] for x in train_set]
    self.testskip = [skip[x] for x in self.test_set]
else:
    raise NotImplementedError
    
self.train_data_file = os.path.join(self.args.save_dir, "train_trajectories.cpkl")
self.test_data_file = os.path.join(self.args.save_dir, "test_trajectories.cpkl")
self.train_batch_cache = os.path.join(self.args.save_dir, "train_batch_cache.cpkl")
self.test_batch_cache = os.path.join(self.args.save_dir, "test_batch_cache.cpkl")

## 1. dataPreprocess

In [94]:
setname = 'train'

In [95]:
if setname == 'train':
    data_dirs = self.train_dir
    data_file = self.train_data_file
else:
    data_dirs = self.test_dir
    data_file = self.test_data_file


def load_dict(data_file):
    f = open(data_file, 'rb')
    raw_data = pickle.load(f)
    f.close()

    frameped_dict = raw_data[0]
    pedtraject_dict = raw_data[1]

    return frameped_dict, pedtraject_dict

In [96]:
frameped_dict, pedtraject_dict = load_dict(data_file)

In [97]:
def dataPreprocess(self, setname):
    '''
    Function to load the pre-processed data into the DataLoader object
    '''
    if setname == 'train':
        val_fraction = 0
        frameped_dict = self.frameped_dict
        pedtraject_dict = self.pedtraject_dict
        cachefile = self.train_batch_cache

    else:
        val_fraction = 0
        frameped_dict = self.test_frameped_dict
        pedtraject_dict = self.test_pedtraject_dict
        cachefile = self.test_batch_cache
        
    if setname != 'train':
        shuffle = False
    else:
        shuffle = True
        
    data_index = self.get_data_index(frameped_dict, setname, ifshuffle=shuffle)
    val_index = data_index[:, :int(data_index.shape[1] * val_fraction)]
    train_index = data_index[:, (int(data_index.shape[1] * val_fraction) + 1):]
    trainbatch = self.get_seq_from_index_balance(frameped_dict, pedtraject_dict, train_index, setname)
    valbatch = self.get_seq_from_index_balance(frameped_dict, pedtraject_dict, val_index, setname)
    trainbatchnums = len(trainbatch)
    valbatchnums = len(valbatch)

    f = open(cachefile, "wb")
    pickle.dump((trainbatch, trainbatchnums, valbatch, valbatchnums), f, protocol=2)
    f.close()

## 2. get_data_index

In [141]:
def get_data_index(self, data_dict, setname, ifshuffle=True):
    '''
    Get the dataset sampling index.
    data_index:
        --> First row: frame id in set
        --> which set
        --> new frame id, all scenes
    '''
    set_id = [] # which scene is the frame of frame_id_in_set related to
    frame_id_in_set = [] # frames in all train/test scenes
    total_frame = 0 # total number of frames in train/test scenes
    
    for seti, dict in enumerate(data_dict):
        frames = sorted(dict)
        maxframe = max(frames) - self.args.seq_length
        ####### TOCHECK
        ####### Why are we subtracting 20 frames and not 20 timesteps?!
        ####!!!!!### maxframe = max(frames) - self.args.seq_length*(frames[1] - frames[0])
        frames = [x for x in frames if not x > maxframe]
        total_frame += len(frames)
        set_id.extend(list(seti for i in range(len(frames))))
        frame_id_in_set.extend(list(frames[i] for i in range(len(frames))))

    all_frame_id_list = list(i for i in range(total_frame))

    data_index = np.concatenate((np.array([frame_id_in_set], dtype=int), np.array([set_id], dtype=int),
                                 np.array([all_frame_id_list], dtype=int)), 0)
    if ifshuffle:
        random.Random().shuffle(all_frame_id_list)
    data_index = data_index[:, all_frame_id_list]

    # to make full use of the data. Add again at the end the fisrt #batch_size frames
    if setname == 'train':
        data_index = np.append(data_index, data_index[:, :self.args.batch_size], 1)
    return data_index

In [142]:
data_index = get_data_index(self, frameped_dict, setname, ifshuffle=False)

In [145]:
data_index

array([[780, 786, 792, ..., 810, 816, 822],
       [  0,   0,   0, ...,   0,   0,   0],
       [  0,   1,   2, ...,   5,   6,   7]])