In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import polars as pl
import numpy as np
from numpy.typing import NDArray
from typing import Sequence, Any, Tuple, Union, List, Dict, Optional, TypeVar, Callable, Iterable, cast
from datetime import datetime, timedelta

In [None]:
import awkward as ak

data = ak.from_parquet("detections.parquet")
display(data.typestr)

In [None]:
detections_coords = data[:, :, ["cX", "cY"]]
# use unzip to separate the x and y coordinates
cXs, cYs = ak.unzip(detections_coords)

In [None]:
# https://github.com/rlabbe/filterpy
import jax.numpy as jnp
from jax import random, vmap, jit, grad, value_and_grad
import numpy as np
import jax
import chex
from jaxtyping import Array, Shaped, Num, Int, Float, Bool, PyTree

In [None]:
@chex.dataclass
class Initiator:
    # tentative tracks are temporary tracks maintained by the initiator that
    # have been initialized but not yet confirmed
    tentative_tracks: Num[Array, "... 2"]


In [None]:
import filterpy
from filterpy.kalman import KalmanFilter

# input [x y]
# state [x y dx/dt dy/dt]


# yapf: disable
def F_cv(dt: float|int):
    return np.array([[1, 0, dt, 0],
                     [0, 1, 0, dt],
                     [0, 0, 1, 0],
                     [0, 0, 0, 1]])
# yapf: enable


# yapf: disable
def H_cv():
    return np.array([[1, 0, 0, 0], 
                     [0, 1, 0, 0]])
# yapf: enable


kf = KalmanFilter(4, 2)
T = 1.0
kf.F = F_cv(T)
kf.H = H_cv()
kf.R = np.diag([0.75, 0.75])
kf.Q = np.diag([0.05, 0.05, 0.05, 0.05])
# a simple constant velocity model
# let's have a hypothesis of the initial velocity
# is 0.05 unit/dt in both x and y directions
kf.x = np.array([0, 0, 0.05, 0.05])
display(kf.P)
kf.predict()
# x now becomes x prior
display(kf.x)
display(kf.P)

# kf.update([0.15, 0.15])
# x_posterior 
# when x is updated, it becomes x_posterior

In [None]:
# from predict state to measurement
kf.x_prior # predicted state
# predicted measurement
# https://peps.python.org/pep-0465/
predicted_measurement = kf.H @ kf.x
display(predicted_measurement)
# compare the predicted measurement with the actual measurement
# with mahalanobis distance
actual_measurement = np.array([0.12, 0.12])
# or just euclidean distance
kf.update(actual_measurement)
display(kf.x)
display(kf.P)
display(kf.mahalanobis)
# use mahalanobis distance as a loss function to determine the best match
# Hungarian algorithm
from scipy.optimize import linear_sum_assignment
# https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.linear_sum_assignment.html
# when we get a successively detected object, we can move it into the confirmed tracks
# and remove it from the tentative tracks
# Well, it's more like two GNNs, one for the tentative tracks and one for the confirmed tracks
# cascaded GNN, interesting...

In [None]:
# abstract out the Kalman filter
# by motion model (state)
# and measurement model

# https://github.com/sisl/GaussianFilters.jl
# not consider Input/External effect
from jaxtyping import jaxtyped
from typeguard import typechecked


@chex.dataclass
class LinearMotionNoInputModel:
    F: Num[Array, "n n"]
    Q: Num[Array, "n n"]


@chex.dataclass
class LinearMeasurementModel:
    H: Num[Array, "m n"]
    R: Num[Array, "m m"]


Measurement = Num[Array, "m"]


# a belief of Gaussian
@chex.dataclass
class GaussianState:
    x: Num[Array, "n"]
    P: Num[Array, "n n"]


@jaxtyped(typechecker=typechecked)
def _predict(
    state: GaussianState,
    motion_model: LinearMotionNoInputModel,
) -> GaussianState:
    x = state.x
    P = state.P
    F = motion_model.F
    Q = motion_model.Q
    assert x.shape[0] == F.shape[
        0], "state and transition model are not compatible"
    assert F.shape[0] == F.shape[1], "transition model is not square"
    assert F.shape[0] == Q.shape[
        0], "transition model and noise model are not compatible"
    x_priori = F @ x
    P_priori = F @ P @ F.T + Q
    return GaussianState(x=x_priori, P=P_priori)


