In [None]:
"""
MCMC-Based Particle Filter for Multi-Target Tracking
====================================================

This script demonstrates:
 1) Ground-truth generation for up to 5 targets (with births/deaths).
 2) Poisson clutter + target-origin measurements.
 3) An MCMC-based particle filter to estimate the (x, y) positions and (vx, vy) velocities
    of each target, and whether each target exists or not at each time step.
 4) Plots:
    - True tracks vs. estimated tracks (best-weight particle) in x-y plane.
    - Cardinality (true vs. estimated).
 5) A progress bar (via tqdm) to indicate the filtering progress.

"""

import numpy as np
import matplotlib.pyplot as plt
from numpy.random import rand, randn, choice
from scipy.stats import poisson, multivariate_normal
from tqdm import tqdm  # for progress bar

# ----------------------------------------------------------------------
# 1) GROUND-TRUTH GENERATION
# ----------------------------------------------------------------------
def generate_ground_truth(K=80,  # total number of timesteps
                          tau=3.0,
                          sigma_process=0.5,
                          scenario_params=None):
    """
    Generate ground-truth states for up to 5 targets, with known birth times
    and at least one known death time. Targets follow a near-constant velocity (NCV) model.

    Parameters
    ----------
    K : int
        Number of timesteps.
    tau : float
        Sampling interval (seconds).
    sigma_process : float
        Standard deviation of the process noise for NCV model.
    scenario_params : dict
        Dictionary specifying each target's birth_time, death_time, and initial state.

    Returns
    -------
    ground_truth : dict
        Contains:
          x_true : array, shape=(K, Nmax, 4), the true states for each target
                   (x, vx, y, vy).
          e_true : array, shape=(K, Nmax), the existence indicators (0 or 1).
    """

    # Default scenario: 5 targets
    #   T1, T2, T3 born at k=1
    #   T4, T5 born at k=25
    #   T1 dies at k=50
    if scenario_params is None:
        scenario_params = {
            0: {'birth_time': 1,
                'death_time': 50,
                'init_state': np.array([500.0,   1.0,  500.0,   15.0])},  
            1: {'birth_time': 1,
                'death_time': None,
                'init_state': np.array([1000.0, -1.0, 4000.0, -10.0])},
            2: {'birth_time': 1,
                'death_time': None,
                'init_state': np.array([4000.0,   5.0, 1000.0,   5.0])},
            3: {'birth_time': 25,
                'death_time': None,
                'init_state': np.array([4500.0,  -5.0, 4500.0,   5.0])},
            4: {'birth_time': 25,
                'death_time': None,
                'init_state': np.array([500.0,   10.0, 4500.0,  -15.0])}
        }

    Nmax = len(scenario_params)  # should be 5
    x_true = np.zeros((K, Nmax, 4))
    e_true = np.zeros((K, Nmax), dtype=int)

    # Define the state transition matrix for near-constant velocity (NCV)
    I_2 = np.array([[1, 0],
                    [0, 1]], dtype=float)
    tau_I_2 = np.array([[tau, 0],
                        [0, tau]], dtype=float)
    A = np.block([[I_2,           tau_I_2],
                  [np.zeros((2, 2)), I_2          ]])
    # Process noise covariance
    q = sigma_process**2
    Q = q * np.array([[tau**3/3, 0, tau**2/2,         0],
                      [0,       tau**3/3, 0,         tau**2/2],
                      [tau**2/2,        0,         tau, 0],
                      [0,        tau**2/2,         0,       tau]])

    # Initialize states for the first timestep
    for n in range(Nmax):
        bt = scenario_params[n]['birth_time']
        dt = scenario_params[n]['death_time']
        x0 = scenario_params[n]['init_state'].copy()

        if bt is not None and 1 >= bt:
            x_true[0, n, :] = x0
            e_true[0, n] = 1

    # Generate ground truth for subsequent timesteps
    for k in range(1, K):
        for n in range(Nmax):
            bt = scenario_params[n]['birth_time']
            dt = scenario_params[n]['death_time']

            if bt is not None and k >= bt and (dt is None or k < dt):
                # If k == bt, then the target is just born at this step
                if k == bt:
                    x_true[k, n, :] = scenario_params[n]['init_state']
                    e_true[k, n] = 1
                else:
                    # Continue the NCV model
                    x_prev = x_true[k-1, n, :]
                    mean = A @ x_prev
                    x_true[k, n, :] = np.random.multivariate_normal(mean, Q)
                    e_true[k, n] = 1
            else:
                # Target does not exist at time k
                x_true[k, n, :] = 0.0
                e_true[k, n] = 0

    return {'x_true': x_true, 'e_true': e_true}


# ----------------------------------------------------------------------
# 2) MEASUREMENT GENERATION
# ----------------------------------------------------------------------
def generate_measurements(ground_truth,
                          Lx=5000.0, 
                          Ly=5000.0,
                          LambdaC=20.0,  # Mean clutter rate (Poisson)
                          LambdaX=1.0,   # Mean detection rate per active target (Poisson)
                          Sigma_x=100.0, # Measurement variance (Sigma_x * I2)
                          seed=None):
    """
    Generate a set of measurements for each time step `k` based on the
    'association-free' model described in the article.

    Parameters
    ----------
    ground_truth : dict
        Dictionary containing at least:
          - ground_truth['x_true']: array of shape (K, Nmax, 4),
            where x_true[k, n, :] = (x_k,n, vx_k,n, y_k,n, vy_k,n).
          - ground_truth['e_true']: array of shape (K, Nmax),
            where e_true[k, n] = 1 if target `n` exists at time `k`, otherwise 0.
    Lx, Ly : float
        Dimensions of the surveillance region in meters.
    LambdaC : float
        Mean number of clutter (false alarms) per time step.
    LambdaX : float
        Mean number of detections per active target and per time step.
    Sigma_x : float
        Measurement variance in both x and y directions (covariance is Sigma_x * I2).
    seed : int or None
        Seed for the random number generator (for reproducibility).
        If None, results will be random.

    Returns
    -------
    measurements : list
        A list of length K.
        measurements[k] is a numpy array of shape (M_k, 2),
        where M_k is the total number of measurements at time step `k` (target + clutter).
    """
    if seed is not None:
        np.random.seed(seed)  # For optional reproducibility

    x_true = ground_truth['x_true']  # (K, Nmax, 4)
    e_true = ground_truth['e_true']  # (K, Nmax)
    K, Nmax, _ = x_true.shape

    # Measurement covariance matrix: Sigma_x * I2
    R_meas = Sigma_x * np.eye(2)

    # List to store measurements for each time step k
    measurements = []

    # Loop through each time step k
    for k in range(K):
        # -- 1) Identify active targets
        active_targets = np.where(e_true[k, :] == 1)[0]

        # -- 2) Generate target-originated measurements
        #     For each active target, draw the number of measurements from Poisson(LambdaX).
        target_meas_list = []
        for n in active_targets:
            # Number of measurements for target `n`
            count_target_n = poisson.rvs(LambdaX)
            if count_target_n > 0:
                # Mean of the Gaussian distribution for target `n`: (x_k,n, y_k,n)
                mean_xy = np.array([
                    x_true[k, n, 0],  # x_k,n
                    x_true[k, n, 2]   # y_k,n
                ])
                # Generate Gaussian measurements
                z_n = np.random.multivariate_normal(mean_xy, R_meas, size=count_target_n)
                target_meas_list.append(z_n)

        # -- 3) Generate clutter (false alarms)
        #     Draw the number of clutter points from Poisson(LambdaC).
        clutter_count = poisson.rvs(LambdaC)
        # Clutter measurements are uniformly distributed in [0, Lx] x [0, Ly].
        clutter_meas = np.column_stack([
            np.random.uniform(0, Lx, size=clutter_count),
            np.random.uniform(0, Ly, size=clutter_count)
        ])

        # -- 4) Combine target and clutter measurements
        if len(target_meas_list) > 0:
            all_target_meas = np.vstack(target_meas_list)  # Concatenate all target measurements
            all_meas = np.vstack([all_target_meas, clutter_meas])
        else:
            # If there are no active targets or Poisson(LambdaX)=0, only clutter measurements remain
            all_meas = clutter_meas

        measurements.append(all_meas)

    return measurements


