In [3]:
"""
Simple deterministic tracker for discrete-spectrum branches.

- Builds a rectangular cost matrix between previous branches and current
  detections using Euclidean distances in ζ- and r-planes.
- Uses Hungarian algorithm (scipy.optimize.linear_sum_assignment) to solve
  the assignment.
- Creates new branches for unmatched detections and marks gaps for unmatched
  branches.
- Produces Plotly 2D and 3D visualizations.

Notes
-----
* The input file name is kept as-is.
* Comments are intentionally concise and user-facing.
"""

from __future__ import annotations

import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from scipy.optimize import linear_sum_assignment

# ---------------------------------------------------------------------------
# 1) Input data (file name is not changed)
# ---------------------------------------------------------------------------
FILENAME = (
    "NFT_DiscreteSpectrum.dat"
)
COLS = [
    "z", "x", "h", "Re_r", "Im_r", "T", "f",
    "abs_r", "abs_a", "abs_b", "abs_aprime",
    "ds_count", "energyRelativeError",
]

df = pd.read_csv(FILENAME, sep=r"\s+", skiprows=1, names=COLS)

df["zeta_real"] = df["x"].astype(float)
df["zeta_imag"] = df["h"].astype(float)
unique_z = np.sort(df["z"].unique())


# ---------------------------------------------------------------------------
# 2) Group storage and parameters  -> wrapped into a function
# ---------------------------------------------------------------------------
def run_discrete_tracker_simple(
    df_in: pd.DataFrame,
    z_levels: np.ndarray,
    max_gap: int = 4,                 # how many consecutive layers can be missed
    K: float = 2.0,                   # adaptive gating: distance > K*v_avg*dz -> invalid
    alpha: float = 0.5,               # exponential smoothing of velocity magnitude
    big_cost: float = 1e6,            # "infinite" cost
    birth_threshold: float = 0.001,   # newborn gate (separately for ζ and r)
    weight_r: float = 0.9             # weight of r-coordinates in total cost
) -> dict[int, list[tuple[float, float, float, float, float, float]]]:
    """
    Run a simple deterministic tracker over all z-layers.

    Returns
    -------
    dict
        Maps group_id -> list of tuples (z, zr, zi, rr, ri, |r|).
    """
    groups: dict[int, list[tuple[float, float, float, float, float, float]]] = {}
    next_group_id = 0

    # Per-track state:
    # - last, prev for ζ: (z, zr, zi)
    # - last_r, prev_r for r: (z, rr, ri)
    # - v_avg, v_avg_r (speed magnitudes), gap_count, active flag
    group_info: dict[int, dict] = {}

    def init_group(g_id: int) -> None:
        """Initialize a new empty group and its service structure."""
        groups[g_id] = []
        group_info[g_id] = {
            "last": None,       # (z, zr, zi) for ζ
            "prev": None,
            "v_avg": 0.0,
            "last_r": None,     # (z, rr, ri) for r
            "prev_r": None,
            "v_avg_r": 0.0,
            "gap_count": 0,
            "active": True,
        }

    def add_point_to_group(
        g_id: int,
        z_val: float,
        zr: float,
        zi: float,
        rr: float,
        ri: float,
        ra: float,
    ) -> None:
        """Append a point to a group and update last/prev for ζ and r."""
        groups[g_id].append((z_val, zr, zi, rr, ri, ra))
        info = group_info[g_id]

        # Update ζ
        if info["last"] is not None:
            info["prev"] = info["last"]
        info["last"] = (z_val, zr, zi)

        # Update r
        if info["last_r"] is not None:
            info["prev_r"] = info["last_r"]
        info["last_r"] = (z_val, rr, ri)

        info["gap_count"] = 0
        info["active"] = True

        # Update v_avg for ζ
        if info["prev"] is not None:
            z_prev, zr_prev, zi_prev = info["prev"]
            dz = z_val - z_prev
            if abs(dz) > 1e-15:
                vx = (zr - zr_prev) / dz
                vy = (zi - zi_prev) / dz
                v_mod = np.hypot(vx, vy)
                info["v_avg"] = (1.0 - alpha) * info["v_avg"] + alpha * v_mod

        # Update v_avg_r for r
        if info["prev_r"] is not None:
            z_prev, rr_prev, ri_prev = info["prev_r"]
            dz = z_val - z_prev
            if abs(dz) > 1e-15:
                vx_r = (rr - rr_prev) / dz
                vy_r = (ri - ri_prev) / dz
                v_mod_r = np.hypot(vx_r, vy_r)
                info["v_avg_r"] = (1.0 - alpha) * info["v_avg_r"] + alpha * v_mod_r

    def predict_position(g_id: int, z_next: float) -> tuple[float, float]:
        """
        Predict (zr_pred, zi_pred) at z_next using linear extrapolation in ζ.
        If there is no prev point, return last.
        """
        info = group_info[g_id]
        last_pt = info["last"]
        prev_pt = info["prev"]
        if last_pt is None:
            return (0.0, 0.0)
        if prev_pt is None:
            return (last_pt[1], last_pt[2])
        z_last, zr_last, zi_last = last_pt
        z_prev, zr_prev, zi_prev = prev_pt
        dz = z_last - z_prev
        if abs(dz) < 1e-15:
            return (zr_last, zi_last)
        vx = (zr_last - zr_prev) / dz
        vy = (zi_last - zi_prev) / dz
        dt = z_next - z_last
        return (zr_last + vx * dt, zi_last + vy * dt)

    def predict_position_r(g_id: int, z_next: float) -> tuple[float, float]:
        """
        Predict (rr_pred, ri_pred) at z_next using linear extrapolation in r.
        If there is no prev_r point, return last_r.
        """
        info = group_info[g_id]
        last_pt_r = info["last_r"]
        prev_pt_r = info["prev_r"]
        if last_pt_r is None:
            return (0.0, 0.0)
        if prev_pt_r is None:
            return (last_pt_r[1], last_pt_r[2])
        z_last, rr_last, ri_last = last_pt_r
        z_prev, rr_prev, ri_prev = prev_pt_r
        dz = z_last - z_prev
        if abs(dz) < 1e-15:
            return (rr_last, ri_last)
        vx_r = (rr_last - rr_prev) / dz
        vy_r = (ri_last - ri_prev) / dz
        dt = z_next - z_last
        return (rr_last + vx_r * dt, ri_last + vy_r * dt)

    # -----------------------------------------------------------------------
    # 3) Main loop over z-layers
    # -----------------------------------------------------------------------
    for i, z_val in enumerate(z_levels):
        current_slice = df_in[df_in["z"] == z_val]
        curr_zr = current_slice["zeta_real"].values
        curr_zi = current_slice["zeta_imag"].values
        curr_rr = current_slice["Re_r"].values
        curr_ri = current_slice["Im_r"].values
        curr_ra = current_slice["abs_r"].values

        if i == 0:
            # First layer: each point spawns a new group
            for j in range(len(current_slice)):
                g_id = next_group_id
                init_group(g_id)
                add_point_to_group(
                    g_id, z_val,
                    curr_zr[j], curr_zi[j],
                    curr_rr[j], curr_ri[j], curr_ra[j],
                )
                next_group_id += 1
            continue

        # Active groups (gap_count < max_gap)
        candidate_groups = [
            g for g, info in group_info.items() if info["gap_count"] < max_gap
        ]
        n_prev = len(candidate_groups)
        n_curr = len(curr_zr)

        if n_prev == 0:
            # No active groups: each detection starts a new group
            for j in range(n_curr):
                g_id = next_group_id
                init_group(g_id)
                add_point_to_group(
                    g_id, z_val,
                    curr_zr[j], curr_zi[j],
                    curr_rr[j], curr_ri[j], curr_ra[j],
                )
                next_group_id += 1

            # Age all groups that were not considered in this layer
            for g_id, info in group_info.items():
                if info["gap_count"] < max_gap:
                    info["gap_count"] += 1
                    if info["gap_count"] >= max_gap:
                        info["active"] = False
            continue

        # Build cost matrix (n_prev x n_curr)
        cost_matrix = np.zeros((n_prev, n_curr), dtype=float)

        for ip, g_id in enumerate(candidate_groups):
            info = group_info[g_id]
            # Predictions for ζ and r
            zr_pred, zi_pred = predict_position(g_id, z_val)
            rr_pred, ri_pred = predict_position_r(g_id, z_val)

            v_avg_z = info["v_avg"]
            v_avg_r = info["v_avg_r"]

            last_pt = info["last"]
            z_last = last_pt[0] if last_pt is not None else z_val
            dz = abs(z_val - z_last)

            # Newborn flag (no prev in ζ or r)
            is_new_born = (info["prev"] is None or info["prev_r"] is None)

            for ic in range(n_curr):
                # Distance in ζ-plane
                d_z = np.hypot(zr_pred - curr_zr[ic], zi_pred - curr_zi[ic])
                # Distance in r-plane
                d_r = np.hypot(rr_pred - curr_rr[ic], ri_pred - curr_ri[ic])

                if is_new_born:
                    # For a newborn, both distances must be below birth_threshold
                    if (d_z > birth_threshold) or (d_r > birth_threshold):
                        cost_matrix[ip, ic] = big_cost
                    else:
                        cost_matrix[ip, ic] = d_z + weight_r * d_r
                else:
                    # Adaptive thresholds per plane
                    thresh_z = K * v_avg_z * dz if (v_avg_z > 1e-15 and dz > 1e-15) else np.inf
                    thresh_r = K * v_avg_r * dz if (v_avg_r > 1e-15 and dz > 1e-15) else np.inf
                    if (d_z > thresh_z) or (d_r > thresh_r):
                        cost_matrix[ip, ic] = big_cost
                    else:
                        cost_matrix[ip, ic] = d_z + weight_r * d_r

        row_ind, col_ind = linear_sum_assignment(cost_matrix)

        used_groups: set[int] = set()
        used_points: set[int] = set()

        # Apply assignments
        for r_idx, c_idx in zip(row_ind, col_ind):
            if cost_matrix[r_idx, c_idx] >= big_cost:
                continue
            g_id = candidate_groups[r_idx]
            add_point_to_group(
                g_id, z_val,
                curr_zr[c_idx], curr_zi[c_idx],
                curr_rr[c_idx], curr_ri[c_idx], curr_ra[c_idx],
            )
            used_groups.add(g_id)
            used_points.add(c_idx)

        # Unmatched detections -> new groups
        for c_idx in range(n_curr):
            if c_idx not in used_points:
                g_id = next_group_id
                init_group(g_id)
                add_point_to_group(
                    g_id, z_val,
                    curr_zr[c_idx], curr_zi[c_idx],
                    curr_rr[c_idx], curr_ri[c_idx], curr_ra[c_idx],
                )
                next_group_id += 1

        # Age unmatched groups
        for g_id in candidate_groups:
            if g_id not in used_groups:
                group_info[g_id]["gap_count"] += 1
                if group_info[g_id]["gap_count"] >= max_gap:
                    group_info[g_id]["active"] = False

        # Age non-candidate groups
        for g_id, info in group_info.items():
            if g_id not in candidate_groups and info["gap_count"] < max_gap:
                info["gap_count"] += 1
                if info["gap_count"] >= max_gap:
                    info["active"] = False

    # Return raw trajectories
    return groups


