In [1]:
import geopandas as gpd
import pandas as pd
import copy
import os
import warnings
import random
from pprint import pprint
import geopandas as gpd
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm.notebook import tqdm
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from sklearn.model_selection import train_test_split

from srai.datasets import AirbnbMulticityDataset
from srai.embedders import Hex2VecEmbedder
from srai.joiners import IntersectionJoiner
from srai.loaders.osm_loaders import OSMPbfLoader
from srai.loaders.osm_loaders.filters import HEX2VEC_FILTER
from srai.neighbourhoods.h3_neighbourhood import H3Neighbourhood
from srai.plotting import plot_regions
from srai.regionalizers import H3Regionalizer
import plotly.graph_objs as go
from shapely.geometry import LineString, Point
from shapely import from_geojson
import h3
from srai.h3 import h3_to_geoseries

In [2]:
gdf = gpd.read_parquet('geolife_points_smooth.parquet')

In [3]:
gdf.shape

(149173, 12)

In [4]:
gdf.head()

Unnamed: 0_level_0,latitude,longitude,altitude,date,date_str,trajectory_id,mode,geometry,user_id,speed,timedelta,direction
time,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1
2007-04-12 10:18:53,39.975517,116.330283,351.049869,39184.42978,2007-04-12,20070412101853_0,unknown,POINT (116.33028 39.97552),161,0.317458,NaT,15.013958
2007-04-12 10:23:15,39.976233,116.330567,118.110236,39184.432813,2007-04-12,20070412101853_0,unknown,POINT (116.33057 39.97623),161,0.317458,0 days 00:01:22,20.964607
2007-04-12 10:23:25,39.97585,116.3304,114.829396,39184.432928,2007-04-12,20070412102325_0,unknown,POINT (116.33040 39.97585),161,0.24934,NaT,20.964677
2007-04-12 10:26:25,39.976233,116.330567,118.110236,39184.435012,2007-04-12,20070412102325_0,unknown,POINT (116.33057 39.97623),161,0.24934,0 days 00:01:22,20.964607
2007-04-13 10:56:48,39.976217,116.33015,0.0,39185.456111,2007-04-13,20070413105648_0,unknown,POINT (116.33015 39.97622),161,0.077967,NaT,99.264399


In [5]:
# gdf_exp = gdf[:10000].copy()

In [6]:
# group by trajectory_id
gdf_agg = gdf.groupby('trajectory_id').agg({'geometry': LineString, 'date_str': list, 'speed': list, 'direction': list, 'altitude': list, 'trajectory_id': 'first'})

In [7]:
gdf_agg.head()

Unnamed: 0_level_0,geometry,date_str,speed,direction,altitude,trajectory_id
trajectory_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
20070412101853_0,"LINESTRING (116.33028 39.97552, 116.33057 39.9...","[2007-04-12, 2007-04-12]","[0.3174578150582134, 0.3174578150582134]","[15.013957914870332, 20.964606804911114]","[351.049868766404, 118.110236220472]",20070412101853_0
20070412102325_0,"LINESTRING (116.33040 39.97585, 116.33057 39.9...","[2007-04-12, 2007-04-12]","[0.24933972268054067, 0.24933972268054067]","[20.964676832270413, 20.964606804911114]","[114.829396325459, 118.110236220472]",20070412102325_0
20070413105648_0,"LINESTRING (116.33015 39.97622, 116.33028 39.9...","[2007-04-13, 2007-04-13, 2007-04-13, 2007-04-1...","[0.07796726674907895, 0.07796726674907895, 0.0...","[99.26439871977664, 99.26439871977664, 213.973...","[0.0, 55.7742782152231, 55.7742782152231, 170....",20070413105648_0
20070413150314_0,"LINESTRING (116.31905 39.92687, 116.31551 39.9...","[2007-04-13, 2007-04-13, 2007-04-13, 2007-04-13]","[9.889579092536993, 9.889579092536993, 0.48658...","[269.68467918410283, 269.68467918410283, 336.9...","[252.624671916011, 252.624671916011, 262.46719...",20070413150314_0
20070414022829_0,"LINESTRING (116.31554 39.92693, 116.31654 39.9...","[2007-04-14, 2007-04-14, 2007-04-14, 2007-04-1...","[1.2964931291294477, 1.2964931291294477, 1.443...","[94.8087951122991, 104.61717890145906, 64.6302...","[242.782152230971, 193.569553805774, 147.63779...",20070414022829_0


In [8]:
gdf_agg.shape

(1795, 6)

