In [2]:
from __future__ import annotations

from dataclasses import dataclass, field
from typing import Dict, List, Tuple, Optional, Set, Iterable
import math
import time

import numpy as np
from scipy.special import gammaincc
from scipy.spatial.distance import cdist
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import connected_components

In [3]:
# ============================================================
# Configuration
# ============================================================

@dataclass
class ACLDConfig:
    # --- timing ---
    dt_ctrl: float = 0.05
    dt_sim: float = 0.05

    # --- swarm ---
    n_uavs: int = 150
    include_ground_station: bool = True
    ground_station_pos: np.ndarray = field(default_factory=lambda: np.array([0.0, 0.0, 0.0], dtype=float))

    # --- physical limits ---
    v_max: float = 23.0
    a_max: float = 60.0
    j_max: float = 100.0
    d_safe: float = 2.0  # collision-avoidance minimum separation (example)

    # --- channel (default UAV-UAV) ---
    m_fading: float = 2.5
    snr_0_db: float = 22.0
    d_0: float = 100.0
    path_loss_exp: float = 2.4
    gamma_0_db: float = 10.0
    p_min: float = 0.90

    # --- UAV-GS channel profile (paper says Free-space LoS for UAV-GS) ---
    gs_m_fading: float = 3.5
    gs_snr_0_db: float = 25.0
    gs_path_loss_exp: float = 2.0

    # --- adaptive gains ---
    alpha: float = 0.35
    beta: float = 0.10
    gamma: float = 0.20
    delta: float = 0.10
    zeta: float = 0.05

    # --- base range ---
    cthr_base: float = 100.0

    # --- numerical guards ---
    min_cthr: float = 5.0
    max_cthr_multiplier: float = 2.0
    lambda2_partition_threshold: float = 1e-3

    # --- energy (optional simple simulation bookkeeping) ---
    e_max: float = 1000.0
    e_min_frac: float = 0.15

    # --- jamming / SINR params ---
    # Controlled stress mode: set jnr_db = None (nominal), or e.g. 5 / 15 dB
    jnr_db: Optional[float] = None
    # For detailed jammer model (optional):
    jammer_tx_dbm: Optional[float] = None
    jammer_pos: Optional[np.ndarray] = None
    jammer_gain_db: float = 0.0
    rx_gain_db: float = 0.0
    noise_figure_db: float = 6.0
    bandwidth_hz: float = 20e6

    # --- simulation area ---
    density_scale_mode: str = "medium"  # "sparse" | "medium" | "dense"
    seed: int = 42

    @property
    def snr_0_lin(self) -> float:
        return 10 ** (self.snr_0_db / 10.0)

    @property
    def gamma_0_lin(self) -> float:
        return 10 ** (self.gamma_0_db / 10.0)

    @property
    def gs_snr_0_lin(self) -> float:
        return 10 ** (self.gs_snr_0_db / 10.0)

    @property
    def n_total(self) -> int:
        return self.n_uavs + (1 if self.include_ground_station else 0)

    @property
    def gs_index(self) -> Optional[int]:
        return self.n_uavs if self.include_ground_station else None

    def density_half_side(self) -> float:
        mode = self.density_scale_mode.lower()
        if mode == "sparse":
            scale = 1.5
        elif mode == "dense":
            scale = 0.3
        else:
            scale = 0.8  # medium
        return scale * self.cthr_base

In [4]:
# ============================================================
# State
# ============================================================

@dataclass
class SwarmState:
    pos: np.ndarray          # (n,3)
    vel: np.ndarray          # (n,3)
    acc_cmd: np.ndarray      # (n,3)
    jerk: np.ndarray         # (n,3)
    energy: np.ndarray       # (n,)
    channel_headroom_db: np.ndarray  # (n,)

    @staticmethod
    def random_init(cfg: ACLDConfig, rng: np.random.Generator) -> "SwarmState":
        n = cfg.n_uavs
        half_side = cfg.density_half_side()

        pos = rng.uniform(-half_side, half_side, size=(n, 3))
        vel = rng.normal(0.0, cfg.v_max / 4.0, size=(n, 3))
        vel = _clip_rows_norm(vel, cfg.v_max)

        acc_cmd = rng.normal(0.0, cfg.a_max / 6.0, size=(n, 3))
        acc_cmd = _clip_rows_norm(acc_cmd, cfg.a_max)

        jerk = rng.normal(0.0, cfg.j_max / 6.0, size=(n, 3))
        jerk = _clip_rows_norm(jerk, cfg.j_max)

        energy = np.full(n, cfg.e_max, dtype=float)
        channel_headroom_db = np.full(n, 0.0, dtype=float)

        return SwarmState(
            pos=pos.astype(float),
            vel=vel.astype(float),
            acc_cmd=acc_cmd.astype(float),
            jerk=jerk.astype(float),
            energy=energy.astype(float),
            channel_headroom_db=channel_headroom_db.astype(float),
        )

    def step_random_kinematics(self, cfg: ACLDConfig, rng: np.random.Generator) -> None:
        """
        Simple bounded kinematics update for Monte-Carlo simulations.
        This is a placeholder simulation model, not a flight controller.
        """
        dt = cfg.dt_sim

        # Random jerk perturbation
        jerk_noise = rng.normal(0.0, cfg.j_max / 20.0, size=self.jerk.shape)
        self.jerk = _clip_rows_norm(self.jerk + jerk_noise, cfg.j_max)

        # Update control acceleration
        self.acc_cmd = _clip_rows_norm(self.acc_cmd + self.jerk * dt, cfg.a_max)

        # Update velocity
        self.vel = _clip_rows_norm(self.vel + self.acc_cmd * dt, cfg.v_max)

        # Update position
        self.pos = self.pos + self.vel * dt

        # Simple energy bookkeeping
        # (Not physically exact; enough for state completeness)
        speed = np.linalg.norm(self.vel, axis=1)
        acc = np.linalg.norm(self.acc_cmd, axis=1)
        power_proxy = 5.0 + 0.2 * speed + 0.02 * acc**2  # arbitrary proxy
        self.energy = np.maximum(0.0, self.energy - power_proxy * dt)

    def copy(self) -> "SwarmState":
        return SwarmState(
            pos=self.pos.copy(),
            vel=self.vel.copy(),
            acc_cmd=self.acc_cmd.copy(),
            jerk=self.jerk.copy(),
            energy=self.energy.copy(),
            channel_headroom_db=self.channel_headroom_db.copy(),
        )