def sort_group_points(
    points: list[tuple[float, float, float, float, float, float]]
) -> tuple[np.ndarray, ...]:
    """Sort a group's points by z and return columns as arrays."""
    arr = np.array(points, dtype=object)
    z_vals = arr[:, 0].astype(float)
    idx_srt = np.argsort(z_vals)
    z_srt = z_vals[idx_srt]
    zr_srt = arr[:, 1][idx_srt].astype(float)
    zi_srt = arr[:, 2][idx_srt].astype(float)
    rr_srt = arr[:, 3][idx_srt].astype(float)
    ri_srt = arr[:, 4][idx_srt].astype(float)
    ra_srt = arr[:, 5][idx_srt].astype(float)
    return z_srt, zr_srt, zi_srt, rr_srt, ri_srt, ra_srt


# --- Run tracking and get groups ---
groups = run_discrete_tracker_simple(
    df,
    unique_z,
    max_gap=4,
    K=2.0,
    alpha=0.5,
    big_cost=1e6,
    birth_threshold=0.001,
    weight_r=0.9,
)

# ---------------------------------------------------------------------------
# 4) Filter groups and build figures
# ---------------------------------------------------------------------------
min_group_length = 2   # minimal number of points per group
coord_precision = 6    # rounding used to drop duplicate points