In [9]:
def plot_trajectories(trajectories):
    if isinstance(trajectories[0], LineString):
        trajectories = [list(trajectory.coords) for trajectory in trajectories]

    fig = go.Figure()
    for trajectory in trajectories:
        lon = [point[0] for point in trajectory]
        lat = [point[1] for point in trajectory]

        fig.add_trace(go.Scattermapbox(
            mode="markers+lines",
            lon=lon,
            lat=lat,
            marker={'size': 10}
        ))

    fig.update_layout(
        mapbox_style="open-street-map",
        mapbox_zoom=10,
        margin={"r": 0, "t": 0, "l": 0, "b": 0}
    )

    fig.show()

In [10]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
device = 'cpu'
device

'cpu'

## 10 regionów na podstawie 2 trajektorii

In [11]:
HEX_RES = 8

In [12]:
regionalizer = H3Regionalizer(resolution=HEX_RES)
regions = regionalizer.transform(gdf)

In [13]:
plot_regions(regions)

In [14]:
loader = OSMPbfLoader()
features = loader.load(regions, HEX2VEC_FILTER)

Finding matching extracts:   0%|          | 0/258 [00:00<?, ?it/s]

Filtering extracts:   0%|          | 0/258 [00:00<?, ?it/s]

  ].unary_union
python(1469) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(1470) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(1471) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(1472) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(1473) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(1474) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(1475) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(1476) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(1477) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(1478) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(1479) M

In [17]:
joiner = IntersectionJoiner()
joint = joiner.transform(regions, features)

In [18]:
# getting embeddings using hex2vec
neighbourhood = H3Neighbourhood(regions)
embedder_hidden_sizes = [150, 100, 50, 10]
embedder = Hex2VecEmbedder(embedder_hidden_sizes)
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    embeddings = embedder.fit_transform(
        regions,
        features,
        joint,
        neighbourhood,
        trainer_kwargs={"max_epochs": 5, "accelerator": device},
        batch_size=100,
    )
embeddings.shape

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3545/3545 [00:00<00:00, 50455.59it/s]
GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name    | Type       | Params
---------------------------------------
0 | encoder | Sequential | 86.2 K
---------------------------------------
86.2 K    Trainable params
0         Non-trainable params
86.2 K    Total params
0.345     Total estimated model params size (MB)


Training: |                                                                                                   …

`Trainer.fit` stopped: `max_epochs=5` reached.


(3545, 10)

In [19]:
embeddings

Unnamed: 0_level_0,0,1,2,3,4,5,6,7,8,9
region_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
8831aa55dbfffff,-0.440059,-0.585884,0.214865,0.308701,0.116530,0.227551,0.636538,0.163365,0.520458,0.201962
8831aa4edbfffff,-0.579493,-0.064501,-0.125170,0.875584,0.011042,0.486846,-0.151257,0.037258,0.239113,-0.179814
8831aa0287fffff,-0.341617,0.340432,-0.133646,0.739794,0.104246,0.496847,-0.517220,-0.128635,-0.263340,-0.224164
8831aac869fffff,-0.078744,0.389650,-0.102333,-0.451383,0.290360,-0.064100,-0.272063,-0.393590,-0.323008,0.223041
883181b051fffff,0.042980,0.869908,-0.270446,-0.249111,0.186374,-0.014133,-1.009785,-0.660062,-0.781883,-0.272112
...,...,...,...,...,...,...,...,...,...,...
8831852c11fffff,-0.505142,-0.203354,-0.297478,0.031365,0.368270,0.329756,0.281796,-0.230674,0.427896,0.355824
8831abdb51fffff,-0.222275,0.425558,-0.512584,0.209994,0.026025,0.292281,-0.689192,-0.366923,-0.233515,-0.246887
8831aa558bfffff,0.947742,0.449744,0.647888,0.116489,-0.818969,-0.258875,-0.254558,0.706311,-0.102929,-0.975531
8831aa4f0dfffff,-0.200948,-0.411213,0.189203,-0.129222,0.222507,-0.021656,0.535167,0.150476,0.224016,0.469305


## Dla każdej sekwencji trajektorii tworzymy kopie z odpowiadającym jej hex np. trajektoria pada na 3 hexy --> mamy 3x trajektorie z 3 roznymi hex

In [20]:
joined_gdf = gpd.sjoin(gdf, regions, how="left")
joined_gdf.rename(columns={"index_right": "h3_index"}, inplace=True)

In [21]:
joined_gdf.head()

