In [39]:
import geopandas as gpd
import pandas as pd
import copy
import os
import warnings
import random
from pprint import pprint
import geopandas as gpd
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm.notebook import tqdm
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from sklearn.model_selection import train_test_split

from srai.datasets import AirbnbMulticityDataset
from srai.embedders import Hex2VecEmbedder
from srai.joiners import IntersectionJoiner
from srai.loaders.osm_loaders import OSMPbfLoader
from srai.loaders.osm_loaders.filters import HEX2VEC_FILTER
from srai.neighbourhoods.h3_neighbourhood import H3Neighbourhood
from srai.plotting import plot_regions
from srai.regionalizers import H3Regionalizer
import plotly.graph_objs as go
from shapely.geometry import LineString, Point
from shapely import from_geojson
import h3
from srai.h3 import h3_to_geoseries

In [2]:
gdf = gpd.read_parquet('output_data/geolife_points.parquet')

In [3]:
gdf.shape

(49999, 12)

In [5]:
gdf.head()

Unnamed: 0_level_0,latitude,longitude,altitude,date,date_str,trajectory_id,mode,geometry,user_id,speed,timedelta,direction
time,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1
2008-12-27 07:26:04,39.969352,116.398589,492.0,39809.309769,2008-12-27,20081227072604,unknown,POINT (116.39859 39.96935),135,13.881385,NaT,160.099874
2008-12-27 07:26:06,39.969117,116.3987,492.0,39809.309792,2008-12-27,20081227072604,unknown,POINT (116.39870 39.96912),135,13.881385,0 days 00:00:02,160.099874
2008-12-27 07:26:07,39.969016,116.398799,492.0,39809.309803,2008-12-27,20081227072604,unknown,POINT (116.39880 39.96902),135,14.046284,0 days 00:00:01,143.085532
2008-12-27 07:26:08,39.968949,116.398868,491.0,39809.309815,2008-12-27,20081227072604,unknown,POINT (116.39887 39.96895),135,9.491683,0 days 00:00:01,141.717016
2008-12-27 07:26:09,39.968893,116.398905,491.0,39809.309826,2008-12-27,20081227072604,unknown,POINT (116.39890 39.96889),135,6.975264,0 days 00:00:01,153.143848


In [6]:
# gdf_exp = gdf[:10000].copy()

In [7]:
# group by trajectory_id
gdf_agg = gdf.groupby('trajectory_id').agg({'geometry': LineString, 'date_str': list, 'speed': list, 'direction': list, 'altitude': list, 'trajectory_id': 'first'})

In [8]:
gdf_agg.head()

Unnamed: 0_level_0,geometry,date_str,speed,direction,altitude,trajectory_id
trajectory_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
20081227072604,"LINESTRING (116.39859 39.96935, 116.39870 39.9...","[2008-12-27, 2008-12-27, 2008-12-27, 2008-12-2...","[13.881384728094615, 13.881384728094615, 14.04...","[160.09987375998526, 160.09987375998526, 143.0...","[492.0, 492.0, 492.0, 491.0, 491.0, 491.0, 491...",20081227072604
20090101024458,"LINESTRING (116.33176 39.93721, 116.33181 39.9...","[2009-01-01, 2009-01-01, 2009-01-01, 2009-01-0...","[26.440512857858558, 26.440512857858558, 2.414...","[9.265462560995047, 9.265462560995047, 14.9894...","[492.0, 492.0, 492.0, 492.0, 492.0, 492.0, 492...",20090101024458
20090102043127,"LINESTRING (116.39653 39.96581, 116.39645 39.9...","[2009-01-02, 2009-01-02, 2009-01-02, 2009-01-0...","[4.726638177926822, 4.726638177926822, 9.32635...","[313.0537264342987, 313.0537264342987, 116.225...","[492.0, 492.0, 492.0, 492.0, 492.0, 492.0, 492...",20090102043127
20090103012134,"LINESTRING (116.39974 39.97429, 116.39959 39.9...","[2009-01-03, 2009-01-03, 2009-01-03, 2009-01-0...","[12.73039176894192, 12.73039176894192, 6.18924...","[268.9965774862792, 268.9965774862792, 287.822...","[492.0, 492.0, 492.0, 492.0, 491.0, 491.0, 491...",20090103012134
20090110011947,"LINESTRING (116.39828 39.97375, 116.39830 39.9...","[2009-01-10, 2009-01-10, 2009-01-10, 2009-01-1...","[22.682514166001585, 22.682514166001585, 2.014...","[3.010510124255461, 3.010510124255461, 257.213...","[54.0, 77.0, 99.0, 101.0, 93.0, 107.0, 114.0, ...",20090110011947