filtered_groups: dict[int, list] = {}
seen_points: set[tuple[float, float, float]] = set()

ind = 0
for g_id, pts in groups.items():
    # Length filter
    if len(pts) < min_group_length:
        continue

    # Deduplicate points by (z, zr, zi)
    unique_pts = []
    for pt in pts:
        pt_key = (
            round(pt[0], coord_precision),
            round(pt[1], coord_precision),
            round(pt[2], coord_precision),
        )
        if pt_key not in seen_points:
            seen_points.add(pt_key)
            unique_pts.append(pt)

    if len(unique_pts) >= min_group_length:
        filtered_groups[ind] = unique_pts
        ind += 1

# ---------------------------------------------------------------------------
# 5) Plotly figures
# ---------------------------------------------------------------------------
colors = px.colors.qualitative.Vivid  # color palette

# Initialize figures
figures: dict[str, go.Figure] = {
    "re_zeta": go.Figure(),
    "im_zeta": go.Figure(),
    "re_r": go.Figure(),
    "im_r": go.Figure(),
    "abs_r": go.Figure(),
    "3d_zeta": go.Figure(),
    "3d_r": go.Figure(),
}

for idx, (g_id, pts) in enumerate(filtered_groups.items()):
    z_vals, zr, zi, rr, ri, ra = sort_group_points(pts)
    color = colors[idx % len(colors)]

    # Common trace kwargs
    common_args = {
        "x": z_vals,
        "mode": "lines+markers",
        "marker": {"size": 3, "color": color},
        "line": {"color": color},
        "name": f"Branch {g_id}",
    }

    # 2D
    figures["re_zeta"].add_trace(go.Scatter(y=zr, **common_args))
    figures["im_zeta"].add_trace(go.Scatter(y=zi, **common_args))
    figures["re_r"].add_trace(go.Scatter(y=rr, **common_args))
    figures["im_r"].add_trace(go.Scatter(y=ri, **common_args))
    figures["abs_r"].add_trace(go.Scatter(y=np.sqrt(rr**2 + ri**2), **common_args))

    # 3D
    for fig_name, coords in [("3d_zeta", (zr, zi)), ("3d_r", (rr, ri))]:
        figures[fig_name].add_trace(
            go.Scatter3d(
                x=coords[0],
                y=z_vals,
                z=coords[1],
                mode="lines+markers",
                marker={"size": 3, "color": color},
                line={"color": color},
                name=f"Branch {g_id}",
            )
        )