Unnamed: 0_level_0,latitude,longitude,altitude,date,date_str,trajectory_id,mode,geometry,user_id,speed,timedelta,direction,h3_index
time,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1
2007-04-12 10:18:53,39.975517,116.330283,351.049869,39184.42978,2007-04-12,20070412101853_0,unknown,POINT (116.33028 39.97552),161,0.317458,NaT,15.013958,8831aa50e3fffff
2007-04-12 10:23:15,39.976233,116.330567,118.110236,39184.432813,2007-04-12,20070412101853_0,unknown,POINT (116.33057 39.97623),161,0.317458,0 days 00:01:22,20.964607,8831aa508dfffff
2007-04-12 10:23:25,39.97585,116.3304,114.829396,39184.432928,2007-04-12,20070412102325_0,unknown,POINT (116.33040 39.97585),161,0.24934,NaT,20.964677,8831aa50e3fffff
2007-04-12 10:26:25,39.976233,116.330567,118.110236,39184.435012,2007-04-12,20070412102325_0,unknown,POINT (116.33057 39.97623),161,0.24934,0 days 00:01:22,20.964607,8831aa508dfffff
2007-04-13 10:56:48,39.976217,116.33015,0.0,39185.456111,2007-04-13,20070413105648_0,unknown,POINT (116.33015 39.97622),161,0.077967,NaT,99.264399,8831aa50e3fffff


In [22]:
joined_gdf = joined_gdf.merge(
    gpd.GeoDataFrame(
        {"h3_index": regions.index, "geometry": regions.geometry}
    ),
    on="h3_index",
    how="left",
)

In [23]:
joined_gdf.shape, regions.shape, gdf.shape

((149173, 14), (3545, 1), (149173, 12))

In [24]:
def plot_both(row):
    traj = row['geometry_x']
    region = row['geometry_y']
    fig = go.Figure()
    lon = [point[0] for point in traj.coords]
    lat = [point[1] for point in traj.coords]
    fig.add_trace(go.Scattermapbox(
        mode="markers+lines",
        lon=lon,
        lat=lat,
        marker={'size': 10}
    ))
    lon = [point[0] for point in region.exterior.coords]
    lat = [point[1] for point in region.exterior.coords]
    fig.add_trace(go.Scattermapbox(
        mode="markers+lines",
        lon=lon,
        lat=lat,
        marker={'size': 10}
    ))
    fig.update_layout(
        mapbox_style="open-street-map",
        mapbox_zoom=10,
        margin={"r": 0, "t": 0, "l": 0, "b": 0}
    )
    fig.show()

In [25]:
merged_gdf = embeddings.merge(
    joined_gdf, how="inner", left_on="region_id", right_on="h3_index"
)

In [26]:
merged_gdf.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,date_str,trajectory_id,mode,geometry_x,user_id,speed,timedelta,direction,h3_index,geometry_y
0,-0.440059,-0.585884,0.214865,0.308701,0.11653,0.227551,0.636538,0.163365,0.520458,0.201962,...,2008-06-20,20080619225306_0,unknown,POINT (116.47243 39.96379),167,11.114936,0 days 00:00:01,125.70475,8831aa55dbfffff,"POLYGON ((116.46820 39.96659, 116.46586 39.962..."
1,-0.579493,-0.064501,-0.12517,0.875584,0.011042,0.486846,-0.151257,0.037258,0.239113,-0.179814,...,2007-08-05,20070805013330_0,unknown,POINT (116.47425 39.80912),150,17.167552,0 days 00:01:54,242.922894,8831aa4edbfffff,"POLYGON ((116.47553 39.81312, 116.47320 39.809..."
2,-0.579493,-0.064501,-0.12517,0.875584,0.011042,0.486846,-0.151257,0.037258,0.239113,-0.179814,...,2009-05-16,20090515221447_1,unknown,POINT (116.47869 39.81121),41,22.321237,0 days 00:00:01,242.859099,8831aa4edbfffff,"POLYGON ((116.47553 39.81312, 116.47320 39.809..."
3,-0.579493,-0.064501,-0.12517,0.875584,0.011042,0.486846,-0.151257,0.037258,0.239113,-0.179814,...,2009-05-16,20090515221447_1,unknown,POINT (116.47598 39.80996),41,21.528096,0 days 00:00:01,238.735983,8831aa4edbfffff,"POLYGON ((116.47553 39.81312, 116.47320 39.809..."
4,-0.579493,-0.064501,-0.12517,0.875584,0.011042,0.486846,-0.151257,0.037258,0.239113,-0.179814,...,2009-05-16,20090515221447_1,unknown,POINT (116.47408 39.80880),41,23.195576,0 days 00:00:01,229.066825,8831aa4edbfffff,"POLYGON ((116.47553 39.81312, 116.47320 39.809..."


