In [3]:
import matplotlib.pyplot as plt
import seaborn as sns
import polars as pl
import numpy as np
from numpy.typing import NDArray
from typing import Sequence, Any, Tuple, Union, List, Dict, Optional, TypeVar, Callable, Iterable, cast
from datetime import datetime, timedelta

In [4]:
from stonesoup.models.transition.linear import CombinedLinearGaussianTransitionModel, \
                                               ConstantVelocity
from stonesoup.types.groundtruth import GroundTruthPath, GroundTruthState
from stonesoup.types.detection import TrueDetection
from stonesoup.types.detection import Clutter
from stonesoup.models.measurement.linear import LinearGaussian
from stonesoup.types.detection import Detection
from stonesoup.plotter import AnimatedPlotterly
from stonesoup.predictor.kalman import KalmanPredictor
from stonesoup.updater.kalman import KalmanUpdater
from stonesoup.hypothesiser.distance import DistanceHypothesiser
from stonesoup.measures import Mahalanobis
from stonesoup.dataassociator.neighbour import GlobalNearestNeighbour, GNNWith2DAssignment
from stonesoup.dataassociator.probability import PDAHypothesiser, JPDA
from stonesoup.deleter.error import CovarianceBasedDeleter
from stonesoup.types.state import GaussianState
from stonesoup.initiator.simple import MultiMeasurementInitiator

In [5]:
import awkward as ak

data = ak.from_parquet("detections.parquet")
display(data.typestr)

'795 * var * {x: int64, y: int64, w: int64, h: int64, area: float64, cX: int64, cY: int64}'

In [6]:
# kalman
transition_model = CombinedLinearGaussianTransitionModel(
    [ConstantVelocity(0.05), ConstantVelocity(0.05)])
measurement_model = LinearGaussian(ndim_state=4,
                                   mapping=(0, 2),
                                   noise_covar=np.diag([0.75, 0.75]))
predictor = KalmanPredictor(transition_model)
updater = KalmanUpdater(measurement_model)

hypothesiser = DistanceHypothesiser(predictor,
                                    updater,
                                    measure=Mahalanobis(),
                                    missed_distance=3)
data_associator = GlobalNearestNeighbour(hypothesiser)

deleter = CovarianceBasedDeleter(covar_trace_thresh=4)
initiator = MultiMeasurementInitiator(
    prior_state=GaussianState([[0], [0], [0], [0]], np.diag([0, 1, 0, 1])),
    measurement_model=measurement_model,
    deleter=deleter,
    data_associator=data_associator,
    updater=updater,
    min_points=2,
)

In [7]:
detections_coords = data[:, :, ["cX", "cY"]]
# use unzip to separate the x and y coordinates
# display(detections_coords)

In [8]:
# https://github.com/rlabbe/filterpy
import jax.numpy as jnp
from jax import random, vmap, jit, grad, value_and_grad
import numpy as np
import jax
import chex
from jaxtyping import Array, Shaped, Num, Int, Float, Bool

In [9]:
@chex.dataclass
class Initiator:
    # tentative tracks are temporary tracks maintained by the initiator that
    # have been initialized but not yet confirmed
    tentative_tracks: Num[Array, "... 2"]


In [37]:
import filterpy
from filterpy.kalman import KalmanFilter

# input [x y]
# state [x y dx/dt dy/dt]


# yapf: disable
def F_cv(dt: float|int):
    return np.array([[1, 0, dt, 0],
                     [0, 1, 0, dt],
                     [0, 0, 1, 0],
                     [0, 0, 0, 1]])
# yapf: enable


# yapf: disable
def H_cv():
    return np.array([[1, 0, 0, 0], 
                     [0, 1, 0, 0]])
# yapf: enable


kf = KalmanFilter(4, 2)
T = 1.0
kf.F = F_cv(T)
kf.H = H_cv()
kf.R = np.diag([0.75, 0.75])
kf.Q = np.diag([0.05, 0.05, 0.05, 0.05])
# a simple constant velocity model
# let's have a hypothesis of the initial velocity
# is 0.05 unit/dt in both x and y directions
kf.x = np.array([0, 0, 0.05, 0.05])
display(kf.P)
kf.predict()
# x now becomes x prior
display(kf.x)
display(kf.P)

# kf.update([0.15, 0.15])
# x_posterior 
# when x is updated, it becomes x_posterior

array([[1., 0., 0., 0.],
       [0., 1., 0., 0.],
       [0., 0., 1., 0.],
       [0., 0., 0., 1.]])

array([0.05, 0.05, 0.05, 0.05])

array([[2.05, 0.  , 1.  , 0.  ],
       [0.  , 2.05, 0.  , 1.  ],
       [1.  , 0.  , 1.05, 0.  ],
       [0.  , 1.  , 0.  , 1.05]])