# Layout config
layout_config: dict[str, tuple[str, str, str, str | None]] = {
    "re_zeta": ("Re(ζ) vs z", "z", "Re(ζ)", None),
    "im_zeta": ("Im(ζ) vs z", "z", "Im(ζ)", None),
    "re_r": ("Re(r) vs z", "z", "Re(r)", None),
    "im_r": ("Im(r) vs z", "z", "Im(r)", None),
    "abs_r": ("|r| vs z", "z", "|r|", None),
    "3d_zeta": ("3D ζ-Space", "Re(ζ)", "z", "Im(ζ)"),
    "3d_r": ("3D r-Space", "Re(r)", "z", "Im(r)"),
}

for fig_name, fig in figures.items():
    title, x_title, y_title, z_title = layout_config[fig_name]
    fig.update_layout(title=title, xaxis_title=x_title, yaxis_title=y_title)

    # Make 3D canvases large for high-quality exports
    if fig_name.startswith("3d"):
        fig.update_layout(
            autosize=False,
            width=1600,
            height=1000,
            margin=dict(l=10, r=10, t=40, b=10),
            scene=dict(
                xaxis_title=x_title,
                yaxis_title=y_title,
                zaxis_title=z_title,
            ),
        )

    fig.show()


In [4]:
"""
Kalman-based discrete-spectrum tracker with χ²-gating and Hungarian assignment.

Key additions vs the simple tracker:
- Adaptive Kalman filter with online tuning of R and Q using innovation statistics.
- χ² gating (Mahalanobis distance) to prune unlikely pairs (branch–detection).
- Birth confirmation: require M consecutive hits to confirm a new track.
- Continuation priority: optional cost shifts for continuous tracks; penalties
  for gaps and unconfirmed tracks.
- Whitening-based normalization across heterogeneous channels (ζ and r).

Notes
-----
* The input file name is kept as-is.
* Parameter values are preserved. Comments are English-only and for end users.
"""

from __future__ import annotations

import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from scipy.optimize import linear_sum_assignment
from scipy.stats import chi2

# ---------------------------------------------------------------------------
# 1) Input data (file name is not changed)
# ---------------------------------------------------------------------------
FILENAME = (
    "SSFM_GLE_Conf=1_Cores=1_N=100000_M=8192_L=5000.00_T=1000.0_"
    "C=0.0_B=0.5_G=1.0_P=0.00000_tau=0.000_Chirp=0.0_omp=8_"
    "NFT_DiscreteSpectrum2.dat"
)
COLS = [
    "z", "x", "h", "Re_r", "Im_r", "T", "f",
    "abs_r", "abs_a", "abs_b", "abs_aprime",
    "ds_count", "energyRelativeError",
]
df = pd.read_csv(FILENAME, sep=r"\s+", skiprows=1, names=COLS)

# Consistent names as in your earlier code
df["zeta_real"] = df["x"].astype(float)
df["zeta_imag"] = df["h"].astype(float)
unique_z = np.sort(df["z"].unique())

# ---------------------------------------------------------------------------
# 2) Parameters and matrices
# ---------------------------------------------------------------------------
STATE_DIM = 8
MEAS_DIM = 4  # y = [zr, zi, rr, ri]

# Working tunings (kept as provided)
max_gap = 4
gate_p = 0.9
alpha_R = 0.5
alpha_Q = 0.001
eps_var = 1e-10

meas_w_zeta = 1.26
meas_w_r = 3.6
R_scale_zeta = 0.18
R_scale_r = 0.5
Q_pos_scale = 0.34
Q_vel_scale = 5.0
P0_pos_scale = 51.5
P0_vel_scale = 0.35
inv_reg = 1.35e-11

# New controls: birth confirmation and priorities
birth_confirm = 2           # consecutive hits required to confirm a track
new_track_penalty = 2.0     # additive cost penalty for unconfirmed tracks
gap_penalty = 0.25          # additive penalty per gap
cont_bonus = 1.0            # cost bonus (subtraction) for continuous tracks (gap == 0)

# Base covariances in unnormalized units
R_init = np.diag([1e-4, 1e-4, 1e-4, 1e-4]).astype(float)
Q_init = np.diag([1e-6, 1e-6, 1e-6, 1e-6, 1e-8, 1e-8, 1e-8, 1e-8]).astype(float)
P0 = np.diag([1e-3, 1e-3, 1e-3, 1e-3, 1e-2, 1e-2, 1e-2, 1e-2]).astype(float)

