In [None]:
import argparse
import ast
import os
import yaml
import csv
import pickle
import random
import numpy as np

import time
import datetime
import h5py
import pandas as pd
from pathlib import Path

import sys

import pneumapackage.iodata as rd
import pneumapackage.mapmatching as mm
from pneumapackage.__init__ import write_pickle, path_data

np.random.seed(10)

# if we want to create a small dataset
newdata_generation = False
num_vechicles = 100
smallgen_file = '20181024_d1_0930_1000.csv'

# if we want to create train and test datasets
traintest_generation = False
test_size = 90
traintest_file = '20181024_d1_1000_1030.csv'

# if we want to create a new smaller map csv
smallmap_generation = False
extend_map = True
smallmap_file = 'drivable_map_d1_ex.csv'

# if we want to use full, small or test dataset
small_dataset = False
test_dataset = True
file_group = '1024_d1_1000' # day_DSid_starttime of file we use

# if we want to create a csv without motorcycles
no_motorcycle_generation = False
nomotor_file = '20181024_d1_0900_0930.csv'

# Dataset generation
Create a smaller dataset in the directory data/trajectory_small, with less vehicles

In [None]:
# create smaller dataset in data_pneumatrajectory_small, with first 'num_vechicles'
def gen_data():

    data_dir = 'data_pneuma/trajectory'
    data_dir = os.path.join(data_dir, smallgen_file) # Current full DS file
    
    newdata_dir = 'data_pneuma/trajectory_small'
    newdata_dir = os.path.join(newdata_dir, smallgen_file) # New small DS file
    
    csv.field_size_limit(sys.maxsize)
    
    # Get list from full DS
    with open(data_dir, 'r') as data_obj:
        data_read = csv.reader(data_obj)
        data = list(data_read)
    
    new_data = []
    
    # Create list with first num_vehicles
    for row in range(num_vechicles):
        new_data.append(data[row][:])
        
    # Create new csv file
    with open(newdata_dir, 'w') as f:
        write = csv.writer(f)
        write.writerows(new_data)


if newdata_generation:
    gen_data()

In [None]:
# create one testing dataset with 'test_size' random vehicles
# and one training dataset with tot - 'test_size' vehicles
def gen_traintest_data():
    
    data_dir = 'data_pneuma/trajectory'
    data_dir = os.path.join(data_dir, traintest_file)
    
    train_dir = 'data_pneuma/trajectory_train'
    train_dir = os.path.join(train_dir, traintest_file)
    
    test_dir = 'data_pneuma/trajectory_test'
    test_dir = os.path.join(test_dir, traintest_file)
    
    
    csv.field_size_limit(sys.maxsize)
    
    # Get list from full DS
    with open(data_dir, 'r') as data_obj:
        data_read = csv.reader(data_obj)
        data = np.array(list(data_read))
    
    # Create random mask with 1 where test vehicles are
    test_mask = np.zeros(len(data), bool)
    test_indices = np.random.choice(np.arange(1, len(data)), test_size, replace=False)
    test_mask[0] = 1 # always take first row (columns names)
    test_mask[test_indices] = 1
    
    # create list with test data
    test_data = data[test_mask][:]
    
    #create list with train data
    test_mask[0] = 0
    train_data = data[~test_mask][:]
    
    # Create new train csv file
    with open(train_dir, 'w') as f:
        write = csv.writer(f)
        write.writerows(train_data)
        
    # Create new test csv file
    with open(test_dir, 'w') as f:
        write = csv.writer(f)
        write.writerows(test_data)
    
        
if traintest_generation:
    gen_traintest_data()

In [None]:
# create a csv file same as the original but without motorcycles
def no_motorcyle_csv():
    data_dir = 'data_pneuma/trajectory'
    data_dir = os.path.join(data_dir, nomotor_file) # Current full DS file
    
    newdata_dir = 'data_pneuma/trajectory_nomotor'
    newdata_dir = os.path.join(newdata_dir, nomotor_file) # New DS file
    
    csv.field_size_limit(sys.maxsize)
    
    # Get list from full DS
    with open(data_dir, 'r') as data_obj:
        data_read = csv.reader(data_obj, delimiter=';')
        data = list(data_read)
    
    new_data = []
    
    # Create list with all vehicles except motorcycles
    for row in range(1, len(data)):
        if data[row][1] != ' Motorcycle':
            new_data.append(data[row][:])
        
    # Create new csv file
    with open(newdata_dir, 'w') as f:
        write = csv.writer(f, delimiter=';')
        write.writerow(data[0][:]) #create header
        write.writerows(new_data)
    

if no_motorcycle_generation:
    no_motorcyle_csv()

# Trajectories creation
Create the hdf5 file with adjusted trajectories from the dataset, which will then will be used to load the trajectories for further processing