In [9]:
gdf_agg.shape

(9, 6)

In [10]:
def plot_trajectories(trajectories):
    if isinstance(trajectories[0], LineString):
        trajectories = [list(trajectory.coords) for trajectory in trajectories]

    fig = go.Figure()
    for trajectory in trajectories:
        lon = [point[0] for point in trajectory]
        lat = [point[1] for point in trajectory]

        fig.add_trace(go.Scattermapbox(
            mode="markers+lines",
            lon=lon,
            lat=lat,
            marker={'size': 10}
        ))

    fig.update_layout(
        mapbox_style="open-street-map",
        mapbox_zoom=10,
        margin={"r": 0, "t": 0, "l": 0, "b": 0}
    )

    fig.show()

In [11]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
device = 'cpu'
device

'cpu'

## 10 regionów na podstawie 2 trajektorii

In [12]:
HEX_RES = 9

In [13]:
regionalizer = H3Regionalizer(resolution=HEX_RES)
regions = regionalizer.transform(gdf)

In [14]:
regions.index

Index(['8940a63b463ffff', '89409015283ffff', '894092cc10fffff',
       '89319e00a07ffff', '89319ebb3b7ffff', '894095c2457ffff',
       '8940b315e83ffff', '89319e00aa3ffff', '89409249a87ffff',
       '893183c842bffff',
       ...
       '8940b377673ffff', '8940928028fffff', '8940928904fffff',
       '8940b36f16fffff', '8931822764fffff', '89319c92643ffff',
       '8940b314207ffff', '8940b150a73ffff', '8940b302b4fffff',
       '8940954909bffff'],
      dtype='object', name='region_id', length=4960)

In [15]:
loader = OSMPbfLoader()
features = loader.load(regions, HEX2VEC_FILTER)

  ].unary_union


In [16]:
joiner = IntersectionJoiner()
joint = joiner.transform(regions, features)

In [17]:
# getting embeddings using hex2vec
neighbourhood = H3Neighbourhood(regions)
embedder_hidden_sizes = [150, 100, 50, 10]
embedder = Hex2VecEmbedder(embedder_hidden_sizes)
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    embeddings = embedder.fit_transform(
        regions,
        features,
        joint,
        neighbourhood,
        trainer_kwargs={"max_epochs": 5, "accelerator": device},
        batch_size=100,
    )
embeddings.shape

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4960/4960 [00:00<00:00, 55558.74it/s]
GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name    | Type       | Params
---------------------------------------
0 | encoder | Sequential | 54.1 K
---------------------------------------
54.1 K    Trainable params
0         Non-trainable params
54.1 K    Total params
0.216     Total estimated model params size (MB)


Training: |                                                                                                   …

`Trainer.fit` stopped: `max_epochs=5` reached.


(4960, 10)

In [18]:
embeddings

