In [37]:
import geopandas as gpd
import pandas as pd
import copy
import os
import warnings
import random
from pprint import pprint
import geopandas as gpd
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm.notebook import tqdm
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from sklearn.model_selection import train_test_split

from srai.datasets import AirbnbMulticityDataset
from srai.embedders import Hex2VecEmbedder
from srai.joiners import IntersectionJoiner
from srai.loaders.osm_loaders import OSMPbfLoader
from srai.loaders.osm_loaders.filters import HEX2VEC_FILTER
from srai.neighbourhoods.h3_neighbourhood import H3Neighbourhood
from srai.plotting import plot_regions
from srai.regionalizers import H3Regionalizer
import plotly.graph_objs as go
from shapely.geometry import LineString, Point
from shapely import from_geojson
import h3
from srai.h3 import h3_to_geoseries
import matplotlib.pyplot as plt
import pickle

In [38]:
H3_RESOLUTION = 9
seq_length = 15

In [39]:
embeddings = pd.read_parquet(os.path.join('output_data', f'embeddings_{H3_RESOLUTION}.parquet'))
embeddings.shape

(221407, 10)

In [40]:
gdf = gpd.read_parquet(os.path.join('output_data', 'geolife_mpd2.parquet'))
gdf.shape

(1199437, 11)

# GDF AGGREAGATION

In [41]:
gdf_agg = gdf.groupby('trajectory_id').agg(
    {'geometry': LineString, 'date_str': list, 'speed': list, 'altitude': list,
     'trajectory_id': 'first'})


In [42]:
# gdf_agg = gdf_agg.sample(frac=0.5, random_state=42)

In [43]:
gdf_agg.shape

(15794, 5)

In [44]:
gdf_agg.head()

Unnamed: 0_level_0,geometry,date_str,speed,altitude,trajectory_id
trajectory_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
20070412093132,"LINESTRING (116.33038 39.97423, 116.33045 39.9...","[2007-04-12, 2007-04-12, 2007-04-12, 2007-04-1...","[0.022402120360618574, 1.193990445597205, 0.29...","[823.490813648294, 446.194225721785, 456.03674...",20070412093132
20070412102116,"LINESTRING (116.33007 39.97647, 116.33038 39.9...","[2007-04-12, 2007-04-12, 2007-04-12, 2007-04-1...","[1.7116285451939273, 0.11295117884677582, 0.00...","[173.884514435696, 229.658792650919, 114.82939...",20070412102116
20070412134621,"LINESTRING (116.33033 39.97537, 116.33050 39.9...","[2007-04-12, 2007-04-12, 2007-04-12, 2007-04-1...","[0.086345396783155, 0.086345396783155, 0.07207...","[396.981627296588, 22.9658792650919, 328.08398...",20070412134621
20070413005306,"LINESTRING (116.34504 39.96698, 116.34526 39.9...","[2007-04-13, 2007-04-13, 2007-04-13, 2007-04-1...","[0.9932525417946071, 0.9932525417946071, 2.939...","[157.48031496063, 150.918635170604, 180.446194...",20070413005306
20070413013238,"LINESTRING (116.26872 39.94579, 116.26881 39.9...","[2007-04-13, 2007-04-13, 2007-04-13, 2007-04-1...","[5.9428259332112985, 5.9428259332112985, 4.988...","[291.994750656168, 291.994750656168, 242.78215...",20070413013238


In [45]:
def get_hex_neigh(traj):
    h3_dist = []
    h3_cells = []
    points = traj.coords
    for idx in range(len(points) - 1):
        a, b = points[idx], points[idx + 1]
        start_hex = h3.latlng_to_cell(a[1], a[0], H3_RESOLUTION)
        end_hex = h3.latlng_to_cell(b[1], b[0], H3_RESOLUTION)
        if start_hex == end_hex:
            continue
        for h3_cell in h3.grid_path_cells(start_hex, end_hex):
            if not h3_cells or h3_dist[-1] != h3_cell:
                h3_dist.append(h3_cell)
                h3_cells.append(h3.grid_disk(h3_cell, 1))
    return h3_cells

In [46]:
dummies = pd.get_dummies([1, 2, 3, 4, 5, 6])

In [47]:
def get_y(traj):
    if not traj:
        return None
    y = []
    for i in range(len(traj) - 1):
        target = traj[i + 1][0]
        prop = np.array(traj[i])
        res = prop == target
        idx = np.where(res)[0]-1
        try:
            dummy = dummies.loc[idx].values[0]
        except:
            print(idx)
        y.append(dummy)
    return y