In [31]:
# from predict state to measurement
kf.x_prior # predicted state
# predicted measurement
# https://peps.python.org/pep-0465/
predicted_measurement = kf.H @ kf.x
display(predicted_measurement)
# compare the predicted measurement with the actual measurement
# with mahalanobis distance
actual_measurement = np.array([0.12, 0.12])
# or just euclidean distance
kf.update(actual_measurement)
display(kf.x)
display(kf.P)
display(kf.mahalanobis)
# use mahalanobis distance as a loss function to determine the best match
# Hungarian algorithm
from scipy.optimize import linear_sum_assignment
# https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.linear_sum_assignment.html
# when we get a successively detected object, we can move it into the confirmed tracks
# and remove it from the tentative tracks
# Well, it's more like two GNNs, one for the tentative tracks and one for the confirmed tracks
# cascaded GNN, interesting...

array([0.05, 0.05])

array([0.10125, 0.10125, 0.075  , 0.075  ])

array([[0.54910714, 0.        , 0.26785714, 0.        ],
       [0.        , 0.54910714, 0.        , 0.26785714],
       [0.26785714, 0.        , 0.69285714, 0.        ],
       [0.        , 0.26785714, 0.        , 0.69285714]])

0.059160797830996155

In [52]:
# abstract out the Kalman filter
# by motion model (state)
# and measurement model

# https://github.com/sisl/GaussianFilters.jl
# not consider Input/External effect
from jaxtyping import jaxtyped
from typeguard import typechecked


@chex.dataclass
class LinearMotionNoInputModel:
    F: Num[Array, "n n"]
    Q: Num[Array, "n n"]


@chex.dataclass
class LinearMeasurementModel:
    H: Num[Array, "m n"]
    R: Num[Array, "m m"]


Measurement = Num[Array, "m"]


# a belief of Gaussian
@chex.dataclass
class GaussianState:
    x: Num[Array, "n"]
    P: Num[Array, "n n"]


@jaxtyped(typechecker=typechecked)
def predict(
    state: GaussianState,
    motion_model: LinearMotionNoInputModel,
) -> GaussianState:
    x = state.x
    P = state.P
    F = motion_model.F
    Q = motion_model.Q
    assert x.shape[0] == F.shape[
        0], "state and transition model are not compatible"
    assert F.shape[0] == F.shape[1], "transition model is not square"
    assert F.shape[0] == Q.shape[
        0], "transition model and noise model are not compatible"
    x_priori = F @ x
    P_priori = F @ P @ F.T + Q
    return GaussianState(x=x_priori, P=P_priori)


@chex.dataclass
class PosterioriResult:
    # updated state
    state: GaussianState
    innovation: Num[Array, "m"]
    posteriori_measurement: Num[Array, "m"]
    mahalanobis_distance: Num[Array, "m"]
    # post-fit residual
    # y = z - H @ x_posteriori


@jaxtyped(typechecker=typechecked)
def update(
    measurement: Measurement,
    state: GaussianState,
    measure_model: LinearMeasurementModel,
) -> PosterioriResult:
    x = state.x
    P = state.P
    H = measure_model.H
    R = measure_model.R
    assert x.shape[0] == H.shape[
        1], "state and measurement model are not compatible"
    assert H.shape[0] == R.shape[0], "measurement model is not square"
    assert H.shape[0] == R.shape[1], "measurement model is not square"
    z = measurement
    inv = jnp.linalg.inv
    # innovation
    # the priori measurement residual
    y = z - H @ x
    # innovation covariance
    S = H @ P @ H.T + R
    # Kalman gain
    K = P @ H.T @ inv(S)
    # posteriori state
    x_posteriori = x + K @ y
    # dummy identity matrix
    I = jnp.eye(P.shape[0])
    # posteriori covariance
    I_KH = I - K @ H
    P_posteriori = I_KH @ P @ I_KH.T + K @ R @ K.T
    posteriori_state = GaussianState(x=x_posteriori, P=P_posteriori)
    posteriori_measurement = H @ x_posteriori
    return PosterioriResult(
        state=posteriori_state,
        innovation=y,
        posteriori_measurement=posteriori_measurement,
        mahalanobis_distance=jnp.sqrt(y.T @ inv(S) @ y),
    )