# Overridden tuned values (kept intact)
gate_p = 0.997634510774889
alpha_R = 0.834963789425233
alpha_Q = 0.0015074814080782235
meas_w_zeta = 0.9627216831558346
meas_w_r = 0.13684269204623412
R_scale_zeta = 1.39407519019618
R_scale_r = 0.45258104260984716
Q_pos_scale = 0.011471121170672172
Q_vel_scale = 0.01692497814723947
P0_pos_scale = 0.0036664929339563805
P0_vel_scale = 0.16310269830426868
inv_reg = 5.406352247136302e-06
birth_confirm = 2
new_track_penalty = 6.110447779202005
gap_penalty = 0.06551007866202696
cont_bonus = 2.308143674570862

# Observation matrix
H = np.zeros((MEAS_DIM, STATE_DIM), dtype=float)
H[0, 0] = 1.0  # zr
H[1, 1] = 1.0  # zi
H[2, 2] = 1.0  # rr
H[3, 3] = 1.0  # ri

gate_thresh = chi2.ppf(gate_p, df=MEAS_DIM)

# ---------------------------------------------------------------------------
# 3) Kalman filter + adaptations
# ---------------------------------------------------------------------------
def make_F(dt: float) -> np.ndarray:
    """Build state transition for constant-velocity model with step dt."""
    F = np.eye(STATE_DIM, dtype=float)
    F[0, 4] = dt  # zr
    F[1, 5] = dt  # zi
    F[2, 6] = dt  # rr
    F[3, 7] = dt  # ri
    return F


def kf_predict(x: np.ndarray, P: np.ndarray, Q: np.ndarray, dt: float):
    """Kalman prediction step with dt-aware process covariance scaling."""
    F = make_F(dt)
    dt = float(max(dt, 1e-12))
    Qeff = Q.copy()
    Qeff[0:4, 0:4] *= dt**2
    Qeff[4:8, 4:8] *= dt
    x_pred = F @ x
    P_pred = F @ P @ F.T + Qeff
    return x_pred, P_pred


def kf_update(
    x_pred: np.ndarray, P_pred: np.ndarray, y: np.ndarray, R: np.ndarray
):
    """Kalman update step; returns updated (x, P) and (innovation, S)."""
    S = H @ P_pred @ H.T + R
    S = (S + S.T) * 0.5
    K = P_pred @ H.T @ np.linalg.inv(S)
    innov = y - (H @ x_pred)
    x_upd = x_pred + K @ innov
    P_upd = (np.eye(STATE_DIM) - K @ H) @ P_pred
    P_upd = (P_upd + P_upd.T) * 0.5
    return x_upd, P_upd, innov, S


def adapt_R(R_old: np.ndarray, innov: np.ndarray, S: np.ndarray, alpha=alpha_R):
    """Exponential adaptation of measurement covariance using innovations."""
    innov_outer = np.outer(innov, innov)
    R_new = (1.0 - alpha) * R_old + alpha * (innov_outer + eps_var * np.eye(MEAS_DIM))
    # Keep only diagonal off-diagonal parts small to prevent drift
    R_new = 0.5 * (R_new + np.diag(np.diag(R_new)))
    return R_new


def adapt_Q(
    Q_old: np.ndarray,
    v_prev: np.ndarray | None,
    v_new: np.ndarray,
    dt: float,
    alpha=alpha_Q,
):
    """
    Heuristic Q adaptation from estimated accelerations of the 4 channels.
    Keeps a floor to avoid degeneracies.
    """
    if v_prev is None or dt <= 0:
        return Q_old, v_new
    a = (v_new - v_prev) / max(dt, 1e-12)
    a2 = np.maximum(a * a, eps_var)
    q_pos = 1e-3 * a2
    q_vel = 1e-1 * a2
    Q_candidate = np.diag(np.concatenate([q_pos[:4], q_vel[:4]]))
    Q_new = (1.0 - alpha) * Q_old + alpha * Q_candidate
    for i in range(STATE_DIM):
        Q_new[i, i] = max(Q_new[i, i], 1e-12)
    return Q_new, v_new


def add_obs_to_group(
    g_id: int, z_val: float, y_raw: np.ndarray, abs_r: float, groups: dict
) -> None:
    """Append raw observation to the group history: (z, zr, zi, rr, ri, |r|)."""
    zr, zi, rr, ri = y_raw
    groups[g_id].append((z_val, zr, zi, rr, ri, abs_r))


