# Associating and Joining Trajectories

This relies on the output of [./vectorized_filter.ipynb](./vectorized_filter.ipynb) -> [./lane_classification.ipynb](./lane_classification.ipynb)

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
%load_ext autoreload
%autoreload 2

# find the root of the project
import os
from pathlib import Path

ROOT = Path(os.getcwd()).parent
while not ROOT.joinpath(".git").exists():
    ROOT = ROOT.parent

# add the root to the python path
import sys

sys.path.append(str(ROOT))

In [None]:
import dotenv
import polars as pl
from pomegranate.distributions import Normal
from pomegranate.gmm import GeneralMixtureModel


# load the environment variables
dotenv.load_dotenv(ROOT.joinpath(".env"))

## Read in the DataFrame

In [None]:
radar_df = pl.scan_parquet(
    ROOT.joinpath("notebooks/clean_workflow/data/imm_filtered_lanes.parquet"),
)

In [None]:
raw_df = pl.scan_parquet(
    ROOT.joinpath("tmp/all_working_processed_1Lane.parquet"),
)

### Add in the IP Address

In [None]:
import polars as pl

radar_df = radar_df.join(
    raw_df.select(
        list(set(raw_df.columns).difference(set(radar_df.columns)))
        + ["object_id", "epoch_time"]
    ),
    on=["object_id", "epoch_time"],
    how="inner",
).collect()

## Identifying Leader-Follower Pairs

### Identify Whether a Vehicle is Heading Towards or Away from the Radar

In [None]:
radar_df = (
    radar_df.with_columns(
        (pl.col("f32_positionX_m") ** 2 + pl.col("f32_positionY_m") ** 2)
        .sqrt()
        .alias("distance")
    )
    .with_columns(
        (
            (pl.col("distance").diff() <= 0)
            .backward_fill()
            .over("object_id")
            .alias("towards_radar")
        )
    )
    # correct the length estimates
    .with_columns(
        [
            (pl.col("f32_distanceToFront_m") * pl.col("s_angle_diff").cos()).alias(
                "distanceToFront_s"
            ),
            (pl.col("f32_distanceToBack_m") * pl.col("s_angle_diff").cos()).alias(
                "distanceToBack_s"
            ),
            # do the vehicle length
            (pl.col("f32_length_m") * pl.col("s_angle_diff").cos()).alias("length_s"),
        ]
    )
    .with_columns(
        # use the median vehicle length
        pl.col("length_s")
        .median()
        .over("object_id")
        .alias("median_length_s")
    )
    # Make the assumption that the radar picks up the plane of the vehicle closest to the radar
    # try to correct for this and get the true centroid of the vehicle
    .with_columns(pl.col("s_filt").alias("s_centroid"))
    # correct to find the true front and back of the vehicle
    .with_columns(
        [
            (pl.col("s_centroid") + (pl.col("median_length_s") / 2)).alias(
                "backBumper_s"
            ),
            (pl.col("s_centroid") - (pl.col("median_length_s") / 2)).alias(
                "frontBumper_s"
            ),
        ]
    )
)

### Add a Unique Column for Lane - Lane Index

In [None]:
radar_df = radar_df.with_columns(
    pl.struct(["lane", "lane_index"]).hash().alias("lane_hash")
)

### Create Leader Follower Pairs

In [None]:
from src.association.pipelines import build_match_df

matching_df = build_match_df(
    radar_df,
)

### Calculate the Mahalanobis Distance

In [None]:
from src.filters.fusion import mahalanobis_distance, loglikelihood
from scipy.stats import chi2


matching_df = (
    matching_df.pipe(
        mahalanobis_distance,
        cutoff=chi2.ppf(0.90, 4),
        gpu=True,
    )
    # p(a = b) = 1 - p(a <> b) = 1 - (p(birth) + p(error) + p())
    # If I use a validation gate, then I have to normalize by the area of the gate
    # The potenital of using a complicated birth model here is
    # for now just rely on the gate
    .pipe(loglikelihood, gpu=True)
)

### Calculate the Headways and Find the Middle of Leader-Follower Pairs