In [48]:
# t = gdf_agg['geometry'].iloc[5]
# cells = get_hex_neigh(t)
# pprint(len(cells))
# y = get_y(cells)
# print(len(y))
# em = embed(cells)
# pprint(em[0])

In [49]:
def embed(traj):
    emb = []
    for t in traj:
        temp = []
        for h in t:
            try:
                temp.append(embeddings.loc[h].values)
            except:
                return None
        emb.append(temp)
    return emb

In [None]:
gdf_agg['hex_neigh'] = gdf_agg['geometry'].apply(get_hex_neigh)
gdf_agg['y'] = gdf_agg['hex_neigh'].apply(get_y)
gdf_agg['neigh_emb'] = gdf_agg['hex_neigh'].apply(embed)

In [None]:
gdf_agg = gdf_agg[gdf_agg['neigh_emb'].notnull()]
gdf_agg = gdf_agg[gdf_agg['y'].notnull()]

# Train test split

In [None]:
train, test = train_test_split(gdf_agg, test_size=0.2, random_state=42)

In [None]:
train.shape # 6246

# SEQUENCE GENERATION

In [None]:
# def make_seq(row, seq_len=seq_length):
#     # seq without padding
#     for i in range(0, len(row['y']) - seq_len):
#         seq = []
#         for j in range(i, i+seq_len):
#             seq.append(np.array(row['neigh_emb'][j]).flatten().tolist())
#         yield seq, row['y'][i+seq_len]   

In [None]:
import numpy as np
import pandas as pd
from tqdm import tqdm

def make_pad_seq(row, seq_len=seq_length, batch_size=256):
    empty = list(np.zeros_like(np.array(row['neigh_emb'][0]).flatten().tolist()))
    batch = []
    
    for i in range(len(row['y']) - 1):
        seq = [empty for _ in range(seq_len)]
        
        start = max(0, i - seq_len + 1)
        end = i + 1
        
        for j in range(start, end):
            seq[seq_len - (i - j + 1)] = np.array(row['neigh_emb'][j]).flatten().tolist()
        
        batch.append((seq, row['y'][i]))
        
        # When batch size is reached, yield the batch
        if len(batch) == batch_size:
            yield batch
            batch = []
    
    # Yield any remaining items in the batch
    if batch:
        yield batch

In [None]:
batch_size = 128
all_batches = []

for _, row in tqdm(train.iterrows(), total=len(train), desc="Processing Rows"):
    for batch in make_pad_seq(row, seq_len=seq_length, batch_size=batch_size):
        all_batches.extend(batch)

train_seq = pd.DataFrame(all_batches, columns=['seq', 'y'])

In [None]:
batch_size = 128
all_batches = []

for _, row in tqdm(test.iterrows(), total=len(test), desc="Processing Rows"):
    for batch in make_pad_seq(row, seq_len=seq_length, batch_size=batch_size):
        all_batches.extend(batch)
            
test_seq = pd.DataFrame(all_batches, columns=['seq', 'y'])

In [None]:
train_seq.shape, test_seq.shape # hex8: ((275885, 2), (71234, 2))

In [None]:
# draw 200k samples
train_seq = train_seq.sample(200000, random_state=42)
test_draw_size = len(train_seq) * 0.2
test_seq = test_seq.sample(int(test_draw_size), random_state=42)

In [None]:
train_seq.shape, test_seq.shape

In [None]:
train_expanded = pd.DataFrame(train_seq['seq'].tolist(), index=train_seq.index)

In [None]:
test_expanded = pd.DataFrame(test_seq['seq'].tolist(), index=test_seq.index)

In [None]:
X_train, y_train = train_expanded.values, train_seq['y'].values
X_test, y_test = test_expanded.values, test_seq['y'].values

# Saving data to train

In [None]:
data_path = 'input_data'
if not os.path.exists(data_path):
    os.makedirs(data_path)

In [None]:
path = f'data_res{H3_RESOLUTION}_seq{seq_length}'
path = os.path.join(data_path, path)
if not os.path.exists(path):
    os.makedirs(path)

In [None]:
with open(os.path.join(path, 'test.pkl'), 'wb') as f:
    pickle.dump((X_test, y_test), f)

In [None]:
with open(os.path.join(path, 'train.pkl'), 'wb') as f:
    pickle.dump((X_train, y_train), f)