# ---------------------------------------------------------------------------
# 4) Tracker runner (with whitening + birth confirmation + priorities)
# ---------------------------------------------------------------------------
def run_discrete_tracker(
    df_in: pd.DataFrame,
    z_levels: np.ndarray,
    max_gap: int = max_gap,
    gate_p: float = gate_p,
    alpha_R: float = alpha_R,
    alpha_Q: float = alpha_Q,
    meas_w_zeta: float = 1.0,
    meas_w_r: float = 1.0,
    R_scale_zeta: float = 1.0,
    R_scale_r: float = 1.0,
    Q_pos_scale: float = 1.0,
    Q_vel_scale: float = 1.0,
    P0_pos_scale: float = 1.0,
    P0_vel_scale: float = 1.0,
    max_cost: float = 1e12,
    inv_reg: float = 0.0,
    # New heuristics (tunable)
    birth_confirm: int = birth_confirm,
    new_track_penalty: float = new_track_penalty,
    gap_penalty: float = gap_penalty,
    cont_bonus: float = cont_bonus,
):
    """Run the Kalman-based tracker; returns (groups, states_log)."""
    gate_thresh_local = chi2.ppf(gate_p, df=MEAS_DIM)

    # Robust per-channel scales (MAD) for whitening
    def _robust_scale(a: np.ndarray) -> float:
        a = np.asarray(a, float)
        med = np.median(a)
        mad = np.median(np.abs(a - med))
        s = 1.4826 * mad
        if not np.isfinite(s) or s < 1e-12:
            s = np.std(a)
        return float(s if s > 1e-12 else 1.0)

    szr = _robust_scale(df_in["zeta_real"].values)
    szi = _robust_scale(df_in["zeta_imag"].values)
    srr = _robust_scale(df_in["Re_r"].values)
    sri = _robust_scale(df_in["Im_r"].values)

    # Whitening vector
    w4 = np.array([1.0 / szr, 1.0 / szi, 1.0 / srr, 1.0 / sri], dtype=float)

    # Normalize base covariances
    R0n = np.diag((np.diag(R_init) * (w4**2)).astype(float))
    Q0n = Q_init.copy()
    P0n = P0.copy()
    for i in range(4):
        Q0n[i, i] *= (w4[i] ** 2)
        Q0n[i + 4, i + 4] *= (w4[i] ** 2)
        P0n[i, i] *= (w4[i] ** 2)
        P0n[i + 4, i + 4] *= (w4[i] ** 2)

    # Per-channel weights after whitening
    W = np.diag(
        [float(meas_w_zeta), float(meas_w_zeta), float(meas_w_r), float(meas_w_r)]
    ).astype(float)

    # Local state
    groups_local: dict[int, list] = {}
    track_state: dict[int, dict] = {}  # per-track KF state and meta
    states_log: dict[int, list] = {}
    next_gid = 0

    def _apply_birth_scales(st: dict) -> None:
        st["R"][0, 0] *= R_scale_zeta
        st["R"][1, 1] *= R_scale_zeta
        st["R"][2, 2] *= R_scale_r
        st["R"][3, 3] *= R_scale_r
        for i in range(0, 4):
            st["Q"][i, i] *= Q_pos_scale
            st["P"][i, i] *= P0_pos_scale
        for i in range(4, 8):
            st["Q"][i, i] *= Q_vel_scale
            st["P"][i, i] *= P0_vel_scale

    def _start_track(z_val: float, y_norm: np.ndarray) -> int:
        nonlocal next_gid
        g = next_gid
        next_gid += 1
        x0 = np.zeros(STATE_DIM, float)
        x0[0:4] = y_norm
        st = dict(
            state=x0,
            P=P0n.copy(),
            R=R0n.copy(),
            Q=Q0n.copy(),
            last_z=z_val,
            gap=0,
            len=0,
            last_vel=None,
            hit_streak=0,
            miss_streak=0,
            confirmed=False,
        )
        _apply_birth_scales(st)
        track_state[g] = st
        groups_local[g] = []
        return g

    for i, z_val in enumerate(z_levels):
        sl = df_in[df_in["z"] == z_val]
        zr = sl["zeta_real"].values.astype(float)
        zi = sl["zeta_imag"].values.astype(float)
        rr = sl["Re_r"].values.astype(float)
        ri = sl["Im_r"].values.astype(float)
        ra = sl["abs_r"].values.astype(float)

        Y_raw = np.stack([zr, zi, rr, ri], axis=1)         # raw observations
        Y_norm = (Y_raw * w4.reshape(1, 4)).astype(float)  # whitened

        if i == 0:
            # First layer: every detection initializes a track
            for j in range(len(sl)):
                g = _start_track(z_val, Y_norm[j])
                add_obs_to_group(g, z_val, Y_raw[j], ra[j], groups_local)
                st = track_state[g]
                st["len"] += 1
                st["hit_streak"] = 1
                if st["hit_streak"] >= birth_confirm:
                    st["confirmed"] = True
                states_log.setdefault(g, []).append((z_val, st["state"][4:8].copy()))
            continue

        # Candidate tracks (not timed out)
        cand = [g for g, st in track_state.items() if st["gap"] < max_gap]
        n_prev, n_curr = len(cand), len(Y_norm)

        if n_prev == 0:
            # Start new tracks for all detections
            for j in range(n_curr):
                g = _start_track(z_val, Y_norm[j])
                add_obs_to_group(g, z_val, Y_raw[j], ra[j], groups_local)
                st = track_state[g]
                st["len"] += 1
                st["hit_streak"] = 1
                if st["hit_streak"] >= birth_confirm:
                    st["confirmed"] = True
                states_log.setdefault(g, []).append((z_val, st["state"][4:8].copy()))
            # Age non-candidate tracks
            for g, st in track_state.items():
                if g not in cand:
                    st["gap"] = min(st["gap"] + 1, max_gap + 1)
                    st["miss_streak"] += 1
            continue

        # Predictions per candidate
        preds: dict[int, tuple[np.ndarray, np.ndarray]] = {}
        for g in cand:
            st = track_state[g]
            dt = float(z_val - st["last_z"])
            dt = max(dt, 1e-12)
            x_pred, P_pred = kf_predict(st["state"], st["P"], st["Q"], dt)
            preds[g] = (x_pred, P_pred)

        # Cost matrix with χ²-gating and priority shifts
        cost = np.full((n_prev, n_curr), fill_value=max_cost, dtype=float)
        for ip, g in enumerate(cand):
            x_pred, P_pred = preds[g]
            st = track_state[g]
            S = H @ P_pred @ H.T + st["R"]
            S = (S + S.T) * 0.5
            if inv_reg > 0.0:
                S = S + inv_reg * np.eye(MEAS_DIM)
            S_w = W @ S @ W
            Sinvw = np.linalg.inv(S_w)

            # Track-level bias: prefer continuous and confirmed tracks
            bias = 0.0
            if not st["confirmed"]:
                bias += new_track_penalty
            if st["gap"] > 0:
                bias += gap_penalty * float(st["gap"])
            else:
                bias -= cont_bonus

            for jc in range(n_curr):
                innovw = W @ (Y_norm[jc] - (H @ x_pred))
                m2 = float(innovw.T @ Sinvw @ innovw)
                if m2 <= gate_thresh_local:
                    c = m2 + bias
                    cost[ip, jc] = max(0.0, c)

        row_ind, col_ind = linear_sum_assignment(cost)
        matched_prev: set[int] = set()
        matched_curr: set[int] = set()

        # Updates for assigned pairs
        for ip, jc in zip(row_ind, col_ind):
            if cost[ip, jc] >= 0.9 * max_cost:
                continue
            g = cand[ip]
            st = track_state[g]
            dt = float(z_val - st["last_z"])
            dt = max(dt, 1e-12)
            x_pred, P_pred = kf_predict(st["state"], st["P"], st["Q"], dt)

            yN = Y_norm[jc]
            x_upd, P_upd, innov, S = kf_update(x_pred, P_pred, yN, st["R"])
            R_new = adapt_R(st["R"], innov, S, alpha=alpha_R)

            v_new = x_upd[4:8].copy()
            Q_new, v_keep = adapt_Q(st["Q"], st["last_vel"], v_new, dt, alpha=alpha_Q)

            st["state"] = x_upd
            st["P"] = P_upd
            st["R"] = R_new
            st["Q"] = Q_new
            st["last_vel"] = v_keep
            st["last_z"] = z_val
            st["gap"] = 0
            st["len"] += 1
            st["hit_streak"] += 1
            st["miss_streak"] = 0
            if not st["confirmed"] and st["hit_streak"] >= birth_confirm:
                st["confirmed"] = True

            add_obs_to_group(g, z_val, Y_raw[jc], ra[jc], groups_local)
            states_log.setdefault(g, []).append((z_val, v_keep.copy()))
            matched_prev.add(g)
            matched_curr.add(jc)

        # New tracks from unmatched detections
        for jc in range(n_curr):
            if jc not in matched_curr:
                g = _start_track(z_val, Y_norm[jc])
                add_obs_to_group(g, z_val, Y_raw[jc], ra[jc], groups_local)
                st = track_state[g]
                st["len"] += 1
                st["hit_streak"] = 1
                if st["hit_streak"] >= birth_confirm:
                    st["confirmed"] = True
                states_log.setdefault(g, []).append((z_val, st["state"][4:8].copy()))

        # Process misses for unmatched tracks
        to_delete: list[int] = []
        for g in cand:
            if g not in matched_prev:
                st = track_state[g]
                dt = float(z_val - st["last_z"])
                dt = max(dt, 1e-12)
                x_pred, P_pred = kf_predict(st["state"], st["P"], st["Q"], dt)
                st["state"] = x_pred
                st["P"] = P_pred
                st["last_z"] = z_val
                st["gap"] = min(st["gap"] + 1, max_gap + 1)
                st["miss_streak"] += 1
                states_log.setdefault(g, []).append((z_val, st["state"][4:8].copy()))
                # Drop weak temporary tracks that are not getting confirmed
                if (not st["confirmed"]) and (st["miss_streak"] >= birth_confirm):
                    to_delete.append(g)

        for g in to_delete:
            track_state.pop(g, None)

    return groups_local, states_log