In [27]:
def get_hex(point, res = HEX_RES):
    # returns the id of the region to which the point belongs
    h = h3.latlng_to_cell(point[1], point[0], res)
    # print(h)
    return h

In [28]:
def get_emb(point):
    # returns the embedding of the region to which the point belongs
    return embeddings.loc[get_hex(point)]

In [29]:
def get_emb_list(traj):
    res = []
    for point in traj.coords:
        res.append(get_emb(point))
    return res

In [30]:
def make_seq(row, seq_len=200):
    geometry = row['geometry']
    speed = row['speed']
    direction = row['direction']
    altitude = row['altitude']
    for i in range(0, len(geometry.coords) - seq_len, seq_len):
        yield {
        'geometry': LineString(geometry.coords[i:i+seq_len]),
        'speed': speed[i:i+seq_len],
        'direction': direction[i:i+seq_len],
        'altitude': altitude[i:i+seq_len],
        'trajectory_id': row['trajectory_id'],
        'emb': get_emb_list(LineString(geometry.coords[i:i+seq_len])),
        'y': Point(geometry.coords[i+seq_len]) 
        }

# SEQ_GDF

In [44]:
seq_gdf = gpd.GeoDataFrame(list(row for _, row in gdf_agg.iterrows() for row in make_seq(row, seq_len = 30)))
seq_gdf.crs = gdf.crs

In [45]:
seq_gdf.shape

(4099, 7)

In [46]:
# test = pd.DataFrame(seq_gdf['emb'][0])

In [47]:
# # group columns to list
# stest = test.transpose().groupby(level=0).apply(lambda x: x.values.tolist())
# stest = pd.DataFrame(stest).transpose()

In [48]:
def extend_emb(row):
    emb = pd.DataFrame(row['emb'])
    emb = emb.transpose().groupby(level=0).apply(lambda x: x.values.tolist())
    emb = pd.DataFrame(emb).transpose()
    for i in range(emb.shape[1]):
        row[f'emb_{i}'] = emb[i][0][0]
    return row

In [49]:
seq_gdf = seq_gdf.apply(extend_emb, axis=1)

In [50]:
seq_gdf['geometry'] = seq_gdf['geometry'].apply(lambda x: list(x.coords))
seq_gdf['lon'] = seq_gdf['geometry'].apply(lambda x: [point[0] for point in x])
seq_gdf['lat'] = seq_gdf['geometry'].apply(lambda x: [point[1] for point in x])
seq_gdf['y'] = seq_gdf['y'].apply(lambda x: [x.x, x.y])


  seq_gdf['geometry'] = seq_gdf['geometry'].apply(lambda x: list(x.coords))


In [51]:
# drop col
col = ['trajectory_id', 'emb', 'geometry']
seq_gdf.drop(col, axis=1, inplace=True)

In [52]:
seq_gdf.head()

