# Associating and Joining Trajectories

This relies on the output of [./vectorized_filter.ipynb](./vectorized_filter.ipynb) -> [./lane_classification.ipynb](./lane_classification.ipynb)


In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
%load_ext autoreload
%autoreload 2

# find the root of the project
import os
from pathlib import Path

ROOT = Path(os.getcwd()).parent
while not ROOT.joinpath(".git").exists():
    ROOT = ROOT.parent

# add the root to the python path
import sys

sys.path.append(str(ROOT))

In [None]:
import dotenv
import polars as pl
from pomegranate.distributions import Normal
from pomegranate.gmm import GeneralMixtureModel


# load the environment variables
dotenv.load_dotenv(ROOT.joinpath(".env"))

## Read in the DataFrame


In [None]:
radar_df = pl.scan_parquet(
    ROOT.joinpath("notebooks/clean_workflow/data/imm_filtered_lanes.parquet"),
)

In [None]:
raw_df = pl.scan_parquet(
    ROOT.joinpath("notebooks/clean_workflow/data/all_working_processed_1Lane.parquet"),
)

In [None]:
correction_df = pl.scan_parquet(
    ROOT.joinpath("notebooks/clean_workflow/data/offsets.parquet"),
)

### Add in the IP Address


In [None]:
import polars as pl

radar_df = (
    radar_df
    # .fetch(1_000_000)
    .lazy()
    .sort(["epoch_time", "object_id"])
    .with_columns(
        pl.struct(["epoch_time", "object_id"]).hash(42).alias("epoch_object_hash")
    )
    .join(
        raw_df.select(
            list(set(raw_df.columns).difference(set(radar_df.columns)))
            + ["object_id", "epoch_time"]
        ).with_columns(
            pl.struct(["epoch_time", "object_id"]).hash(42).alias("epoch_object_hash")
        ),
        on="epoch_object_hash",
        how="left",
    )
    .with_columns(
        pl.col(list(set(raw_df.columns).difference(set(radar_df.columns))))
        .forward_fill()
        .over("object_id"),
    )
    .collect()
)

In [None]:
radar_df["object_id"].n_unique()

## Correcting Positional Information


In [None]:
radar_df = (
    radar_df
    # .tail(1_000_000)
    .lazy()
    # add in the positional correction
    .join(
        correction_df.select(["ip", "lane", "correction"]),
        on=["ip", "lane"],
        how="left",
    )
    .with_columns(
        (pl.col("s_filt") + pl.col("correction").fill_null(0)).alias("s_filt")
    )
    # find whether the vehicle is going towards or away from the radar
    .with_columns(
        (pl.col("f32_positionX_m") ** 2 + pl.col("f32_positionY_m") ** 2)
        .sqrt()
        .alias("distance")
    )
    .with_columns(
        (
            (pl.col("distance").diff() <= 0)
            .backward_fill()
            .over("object_id")
            .alias("towards_radar")
        )
    )
    # correct the length estimates
    .with_columns(
        [
            (pl.col("f32_distanceToFront_m") * pl.col("s_angle_diff").cos()).alias(
                "distanceToFront_s"
            ),
            (pl.col("f32_distanceToBack_m") * pl.col("s_angle_diff").cos()).alias(
                "distanceToBack_s"
            ),
            # do the vehicle length
            (pl.col("f32_length_m") * pl.col("s_angle_diff").cos()).alias("length_s"),
        ]
    )
    .with_columns(
        # use the median vehicle length
        pl.col("length_s").median().over("object_id").alias("median_length_s")
    )
    # Make the assumption that the radar picks up the plane of the vehicle closest to the radar
    # try to correct for this and get the true centroid of the vehicle
    .with_columns(pl.col("s_filt").alias("s_centroid"))
    # correct to find the true front and back of the vehicle
    .with_columns(
        [
            (pl.col("s_centroid") + (pl.col("median_length_s") / 2)).alias(
                "backBumper_s"
            ),
            (pl.col("s_centroid") - (pl.col("median_length_s") / 2)).alias(
                "frontBumper_s"
            ),
        ]
    )
    .collect()
)