# === Run ===
groups, states_log = run_discrete_tracker(
    df,
    unique_z,
    max_gap=max_gap,
    gate_p=gate_p,
    alpha_R=alpha_R,
    alpha_Q=alpha_Q,
    meas_w_zeta=meas_w_zeta,
    meas_w_r=meas_w_r,
    R_scale_zeta=R_scale_zeta,
    R_scale_r=R_scale_r,
    Q_pos_scale=Q_pos_scale,
    Q_vel_scale=Q_vel_scale,
    P0_pos_scale=P0_pos_scale,
    P0_vel_scale=P0_vel_scale,
    inv_reg=inv_reg,
    birth_confirm=birth_confirm,
    new_track_penalty=new_track_penalty,
    gap_penalty=gap_penalty,
    cont_bonus=cont_bonus,
)

# ---------------------------------------------------------------------------
# 5) Filtering and figures (same presentation as the simple tracker)
# ---------------------------------------------------------------------------
min_group_length = 2
coord_precision = 6

filtered_groups: dict[int, list] = {}
seen_points: set[tuple[float, float, float]] = set()
ind = 0
for g_id, pts in groups.items():
    if len(pts) < min_group_length:
        continue
    unique_pts = []
    for pt in pts:
        pt_key = (
            round(pt[0], coord_precision),
            round(pt[1], coord_precision),
            round(pt[2], coord_precision),
        )
        if pt_key not in seen_points:
            seen_points.add(pt_key)
            unique_pts.append(pt)
    if len(unique_pts) >= min_group_length:
        filtered_groups[ind] = unique_pts
        ind += 1