Unnamed: 0_level_0,0,1,2,3,4,5,6,7,8,9
region_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
8940a63b463ffff,-0.159533,0.311754,-0.282978,-0.219009,0.122352,0.022630,0.360571,0.011850,0.290332,0.237417
89409015283ffff,0.009939,-0.410331,0.402643,0.294530,-0.492534,-0.568063,-0.252152,-0.313117,-0.437425,-0.623071
894092cc10fffff,-0.159533,0.311754,-0.282978,-0.219009,0.122352,0.022630,0.360571,0.011850,0.290332,0.237417
89319e00a07ffff,-0.159533,0.311754,-0.282978,-0.219009,0.122352,0.022630,0.360571,0.011850,0.290332,0.237417
89319ebb3b7ffff,0.356493,-0.723548,0.727367,0.527343,-0.371635,-0.235114,-0.724987,-0.101245,-0.619012,-0.581032
...,...,...,...,...,...,...,...,...,...,...
89319c92643ffff,-0.159533,0.311754,-0.282978,-0.219009,0.122352,0.022630,0.360571,0.011850,0.290332,0.237417
8940b314207ffff,0.041535,-0.440938,0.424034,0.320857,-0.478545,-0.546044,-0.293037,-0.303098,-0.439275,-0.625058
8940b150a73ffff,-0.159533,0.311754,-0.282978,-0.219009,0.122352,0.022630,0.360571,0.011850,0.290332,0.237417
8940b302b4fffff,-0.159533,0.311754,-0.282978,-0.219009,0.122352,0.022630,0.360571,0.011850,0.290332,0.237417


## Dla każdej sekwencji trajektorii tworzymy kopie z odpowiadającym jej hex np. trajektoria pada na 3 hexy --> mamy 3x trajektorie z 3 roznymi hex

In [19]:
joined_gdf = gpd.sjoin(gdf, regions, how="left")
joined_gdf.rename(columns={"index_right": "h3_index"}, inplace=True)

In [20]:
joined_gdf.head()

Unnamed: 0_level_0,latitude,longitude,altitude,date,date_str,trajectory_id,mode,geometry,user_id,speed,timedelta,direction,h3_index
time,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1
2008-12-27 07:26:04,39.969352,116.398589,492.0,39809.309769,2008-12-27,20081227072604,unknown,POINT (116.39859 39.96935),135,13.881385,NaT,160.099874,8931aa55097ffff
2008-12-27 07:26:06,39.969117,116.3987,492.0,39809.309792,2008-12-27,20081227072604,unknown,POINT (116.39870 39.96912),135,13.881385,0 days 00:00:02,160.099874,8931aa5554bffff
2008-12-27 07:26:07,39.969016,116.398799,492.0,39809.309803,2008-12-27,20081227072604,unknown,POINT (116.39880 39.96902),135,14.046284,0 days 00:00:01,143.085532,8931aa5554bffff
2008-12-27 07:26:08,39.968949,116.398868,491.0,39809.309815,2008-12-27,20081227072604,unknown,POINT (116.39887 39.96895),135,9.491683,0 days 00:00:01,141.717016,8931aa5554bffff
2008-12-27 07:26:09,39.968893,116.398905,491.0,39809.309826,2008-12-27,20081227072604,unknown,POINT (116.39890 39.96889),135,6.975264,0 days 00:00:01,153.143848,8931aa5554bffff


In [21]:
joined_gdf = joined_gdf.merge(
    gpd.GeoDataFrame(
        {"h3_index": regions.index, "geometry": regions.geometry}
    ),
    on="h3_index",
    how="left",
)

In [23]:
joined_gdf.shape, regions.shape, gdf.shape

((49999, 14), (4960, 1), (49999, 12))