In [None]:
from src.association.pipelines import calculate_match_indexes, pipe_gate_headway_calc

matching_df = matching_df.filter(
    (pl.col("s_velocity_filt").abs() > 5) & (pl.col("s_velocity_filt_leader").abs() > 5)
).pipe(
    calculate_match_indexes,
)

In [None]:
valid_matches = (
    matching_df.pipe(pipe_gate_headway_calc)
    .filter(
        (pl.col("inside_gate") > 0.5)
        | ((pl.col("headway") < 0.5) & (pl.col("headway_std") < 0.1))
    )
    .sort("epoch_time")
    .unnest("pair")
    .with_columns(
        (pl.col("object_id").cumcount() + 1).over("object_id").alias("following_count"),
        (pl.col("object_id").cumcount() + 1).over("leader").alias("leader_count"),
    )
    .sort("object_id")
    .filter(
        (
            ~pl.col("prediction")
            | ((pl.col("following_count") < 3) & (pl.col("leader_count") < 3))
        )
    )
)

### Build a Graph of Connected Vehicles

In [None]:
from src.association.pipelines import create_vehicle_ids

joined_df = radar_df.pipe(
    create_vehicle_ids,
    match_df=valid_matches,
)

In [None]:
assert (
    joined_df.group_by(["object_id"]).agg(
        pl.col("vehicle_id").n_unique().alias("vehicle_count")
    )["vehicle_count"]
    == 1
).all()

### Mark Vehicle Group Ends

In [None]:
from src.plotting.time_space import plot_time_space
from datetime import timedelta
from src.radar import Filtering


# get a 10 minute window
plot_df = joined_df.filter(
    pl.col("epoch_time").is_between(
        joined_df["epoch_time"].min() + timedelta(hours=0, minutes=40),
        joined_df["epoch_time"].min() + timedelta(hours=0, minutes=45),
    )
    & (pl.col("lane").str.contains("EBL1"))
    & (pl.col("lane_index") == 1)
    # (pl.col("vehicle_id") == pl.lit(15420721423209556182))
    # pl.col('object_id').is_in([254, 147,])
).pipe(Filtering.add_cst_timezone)

fig = plot_time_space(
    plot_df,
    hoverdata="object_id",
    vehicle_col="vehicle_id",
    s_col="s_centroid",
    markers=True,
)

fig.show()

## Calculating the Positional Error as A Function of Radar Pair & IP Address