# ----------------------------------------------------------------------
# 3) MCMC-BASED PARTICLE FILTER CLASS
# ----------------------------------------------------------------------
class MCMCParticleFilter:
    """
    MCMC-Based Particle Filter for a time-varying number of targets, as described
    in Section III–IV of the article.

    The targets are modeled with:
      - Existence variables e_{k,n} in {0,1} (Eq. (6))
      - Kinematics x_{k,n} = (x, vx, y, vy) (fixed dimension Nmax x 4)
      - Transition densities:
         * Birth:   Eq. (8)
         * Death:   Eq. (9)
         * Update:  Eq. (10)
      - Association-free measurement model: Eqs. (12)–(13)
      - MCMC approach: Eqs. (14)–(15) for refinement
    """

    def __init__(self,
                 Np=4000,
                 Nburn=1000,
                 thinning=6,
                 PB=0.01,       # Probability of birth, Eq. (6)
                 PD=0.01,       # Probability of death, Eq. (6)
                 sigma_process=0.5,  # std. dev. squared in NCV model
                 tau=3.0,       # sampling interval
                 R=100.0,       # measurement noise covariance = 100 * I2
                 LambdaC=20.0,  # clutter rate
                 LambdaX=1.0,   # detection rate per active target
                 Lx=5000.0,
                 Ly=5000.0,
                 x_death=None,
                 Vmax=20.0):    # max speed for birth sampling
        """
        Initialize MCMC-based Particle Filter with default parameters from the article:
          - Np=4000 (num. particles)
          - Nburn=1000 (burn-in for MCMC)
          - thinning=6
          - PB, PD in [0,1]
          - sigma_process=0.5
          - R=100 => CovMeas = 100 * I2
          - LambdaC=20, LambdaX=1 => Poisson parameters for clutter/target measurements
          - Lx=5000, Ly=5000 => surveillance area
          - x_death: placeholder state for 'dead' target (Eq. (9))
          - Vmax=20 => uniform speed range in [-Vmax, Vmax] at birth (Eq. (8)).
        """
        self.Np = Np
        self.Nburn = Nburn
        self.thinning = thinning

        # Eq. (6) parameters for birth/death processes
        self.PB = PB  # Probability to become alive if previously dead
        self.PD = PD  # Probability to become dead if previously alive

        # NCV model parameters (Eq. (10) and (11))
        self.sigma_process = sigma_process
        self.tau = tau

        # Observation model parameters
        self.R = R
        self.LambdaC = LambdaC
        self.LambdaX = LambdaX

        # Surveillance area and speeds
        self.Lx = Lx
        self.Ly = Ly
        self.Vmax = Vmax

        # "Death" state for inactive targets (Eq. (9))
        # By default, use zero for all (x,y,vx,vy).
        if x_death is None:
            self.x_death = np.zeros(4, dtype=float)
        else:
            self.x_death = x_death

        # Near-constant velocity (NCV) state transition matrix A_k,n (Eq. (11))
        # A = [[1, 0, tau,   0  ],
        #      [0,  1,   0,  tau ],
        #      [0,  0,   1,  0],
        #      [0,  0,   0,  1 ]]
        tau_I2 = np.array([[self.tau, 0],
                            [0, self.tau]], dtype=float)
        I2 = np.array([[1, 0],
                      [0, 1]], dtype=float)
        self.A = np.block([
            [I2,            tau_I2],
            [np.zeros((2,2)),    I2       ]
        ])

        # Process noise Q_k,n (Eq. (11)): sigma^2 * [...]
        q = self.sigma_process
        self.Q = q * np.array([
            [self.tau**3/3, 0, self.tau**2/2,             0            ],
            [0, self.tau**3/3,      0,             self.tau**2/2       ],
            [self.tau**2/2,             0,             self.tau, 0],
            [0,             self.tau**2/2,             0, self.tau     ]
        ])

        # Measurement covariance = R * I2
        self.CovMeas = self.R * np.eye(2)


    # ----------------------------------------------------------------
    # ALGORITHM 2, STEP 0: Initialize Particle Set
    # ----------------------------------------------------------------
    def init_particles(self, Nmax=5):
        r"""
        This method creates the initial particle set, approximating
        the distribution \(p(x_0, e_0)\).

        The article (Algorithm 1, Step 0) does not fully specify how
        to initialize. It says "For the joint draw of \{x_k, e_k, x_{k-1}, e_{k-1}\},
        the following proposal distribution is used..." [Eq. (14)],
        but the initial time step is often scenario-dependent.

        \[
          \text{One possible approach: all targets inactive at } k=0
          \quad\Longrightarrow\quad
          e_0 = (0,0,\dots,0),\;
          x_0 = \text{(some default or random states)}.
        \]

        Here, for clarity, we start each target as inactive (\(e_0,n=0\)),
        and set the kinematics to zero or a user-specified default.
        Adjust this to your prior knowledge if needed.
        """
        particles = []
        for _ in range(self.Np):
            # By default: all targets are inactive
            e_init = np.zeros(Nmax, dtype=int)
            # By default: x_init is all zeros
            x_init = np.zeros((Nmax, 4), dtype=float)
            # log-weight initially set to 0
            w_init = 0.0
            particles.append({'x': x_init, 'e': e_init, 'w': w_init})

        return particles

    # ----------------------------------------------------------------
    # ALGORITHM 2, STEP 1: Sample Existence Variables (Eq. (6))
    # ----------------------------------------------------------------
    def sample_existence(self, e_prev):
        r"""
        Samples each target's existence variable \( e_{k,n} \) given
        \( e_{k-1,n} \). This follows Eq. (6) in the paper:

        \[
          p(e_{k,n}=1 \mid e_{k-1,n}=1) = 1 - P_D, \quad
          p(e_{k,n}=0 \mid e_{k-1,n}=1) = P_D,
        \]
        \[
          p(e_{k,n}=1 \mid e_{k-1,n}=0) = P_B, \quad
          p(e_{k,n}=0 \mid e_{k-1,n}=0) = 1 - P_B.
        \]

        Parameters
        ----------
        e_prev : ndarray of shape (Nmax,)
            Existence vector at the previous time step.

        Returns
        -------
        e_new : ndarray of shape (Nmax,)
            Sampled existence vector at the current time step.
        """
        Nmax = len(e_prev)
        e_new = np.zeros(Nmax, dtype=int)

        for n in range(Nmax):
            if e_prev[n] == 1:
                # e_{k-1,n} = 1 => remain alive with prob (1 - PD), or die with prob PD
                if np.random.rand() < self.PD:
                    e_new[n] = 0
                else:
                    e_new[n] = 1
            else:
                # e_{k-1,n} = 0 => become alive with prob PB, or stay dead with prob (1 - PB)
                if np.random.rand() < self.PB:
                    e_new[n] = 1
                else:
                    e_new[n] = 0

        return e_new

    # ----------------------------------------------------------------
    # ALGORITHM 2, STEP 2: Sample States p(x_{k,n} | x_{k-1,n}, e_{k,n}, e_{k-1,n}), Eq. (7)
    # ----------------------------------------------------------------
    def sample_motion(self, x_prev, e_prev, e_curr):
        r"""
        Samples each target's kinematic state \(x_{k,n}\) given
        \( x_{k-1,n} \) and the existence variables \( e_{k,n}, e_{k-1,n}\).

        From Eq. (7):
        \[
          p(x_{k,n} \mid x_{k-1,n}, e_{k,n}, e_{k-1,n}) = 
            \begin{cases}
              p_b(x_{k,n}) & \text{if } \{e_{k,n}=1, e_{k-1,n}=0\}, \\
              p_d(x_{k,n}) & \text{if } e_{k,n}=0, \\
              p_u(x_{k,n} \mid x_{k-1,n}) & \text{if } \{e_{k,n}=1, e_{k-1,n}=1\}.
            \end{cases}
        \]

        where
         - \(p_b\) is the birth density (Eq. (8)),
         - \(p_d\) is the death density (Eq. (9)),
         - \(p_u\) is the near-constant velocity update (Eq. (10)).

        Parameters
        ----------
        x_prev : ndarray of shape (Nmax, 4)
            States of the previous time step.
        e_prev : ndarray of shape (Nmax,)
            Existence vector at the previous time step.
        e_curr : ndarray of shape (Nmax,)
            Existence vector at the current time step.

        Returns
        -------
        x_curr : ndarray of shape (Nmax, 4)
            Sampled states at the current time step.
        """
        Nmax = len(e_prev)
        x_curr = np.zeros((Nmax, 4), dtype=float)

        for n in range(Nmax):
            if (e_prev[n] == 0) and (e_curr[n] == 1):
                # Birth => p_b(x_{k,n}), Eq. (8)
                # Uniform in [0, Lx], [0, Ly], velocity in [-Vmax, Vmax].
                x_birth = np.zeros(4, dtype=float)
                x_birth[0] = np.random.uniform(0, self.Lx)               # x
                x_birth[2] = np.random.uniform(0, self.Ly)               # y
                x_birth[1] = np.random.uniform(-self.Vmax, self.Vmax)    # vx
                x_birth[3] = np.random.uniform(-self.Vmax, self.Vmax)    # vy
                x_curr[n, :] = x_birth

            elif e_curr[n] == 0:
                # Death => p_d(x_{k,n}) = δ(x_death), Eq. (9)
                x_curr[n, :] = self.x_death

            else:
                # Update => p_u(x_{k,n} | x_{k-1,n}), Eq. (10)
                # near-constant velocity: N(A x_{k-1,n}, Q).
                mean = self.A @ x_prev[n, :]  # A is 4x4, x_prev[n,:] is 4x1
                x_curr[n, :] = np.random.multivariate_normal(mean, self.Q)

        return x_curr

    # ----------------------------------------------------------------
    # LOG-LIKELIHOOD p(z_k | x_k), Eqs. (12)–(13)
    # ----------------------------------------------------------------
    def log_likelihood_measurements(self, z, x, e, LambdaC, LambdaX):
        r"""
        Computes \(\log p(z_k | x_k)\) under the association-free model (Eqs. (12)–(13)):

        \[
          p(z_k \mid x_k) \;=\;
           \frac{\exp(-\mu_k)}{M_k!}
           \prod_{m=1}^{M_k} \lambda(z_k^{(m)})  \;,\quad
          \text{where } \mu_k = \Lambda_C + \sum_{n \in \text{active}} \Lambda_X,
        \]
        \[
          \lambda(z) = \Lambda_C\, p_C(z) \;+\; 
                        \sum_{n \in \text{active}} \bigl[\Lambda_X\, p_x(z \mid x_{k,n})\bigr],
        \]
        with
         - \(p_C(z)\) uniform in \([0,L_x]\times [0,L_y]\),
         - \(p_x(z \mid x_{k,n}) = \mathcal{N}(z; [x_{k,n}, y_{k,n}], \text{CovMeas})\).

        We do everything in the log domain for numerical stability.

        Parameters
        ----------
        z : ndarray, shape (M_k, 2)
            All measurements at time k.
        x : ndarray, shape (Nmax, 4)
            The states (x, vx, y, vy) for each target.
        e : ndarray, shape (Nmax,)
            Existence indicator for each target (0 or 1).
        LambdaC : float
            Mean clutter rate \(\Lambda_C\).
        LambdaX : float
            Mean detection rate per active target \(\Lambda_X\).

        Returns
        -------
        log_like : float
            The log of p(z_k | x_k).
        """
        # Number of active targets
        N_active = np.sum(e)
        # mu_k = LambdaC + N_active * LambdaX
        mu_k = LambdaC + N_active * LambdaX
        # M_k = number of measurements
        M_k = z.shape[0]

        # log(M_k!)
        if M_k <= 1:
            log_factorial_M = 0.0
        else:
            log_factorial_M = np.sum(np.log(np.arange(1, M_k+1)))

        # Start log p(z_k|x_k)
        #  = - mu_k - log(M_k!) + sum_{m=1..M_k} [ log(lambda(z_m)) ]
        log_like = -mu_k - log_factorial_M

        # Uniform clutter density = 1 / (Lx * Ly)
        clutter_density = 1.0 / (self.Lx * self.Ly)

        # Utility for computing Gaussian pdf in 2D
        def normal_pdf_2d(zm, mean_xy, Cov):
            return multivariate_normal.pdf(zm, mean=mean_xy, cov=Cov)

        for m in range(M_k):
            zm = z[m, :]  # measurement
            # lambda(z_m) = LambdaC * pC(z_m) + sum_{active n} [ LambdaX * px(z_m| x_{k,n}) ]
            lam_val = LambdaC * clutter_density

            for n in range(len(e)):
                if e[n] == 1:
                    # x[n,:] = (x, vx, y, vy)
                    mean_xy = np.array([x[n, 0], x[n, 2]], dtype=float)
                    lam_val += LambdaX * normal_pdf_2d(zm, mean_xy, self.CovMeas)

            # If lam_val ~ 0 => log -> -inf
            if lam_val < 1e-300:
                return -1e16
            log_like += np.log(lam_val)

        return log_like


    # ---------------------------------------------------------
    # Main Filtering Loop (MCMC-based Particle Algorithm)
    # ---------------------------------------------------------
    def filter(self, measurements, Nmax=5):
        """
        Perform the MCMC-based particle filtering over a sequence of measurements
        according to the approach in Section III-B and Algorithm 1 of the paper.

        Steps:
          1) For k=0..K-1:
             (a) Prediction: sample e_k from e_{k-1} (Eq. (6)) & x_k from x_{k-1} (Eq. (7))
             (b) Joint draw (Eq. (14)) => predicted particles
             (c) Refinement: MCMC step that successively samples each target's {x_{k,n}, e_{k,n}}
                 using q3(...) = p(x_{k,n} | x_{k-1,n}, e_{k,n}) p(e_{k,n} | e_{k-1,n}) (Eq. (15))
             (d) Weight update with p(z_k | x_k) from Eqs. (12)–(13), in log form
             (e) Resampling

        Returns
        -------
        all_particles : list of length K
            all_particles[k] is the final particle set for time k.
        card_mean : array of shape (K,)
            Mean cardinality (number of alive targets) per time step.
        card_std : array of shape (K,)
            Standard deviation of the cardinality estimates.
        """
        K = len(measurements)
        # Initialize from (possibly) q1(...) (Eq. (14))
        # Here we simply start all in a default "dead" state.
        particles_prev = self.init_particles(Nmax=Nmax)

        card_mean = np.zeros(K)
        card_std = np.zeros(K)
        all_particles = []

        # Optional progress bar (from tqdm) for user convenience
        for k in tqdm(range(K), desc="MCMC-PF Filtering"):
            z_k = measurements[k]

            # ---- (a) & (b) Prediction Step & Joint Draw (Eq. (14)) ----
            predicted = []
            for i in range(self.Np):
                x_old = particles_prev[i]['x']
                e_old = particles_prev[i]['e']
                w_old = particles_prev[i]['w']  # log-weight from previous

                # Sample e_k from e_{k-1} (Eq. (6))
                e_new = self.sample_existence(e_old)
                # Sample x_k from x_{k-1}, e_k, e_{k-1} (Eq. (7))
                x_new = self.sample_motion(x_old, e_old, e_new)

                predicted.append({'x': x_new, 'e': e_new, 'w': w_old})

            # ---- (c) Refinement Step (Eq. (15)) ----
            # For each particle, run a short MCMC chain that updates each target's {x_{k,n}, e_{k,n}} successively.
            # We skip re-drawing x_{k-1}, e_{k-1} (q2 not used in this simplified version).
            for i in range(self.Np):
                x_cur = predicted[i]['x'].copy()
                e_cur = predicted[i]['e'].copy()

                # Basic MCMC with Nburn + thinning * 1 (or more) steps
                # Each iteration, we cycle through targets n=1..Nmax
                for _ in range(self.Nburn + self.thinning):
                    for n in range(Nmax):
                        e_oldn = e_cur[n]
                        x_oldn = x_cur[n, :].copy()

                        # Propose a local change in existence or state:
                        # The article uses q3(...) = p(x_k,n | x_{k-1,n}, e_k,n) p(e_k,n| e_{k-1,n}).
                        # Here, we do a small random move if alive, or keep x_death if dead.
                        if e_oldn == 1:
                            # local small perturbation
                            x_prop_n = x_oldn + 0.05 * randn(4)
                        else:
                            x_prop_n = self.x_death.copy()

                        # Evaluate log-likelihood before
                        ll_old = self.log_likelihood_measurements(z_k, x_cur, e_cur, self.LambdaC, self.LambdaX)

                        # Modify just the nth target's state
                        x_test = x_cur.copy()
                        x_test[n, :] = x_prop_n
                        # We could also consider flipping e_cur[n] with small probability
                        # but typically that would require referencing e_{k-1,n}; omitted for brevity.

                        ll_new = self.log_likelihood_measurements(z_k, x_test, e_cur, self.LambdaC, self.LambdaX)

                        # Accept/reject (Metropolis-Hastings)
                        alpha = np.exp(ll_new - ll_old)
                        if np.random.rand() < alpha:
                            x_cur[n, :] = x_prop_n

                # After MCMC chain, compute final log-likelihood
                w_final = self.log_likelihood_measurements(z_k, x_cur, e_cur, self.LambdaC, self.LambdaX)
                predicted[i]['x'] = x_cur
                predicted[i]['e'] = e_cur
                predicted[i]['w'] = w_final

            # ---- (d) Resampling ----
            weights_log = np.array([p['w'] for p in predicted])
            wmax = np.max(weights_log)
            # Convert log-weights to normalized weights
            weights_lin = np.exp(weights_log - wmax)
            weights_lin /= np.sum(weights_lin)

            indices = choice(self.Np, size=self.Np, replace=True, p=weights_lin)
            new_particles = []
            for idx in indices:
                # Copy the chosen particle; set new log-weight=0 after resampling
                new_particles.append({
                    'x': predicted[idx]['x'].copy(),
                    'e': predicted[idx]['e'].copy(),
                    'w': 0.0
                })

            # ---- (e) Cardinality Estimation ----
            card_est = [np.sum(p['e']) for p in new_particles]
            card_mean[k] = np.mean(card_est)
            card_std[k] = np.std(card_est)

            # Move to next time step
            particles_prev = new_particles
            all_particles.append(new_particles)

        return all_particles, card_mean, card_std