In [24]:
def plot_both(row):
    traj = row['geometry_x']
    region = row['geometry_y']
    fig = go.Figure()
    lon = [point[0] for point in traj.coords]
    lat = [point[1] for point in traj.coords]
    fig.add_trace(go.Scattermapbox(
        mode="markers+lines",
        lon=lon,
        lat=lat,
        marker={'size': 10}
    ))
    lon = [point[0] for point in region.exterior.coords]
    lat = [point[1] for point in region.exterior.coords]
    fig.add_trace(go.Scattermapbox(
        mode="markers+lines",
        lon=lon,
        lat=lat,
        marker={'size': 10}
    ))
    fig.update_layout(
        mapbox_style="open-street-map",
        mapbox_zoom=10,
        margin={"r": 0, "t": 0, "l": 0, "b": 0}
    )
    fig.show()

In [25]:
merged_gdf = embeddings.merge(
    joined_gdf, how="inner", left_on="region_id", right_on="h3_index"
)

In [26]:
merged_gdf.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,date_str,trajectory_id,mode,geometry_x,user_id,speed,timedelta,direction,h3_index,geometry_y
0,-0.159533,0.311754,-0.282978,-0.219009,0.122352,0.02263,0.360571,0.01185,0.290332,0.237417,...,2009-01-24,20090124065103,unknown,POINT (113.92332 31.04352),135,36.104048,0 days 00:00:01,191.53685,8940a63b463ffff,"POLYGON ((113.92356 31.04256, 113.92316 31.044..."
1,-0.159533,0.311754,-0.282978,-0.219009,0.122352,0.02263,0.360571,0.01185,0.290332,0.237417,...,2009-01-24,20090124065103,unknown,POINT (113.92324 31.04321),135,35.620712,0 days 00:00:01,192.164489,8940a63b463ffff,"POLYGON ((113.92356 31.04256, 113.92316 31.044..."
2,-0.159533,0.311754,-0.282978,-0.219009,0.122352,0.02263,0.360571,0.01185,0.290332,0.237417,...,2009-01-24,20090124065103,unknown,POINT (113.92316 31.04289),135,36.302351,0 days 00:00:01,191.319993,8940a63b463ffff,"POLYGON ((113.92356 31.04256, 113.92316 31.044..."
3,-0.159533,0.311754,-0.282978,-0.219009,0.122352,0.02263,0.360571,0.01185,0.290332,0.237417,...,2009-01-24,20090124065103,unknown,POINT (113.92309 31.04257),135,35.797773,0 days 00:00:01,191.792453,8940a63b463ffff,"POLYGON ((113.92356 31.04256, 113.92316 31.044..."
4,-0.159533,0.311754,-0.282978,-0.219009,0.122352,0.02263,0.360571,0.01185,0.290332,0.237417,...,2009-01-24,20090124065103,unknown,POINT (113.92301 31.04225),135,36.340493,0 days 00:00:01,191.613836,8940a63b463ffff,"POLYGON ((113.92356 31.04256, 113.92316 31.044..."


In [27]:
def get_hex(point, res = HEX_RES):
    # returns the id of the region to which the point belongs
    h = h3.latlng_to_cell(point[1], point[0], res)
    # print(h)
    return h

In [28]:
def get_emb(point):
    # returns the embedding of the region to which the point belongs
    return embeddings.loc[get_hex(point)]

In [29]:
def get_emb_list(traj):
    res = []
    for point in traj.coords:
        res.append(get_emb(point))
    return res

In [30]:
def make_seq(row, seq_len=200):
    geometry = row['geometry']
    speed = row['speed']
    direction = row['direction']
    altitude = row['altitude']
    for i in range(0, len(geometry.coords) - seq_len, seq_len):
        yield {
        'geometry': LineString(geometry.coords[i:i+seq_len]),
        'speed': speed[i:i+seq_len],
        'direction': direction[i:i+seq_len],
        'altitude': altitude[i:i+seq_len],
        'trajectory_id': row['trajectory_id'],
        'emb': get_emb_list(LineString(geometry.coords[i:i+seq_len])),
        'y': Point(geometry.coords[i+seq_len]) 
        }

# SEQ_GDF

In [116]:
seq_gdf = gpd.GeoDataFrame(list(row for _, row in gdf_agg.iterrows() for row in make_seq(row)))
seq_gdf.crs = gdf.crs