@chex.dataclass
class PosterioriResult:
    # updated state
    state: GaussianState
    innovation: Num[Array, "m"]
    posteriori_measurement: Num[Array, "m"]
    mahalanobis_distance: Num[Array, "m"]
    # post-fit residual
    # y = z - H @ x_posteriori


@jaxtyped(typechecker=typechecked)
def update(
    measurement: Measurement,
    state: GaussianState,
    measure_model: LinearMeasurementModel,
) -> PosterioriResult:
    x = state.x
    P = state.P
    H = measure_model.H
    R = measure_model.R
    assert x.shape[0] == H.shape[
        1], "state and measurement model are not compatible"
    assert H.shape[0] == R.shape[0], "measurement model is not square"
    assert H.shape[0] == R.shape[1], "measurement model is not square"
    z = measurement
    inv = jnp.linalg.inv
    # innovation
    # the priori measurement residual
    y = z - H @ x
    # innovation covariance
    S = H @ P @ H.T + R
    # Kalman gain
    K = P @ H.T @ inv(S)
    # posteriori state
    x_posteriori = x + K @ y
    # dummy identity matrix
    I = jnp.eye(P.shape[0])
    # posteriori covariance
    I_KH = I - K @ H
    P_posteriori = I_KH @ P @ I_KH.T + K @ R @ K.T
    posteriori_state = GaussianState(x=x_posteriori, P=P_posteriori)
    posteriori_measurement = H @ x_posteriori
    return PosterioriResult(
        state=posteriori_state,
        innovation=y,
        posteriori_measurement=posteriori_measurement,
        mahalanobis_distance=jnp.sqrt(y.T @ inv(S) @ y),
    )


In [None]:
def cv_model(
    v_x: float,
    v_y: float,
    dt: float,
    q: float,
    r: float,
) -> Tuple[
        LinearMotionNoInputModel,
        LinearMeasurementModel,
        GaussianState,
]:
    """
    Create a constant velocity model with no input
    
    Args:
    v_x: initial velocity in x direction
    v_y: initial velocity in y direction
    dt: time interval
    q: process noise
    r: measurement noise

    Returns:
    motion_model: motion model
    measure_model: measurement model
    state: initial state
    """
    # yapf: disable
    F = jnp.array([[1, 0, dt, 0],
                        [0, 1, 0, dt],
                        [0, 0, 1, 0],
                        [0, 0, 0, 1]])
    H = jnp.array([[1, 0, 0, 0],
                        [0, 1, 0, 0]])
    # yapf: enable
    Q = q * jnp.eye(4)
    R = r * jnp.eye(2)
    P = jnp.eye(4)
    motion_model = LinearMotionNoInputModel(F=F, Q=Q)
    measure_model = LinearMeasurementModel(H=H, R=R)
    state = GaussianState(x=jnp.array([0, 0, v_x, v_y]), P=P)
    return motion_model, measure_model, state

In [None]:
mo_model, me_model, st = cv_model(0.05, 0.05, 1.0, 0.05, 0.75)

# predict
new_st = _predict(st, mo_model)
# update
res = update(jnp.array([0.12, 0.12]), new_st, me_model)

In [None]:
# def to_ak_record(dict_like: Dict[str, Any] | Any) -> ak.Record:
#     return ak.Record(dict_like.__dict__)
s = ak.Array([st.__dict__, new_st.__dict__])
display(s)

In [None]:
@jaxtyped(typechecker=typechecked)
def outer_distance(x: Num[Array, "a 2"], y: Num[Array,
                                                "b 2"]) -> Num[Array, "a b"]:
    """
    Here's equivalent python code:
    
    ```python
    res = jnp.empty((x.shape[0], y.shape[0]))
    for i in range(x.shape[0]):
        for j in range(y.shape[0]):
            # res[i, j] = jnp.linalg.norm(x[i] - y[j])
            res = res.at[i, j].set(jnp.linalg.norm(x[i] - y[j]))
    return res
    ```

    See Also
    --------
    `outer product <https://en.wikipedia.org/wiki/Outer_product>`_
    """

    @jit
    def go(x, y):
        x_expanded = x[:, None, :]
        y_expanded = y[None, :, :]
        diff = y_expanded - x_expanded
        return jnp.linalg.norm(diff, axis=-1)

    return go(x, y)


