In [None]:
import geopandas as gpd
import pandas as pd
import copy
import os
import warnings
import random
from pprint import pprint
import geopandas as gpd
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm.notebook import tqdm
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from sklearn.model_selection import train_test_split

from srai.datasets import AirbnbMulticityDataset
from srai.embedders import Hex2VecEmbedder
from srai.joiners import IntersectionJoiner
from srai.loaders.osm_loaders import OSMPbfLoader
from srai.loaders.osm_loaders.filters import HEX2VEC_FILTER
from srai.neighbourhoods.h3_neighbourhood import H3Neighbourhood
from srai.plotting import plot_regions
from srai.regionalizers import H3Regionalizer
import plotly.graph_objs as go
from shapely.geometry import LineString, Point
from shapely import from_geojson
import h3
from srai.h3 import h3_to_geoseries
import matplotlib.pyplot as plt

In [None]:
H3_RESOLUTION = 9
seq_length = 8

In [None]:
embeddings = pd.read_parquet(os.path.join('output_data', f'embeddings_{H3_RESOLUTION}.parquet'))
embeddings.shape

In [None]:
gdf = gpd.read_parquet(os.path.join('output_data', 'geolife_mpd.parquet'))
gdf.shape

# GDF AGGREAGATION

In [None]:
gdf_agg = gdf.groupby('trajectory_id').agg(
    {'geometry': LineString, 'date_str': list, 'speed': list, 'direction': list, 'altitude': list,
     'trajectory_id': 'first'})
gdf_agg.shape

In [None]:
gdf_agg.head()

In [None]:
def get_hex_neigh(traj):
    h3_dist = []
    h3_cells = []
    points = traj.coords
    for idx in range(len(points) - 1):
        a, b = points[idx], points[idx + 1]
        start_hex = h3.latlng_to_cell(a[1], a[0], H3_RESOLUTION)
        end_hex = h3.latlng_to_cell(b[1], b[0], H3_RESOLUTION)
        if start_hex == end_hex:
            continue
        for h3_cell in h3.grid_path_cells(start_hex, end_hex):
            if not h3_cells or h3_dist[-1] != h3_cell:
                h3_dist.append(h3_cell)
                h3_cells.append(h3.grid_disk(h3_cell, 1))
    return h3_cells

In [None]:
dummies = pd.get_dummies([1, 2, 3, 4, 5, 6])

In [None]:
def get_y(traj):
    if not traj:
        return None
    y = []
    for i in range(len(traj) - 1):
        target = traj[i + 1][0]
        prop = np.array(traj[i])
        res = prop == target
        # print(res)
        idx = np.where(res)[0]-1
        try:
            dummy = dummies.loc[idx].values[0]
        except:
            print(idx)
        y.append(dummy)
    return y

In [None]:
ls = gdf_agg['geometry'].iloc[5]
h3_cells = []
points = ls.coords
for idx in range(len(points) - 1):
    a, b = points[idx], points[idx + 1]
    start_hex = h3.latlng_to_cell(a[1], a[0], H3_RESOLUTION)
    end_hex = h3.latlng_to_cell(b[1], b[0], H3_RESOLUTION)
    if start_hex == end_hex:
        continue
    for h3_cell in h3.grid_path_cells(start_hex, end_hex):
        if not h3_cells or h3_cells[-1] != h3_cell:
            h3_cells.append(h3_cell)
m = h3_to_geoseries(h3_cells).reset_index().explore("index", tiles="CartoDB positron", opacity=0.4)
gpd.GeoSeries([ls]).explore(m=m)

In [None]:
# t = gdf_agg['geometry'].iloc[5]
# cells = get_hex_neigh(t)
# pprint(len(cells))
# y = get_y(cells)
# print(len(y))
# em = embed(cells)
# pprint(em[0])

In [None]:
def embed(traj):
    emb = []
    for t in traj:
        temp = []
        for h in t:
            try:
                temp.append(embeddings.loc[h].values)
            except:
                return None
        emb.append(temp)
    return emb

In [None]:
gdf_agg['hex_neigh'] = gdf_agg['geometry'].apply(get_hex_neigh)
gdf_agg['y'] = gdf_agg['hex_neigh'].apply(get_y)
gdf_agg['neigh_emb'] = gdf_agg['hex_neigh'].apply(embed)

In [None]:
gdf_agg = gdf_agg[gdf_agg['neigh_emb'].notnull()]
gdf_agg = gdf_agg[gdf_agg['y'].notnull()]
gdf_agg = gdf_agg[gdf_agg['y'].apply(lambda x: len(x) >= seq_length)]

# SEQUENCE GENERATION

In [None]:
# def make_seq(row, seq_len=seq_length):
#     # seq without padding
#     for i in range(0, len(row['y']) - seq_len):
#         seq = []
#         for j in range(i, i+seq_len):
#             seq.append(np.array(row['neigh_emb'][j]).flatten().tolist())
#         yield seq, row['y'][i+seq_len]   

In [None]:
def make_pad_seq(row, seq_len=seq_length):
    empty = list(np.zeros_like(np.array(row['neigh_emb'][0]).flatten().tolist()))

    for i in range(len(row['y']) - 1):
        seq = [empty for _ in range(seq_len)]

        start = max(0, i - seq_len + 1)
        end = i + 1

        for j in range(start, end):
            seq[seq_len - (i - j + 1)] = np.array(row['neigh_emb'][j]).flatten().tolist()
        yield seq, row['y'][i]
        

In [None]:
gdf_seq = gpd.GeoDataFrame(list(row for _, row in gdf_agg.iterrows() for row in make_pad_seq(row)) , columns=['seq', 'y'])

In [None]:
gdf_seq.head() 

In [None]:
seq_expanded = pd.DataFrame(gdf_seq['seq'].tolist(), index=gdf_seq.index)

In [None]:
seq_expanded.head()

In [None]:
X = seq_expanded
y = gdf_seq['y']

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Saving data to train

In [None]:
data_path = 'input_data'
if not os.path.exists(data_path):
    os.makedirs(data_path)

In [None]:
path = f'data_res{H3_RESOLUTION}_seq{seq_length}'
path = os.path.join(data_path, path)
if not os.path.exists(path):
    os.makedirs(path)

In [None]:
torch.save((X_train, y_train), os.path.join(path, 'train.pt'))
torch.save((X_test, y_test), os.path.join(path, 'test.pt'))