In [None]:
# main: store data in 'pneuma_hdf' in HDF5 format,
# store path to hdf data, group names (date, zone, time), etc. in 'path_dict'
# both in folder 'data'
# select folder where data is, based on what we want to do
def all_data_to_hdf():
    
    # init data dictionary and create file 'datasets_paths' in folder 'data'
    path = 'data_pneuma/trajectory/' # path to raw data
    if small_dataset:
        path = 'data_pneuma/trajectory_small/' # path to raw data
    if test_dataset:
        path = 'data_pneuma/trajectory_test/' # path to raw data
    data_dict = rd.initialize_data_paths(path) #create a dict with all data groups in the folder (id: path)
    
    for k in data_dict.keys(): #for every group in the folder
        _ = get_hdf_names(k) #store data in HDF5 format

    print('All data stored in HDF5')

In [None]:
# helper
def get_hdf_names(group_id):
    try:
        # if data is already stored in HDF5, simply return group_path
        group_path = rd.get_group(group_id) 
    except KeyError:
        # otherwise, store data in HDF5 format 
        _, group_path = rd.io_data(group_id)
        
    path_id = group_path + '/all_id'
    path_org = group_path + '/original_trajectories'
    path_adj = group_path + '/adjusted_trajectories'
    
    return path_id, path_org, path_adj


all_data_to_hdf()


# main: adjust raw trajectory data by group, store in hdf
def adjust_data_in_hdf():
    tic = time.time()
    groups = rd.get_path_dict()['groups'].keys() #get the IDs of the stored data groups
    for i in groups:
        print(f'Dataset: {i}')
        _ = adjust_traj(i)
    toc = time.time()
    print(f'All datasets adjusted, took {toc-tic} sec')
    
# helper: adjust trajectory by group (make all timestamps the same)
def adjust_traj(group_id, redo=False, bearing=True, resample=True, step=1000):
    tic = time.time()
    print('Start: …load trajectories (resampled)')
    hdf_path = rd.get_hdf_path()
    path_id, path_org, path_adj = get_hdf_names(group_id)
    if redo:
        print('Create adjusted trajectory table and save the file in HDF5')
        ldf = rd.get_from_hdf(hdf_path, path_id, path_org, result='list')
        rd.new_dfs(ldf, group_id, bearing=bearing, resample=resample, step=step)
        ldf = rd.get_from_hdf(hdf_path, path_id, path_adj, result='list')
    else:
        try:
            ldf = rd.get_from_hdf(hdf_path, path_id, path_adj, result='list') #if we already have adj traj
        except KeyError:
            print('Create adjusted trajectory table and save the file in HDF5')
            ldf = rd.get_from_hdf(hdf_path, path_id, path_org, result='list')
            rd.new_dfs(ldf, group_id, bearing=bearing, resample=resample, step=step)
            ldf = rd.get_from_hdf(hdf_path, path_id, path_adj, result='list')
    toc = time.time()
    print(f'Adjusted trajectories loaded, took {toc - tic}')
    return ldf


adjust_data_in_hdf()

# main: match trajectory to Athens map, store in hdf
# - generate map files "athensmap.dat' and 'athensmap.idx'
def match_all_data_in_hdf():
    tic = time.time()
    groups = rd.get_path_dict()['groups'].keys()
    for i in groups:
        print(f'Dataset: {i}')
        _ = match_line_trajectories(i)
    toc = time.time()
    print(f'All datasets matched, took {toc-tic} sec')


# helper: if trajectories have already been matched, simply return matched data. otherwise, do map-matching
def match_line_trajectories(group_id, selection=None, reload=False, **kwargs):
    tic = time.time()
    print('Start: …load matched trajectories')
    hdf_path = rd.get_hdf_path()
    group_path = rd.get_path_dict()['groups'][group_id]
    if reload:
        traj_obj = map_matching(group_id=group_id, **kwargs)
        traj_matched = traj_obj['tracks']
        network_obj = traj_obj['network']
        traj_match = mm.TransformTrajectories(traj_matched, network_obj)
        step = rd.get_path_dict()['current_resample_step'][group_id]
        traj_match.tracks_line.to_hdf(hdf_path, key=group_path + f'/mm_line_{step}ms', format='table', mode='a',
                                      append=False,
                                      data_columns=['time', 'u_match', 'v_match'])
        line_traj = traj_match.tracks_line
        if selection is not None:
            line_traj = traj_match.tracks_line.query(selection)
    else:
        try:
            step = rd.get_path_dict()['current_resample_step'][group_id]
            line_traj = rd.get_from_hdf(hdf_path, key_id=group_path + '/all_id', key_tr=group_path +
                                                                                        f'/mm_line_{step}ms',
                                        result='df_all', select_tr=selection)
        except (KeyError, TypeError):
            traj_obj = map_matching(group_id=group_id, **kwargs)
            traj_matched = traj_obj['tracks']
            network_obj = traj_obj['network']
            traj_match = mm.TransformTrajectories(traj_matched, network_obj)
            step = rd.get_path_dict()['current_resample_step'][group_id]
            traj_match.tracks_line.to_hdf(hdf_path, key=group_path + f'/mm_line_{step}ms', format='table', mode='a',
                                          append=False, data_columns=['time', 'u_match', 'v_match'])
            line_traj = traj_match.tracks_line
            if selection is not None:
                line_traj = traj_match.tracks_line.query(selection)

    toc = time.time()
    print(f'Matched trajectories loaded, took {toc - tic} sec')
    return line_traj