In [54]:
def cv_model(
    v_x: float,
    v_y: float,
    dt: float,
    q: float,
    r: float,
) -> Tuple[
        LinearMotionNoInputModel,
        LinearMeasurementModel,
        GaussianState,
]:
    """
    Create a constant velocity model with no input
    
    Args:
    v_x: initial velocity in x direction
    v_y: initial velocity in y direction
    dt: time interval
    q: process noise
    r: measurement noise

    Returns:
    motion_model: motion model
    measure_model: measurement model
    state: initial state
    """
    # yapf: disable
    F = jnp.array([[1, 0, dt, 0],
                        [0, 1, 0, dt],
                        [0, 0, 1, 0],
                        [0, 0, 0, 1]])
    H = jnp.array([[1, 0, 0, 0],
                        [0, 1, 0, 0]])
    # yapf: enable
    Q = q * jnp.eye(4)
    R = r * jnp.eye(2)
    P = jnp.eye(4)
    motion_model = LinearMotionNoInputModel(F=F, Q=Q)
    measure_model = LinearMeasurementModel(H=H, R=R)
    state = GaussianState(x=jnp.array([0, 0, v_x, v_y]), P=P)
    return motion_model, measure_model, state

In [36]:
mo_model, me_model, st = cv_model(0.05, 0.05, 1.0, 0.05, 0.75)

# predict
new_st = predict(st, mo_model)
# update
res = update(jnp.array([0.12, 0.12]), new_st, me_model)
display(new_st)
display(res)

GaussianState(x=Array([0.05, 0.05, 0.05, 0.05], dtype=float32), P=Array([[2.05, 0.  , 1.  , 0.  ],
       [0.  , 2.05, 0.  , 1.  ],
       [1.  , 0.  , 1.05, 0.  ],
       [0.  , 1.  , 0.  , 1.05]], dtype=float32))

PosterioriResult(state=GaussianState(x=Array([0.10124999, 0.10124999, 0.075     , 0.075     ], dtype=float32), P=Array([[0.54910713, 0.        , 0.26785713, 0.        ],
       [0.        , 0.54910713, 0.        , 0.26785713],
       [0.26785713, 0.        , 0.6928571 , 0.        ],
       [0.        , 0.26785713, 0.        , 0.6928571 ]], dtype=float32)), innovation=Array([0.06999999, 0.06999999], dtype=float32), posteriori_measurement=Array([0.10124999, 0.10124999], dtype=float32), mahalanobis_distance=Array(0.05916079, dtype=float32))

In [59]:
st.__dict__

{'x': Array([0.  , 0.  , 0.05, 0.05], dtype=float32),
 'P': Array([[1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 0., 1.]], dtype=float32)}

In [61]:
# def to_ak_record(dict_like: Dict[str, Any] | Any) -> ak.Record:
#     return ak.Record(dict_like.__dict__)
s = ak.Array([st.__dict__, new_st.__dict__])
display(s)

In [66]:
# display(s.layout)
# v = ak.Array([])
# v.layout = s.layout
# display(v.typestr)

'2 * {x: var * float64, P: var * var * float64}'

In [73]:
from dataclasses import dataclass

AKArray = ak.Array

# register the JAX backend
ak.jax.register_and_check()  # type: ignore

# https://awkward-array.org/doc/main/user-guide/how-to-create-records.html
class Track(ak.Record):
    x: Num[Array, "4"]
    P: Num[Array, "4 4"]
    survived_time_steps: int
    missed_time_steps: int


class Tracker:
    # [n {x: Num[Array, "4"], P: Num[Array, "4 4"], survived_time_steps:int, missed_time_steps:int}]
    _tentative_tracks: AKArray = ak.Array([],
                                          with_name="tentative_tracks",
                                          backend="jax")
    # [m {x: Num[Array, "4"], P: Num[Array, "4 4"]}]
    _confirmed_tracks: AKArray = ak.Array([],
                                          with_name="confirmed_tracks",
                                          backend="jax")

    def next_measurements(self, measurements: Float[Array, "... 2"]):
        ...

    @property
    def tentative_tracks(self) -> AKArray:
        return self._tentative_tracks

    @property
    def confirmed_tracks(self) -> AKArray:
        return self._confirmed_tracks

    @staticmethod
    def motion_model(dt: float = 1,
                     q: float = 0.05) -> LinearMotionNoInputModel:
        # yapf: disable
        F = jnp.array([[1, 0, dt, 0],
                            [0, 1, 0, dt],
                            [0, 0, 1, 0],
                            [0, 0, 0, 1]])
        # yapf: enable
        Q = q * jnp.eye(4)
        return LinearMotionNoInputModel(F=F, Q=Q)

    @staticmethod
    def measurement_model(r: float = 0.75) -> LinearMeasurementModel:
        # yapf: disable
        H = jnp.array([[1, 0, 0, 0],
                            [0, 1, 0, 0]])
        # yapf: enable
        R = r * jnp.eye(2)
        return LinearMeasurementModel(H=H, R=R)


tracker = Tracker()
