In [5]:
import datetime
from itertools import tee

import numpy as np

from stonesoup.types.array import StateVector, CovarianceMatrix
from stonesoup.types.state import GaussianState

initial_state_mean = StateVector([[0], [0], [0], [0]])
initial_state_covariance = CovarianceMatrix(np.diag([4, 0.5, 4, 0.5]))
timestep_size = datetime.timedelta(seconds=5)
number_of_steps = 25
birth_rate = 0.3
death_probability = 0.05
initial_state = GaussianState(initial_state_mean, initial_state_covariance)

from stonesoup.models.transition.linear import (
    CombinedLinearGaussianTransitionModel, ConstantVelocity)

transition_model = CombinedLinearGaussianTransitionModel(
    [ConstantVelocity(0.05), ConstantVelocity(0.05)])

from stonesoup.simulator.simple import MultiTargetGroundTruthSimulator

groundtruth_sims = [
    MultiTargetGroundTruthSimulator(
        transition_model=transition_model,
        initial_state=initial_state,
        timestep=timestep_size,
        number_steps=number_of_steps,
        birth_rate=birth_rate,
        death_probability=death_probability,
        initial_number_targets=100)
    for _ in range(3)]

from stonesoup.simulator.simple import SimpleDetectionSimulator
from stonesoup.models.measurement.linear import LinearGaussian

# initialise the measurement model
measurement_model_covariance = np.diag([0.25, 0.25])
measurement_model = LinearGaussian(4, [0, 2], measurement_model_covariance)

# probability of detection
probability_detection = 0.9

# clutter will be generated uniformly in this are around the target
clutter_area = np.array([[-1, 1], [-1, 1]]) * 500
clutter_rate = 50

detection_sims = [
   SimpleDetectionSimulator(
        groundtruth=groundtruth_sim,
        measurement_model=measurement_model,
        detection_probability=probability_detection,
        meas_range=clutter_area,
        clutter_rate=clutter_rate)
   for groundtruth_sim in groundtruth_sims]

# Use tee to create 3 identical versions, for GNN, k-D tree and TPR-Tree
sim_sets = [tee(sim, 3) for sim in detection_sims]
sim_sets = list(zip(*sim_sets))

# tracker predictor and updater
from stonesoup.predictor.kalman import KalmanPredictor
from stonesoup.updater.kalman import KalmanUpdater

# initiator and deleter
from stonesoup.deleter.error import CovarianceBasedDeleter
from stonesoup.initiator.simple import MultiMeasurementInitiator

# tracker
from stonesoup.tracker.simple import MultiTargetTracker

# timer for comparing run time of k-d tree algorithm with linear search
import time as timer

# create predictor and updater
predictor = KalmanPredictor(transition_model)
updater = KalmanUpdater(measurement_model)

# create hypothesiser
from stonesoup.hypothesiser.distance import DistanceHypothesiser
from stonesoup.measures import Mahalanobis

hypothesiser = DistanceHypothesiser(predictor, updater, measure=Mahalanobis(), missed_distance=3)

# create KD-tree data associator
from stonesoup.dataassociator.tree import DetectionKDTreeGNN2D

KD_data_associator = DetectionKDTreeGNN2D(hypothesiser=hypothesiser,
                                          predictor=predictor,
                                          updater=updater,
                                          number_of_neighbours=3,
                                          max_distance_covariance_multiplier=3)

run_times_KDTree = []

# run loop to calculate average run time
for n, detection_sim in enumerate(sim_sets[0]):
    start_time = timer.perf_counter()

    # create tracker components
    predictor = KalmanPredictor(transition_model)
    updater = KalmanUpdater(measurement_model)

    # create deleter
    covariance_limit_for_delete = 100
    deleter = CovarianceBasedDeleter(covar_trace_thresh=covariance_limit_for_delete)

    # create initiator
    s_prior_state = GaussianState([[0], [0], [0], [0]], np.diag([0, 0.5, 0, 0.5]))
    min_detections = 3

    initiator = MultiMeasurementInitiator(
        prior_state=s_prior_state,
        measurement_model=measurement_model,
        deleter=deleter,
        data_associator=KD_data_associator,
        updater=updater,
        min_points=min_detections
    )

    # create tracker
    tracker = MultiTargetTracker(
        initiator=initiator,
        deleter=deleter,
        detector=detection_sim,
        data_associator=KD_data_associator,
        updater=updater
    )

    # run tracker
    groundtruth = set()
    detections = set()
    tracks = set()

    for time, ctracks in tracker:
        groundtruth.update(groundtruth_sims[n].groundtruth_paths)
        detections.update(detection_sims[n].detections)
        tracks.update(ctracks)

    end_time = timer.perf_counter()
    run_time = end_time - start_time

    run_times_KDTree.append(run_time)
    
from stonesoup.plotter import Plotterly

plotter = Plotterly()
plotter.plot_ground_truths(groundtruth, mapping=[0, 2])
plotter.plot_measurements(detections, mapping=[0, 2])
plotter.plot_tracks(tracks, mapping=[0, 2])
plotter.fig