# Using the Processed & Joined DataFrame to Identify Leader & Follower Pairs

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
%load_ext autoreload
%autoreload 2

# find the root of the project
import os
from pathlib import Path
import polars as pl

ROOT = Path(os.getcwd()).parent
while not ROOT.joinpath(".git").exists():
    ROOT = ROOT.parent

# add the root to the python path
import sys

sys.path.append(str(ROOT))

## Read in the Processed Trajectory Data

In [None]:
trajectory_df = pl.read_parquet(
    ROOT / "notebooks" / "clean_workflow" / "data" / "merged_trajectories.parquet"
).with_columns(
    pl.struct(["lane", "lane_index"]).hash().alias("lane_hash"),
)

In [None]:
trajectory_df.head()

### Build a Match DataFrame

In [None]:
from src.association.pipelines import build_match_df

matching_df = build_match_df(
    trajectory_df,
    s_thresh=200,
    object_id_col="vehicle_id",
    s_col="s_smooth",
).with_columns(
    # calculate the time headway
    time_headway=pl.col("s_gap")
    / pl.col("s_velocity_smooth"),
)

matching_df.head()

### Group the Trajectories and Rank Them Based on Un-broken Duration

In [None]:
pair_df = (
    matching_df.sort(["vehicle_id", "epoch_time"])
    .with_columns(
        leader_change=(pl.col("leader") != pl.col("leader").shift(1).backward_fill())
        .cumsum()
        .over("vehicle_id")
    )
    .with_columns(
        leader_follower=pl.struct(["vehicle_id", "leader", "leader_change"]),
        leader_follower_hash=pl.struct(
            ["vehicle_id", "leader", "leader_change"]
        ).hash(),
    )
    .group_by("leader_follower_hash")
    .agg(
        pl.col("vehicle_id").first().alias("vehicle_id"),
        pl.col("leader").first().alias("leader"),
        pl.col("leader_change").first().alias("leader_change"),
        pl.col("s_velocity_smooth").mean().alias("s_velocity_smooth"),
        pl.col("time_headway")
        .filter(pl.col("s_velocity_smooth").abs() > 5)
        .mean()
        .alias("time_headway"),
        pl.col("time_headway")
        .filter(pl.col("s_velocity_smooth").abs() > 5)
        .min()
        .alias("time_headway_min"),
        pl.col("time_headway")
        .filter(pl.col("s_velocity_smooth").abs() > 5)
        .max()
        .alias("time_headway_max"),
        (-1 * pl.col("s_gap").mean()).alias("s_gap"),
        (-1 * pl.col("s_gap").min()).alias("s_gap_min"),
        (pl.col("s_smooth").first() - pl.col("s_gap").last()).alias("s_distance"),
        pl.col("epoch_time").min().alias("start_time"),
        pl.col("epoch_time").max().alias("end_time"),
        pl.col("epoch_time").count().alias("num_points"),
    )
    .with_columns(
        # calculate the total time in follower mode
        total_time=(pl.col("end_time") - pl.col("start_time"))
        / 1000,
    )
    .filter(
        # filter out pairs with less than 10 seconds of follower time
        (pl.col("total_time") > 30)
        & (pl.col("time_headway_min").is_between(0.5, 5))
        & (pl.col("s_distance") > 50)
    )
)

pair_df.head()

In [None]:
pair_df.sort(pl.col("total_time"))

## Test Querying the DataFrame

In [None]:
trajectory_df = pl.scan_parquet(
    ROOT / "notebooks" / "clean_workflow" / "data" / "merged_trajectories.parquet"
)

In [None]:
def load_trajectories(
    follower_id: int,
    leader_id: int,
    trajectory_df: pl.LazyFrame,
) -> pl.DataFrame:
    interest_df = trajectory_df.filter(
        (pl.col("vehicle_id") == follower_id) | (pl.col("vehicle_id") == leader_id)
    )

    return (
        interest_df.filter(
            pl.col("epoch_time").is_between(
                pl.max_horizontal(
                    [
                        pl.col("epoch_time")
                        .filter(pl.col("vehicle_id") == follower_id)
                        .first(),
                        pl.col("epoch_time")
                        .filter(pl.col("vehicle_id") == leader_id)
                        .first(),
                    ]
                ),
                pl.min_horizontal(
                    [
                        pl.col("epoch_time")
                        .filter(pl.col("vehicle_id") == follower_id)
                        .last(),
                        pl.col("epoch_time")
                        .filter(pl.col("vehicle_id") == leader_id)
                        .last(),
                    ]
                )
            )
        )
        .with_columns(
            (pl.col("s_smooth").max() - pl.col("s_smooth")),
            ((pl.col('epoch_time') - pl.col('epoch_time').first()) / 1000).alias('sim_time')
        )
        .with_columns(
            pl.col('s_smooth').shift_and_fill(
                pl.col('s_smooth').first() - pl.col('s_velocity_smooth').first() * 0.1,
            ).over('vehicle_id')
        )
        .collect()
    )


match_df = pair_df.sample(1)


test_df = load_trajectories(
    follower_id=match_df["vehicle_id"][0],
    leader_id=match_df["leader"][0],
    trajectory_df=trajectory_df,
)

In [None]:
import plotly.graph_objects as go

veh = 5100

follower_df = test_df.filter(pl.col("vehicle_id") == match_df["vehicle_id"][0])

fig = go.Figure()

fig.add_trace(
    go.Scatter(
        x=follower_df["sim_time"],
        y=follower_df["s_smooth"],
        mode="markers",
        marker=dict(color="blue", size=2),
        name="Follower",
    )
)


leader_df = test_df.filter(pl.col("vehicle_id") == match_df["leader"][0])

fig.add_trace(
    go.Scatter(
        x=leader_df["sim_time"],
        y=leader_df["s_smooth"],
        mode="markers",
        marker=dict(color="red", size=2),
        name="Leader",
    )
)


fig.show()

In [None]:
import plotly.graph_objects as go

follower_df = test_df.filter(pl.col("vehicle_id") == match_df["vehicle_id"][0])
print(match_df["vehicle_id"][0])

fig = go.Figure()

fig.add_trace(
    go.Scatter(
        x=follower_df["sim_time"],
        y=follower_df["s_velocity_smooth"] * -1,
        mode="markers",
        marker=dict(color="blue", size=2),
        name="Follower",
    )
)


leader_df = test_df.filter(pl.col("vehicle_id") == match_df["leader"][0])
print(match_df["leader"][0])

fig.add_trace(
    go.Scatter(
        x=leader_df["sim_time"],
        y=leader_df["s_velocity_smooth"] * -1,
        mode="markers",
        marker=dict(color="red", size=2),
        name="Leader",
    )
)


fig.show()

In [None]:
pair_df.write_database(
    
)