# helper: do map-matching by group
def map_matching(group_id, traj_obj='/adjusted_trajectories', max_distance=10, rematch=False,
                      match_latlon=True, save_shp=False, path=path_data):
    tic = time.time()
    print('Start: …load map matching')
    # load network
    network_obj = load_network() # function defined in this notebook
    hdf_path = rd.get_hdf_path()
    group_path = rd.get_path_dict()['groups'][group_id]
    step = rd.get_path_dict()['current_resample_step'][group_id]
    if traj_obj not in ['/original_trajectories', '/adjusted_trajectories']:
        raise ValueError(f'traj_obj should be in ["/original_trajectories", "/adjusted_trajectories"]')
    try:
        with h5py.File(hdf_path, 'r') as s:
            tag_match = s[group_path].attrs[f'tag_mapmatching_{step}']
        if tag_match != network_obj.mm_id[group_id]:
            rematch = True
    except KeyError:
        rematch = True

    if not rematch:
        try:
            dfmatch_all = rd.get_from_hdf(hdf_path, key_id=group_path + '/all_id', key_tr=group_path + f'/mm_{step}ms',
                                          result='df_all')
        except (KeyError, TypeError):
            max_init_dist = int(max(network_obj.network_edges['length'])) + 1
            print(f'Initial distance: {max_init_dist} m, maximum distance (start): {max_distance} m')
            traj_unm = rd.get_from_hdf(hdf_path, key_id=group_path + '/all_id', key_tr=group_path + traj_obj,
                                       result='list')
            tmm = mm.MapMatching(traj_unm, network_obj, max_init=max_init_dist, max_d=max_distance,
                                 match_latlon=match_latlon)
            match_all = tmm.match_variable_distance(progress=False)
            dfmatch_all = pd.concat(match_all)
            step = rd.get_path_dict()['current_resample_step'][group_id]
            dfmatch_all.to_hdf(hdf_path, key=group_path + f'/mm_{step}ms', format='table', mode='a', append=False,
                               data_columns=['track_id', 'time', 'n1', 'n2', '_id'])
            dt = datetime.datetime.now()
            tag = int(dt.strftime('%Y%m%d%H%M'))
            with h5py.File(hdf_path, 'a') as s:
                s[group_path].attrs[f'tag_mapmatching_{step}'] = tag
            network_obj.add_mapmatch_tag(group_id, tag)
            write_pickle(network_obj, 'network', path)
    else:
        max_init_dist = int(max(network_obj.network_edges['length'])) + 1
        print(f'Initial distance: {max_init_dist} m, maximum distance (start): {max_distance} m')
        step = rd.get_path_dict()['current_resample_step'][group_id]
        traj_unm = rd.get_from_hdf(hdf_path, key_id=group_path + '/all_id', key_tr=group_path + traj_obj,
                                   result='list')
        tmm = mm.MapMatching(traj_unm, network_obj, max_init=max_init_dist, max_d=max_distance,
                             match_latlon=match_latlon)
        match_all = tmm.match_variable_distance(progress=False)
        dfmatch_all = pd.concat(match_all)
        step = rd.get_path_dict()['current_resample_step'][group_id]
        dfmatch_all.to_hdf(hdf_path, key=group_path + f'/mm_{step}ms', format='table', mode='a', append=False,
                           data_columns=['track_id', 'time', 'n1', 'n2', '_id'])
        dt = datetime.datetime.now()
        tag = int(dt.strftime('%Y%m%d%H%M'))
        with h5py.File(hdf_path, 'a') as s:
            s[group_path].attrs[f'tag_mapmatching_{step}'] = tag
        network_obj.add_mapmatch_tag(group_id, tag)
        write_pickle(network_obj, 'network', path)

    if save_shp:
        fn = path + f'/shapefiles/used_network_{group_id}_mm{step}'
        used_network = network_obj.network_edges[network_obj.network_edges['_id'].isin(dfmatch_all['_id'])]
        used_network.loc[:, ~used_network.columns.isin(['edge'])].to_file(filename=fn)
        Path(path + "/used_network").mkdir(parents=True, exist_ok=True)
        write_pickle(used_network, f'used_network_{group_id}_mm{step}', path=path + '/used_network')
        print('Shapefiles stored')

    toc = time.time()
    print(f'Map-matched trajectories loaded, took {toc - tic} sec')
    return {'tracks': dfmatch_all, 'network': network_obj}