In [None]:
from dataclasses import dataclass
from typing import Generator, TypedDict
# https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.linear_sum_assignment.html
# https://github.com/google/jax/issues/10403
from scipy.optimize import linear_sum_assignment

AKArray = ak.Array

# register the JAX backend
ak.jax.register_and_check()  # type: ignore


@chex.dataclass
class Tracking:
    id: int
    state: GaussianState
    survived_time_steps: int
    missed_time_steps: int


@dataclass
class TrackerParams:
    dt: float = 1.0
    cov_threshold: float = 4.0
    tentative_mahalanobis_threshold: float = 10.0
    confirm_mahalanobis_threshold: float = 10.0
    forming_tracks_euclidean_threshold: float = 25.0
    survival_steps_threshold: int = 3


class Tracker:
    """
    A simple GNN tracker
    """
    _last_measurements: Float[Array, "... 2"] = jnp.empty((0, 2),
                                                          dtype=jnp.float32)
    _tentative_tracks: list[Tracking] = []
    _confirmed_tracks: list[Tracking] = []
    _last_id: int = 0

    def __init__(self):
        self._last_measurements = jnp.array([], dtype=jnp.float32)
        self._tentative_tracks = []
        self._confirmed_tracks = []

    @staticmethod
    def _predict(tracks: list[Tracking], dt: float = 1.0):
        return [
            Tracking(
                id=track.id,
                state=_predict(track.state, Tracker.motion_model(dt=dt)),
                survived_time_steps=track.survived_time_steps,
                missed_time_steps=track.missed_time_steps,
            ) for track in tracks
        ]

    @staticmethod
    def _data_associate_and_update(
            measurements: Float[Array, "... 2"],
            tracks: list[Tracking],
            distance_threshold: float = 3) -> Float[Array, "... 2"]:
        """
        Match tracks with measurements and update the tracks

        Parameters
        ----------
        [in] measurements: Float["a 2"]
        [in,out] tracks: Tracking["b"]

        Returns
        ----------
        return 
            Float["... 2"] the unmatched measurements
        
        Effect
        ----------
        find the best match by minimum Mahalanobis distance, please note that I assume the state has been predicted
        """
        if len(tracks) == 0:
            return measurements

        def _update(measurement: Float[Array, "a 2"], tracking: Tracking):
            return update(measurement, tracking.state,
                          Tracker.measurement_model())

        def outer_posteriori(
                measurements: Float[Array, "a 2"],
                tracks: list[Tracking]) -> list[list[PosterioriResult]]:
            """
            calculate the outer posteriori for each measurement and track

            Parameters
            ----------
            [in] measurements: Float["a 2"]
            [in] tracks: Tracking["b"]

            Returns
            ----------
            PosterioriResult["a b"]
            """
            return [[
                _update(measurement, tracking) for measurement in measurements
            ] for tracking in tracks]

        def posteriori_to_mahalanobis(
                posteriori: list[list[PosterioriResult]]
        ) -> Float[Array, "a b"]:
            """
            Parameters
            ----------
            [in] posteriori: PosterioriResult["a b"]

            Returns
            ----------
            Float["a b"]
            """
            return jnp.array(
                [[r_m.mahalanobis_distance for r_m in p_t] for p_t in posteriori
                ],
                dtype=jnp.float32)

        posteriors = outer_posteriori(measurements, tracks)
        distances = posteriori_to_mahalanobis(posteriors)
        row, col = linear_sum_assignment(np.array(distances))
        row = jnp.array(row)
        col = jnp.array(col)

        def to_be_deleted() -> Generator[Tuple[int, int], None, None]:
            for i, j in zip(row, col):
                post: PosterioriResult = posteriors[i][j]
                if post.mahalanobis_distance > distance_threshold:
                    yield i, j

        for i, j in to_be_deleted():
            row = row[row != i]
            col = col[col != j]

        for i, j in zip(row, col):
            track: Tracking = tracks[i]
            post: PosterioriResult = posteriors[i][j]
            track.state = post.state
            track.survived_time_steps += 1
            tracks[i] = track

        for i, track in enumerate(tracks):
            if i not in row:
                # reset the survived time steps once missed
                track.missed_time_steps += 1
                tracks[i] = track
        # remove measurements that have been matched
        left_measurements = jnp.delete(measurements, col, axis=0)
        return left_measurements

    def _tracks_from_past_measurements(self,
                                       measurements: Float[Array, "... 2"],
                                       dt: float = 1.0,
                                       distance_threshold: float = 3.0):
        """
        consume the last measurements and create tentative tracks from them

        Note
        ----
        mutate self._tentative_tracks and self._last_measurements
        """
        if self._last_measurements.shape[0] == 0:
            self._last_measurements = measurements
            return
        distances = outer_distance(self._last_measurements, measurements)
        row, col = linear_sum_assignment(distances)
        row = jnp.array(row)
        col = jnp.array(col)

        def to_be_deleted() -> Generator[Tuple[int, int], None, None]:
            for i, j in zip(row, col):
                euclidean_distance = distances[i, j]
                if euclidean_distance > distance_threshold:
                    yield i, j

        for i, j in to_be_deleted():
            row = row[row != i]
            col = col[col != j]

        for i, j in zip(row, col):
            coord = measurements[j]
            vel = (coord - self._last_measurements[i]) / dt
            s = jnp.concatenate([coord, vel])
            state = GaussianState(x=s, P=jnp.eye(4))
            track = Tracking(id=self._last_id,
                             state=state,
                             survived_time_steps=0,
                             missed_time_steps=0)
            self._last_id += 1
            self._tentative_tracks.append(track)
        # update the last measurements with the unmatched measurements
        self._last_measurements = jnp.delete(measurements, col, axis=0)

    def _transfer_tentative_to_confirmed(self,
                                        survival_steps_threshold: int = 3):
        """
        transfer tentative tracks to confirmed tracks

        Note
        ----
        mutate self._tentative_tracks and self._confirmed_tracks in place
        """
        for i, track in enumerate(self._tentative_tracks):
            if track.survived_time_steps > survival_steps_threshold:
                self._confirmed_tracks.append(track)
                self._tentative_tracks.pop(i)

    @staticmethod
    def _track_cov_deleter(tracks: list[Tracking], cov_threshold: float = 4.0):
        """
        delete tracks with covariance trace greater than threshold

        Parameters
        ----------
        [in,out] tracks: list[Tracking]
        cov_threshold: float
            the threshold of the covariance trace

        Note
        ----
        mutate tracks in place
        """
        for i, track in enumerate(tracks):
            # https://numpy.org/doc/stable/reference/generated/numpy.trace.html
            if jnp.trace(track.state.P) > cov_threshold:
                tracks.pop(i)

    def next_measurements(self, measurements: Float[Array, "... 2"],
                          params: TrackerParams):
        self._confirmed_tracks = self._predict(self._confirmed_tracks,
                                               params.dt)
        self._tentative_tracks = self._predict(self._tentative_tracks,
                                               params.dt)
        left_ = self._data_associate_and_update(
            measurements, self._confirmed_tracks,
            params.confirm_mahalanobis_threshold)
        left = self._data_associate_and_update(
            left_, self._tentative_tracks,
            params.tentative_mahalanobis_threshold)
        self._transfer_tentative_to_confirmed(params.survival_steps_threshold)
        self._tracks_from_past_measurements(
            left, params.dt, params.forming_tracks_euclidean_threshold)
        self._track_cov_deleter(self._tentative_tracks, params.cov_threshold)
        self._track_cov_deleter(self._confirmed_tracks, params.cov_threshold)

    @property
    def confirmed_tracks(self):
        return self._confirmed_tracks

    @staticmethod
    def motion_model(dt: float = 1,
                     q: float = 0.05) -> LinearMotionNoInputModel:
        """
        a constant velocity motion model
        """
        # yapf: disable
        F = jnp.array([[1, 0, dt, 0],
                            [0, 1, 0, dt],
                            [0, 0, 1, 0],
                            [0, 0, 0, 1]])
        # yapf: enable
        Q = q * jnp.eye(4)
        return LinearMotionNoInputModel(F=F, Q=Q)

    @staticmethod
    def measurement_model(r: float = 0.75) -> LinearMeasurementModel:
        # yapf: disable
        H = jnp.array([[1, 0, 0, 0],
                            [0, 1, 0, 0]])
        # yapf: enable
        R = r * jnp.eye(2)
        return LinearMeasurementModel(H=H, R=R)