# ----------------------------------------------------------------------
# 4) HELPER FUNCTION: TAKE THE MEAN OF PARTICLES
# ----------------------------------------------------------------------
def extract_average_particle_tracks(particle_set):
    """
    Compute the weighted average of the particle states (x) and
    the (fractional) existence (e) at a single time step.

    Parameters
    ----------
    particle_set : list of dict
        Each dict has keys:
          'x': (Nmax, 4) - the kinematic state,
          'e': (Nmax,) in {0,1} - existence indicators,
          'w': float (the log-weight).

    Returns
    -------
    x_avg : numpy array, shape=(Nmax, 4)
        The weighted average of the kinematic states.
    e_avg : numpy array, shape=(Nmax,)
        The (fractional) weighted average of the existences, i.e. in [0,1].
        (If you want hard 0/1, you can threshold this later.)
    """

    # 1) Gather log-weights in a NumPy array
    log_w = np.array([p['w'] for p in particle_set], dtype=float)

    # 2) Convert to linear weights safely (avoid exp underflow)
    w_max = np.max(log_w)
    w_lin = np.exp(log_w - w_max)

    # 3) Normalize the weights
    w_lin_sum = np.sum(w_lin)
    if w_lin_sum < 1e-300:
        # If all weights are extremely small or -inf, fallback:
        w_lin = np.ones_like(w_lin) / len(w_lin_sum)
        w_lin_sum = 1.0
    else:
        w_lin /= w_lin_sum

    # 4) Compute weighted average for x and e
    #    We'll accumulate sums in arrays
    #    x_sum will be shape (Nmax, 4), e_sum will be shape (Nmax,)
    Nmax = particle_set[0]['x'].shape[0]
    x_sum = np.zeros((Nmax, 4), dtype=float)
    e_sum = np.zeros(Nmax, dtype=float)

    for i, part in enumerate(particle_set):
        x_sum += w_lin[i] * part['x']
        e_sum += w_lin[i] * part['e']

    x_avg = x_sum  # Weighted mean
    e_avg = e_sum  # Weighted mean in [0,1]

    return x_avg, e_avg