In [5]:
# ============================================================
# Result containers
# ============================================================

@dataclass
class ACLDResult:
    A_pred_guaranteed: np.ndarray
    A_pred_viable: np.ndarray
    d_pred_uav: np.ndarray
    d_cur_uav: np.ndarray
    theta_minus_uav: np.ndarray
    theta_plus_uav: np.ndarray
    theta_minus_gs: Optional[np.ndarray]
    theta_plus_gs: Optional[np.ndarray]
    s_pred_uav: np.ndarray
    s_cur_uav: np.ndarray
    components: List[Set[int]]
    comp_of: Dict[int, int]
    risky_inter: List[Tuple[int, int]]
    risky_intra: List[Tuple[int, int]]
    risky_s_uav: List[Tuple[int, int]]
    risky_uav_uav: List[Tuple[int, int]]
    lambda2_current: float
    lambda2_pred_viable: float
    W_current: np.ndarray
    L_current: np.ndarray

In [None]:
# ============================================================
# Core ACLD
# ============================================================

class ACLDCore:
    def __init__(self, cfg: ACLDConfig):
        self.cfg = cfg

    # --------------------------
    # Geometry utilities
    # --------------------------
    def _pairwise_uav_dist(self, pos: np.ndarray) -> np.ndarray:
        return cdist(pos, pos, metric="euclidean")

    def _uav_gs_dist(self, pos: np.ndarray) -> np.ndarray:
        gs = self.cfg.ground_station_pos.reshape(1, 3)
        return np.linalg.norm(pos - gs, axis=1)

    def _dispersion_stats(self, d_uav: np.ndarray) -> Tuple[float, float]:
        n = d_uav.shape[0]
        iu = np.triu_indices(n, k=1)
        vals = d_uav[iu]
        if vals.size == 0:
            return 0.0, 1.0
        bar_d = float(np.mean(vals))
        sigma_d = float(np.std(vals))
        return sigma_d, max(bar_d, 1e-9)

    def _heading_diff_matrix(self, vel: np.ndarray) -> np.ndarray:
        """
        Absolute angle difference in [0, pi] using 3D vectors.
        If a vector norm is ~0, angle term defaults to 0 for that pair.
        """
        n = vel.shape[0]
        norms = np.linalg.norm(vel, axis=1)
        H = np.zeros((n, n), dtype=float)

        for i in range(n):
            for j in range(i + 1, n):
                ni, nj = norms[i], norms[j]
                if ni < 1e-9 or nj < 1e-9:
                    ang = 0.0
                else:
                    c = float(np.dot(vel[i], vel[j]) / (ni * nj))
                    c = np.clip(c, -1.0, 1.0)
                    ang = float(np.arccos(c))
                H[i, j] = H[j, i] = ang
        return H

    # --------------------------
    # Adaptive thresholds
    # --------------------------
    def adaptive_threshold_uav_uav(
        self,
        state: SwarmState,
        d_uav: np.ndarray,
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        """
        Returns cthr_ij, theta_minus_ij, theta_plus_ij
        """
        cfg = self.cfg
        n = cfg.n_uavs

        sigma_d, bar_d = self._dispersion_stats(d_uav)
        rel_speed = cdist(state.vel, state.vel, metric="euclidean")
        rel_acc = cdist(state.acc_cmd, state.acc_cmd, metric="euclidean")
        rel_jerk = cdist(state.jerk, state.jerk, metric="euclidean")
        dtheta = self._heading_diff_matrix(state.vel)

        term = (
            1.0
            + cfg.alpha * (sigma_d / bar_d)
            - cfg.beta * (rel_speed / (2.0 * cfg.v_max + 1e-12))
            - cfg.gamma * (dtheta / math.pi)
            - cfg.delta * (rel_acc / (2.0 * cfg.a_max + 1e-12))
            - cfg.zeta * (rel_jerk / (2.0 * cfg.j_max + 1e-12))
        )
        cthr = cfg.cthr_base * term

        cthr = np.clip(cthr, cfg.min_cthr, cfg.max_cthr_multiplier * cfg.cthr_base)
        np.fill_diagonal(cthr, 0.0)

        theta_plus = cthr
        theta_minus = 0.9 * cthr
        np.fill_diagonal(theta_minus, 0.0)
        np.fill_diagonal(theta_plus, 0.0)

        return cthr, theta_minus, theta_plus

    def adaptive_threshold_uav_gs(self, state: SwarmState, d_uav: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        """
        Eq. cthradp_gs in paper (no heading term).
        Returns cthr_is, theta_minus_is, theta_plus_is of shape (n_uavs,)
        """
        cfg = self.cfg
        sigma_d, bar_d = self._dispersion_stats(d_uav)

        speed = np.linalg.norm(state.vel, axis=1)
        acc = np.linalg.norm(state.acc_cmd, axis=1)
        jerk = np.linalg.norm(state.jerk, axis=1)

        term = (
            1.0
            + cfg.alpha * (sigma_d / bar_d)
            - cfg.beta * (speed / (cfg.v_max + 1e-12))
            - cfg.delta * (acc / (cfg.a_max + 1e-12))
            - cfg.zeta * (jerk / (cfg.j_max + 1e-12))
        )
        cthr = cfg.cthr_base * term
        cthr = np.clip(cthr, cfg.min_cthr, cfg.max_cthr_multiplier * cfg.cthr_base)
        theta_plus = cthr
        theta_minus = 0.9 * cthr
        return cthr, theta_minus, theta_plus

    # --------------------------
    # Channel model (Nakagami-m)
    # --------------------------
    def _nb_dbm(self) -> float:
        cfg = self.cfg
        return -174.0 + 10.0 * math.log10(cfg.bandwidth_hz) + cfg.noise_figure_db

    def _detailed_jnr_db_per_receiver(self, pos: np.ndarray) -> Optional[np.ndarray]:
        cfg = self.cfg
        if cfg.jammer_tx_dbm is None or cfg.jammer_pos is None:
            return None

        nb_dbm = self._nb_dbm()
        d = np.linalg.norm(pos - cfg.jammer_pos.reshape(1, 3), axis=1)
        d = np.maximum(d, 1.0)

        # Simple log-distance path loss (consistent enough for stress simulation)
        pl_db = 20.0 * np.log10(4 * np.pi * cfg.d_0 / 0.0517) + 10.0 * cfg.path_loss_exp * np.log10(d / cfg.d_0)
        p_rx_j_dbm = cfg.jammer_tx_dbm + cfg.jammer_gain_db + cfg.rx_gain_db - pl_db
        jnr_db = p_rx_j_dbm - nb_dbm
        return jnr_db

    def _receiver_jnr_lin(self, pos: np.ndarray) -> Optional[np.ndarray]:
        cfg = self.cfg

        if cfg.jnr_db is not None:
            # controlled receiver-referenced stress (paper-friendly)
            return np.full(cfg.n_uavs, 10 ** (cfg.jnr_db / 10.0), dtype=float)

        jnr_db_per_rx = self._detailed_jnr_db_per_receiver(pos)
        if jnr_db_per_rx is None:
            return None
        return 10 ** (jnr_db_per_rx / 10.0)

    def _nakagami_success_from_mean_snr(self, mean_snr: np.ndarray, m: float) -> np.ndarray:
        cfg = self.cfg
        x = m * (cfg.gamma_0_lin / np.maximum(mean_snr, 1e-12))
        # Γ(m,x)/Γ(m) = gammaincc(m,x)
        s = gammaincc(m, x)
        return np.clip(s, 0.0, 1.0)

    def success_probability_uav_uav(self, d_uav: np.ndarray, pos_for_jammer: Optional[np.ndarray] = None) -> np.ndarray:
        cfg = self.cfg
        d = np.maximum(d_uav, 1e-9)

        mean_snr = cfg.snr_0_lin * (cfg.d_0 / d) ** cfg.path_loss_exp

        # Jamming-aware SINR if configured
        if pos_for_jammer is None:
            pos_for_jammer = None

        jnr_lin_per_rx = None
        if pos_for_jammer is not None:
            jnr_lin_per_rx = self._receiver_jnr_lin(pos_for_jammer)

        if jnr_lin_per_rx is not None:
            # conservative link-level j_lin_ij = max(j_i, j_j)
            ji = jnr_lin_per_rx.reshape(-1, 1)
            jj = jnr_lin_per_rx.reshape(1, -1)
            j_lin_ij = np.maximum(ji, jj)
            mean_snr = mean_snr / (1.0 + j_lin_ij)

        s = self._nakagami_success_from_mean_snr(mean_snr, cfg.m_fading)
        np.fill_diagonal(s, 0.0)
        return s

    def success_probability_uav_gs(self, d_gs: np.ndarray, pos: np.ndarray) -> np.ndarray:
        cfg = self.cfg
        d = np.maximum(d_gs, 1e-9)

        mean_snr = cfg.gs_snr_0_lin * (cfg.d_0 / d) ** cfg.gs_path_loss_exp

        jnr_lin_per_rx = self._receiver_jnr_lin(pos)
        if jnr_lin_per_rx is not None:
            # receiver is UAV i for UAV-GS link
            mean_snr = mean_snr / (1.0 + jnr_lin_per_rx)

        s = self._nakagami_success_from_mean_snr(mean_snr, cfg.gs_m_fading)
        return np.clip(s, 0.0, 1.0)

    # --------------------------
    # Predicted adjacency / risk sets
    # --------------------------
    def _predict_positions_one_step(self, state: SwarmState) -> np.ndarray:
        return state.pos + state.vel * self.cfg.dt_ctrl

    def _build_pred_matrices(
        self, state: SwarmState
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]:
        """
        Returns:
          d_cur_uav, d_pred_uav,
          theta_minus_uav, theta_plus_uav,
          s_cur_uav, s_pred_uav,
          d_cur_gs, d_pred_gs, theta_minus_gs, theta_plus_gs
        """
        d_cur_uav = self._pairwise_uav_dist(state.pos)
        pred_pos = self._predict_positions_one_step(state)
        d_pred_uav = self._pairwise_uav_dist(pred_pos)

        _, theta_minus_uav, theta_plus_uav = self.adaptive_threshold_uav_uav(state, d_cur_uav)

        s_cur_uav = self.success_probability_uav_uav(d_cur_uav, pos_for_jammer=state.pos)
        s_pred_uav = self.success_probability_uav_uav(d_pred_uav, pos_for_jammer=pred_pos)

        if self.cfg.include_ground_station:
            d_cur_gs = self._uav_gs_dist(state.pos)
            d_pred_gs = self._uav_gs_dist(pred_pos)
            _, theta_minus_gs, theta_plus_gs = self.adaptive_threshold_uav_gs(state, d_cur_uav)
        else:
            d_cur_gs = d_pred_gs = theta_minus_gs = theta_plus_gs = None

        return (
            d_cur_uav, d_pred_uav,
            theta_minus_uav, theta_plus_uav,
            s_cur_uav, s_pred_uav,
            d_cur_gs, d_pred_gs, theta_minus_gs, theta_plus_gs
        )

    def _build_pred_adjacency(
        self,
        d_pred_uav: np.ndarray,
        theta_minus_uav: np.ndarray,
        theta_plus_uav: np.ndarray,
        s_pred_uav: np.ndarray,
        d_pred_gs: Optional[np.ndarray],
        theta_minus_gs: Optional[np.ndarray],
        theta_plus_gs: Optional[np.ndarray],
        s_pred_gs: Optional[np.ndarray],
    ) -> Tuple[np.ndarray, np.ndarray]:
        """
        Build predicted guaranteed and viable adjacency on V' (including GS if enabled).
        """
        cfg = self.cfg
        n_uav = cfg.n_uavs
        n_total = cfg.n_total

        A_minus = np.zeros((n_total, n_total), dtype=np.uint8)
        A_plus = np.zeros((n_total, n_total), dtype=np.uint8)

        # UAV-UAV edges
        viable_uav = (d_pred_uav <= theta_plus_uav) & (s_pred_uav >= cfg.p_min)
        guar_uav = (d_pred_uav <= theta_minus_uav) & (s_pred_uav >= cfg.p_min)

        np.fill_diagonal(viable_uav, False)
        np.fill_diagonal(guar_uav, False)

        A_plus[:n_uav, :n_uav] = viable_uav.astype(np.uint8)
        A_minus[:n_uav, :n_uav] = guar_uav.astype(np.uint8)

        # UAV-GS edges
        if cfg.include_ground_station:
            assert d_pred_gs is not None and theta_plus_gs is not None and theta_minus_gs is not None and s_pred_gs is not None
            sidx = cfg.gs_index
            viable_gs = (d_pred_gs <= theta_plus_gs) & (s_pred_gs >= cfg.p_min)
            guar_gs = (d_pred_gs <= theta_minus_gs) & (s_pred_gs >= cfg.p_min)

            for i in range(n_uav):
                if viable_gs[i]:
                    A_plus[i, sidx] = 1
                    A_plus[sidx, i] = 1
                if guar_gs[i]:
                    A_minus[i, sidx] = 1
                    A_minus[sidx, i] = 1

        return A_minus, A_plus

    # --------------------------
    # Pseudocode-equivalent graph routines
    # --------------------------
    @staticmethod
    def _adj_list_from_adjacency(A: np.ndarray) -> List[List[int]]:
        n = A.shape[0]
        adj = []
        for i in range(n):
            adj.append(np.flatnonzero(A[i]).tolist())
        return adj

    @staticmethod
    def kneighbors(adj: List[List[int]], seed: int, k_max: int) -> Set[int]:
        seen: Set[int] = {seed}
        q: List[Tuple[int, int]] = [(seed, 0)]
        head = 0
        while head < len(q):
            u, depth = q[head]
            head += 1
            if depth == k_max:
                continue
            for v in adj[u]:
                if v not in seen:
                    seen.add(v)
                    q.append((v, depth + 1))
        return seen

    @classmethod
    def predict_components(cls, A_guaranteed: np.ndarray, k_max: Optional[int] = None) -> Tuple[List[Set[int]], Dict[int, int]]:
        n = A_guaranteed.shape[0]
        if k_max is None:
            k_max = n  # full BFS as in manuscript note

        adj = cls._adj_list_from_adjacency(A_guaranteed)

        unseen = set(range(n))
        comps: List[Set[int]] = []
        comp_of: Dict[int, int] = {}

        while unseen:
            i = next(iter(unseen))
            C = cls.kneighbors(adj, i, k_max=k_max)
            idx = len(comps)
            for u in C:
                comp_of[u] = idx
            comps.append(C)
            unseen -= C

        return comps, comp_of

    @staticmethod
    def catalogue_risky_links(
        A_pred_viable: np.ndarray,
        comp_of: Dict[int, int],
        cfg: ACLDConfig,
        d_pred_uav: np.ndarray,
        theta_minus_uav: np.ndarray,
        theta_plus_uav: np.ndarray,
        d_pred_gs: Optional[np.ndarray],
        theta_minus_gs: Optional[np.ndarray],
        theta_plus_gs: Optional[np.ndarray],
    ) -> Tuple[List[Tuple[int, int]], List[Tuple[int, int]]]:
        """
        Risk-band = viable but not guaranteed (theta_minus < d <= theta_plus, and viable already ensured).
        Returns (inter, intra)
        """
        n_total = A_pred_viable.shape[0]
        n_uav = cfg.n_uavs
        sidx = cfg.gs_index

        inter: List[Tuple[int, int]] = []
        intra: List[Tuple[int, int]] = []

        for i in range(n_total):
            for j in range(i + 1, n_total):
                if A_pred_viable[i, j] == 0:
                    continue

                in_risk_band = False

                if i < n_uav and j < n_uav:
                    dij = d_pred_uav[i, j]
                    if theta_minus_uav[i, j] < dij <= theta_plus_uav[i, j]:
                        in_risk_band = True
                else:
                    # UAV-GS case
                    if sidx is None:
                        continue
                    if d_pred_gs is None or theta_minus_gs is None or theta_plus_gs is None:
                        continue
                    u = i if j == sidx else j if i == sidx else None
                    if u is not None and 0 <= u < n_uav:
                        du = d_pred_gs[u]
                        if theta_minus_gs[u] < du <= theta_plus_gs[u]:
                            in_risk_band = True

                if not in_risk_band:
                    continue

                if comp_of.get(i, -1) != comp_of.get(j, -1):
                    inter.append((i, j))
                else:
                    intra.append((i, j))

        return inter, intra

    @staticmethod
    def separate_ground_links(
        risky_inter: List[Tuple[int, int]],
        gs_index: Optional[int]
    ) -> Tuple[List[Tuple[int, int]], List[Tuple[int, int]]]:
        if gs_index is None:
            return [], risky_inter.copy()

        s_uav: List[Tuple[int, int]] = []
        uav_uav: List[Tuple[int, int]] = []

        for i, j in risky_inter:
            if i == gs_index or j == gs_index:
                s_uav.append((i, j))
            else:
                uav_uav.append((i, j))
        return s_uav, uav_uav

    # --------------------------
    # Weighted graph / Laplacian / lambda2
    # --------------------------
    def _build_current_weighted_graph(self, state: SwarmState) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        """
        Build W and L on current channel-viable graph G_viable(t), including GS if enabled.
        Uses geometric candidates + viability guard + directed row-stochastic inverse-square weights + symmetrization.
        """
        cfg = self.cfg
        n_uav = cfg.n_uavs
        n_total = cfg.n_total
        sidx = cfg.gs_index

        d_cur_uav = self._pairwise_uav_dist(state.pos)
        _, theta_minus_uav, theta_plus_uav = self.adaptive_threshold_uav_uav(state, d_cur_uav)
        s_cur_uav = self.success_probability_uav_uav(d_cur_uav, pos_for_jammer=state.pos)

        d_cur_gs = theta_minus_gs = theta_plus_gs = s_cur_gs = None
        if cfg.include_ground_station:
            d_cur_gs = self._uav_gs_dist(state.pos)
            _, theta_minus_gs, theta_plus_gs = self.adaptive_threshold_uav_gs(state, d_cur_uav)
            s_cur_gs = self.success_probability_uav_gs(d_cur_gs, state.pos)

        W_tilde = np.zeros((n_total, n_total), dtype=float)

        # UAV rows
        for i in range(n_uav):
            neighbors: List[Tuple[int, float]] = []

            # UAV-UAV viable neighbors
            for j in range(n_uav):
                if i == j:
                    continue
                if d_cur_uav[i, j] <= theta_plus_uav[i, j] and s_cur_uav[i, j] >= cfg.p_min:
                    neighbors.append((j, max(d_cur_uav[i, j], cfg.d_safe)))

            # UAV-GS viable neighbor
            if cfg.include_ground_station and d_cur_gs is not None and theta_plus_gs is not None and s_cur_gs is not None:
                if d_cur_gs[i] <= theta_plus_gs[i] and s_cur_gs[i] >= cfg.p_min:
                    neighbors.append((sidx, max(d_cur_gs[i], cfg.d_safe)))

            if len(neighbors) == 0:
                continue

            invsq = np.array([1.0 / (d * d) for _, d in neighbors], dtype=float)
            denom = float(np.sum(invsq))
            if denom <= 0:
                continue

            for (k, _d), w in zip(neighbors, invsq / denom):
                W_tilde[i, k] = float(w)

        # GS row (if present)
        if cfg.include_ground_station and sidx is not None and d_cur_gs is not None and theta_plus_gs is not None and s_cur_gs is not None:
            neighbors_gs: List[Tuple[int, float]] = []
            for i in range(n_uav):
                if d_cur_gs[i] <= theta_plus_gs[i] and s_cur_gs[i] >= cfg.p_min:
                    neighbors_gs.append((i, max(d_cur_gs[i], cfg.d_safe)))

            if neighbors_gs:
                invsq = np.array([1.0 / (d * d) for _, d in neighbors_gs], dtype=float)
                denom = float(np.sum(invsq))
                if denom > 0:
                    for (i, _d), w in zip(neighbors_gs, invsq / denom):
                        W_tilde[sidx, i] = float(w)

        # Symmetrize
        W = 0.5 * (W_tilde + W_tilde.T)
        np.fill_diagonal(W, 0.0)

        D = np.diag(np.sum(W, axis=1))
        L = D - W
        return W, D, L

    @staticmethod
    def algebraic_connectivity_lambda2(L: np.ndarray) -> float:
        n = L.shape[0]
        if n <= 1:
            return 0.0
        # Dense eigvalsh is acceptable up to n~1000 in this context
        evals = np.linalg.eigvalsh(L)
        evals = np.sort(np.real(evals))
        if len(evals) < 2:
            return 0.0
        return float(max(0.0, evals[1]))

    # --------------------------
    # Baselines
    # --------------------------
    @staticmethod
    def tarjan_bridges(A: np.ndarray) -> List[Tuple[int, int]]:
        n = A.shape[0]
        adj = ACLDCore._adj_list_from_adjacency(A)

        disc = [-1] * n
        low = [-1] * n
        parent = [-1] * n
        time_counter = 0
        bridges: List[Tuple[int, int]] = []

        def dfs(u: int) -> None:
            nonlocal time_counter
            disc[u] = low[u] = time_counter
            time_counter += 1

            for v in adj[u]:
                if disc[v] == -1:
                    parent[v] = u
                    dfs(v)
                    low[u] = min(low[u], low[v])

                    if low[v] > disc[u]:
                        bridges.append((u, v) if u < v else (v, u))
                elif v != parent[u]:
                    low[u] = min(low[u], disc[v])

        for i in range(n):
            if disc[i] == -1:
                dfs(i)

        bridges = sorted(set(bridges))
        return bridges

    @staticmethod
    def edge_betweenness_brandes_unweighted(A: np.ndarray) -> Dict[Tuple[int, int], float]:
        """
        Exact Brandes edge betweenness for unweighted undirected graph.
        """
        n = A.shape[0]
        adj = ACLDCore._adj_list_from_adjacency(A)
        eb: Dict[Tuple[int, int], float] = {}

        for s in range(n):
            stack: List[int] = []
            pred: List[List[int]] = [[] for _ in range(n)]
            sigma = np.zeros(n, dtype=float)
            dist = -np.ones(n, dtype=int)

            sigma[s] = 1.0
            dist[s] = 0

            # BFS
            q = [s]
            qh = 0
            while qh < len(q):
                v = q[qh]
                qh += 1
                stack.append(v)
                for w in adj[v]:
                    if dist[w] < 0:
                        q.append(w)
                        dist[w] = dist[v] + 1
                    if dist[w] == dist[v] + 1:
                        sigma[w] += sigma[v]
                        pred[w].append(v)

            delta = np.zeros(n, dtype=float)
            while stack:
                w = stack.pop()
                for v in pred[w]:
                    if sigma[w] > 0:
                        c = (sigma[v] / sigma[w]) * (1.0 + delta[w])
                    else:
                        c = 0.0
                    e = (v, w) if v < w else (w, v)
                    eb[e] = eb.get(e, 0.0) + c
                    delta[v] += c

        # undirected graph correction
        for e in list(eb.keys()):
            eb[e] *= 0.5
        return eb

    def baseline_tarjan_snapshot(self, state: SwarmState, with_guard: bool = True) -> List[Tuple[int, int]]:
        """
        TARJAN@t: snapshot bridges on current graph.
        If with_guard=False, geometry only (cthr+ band) without probabilistic guard.
        """
        cfg = self.cfg
        n_total = cfg.n_total
        n_uav = cfg.n_uavs
        sidx = cfg.gs_index

        d_cur_uav = self._pairwise_uav_dist(state.pos)
        _, _, theta_plus_uav = self.adaptive_threshold_uav_uav(state, d_cur_uav)
        s_cur_uav = self.success_probability_uav_uav(d_cur_uav, pos_for_jammer=state.pos)

        A = np.zeros((n_total, n_total), dtype=np.uint8)

        mask_uav = (d_cur_uav <= theta_plus_uav)
        if with_guard:
            mask_uav &= (s_cur_uav >= cfg.p_min)
        np.fill_diagonal(mask_uav, False)
        A[:n_uav, :n_uav] = mask_uav.astype(np.uint8)

        if cfg.include_ground_station:
            d_cur_gs = self._uav_gs_dist(state.pos)
            _, _, theta_plus_gs = self.adaptive_threshold_uav_gs(state, d_cur_uav)
            s_cur_gs = self.success_probability_uav_gs(d_cur_gs, state.pos)

            for i in range(n_uav):
                ok = d_cur_gs[i] <= theta_plus_gs[i]
                if with_guard:
                    ok = ok and (s_cur_gs[i] >= cfg.p_min)
                if ok:
                    A[i, sidx] = A[sidx, i] = 1

        return self.tarjan_bridges(A)

    # --------------------------
    # Main ACLD step
    # --------------------------
    def run_acld(self, state: SwarmState) -> ACLDResult:
        cfg = self.cfg

        (
            d_cur_uav, d_pred_uav,
            theta_minus_uav, theta_plus_uav,
            s_cur_uav, s_pred_uav,
            d_cur_gs, d_pred_gs, theta_minus_gs, theta_plus_gs
        ) = self._build_pred_matrices(state)

        s_pred_gs = None
        if cfg.include_ground_station and d_pred_gs is not None:
            pred_pos = self._predict_positions_one_step(state)
            s_pred_gs = self.success_probability_uav_gs(d_pred_gs, pred_pos)

        A_pred_guaranteed, A_pred_viable = self._build_pred_adjacency(
            d_pred_uav=d_pred_uav,
            theta_minus_uav=theta_minus_uav,
            theta_plus_uav=theta_plus_uav,
            s_pred_uav=s_pred_uav,
            d_pred_gs=d_pred_gs,
            theta_minus_gs=theta_minus_gs,
            theta_plus_gs=theta_plus_gs,
            s_pred_gs=s_pred_gs,
        )

        components, comp_of = self.predict_components(A_pred_guaranteed, k_max=cfg.n_total)

        risky_inter, risky_intra = self.catalogue_risky_links(
            A_pred_viable=A_pred_viable,
            comp_of=comp_of,
            cfg=cfg,
            d_pred_uav=d_pred_uav,
            theta_minus_uav=theta_minus_uav,
            theta_plus_uav=theta_plus_uav,
            d_pred_gs=d_pred_gs,
            theta_minus_gs=theta_minus_gs,
            theta_plus_gs=theta_plus_gs,
        )

        risky_s_uav, risky_uav_uav = self.separate_ground_links(risky_inter, cfg.gs_index)

        W_current, D_current, L_current = self._build_current_weighted_graph(state)
        lambda2_current = self.algebraic_connectivity_lambda2(L_current)

        # predicted viable Laplacian (binary) for diagnostic only
        D_pred = np.diag(np.sum(A_pred_viable.astype(float), axis=1))
        L_pred_viable = D_pred - A_pred_viable.astype(float)
        lambda2_pred_viable = self.algebraic_connectivity_lambda2(L_pred_viable)

        return ACLDResult(
            A_pred_guaranteed=A_pred_guaranteed,
            A_pred_viable=A_pred_viable,
            d_pred_uav=d_pred_uav,
            d_cur_uav=d_cur_uav,
            theta_minus_uav=theta_minus_uav,
            theta_plus_uav=theta_plus_uav,
            theta_minus_gs=theta_minus_gs,
            theta_plus_gs=theta_plus_gs,
            s_pred_uav=s_pred_uav,
            s_cur_uav=s_cur_uav,
            components=components,
            comp_of=comp_of,
            risky_inter=risky_inter,
            risky_intra=risky_intra,
            risky_s_uav=risky_s_uav,
            risky_uav_uav=risky_uav_uav,
            lambda2_current=lambda2_current,
            lambda2_pred_viable=lambda2_pred_viable,
            W_current=W_current,
            L_current=L_current,
        )

In [None]:
# ============================================================
# Simulation / Evaluation helpers
# ============================================================

def _clip_rows_norm(X: np.ndarray, max_norm: float) -> np.ndarray:
    norms = np.linalg.norm(X, axis=1, keepdims=True)
    scale = np.ones_like(norms)
    mask = norms > max_norm
    scale[mask] = max_norm / np.maximum(norms[mask], 1e-12)
    return X * scale


def make_rng(seed: int) -> np.random.Generator:
    return np.random.default_rng(seed)


@dataclass
class StepLog:
    t: float
    lambda2_current: float
    partition_now: bool
    n_components_pred_guar: int
    n_risky_inter: int
    n_risky_intra: int
    n_risky_s_uav: int
    n_risky_uav_uav: int
    runtime_ms: float


def run_episode(
    cfg: ACLDConfig,
    horizon_s: float = 10.0,
    seed: Optional[int] = None,
) -> Tuple[List[StepLog], List[ACLDResult]]:
    rng = make_rng(cfg.seed if seed is None else seed)
    state = SwarmState.random_init(cfg, rng)
    acld = ACLDCore(cfg)

    T = int(round(horizon_s / cfg.dt_sim))
    logs: List[StepLog] = []
    results: List[ACLDResult] = []

    t_now = 0.0
    for _ in range(T):
        t0 = time.perf_counter()
        res = acld.run_acld(state)
        t1 = time.perf_counter()
        runtime_ms = (t1 - t0) * 1000.0

        partition_now = res.lambda2_current < cfg.lambda2_partition_threshold
        logs.append(
            StepLog(
                t=t_now,
                lambda2_current=res.lambda2_current,
                partition_now=partition_now,
                n_components_pred_guar=len(res.components),
                n_risky_inter=len(res.risky_inter),
                n_risky_intra=len(res.risky_intra),
                n_risky_s_uav=len(res.risky_s_uav),
                n_risky_uav_uav=len(res.risky_uav_uav),
                runtime_ms=runtime_ms,
            )
        )
        results.append(res)

        # Advance simulation
        state.step_random_kinematics(cfg, rng)
        t_now += cfg.dt_sim

    return logs, results


def summarize_episode(logs: List[StepLog]) -> Dict[str, float]:
    lambda2_vals = np.array([x.lambda2_current for x in logs], dtype=float)
    partitions = np.array([x.partition_now for x in logs], dtype=bool)
    runtimes = np.array([x.runtime_ms for x in logs], dtype=float)

    out = {
        "mean_lambda2": float(np.mean(lambda2_vals)) if len(lambda2_vals) else 0.0,
        "partition_count": int(np.sum(partitions)),
        "partition_rate": float(np.mean(partitions)) if len(partitions) else 0.0,
        "runtime_median_ms": float(np.median(runtimes)) if len(runtimes) else 0.0,
        "runtime_iqr_low_ms": float(np.percentile(runtimes, 25)) if len(runtimes) else 0.0,
        "runtime_iqr_high_ms": float(np.percentile(runtimes, 75)) if len(runtimes) else 0.0,
    }
    return out


# ============================================================
# Simple benchmark (runtime vs n)
# ============================================================

def benchmark_runtime_vs_n(
    base_cfg: ACLDConfig,
    sizes: Iterable[int] = (50, 150, 500, 1000),
    repeats: int = 200,
) -> List[Dict[str, float]]:
    rows: List[Dict[str, float]] = []

    for n in sizes:
        cfg = ACLDConfig(**{**base_cfg.__dict__, "n_uavs": int(n)})
        rng = make_rng(cfg.seed + int(n))
        state = SwarmState.random_init(cfg, rng)
        acld = ACLDCore(cfg)

        times_ms = np.zeros(repeats, dtype=float)
        for k in range(repeats):
            t0 = time.perf_counter()
            _ = acld.run_acld(state)
            t1 = time.perf_counter()
            times_ms[k] = (t1 - t0) * 1000.0
            # small random motion to avoid exact cache repetition
            state.step_random_kinematics(cfg, rng)

        row = {
            "n": int(n),
            "pairs": int(n * (n - 1) // 2),
            "median_ms": float(np.median(times_ms)),
            "q1_ms": float(np.percentile(times_ms, 25)),
            "q3_ms": float(np.percentile(times_ms, 75)),
            "mean_ms": float(np.mean(times_ms)),
            "p95_ms": float(np.percentile(times_ms, 95)),
        }
        rows.append(row)

    return rows


# ============================================================
# Optional: one-step evaluation helper for PEW/FP style labels
# ============================================================

def connectivity_components_from_adjacency(A: np.ndarray) -> Tuple[int, np.ndarray]:
    G = csr_matrix(A.astype(np.int8))
    n_comp, labels = connected_components(G, directed=False, return_labels=True)
    return int(n_comp), labels


def edge_deletion_disconnects(A: np.ndarray, edge: Tuple[int, int]) -> bool:
    i, j = edge
    if A[i, j] == 0:
        return False
    A2 = A.copy()
    A2[i, j] = 0
    A2[j, i] = 0
    n_comp_before, _ = connectivity_components_from_adjacency(A)
    n_comp_after, _ = connectivity_components_from_adjacency(A2)
    return n_comp_after > n_comp_before


def approximate_edge_f1_on_predicted_graph(
    A_pred_viable: np.ndarray,
    flagged_edges: List[Tuple[int, int]],
    sample_size: int = 200,
    rng: Optional[np.random.Generator] = None,
) -> Dict[str, float]:
    """
    Approximate edge-level F1 by sampling edges and testing edge deletion connectivity impact.
    """
    if rng is None:
        rng = np.random.default_rng(123)

    n = A_pred_viable.shape[0]
    edges = [(i, j) for i in range(n) for j in range(i + 1, n) if A_pred_viable[i, j] == 1]
    if not edges:
        return {"precision": 0.0, "recall": 0.0, "f1": 0.0}

    if len(edges) > sample_size:
        idx = rng.choice(len(edges), size=sample_size, replace=False)
        sampled = [edges[k] for k in idx]
    else:
        sampled = edges

    true_critical = set(e for e in sampled if edge_deletion_disconnects(A_pred_viable, e))
    flagged = set(flagged_edges)

    tp = len(true_critical & flagged)
    fp = len(flagged - true_critical)
    fn = len(true_critical - flagged)

    precision = tp / (tp + fp) if (tp + fp) else 0.0
    recall = tp / (tp + fn) if (tp + fn) else 0.0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) else 0.0
    return {"precision": precision, "recall": recall, "f1": f1}


# ============================================================
# Demo main
# ============================================================

if __name__ == "__main__":
    # Example 1: single episode demo
    cfg = ACLDConfig(
        n_uavs=50,
        include_ground_station=True,
        density_scale_mode="medium",
        jnr_db=None,   # nominal
        seed=42,
    )

    logs, results = run_episode(cfg, horizon_s=2.0, seed=42)
    summary = summarize_episode(logs)

    print("=== ACLD Episode Summary ===")
    for k, v in summary.items():
        print(f"{k}: {v}")

    last = results[-1]
    print("\n=== Last-step diagnostics ===")
    print(f"lambda2_current      : {last.lambda2_current:.6f}")
    print(f"lambda2_pred_viable  : {last.lambda2_pred_viable:.6f}")
    print(f"pred components (guar): {len(last.components)}")
    print(f"risky_inter          : {len(last.risky_inter)}")
    print(f"risky_intra          : {len(last.risky_intra)}")
    print(f"risky_s_uav          : {len(last.risky_s_uav)}")
    print(f"risky_uav_uav        : {len(last.risky_uav_uav)}")

    # Example 2: runtime benchmark (small repeats for demo)
    print("\n=== Runtime benchmark (demo repeats=20) ===")
    bench_rows = benchmark_runtime_vs_n(cfg, sizes=(50, 150, 300), repeats=20)
    for row in bench_rows:
        print(row) 