In [117]:
seq_gdf.shape

(246, 7)

In [118]:
# test = pd.DataFrame(seq_gdf['emb'][0])

In [119]:
# # group columns to list
# stest = test.transpose().groupby(level=0).apply(lambda x: x.values.tolist())
# stest = pd.DataFrame(stest).transpose()

In [120]:
def extend_emb(row):
    emb = pd.DataFrame(row['emb'])
    emb = emb.transpose().groupby(level=0).apply(lambda x: x.values.tolist())
    emb = pd.DataFrame(emb).transpose()
    for i in range(emb.shape[1]):
        row[f'emb_{i}'] = emb[i][0][0]
    return row

In [121]:
seq_gdf = seq_gdf.apply(extend_emb, axis=1)

In [122]:
seq_gdf['geometry'] = seq_gdf['geometry'].apply(lambda x: list(x.coords))
seq_gdf['lon'] = seq_gdf['geometry'].apply(lambda x: [point[0] for point in x])
seq_gdf['lat'] = seq_gdf['geometry'].apply(lambda x: [point[1] for point in x])
seq_gdf['y'] = seq_gdf['y'].apply(lambda x: [x.x, x.y])


  seq_gdf['geometry'] = seq_gdf['geometry'].apply(lambda x: list(x.coords))


In [123]:
# drop col
col = ['trajectory_id', 'emb', 'geometry']
seq_gdf.drop(col, axis=1, inplace=True)

In [126]:
seq_gdf.head()