def get_estimated_tracks_over_time(all_particles):
    """
    Use the weighted average of the particles at each time k to produce
    the estimated track over time.

    Parameters
    ----------
    all_particles : list of lists
        all_particles[k] is the particle set at time k, 
        i.e. a list of dicts [{'x':..., 'e':..., 'w':...}, ...].

    Returns
    -------
    x_est : numpy array, shape=(K, Nmax, 4)
        Weighted-average states at each time k.
    e_est : numpy array, shape=(K, Nmax)
        Weighted-average existences (in [0,1]) at each time k.
        (You can threshold later if you need strict 0/1.)
    """
    K = len(all_particles)
    Nmax = all_particles[0][0]['x'].shape[0]

    x_est = np.zeros((K, Nmax, 4), dtype=float)
    e_est = np.zeros((K, Nmax), dtype=float)

    for k in range(K):
        x_k, e_k = extract_average_particle_tracks(all_particles[k])
        x_est[k, :, :] = x_k
        e_est[k, :] = e_k

    return x_est, e_est

# ----------------------------------------------------------------------
# 5) MAIN SIMULATION AND PLOTTING
# ----------------------------------------------------------------------
def main_simulation(K=80, do_plot=True):
    """
    Executes the complete simulation in strict accordance with the article's methodology:
     1) Generate ground truth for a maximum of 5 targets (births/deaths, near-constant velocity).
     2) Generate measurements (association-free, Poisson-based).
     3) Instantiate the MCMC-based Particle Filter with chosen parameters (Np, Nburn, etc.).
     4) Perform filtering over the measurement sequence (Algorithm 1).
     5) Extract a single estimated track using the highest-weight particle at each time step.
     6) Optionally plot:
        - True vs. estimated cardinality
        - True vs. estimated x-y trajectories
    """
    # 1) Generate ground truth (Eq. (10), Section IV)
    ground_truth = generate_ground_truth(K=K)

    # 2) Generate measurements (Eqs. (12)–(13), association-free model)
    measurements = generate_measurements(ground_truth)

    # 3) Instantiate MCMC-based Particle Filter with smaller defaults
    #    (the article used Np=4000, Nburn=1000, etc. for final results,
    #     but here we reduce them for a quicker demo).
    pf = MCMCParticleFilter(Np=100,    # fewer particles for demo
                            Nburn=2,  # fewer burn-in steps
                            thinning=1,
                            PB=0.1,
                            PD=0.1)

    # 4) Run the particle filter
    all_particles, card_mean, card_std = pf.filter(measurements, Nmax=5)

    # 5) Extract track estimates (best-weight particle) over time
    x_est, e_est = get_estimated_tracks_over_time(all_particles)

    # 6) (Optional) Plot results
    if do_plot:
        import matplotlib.pyplot as plt
        e_true = ground_truth['e_true']  # shape (K, Nmax)
        x_true = ground_truth['x_true']  # shape (K, Nmax, 4)
        true_card = np.sum(e_true, axis=1)

        # (A) Plot the cardinality
        plt.figure(figsize=(8, 4))
        plt.plot(true_card, 'k-', label='True Cardinality')
        plt.plot(card_mean, 'r--', label='Estimated Mean Cardinality')
        plt.fill_between(
            range(K),
            card_mean - card_std,
            card_mean + card_std,
            color='r',
            alpha=0.2,
            label='±1 std'
        )
        plt.xlabel('Time Step')
        plt.ylabel('Cardinality')
        plt.title('Cardinality: True vs. Estimated')
        plt.legend()
        plt.grid(True)
        plt.tight_layout()

        # (B) Plot the true and estimated tracks in x-y plane
        plt.figure(figsize=(6, 6))
        # Plot each true target
        for n in range(e_true.shape[1]):
            alive_times = np.where(e_true[:, n] == 1)[0]
            if len(alive_times) > 0:
                # Récupérer le vecteur x et y de la cible n
                xvals = x_true[alive_times, n, 0]
                yvals = x_true[alive_times, n, 2]
                
                # Indices de début et fin
                k_start = alive_times[0]
                k_stop = alive_times[-1]
                
                plt.plot(xvals, yvals, '-o', label=f"True T{n+1}")
                
                # Ajouter un marqueur cercle (o) pour la naissance
                plt.plot(x_true[k_start, n, 0],
                        x_true[k_start, n, 2],
                        'ko',  # k = noir, o = cercle
                        markersize=8)

                # Ajouter un marqueur triangle (Δ) pour la mort
                # (seulement si la cible meurt effectivement)
                if k_stop < (e_true.shape[0] - 1):  
                    plt.plot(x_true[k_stop, n, 0],
                            x_true[k_stop, n, 2],
                            'k^',  # k = noir, ^ = triangle
                            markersize=8)

        # Plot each estimated target from best-weight particle
        for n in range(e_est.shape[1]):
            alive_times_est = np.where(e_est[:, n] == 1)[0]
            if len(alive_times_est) > 0:
                plt.plot(
                    x_est[alive_times_est, n, 0],
                    x_est[alive_times_est, n, 2],
                    '--x',
                    label=f"Est T{n+1}"
                )

        plt.xlabel('x [m]')
        plt.ylabel('y [m]')
        plt.xlim([0, 5000])
        plt.ylim([0, 5000])
        plt.title('Tracks in x-y Plane')
        plt.legend()
        plt.grid(True)
        plt.tight_layout()
        plt.show()

    return ground_truth, measurements, (card_mean, card_std), x_est, e_est


# ----------------------------------------------------------------------
# 6) RUN EXAMPLE
# ----------------------------------------------------------------------
if __name__ == "__main__":
    # Example: single run, T=100
    T = 100
    gt, meas, (c_mean, c_std), x_estimated, e_estimated = main_simulation(K=T, do_plot=True)

    # The user can adapt the code to do multiple runs or compute advanced metrics.

MCMC-PF Filtering:  57%|█████▋    | 57/100 [07:22<05:19,  7.43s/it]

In [None]:
"""
MCMC-Based Particle Filter for Multi-Target Tracking with Enhanced Plotting
===========================================================================
This script demonstrates:
 1) Ground-truth generation for up to 5 targets (with births/deaths).
 2) Poisson clutter + target-origin measurements.
 3) An MCMC-based particle filter to estimate the (x, y) positions and (vx, vy) velocities
    of each target, and whether each target exists or not at each time step.
 4) Enhanced Plots:
    - True tracks vs. estimated tracks in x-y plane.
    - Cardinality (true vs. estimated) over time.
    - Existence probabilities over time.
    - Root Mean Square Error (RMSE) computation.
    - Measurements and estimates at selected time steps.
 5) A progress bar (via tqdm) to indicate the filtering progress.
"""

import numpy as np
import matplotlib.pyplot as plt
from numpy.random import rand, randn, choice
from scipy.stats import poisson, multivariate_normal
from tqdm import tqdm  # for progress bar

