# Imports

In [None]:
import os
from datetime import datetime

import geopandas as gpd
import movingpandas as mpd
import torch
from shapely.geometry import Point
from srai.datasets import PortoTaxiDataset
from srai.embedders import Hex2VecEmbedder
from srai.joiners import IntersectionJoiner
from srai.loaders import OSMPbfLoader
from srai.loaders.osm_loaders.filters import HEX2VEC_FILTER
from srai.neighbourhoods import H3Neighbourhood
from srai.regionalizers import H3Regionalizer
from srai.regionalizers import geocode_to_region_gdf
from tqdm import tqdm

In [None]:
gpd.options.io_engine = "pyogrio"

# Enable loading Environment Variables

In [None]:
%load_ext dotenv

%dotenv

In [None]:
hf_token = os.getenv("HF_TOKEN")

# Data Loading

In [None]:
subset_size = 1_000
use_subset = True

gdf_porto_taxi_full_path = os.path.join("data", "porto_taxi.feather")
gdf_porto_taxi_subset_path = os.path.join(
    "data", f"porto_taxi_subset_{subset_size}.feather"
)

In [None]:
if not use_subset:
    if not os.path.exists(gdf_porto_taxi_full_path):
        porto_taxi_dataset = PortoTaxiDataset()
        gdf_porto_taxi = porto_taxi_dataset.load(hf_token=hf_token)
        gdf_porto_taxi.to_feather(gdf_porto_taxi_full_path)
    else:
        gdf_porto_taxi = gpd.read_feather(gdf_porto_taxi_full_path)
else:
    if not os.path.exists(gdf_porto_taxi_subset_path):
        if not os.path.exists(gdf_porto_taxi_full_path):
            porto_taxi_dataset = PortoTaxiDataset()
            gdf_porto_taxi = porto_taxi_dataset.load(hf_token=hf_token)
            gdf_porto_taxi.to_feather(gdf_porto_taxi_full_path)
            gdf_porto_taxi = gdf_porto_taxi.head(subset_size)
            gdf_porto_taxi.to_feather(gdf_porto_taxi_subset_path)
        else:
            gdf_porto_taxi = gpd.read_feather(gdf_porto_taxi_full_path)
            gdf_porto_taxi = gdf_porto_taxi.head(subset_size)
            gdf_porto_taxi.to_feather(gdf_porto_taxi_subset_path)
    else:
        gdf_porto_taxi = gpd.read_feather(gdf_porto_taxi_subset_path)

In [None]:
gdf_porto_taxi.drop(
    [
        "taxi_id",
        "call_type",
        "origin_call",
        "origin_stand",
        "day_type",
        "travel_time_seconds",
    ],
    axis=1,
    inplace=True,
)

# Convert LineString to Point

In [None]:
exploded_rows = []

for idx, row in tqdm(gdf_porto_taxi.iterrows(), total=gdf_porto_taxi.shape[0]):
    start_timestamp = row.timestamp
    current_timestamp = start_timestamp
    for xy in row.geometry.coords:
        point = Point(xy)
        row_dict = row.to_dict()
        row_dict["geometry"] = point
        row_dict["timestamp"] = current_timestamp
        current_timestamp += 15
        exploded_rows.append(row_dict)

In [None]:
gdf_porto_taxi_points = gpd.GeoDataFrame(exploded_rows, crs="EPSG:4326")

In [None]:
gdf_porto_taxi_points["timestamp"] = gdf_porto_taxi_points["timestamp"].apply(
    lambda x: datetime.fromtimestamp(x)
)

In [None]:
gdf_porto_taxi_points.head()

### Restricting to Porto Area

In [None]:
porto_area = geocode_to_region_gdf("Porto District, Portugal")

In [None]:
porto_area.explore()

In [None]:
gdf_porto_taxi_points_inside_porto = gdf_porto_taxi_points.sjoin(porto_area)

In [None]:
gdf_porto_taxi_points_inside_porto.geometry.explore()

In [None]:
gdf_merged = gdf_porto_taxi_points.merge(
    gdf_porto_taxi_points_inside_porto, how="left", indicator=True
)
df_porto_taxi_points_outside_porto = gdf_merged[gdf_merged["_merge"] == "left_only"]

In [None]:
trajectories_outside_porto = list(
    df_porto_taxi_points_outside_porto["trip_id"].unique()
)

In [None]:
gdf_porto_taxi_points = gdf_porto_taxi_points[
    ~gdf_porto_taxi_points["trip_id"].isin(trajectories_outside_porto)
]

In [None]:
gdf_porto_taxi_points.geometry.explore()

# Trajectory Collection

In [None]:
trajectory_collection = mpd.TrajectoryCollection(
    data=gdf_porto_taxi_points, traj_id_col="trip_id", t="timestamp"
)

In [None]:
trajectory_collection.plot(column="trip_id", legend=False, figsize=(16, 9))

In [None]:
single_trajectory = trajectory_collection.trajectories[1].add_speed(
    units=("km", "h"), overwrite=True
)

In [None]:
single_trajectory.to_traj_gdf()

In [None]:
single_trajectory.hvplot(c="speed")

In [None]:
single_trajectory_copy = single_trajectory.copy()

mpd.DouglasPeuckerGeneralizer(single_trajectory_copy).generalize(
    tolerance=0.0001
).add_speed(units=("km", "h"), overwrite=True).hvplot(c="speed")

In [None]:
single_trajectory_copy = single_trajectory.copy()

mpd.OutlierCleaner(single_trajectory_copy).clean(
    v_max=120, units=("km", "h")
).add_speed(units=("km", "h"), overwrite=True).hvplot(c="speed")

In [None]:
single_trajectory_copy = single_trajectory.copy()

mpd.DouglasPeuckerGeneralizer(
    mpd.OutlierCleaner(single_trajectory_copy).clean(v_max=120, units=("km", "h"))
).generalize(tolerance=0.0001).add_speed(units=("km", "h"), overwrite=True).hvplot(
    c="speed"
)

# Regionalizer

In [None]:
regionalizer = H3Regionalizer(resolution=9)
regions = regionalizer.transform(gdf_porto_taxi_points)

In [None]:
regions.plot()

# Features from regions

In [None]:
loader = OSMPbfLoader()
features = loader.load(regions, HEX2VEC_FILTER)

In [None]:
features.head()

# Join Regions with Features

In [None]:
joiner = IntersectionJoiner()
joint = joiner.transform(regions, features)

In [None]:
neighbourhood = H3Neighbourhood(regions)

In [None]:
embedder_hidden_sizes = [150, 100, 50, 10]
embedder = Hex2VecEmbedder(embedder_hidden_sizes)

# Fit Embeddings

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

embeddings = embedder.fit_transform(
    regions,
    features,
    joint,
    neighbourhood,
    trainer_kwargs={"max_epochs": 15, "accelerator": device},
    batch_size=100,
)

In [None]:
embeddings.head()