In [None]:
def gen_measurements() -> Generator[Float[Array, "... 2"], None, None]:
    for m_cXs, m_cYs in zip(cXs, cYs):
        nxs = m_cXs.to_numpy()
        nys = m_cYs.to_numpy()
        xs = jnp.array(nxs)
        ys = jnp.array(nys)
        yield jnp.column_stack([xs, ys])

In [None]:
tracker = Tracker()

tenative_histories: list[list[Tracking]] = []
confirmed_histories: list[list[Tracking]] = []

params = TrackerParams(
    cov_threshold=25.0,
    tentative_mahalanobis_threshold=50.0,
    confirm_mahalanobis_threshold=25.0,
    forming_tracks_euclidean_threshold=20,
    dt=1.0,
    survival_steps_threshold=6,
)

for measurement in gen_measurements():
    m = jnp.array(measurement)
    tracker.next_measurements(m, params)
    tenative_histories.append(tracker._tentative_tracks.copy())
    confirmed_histories.append(tracker._confirmed_tracks.copy())

In [None]:
import cv2 as cv
import cv2
from cv2.typing import MatLike
from loguru import logger

@dataclass
class CapProps:
    width: int
    height: int
    fps: float
    frame_count: Optional[int] = None

def fourcc(*args: str) -> int:
    return cv2.VideoWriter_fourcc(*args)  # type: ignore