# ----------------------------------------------------------------------
# 1) GROUND-TRUTH GENERATION
# ----------------------------------------------------------------------
def generate_ground_truth(K=80,  # total number of timesteps
                          tau=3.0,
                          sigma_process=0.5,
                          scenario_params=None):
    """
    Generate ground-truth states for up to 5 targets, with known birth times
    and at least one known death time. Targets follow a near-constant velocity (NCV) model.

    Parameters
    ----------
    K : int
        Number of timesteps.
    tau : float
        Sampling interval (seconds).
    sigma_process : float
        Standard deviation of the process noise for NCV model.
    scenario_params : dict
        Dictionary specifying each target's birth_time, death_time, and initial state.

    Returns
    -------
    ground_truth : dict
        Contains:
          x_true : array, shape=(K, Nmax, 4), the true states for each target
                   (x, vx, y, vy).
          e_true : array, shape=(K, Nmax), the existence indicators (0 or 1).
    """

    # Default scenario: 5 targets
    #   T1, T2, T3 born at k=1
    #   T4, T5 born at k=25
    #   T1 dies at k=50
    if scenario_params is None:
        scenario_params = {
            0: {'birth_time': 1,
                'death_time': 50,
                'init_state': np.array([500.0,   1.0,  500.0,   15.0])},  
            1: {'birth_time': 1,
                'death_time': None,
                'init_state': np.array([1000.0, -1.0, 4000.0, -10.0])},
            2: {'birth_time': 1,
                'death_time': None,
                'init_state': np.array([4000.0,   5.0, 1000.0,   5.0])},
            3: {'birth_time': 25,
                'death_time': None,
                'init_state': np.array([4500.0,  -5.0, 4500.0,   5.0])},
            4: {'birth_time': 25,
                'death_time': None,
                'init_state': np.array([500.0,   10.0, 4500.0,  -15.0])}
        }

    Nmax = len(scenario_params)  # should be 5
    x_true = np.zeros((K, Nmax, 4))
    e_true = np.zeros((K, Nmax), dtype=int)

    # Define the state transition matrix for near-constant velocity (NCV)
    I_2 = np.array([[1, 0],
                    [0, 1]], dtype=float)
    tau_I_2 = np.array([[tau, 0],
                        [0, tau]], dtype=float)
    A = np.block([[I_2,           tau_I_2],
                  [np.zeros((2, 2)), I_2          ]])
    # Process noise covariance
    q = sigma_process**2
    Q = q * np.array([[tau**3/3, 0, tau**2/2,         0],
                      [0,       tau**3/3, 0,         tau**2/2],
                      [tau**2/2,        0,         tau, 0],
                      [0,        tau**2/2,         0,       tau]])

    # Initialize states for the first timestep
    for n in range(Nmax):
        bt = scenario_params[n]['birth_time']
        dt = scenario_params[n]['death_time']
        x0 = scenario_params[n]['init_state'].copy()

        if bt is not None and 1 >= bt:
            x_true[0, n, :] = x0
            e_true[0, n] = 1

    # Generate ground truth for subsequent timesteps
    for k in range(1, K):
        for n in range(Nmax):
            bt = scenario_params[n]['birth_time']
            dt = scenario_params[n]['death_time']

            if bt is not None and k >= bt and (dt is None or k < dt):
                # If k == bt, then the target is just born at this step
                if k == bt:
                    x_true[k, n, :] = scenario_params[n]['init_state']
                    e_true[k, n] = 1
                else:
                    # Continue the NCV model
                    x_prev = x_true[k-1, n, :]
                    mean = A @ x_prev
                    x_true[k, n, :] = np.random.multivariate_normal(mean, Q)
                    e_true[k, n] = 1
            else:
                # Target does not exist at time k
                x_true[k, n, :] = 0.0
                e_true[k, n] = 0

    return {'x_true': x_true, 'e_true': e_true}

# ----------------------------------------------------------------------
# 2) MEASUREMENT GENERATION
# ----------------------------------------------------------------------
def generate_measurements(ground_truth,
                          Lx=5000.0, 
                          Ly=5000.0,
                          LambdaC=20.0,  # Mean clutter rate (Poisson)
                          LambdaX=1.0,   # Mean detection rate per active target (Poisson)
                          Sigma_x=100.0, # Measurement variance (Sigma_x * I2)
                          seed=None):
    """
    Generate a set of measurements for each time step `k` based on the
    'association-free' model described in the article.

    Parameters
    ----------
    ground_truth : dict
        Dictionary containing at least:
          - ground_truth['x_true']: array of shape (K, Nmax, 4),
            where x_true[k, n, :] = (x_k,n, vx_k,n, y_k,n, vy_k,n).
          - ground_truth['e_true']: array of shape (K, Nmax),
            where e_true[k, n] = 1 if target `n` exists at time `k`, otherwise 0.
    Lx, Ly : float
        Dimensions of the surveillance region in meters.
    LambdaC : float
        Mean number of clutter (false alarms) per time step.
    LambdaX : float
        Mean number of detections per active target and per time step.
    Sigma_x : float
        Measurement variance in both x and y directions (covariance is Sigma_x * I2).
    seed : int or None
        Seed for the random number generator (for reproducibility).
        If None, results will be random.

    Returns
    -------
    measurements : list
        A list of length K.
        measurements[k] is a numpy array of shape (M_k, 2),
        where M_k is the total number of measurements at time step `k` (target + clutter).
    """
    if seed is not None:
        np.random.seed(seed)  # For optional reproducibility

    x_true = ground_truth['x_true']  # (K, Nmax, 4)
    e_true = ground_truth['e_true']  # (K, Nmax)
    K, Nmax, _ = x_true.shape

    # Measurement covariance matrix: Sigma_x * I2
    R_meas = Sigma_x * np.eye(2)

    # List to store measurements for each time step k
    measurements = []

    # Loop through each time step k
    for k in range(K):
        # -- 1) Identify active targets
        active_targets = np.where(e_true[k, :] == 1)[0]

        # -- 2) Generate target-originated measurements
        #     For each active target, draw the number of measurements from Poisson(LambdaX).
        target_meas_list = []
        for n in active_targets:
            # Number of measurements for target `n`
            count_target_n = poisson.rvs(LambdaX)
            if count_target_n > 0:
                # Mean of the Gaussian distribution for target `n`: (x_k,n, y_k,n)
                mean_xy = np.array([
                    x_true[k, n, 0],  # x_k,n
                    x_true[k, n, 2]   # y_k,n
                ])
                # Generate Gaussian measurements
                z_n = np.random.multivariate_normal(mean_xy, R_meas, size=count_target_n)
                target_meas_list.append(z_n)

        # -- 3) Generate clutter (false alarms)
        #     Draw the number of clutter points from Poisson(LambdaC).
        clutter_count = poisson.rvs(LambdaC)
        # Clutter measurements are uniformly distributed in [0, Lx] x [0, Ly].
        clutter_meas = np.column_stack([
            np.random.uniform(0, Lx, size=clutter_count),
            np.random.uniform(0, Ly, size=clutter_count)
        ])

        # -- 4) Combine target and clutter measurements
        if len(target_meas_list) > 0:
            all_target_meas = np.vstack(target_meas_list)  # Concatenate all target measurements
            all_meas = np.vstack([all_target_meas, clutter_meas])
        else:
            # If there are no active targets or Poisson(LambdaX)=0, only clutter measurements remain
            all_meas = clutter_meas

        measurements.append(all_meas)

    return measurements

