In [3]:
import geopandas as gpd
import copy
import os
import warnings
import random
from pprint import pprint
import geopandas as gpd
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm.notebook import tqdm
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from sklearn.model_selection import train_test_split

from srai.datasets import AirbnbMulticityDataset
from srai.embedders import Hex2VecEmbedder
from srai.joiners import IntersectionJoiner
from srai.loaders.osm_loaders import OSMPbfLoader
from srai.loaders.osm_loaders.filters import HEX2VEC_FILTER
from srai.neighbourhoods.h3_neighbourhood import H3Neighbourhood
from srai.plotting import plot_regions
from srai.regionalizers import H3Regionalizer
import plotly.graph_objs as go
from shapely.geometry import LineString, Point
from shapely import from_geojson
import h3
from srai.h3 import h3_to_geoseries

In [4]:
gdf = gpd.read_parquet('output_data/geolife_points.parquet')

In [39]:
gdf.shape

(49999, 12)

In [5]:
gdf.head()

Unnamed: 0_level_0,latitude,longitude,altitude,date,date_str,trajectory_id,mode,geometry,user_id,speed,timedelta,direction
time,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1
2008-12-27 07:26:04,39.969352,116.398589,492.0,39809.309769,2008-12-27,20081227072604,unknown,POINT (116.39859 39.96935),135,13.881385,NaT,160.099874
2008-12-27 07:26:06,39.969117,116.3987,492.0,39809.309792,2008-12-27,20081227072604,unknown,POINT (116.39870 39.96912),135,13.881385,0 days 00:00:02,160.099874
2008-12-27 07:26:07,39.969016,116.398799,492.0,39809.309803,2008-12-27,20081227072604,unknown,POINT (116.39880 39.96902),135,14.046284,0 days 00:00:01,143.085532
2008-12-27 07:26:08,39.968949,116.398868,491.0,39809.309815,2008-12-27,20081227072604,unknown,POINT (116.39887 39.96895),135,9.491683,0 days 00:00:01,141.717016
2008-12-27 07:26:09,39.968893,116.398905,491.0,39809.309826,2008-12-27,20081227072604,unknown,POINT (116.39890 39.96889),135,6.975264,0 days 00:00:01,153.143848


In [6]:
# print columns types
gdf.dtypes

latitude                 float64
longitude                float64
altitude                 float64
date                     float64
date_str                  object
trajectory_id             object
mode                      object
geometry                geometry
user_id                   object
speed                    float64
timedelta        timedelta64[ns]
direction                float64
dtype: object

In [30]:
# group by trajectory_id
gdf2 = gdf.groupby('trajectory_id').agg({'geometry': LineString, 'date_str': list, 'speed': list, 'direction': list, 'altitude': list, 'trajectory_id': 'first'})

In [31]:
gdf2.head()

Unnamed: 0_level_0,geometry,date_str,speed,direction,altitude,trajectory_id
trajectory_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
20081227072604,"LINESTRING (116.39859 39.96935, 116.39870 39.9...","[2008-12-27, 2008-12-27, 2008-12-27, 2008-12-2...","[13.881384728094615, 13.881384728094615, 14.04...","[160.09987375998526, 160.09987375998526, 143.0...","[492.0, 492.0, 492.0, 491.0, 491.0, 491.0, 491...",20081227072604
20090101024458,"LINESTRING (116.33176 39.93721, 116.33181 39.9...","[2009-01-01, 2009-01-01, 2009-01-01, 2009-01-0...","[26.440512857858558, 26.440512857858558, 2.414...","[9.265462560995047, 9.265462560995047, 14.9894...","[492.0, 492.0, 492.0, 492.0, 492.0, 492.0, 492...",20090101024458
20090102043127,"LINESTRING (116.39653 39.96581, 116.39645 39.9...","[2009-01-02, 2009-01-02, 2009-01-02, 2009-01-0...","[4.726638177926822, 4.726638177926822, 9.32635...","[313.0537264342987, 313.0537264342987, 116.225...","[492.0, 492.0, 492.0, 492.0, 492.0, 492.0, 492...",20090102043127
20090103012134,"LINESTRING (116.39974 39.97429, 116.39959 39.9...","[2009-01-03, 2009-01-03, 2009-01-03, 2009-01-0...","[12.73039176894192, 12.73039176894192, 6.18924...","[268.9965774862792, 268.9965774862792, 287.822...","[492.0, 492.0, 492.0, 492.0, 491.0, 491.0, 491...",20090103012134
20090110011947,"LINESTRING (116.39828 39.97375, 116.39830 39.9...","[2009-01-10, 2009-01-10, 2009-01-10, 2009-01-1...","[22.682514166001585, 22.682514166001585, 2.014...","[3.010510124255461, 3.010510124255461, 257.213...","[54.0, 77.0, 99.0, 101.0, 93.0, 107.0, 114.0, ...",20090110011947