### Add a Unique Column for Lane - Lane Index


In [None]:
radar_df = radar_df.with_columns(
    pl.struct(["lane", "lane_index"]).hash().alias("lane_hash")
)

### Create Leader Follower Pairs


In [None]:
from src.association.pipelines import build_match_df


matching_df = build_match_df(
    radar_df.select(
        [
            "object_id",
            "epoch_time",
            "s_centroid",
            "s_velocity_filt",
            "lane",
            "prediction",
            "lane_hash",
            "P",
            "d_filt",
            "d_velocity_filt",
            "frontBumper_s",
            "backBumper_s",
            "length_s",
        ]
    ),
).filter(
    ~(pl.col("prediction") & pl.col("prediction_leader")),
)

In [None]:
matching_df.shape

### Calculate the Mahalanobis Distance


In [None]:
from src.filters.fusion import (
    mahalanobis_distance,
    loglikelihood,
    association_loglikelihood_distance,
)
from scipy.stats import chi2
import torch

matching_df = (
    matching_df.pipe(
        mahalanobis_distance,
        cutoff=chi2.ppf(0.99, 4),
        gpu=True,
        batch_size=100_000,
    ).pipe(
        association_loglikelihood_distance,
        gpu=True,
    )
    # p(a = b) = 1 - p(a <> b) = 1 - (p(birth) + p(error) + p())
    # If I use a validation gate, then I have to normalize by the area of the gate
    # The potenital of using a complicated birth model here is
    # for now just rely on the gate
    # .pipe(loglikelihood, gpu=True)
)

torch.cuda.empty_cache()

### Calculate the Headways and Find the Middle of Leader-Follower Pairs


In [None]:
from src.association.pipelines import calculate_match_indexes, pipe_gate_headway_calc

matching_df = matching_df.filter(
    (pl.col("s_velocity_filt").abs() > 2) & (pl.col("s_velocity_filt_leader").abs() > 2)
).pipe(
    calculate_match_indexes,
)

In [None]:
valid_matches = (
    matching_df.pipe(pipe_gate_headway_calc, alpha=0.1)
    # .filter(
    #     (pl.col("inside_gate") > 0.5)
    # )
    .sort("epoch_time")
    .unnest("pair")
    .with_row_count()
    .join(
        radar_df.select(["object_id", "epoch_time"])
        .group_by("object_id")
        .agg(pl.col("epoch_time").max().alias("epoch_time_max")),
        on="object_id",
    )
    .join(
        radar_df.select(["object_id", "epoch_time"])
        .group_by("object_id")
        .agg(pl.col("epoch_time").max().alias("epoch_time_max_leader")),
        left_on="leader",
        right_on="object_id",
    )
)

In [None]:
from scipy.stats import chi

In [None]:
valid_matches.filter(pl.col("association_distance_filt") < 10).to_pandas()[
    "association_distance_filt"
].plot.hist(bins=100)

In [None]:
chi.ppf(0.99, 4)

In [None]:
keep_rows = (
    valid_matches.melt(
        id_vars=[
            "epoch_time",
            "row_nr",
            "prediction",
            "prediction_leader",
            "epoch_time_max",
            "epoch_time_max_leader",
            "association_distance_filt",
        ],
        value_vars=[
            "object_id",
            "leader",
        ],
    )
    .filter((pl.col("association_distance_filt") < chi.ppf(0.9875, 4)))
    .sort("value", "epoch_time")
    .with_columns(
        pl.when(pl.col("variable") == "object_id")
        .then(pl.col("prediction"))
        .otherwise(pl.col("prediction_leader"))
        .alias("prediction"),
        pl.when(pl.col("variable") == "object_id")
        .then(pl.col("epoch_time_max"))
        .otherwise(pl.col("epoch_time_max_leader"))
        .alias("my_end_time"),
        pl.when(pl.col("variable") == "object_id")
        .then(pl.col("epoch_time_max_leader"))
        .otherwise(pl.col("epoch_time_max"))
        .alias("other_end_time"),
    )
    .drop(["epoch_time_max", "epoch_time_max_leader", "prediction_leader", "variable"])
    .with_columns(
        pl.col("prediction").cumsum().over("value").alias("prediction_count"),
        pl.col("other_end_time")
        .filter(~pl.col("prediction"))
        .max()
        .over("value")
        .alias("other_end_time_max"),
    )
    .filter((pl.col("prediction_count") <= 1))
    .filter(pl.col("row_nr").count().over("row_nr") > 1)
)