In [None]:
# exploting the fact that we are looking at EB lanes and ips for from 136 -> 147 in order
e_df = (
    joined_df.filter((pl.count().over(["epoch_time", "vehicle_id"]) > 1))
    .sort(["epoch_time", "ip"], descending=[False, False])
    .group_by(["epoch_time", "vehicle_id"])
    .agg(
        (pl.col("s_centroid").first() - pl.col("s_centroid")).alias("s_error"),
        pl.col("ip").first().alias("first_ip"),
        pl.col("ip"),
        pl.col("object_id"),
        pl.col("object_id").first().alias("first_object_id"),
        (pl.col("s_centroid").first() // 5).alias("s_pos_binned"),
        pl.col("lane").first().alias("lane"),
    )
    .explode(["ip", "s_error", "object_id"])
    .filter(
        (pl.col("ip") != pl.col("first_ip"))
        & (
            pl.col("s_error").is_between(
                pl.col("s_error").quantile(0.05), pl.col("s_error").quantile(0.95)
            )
        )
    )
)

#### Visualizing the Error

In [None]:
# # plot the error vs. s
# import plotly.graph_objects as go
# from plotly.subplots import make_subplots

# # make subplots
# fig = make_subplots(
#     rows=4,
#     cols=1,
#     shared_xaxes=False,
#     vertical_spacing=0.05,
#     specs=[[{"type": "scatter"}], [{"type": "scatter"}], [{"type": "scatter"}], [{"type": "scatter"}]],
# )

# for i, (r1, r2) in enumerate([("137", "141"), ("136", "137"), ("141", "142"), ("142", "146")]):
#     plot_df = (
#         e_df.filter(pl.col("first_ip").str.contains(r1) & pl.col("ip").str.contains(r2))
#         .group_by("s_pos_binned")
#         .agg(
#             pl.col("s_error").mean().alias("s_error_avg"),
#             pl.col("s_error").quantile(0.95).alias("s_error_q95"),
#             pl.col("s_error").quantile(0.05).alias("s_error_q05"),
#             pl.col("s_error").quantile(0.5).alias("s_error_q50"),
#             pl.col("s_error").count().alias("count"),
#         )
#         .sort("s_pos_binned", descending=False)
#         .with_columns(pl.col("s_pos_binned") * 5.0)
#     )

#     fig.add_trace(
#         go.Scatter(
#             x=plot_df["s_pos_binned"],
#             y=plot_df["s_error_q50"],
#             name=f"{r1} -> {r2}",
#         ),
#         row=1 + i,
#         col=1,
#     )

#     fig.add_trace(
#         go.Scatter(
#             x=plot_df["s_pos_binned"],
#             y=plot_df["s_error_q95"],
#             # make it a filled area
#             line=dict(
#                 color="rgba(255,255,255,0)",
#             ),
#             showlegend=False,
#         ),
#         row=1 + i,
#         col=1,
#     )

#     fig.add_trace(
#         go.Scatter(
#             x=plot_df["s_pos_binned"],
#             y=plot_df["s_error_q05"],
#             fill="tonexty",
#             showlegend=False,
#         ),
#         row=1 + i,
#         col=1,
#     )

# # make the x-axis descending
# fig.update_xaxes(autorange="reversed")

# fig.show()

### Creating a Correction DataFrame

In [None]:
keep_pairs = dict(
    [
        [136, 137],
        [137, 141],
        [141, 142],
        [142, 146],
        [146, 147],
    ]
)

correction_df = (
    e_df.group_by(
        [
            "first_ip",
            "ip",
            "lane",
        ]
    )
    .agg(
        pl.col("s_error").mean().alias("mean_s_error"),
        pl.col("s_error").median().alias("median_s_error"),
        pl.col("s_error").std().alias("std_s_error"),
        pl.col("s_error").count().alias("count"),
    )
    .sort(["first_ip", "ip"])
    .with_columns(
        pl.col(["first_ip", "ip"])
        .map_batches(lambda x: x.str.slice(-3).cast(int))
        .map_alias(lambda x: f"{x}_int")
    )
    .filter(pl.col("ip_int") == pl.col("first_ip_int").map_dict(keep_pairs))
    .sort(["first_ip", "ip"])
    .with_columns(pl.col("median_s_error").cumsum().over("lane").alias("correction"))
)


correction_df.sort('lane').head(10)
# keep_pairs = [[136, 147], [147, 136], [136, 254], [254, 136], [147, 254], [254, 147]]

In [None]:
joined_df = joined_df.join(
    correction_df.select(["ip", "lane", "correction"]), 
    on=["ip", "lane"],
    how='left',
).with_columns((pl.col("s_centroid") + pl.col("correction").fill_null(0)).alias("corrected_s"))

In [None]:
from src.plotting.time_space import plot_time_space
from datetime import timedelta
from src.radar import Filtering


# get a 10 minute window
plot_df = joined_df.filter(
    pl.col("epoch_time").is_between(
        joined_df["epoch_time"].min() + timedelta(hours=0, minutes=40),
        joined_df["epoch_time"].min() + timedelta(hours=0, minutes=45),
    )
    & (pl.col("lane").str.contains("WBL1"))
    & (pl.col("lane_index") == 1)
    # (pl.col("vehicle_id") == pl.lit(15420721423209556182))
    # pl.col('object_id').is_in([254, 147,])
).pipe(Filtering.add_cst_timezone)

fig = plot_time_space(
    plot_df,
    hoverdata="object_id",
    vehicle_col="vehicle_id",
    s_col="corrected_s",
    markers=True,
)

fig.show()

In [None]:
correction_df.write_parquet(
    ROOT / 'notebooks' / "clean_workflow" / "data" / "offsets.parquet"
)