Unnamed: 0,speed,direction,altitude,y,emb_0,emb_1,emb_2,emb_3,emb_4,emb_5,emb_6,emb_7,emb_8,emb_9,lon,lat
0,"[13.881384728094615, 13.881384728094615, 14.04...","[160.09987375998526, 160.09987375998526, 143.0...","[492.0, 492.0, 492.0, 491.0, 491.0, 491.0, 491...","[116.401025, 39.966273]","[0.7119172811508179, 0.6457410454750061, 0.645...","[-0.7525131106376648, -0.7173189520835876, -0....","[0.5731177926063538, 0.5326927304267883, 0.532...","[0.5906577110290527, 0.5527756810188293, 0.552...","[0.47214147448539734, 0.42731523513793945, 0.4...","[1.133920431137085, 1.0219310522079468, 1.0219...","[-1.0184216499328613, -0.9298431873321533, -0....","[0.3922121524810791, 0.34579455852508545, 0.34...","[-0.4229356050491333, -0.401667058467865, -0.4...","[0.19938382506370544, 0.16551509499549866, 0.1...","[116.398589, 116.3987, 116.398799, 116.398868,...","[39.969352, 39.969117, 39.969016, 39.968949, 3..."
1,"[9.084598332238107, 9.097043835047296, 9.54958...","[171.92195610859665, 171.3912319062665, 167.64...","[219.0, 216.0, 215.0, 213.0, 210.0, 209.0, 209...","[116.401691, 39.951529]","[0.4736114740371704, 0.4736114740371704, 0.473...","[-0.6809360980987549, -0.6809360980987549, -0....","[0.5464069247245789, 0.5464069247245789, 0.546...","[0.4949587881565094, 0.4949587881565094, 0.494...","[0.07365410029888153, 0.07365410029888153, 0.0...","[0.43190819025039673, 0.43190819025039673, 0.4...","[-0.7212070226669312, -0.7212070226669312, -0....","[0.12159909307956696, 0.12159909307956696, 0.1...","[-0.4230879545211792, -0.4230879545211792, -0....","[-0.14892378449440002, -0.14892378449440002, -...","[116.401025, 116.401041, 116.401065, 116.40108...","[39.966273, 39.966192, 39.966108, 39.96603, 39..."
2,"[1.6048093521818334, 1.8321806157672382, 1.963...","[194.33402454321873, 180.66545941872383, 177.5...","[183.0, 172.0, 172.0, 169.0, 169.0, 167.0, 166...","[116.402274, 39.939697]","[0.904712975025177, 0.904712975025177, 0.90471...","[-0.9398432374000549, -0.9398432374000549, -0....","[0.7874607443809509, 0.7874607443809509, 0.787...","[0.7536092400550842, 0.7536092400550842, 0.753...","[0.39632901549339294, 0.39632901549339294, 0.3...","[1.081040859222412, 1.081040859222412, 1.08104...","[-1.191597819328308, -1.191597819328308, -1.19...","[0.4578467011451721, 0.4578467011451721, 0.457...","[-0.5059295892715454, -0.5059295892715454, -0....","[0.06625568866729736, 0.06625568866729736, 0.0...","[116.401691, 116.40169, 116.401693, 116.401711...","[39.951529, 39.951463, 39.95141, 39.95137, 39...."
3,"[3.337850361530505, 4.233099950289504, 3.28550...","[176.3441484920835, 184.61419885801112, 175.54...","[103.0, 102.0, 100.0, 97.0, 94.0, 96.0, 106.0,...","[116.392419, 39.93942]","[0.44317102432250977, 0.44317102432250977, 0.4...","[-0.4134470820426941, -0.4134470820426941, -0....","[0.24752885103225708, 0.24752885103225708, 0.2...","[0.29616084694862366, 0.29616084694862366, 0.2...","[0.4406273365020752, 0.4406273365020752, 0.440...","[0.9112904667854309, 0.9112904667854309, 0.911...","[-0.6926287412643433, -0.6926287412643433, -0....","[0.34209734201431274, 0.34209734201431274, 0.3...","[-0.2924352288246155, -0.2924352288246155, -0....","[0.252744197845459, 0.252744197845459, 0.25274...","[116.402274, 116.402266, 116.402272, 116.40227...","[39.939697, 39.939621, 39.939562, 39.939503, 3..."
4,"[3.270758161746067, 5.942572253366157, 2.73726...","[247.99470744245548, 290.8689995274607, 272.33...","[156.0, 152.0, 148.0, 144.0, 138.0, 138.0, 138...","[116.389321, 39.936959]","[0.263872891664505, 0.263872891664505, 0.26387...","[-0.39086949825286865, -0.39086949825286865, -...","[0.21079397201538086, 0.21079397201538086, 0.2...","[0.28500741720199585, 0.28500741720199585, 0.2...","[0.2509629726409912, 0.2509629726409912, 0.250...","[0.5820782780647278, 0.5820782780647278, 0.582...","[-0.4978726804256439, -0.4978726804256439, -0....","[0.13569317758083344, 0.13569317758083344, 0.1...","[-0.2643700838088989, -0.2643700838088989, -0....","[0.09379686415195465, 0.09379686415195465, 0.0...","[116.392419, 116.392354, 116.39229, 116.392226...","[39.93942, 39.939439, 39.939441, 39.939442, 39..."


In [89]:
X = seq_gdf.drop('y', axis=1)
y = seq_gdf['y']

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
X_train.shape, X_test.shape, y_train.shape, y_test.shape

((196, 15), (50, 15), (196,), (50,))

In [130]:
torch.save((X_train, y_train), 'train_data.pt')
torch.save((X_test, y_test), 'test_data.pt')

In [127]:
class SequenceDataset(torch.utils.data.Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        X_seq = self.X.iloc[idx].values
        y_val = self.y.iloc[idx]
        
        X_tensor = torch.tensor([item for item in X_seq], dtype=torch.float32)
        y_tensor = torch.tensor(y_val, dtype=torch.float32)
        
        return X_tensor, y_tensor

In [128]:
train_dataset = SequenceDataset(X_train, y_train)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)

test_dataset = SequenceDataset(X_test, y_test)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)