In [37]:
gdf2.shape

(9, 6)

In [32]:
def plot_trajectories(trajectories):
    if isinstance(trajectories[0], LineString):
        trajectories = [list(trajectory.coords) for trajectory in trajectories]

    fig = go.Figure()
    for trajectory in trajectories:
        lon = [point[0] for point in trajectory]
        lat = [point[1] for point in trajectory]

        fig.add_trace(go.Scattermapbox(
            mode="markers+lines",
            lon=lon,
            lat=lat,
            marker={'size': 10}
        ))

    fig.update_layout(
        mapbox_style="open-street-map",
        mapbox_zoom=10,
        margin={"r": 0, "t": 0, "l": 0, "b": 0}
    )

    fig.show()

In [33]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
device = 'cpu'
device

'cpu'

## 10 regionów na podstawie 2 trajektorii

In [13]:
HEX_RES = 9

In [14]:
regionalizer = H3Regionalizer(resolution=HEX_RES)
regions = regionalizer.transform(gdf)

In [15]:
regions.index

Index(['89319c49827ffff', '8940b150dcfffff', '8940a676093ffff',
       '894095c9a1bffff', '89319c93247ffff', '8940a68a187ffff',
       '8931815554fffff', '89318264923ffff', '893183cac37ffff',
       '89319e7118bffff',
       ...
       '8940a6756cbffff', '89319c1b323ffff', '8940b301bb3ffff',
       '89409240617ffff', '89319e14347ffff', '894095c26a3ffff',
       '8940b385917ffff', '894092941afffff', '89319e702c7ffff',
       '894095c244fffff'],
      dtype='object', name='region_id', length=4960)

In [16]:
loader = OSMPbfLoader()
features = loader.load(regions, HEX2VEC_FILTER)

  ].unary_union


Encountered MemoryError during operation. Retrying with lower number of rows per group (500000).


python(55166) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(55167) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(55168) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(55169) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(55170) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(55171) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(55172) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(55174) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(55175) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(55176) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(55190) Malloc

In [17]:
joiner = IntersectionJoiner()
joint = joiner.transform(regions, features)

In [18]:
# getting embeddings using hex2vec
neighbourhood = H3Neighbourhood(regions)
embedder_hidden_sizes = [150, 100, 50, 10]
embedder = Hex2VecEmbedder(embedder_hidden_sizes)
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    embeddings = embedder.fit_transform(
        regions,
        features,
        joint,
        neighbourhood,
        trainer_kwargs={"max_epochs": 5, "accelerator": device},
        batch_size=100,
    )
embeddings.shape