def sort_group_points(
    points: list[tuple[float, float, float, float, float, float]]
) -> tuple[np.ndarray, ...]:
    """Sort a group's points by z and return columns as arrays."""
    arr = np.array(points, dtype=object)
    z_vals = arr[:, 0].astype(float)
    idx_srt = np.argsort(z_vals)
    z_srt = z_vals[idx_srt]
    zr_srt = arr[:, 1][idx_srt].astype(float)
    zi_srt = arr[:, 2][idx_srt].astype(float)
    rr_srt = arr[:, 3][idx_srt].astype(float)
    ri_srt = arr[:, 4][idx_srt].astype(float)
    ra_srt = arr[:, 5][idx_srt].astype(float)
    return z_srt, zr_srt, zi_srt, rr_srt, ri_srt, ra_srt


colors = px.colors.qualitative.Vivid

figures: dict[str, go.Figure] = {
    "re_zeta": go.Figure(),
    "im_zeta": go.Figure(),
    "re_r": go.Figure(),
    "im_r": go.Figure(),
    "abs_r": go.Figure(),
    "3d_zeta": go.Figure(),
    "3d_r": go.Figure(),
}

for idx, (g_id, pts) in enumerate(filtered_groups.items()):
    z_vals, zr, zi, rr, ri, ra = sort_group_points(pts)
    color = colors[idx % len(colors)]

    common_args = {
        "x": z_vals,
        "mode": "lines+markers",
        "marker": {"size": 3, "color": color},
        "line": {"color": color},
        "name": f"Branch {g_id}",
    }

    figures["re_zeta"].add_trace(go.Scatter(y=zr, **common_args))
    figures["im_zeta"].add_trace(go.Scatter(y=zi, **common_args))
    figures["re_r"].add_trace(go.Scatter(y=rr, **common_args))
    figures["im_r"].add_trace(go.Scatter(y=ri, **common_args))
    figures["abs_r"].add_trace(go.Scatter(y=np.sqrt(rr**2 + ri**2), **common_args))

    for fig_name, coords in [("3d_zeta", (zr, zi)), ("3d_r", (rr, ri))]:
        figures[fig_name].add_trace(
            go.Scatter3d(
                x=coords[0],
                y=z_vals,
                z=coords[1],
                mode="lines+markers",
                marker={"size": 3, "color": color},
                line={"color": color},
                name=f"Branch {g_id}",
            )
        )

layout_config: dict[str, tuple[str, str, str, str | None]] = {
    "re_zeta": ("Re(ζ) vs z", "z", "Re(ζ)", None),
    "im_zeta": ("Im(ζ) vs z", "z", "Im(ζ)", None),
    "re_r": ("Re(r) vs z", "z", "Re(r)", None),
    "im_r": ("Im(r) vs z", "z", "Im(r)", None),
    "abs_r": ("|r| vs z", "z", "|r|", None),
    "3d_zeta": ("3D ζ-Space", "Re(ζ)", "z", "Im(ζ)"),
    "3d_r": ("3D r-Space", "Re(r)", "z", "Im(r)"),
}

for fig_name, fig in figures.items():
    title, x_title, y_title, z_title = layout_config[fig_name]
    fig.update_layout(title=title, xaxis_title=x_title, yaxis_title=y_title)
    if fig_name.startswith("3d"):
        fig.update_layout(
            autosize=False,
            width=1600,
            height=1000,
            margin=dict(l=10, r=10, t=40, b=10),
            scene=dict(xaxis_title=x_title, yaxis_title=y_title, zaxis_title=z_title)
        )
    fig.show()