# ----------------------------------------------------------------------
# 3) MCMC-BASED PARTICLE FILTER CLASS
# ----------------------------------------------------------------------
class MCMCParticleFilter:
    """
    MCMC-Based Particle Filter for a time-varying number of targets, as described
    in Section III–IV of the article.

    The targets are modeled with:
      - Existence variables e_{k,n} in {0,1} (Eq. (6))
      - Kinematics x_{k,n} = (x, vx, y, vy) (fixed dimension Nmax x 4)
      - Transition densities:
         * Birth:   Eq. (8)
         * Death:   Eq. (9)
         * Update:  Eq. (10)
      - Association-free measurement model: Eqs. (12)–(13)
      - MCMC approach: Eqs. (14)–(15) for refinement
    """

    def __init__(self,
                 Np=4000,
                 Nburn=1000,
                 thinning=6,
                 PB=0.01,       # Probability of birth, Eq. (6)
                 PD=0.01,       # Probability of death, Eq. (6)
                 sigma_process=0.5,  # std. dev. squared in NCV model
                 tau=3.0,       # sampling interval
                 R=100.0,       # measurement noise covariance = 100 * I2
                 LambdaC=20.0,  # clutter rate
                 LambdaX=1.0,   # detection rate per active target
                 Lx=5000.0,
                 Ly=5000.0,
                 x_death=None,
                 Vmax=20.0):    # max speed for birth sampling
        """
        Initialize MCMC-based Particle Filter with default parameters from the article:
          - Np=4000 (num. particles)
          - Nburn=1000 (burn-in for MCMC)
          - thinning=6
          - PB, PD in [0,1]
          - sigma_process=0.5
          - R=100 => CovMeas = 100 * I2
          - LambdaC=20, LambdaX=1 => Poisson parameters for clutter/target measurements
          - Lx=5000, Ly=5000 => surveillance area
          - x_death: placeholder state for 'dead' target (Eq. (9))
          - Vmax=20 => uniform speed range in [-Vmax, Vmax] at birth (Eq. (8)).
        """
        self.Np = Np
        self.Nburn = Nburn
        self.thinning = thinning

        # Eq. (6) parameters for birth/death processes
        self.PB = PB  # Probability to become alive if previously dead
        self.PD = PD  # Probability to become dead if previously alive

        # NCV model parameters (Eq. (10) and (11))
        self.sigma_process = sigma_process
        self.tau = tau

        # Observation model parameters
        self.R = R
        self.LambdaC = LambdaC
        self.LambdaX = LambdaX

        # Surveillance area and speeds
        self.Lx = Lx
        self.Ly = Ly
        self.Vmax = Vmax

        # "Death" state for inactive targets (Eq. (9))
        # By default, use zero for all (x,y,vx,vy).
        if x_death is None:
            self.x_death = np.zeros(4, dtype=float)
        else:
            self.x_death = x_death

        # Near-constant velocity (NCV) state transition matrix A_k,n (Eq. (11))
        # A = [[1, 0, tau,   0  ],
        #      [0,  1,   0,  tau ],
        #      [0,  0,   1,  0],
        #      [0,  0,   0,  1 ]]
        tau_I2 = np.array([[self.tau, 0],
                            [0, self.tau]], dtype=float)
        I2 = np.array([[1, 0],
                      [0, 1]], dtype=float)
        self.A = np.block([
            [I2,            tau_I2],
            [np.zeros((2,2)),    I2       ]
        ])

        # Process noise Q_k,n (Eq. (11)): sigma^2 * [...]
        q = self.sigma_process
        self.Q = q * np.array([
            [self.tau**3/3, 0, self.tau**2/2,             0            ],
            [0, self.tau**3/3,      0,             self.tau**2/2       ],
            [self.tau**2/2,             0,             self.tau, 0],
            [0,             self.tau**2/2,             0, self.tau     ]
        ])

        # Measurement covariance = R * I2
        self.CovMeas = self.R * np.eye(2)


    # ----------------------------------------------------------------
    # ALGORITHM 2, STEP 0: Initialize Particle Set
    # ----------------------------------------------------------------
    def init_particles(self, Nmax=5):
        r"""
        This method creates the initial particle set, approximating
        the distribution \(p(x_0, e_0)\).

        The article (Algorithm 1, Step 0) does not fully specify how
        to initialize. It says "For the joint draw of \{x_k, e_k, x_{k-1}, e_{k-1}\},
        the following proposal distribution is used..." [Eq. (14)],
        but the initial time step is often scenario-dependent.

        \[
          \text{One possible approach: all targets inactive at } k=0
          \quad\Longrightarrow\quad
          e_0 = (0,0,\dots,0),\;
          x_0 = \text{(some default or random states)}.
        \]

        Here, for clarity, we start each target as inactive (\(e_0,n=0\)),
        and set the kinematics to zero or a user-specified default.
        Adjust this to your prior knowledge if needed.
        """
        particles = []
        for _ in range(self.Np):
            # By default: all targets are inactive
            e_init = np.zeros(Nmax, dtype=int)
            # By default: x_init is all zeros
            x_init = np.zeros((Nmax, 4), dtype=float)
            # log-weight initially set to 0
            w_init = 0.0
            particles.append({'x': x_init, 'e': e_init, 'w': w_init})

        return particles


    # ----------------------------------------------------------------
    # ALGORITHM 2, STEP 1: Sample Existence Variables (Eq. (6))
    # ----------------------------------------------------------------
    def sample_existence(self, e_prev):
        r"""
        Samples each target's existence variable \( e_{k,n} \) given
        \( e_{k-1,n} \). This follows Eq. (6) in the paper:

        \[
          p(e_{k,n}=1 \mid e_{k-1,n}=1) = 1 - P_D, \quad
          p(e_{k,n}=0 \mid e_{k-1,n}=1) = P_D,
        \]
        \[
          p(e_{k,n}=1 \mid e_{k-1,n}=0) = P_B, \quad
          p(e_{k,n}=0 \mid e_{k-1,n}=0) = 1 - P_B.
        \]

        Parameters
        ----------
        e_prev : ndarray of shape (Nmax,)
            Existence vector at the previous time step.

        Returns
        -------
        e_new : ndarray of shape (Nmax,)
            Sampled existence vector at the current time step.
        """
        Nmax = len(e_prev)
        e_new = np.zeros(Nmax, dtype=int)

        for n in range(Nmax):
            if e_prev[n] == 1:
                # e_{k-1,n} = 1 => remain alive with prob (1 - PD), or die with prob PD
                if np.random.rand() < self.PD:
                    e_new[n] = 0
                else:
                    e_new[n] = 1
            else:
                # e_{k-1,n} = 0 => become alive with prob PB, or stay dead with prob (1 - PB)
                if np.random.rand() < self.PB:
                    e_new[n] = 1
                else:
                    e_new[n] = 0

        return e_new


    # ----------------------------------------------------------------
    # ALGORITHM 2, STEP 2: Sample States p(x_{k,n} | x_{k-1,n}, e_{k,n}, e_{k-1,n}), Eq. (7)
    # ----------------------------------------------------------------
    def sample_motion(self, x_prev, e_prev, e_curr):
        r"""
        Samples each target's kinematic state \(x_{k,n}\) given
        \( x_{k-1,n} \) and the existence variables \( e_{k,n}, e_{k-1,n}\).

        From Eq. (7):
        \[
          p(x_{k,n} \mid x_{k-1,n}, e_{k,n}, e_{k-1,n}) = 
            \begin{cases}
              p_b(x_{k,n}) & \text{if } \{e_{k,n}=1, e_{k-1,n}=0\}, \\
              p_d(x_{k,n}) & \text{if } e_{k,n}=0, \\
              p_u(x_{k,n} \mid x_{k-1,n}) & \text{if } \{e_{k,n}=1, e_{k-1,n}=1\}.
            \end{cases}
        \]

        where
         - \(p_b\) is the birth density (Eq. (8)),
         - \(p_d\) is the death density (Eq. (9)),
         - \(p_u\) is the near-constant velocity update (Eq. (10)).

        Parameters
        ----------
        x_prev : ndarray of shape (Nmax, 4)
            States of the previous time step.
        e_prev : ndarray of shape (Nmax,)
            Existence vector at the previous time step.
        e_curr : ndarray of shape (Nmax,)
            Existence vector at the current time step.

        Returns
        -------
        x_curr : ndarray of shape (Nmax, 4)
            Sampled states at the current time step.
        """
        Nmax = len(e_prev)
        x_curr = np.zeros((Nmax, 4), dtype=float)

        for n in range(Nmax):
            if (e_prev[n] == 0) and (e_curr[n] == 1):
                # Birth => p_b(x_{k,n}), Eq. (8)
                # Uniform in [0, Lx], [0, Ly], velocity in [-Vmax, Vmax].
                x_birth = np.zeros(4, dtype=float)
                x_birth[0] = np.random.uniform(0, self.Lx)               # x
                x_birth[2] = np.random.uniform(0, self.Ly)               # y
                x_birth[1] = np.random.uniform(-self.Vmax, self.Vmax)    # vx
                x_birth[3] = np.random.uniform(-self.Vmax, self.Vmax)    # vy
                x_curr[n, :] = x_birth

            elif e_curr[n] == 0:
                # Death => p_d(x_{k,n}) = δ(x_death), Eq. (9)
                x_curr[n, :] = self.x_death

            else:
                # Update => p_u(x_{k,n} | x_{k-1,n}), Eq. (10)
                # near-constant velocity: N(A x_{k-1,n}, Q).
                mean = self.A @ x_prev[n, :]  # A is 4x4, x_prev[n,:] is 4x1
                x_curr[n, :] = np.random.multivariate_normal(mean, self.Q)

        return x_curr


    # ----------------------------------------------------------------
    # LOG-LIKELIHOOD p(z_k | x_k), Eqs. (12)–(13)
    # ----------------------------------------------------------------
    def log_likelihood_measurements(self, z, x, e, LambdaC, LambdaX):
        r"""
        Computes \(\log p(z_k | x_k)\) under the association-free model (Eqs. (12)–(13)):

        \[
          p(z_k \mid x_k) \;=\;
           \frac{\exp(-\mu_k)}{M_k!}
           \prod_{m=1}^{M_k} \lambda(z_k^{(m)})  \;,\quad
          \text{where } \mu_k = \Lambda_C + \sum_{n \in \text{active}} \Lambda_X,
        \]
        \[
          \lambda(z) = \Lambda_C\, p_C(z) \;+\; 
                        \sum_{n \in \text{active}} \bigl[\Lambda_X\, p_x(z \mid x_{k,n})\bigr],
        \]
        with
         - \(p_C(z)\) uniform in \([0,L_x]\times [0,L_y]\),
         - \(p_x(z \mid x_{k,n}) = \mathcal{N}(z; [x_{k,n}, y_{k,n}], \text{CovMeas})\).

        We do everything in the log domain for numerical stability.

        Parameters
        ----------
        z : ndarray, shape (M_k, 2)
            All measurements at time k.
        x : ndarray, shape (Nmax, 4)
            The states (x, vx, y, vy) for each target.
        e : ndarray, shape (Nmax,)
            Existence indicator for each target (0 or 1).
        LambdaC : float
            Mean clutter rate \(\Lambda_C\).
        LambdaX : float
            Mean detection rate per active target \(\Lambda_X\).

        Returns
        -------
        log_like : float
            The log of p(z_k | x_k).
        """
        # Number of active targets
        N_active = np.sum(e)
        # mu_k = LambdaC + N_active * LambdaX
        mu_k = LambdaC + N_active * LambdaX
        # M_k = number of measurements
        M_k = z.shape[0]

        # log(M_k!)
        if M_k <= 1:
            log_factorial_M = 0.0
        else:
            log_factorial_M = np.sum(np.log(np.arange(1, M_k+1)))

        # Start log p(z_k|x_k)
        #  = - mu_k - log(M_k!) + sum_{m=1..M_k} [ log(lambda(z_m)) ]
        log_like = -mu_k - log_factorial_M

        # Uniform clutter density = 1 / (Lx * Ly)
        clutter_density = 1.0 / (self.Lx * self.Ly)

        # Utility for computing Gaussian pdf in 2D
        def normal_pdf_2d(zm, mean_xy, Cov):
            return multivariate_normal.pdf(zm, mean=mean_xy, cov=Cov)

        for m in range(M_k):
            zm = z[m, :]  # measurement
            # lambda(z_m) = LambdaC * pC(z_m) + sum_{active n} [ LambdaX * px(z_m| x_{k,n}) ]
            lam_val = LambdaC * clutter_density

            for n in range(len(e)):
                if e[n] == 1:
                    # x[n,:] = (x, vx, y, vy)
                    mean_xy = np.array([x[n, 0], x[n, 2]], dtype=float)
                    lam_val += LambdaX * normal_pdf_2d(zm, mean_xy, self.CovMeas)

            # If lam_val ~ 0 => log -> -inf
            if lam_val < 1e-300:
                return -1e16
            log_like += np.log(lam_val)

        return log_like


    # ---------------------------------------------------------
    # Main Filtering Loop (MCMC-based Particle Algorithm)
    # ---------------------------------------------------------
    def filter(self, measurements, Nmax=5):
        """
        Perform the MCMC-based particle filtering over a sequence of measurements
        according to the approach in Section III-B and Algorithm 1 of the paper.

        Steps:
          1) For k=0..K-1:
             (a) Prediction: sample e_k from e_{k-1} (Eq. (6)) & x_k from x_{k-1} (Eq. (7))
             (b) Joint draw (Eq. (14)) => predicted particles
             (c) Refinement: MCMC step that successively samples each target's {x_{k,n}, e_{k,n}}
                 using q3(...) = p(x_{k,n} | x_{k-1,n}, e_{k,n}) p(e_{k,n} | e_{k-1,n}) (Eq. (15))
             (d) Weight update with p(z_k | x_k) from Eqs. (12)–(13), in log form
             (e) Resampling

        Returns
        -------
        all_particles : list of length K
            all_particles[k] is the final particle set for time k.
        card_mean : array of shape (K,)
            Mean cardinality (number of alive targets) per time step.
        card_std : array of shape (K,)
            Standard deviation of the cardinality estimates.
        """
        K = len(measurements)
        # Initialize from (possibly) q1(...) (Eq. (14))
        # Here we simply start all in a default "dead" state.
        particles_prev = self.init_particles(Nmax=Nmax)

        card_mean = np.zeros(K)
        card_std = np.zeros(K)
        all_particles = []

        # Optional progress bar (from tqdm) for user convenience
        for k in tqdm(range(K), desc="MCMC-PF Filtering"):
            z_k = measurements[k]

            # ---- (a) & (b) Prediction Step & Joint Draw (Eq. (14)) ----
            predicted = []
            for i in range(self.Np):
                x_old = particles_prev[i]['x']
                e_old = particles_prev[i]['e']
                w_old = particles_prev[i]['w']  # log-weight from previous

                # Sample e_k from e_{k-1} (Eq. (6))
                e_new = self.sample_existence(e_old)
                # Sample x_k from x_{k-1}, e_k, e_{k-1} (Eq. (7))
                x_new = self.sample_motion(x_old, e_old, e_new)

                predicted.append({'x': x_new, 'e': e_new, 'w': w_old})

            # ---- (c) Refinement Step (Eq. (15)) ----
            # For each particle, run a short MCMC chain that updates each target's {x_{k,n}, e_{k,n}} successively.
            # We skip re-drawing x_{k-1}, e_{k-1} (q2 not used in this simplified version).
            for i in range(self.Np):
                x_cur = predicted[i]['x'].copy()
                e_cur = predicted[i]['e'].copy()

                # Basic MCMC with Nburn + thinning * 1 (or more) steps
                # Each iteration, we cycle through targets n=1..Nmax
                for _ in range(self.Nburn + self.thinning):
                    for n in range(Nmax):
                        e_oldn = e_cur[n]
                        x_oldn = x_cur[n, :].copy()

                        # Propose a local change in existence or state:
                        # The article uses q3(...) = p(x_k,n | x_{k-1,n}, e_k,n) p(e_k,n| e_{k-1,n}).
                        # Here, we do a small random move if alive, or keep x_death if dead.
                        if e_oldn == 1:
                            # local small perturbation
                            x_prop_n = x_oldn + 0.05 * randn(4)
                        else:
                            x_prop_n = self.x_death.copy()

                        # Evaluate log-likelihood before
                        ll_old = self.log_likelihood_measurements(z_k, x_cur, e_cur, self.LambdaC, self.LambdaX)

                        # Modify just the nth target's state
                        x_test = x_cur.copy()
                        x_test[n, :] = x_prop_n
                        # We could also consider flipping e_cur[n] with small probability
                        # but typically that would require referencing e_{k-1,n}; omitted for brevity.

                        ll_new = self.log_likelihood_measurements(z_k, x_test, e_cur, self.LambdaC, self.LambdaX)

                        # Accept/reject (Metropolis-Hastings)
                        alpha = np.exp(ll_new - ll_old)
                        if np.random.rand() < alpha:
                            x_cur[n, :] = x_prop_n

                # After MCMC chain, compute final log-likelihood
                w_final = self.log_likelihood_measurements(z_k, x_cur, e_cur, self.LambdaC, self.LambdaX)
                predicted[i]['x'] = x_cur
                predicted[i]['e'] = e_cur
                predicted[i]['w'] = w_final

            # ---- (d) Resampling ----
            weights_log = np.array([p['w'] for p in predicted])
            wmax = np.max(weights_log)
            # Convert log-weights to normalized weights
            weights_lin = np.exp(weights_log - wmax)
            weights_lin /= np.sum(weights_lin)

            indices = choice(self.Np, size=self.Np, replace=True, p=weights_lin)
            new_particles = []
            for idx in indices:
                # Copy the chosen particle; set new log-weight=0 after resampling
                new_particles.append({
                    'x': predicted[idx]['x'].copy(),
                    'e': predicted[idx]['e'].copy(),
                    'w': 0.0
                })

            # ---- (e) Cardinality Estimation ----
            card_est = [np.sum(p['e']) for p in new_particles]
            card_mean[k] = np.mean(card_est)
            card_std[k] = np.std(card_est)

            # Move to next time step
            particles_prev = new_particles
            all_particles.append(new_particles)

        return all_particles, card_mean, card_std