# load network
def load_network(plot=False, save_to_shp=False):
    path = 'data_pneuma/'
    filename = 'network'
    with open(path + filename, 'rb') as a: #load the map pickle file
        net = pickle.load(a)
    
    if plot:
        net.plot_network_lanes()  # plot network
    
    if save_to_shp:
        net.save_graph_to_shp() # save to shapefiles in folder 'data'
        
    return net

match_all_data_in_hdf()

# Small map generation
Create the csv file with the smaller map of current dataset (e.g. area 1 of pneuma dataset)

In [None]:
def create_small_map():
    
    hdf_path = rd.get_hdf_path() #get path to HDF file with all data (pneuma_hdf)
    group_path = rd.get_path_dict()['groups'][file_group]
    traj_obj = '/mm_line_1000ms'
    
    map_dir = 'data_pneuma/map' # map directory
    map_file = 'drivable_map.csv' # map file name
    
    fullmap_dir = os.path.join(map_dir, map_file) # Current full map file
    smallmap_dir = os.path.join(map_dir, smallmap_file) # New small map file
    
    # get trajectories data from dataset we are using
    _, traj_df, _ = rd.get_from_hdf(hdf_path, key_id=group_path + '/all_id',
                                            key_tr=group_path + traj_obj, result='all')
    
    # get the min and max longitued and latitude between all traj
    min_lon = min(np.array(traj_df['lon_1']))
    max_lon = max(np.array(traj_df['lon_1']))
    min_lat = min(np.array(traj_df['lat_1']))
    max_lat = max(np.array(traj_df['lat_1']))
    
    
    csv.field_size_limit(sys.maxsize)
    
    map_df = pd.read_csv(fullmap_dir)
    
    # get the min map that includes all trajectories
    map_reduced = map_df[(map_df.lon >= min_lon) & (map_df.lon <= max_lon) &
                         (map_df.lat >= min_lat) & (map_df.lat <= max_lat)]
    
    # used for numerical map dimensions reasons, to be used later in the network (same rows and cols)
    if extend_map:
        min_x = min(np.array(map_reduced['idx_x']))
        max_x = max(np.array(map_reduced['idx_x']))
        min_y = min(np.array(map_reduced['idx_y']))
        max_y = max(np.array(map_reduced['idx_y']))
        
        map_reduced = map_df[(map_df.idx_x >= min_x) & (map_df.idx_x <= max_x+7) &
                             (map_df.idx_y >= min_y) & (map_df.idx_y <= max_y)]
    
    # write new csv file with small map
    map_reduced.to_csv(smallmap_dir)
    
if smallmap_generation:
    create_small_map()

# Parameters loading

Load all the parameters from parser

In [None]:
def get_parser():
    parser = argparse.ArgumentParser(description='STAR')
    parser.add_argument('--dataset', default='eth5')
    parser.add_argument('--save_dir')
    parser.add_argument('--model_dir')
    parser.add_argument('--config')
    parser.add_argument('--using_cuda', default=False, type=ast.literal_eval)
    parser.add_argument('--test_set', default='eth', type=str, help='Set this value to [eth, hotel, zara1, zara2, univ] for ETH-univ, ETH-hotel, UCY-zara01, UCY-zara02, UCY-univ')
    parser.add_argument('--base_dir', default='.', help='Base directory including these scripts.')
    parser.add_argument('--save_base_dir', default='./output/', help='Directory for saving caches and models.')
    parser.add_argument('--phase', default='train', help='Set this value to \'train\' or \'test\'')
    parser.add_argument('--train_model', default='star', help='Your model name')
    parser.add_argument('--load_model', default=None, type=str, help="load pretrained model for test or training")
    parser.add_argument('--model', default='star.STAR')
    parser.add_argument('--seq_length', default=10, type=int)
    parser.add_argument('--obs_length', default=4, type=int)
    parser.add_argument('--pred_length', default=6, type=int)
    parser.add_argument('--batch_around_ped', default=128, type=int)
    parser.add_argument('--batch_size', default=8, type=int)
    parser.add_argument('--test_batch_size', default=4, type=int)
    parser.add_argument('--show_step', default=100, type=int)
    parser.add_argument('--start_test', default=10, type=int)
    parser.add_argument('--sample_num', default=20, type=int)
    parser.add_argument('--num_epochs', default=3, type=int)
    parser.add_argument('--ifshow_detail', default=True, type=ast.literal_eval)
    parser.add_argument('--ifsave_results', default=False, type=ast.literal_eval)
    parser.add_argument('--randomRotate', default=False, type=ast.literal_eval, help="=True:random rotation of each trajectory fragment")
    parser.add_argument('--neighbor_thred', default=100, type=int)
    parser.add_argument('--learning_rate', default=0.0015, type=float)
    parser.add_argument('--clip', default=1, type=int)

    return parser