keep_rows.head(20)

In [None]:
valid_matches = valid_matches.filter(
    pl.col("row_nr").is_in(keep_rows["row_nr"].unique())
).drop("row_nr")

### Build a Graph of Connected Vehicles


In [None]:
from src.association.pipelines import create_vehicle_ids

joined_df = radar_df.pipe(
    create_vehicle_ids,
    match_df=valid_matches,
)

### Mark Vehicle Group Ends


In [None]:
from src.plotting.time_space import plot_time_space
from datetime import timedelta
from src.radar import Filtering


# get a 10 minute window
plot_df = joined_df.filter(
    pl.col("epoch_time").is_between(
        joined_df["epoch_time"].min() + timedelta(hours=0, minutes=40),
        joined_df["epoch_time"].min() + timedelta(hours=0, minutes=45),
    )
    & (pl.col("lane").str.contains("WBL1"))
    & (pl.col("lane_index") == 1)
    # (pl.col("vehicle_id") == pl.lit(15420721423209556182))
    # pl.col('object_id').is_in([254, 147,])
).pipe(Filtering.add_cst_timezone)

fig = plot_time_space(
    plot_df,
    hoverdata="object_id",
    vehicle_col="vehicle_id",
    s_col="s_centroid",
    markers=True,
)

fig.show()

## Join Trajectories


In [None]:
joined_df.group_by(["vehicle_id", "epoch_time"]).count()[
    "count"
].value_counts().with_columns(
    [(pl.col("counts") / pl.col("counts").sum()).alias("percent")]
).sort("count")

### Create a 4D Data Frame of Vehicle States

- The dimensions are Time
- The Vehicle
- Measurements 1-3
- X dim


In [None]:
# ci_df = joined_df.filter(~pl.col("prediction"))

In [None]:
ci_df = (
    joined_df.lazy()
    # .filter(~pl.col("prediction"))
)

ci_df = (
    ci_df.sort(["vehicle_id", "epoch_time"])
    # .set_sorted(["vehicle_id", "epoch_time"])
    .join(
        ci_df.select(["vehicle_id", "epoch_time"])
        .unique()
        .sort(["vehicle_id", "epoch_time"])
        .with_columns(
            (pl.col("epoch_time").cumcount()).over("vehicle_id").alias("time_index")
        ),
        on=["vehicle_id", "epoch_time"],
    )
    .with_columns(
        pl.col("epoch_time").first().over("vehicle_id").alias("vehicle_start_time"),
    )
    .sort(
        [
            "prediction",
            "vehicle_start_time",
        ]
    )
    # .set_sorted(["vehicle_id", "vehicle_start_time", "epoch_time"])
    .with_columns(
        pl.col("object_id")
        .cumcount()
        .over(["vehicle_id", "time_index"])
        .alias("vehicle_time_index_int")
    )
    .filter(pl.col("vehicle_time_index_int") < 3)
    .sort("epoch_time")
    # .set_sorted("epoch_time")
    .with_columns(
        (pl.col("epoch_time").diff() / 1000)
        .cast(float)
        .over(
            "object_id",
        )
        .fill_null(0)
        .alias("time_diff")
    )
    .drop(["vehicle_start_time", "time_ind", "vehicle_ind"])
    .collect(streaming=True)
)

### Join using CI & then RTS Smooth


In [None]:
from src.filters.fusion import batch_join, rts_smooth

