In [13]:
import h5py
import numpy as np
import random

def copy_group(src, dst):
    for key in src.keys():
        if isinstance(src[key], h5py.Group):
            dst.create_group(key)
            copy_group(src[key], dst[key])
        else:
            src.copy(key, dst)

def split_data(input_path, output_path, levels=1, max_samples=100000, max_length=100, min_length=3):
    with h5py.File(input_path, 'r') as f:
        with h5py.File(output_path + '_train.h5', 'w') as f_80, h5py.File(output_path + '_test.h5', 'w') as f_20:
            copy_group(f['graph'], f_80.create_group('graph'))
            copy_group(f['graph'], f_20.create_group('graph'))
            f_80.create_group('trajectories')
            f_20.create_group('trajectories')
    
            all_trajectories = {}
            for k in f['trajectories']:
                if levels == 1:
                    edge_idx = tuple(f[f'trajectories/{k}/edge_idxs'][:])
                    if len(edge_idx) < min_length or len(edge_idx) > max_length or edge_idx in all_trajectories:
                        continue
                    all_trajectories[edge_idx] = k
                    if len(all_trajectories) > max_samples:
                        break
                else:
                    for k2 in f[f'trajectories/{k}']:
                        edge_idx = tuple(f[f'trajectories/{k}/{k2}/edge_idxs'][:])
                        if len(edge_idx) < min_length or len(edge_idx) > max_length or edge_idx in all_trajectories:
                            continue
                        all_trajectories[edge_idx] = (k, k2)
                        if len(all_trajectories) > max_samples:
                            break
    
            # Randomly shuffle the trajectories
            dict_keys = list(all_trajectories.keys())
            random.shuffle(dict_keys)
    
            split_point = int(0.8 * len(dict_keys))
    
            # split and copy the data
            for i, edge_idx in enumerate(dict_keys):
                k = all_trajectories[edge_idx]
                if i < split_point:
                    dst_file = f_80
                else:
                    dst_file = f_20

                if levels > 1:
                    k1, k2 = k
                    src_path = f'trajectories/{k1}/{k2}'
                    if k1 not in dst_file['trajectories']:
                        dst_file['trajectories'].create_group(k1)
                    copy_group(f[src_path], dst_file[f'trajectories/{k1}'].create_group(k2))
                else:
                    src_path = f'trajectories/{k}'
                    copy_group(f[src_path], dst_file[f'trajectories'].create_group(k))
    
    print("Splitting complete.")

In [14]:
split_data('/ceph/hdd/students/yaro/new_format/geolife.h5', '../datasets/geolife', levels=1, max_samples=1000000)

Splitting complete.


In [16]:
split_data('/ceph/hdd/students/yaro/new_format/tdrive.h5', '../datasets/tdrive', levels=1, max_samples=1000000)

Splitting complete.


In [17]:
split_data('/ceph/hdd/students/yaro/pneuma/merged.h5', '../datasets/pneuma', levels=2, max_samples=50000)

Splitting complete.


In [6]:
import sys
sys.path.insert(0, '../')

from datasets import GeoLifeTrajectoryDataset, TDriveTrajectoryDataset, PneumaTrajectoryDataset, Distance_evaluation

In [15]:
data = GeoLifeTrajectoryDataset("/ceph/hdd/students/weea/trajectory-prediction-on-graphs/datasets/geolife_train.h5", n_samples=-1, min_trajectory_length=3, max_trajectory_length=10000)
testdata = GeoLifeTrajectoryDataset("/ceph/hdd/students/weea/trajectory-prediction-on-graphs/datasets/geolife_test.h5", n_samples=-1, min_trajectory_length=3, max_trajectory_length=10000)

len(data), len(testdata)

(13869, 3468)

In [18]:
data = TDriveTrajectoryDataset("/ceph/hdd/students/weea/trajectory-prediction-on-graphs/datasets/tdrive_train.h5", n_samples=-1, min_trajectory_length=3, max_trajectory_length=10000)
testdata = TDriveTrajectoryDataset("/ceph/hdd/students/weea/trajectory-prediction-on-graphs/datasets/tdrive_test.h5", n_samples=-1, min_trajectory_length=3, max_trajectory_length=10000)

len(data), len(testdata)

(5728, 1433)

In [19]:
data = PneumaTrajectoryDataset("/ceph/hdd/students/weea/trajectory-prediction-on-graphs/datasets/pneuma_train.h5", n_samples=-1, min_trajectory_length=3, max_trajectory_length=10000)
testdata = PneumaTrajectoryDataset("/ceph/hdd/students/weea/trajectory-prediction-on-graphs/datasets/pneuma_test.h5", n_samples=-1, min_trajectory_length=3, max_trajectory_length=10000)

len(data), len(testdata)

(8216, 2060)