# ----------------------------------------------------------------------
# 4) HELPER FUNCTIONS FOR ESTIMATION AND PLOTTING
# ----------------------------------------------------------------------
def extract_average_particle_tracks(particle_set):
    """
    Compute the weighted average of the particle states (x) and
    the (fractional) existence (e) at a single time step.

    Parameters
    ----------
    particle_set : list of dict
        Each dict has keys:
          'x': (Nmax, 4) - the kinematic state,
          'e': (Nmax,) in {0,1} - existence indicators,
          'w': float (the log-weight).

    Returns
    -------
    x_avg : numpy array, shape=(Nmax, 4)
        The weighted average of the kinematic states.
    e_avg : numpy array, shape=(Nmax,)
        The (fractional) weighted average of the existences, i.e. in [0,1].
        (If you want hard 0/1, you can threshold this later.)
    """

    # 1) Gather log-weights in a NumPy array
    log_w = np.array([p['w'] for p in particle_set], dtype=float)

    # 2) Convert to linear weights safely (avoid exp underflow)
    w_max = np.max(log_w)
    w_lin = np.exp(log_w - w_max)

    # 3) Normalize the weights
    w_lin_sum = np.sum(w_lin)
    if w_lin_sum < 1e-300:
        # If all weights are extremely small or -inf, fallback:
        w_lin = np.ones_like(w_lin) / len(w_lin)
        w_lin_sum = 1.0
    else:
        w_lin /= w_lin_sum

    # 4) Compute weighted average for x and e
    #    We'll accumulate sums in arrays
    #    x_sum will be shape (Nmax, 4), e_sum will be shape (Nmax,)
    Nmax = particle_set[0]['x'].shape[0]
    x_sum = np.zeros((Nmax, 4), dtype=float)
    e_sum = np.zeros(Nmax, dtype=float)

    for i, part in enumerate(particle_set):
        x_sum += w_lin[i] * part['x']
        e_sum += w_lin[i] * part['e']

    x_avg = x_sum  # Weighted mean
    e_avg = e_sum  # Weighted mean in [0,1]

    return x_avg, e_avg

def get_estimated_tracks_over_time(all_particles):
    """
    Use the weighted average of the particles at each time k to produce
    the estimated track over time.

    Parameters
    ----------
    all_particles : list of lists
        all_particles[k] is the particle set at time k, 
        i.e. a list of dicts [{'x':..., 'e':..., 'w':...}, ...].

    Returns
    -------
    x_est : numpy array, shape=(K, Nmax, 4)
        Weighted-average states at each time k.
    e_est : numpy array, shape=(K, Nmax)
        Weighted-average existences (in [0,1]) at each time k.
        (You can threshold later if you need strict 0/1.)
    """

    K = len(all_particles)
    Nmax = all_particles[0][0]['x'].shape[0]

    x_est = np.zeros((K, Nmax, 4), dtype=float)
    e_est = np.zeros((K, Nmax), dtype=float)

    for k in range(K):
        x_k, e_k = extract_average_particle_tracks(all_particles[k])
        x_est[k, :, :] = x_k
        e_est[k, :] = e_k

    return x_est, e_est