parser = get_parser()
p = parser.parse_args(args=[])

p.save_dir = p.save_base_dir + str(p.test_set) + '/'
p.model_dir = p.save_base_dir + str(p.test_set) + '/' + p.train_model + '/'
p.config = p.model_dir + '/config_' + p.phase + '.yaml'

In [None]:
def load_arg(p):
    # save arg
    if os.path.exists(p.config):
        with open(p.config, 'r') as f:
            default_arg = yaml.load(f)
        key = vars(p).keys()
        for k in default_arg.keys():
            if k not in key:
                print('WRONG ARG: {}'.format(k))
                try:
                    assert (k in key)
                except:
                    s = 1
        parser.set_defaults(**default_arg)
        return parser.parse_args(args=[])
    else:
        return False
    
def save_arg(args):
    # save arg
    arg_dict = vars(args)
    if not os.path.exists(args.model_dir):
        os.makedirs(args.model_dir)
    with open(args.config, 'w') as f:
        yaml.dump(arg_dict, f)
        

if not load_arg(p):
    save_arg(p)

args = load_arg(p)

# Network input processing

Define classes and functions to process the input of our network

In [None]:
class processor(object):
    def __init__(self, args):

        self.args = args

        self.dataloader = Trajectory_Dataloader(args)

        if not os.path.isdir(self.args.model_dir):
            os.mkdir(self.args.model_dir)