100%|██████████| 4960/4960 [00:00<00:00, 61347.36it/s]
GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name    | Type       | Params
---------------------------------------
0 | encoder | Sequential | 54.1 K
---------------------------------------
54.1 K    Trainable params
0         Non-trainable params
54.1 K    Total params
0.216     Total estimated model params size (MB)


Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=5` reached.


(4960, 10)

In [19]:
embeddings

Unnamed: 0_level_0,0,1,2,3,4,5,6,7,8,9
region_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
89319c49827ffff,0.186160,-0.367860,0.239409,-0.200836,0.359717,0.021071,0.036221,-0.221839,0.011196,0.120120
8940b150dcfffff,0.186160,-0.367860,0.239409,-0.200836,0.359717,0.021071,0.036221,-0.221839,0.011196,0.120120
8940a676093ffff,-0.093212,0.142198,-0.302980,0.088826,-0.153741,0.422962,0.778910,0.255544,-0.582528,0.267696
894095c9a1bffff,0.186160,-0.367860,0.239409,-0.200836,0.359717,0.021071,0.036221,-0.221839,0.011196,0.120120
89319c93247ffff,0.186160,-0.367860,0.239409,-0.200836,0.359717,0.021071,0.036221,-0.221839,0.011196,0.120120
...,...,...,...,...,...,...,...,...,...,...
894095c26a3ffff,0.186160,-0.367860,0.239409,-0.200836,0.359717,0.021071,0.036221,-0.221839,0.011196,0.120120
8940b385917ffff,0.186160,-0.367860,0.239409,-0.200836,0.359717,0.021071,0.036221,-0.221839,0.011196,0.120120
894092941afffff,0.186160,-0.367860,0.239409,-0.200836,0.359717,0.021071,0.036221,-0.221839,0.011196,0.120120
89319e702c7ffff,-0.219831,0.426545,-0.412229,0.210392,-0.400367,0.222616,0.474490,0.367993,-0.334369,0.037958


## Dla każdej sekwencji trajektorii tworzymy kopie z odpowiadającym jej hex np. trajektoria pada na 3 hexy --> mamy 3x trajektorie z 3 roznymi hex

In [21]:
joined_gdf = gpd.sjoin(gdf_exp, regions, how="left")
joined_gdf.rename(columns={"index_right": "h3_index"}, inplace=True)

In [41]:
joined_gdf.head()

Unnamed: 0,geometry_x,speed,direction,altitude,trajectory_id,y,h3_index,geometry_y
0,"LINESTRING (116.39859 39.96935, 116.39870 39.9...","[13.881384728094615, 13.881384728094615, 14.04...","[160.09987375998526, 160.09987375998526, 143.0...","[492.0, 492.0, 492.0, 491.0, 491.0, 491.0, 491...",20081227072604,POINT (116.401025 39.966273),8931aa550b3ffff,"POLYGON ((116.39991 39.96732, 116.39963 39.965..."
1,"LINESTRING (116.39859 39.96935, 116.39870 39.9...","[13.881384728094615, 13.881384728094615, 14.04...","[160.09987375998526, 160.09987375998526, 143.0...","[492.0, 492.0, 492.0, 491.0, 491.0, 491.0, 491...",20081227072604,POINT (116.401025 39.966273),8931aa5554bffff,"POLYGON ((116.39870 39.96977, 116.39842 39.968..."
2,"LINESTRING (116.39859 39.96935, 116.39870 39.9...","[13.881384728094615, 13.881384728094615, 14.04...","[160.09987375998526, 160.09987375998526, 143.0...","[492.0, 492.0, 492.0, 491.0, 491.0, 491.0, 491...",20081227072604,POINT (116.401025 39.966273),8931aa55097ffff,"POLYGON ((116.39545 39.96981, 116.39517 39.968..."
3,"LINESTRING (116.39859 39.96935, 116.39870 39.9...","[13.881384728094615, 13.881384728094615, 14.04...","[160.09987375998526, 160.09987375998526, 143.0...","[492.0, 492.0, 492.0, 491.0, 491.0, 491.0, 491...",20081227072604,POINT (116.401025 39.966273),8931aa55543ffff,"POLYGON ((116.40075 39.97219, 116.40047 39.970..."
4,"LINESTRING (116.40103 39.96627, 116.40104 39.9...","[9.084598332238107, 9.097043835047296, 9.54958...","[171.92195610859665, 171.3912319062665, 167.64...","[219.0, 216.0, 215.0, 213.0, 210.0, 209.0, 209...",20081227072604,POINT (116.401691 39.951529),8931aa55393ffff,"POLYGON ((116.40063 39.95269, 116.40035 39.951..."


In [40]:
joined_gdf = joined_gdf.merge(
    gpd.GeoDataFrame(
        {"h3_index": regions.index, "geometry": regions.geometry}
    ),
    on="h3_index",
    how="left",
)

In [25]:
joined_gdf.shape, regions.shape, gdf_exp.shape

((11, 7), (10, 1), (2, 6))

In [42]:
def plot_both(row):
    traj = row['geometry_x']
    region = row['geometry_y']
    fig = go.Figure()
    lon = [point[0] for point in traj.coords]
    lat = [point[1] for point in traj.coords]
    fig.add_trace(go.Scattermapbox(
        mode="markers+lines",
        lon=lon,
        lat=lat,
        marker={'size': 10}
    ))
    lon = [point[0] for point in region.exterior.coords]
    lat = [point[1] for point in region.exterior.coords]
    fig.add_trace(go.Scattermapbox(
        mode="markers+lines",
        lon=lon,
        lat=lat,
        marker={'size': 10}
    ))
    fig.update_layout(
        mapbox_style="open-street-map",
        mapbox_zoom=10,
        margin={"r": 0, "t": 0, "l": 0, "b": 0}
    )
    fig.show()

In [57]:
merged_gdf = embeddings.merge(
    joined_gdf, how="inner", left_on="region_id", right_on="h3_index"
)

In [58]:
merged_gdf.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,geometry_x,speed,direction,altitude,trajectory_id,y,h3_index,geometry_y
0,-0.018811,-0.109211,-0.041467,0.047751,0.015788,-0.054962,0.013393,0.038641,-0.003545,0.021667,"LINESTRING (116.40103 39.96627, 116.40104 39.9...","[9.084598332238107, 9.097043835047296, 9.54958...","[171.92195610859665, 171.3912319062665, 167.64...","[219.0, 216.0, 215.0, 213.0, 210.0, 209.0, 209...",20081227072604,POINT (116.401691 39.951529),8931aa55393ffff,"POLYGON ((116.40063 39.95269, 116.40035 39.951..."
1,-0.057792,-0.156122,-0.091577,0.087942,0.041663,-0.052072,0.038837,0.034544,0.038947,-0.008255,"LINESTRING (116.39859 39.96935, 116.39870 39.9...","[13.881384728094615, 13.881384728094615, 14.04...","[160.09987375998526, 160.09987375998526, 143.0...","[492.0, 492.0, 492.0, 491.0, 491.0, 491.0, 491...",20081227072604,POINT (116.401025 39.966273),8931aa55097ffff,"POLYGON ((116.39545 39.96981, 116.39517 39.968..."
2,0.027579,-0.079779,-0.049456,0.030406,0.008061,-0.024665,-0.005402,0.045945,-0.040692,0.025304,"LINESTRING (116.39859 39.96935, 116.39870 39.9...","[13.881384728094615, 13.881384728094615, 14.04...","[160.09987375998526, 160.09987375998526, 143.0...","[492.0, 492.0, 492.0, 491.0, 491.0, 491.0, 491...",20081227072604,POINT (116.401025 39.966273),8931aa550b3ffff,"POLYGON ((116.39991 39.96732, 116.39963 39.965..."
3,0.027579,-0.079779,-0.049456,0.030406,0.008061,-0.024665,-0.005402,0.045945,-0.040692,0.025304,"LINESTRING (116.40103 39.96627, 116.40104 39.9...","[9.084598332238107, 9.097043835047296, 9.54958...","[171.92195610859665, 171.3912319062665, 167.64...","[219.0, 216.0, 215.0, 213.0, 210.0, 209.0, 209...",20081227072604,POINT (116.401691 39.951529),8931aa550b3ffff,"POLYGON ((116.39991 39.96732, 116.39963 39.965..."
4,-0.098882,-0.181297,-0.073567,0.072909,0.032905,-0.097502,0.078073,0.052253,0.048962,0.013693,"LINESTRING (116.40103 39.96627, 116.40104 39.9...","[9.084598332238107, 9.097043835047296, 9.54958...","[171.92195610859665, 171.3912319062665, 167.64...","[219.0, 216.0, 215.0, 213.0, 210.0, 209.0, 209...",20081227072604,POINT (116.401691 39.951529),8931aa55033ffff,"POLYGON ((116.40027 39.96001, 116.39999 39.958..."


In [254]:
def get_hex(point, res = HEX_RES):
    # returns the id of the region to which the point belongs
    h = h3.latlng_to_cell(point[1], point[0], res)
    # print(h)
    return h

In [255]:
def get_emb(point):
    # returns the embedding of the region to which the point belongs
    return embeddings.loc[get_hex(point)]

In [256]:
def get_emb_list(traj):
    res = []
    for point in traj.coords:
        res.append(get_emb(point))
    return res

In [257]:
def make_seq(row, seq_len=200):
    geometry = row['geometry']
    speed = row['speed']
    direction = row['direction']
    altitude = row['altitude']
    for i in range(0, len(geometry.coords) - seq_len, seq_len):
        yield {
        'geometry': LineString(geometry.coords[i:i+seq_len]),
        'speed': speed[i:i+seq_len],
        'direction': direction[i:i+seq_len],
        'altitude': altitude[i:i+seq_len],
        'trajectory_id': row['trajectory_id'],
        'emb': get_emb_list(LineString(geometry.coords[i:i+seq_len])),
        'y': Point(geometry.coords[i+seq_len]) 
        }

In [278]:
seq_gdf2 = gpd.GeoDataFrame(list(row for _, row in gdf2.iterrows() for row in make_seq(row)))

In [279]:
seq_gdf2.shape

(246, 7)

In [280]:
seq_gdf2.head()

Unnamed: 0,geometry,speed,direction,altitude,trajectory_id,emb,y
0,"LINESTRING (116.39859 39.96935, 116.39870 39.9...","[13.881384728094615, 13.881384728094615, 14.04...","[160.09987375998526, 160.09987375998526, 143.0...","[492.0, 492.0, 492.0, 491.0, 491.0, 491.0, 491...",20081227072604,"[[-0.29703888297080994, 0.5368309617042542, -0...",POINT (116.401025 39.966273)
1,"LINESTRING (116.40103 39.96627, 116.40104 39.9...","[9.084598332238107, 9.097043835047296, 9.54958...","[171.92195610859665, 171.3912319062665, 167.64...","[219.0, 216.0, 215.0, 213.0, 210.0, 209.0, 209...",20081227072604,"[[-0.259323388338089, 0.4811534583568573, -0.1...",POINT (116.401691 39.951529)
2,"LINESTRING (116.40169 39.95153, 116.40169 39.9...","[1.6048093521818334, 1.8321806157672382, 1.963...","[194.33402454321873, 180.66545941872383, 177.5...","[183.0, 172.0, 172.0, 169.0, 169.0, 167.0, 166...",20081227072604,"[[-0.3428332805633545, 0.645214855670929, -0.2...",POINT (116.402274 39.939697)
3,"LINESTRING (116.40227 39.93970, 116.40227 39.9...","[3.337850361530505, 4.233099950289504, 3.28550...","[176.3441484920835, 184.61419885801112, 175.54...","[103.0, 102.0, 100.0, 97.0, 94.0, 96.0, 106.0,...",20081227072604,"[[-0.2439761757850647, 0.5035545229911804, -0....",POINT (116.392419 39.93942)
4,"LINESTRING (116.39242 39.93942, 116.39235 39.9...","[3.270758161746067, 5.942572253366157, 2.73726...","[247.99470744245548, 290.8689995274607, 272.33...","[156.0, 152.0, 148.0, 144.0, 138.0, 138.0, 138...",20081227072604,"[[-0.21514934301376343, 0.4306723177433014, -0...",POINT (116.389321 39.936959)


In [281]:
test = pd.DataFrame(seq_gdf2['emb'][0])

In [282]:
test

Unnamed: 0,0,1,2,3,4,5,6,7,8,9
8931aa55097ffff,-0.297039,0.536831,-0.149996,0.254206,-0.532098,-0.659755,-1.129411,0.171043,0.756493,-0.577874
8931aa5554bffff,-0.297899,0.530968,-0.171187,0.243807,-0.510270,-0.576723,-0.956570,0.175417,0.641116,-0.520225
8931aa5554bffff,-0.297899,0.530968,-0.171187,0.243807,-0.510270,-0.576723,-0.956570,0.175417,0.641116,-0.520225
8931aa5554bffff,-0.297899,0.530968,-0.171187,0.243807,-0.510270,-0.576723,-0.956570,0.175417,0.641116,-0.520225
8931aa5554bffff,-0.297899,0.530968,-0.171187,0.243807,-0.510270,-0.576723,-0.956570,0.175417,0.641116,-0.520225
...,...,...,...,...,...,...,...,...,...,...
8931aa550b3ffff,-0.259323,0.481153,-0.170086,0.199609,-0.446381,-0.433880,-0.654971,0.185986,0.455642,-0.404490
8931aa550b3ffff,-0.259323,0.481153,-0.170086,0.199609,-0.446381,-0.433880,-0.654971,0.185986,0.455642,-0.404490
8931aa550b3ffff,-0.259323,0.481153,-0.170086,0.199609,-0.446381,-0.433880,-0.654971,0.185986,0.455642,-0.404490
8931aa550b3ffff,-0.259323,0.481153,-0.170086,0.199609,-0.446381,-0.433880,-0.654971,0.185986,0.455642,-0.404490


In [283]:
# group columns to list
stest = test.transpose().groupby(level=0).apply(lambda x: x.values.tolist())
stest = pd.DataFrame(stest).transpose()

In [284]:
stest

Unnamed: 0,0,1,2,3,4,5,6,7,8,9
0,"[[-0.29703888297080994, -0.297899454832077, -0...","[[0.5368309617042542, 0.5309680104255676, 0.53...","[[-0.14999550580978394, -0.17118650674819946, ...","[[0.25420618057250977, 0.24380697309970856, 0....","[[-0.5320983529090881, -0.5102698802947998, -0...","[[-0.659754753112793, -0.5767227411270142, -0....","[[-1.1294108629226685, -0.9565703868865967, -0...","[[0.1710430383682251, 0.17541703581809998, 0.1...","[[0.756493330001831, 0.6411163806915283, 0.641...","[[-0.5778744220733643, -0.520225465297699, -0...."


In [285]:
def extend_emb(row):
    emb = pd.DataFrame(row['emb'])
    emb = emb.transpose().groupby(level=0).apply(lambda x: x.values.tolist())
    emb = pd.DataFrame(emb).transpose()
    for i in range(emb.shape[1]):
        row[f'emb_{i}'] = emb[i][0][0]
    return row

In [286]:
seq_gdf2 = seq_gdf2.apply(extend_emb, axis=1)

In [287]:
seq_gdf2['emb_0'][0][0]

-0.29703888297080994

In [288]:
seq_gdf2['geometry'] = seq_gdf2['geometry'].apply(lambda x: list(x.coords))
seq_gdf2['lon'] = seq_gdf2['geometry'].apply(lambda x: [point[0] for point in x])
seq_gdf2['lat'] = seq_gdf2['geometry'].apply(lambda x: [point[1] for point in x])
seq_gdf2['y'] = seq_gdf2['y'].apply(lambda x: [x.x, x.y])


  seq_gdf2['geometry'] = seq_gdf2['geometry'].apply(lambda x: list(x.coords))


In [290]:
# drop col
col = ['trajectory_id', 'emb', 'geometry']
seq_gdf2.drop(col, axis=1, inplace=True)

In [28]:
X = seq_gdf2.drop('y', axis=1)
y = seq_gdf2['y']

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
X_train.shape, X_test.shape, y_train.shape, y_test.shape

NameError: name 'seq_gdf2' is not defined

In [None]:
# DATA  LOADE
train_dataloader = torch.utils.data.DataLoader(
    list(zip(X_train.values, y_train.values)),
    batch_size=32,
    shuffle=True,
)
test_dataloader = torch.utils.data.DataLoader(
    list(zip(X_test.values, y_test.values)),
    batch_size=32,
    shuffle=False,
)