# ----------------------------------------------------------------------
# 5) MAIN SIMULATION AND ENHANCED PLOTTING
# ----------------------------------------------------------------------
def main_simulation(K=80, do_plot=True):
    """
    Executes the complete simulation with enhanced plotting:
     1) Generate ground truth for a maximum of 5 targets (births/deaths, near-constant velocity).
     2) Generate measurements (association-free, Poisson-based).
     3) Instantiate the MCMC-based Particle Filter with chosen parameters (Np, Nburn, etc.).
     4) Perform filtering over the measurement sequence (Algorithm 1).
     5) Extract weighted-average estimates over time.
     6) Compute additional metrics like RMSE.
     7) Optionally plot:
        - True vs. estimated cardinality over time.
        - True target trajectories.
        - Estimated target trajectories.
        - Existence probabilities over time.
        - Measurements and estimates at selected time steps.
    """
    # 1) Generate ground truth (Eq. (10), Section IV)
    ground_truth = generate_ground_truth(K=K)

    # 2) Generate measurements (Eqs. (12)–(13), association-free model)
    measurements = generate_measurements(ground_truth)

    # 3) Instantiate MCMC-based Particle Filter with smaller defaults
    #    (the article used Np=4000, Nburn=1000, etc. for final results,
    #     but here we reduce them for a quicker demo).
    pf = MCMCParticleFilter(Np=100,    # fewer particles for demo
                            Nburn=2,  # fewer burn-in steps
                            thinning=1,
                            PB=0.1,
                            PD=0.1)

    # 4) Run the particle filter
    all_particles, card_mean, card_std = pf.filter(measurements, Nmax=5)

    # 5) Extract track estimates (weighted-average tracks) over time
    x_est, e_est = get_estimated_tracks_over_time(all_particles)

    # 6) (Optional) Plot results
    if do_plot:
        import matplotlib.pyplot as plt
        e_true = ground_truth['e_true']  # shape (K, Nmax)
        x_true = ground_truth['x_true']  # shape (K, Nmax, 4)
        true_card = np.sum(e_true, axis=1)

        # (A) Plot the cardinality over time
        plt.figure(figsize=(12, 6))
        plt.plot(true_card, 'k-', label='True Cardinality')
        plt.plot(card_mean, 'r--', label='Estimated Mean Cardinality')
        plt.fill_between(
            range(K),
            card_mean - card_std,
            card_mean + card_std,
            color='r',
            alpha=0.2,
            label='±1 std'
        )
        plt.xlabel('Time Step')
        plt.ylabel('Cardinality')
        plt.title('Cardinality: True vs. Estimated Over Time')
        plt.legend()
        plt.grid(True)
        plt.tight_layout()
        plt.show()

        # (B) Plot the true trajectories in x-y plane
        plt.figure(figsize=(6, 6))
        for n in range(e_true.shape[1]):
            alive_times = np.where(e_true[:, n] == 1)[0]
            if len(alive_times) > 0:
                # Extract x and y positions
                xvals = x_true[alive_times, n, 0]
                yvals = x_true[alive_times, n, 2]

                # Plot true trajectory
                plt.plot(xvals, yvals, '-o', label=f"True T{n+1}")

                # Mark birth and death
                k_start = alive_times[0]
                k_stop = alive_times[-1]

                plt.plot(x_true[k_start, n, 0],
                         x_true[k_start, n, 2],
                         'ko',  # Black circle for birth
                         markersize=8)
                if k_stop < (K - 1):
                    plt.plot(x_true[k_stop, n, 0],
                             x_true[k_stop, n, 2],
                             'k^',  # Black triangle for death
                             markersize=8)

        plt.xlabel('x [m]')
        plt.ylabel('y [m]')
        plt.xlim([0, 5000])
        plt.ylim([0, 5000])
        plt.title('True Target Trajectories in x-y Plane')
        plt.legend()
        plt.grid(True)
        plt.tight_layout()
        plt.show()

        # (C) Plot the estimated trajectories in x-y plane
        plt.figure(figsize=(6, 6))
        for n in range(e_est.shape[1]):
            alive_times_est = np.where(e_est[:, n] > 0.5)[0]
            if len(alive_times_est) > 0:
                x_est_vals = x_est[alive_times_est, n, 0]
                y_est_vals = x_est[alive_times_est, n, 2]

                # Plot estimated trajectory
                plt.plot(x_est_vals, y_est_vals, '--x', label=f"Estimated T{n+1}")

                # Mark birth and death estimates
                k_start_est = alive_times_est[0]
                k_stop_est = alive_times_est[-1]

                plt.plot(x_est[k_start_est, n, 0],
                         x_est[k_start_est, n, 2],
                         'rx',  # Red X for birth
                         markersize=8)
                if k_stop_est < K:
                    plt.plot(x_est[k_stop_est, n, 0],
                             x_est[k_stop_est, n, 2],
                             'r^',  # Red triangle for death
                             markersize=8)

        plt.xlabel('x [m]')
        plt.ylabel('y [m]')
        plt.xlim([0, 5000])
        plt.ylim([0, 5000])
        plt.title('Estimated Target Trajectories in x-y Plane')
        plt.legend()
        plt.grid(True)
        plt.tight_layout()
        plt.show()

        # (D) Compute and plot RMSE
        rmse = compute_rmse(x_true, x_est, e_true, e_est)
        print(f'RMSE over all time steps and targets: {rmse:.2f} meters')

        # (E) Plot existence probabilities over time
        plot_existence_probabilities(e_true, e_est, K, Nmax)

        # (F) Plot measurements and estimates at selected time steps
        selected_time_steps = [24, 25, 50, 51, 75]
        plot_measurements_and_estimates(measurements, x_true, x_est, e_true, e_est, selected_time_steps, K, Nmax)

    return ground_truth, measurements, (card_mean, card_std), x_est, e_est

# ----------------------------------------------------------------------
# 6) HELPER FUNCTIONS FOR ENHANCED PLOTTING
# ----------------------------------------------------------------------
def compute_rmse(x_true, x_est, e_true, e_est):
    """
    Compute the Root Mean Square Error (RMSE) between true and estimated positions.

    Parameters
    ----------
    x_true : ndarray, shape=(K, Nmax, 4)
        True states.
    x_est : ndarray, shape=(K, Nmax, 4)
        Estimated states.
    e_true : ndarray, shape=(K, Nmax)
        True existence indicators.
    e_est : ndarray, shape=(K, Nmax)
        Estimated existence indicators.

    Returns
    -------
    rmse : float
        Root Mean Square Error.
    """
    total_error = 0.0
    count = 0
    for k in range(x_true.shape[0]):
        for n in range(x_true.shape[1]):
            if e_true[k, n] == 1 and e_est[k, n] > 0.5:
                true_pos = x_true[k, n, [0, 2]]
                est_pos = x_est[k, n, [0, 2]]
                error = np.linalg.norm(true_pos - est_pos)
                total_error += error ** 2
                count += 1
    rmse = np.sqrt(total_error / count) if count > 0 else None
    return rmse

def plot_existence_probabilities(e_true, e_est, K, Nmax):
    """
    Plot the existence probabilities of each target over time.

    Parameters
    ----------
    e_true : ndarray, shape=(K, Nmax)
        True existence indicators.
    e_est : ndarray, shape=(K, Nmax)
        Estimated existence probabilities.
    K : int
        Number of timesteps.
    Nmax : int
        Number of targets.
    """
    plt.figure(figsize=(12, 6))
    for n in range(Nmax):
        plt.plot(range(K), e_est[:, n], label=f'Target {n+1} Existence Probability')
    plt.xlabel('Time Step')
    plt.ylabel('Existence Probability')
    plt.title('Existence Probabilities Over Time')
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.show()

def plot_measurements_and_estimates(measurements, x_true, x_est, e_true, e_est, selected_time_steps, K, Nmax):
    """
    Plot measurements and estimates at selected time steps.

    Parameters
    ----------
    measurements : list of ndarray
        List of measurements at each time step.
    x_true : ndarray, shape=(K, Nmax, 4)
        True states.
    x_est : ndarray, shape=(K, Nmax, 4)
        Estimated states.
    e_true : ndarray, shape=(K, Nmax)
        True existence indicators.
    e_est : ndarray, shape=(K, Nmax)
        Estimated existence indicators.
    selected_time_steps : list of int
        Time steps to plot.
    K : int
        Total number of time steps.
    Nmax : int
        Number of targets.
    """
    for k in selected_time_steps:
        if k >= K:
            print(f"Time step {k} is out of range. Skipping.")
            continue
        plt.figure(figsize=(8, 8))
        plt.xlim(0, 5000)
        plt.ylim(0, 5000)
        plt.xlabel('x [m]')
        plt.ylabel('y [m]')
        plt.title(f'Measurements and Estimates at Time Step {k+1}')

        # Plot measurements
        z_k = measurements[k]
        if len(z_k) > 0:
            plt.scatter(z_k[:, 0], z_k[:, 1], c='k', marker='.', label='Measurements')

        # Plot true target positions
        for n in range(Nmax):
            if e_true[k, n] == 1:
                plt.scatter(x_true[k, n, 0], x_true[k, n, 2],
                            marker='o', edgecolors='b', facecolors='none', label=f'True T{n+1}' if k == selected_time_steps[0] else "")

        # Plot estimated target positions
        for n in range(Nmax):
            if e_est[k, n] > 0.5:
                plt.scatter(x_est[k, n, 0], x_est[k, n, 2],
                            marker='x', color='r', label=f'Est T{n+1}' if k == selected_time_steps[0] else "")

        plt.legend()
        plt.grid(True)
        plt.tight_layout()
        plt.show()

# ----------------------------------------------------------------------
# 7) RUN EXAMPLE
# ----------------------------------------------------------------------
if __name__ == "__main__":
    # Example: single run, T=100
    T = 100
    gt, meas, (c_mean, c_std), x_estimated, e_estimated = main_simulation(K=T, do_plot=True)

    # The user can adapt the code to do multiple runs or compute advanced metrics.