In [None]:
class Trajectory_Dataloader():
    def __init__(self, args):

        self.args = args
        if self.args.dataset == 'eth5':

            #self.data_dirs = ['data_pneuma/trajectory_small/']
            self.data_dirs = ['data_pneuma/trajectory/']

            # Data directory where the pre-processed pickle file resides
            self.data_dir = 'data'
            skip = 1

            self.train_dir = self.data_dirs
            self.trainskip = skip
            self.test_dir = self.data_dirs
            self.testskip = skip
        else:
            raise NotImplementedError

        self.train_data_file = os.path.join(self.args.save_dir, "train_trajectories.cpkl")
        self.test_data_file = os.path.join(self.args.save_dir, "test_trajectories.cpkl")
        self.train_batch_cache = os.path.join(self.args.save_dir, "train_batch_cache.cpkl")
        self.test_batch_cache = os.path.join(self.args.save_dir, "test_batch_cache.cpkl")

        print("Creating pre-processed data from raw data.")
        if not test_dataset:
            self.traject_preprocess('train')
        self.traject_preprocess('test')
        print("Done.")

        # Load the processed data from the pickle file
        print("Preparing data batches.")
        
        # if we are using test dataset, skip train batches computation
        if not (os.path.exists(self.train_batch_cache)) and not test_dataset:
            self.frameped_dict, self.pedtraject_dict = self.load_dict(self.train_data_file)
            self.dataPreprocess('train')
            
            
        if not (os.path.exists(self.test_batch_cache)):
            self.test_frameped_dict, self.test_pedtraject_dict = self.load_dict(self.test_data_file)
            self.dataPreprocess('test')

        if not test_dataset:
            self.trainbatch, self.trainbatchnums, _, _ = self.load_cache(self.train_batch_cache) #!!
        self.testbatch, self.testbatchnums, _, _ = self.load_cache(self.test_batch_cache)
        print("Done.")

        if not test_dataset:
            print('Total number of training batches:', self.trainbatchnums)
        print('Total number of test batches:', self.testbatchnums)

    
    def load_cache(self, data_file):
        f = open(data_file, 'rb')
        raw_data = pickle.load(f)
        f.close()
        return raw_data
    
    def load_dict(self, data_file):
        f = open(data_file, 'rb')
        raw_data = pickle.load(f)
        f.close()

        frameped_dict = raw_data[0]
        pedtraject_dict = raw_data[1]

        return frameped_dict, pedtraject_dict
    
    
    # MODIFIED FUNCTION
    def traject_preprocess(self, setname):
        '''
        data_dirs : List of directories where raw data resides
        data_file : The file into which all the pre-processed data needs to be stored
        '''
        if setname == 'train':
            data_dirs = self.train_dir
            data_file = self.train_data_file
        else:
            data_dirs = self.test_dir
            data_file = self.test_data_file
        
        # vehicle's traj dataframe column number where timestamps, x, y are
        time_col = 19
        x_col = 6
        y_col = 7
        
        numFrame_data = []

        Pedlist_data = []
        frameped_dict = []
        pedtrajec_dict = []
        
        for seti, _ in enumerate(data_dirs):
            
            hdf_path = rd.get_hdf_path() #get path to HDF file with all data (pneuma_hdf)
            group_path = rd.get_path_dict()['groups'][file_group]
            traj_obj = '/mm_line_1000ms'  # matched trajectories
            
            # id_df = all ids of peds in the dataset
            # traj_list = list of list, one list with traj data for every ped
            id_df, _, traj_list = rd.get_from_hdf(hdf_path, key_id=group_path + '/all_id',
                                                  key_tr=group_path + traj_obj, result='all')
            
            Pedlist = np.arange(1, len(traj_list)+1)
            numPeds = len(Pedlist)
            # Add the list of frameIDs to the frameList_data
            Pedlist_data.append(Pedlist)

            numFrame_data.append([])
            frameped_dict.append({})
            pedtrajec_dict.append({})
            
            for ind, pedi in enumerate(Pedlist):
                if ind % 100 == 0:
                    print(ind, len(Pedlist))
                
                Dataped = np.array(traj_list[ind]) # array with traj data of ped
                FrameList = np.around(Dataped[:, time_col]/1000) #list of traj frames (from ms to s)
                
                if len(FrameList) < 2:
                    continue
                # Add number of frames of this trajectory
                numFrame_data[seti].append(len(FrameList))
                # Initialize the row of the numpy array
                Trajectories = []
                # For each frame for the current ped
                
                for fi, frame in enumerate(FrameList):
                    # Extract the x and y positions
                    current_x = Dataped[fi, x_col]
                    current_y = Dataped[fi, y_col]
                    # Add the pedID, x, y to the row of the numpy array
                    Trajectories.append([int(frame), current_x, current_y])
                    if int(frame) not in frameped_dict[seti]:
                        frameped_dict[seti][int(frame)] = []
                    frameped_dict[seti][int(frame)].append(pedi)
                pedtrajec_dict[seti][pedi] = np.array(Trajectories)
        
        f = open(data_file, "wb")
        pickle.dump((frameped_dict, pedtrajec_dict), f, protocol=2)
        f.close()
    
    
    def dataPreprocess(self, setname):
        '''
        Function to load the pre-processed data into the DataLoader object
        '''
        if setname == 'train':
            val_fraction = 0
            frameped_dict = self.frameped_dict
            pedtraject_dict = self.pedtraject_dict
            cachefile = self.train_batch_cache

        else:
            val_fraction = 0
            frameped_dict = self.test_frameped_dict
            pedtraject_dict = self.test_pedtraject_dict
            cachefile = self.test_batch_cache
        if setname != 'train':
            shuffle = False
        else:
            shuffle = True
        
        data_index = self.get_data_index(frameped_dict, setname, ifshuffle=shuffle)
        
        val_index = data_index[:, :int(data_index.shape[1] * val_fraction)]
        train_index = data_index[:, (int(data_index.shape[1] * val_fraction) + 1):]
        
        
        trainbatch = self.get_seq_from_index_balance(frameped_dict, pedtraject_dict, train_index, setname)
        valbatch = self.get_seq_from_index_balance(frameped_dict, pedtraject_dict, val_index, setname)
        
        trainbatchnums = len(trainbatch)
        valbatchnums = len(valbatch)

        f = open(cachefile, "wb")
        pickle.dump((trainbatch, trainbatchnums, valbatch, valbatchnums), f, protocol=2)
        f.close()
        
    def get_data_index(self, data_dict, setname, ifshuffle=True):
        '''
        Get the dataset sampling index.
        '''
        set_id = []
        frame_id_in_set = []
        total_frame = 0
        for seti, dict in enumerate(data_dict):
            frames = sorted(dict)
            maxframe = max(frames) - self.args.seq_length
            frames = [x for x in frames if not x > maxframe]
            total_frame += len(frames)
            set_id.extend(list(seti for i in range(len(frames))))
            frame_id_in_set.extend(list(frames[i] for i in range(len(frames))))

        all_frame_id_list = list(i for i in range(total_frame))

        data_index = np.concatenate((np.array([frame_id_in_set], dtype=int), np.array([set_id], dtype=int),
                                     np.array([all_frame_id_list], dtype=int)), 0)
        
        if ifshuffle:
            random.Random().shuffle(all_frame_id_list)
        data_index = data_index[:, all_frame_id_list]
        
        if setname == 'train':
            data_index = np.append(data_index, data_index[:, :self.args.batch_size], 1)
            
        return data_index
    
    # MODIFIED FUNCTION
    def get_seq_from_index_balance(self, frameped_dict, pedtraject_dict, data_index, setname):
        '''
        Query the trajectories fragments from data sampling index.
        Notes: Divide the scene if there are too many people; accumulate the scene if there are few people.
               This function takes less gpu memory.
        '''
        batch_data_mass = []
        batch_data = []
        Batch_id = []

        temp = self.args.batch_around_ped 
        if setname == 'train':
            skip = self.trainskip
        else:
            skip = self.testskip
        
        ped_cnt = 0
        last_frame = 0
        
        for i in range(data_index.shape[1]):
            if i % 100 == 0:
                print(i, '/', data_index.shape[1])
            cur_frame, cur_set, _ = data_index[:, i]
            framestart_pedi = set(frameped_dict[cur_set][cur_frame])
            try:
                frameend_pedi = set(frameped_dict[cur_set][cur_frame + self.args.seq_length * skip])
            except:
                continue
            present_pedi = framestart_pedi | frameend_pedi
            
            if (framestart_pedi & frameend_pedi).__len__() == 0:
                continue
            traject = ()
            IFfull = []
            
            for ped in present_pedi:
                # if there are missing frames in traj, skip the vehicle, as it causes errors in find_trajectory_fragment
                if (pedtraject_dict[cur_set][ped][-1, 0] - pedtraject_dict[cur_set][ped][0, 0]) \
                    != (pedtraject_dict[cur_set][ped].shape[0]-1)*skip :
                    continue
                
                cur_trajec, iffull, ifexistobs = self.find_trajectory_fragment(pedtraject_dict[cur_set][ped],
                                                                               cur_frame, self.args.seq_length, skip)
                if len(cur_trajec) == 0: 
                    continue
                if ifexistobs == False:
                    # Just ignore trajectories if their data don't exsist at the last obversed time step (easy for data shift)
                    continue
                if sum(cur_trajec[:, 0] > 0) < 5:
                    # filter trajectories have too few frame data
                    continue

                cur_trajec = (cur_trajec[:, 1:].reshape(-1, 1, 2),)
                traject = traject.__add__(cur_trajec)
                IFfull.append(iffull)
                
            if traject.__len__() < 1:
                continue
            if sum(IFfull) < 1:
                continue
            traject_batch = np.concatenate(traject, 1)
            batch_pednum = sum([i.shape[1] for i in batch_data]) + traject_batch.shape[1] 

            cur_pednum = traject_batch.shape[1]
            ped_cnt += cur_pednum
            batch_id = (cur_set, cur_frame,)
            
            if cur_pednum >= self.args.batch_around_ped * 2:
                # too many people in current scene
                # split the scene into two batches
                ind = traject_batch[self.args.obs_length - 1].argsort(0)
                cur_batch_data, cur_Batch_id = [], []
                Seq_batchs = [traject_batch[:, ind[:cur_pednum // 2, 0]], traject_batch[:, ind[cur_pednum // 2:, 0]]]
                for sb in Seq_batchs:
                    cur_batch_data.append(sb)
                    cur_Batch_id.append(batch_id)
                    cur_batch_data = self.massup_batch(cur_batch_data)
                    batch_data_mass.append((cur_batch_data, cur_Batch_id,))
                    cur_batch_data = []
                    cur_Batch_id = []
                    
                last_frame = i
            elif cur_pednum >= self.args.batch_around_ped:
                # good pedestrian numbers
                cur_batch_data, cur_Batch_id = [], []
                cur_batch_data.append(traject_batch)
                cur_Batch_id.append(batch_id)
                cur_batch_data = self.massup_batch(cur_batch_data)
                batch_data_mass.append((cur_batch_data, cur_Batch_id,))
                
                last_frame = i
            else:  # less pedestrian numbers <64
                # accumulate multiple framedata into a batch
                if batch_pednum > self.args.batch_around_ped:
                    # enough (accumulated) people in the scene
                    batch_data.append(traject_batch)
                    Batch_id.append(batch_id)

                    batch_data = self.massup_batch(batch_data) #!!
                    batch_data_mass.append((batch_data, Batch_id,))

                    last_frame = i
                    batch_data = []
                    Batch_id = []
                else:
                    batch_data.append(traject_batch)
                    Batch_id.append(batch_id)
                

        if last_frame < data_index.shape[1] - 1 and setname == 'test' and batch_pednum > 1:
            batch_data = self.massup_batch(batch_data) #!!
            batch_data_mass.append((batch_data, Batch_id,))
        self.args.batch_around_ped = temp
        
        return batch_data_mass
    
    def find_trajectory_fragment(self, trajectory, startframe, seq_length, skip):
        '''
        Query the trajectory fragment based on the index. Replace where data doesn't exsist with 0.
        '''
        return_trajec = np.zeros((seq_length, 3))
        endframe = startframe + (seq_length) * skip
        start_n = np.where(trajectory[:, 0] == startframe)
        end_n = np.where(trajectory[:, 0] == endframe)
        iffull = False
        ifexsitobs = False

        if start_n[0].shape[0] == 0 and end_n[0].shape[0] != 0:
            start_n = 0
            end_n = end_n[0][0]
            if end_n == 0:
                return return_trajec, iffull, ifexsitobs
            
        elif end_n[0].shape[0] == 0 and start_n[0].shape[0] != 0:
            start_n = start_n[0][0]
            end_n = trajectory.shape[0]

        elif end_n[0].shape[0] == 0 and start_n[0].shape[0] == 0:
            start_n = 0
            end_n = trajectory.shape[0]

        else:
            end_n = end_n[0][0]
            start_n = start_n[0][0]

        candidate_seq = trajectory[start_n:end_n]
        offset_start = int((candidate_seq[0, 0] - startframe) // skip)

        offset_end = self.args.seq_length + int((candidate_seq[-1, 0] - endframe) // skip)
                
        return_trajec[offset_start:offset_end + 1, :3] = candidate_seq
        
        if return_trajec[self.args.obs_length - 1, 1] != 0:
            ifexsitobs = True

        if offset_end - offset_start >= seq_length - 1:
            iffull = True

        return return_trajec, iffull, ifexsitobs
    
    def massup_batch(self, batch_data):
        '''
        Massed up data fragements in different time window together to a batch
        '''
        num_Peds = 0
        for batch in batch_data:
            num_Peds += batch.shape[1]

        seq_list_b = np.zeros((self.args.seq_length, 0))
        nodes_batch_b = np.zeros((self.args.seq_length, 0, 2))
        nei_list_b = np.zeros((self.args.seq_length, num_Peds, num_Peds))
        nei_num_b = np.zeros((self.args.seq_length, num_Peds))
        num_Ped_h = 0
        batch_pednum = []
        for batch in batch_data:
            num_Ped = batch.shape[1]
            seq_list, nei_list, nei_num = self.get_social_inputs_numpy(batch)
            nodes_batch_b = np.append(nodes_batch_b, batch, 1)
            seq_list_b = np.append(seq_list_b, seq_list, 1)
            nei_list_b[:, num_Ped_h:num_Ped_h + num_Ped, num_Ped_h:num_Ped_h + num_Ped] = nei_list
            nei_num_b[:, num_Ped_h:num_Ped_h + num_Ped] = nei_num
            batch_pednum.append(num_Ped)
            num_Ped_h += num_Ped
            
        return (nodes_batch_b, seq_list_b, nei_list_b, nei_num_b, batch_pednum)
    
    def get_social_inputs_numpy(self, inputnodes):
        '''
        Get the sequence list (denoting where data exsist) and neighboring list (denoting where neighbors exsist).
        '''
        num_Peds = inputnodes.shape[1]

        seq_list = np.zeros((inputnodes.shape[0], num_Peds))
        # denote where data not missing

        for pedi in range(num_Peds):
            seq = inputnodes[:, pedi]
            seq_list[seq[:, 0] != 0, pedi] = 1

        # get relative cords, neighbor id list
        nei_list = np.zeros((inputnodes.shape[0], num_Peds, num_Peds))
        nei_num = np.zeros((inputnodes.shape[0], num_Peds))

        # nei_list[f,i,j] denote if j is i's neighbors in frame f
        for pedi in range(num_Peds):
            nei_list[:, pedi, :] = seq_list
            nei_list[:, pedi, pedi] = 0
            nei_num[:, pedi] = np.sum(nei_list[:, pedi, :], 1)
            seqi = inputnodes[:, pedi]
            for pedj in range(num_Peds):
                seqj = inputnodes[:, pedj]
                select = (seq_list[:, pedi] > 0) & (seq_list[:, pedj] > 0)

                relative_cord = seqi[select, :2] - seqj[select, :2]

                # invalid data index
                select_dist = (abs(relative_cord[:, 0]) > self.args.neighbor_thred) | (
                        abs(relative_cord[:, 1]) > self.args.neighbor_thred)

                nei_num[select, pedi] -= select_dist

                select[select == True] = select_dist
                nei_list[select, pedi, pedj] = 0
        return seq_list, nei_list, nei_num
    
    def get_train_batch(self, idx):
        batch_data, batch_id = self.trainbatch[idx]

        return batch_data, batch_id

In [None]:
trainer = processor(args)