def video_cap(
        src: str | int) -> Tuple[Generator[MatLike, None, None], CapProps]:
    cap = cv2.VideoCapture(src)
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = float(cap.get(cv2.CAP_PROP_FPS))
    frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    props = CapProps(width, height, fps, frame_count)

    def gen():
        while True:
            ret, frame = cap.read()
            if not ret:
                break
            yield frame
        cap.release()

    return gen(), props

In [None]:
frames, props = video_cap("PETS09-S2L1-raw.mp4")
writer = cv2.VideoWriter("PETS09-S2L1-tracking.mp4",
                         fourcc(*"mp4v"),
                         props.fps, (props.width, props.height),
                         isColor=True)

class RawDataDict(TypedDict):
    x: int
    y: int
    w: int
    h: int
    area: float
    cX: int
    cY: int

display(props)

try:
    colors = np.random.randint(0, 255, size=(1024, 3))
    for frame, tentative_tracks, confirmed_tracks, raws in zip(
            frames, tenative_histories, confirmed_histories, data): # type: ignore
        for raw in raws:
            x, y, w, h, area, cX, cY = raw.x, raw.y, raw.w, raw.h, raw.area, raw.cX, raw.cY
            cv.rectangle(frame, (x, y), (x + w, y + h), (0, 255, 0), 2)
        # generate a random color map for each track
        for track in tentative_tracks:
            x, y = track.state.x[:2]
            color_ = colors[track.id]
            color = tuple(color_.tolist())
            # cv.rectangle(frame, (int(x - 5), int(y - 5), 10, 10), color, -1)
        for track in confirmed_tracks:
            x, y = track.state.x[:2]
            color_ = colors[track.id]
            color = tuple(color_.tolist())
            cv.circle(frame, (int(x), int(y)), 5, color, -1)
        writer.write(frame)
except Exception as e:
    logger.exception(e)
finally:
    writer.release()