In [None]:
merged_df = batch_join(ci_df, method="CI", batch_size=5_000)

In [None]:
merged_df = rts_smooth(merged_df, gpu=True, batch_size=10_000)

In [None]:
from src.plotting.time_space import plot_time_space
from datetime import timedelta
from src.radar import Filtering


# get a 10 minute window
plot_df = (
    merged_df.filter(
        pl.col("epoch_time").is_between(
            joined_df["epoch_time"].min() + timedelta(hours=0, minutes=40),
            joined_df["epoch_time"].min() + timedelta(hours=0, minutes=45),
        )
        & (pl.col("lane").str.contains("WBL1"))
        & (pl.col("lane_index") == 1)
        # & (pl.col("vehicle_id").is_in(['563']))
        # pl.col('object_id').is_in([254, 147,])
    )
    .pipe(Filtering.add_cst_timezone)
    .sort("epoch_time")
)

fig = plot_time_space(
    plot_df,
    hoverdata="vehicle_id",
    vehicle_col="vehicle_id",
    markers=True,
    s_col="s_smooth",
)

fig.show()

### Plot Velocity


In [None]:
import plotly.graph_objects as go

veh = 4759

plot_df = merged_df.filter(pl.col("vehicle_id") == veh)

fig = go.Figure()

fig.add_trace(
    go.Scatter(
        x=plot_df["epoch_time"],
        y=plot_df["s_velocity_smooth"] * -1,
        mode="markers",
        marker=dict(color="green", size=2),
        name="s_velocity_smooth",
    )
)


fig.add_trace(
    go.Scatter(
        x=plot_df["epoch_time"],
        y=plot_df["ci_s_velocity"] * -1,
        mode="markers",
        marker=dict(color="blue", size=2),
        name="s_velocity_smooth",
    )
)


individual_traj = joined_df.filter(pl.col("vehicle_id") == veh).sort("epoch_time")

for v, v_df in individual_traj.group_by("object_id"):
    fig.add_trace(
        go.Scatter(
            x=v_df["epoch_time"],
            y=v_df["s_velocity"] * -1,
            mode="markers",
            marker=dict(color="red", size=2),
            name=f"{v}",
        )
    )


fig.show()

In [None]:
import plotly.graph_objects as go

veh = 5100

plot_df = merged_df.filter(pl.col("vehicle_id") == veh)

fig = go.Figure()

fig.add_trace(
    go.Scatter(
        x=plot_df["epoch_time"],
        y=plot_df["s_smooth"] * -1,
        mode="markers",
        marker=dict(color="blue", size=2),
        name="s_velocity_smooth",
    )
)


individual_traj = joined_df.filter(pl.col("vehicle_id") == veh).sort("epoch_time")

for v, v_df in individual_traj.group_by("object_id"):
    fig.add_trace(
        go.Scatter(
            x=v_df["epoch_time"],
            y=v_df["s_filt"] * -1,
            mode="markers",
            marker=dict(color="red", size=2),
            name=f"{v}",
        )
    )


# reverse the y axis
fig.update_yaxes(autorange="reversed")


fig.show()

### Saving the Filtered Trajectory Database


In [None]:
save_df = merged_df.select(
    ["vehicle_id", "lane", "lane_index", "s_smooth", "s_velocity_smooth", "epoch_time"]
).join(
    joined_df.group_by(["vehicle_id", "lane", "lane_index", "epoch_time"]).agg(
        pl.col("length_s").mean().alias("length_s"),
    ),
    on=["vehicle_id", "lane", "lane_index", "epoch_time"],
)

In [None]:
save_df.sort(["vehicle_id", "epoch_time"]).write_parquet(
    ROOT.joinpath("notebooks/clean_workflow/data/merged_trajectories.parquet"),
    compression_level=10,
)

In [None]:
joined_df.select(["object_id", "vehicle_id"]).unique().write_parquet(
    ROOT.joinpath("notebooks/clean_workflow/data/vehicle_id_map.parquet")
)