Unnamed: 0,speed,direction,altitude,y,emb_0,emb_1,emb_2,emb_3,emb_4,emb_5,emb_6,emb_7,emb_8,emb_9,lon,lat
0,"[1.9352457750899619, 1.9352457750899619, 0.894...","[355.52583819652267, 355.52583819652267, 56.89...","[688.976377952756, 688.976377952756, 688.97637...","[116.33364290439265, 39.9873397183338]","[0.6838635206222534, 1.0367484092712402, 1.036...","[-0.2248058319091797, 0.19317646324634552, 0.1...","[0.4852322041988373, 0.744053304195404, 0.7440...","[0.0734969973564148, 0.1836797297000885, 0.183...","[-0.700340747833252, -0.7748642563819885, -0.7...","[-0.23136866092681885, -0.36025261878967285, -...","[0.22274862229824066, -0.033303387463092804, -...","[0.6439347267150879, 0.8542038202285767, 0.854...","[0.08180934190750122, -0.21048569679260254, -0...","[-0.4374580681324005, -0.717753529548645, -0.7...","[116.3203436324123, 116.32030114464935, 116.32...","[39.9291209303895, 39.929871495701605, 39.9300..."
1,"[1.1239279560300646, 1.1239279560300646, 1.705...","[106.74764671955836, 89.99993047235131, 26.012...","[646.325459317585, 646.325459317585, 646.32545...","[116.31728186678397, 39.965828968334925]","[0.9132871627807617, 0.9132871627807617, 0.683...","[0.420354425907135, 0.420354425907135, -0.2248...","[0.5214747190475464, 0.5214747190475464, 0.485...","[0.028186190873384476, 0.028186190873384476, 0...","[-0.748537540435791, -0.748537540435791, -0.70...","[-0.2477864921092987, -0.2477864921092987, -0....","[-0.2597447335720062, -0.2597447335720062, 0.2...","[0.6203569769859314, 0.6203569769859314, 0.643...","[-0.12753674387931824, -0.12753674387931824, 0...","[-0.8862132430076599, -0.8862132430076599, -0....","[116.31575873004054, 116.31640344579755, 116.3...","[39.92610805798446, 39.92602371185887, 39.9263..."
2,"[0.06563349478627034, 0.06563349478627034, 8.3...","[119.21604933526879, 61.466344718072435, 332.2...","[203.412073490814, 236.220472440945, 269.02887...","[116.31988947986567, 39.92852736551485]","[0.7315746545791626, 0.7315746545791626, 0.731...","[-0.09455197304487228, -0.09455197304487228, -...","[0.5651118755340576, 0.5651118755340576, 0.565...","[0.08710302412509918, 0.08710302412509918, 0.0...","[-0.553844153881073, -0.553844153881073, -0.55...","[-0.28093671798706055, -0.28093671798706055, -...","[0.18330274522304535, 0.18330274522304535, 0.1...","[0.7112036943435669, 0.7112036943435669, 0.711...","[-0.0829773023724556, -0.0829773023724556, -0....","[-0.3958691954612732, -0.3958691954612732, -0....","[116.3276166256488, 116.32788640672595, 116.32...","[39.97429973930824, 39.97442150088496, 39.9744..."
3,"[0.37238197926286115, 0.37238197926286115, 0.4...","[23.648012741736125, 23.648012741736125, 303.1...","[164.041994750656, 285.433070866142, 187.00787...","[116.3073927411942, 39.98057102693695]","[1.6098276376724243, 1.6098276376724243, 1.609...","[0.14544053375720978, 0.14544053375720978, 0.1...","[0.1702210009098053, 0.1702210009098053, 0.170...","[-0.37662407755851746, -0.37662407755851746, -...","[-1.4968574047088623, -1.4968574047088623, -1....","[-0.5751749873161316, -0.5751749873161316, -0....","[-0.1121007427573204, -0.1121007427573204, -0....","[0.7581111192703247, 0.7581111192703247, 0.758...","[0.19667595624923706, 0.19667595624923706, 0.1...","[-1.0017015933990479, -1.0017015933990479, -1....","[116.33141685274326, 116.33149042087513, 116.3...","[39.976804029957734, 39.97691108087521, 39.977..."
4,"[0.5742396563635449, 0.5742396563635449, 0.611...","[277.43540358738346, 277.43540358738346, 188.7...","[108.267716535433, 108.267716535433, 108.26771...","[116.30093577822865, 39.968393183800075]","[0.9284706115722656, 0.9284706115722656, 0.928...","[0.3758922815322876, 0.3758922815322876, 0.375...","[0.30908721685409546, 0.30908721685409546, 0.3...","[-0.06933453679084778, -0.06933453679084778, -...","[-0.8942544460296631, -0.8942544460296631, -0....","[-0.23269900679588318, -0.23269900679588318, -...","[-0.28407737612724304, -0.28407737612724304, -...","[0.537669837474823, 0.537669837474823, 0.53766...","[0.10160468518733978, 0.10160468518733978, 0.1...","[-0.9181151390075684, -0.9181151390075684, -0....","[116.32833480683128, 116.32819935115883, 116.3...","[39.981171013599514, 39.98117835026088, 39.980..."


In [53]:
X = seq_gdf.drop('y', axis=1)
y = seq_gdf['y']

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
X_train.shape, X_test.shape, y_train.shape, y_test.shape

((3279, 15), (820, 15), (3279,), (820,))

In [54]:
torch.save((X_train, y_train), 'train_data.pt')
torch.save((X_test, y_test), 'test_data.pt')

In [42]:
class SequenceDataset(torch.utils.data.Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        X_seq = self.X.iloc[idx].values
        y_val = self.y.iloc[idx]
        
        X_tensor = torch.tensor([item for item in X_seq], dtype=torch.float32)
        y_tensor = torch.tensor(y_val, dtype=torch.float32)
        
        return X_tensor, y_tensor

In [43]:
train_dataset = SequenceDataset(X_train, y_train)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)

test_dataset = SequenceDataset(X_test, y_test)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)