In [1]:
import os
import datetime
import numpy as np
import pandas as pd
import torch

In [2]:
def get_csv_files_from_dir(dir):
    files = []
    for file in os.listdir(dir):
        f = os.path.join(dir, file)
        if os.path.isfile(f) and f.endswith('.csv'):
            files.append(f)
    return files

In [3]:
def transform_cp_data(data, error, velocity, acceleration):
    data = data.sort_values(by=['rs_id', 'id', 'time']).reset_index(drop=True)
    
    data = data[
        ['rs_id', 'id', 'time', 'x', 'y', 'z'] + \
        (['x_err', 'y_err', 'z_err'] if error else []) + \
        (['v_x_est', 'v_y_est', 'v_z_est'] if velocity else []) + \
        (['a_x_est', 'a_y_est', 'a_z_est'] if acceleration else [])
    ]

    data = data.dropna()

    return data

In [4]:
def get_track(data, rs_id, id, t_min=None, t_max=None, drop_ids=True):
    track = data[(data['rs_id'] == rs_id) & (data['id'] == id)]
    if t_min is not None:
        track = track[track['time'] >= t_min]
    if t_max is not None:
        track = track[track['time'] <= t_max]
    if drop_ids:
        track = track.drop(columns=['rs_id', 'id'])
    return track

In [5]:
def get_tracks_timeranges_intersection(track_1, track_2):
    t_min_1, t_max_1 = track_1['time'].min(), track_1['time'].max()
    t_min_2, t_max_2 = track_2['time'].min(), track_2['time'].max()

    t_min, t_max = max(t_min_1, t_min_2), min(t_max_1, t_max_2)
    return t_min, t_max

In [10]:
def generate_siamese_data_from_cp_data_file(file, track_length, error, velocity, acceleration):
    row_length = 4 + int(error) * 3 + int(velocity) * 3 + int(acceleration) * 3
    x, y = torch.empty((0, 2, track_length, row_length)), torch.empty((0, 1))
    
    cp_data = pd.read_csv(file)
    cp_data = transform_cp_data(cp_data, error, velocity, acceleration)
    
    track_ids = [(rs_id, id) for rs_id in cp_data['rs_id'].unique() for id in cp_data[cp_data['rs_id'] == rs_id]['id'].unique()]
    n = len(track_ids)
    
    for i in range(n):
        for j in range(i + 1, n):
            rs_id_1, id_1 = track_ids[i]
            rs_id_2, id_2 = track_ids[j]
        
            if rs_id_1 == rs_id_2:
                continue
        
            track_1 = get_track(cp_data, rs_id_1, id_1)
            track_2 = get_track(cp_data, rs_id_2, id_2)

            t_min, t_max = get_tracks_timeranges_intersection(track_1, track_2)

            track_parts_pairs = torch.empty((0, 2, track_length, row_length))
            current_t_max = t_max

            while cur_t_max > t_min:
                track_1_part = track_1[track_1['time'] <= current_t_max]
                track_2_part = track_2[track_2['time'] <= current_t_max]

                if len(track_1_part) 
            
            break
        break
            # x = torch.cat((x, x_cur), 0)
            # y = torch.cat((y, y_cur), 0)
        
    return x, y

In [7]:
def generate_siamese_data_from_cp_data_dir(dir, track_length, error, velocity, acceleration):
    row_length = 4 + int(error) * 3 + int(velocity) * 3 + int(acceleration) * 3
    x, y = torch.empty((0, 2, track_length, row_length)), torch.empty((0, 1))
    
    for file in get_csv_files_from_dir(dir):
        print(f'- file {file}')
        x_cur, y_cur = generate_siamese_data_from_cp_data_file(file, track_length, error, velocity, acceleration)
        # print(f'({x_cur.shape[0]} rows)')
        
        x = torch.cat((x, x_cur), 0)
        y = torch.cat((y, y_cur), 0)

    return x, y

In [8]:
cp_data_dir      = 'CP_data'
siamese_data_dir = 'siamese_data'

track_length     = 32
error            = False
velocity         = False
acceleration     = False

In [12]:
for data_usage_aim in ('train', 'test'):
    print(f'Processing {cp_data_dir}/{data_usage_aim}')
    
    x, y = generate_siamese_data_from_cp_data_dir(f'{cp_data_dir}/{data_usage_aim}', track_length, error, velocity, acceleration)
    
    # if x.shape[0] != 0:
        # print(f'Saving to {siamese_data_dir}/{data_usage_aim} ({x.shape[0]} rows)')
        # torch.save(x, f'{siamese_data_dir}/{data_usage_aim}/x.pt')
        # torch.save(y, f'{siamese_data_dir}/{data_usage_aim}/y.pt')
    # else:
        # print(f'Nothing to save to {siamese_data_dir}/{data_usage_aim}')
    
    print()

Processing CP_data/train
- file CP_data/train/data_1403232125.csv
Tracks rs0id0 and rs1id0, shapes - (712, 4) and (1416, 4), time range - (626810, 873140)
Tracks rs0id0 and rs2id1, shapes - (712, 4) and (1409, 4), time range - (626810, 872000)
Tracks rs0id0 and rs3id2, shapes - (712, 4) and (1406, 4), time range - (626810, 872150)
Tracks rs0id0 and rs4id3, shapes - (712, 4) and (1397, 4), time range - (626810, 871180)
Tracks rs0id1 and rs1id0, shapes - (711, 4) and (1416, 4), time range - (627310, 873140)
Tracks rs0id1 and rs2id1, shapes - (711, 4) and (1409, 4), time range - (627310, 872000)
Tracks rs0id1 and rs3id2, shapes - (711, 4) and (1406, 4), time range - (627310, 872150)
Tracks rs0id1 and rs4id3, shapes - (711, 4) and (1397, 4), time range - (627310, 871180)
Tracks rs0id2 and rs1id0, shapes - (710, 4) and (1416, 4), time range - (627810, 873140)
Tracks rs0id2 and rs2id1, shapes - (710, 4) and (1409, 4), time range - (627810, 872000)
Tracks rs0id2 and rs3id2, shapes - (710, 4) 