In [6]:
from __future__ import annotations

from dataclasses import dataclass, field
from typing import Dict, List, Tuple, Optional, Set, Iterable
import math
import time

import os
import csv
import json
import copy
import matplotlib.pyplot as plt

import numpy as np
from scipy.special import gammaincc
from scipy.spatial.distance import cdist
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import connected_components

In [3]:
# ============================================================
# Configuration
# ============================================================

@dataclass
class ACLDConfig:
    # --- timing ---
    dt_ctrl: float = 0.05
    dt_sim: float = 0.05

    # --- swarm ---
    n_uavs: int = 150
    include_ground_station: bool = True
    ground_station_pos: np.ndarray = field(default_factory=lambda: np.array([0.0, 0.0, 0.0], dtype=float))

    # --- physical limits ---
    v_max: float = 23.0
    a_max: float = 60.0
    j_max: float = 100.0
    d_safe: float = 2.0  # collision-avoidance minimum separation (example)

    # --- channel (default UAV-UAV) ---
    m_fading: float = 2.5
    snr_0_db: float = 22.0
    d_0: float = 100.0
    path_loss_exp: float = 2.4
    gamma_0_db: float = 10.0
    p_min: float = 0.90

    # --- UAV-GS channel profile (paper says Free-space LoS for UAV-GS) ---
    gs_m_fading: float = 3.5
    gs_snr_0_db: float = 25.0
    gs_path_loss_exp: float = 2.0

    # --- adaptive gains ---
    alpha: float = 0.35
    beta: float = 0.10
    gamma: float = 0.20
    delta: float = 0.10
    zeta: float = 0.05

    # --- base range ---
    cthr_base: float = 100.0

    # --- numerical guards ---
    min_cthr: float = 5.0
    max_cthr_multiplier: float = 2.0
    lambda2_partition_threshold: float = 1e-3

    # --- energy (optional simple simulation bookkeeping) ---
    e_max: float = 1000.0
    e_min_frac: float = 0.15

    # --- jamming / SINR params ---
    # Controlled stress mode: set jnr_db = None (nominal), or e.g. 5 / 15 dB
    jnr_db: Optional[float] = None
    # For detailed jammer model (optional):
    jammer_tx_dbm: Optional[float] = None
    jammer_pos: Optional[np.ndarray] = None
    jammer_gain_db: float = 0.0
    rx_gain_db: float = 0.0
    noise_figure_db: float = 6.0
    bandwidth_hz: float = 20e6

    # --- simulation area ---
    density_scale_mode: str = "medium"  # "sparse" | "medium" | "dense"
    seed: int = 42

    @property
    def snr_0_lin(self) -> float:
        return 10 ** (self.snr_0_db / 10.0)

    @property
    def gamma_0_lin(self) -> float:
        return 10 ** (self.gamma_0_db / 10.0)

    @property
    def gs_snr_0_lin(self) -> float:
        return 10 ** (self.gs_snr_0_db / 10.0)

    @property
    def n_total(self) -> int:
        return self.n_uavs + (1 if self.include_ground_station else 0)

    @property
    def gs_index(self) -> Optional[int]:
        return self.n_uavs if self.include_ground_station else None

    def density_half_side(self) -> float:
        mode = self.density_scale_mode.lower()
        if mode == "sparse":
            scale = 1.5
        elif mode == "dense":
            scale = 0.3
        else:
            scale = 0.8  # medium
        return scale * self.cthr_base

In [4]:
# ============================================================
# State
# ============================================================

@dataclass
class SwarmState:
    pos: np.ndarray          # (n,3)
    vel: np.ndarray          # (n,3)
    acc_cmd: np.ndarray      # (n,3)
    jerk: np.ndarray         # (n,3)
    energy: np.ndarray       # (n,)
    channel_headroom_db: np.ndarray  # (n,)

    @staticmethod
    def random_init(cfg: ACLDConfig, rng: np.random.Generator) -> "SwarmState":
        n = cfg.n_uavs
        half_side = cfg.density_half_side()

        pos = rng.uniform(-half_side, half_side, size=(n, 3))
        vel = rng.normal(0.0, cfg.v_max / 4.0, size=(n, 3))
        vel = _clip_rows_norm(vel, cfg.v_max)

        acc_cmd = rng.normal(0.0, cfg.a_max / 6.0, size=(n, 3))
        acc_cmd = _clip_rows_norm(acc_cmd, cfg.a_max)

        jerk = rng.normal(0.0, cfg.j_max / 6.0, size=(n, 3))
        jerk = _clip_rows_norm(jerk, cfg.j_max)

        energy = np.full(n, cfg.e_max, dtype=float)
        channel_headroom_db = np.full(n, 0.0, dtype=float)

        return SwarmState(
            pos=pos.astype(float),
            vel=vel.astype(float),
            acc_cmd=acc_cmd.astype(float),
            jerk=jerk.astype(float),
            energy=energy.astype(float),
            channel_headroom_db=channel_headroom_db.astype(float),
        )

    def step_random_kinematics(self, cfg: ACLDConfig, rng: np.random.Generator) -> None:
        """
        Simple bounded kinematics update for Monte-Carlo simulations.
        This is a placeholder simulation model, not a flight controller.
        """
        dt = cfg.dt_sim

        # Random jerk perturbation
        jerk_noise = rng.normal(0.0, cfg.j_max / 20.0, size=self.jerk.shape)
        self.jerk = _clip_rows_norm(self.jerk + jerk_noise, cfg.j_max)

        # Update control acceleration
        self.acc_cmd = _clip_rows_norm(self.acc_cmd + self.jerk * dt, cfg.a_max)

        # Update velocity
        self.vel = _clip_rows_norm(self.vel + self.acc_cmd * dt, cfg.v_max)

        # Update position
        self.pos = self.pos + self.vel * dt

        # Simple energy bookkeeping
        # (Not physically exact; enough for state completeness)
        speed = np.linalg.norm(self.vel, axis=1)
        acc = np.linalg.norm(self.acc_cmd, axis=1)
        power_proxy = 5.0 + 0.2 * speed + 0.02 * acc**2  # arbitrary proxy
        self.energy = np.maximum(0.0, self.energy - power_proxy * dt)

    def copy(self) -> "SwarmState":
        return SwarmState(
            pos=self.pos.copy(),
            vel=self.vel.copy(),
            acc_cmd=self.acc_cmd.copy(),
            jerk=self.jerk.copy(),
            energy=self.energy.copy(),
            channel_headroom_db=self.channel_headroom_db.copy(),
        )

In [5]:
# ============================================================
# Result containers
# ============================================================

@dataclass
class ACLDResult:
    A_pred_guaranteed: np.ndarray
    A_pred_viable: np.ndarray
    d_pred_uav: np.ndarray
    d_cur_uav: np.ndarray
    theta_minus_uav: np.ndarray
    theta_plus_uav: np.ndarray
    theta_minus_gs: Optional[np.ndarray]
    theta_plus_gs: Optional[np.ndarray]
    s_pred_uav: np.ndarray
    s_cur_uav: np.ndarray
    components: List[Set[int]]
    comp_of: Dict[int, int]
    risky_inter: List[Tuple[int, int]]
    risky_intra: List[Tuple[int, int]]
    risky_s_uav: List[Tuple[int, int]]
    risky_uav_uav: List[Tuple[int, int]]
    lambda2_current: float
    lambda2_pred_viable: float
    W_current: np.ndarray
    L_current: np.ndarray

In [None]:
# ============================================================
# Core ACLD
# ============================================================

class ACLDCore:
    def __init__(self, cfg: ACLDConfig):
        self.cfg = cfg

    # --------------------------
    # Geometry utilities
    # --------------------------
    def _pairwise_uav_dist(self, pos: np.ndarray) -> np.ndarray:
        return cdist(pos, pos, metric="euclidean")

    def _uav_gs_dist(self, pos: np.ndarray) -> np.ndarray:
        gs = self.cfg.ground_station_pos.reshape(1, 3)
        return np.linalg.norm(pos - gs, axis=1)

    def _dispersion_stats(self, d_uav: np.ndarray) -> Tuple[float, float]:
        n = d_uav.shape[0]
        iu = np.triu_indices(n, k=1)
        vals = d_uav[iu]
        if vals.size == 0:
            return 0.0, 1.0
        bar_d = float(np.mean(vals))
        sigma_d = float(np.std(vals))
        return sigma_d, max(bar_d, 1e-9)

    def _heading_diff_matrix(self, vel: np.ndarray) -> np.ndarray:
        """
        Absolute angle difference in [0, pi] using 3D vectors.
        If a vector norm is ~0, angle term defaults to 0 for that pair.
        """
        n = vel.shape[0]
        norms = np.linalg.norm(vel, axis=1)
        H = np.zeros((n, n), dtype=float)

        for i in range(n):
            for j in range(i + 1, n):
                ni, nj = norms[i], norms[j]
                if ni < 1e-9 or nj < 1e-9:
                    ang = 0.0
                else:
                    c = float(np.dot(vel[i], vel[j]) / (ni * nj))
                    c = np.clip(c, -1.0, 1.0)
                    ang = float(np.arccos(c))
                H[i, j] = H[j, i] = ang
        return H

    # --------------------------
    # Adaptive thresholds
    # --------------------------
    def adaptive_threshold_uav_uav(
        self,
        state: SwarmState,
        d_uav: np.ndarray,
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        """
        Returns cthr_ij, theta_minus_ij, theta_plus_ij
        """
        cfg = self.cfg
        n = cfg.n_uavs

        sigma_d, bar_d = self._dispersion_stats(d_uav)
        rel_speed = cdist(state.vel, state.vel, metric="euclidean")
        rel_acc = cdist(state.acc_cmd, state.acc_cmd, metric="euclidean")
        rel_jerk = cdist(state.jerk, state.jerk, metric="euclidean")
        dtheta = self._heading_diff_matrix(state.vel)

        term = (
            1.0
            + cfg.alpha * (sigma_d / bar_d)
            - cfg.beta * (rel_speed / (2.0 * cfg.v_max + 1e-12))
            - cfg.gamma * (dtheta / math.pi)
            - cfg.delta * (rel_acc / (2.0 * cfg.a_max + 1e-12))
            - cfg.zeta * (rel_jerk / (2.0 * cfg.j_max + 1e-12))
        )
        cthr = cfg.cthr_base * term

        cthr = np.clip(cthr, cfg.min_cthr, cfg.max_cthr_multiplier * cfg.cthr_base)
        np.fill_diagonal(cthr, 0.0)

        theta_plus = cthr
        theta_minus = 0.9 * cthr
        np.fill_diagonal(theta_minus, 0.0)
        np.fill_diagonal(theta_plus, 0.0)

        return cthr, theta_minus, theta_plus

    def adaptive_threshold_uav_gs(self, state: SwarmState, d_uav: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        """
        Eq. cthradp_gs in paper (no heading term).
        Returns cthr_is, theta_minus_is, theta_plus_is of shape (n_uavs,)
        """
        cfg = self.cfg
        sigma_d, bar_d = self._dispersion_stats(d_uav)

        speed = np.linalg.norm(state.vel, axis=1)
        acc = np.linalg.norm(state.acc_cmd, axis=1)
        jerk = np.linalg.norm(state.jerk, axis=1)

        term = (
            1.0
            + cfg.alpha * (sigma_d / bar_d)
            - cfg.beta * (speed / (cfg.v_max + 1e-12))
            - cfg.delta * (acc / (cfg.a_max + 1e-12))
            - cfg.zeta * (jerk / (cfg.j_max + 1e-12))
        )
        cthr = cfg.cthr_base * term
        cthr = np.clip(cthr, cfg.min_cthr, cfg.max_cthr_multiplier * cfg.cthr_base)
        theta_plus = cthr
        theta_minus = 0.9 * cthr
        return cthr, theta_minus, theta_plus

    # --------------------------
    # Channel model (Nakagami-m)
    # --------------------------
    def _nb_dbm(self) -> float:
        cfg = self.cfg
        return -174.0 + 10.0 * math.log10(cfg.bandwidth_hz) + cfg.noise_figure_db

    def _detailed_jnr_db_per_receiver(self, pos: np.ndarray) -> Optional[np.ndarray]:
        cfg = self.cfg
        if cfg.jammer_tx_dbm is None or cfg.jammer_pos is None:
            return None

        nb_dbm = self._nb_dbm()
        d = np.linalg.norm(pos - cfg.jammer_pos.reshape(1, 3), axis=1)
        d = np.maximum(d, 1.0)

        # Simple log-distance path loss (consistent enough for stress simulation)
        pl_db = 20.0 * np.log10(4 * np.pi * cfg.d_0 / 0.0517) + 10.0 * cfg.path_loss_exp * np.log10(d / cfg.d_0)
        p_rx_j_dbm = cfg.jammer_tx_dbm + cfg.jammer_gain_db + cfg.rx_gain_db - pl_db
        jnr_db = p_rx_j_dbm - nb_dbm
        return jnr_db

    def _receiver_jnr_lin(self, pos: np.ndarray) -> Optional[np.ndarray]:
        cfg = self.cfg

        if cfg.jnr_db is not None:
            # controlled receiver-referenced stress (paper-friendly)
            return np.full(cfg.n_uavs, 10 ** (cfg.jnr_db / 10.0), dtype=float)

        jnr_db_per_rx = self._detailed_jnr_db_per_receiver(pos)
        if jnr_db_per_rx is None:
            return None
        return 10 ** (jnr_db_per_rx / 10.0)

    def _nakagami_success_from_mean_snr(self, mean_snr: np.ndarray, m: float) -> np.ndarray:
        cfg = self.cfg
        x = m * (cfg.gamma_0_lin / np.maximum(mean_snr, 1e-12))
        # Γ(m,x)/Γ(m) = gammaincc(m,x)
        s = gammaincc(m, x)
        return np.clip(s, 0.0, 1.0)

    def success_probability_uav_uav(self, d_uav: np.ndarray, pos_for_jammer: Optional[np.ndarray] = None) -> np.ndarray:
        cfg = self.cfg
        d = np.maximum(d_uav, 1e-9)

        mean_snr = cfg.snr_0_lin * (cfg.d_0 / d) ** cfg.path_loss_exp

        # Jamming-aware SINR if configured
        if pos_for_jammer is None:
            pos_for_jammer = None

        jnr_lin_per_rx = None
        if pos_for_jammer is not None:
            jnr_lin_per_rx = self._receiver_jnr_lin(pos_for_jammer)

        if jnr_lin_per_rx is not None:
            # conservative link-level j_lin_ij = max(j_i, j_j)
            ji = jnr_lin_per_rx.reshape(-1, 1)
            jj = jnr_lin_per_rx.reshape(1, -1)
            j_lin_ij = np.maximum(ji, jj)
            mean_snr = mean_snr / (1.0 + j_lin_ij)

        s = self._nakagami_success_from_mean_snr(mean_snr, cfg.m_fading)
        np.fill_diagonal(s, 0.0)
        return s

    def success_probability_uav_gs(self, d_gs: np.ndarray, pos: np.ndarray) -> np.ndarray:
        cfg = self.cfg
        d = np.maximum(d_gs, 1e-9)

        mean_snr = cfg.gs_snr_0_lin * (cfg.d_0 / d) ** cfg.gs_path_loss_exp

        jnr_lin_per_rx = self._receiver_jnr_lin(pos)
        if jnr_lin_per_rx is not None:
            # receiver is UAV i for UAV-GS link
            mean_snr = mean_snr / (1.0 + jnr_lin_per_rx)

        s = self._nakagami_success_from_mean_snr(mean_snr, cfg.gs_m_fading)
        return np.clip(s, 0.0, 1.0)

    # --------------------------
    # Predicted adjacency / risk sets
    # --------------------------
    def _predict_positions_one_step(self, state: SwarmState) -> np.ndarray:
        return state.pos + state.vel * self.cfg.dt_ctrl

    def _build_pred_matrices(
        self, state: SwarmState
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]:
        """
        Returns:
          d_cur_uav, d_pred_uav,
          theta_minus_uav, theta_plus_uav,
          s_cur_uav, s_pred_uav,
          d_cur_gs, d_pred_gs, theta_minus_gs, theta_plus_gs
        """
        d_cur_uav = self._pairwise_uav_dist(state.pos)
        pred_pos = self._predict_positions_one_step(state)
        d_pred_uav = self._pairwise_uav_dist(pred_pos)

        _, theta_minus_uav, theta_plus_uav = self.adaptive_threshold_uav_uav(state, d_cur_uav)

        s_cur_uav = self.success_probability_uav_uav(d_cur_uav, pos_for_jammer=state.pos)
        s_pred_uav = self.success_probability_uav_uav(d_pred_uav, pos_for_jammer=pred_pos)

        if self.cfg.include_ground_station:
            d_cur_gs = self._uav_gs_dist(state.pos)
            d_pred_gs = self._uav_gs_dist(pred_pos)
            _, theta_minus_gs, theta_plus_gs = self.adaptive_threshold_uav_gs(state, d_cur_uav)
        else:
            d_cur_gs = d_pred_gs = theta_minus_gs = theta_plus_gs = None

        return (
            d_cur_uav, d_pred_uav,
            theta_minus_uav, theta_plus_uav,
            s_cur_uav, s_pred_uav,
            d_cur_gs, d_pred_gs, theta_minus_gs, theta_plus_gs
        )

    def _build_pred_adjacency(
        self,
        d_pred_uav: np.ndarray,
        theta_minus_uav: np.ndarray,
        theta_plus_uav: np.ndarray,
        s_pred_uav: np.ndarray,
        d_pred_gs: Optional[np.ndarray],
        theta_minus_gs: Optional[np.ndarray],
        theta_plus_gs: Optional[np.ndarray],
        s_pred_gs: Optional[np.ndarray],
    ) -> Tuple[np.ndarray, np.ndarray]:
        """
        Build predicted guaranteed and viable adjacency on V' (including GS if enabled).
        """
        cfg = self.cfg
        n_uav = cfg.n_uavs
        n_total = cfg.n_total

        A_minus = np.zeros((n_total, n_total), dtype=np.uint8)
        A_plus = np.zeros((n_total, n_total), dtype=np.uint8)

        # UAV-UAV edges
        viable_uav = (d_pred_uav <= theta_plus_uav) & (s_pred_uav >= cfg.p_min)
        guar_uav = (d_pred_uav <= theta_minus_uav) & (s_pred_uav >= cfg.p_min)

        np.fill_diagonal(viable_uav, False)
        np.fill_diagonal(guar_uav, False)

        A_plus[:n_uav, :n_uav] = viable_uav.astype(np.uint8)
        A_minus[:n_uav, :n_uav] = guar_uav.astype(np.uint8)

        # UAV-GS edges
        if cfg.include_ground_station:
            assert d_pred_gs is not None and theta_plus_gs is not None and theta_minus_gs is not None and s_pred_gs is not None
            sidx = cfg.gs_index
            viable_gs = (d_pred_gs <= theta_plus_gs) & (s_pred_gs >= cfg.p_min)
            guar_gs = (d_pred_gs <= theta_minus_gs) & (s_pred_gs >= cfg.p_min)

            for i in range(n_uav):
                if viable_gs[i]:
                    A_plus[i, sidx] = 1
                    A_plus[sidx, i] = 1
                if guar_gs[i]:
                    A_minus[i, sidx] = 1
                    A_minus[sidx, i] = 1

        return A_minus, A_plus

    # --------------------------
    # Pseudocode-equivalent graph routines
    # --------------------------
    @staticmethod
    def _adj_list_from_adjacency(A: np.ndarray) -> List[List[int]]:
        n = A.shape[0]
        adj = []
        for i in range(n):
            adj.append(np.flatnonzero(A[i]).tolist())
        return adj

    @staticmethod
    def kneighbors(adj: List[List[int]], seed: int, k_max: int) -> Set[int]:
        seen: Set[int] = {seed}
        q: List[Tuple[int, int]] = [(seed, 0)]
        head = 0
        while head < len(q):
            u, depth = q[head]
            head += 1
            if depth == k_max:
                continue
            for v in adj[u]:
                if v not in seen:
                    seen.add(v)
                    q.append((v, depth + 1))
        return seen

    @classmethod
    def predict_components(cls, A_guaranteed: np.ndarray, k_max: Optional[int] = None) -> Tuple[List[Set[int]], Dict[int, int]]:
        n = A_guaranteed.shape[0]
        if k_max is None:
            k_max = n  # full BFS as in manuscript note

        adj = cls._adj_list_from_adjacency(A_guaranteed)

        unseen = set(range(n))
        comps: List[Set[int]] = []
        comp_of: Dict[int, int] = {}

        while unseen:
            i = next(iter(unseen))
            C = cls.kneighbors(adj, i, k_max=k_max)
            idx = len(comps)
            for u in C:
                comp_of[u] = idx
            comps.append(C)
            unseen -= C

        return comps, comp_of

    @staticmethod
    def catalogue_risky_links(
        A_pred_viable: np.ndarray,
        comp_of: Dict[int, int],
        cfg: ACLDConfig,
        d_pred_uav: np.ndarray,
        theta_minus_uav: np.ndarray,
        theta_plus_uav: np.ndarray,
        d_pred_gs: Optional[np.ndarray],
        theta_minus_gs: Optional[np.ndarray],
        theta_plus_gs: Optional[np.ndarray],
    ) -> Tuple[List[Tuple[int, int]], List[Tuple[int, int]]]:
        """
        Risk-band = viable but not guaranteed (theta_minus < d <= theta_plus, and viable already ensured).
        Returns (inter, intra)
        """
        n_total = A_pred_viable.shape[0]
        n_uav = cfg.n_uavs
        sidx = cfg.gs_index

        inter: List[Tuple[int, int]] = []
        intra: List[Tuple[int, int]] = []

        for i in range(n_total):
            for j in range(i + 1, n_total):
                if A_pred_viable[i, j] == 0:
                    continue

                in_risk_band = False

                if i < n_uav and j < n_uav:
                    dij = d_pred_uav[i, j]
                    if theta_minus_uav[i, j] < dij <= theta_plus_uav[i, j]:
                        in_risk_band = True
                else:
                    # UAV-GS case
                    if sidx is None:
                        continue
                    if d_pred_gs is None or theta_minus_gs is None or theta_plus_gs is None:
                        continue
                    u = i if j == sidx else j if i == sidx else None
                    if u is not None and 0 <= u < n_uav:
                        du = d_pred_gs[u]
                        if theta_minus_gs[u] < du <= theta_plus_gs[u]:
                            in_risk_band = True

                if not in_risk_band:
                    continue

                if comp_of.get(i, -1) != comp_of.get(j, -1):
                    inter.append((i, j))
                else:
                    intra.append((i, j))

        return inter, intra

    @staticmethod
    def separate_ground_links(
        risky_inter: List[Tuple[int, int]],
        gs_index: Optional[int]
    ) -> Tuple[List[Tuple[int, int]], List[Tuple[int, int]]]:
        if gs_index is None:
            return [], risky_inter.copy()

        s_uav: List[Tuple[int, int]] = []
        uav_uav: List[Tuple[int, int]] = []

        for i, j in risky_inter:
            if i == gs_index or j == gs_index:
                s_uav.append((i, j))
            else:
                uav_uav.append((i, j))
        return s_uav, uav_uav

    # --------------------------
    # Weighted graph / Laplacian / lambda2
    # --------------------------
    def _build_current_weighted_graph(self, state: SwarmState) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        """
        Build W and L on current channel-viable graph G_viable(t), including GS if enabled.
        Uses geometric candidates + viability guard + directed row-stochastic inverse-square weights + symmetrization.
        """
        cfg = self.cfg
        n_uav = cfg.n_uavs
        n_total = cfg.n_total
        sidx = cfg.gs_index

        d_cur_uav = self._pairwise_uav_dist(state.pos)
        _, theta_minus_uav, theta_plus_uav = self.adaptive_threshold_uav_uav(state, d_cur_uav)
        s_cur_uav = self.success_probability_uav_uav(d_cur_uav, pos_for_jammer=state.pos)

        d_cur_gs = theta_minus_gs = theta_plus_gs = s_cur_gs = None
        if cfg.include_ground_station:
            d_cur_gs = self._uav_gs_dist(state.pos)
            _, theta_minus_gs, theta_plus_gs = self.adaptive_threshold_uav_gs(state, d_cur_uav)
            s_cur_gs = self.success_probability_uav_gs(d_cur_gs, state.pos)

        W_tilde = np.zeros((n_total, n_total), dtype=float)

        # UAV rows
        for i in range(n_uav):
            neighbors: List[Tuple[int, float]] = []

            # UAV-UAV viable neighbors
            for j in range(n_uav):
                if i == j:
                    continue
                if d_cur_uav[i, j] <= theta_plus_uav[i, j] and s_cur_uav[i, j] >= cfg.p_min:
                    neighbors.append((j, max(d_cur_uav[i, j], cfg.d_safe)))

            # UAV-GS viable neighbor
            if cfg.include_ground_station and d_cur_gs is not None and theta_plus_gs is not None and s_cur_gs is not None:
                if d_cur_gs[i] <= theta_plus_gs[i] and s_cur_gs[i] >= cfg.p_min:
                    neighbors.append((sidx, max(d_cur_gs[i], cfg.d_safe)))

            if len(neighbors) == 0:
                continue

            invsq = np.array([1.0 / (d * d) for _, d in neighbors], dtype=float)
            denom = float(np.sum(invsq))
            if denom <= 0:
                continue

            for (k, _d), w in zip(neighbors, invsq / denom):
                W_tilde[i, k] = float(w)

        # GS row (if present)
        if cfg.include_ground_station and sidx is not None and d_cur_gs is not None and theta_plus_gs is not None and s_cur_gs is not None:
            neighbors_gs: List[Tuple[int, float]] = []
            for i in range(n_uav):
                if d_cur_gs[i] <= theta_plus_gs[i] and s_cur_gs[i] >= cfg.p_min:
                    neighbors_gs.append((i, max(d_cur_gs[i], cfg.d_safe)))

            if neighbors_gs:
                invsq = np.array([1.0 / (d * d) for _, d in neighbors_gs], dtype=float)
                denom = float(np.sum(invsq))
                if denom > 0:
                    for (i, _d), w in zip(neighbors_gs, invsq / denom):
                        W_tilde[sidx, i] = float(w)

        # Symmetrize
        W = 0.5 * (W_tilde + W_tilde.T)
        np.fill_diagonal(W, 0.0)

        D = np.diag(np.sum(W, axis=1))
        L = D - W
        return W, D, L

    @staticmethod
    def algebraic_connectivity_lambda2(L: np.ndarray) -> float:
        n = L.shape[0]
        if n <= 1:
            return 0.0
        # Dense eigvalsh is acceptable up to n~1000 in this context
        evals = np.linalg.eigvalsh(L)
        evals = np.sort(np.real(evals))
        if len(evals) < 2:
            return 0.0
        return float(max(0.0, evals[1]))

    # --------------------------
    # Baselines
    # --------------------------
    @staticmethod
    def tarjan_bridges(A: np.ndarray) -> List[Tuple[int, int]]:
        n = A.shape[0]
        adj = ACLDCore._adj_list_from_adjacency(A)

        disc = [-1] * n
        low = [-1] * n
        parent = [-1] * n
        time_counter = 0
        bridges: List[Tuple[int, int]] = []

        def dfs(u: int) -> None:
            nonlocal time_counter
            disc[u] = low[u] = time_counter
            time_counter += 1

            for v in adj[u]:
                if disc[v] == -1:
                    parent[v] = u
                    dfs(v)
                    low[u] = min(low[u], low[v])

                    if low[v] > disc[u]:
                        bridges.append((u, v) if u < v else (v, u))
                elif v != parent[u]:
                    low[u] = min(low[u], disc[v])

        for i in range(n):
            if disc[i] == -1:
                dfs(i)

        bridges = sorted(set(bridges))
        return bridges

    @staticmethod
    def edge_betweenness_brandes_unweighted(A: np.ndarray) -> Dict[Tuple[int, int], float]:
        """
        Exact Brandes edge betweenness for unweighted undirected graph.
        """
        n = A.shape[0]
        adj = ACLDCore._adj_list_from_adjacency(A)
        eb: Dict[Tuple[int, int], float] = {}

        for s in range(n):
            stack: List[int] = []
            pred: List[List[int]] = [[] for _ in range(n)]
            sigma = np.zeros(n, dtype=float)
            dist = -np.ones(n, dtype=int)

            sigma[s] = 1.0
            dist[s] = 0

            # BFS
            q = [s]
            qh = 0
            while qh < len(q):
                v = q[qh]
                qh += 1
                stack.append(v)
                for w in adj[v]:
                    if dist[w] < 0:
                        q.append(w)
                        dist[w] = dist[v] + 1
                    if dist[w] == dist[v] + 1:
                        sigma[w] += sigma[v]
                        pred[w].append(v)

            delta = np.zeros(n, dtype=float)
            while stack:
                w = stack.pop()
                for v in pred[w]:
                    if sigma[w] > 0:
                        c = (sigma[v] / sigma[w]) * (1.0 + delta[w])
                    else:
                        c = 0.0
                    e = (v, w) if v < w else (w, v)
                    eb[e] = eb.get(e, 0.0) + c
                    delta[v] += c

        # undirected graph correction
        for e in list(eb.keys()):
            eb[e] *= 0.5
        return eb

    def baseline_tarjan_snapshot(self, state: SwarmState, with_guard: bool = True) -> List[Tuple[int, int]]:
        """
        TARJAN@t: snapshot bridges on current graph.
        If with_guard=False, geometry only (cthr+ band) without probabilistic guard.
        """
        cfg = self.cfg
        n_total = cfg.n_total
        n_uav = cfg.n_uavs
        sidx = cfg.gs_index

        d_cur_uav = self._pairwise_uav_dist(state.pos)
        _, _, theta_plus_uav = self.adaptive_threshold_uav_uav(state, d_cur_uav)
        s_cur_uav = self.success_probability_uav_uav(d_cur_uav, pos_for_jammer=state.pos)

        A = np.zeros((n_total, n_total), dtype=np.uint8)

        mask_uav = (d_cur_uav <= theta_plus_uav)
        if with_guard:
            mask_uav &= (s_cur_uav >= cfg.p_min)
        np.fill_diagonal(mask_uav, False)
        A[:n_uav, :n_uav] = mask_uav.astype(np.uint8)

        if cfg.include_ground_station:
            d_cur_gs = self._uav_gs_dist(state.pos)
            _, _, theta_plus_gs = self.adaptive_threshold_uav_gs(state, d_cur_uav)
            s_cur_gs = self.success_probability_uav_gs(d_cur_gs, state.pos)

            for i in range(n_uav):
                ok = d_cur_gs[i] <= theta_plus_gs[i]
                if with_guard:
                    ok = ok and (s_cur_gs[i] >= cfg.p_min)
                if ok:
                    A[i, sidx] = A[sidx, i] = 1

        return self.tarjan_bridges(A)

    # --------------------------
    # Main ACLD step
    # --------------------------
    def run_acld(self, state: SwarmState) -> ACLDResult:
        cfg = self.cfg

        (
            d_cur_uav, d_pred_uav,
            theta_minus_uav, theta_plus_uav,
            s_cur_uav, s_pred_uav,
            d_cur_gs, d_pred_gs, theta_minus_gs, theta_plus_gs
        ) = self._build_pred_matrices(state)

        s_pred_gs = None
        if cfg.include_ground_station and d_pred_gs is not None:
            pred_pos = self._predict_positions_one_step(state)
            s_pred_gs = self.success_probability_uav_gs(d_pred_gs, pred_pos)

        A_pred_guaranteed, A_pred_viable = self._build_pred_adjacency(
            d_pred_uav=d_pred_uav,
            theta_minus_uav=theta_minus_uav,
            theta_plus_uav=theta_plus_uav,
            s_pred_uav=s_pred_uav,
            d_pred_gs=d_pred_gs,
            theta_minus_gs=theta_minus_gs,
            theta_plus_gs=theta_plus_gs,
            s_pred_gs=s_pred_gs,
        )

        components, comp_of = self.predict_components(A_pred_guaranteed, k_max=cfg.n_total)

        risky_inter, risky_intra = self.catalogue_risky_links(
            A_pred_viable=A_pred_viable,
            comp_of=comp_of,
            cfg=cfg,
            d_pred_uav=d_pred_uav,
            theta_minus_uav=theta_minus_uav,
            theta_plus_uav=theta_plus_uav,
            d_pred_gs=d_pred_gs,
            theta_minus_gs=theta_minus_gs,
            theta_plus_gs=theta_plus_gs,
        )

        risky_s_uav, risky_uav_uav = self.separate_ground_links(risky_inter, cfg.gs_index)

        W_current, D_current, L_current = self._build_current_weighted_graph(state)
        lambda2_current = self.algebraic_connectivity_lambda2(L_current)

        # predicted viable Laplacian (binary) for diagnostic only
        D_pred = np.diag(np.sum(A_pred_viable.astype(float), axis=1))
        L_pred_viable = D_pred - A_pred_viable.astype(float)
        lambda2_pred_viable = self.algebraic_connectivity_lambda2(L_pred_viable)

        return ACLDResult(
            A_pred_guaranteed=A_pred_guaranteed,
            A_pred_viable=A_pred_viable,
            d_pred_uav=d_pred_uav,
            d_cur_uav=d_cur_uav,
            theta_minus_uav=theta_minus_uav,
            theta_plus_uav=theta_plus_uav,
            theta_minus_gs=theta_minus_gs,
            theta_plus_gs=theta_plus_gs,
            s_pred_uav=s_pred_uav,
            s_cur_uav=s_cur_uav,
            components=components,
            comp_of=comp_of,
            risky_inter=risky_inter,
            risky_intra=risky_intra,
            risky_s_uav=risky_s_uav,
            risky_uav_uav=risky_uav_uav,
            lambda2_current=lambda2_current,
            lambda2_pred_viable=lambda2_pred_viable,
            W_current=W_current,
            L_current=L_current,
        )

In [None]:
# ============================================================
# Simulation / Evaluation helpers
# ============================================================

def _clip_rows_norm(X: np.ndarray, max_norm: float) -> np.ndarray:
    norms = np.linalg.norm(X, axis=1, keepdims=True)
    scale = np.ones_like(norms)
    mask = norms > max_norm
    scale[mask] = max_norm / np.maximum(norms[mask], 1e-12)
    return X * scale


def make_rng(seed: int) -> np.random.Generator:
    return np.random.default_rng(seed)


@dataclass
class StepLog:
    t: float
    lambda2_current: float
    partition_now: bool
    n_components_pred_guar: int
    n_risky_inter: int
    n_risky_intra: int
    n_risky_s_uav: int
    n_risky_uav_uav: int
    runtime_ms: float


def run_episode(
    cfg: ACLDConfig,
    horizon_s: float = 10.0,
    seed: Optional[int] = None,
) -> Tuple[List[StepLog], List[ACLDResult]]:
    rng = make_rng(cfg.seed if seed is None else seed)
    state = SwarmState.random_init(cfg, rng)
    acld = ACLDCore(cfg)

    T = int(round(horizon_s / cfg.dt_sim))
    logs: List[StepLog] = []
    results: List[ACLDResult] = []

    t_now = 0.0
    for _ in range(T):
        t0 = time.perf_counter()
        res = acld.run_acld(state)
        t1 = time.perf_counter()
        runtime_ms = (t1 - t0) * 1000.0

        partition_now = res.lambda2_current < cfg.lambda2_partition_threshold
        logs.append(
            StepLog(
                t=t_now,
                lambda2_current=res.lambda2_current,
                partition_now=partition_now,
                n_components_pred_guar=len(res.components),
                n_risky_inter=len(res.risky_inter),
                n_risky_intra=len(res.risky_intra),
                n_risky_s_uav=len(res.risky_s_uav),
                n_risky_uav_uav=len(res.risky_uav_uav),
                runtime_ms=runtime_ms,
            )
        )
        results.append(res)

        # Advance simulation
        state.step_random_kinematics(cfg, rng)
        t_now += cfg.dt_sim

    return logs, results


def summarize_episode(logs: List[StepLog]) -> Dict[str, float]:
    lambda2_vals = np.array([x.lambda2_current for x in logs], dtype=float)
    partitions = np.array([x.partition_now for x in logs], dtype=bool)
    runtimes = np.array([x.runtime_ms for x in logs], dtype=float)

    out = {
        "mean_lambda2": float(np.mean(lambda2_vals)) if len(lambda2_vals) else 0.0,
        "partition_count": int(np.sum(partitions)),
        "partition_rate": float(np.mean(partitions)) if len(partitions) else 0.0,
        "runtime_median_ms": float(np.median(runtimes)) if len(runtimes) else 0.0,
        "runtime_iqr_low_ms": float(np.percentile(runtimes, 25)) if len(runtimes) else 0.0,
        "runtime_iqr_high_ms": float(np.percentile(runtimes, 75)) if len(runtimes) else 0.0,
    }
    return out

In [None]:
# ============================================================
# Simple benchmark (runtime vs n)
# ============================================================

def benchmark_runtime_vs_n(
    base_cfg: ACLDConfig,
    sizes: Iterable[int] = (50, 150, 500, 1000),
    repeats: int = 200,
) -> List[Dict[str, float]]:
    rows: List[Dict[str, float]] = []

    for n in sizes:
        cfg = ACLDConfig(**{**base_cfg.__dict__, "n_uavs": int(n)})
        rng = make_rng(cfg.seed + int(n))
        state = SwarmState.random_init(cfg, rng)
        acld = ACLDCore(cfg)

        times_ms = np.zeros(repeats, dtype=float)
        for k in range(repeats):
            t0 = time.perf_counter()
            _ = acld.run_acld(state)
            t1 = time.perf_counter()
            times_ms[k] = (t1 - t0) * 1000.0
            # small random motion to avoid exact cache repetition
            state.step_random_kinematics(cfg, rng)

        row = {
            "n": int(n),
            "pairs": int(n * (n - 1) // 2),
            "median_ms": float(np.median(times_ms)),
            "q1_ms": float(np.percentile(times_ms, 25)),
            "q3_ms": float(np.percentile(times_ms, 75)),
            "mean_ms": float(np.mean(times_ms)),
            "p95_ms": float(np.percentile(times_ms, 95)),
        }
        rows.append(row)

    return rows

In [None]:
# ============================================================
# Optional: one-step evaluation helper for PEW/FP style labels
# ============================================================

def connectivity_components_from_adjacency(A: np.ndarray) -> Tuple[int, np.ndarray]:
    G = csr_matrix(A.astype(np.int8))
    n_comp, labels = connected_components(G, directed=False, return_labels=True)
    return int(n_comp), labels


def edge_deletion_disconnects(A: np.ndarray, edge: Tuple[int, int]) -> bool:
    i, j = edge
    if A[i, j] == 0:
        return False
    A2 = A.copy()
    A2[i, j] = 0
    A2[j, i] = 0
    n_comp_before, _ = connectivity_components_from_adjacency(A)
    n_comp_after, _ = connectivity_components_from_adjacency(A2)
    return n_comp_after > n_comp_before


def approximate_edge_f1_on_predicted_graph(
    A_pred_viable: np.ndarray,
    flagged_edges: List[Tuple[int, int]],
    sample_size: int = 200,
    rng: Optional[np.random.Generator] = None,
) -> Dict[str, float]:
    """
    Approximate edge-level F1 by sampling edges and testing edge deletion connectivity impact.
    """
    if rng is None:
        rng = np.random.default_rng(123)

    n = A_pred_viable.shape[0]
    edges = [(i, j) for i in range(n) for j in range(i + 1, n) if A_pred_viable[i, j] == 1]
    if not edges:
        return {"precision": 0.0, "recall": 0.0, "f1": 0.0}

    if len(edges) > sample_size:
        idx = rng.choice(len(edges), size=sample_size, replace=False)
        sampled = [edges[k] for k in idx]
    else:
        sampled = edges

    true_critical = set(e for e in sampled if edge_deletion_disconnects(A_pred_viable, e))
    flagged = set(flagged_edges)

    tp = len(true_critical & flagged)
    fp = len(flagged - true_critical)
    fn = len(true_critical - flagged)

    precision = tp / (tp + fp) if (tp + fp) else 0.0
    recall = tp / (tp + fn) if (tp + fn) else 0.0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) else 0.0
    return {"precision": precision, "recall": recall, "f1": f1}

In [7]:
# ============================================================
# Paper-style experiment + plotting suite (ACLD full eval)
# ============================================================

@dataclass
class ScenarioSpec:
    name: str
    cfg_overrides: Dict[str, object]
    horizon_s: float = 10.0
    seeds: Tuple[int, ...] = (42, 43, 44, 45, 46)
    ebc_every_k_steps: int = 5  # exact EBC is expensive; evaluate every k-th step


def _clone_cfg(base_cfg: ACLDConfig, **overrides) -> ACLDConfig:
    d = dict(base_cfg.__dict__)
    d.update(overrides)
    return ACLDConfig(**d)


def _ensure_dir(path: str) -> None:
    os.makedirs(path, exist_ok=True)


def _save_rows_csv(path: str, rows: List[Dict[str, object]]) -> None:
    _ensure_dir(os.path.dirname(path) or ".")
    if not rows:
        with open(path, "w", newline="", encoding="utf-8") as f:
            f.write("")
        return
    keys = sorted({k for r in rows for k in r.keys()})
    with open(path, "w", newline="", encoding="utf-8") as f:
        w = csv.DictWriter(f, fieldnames=keys)
        w.writeheader()
        for r in rows:
            w.writerow(r)


def _save_json(path: str, obj: object) -> None:
    _ensure_dir(os.path.dirname(path) or ".")
    with open(path, "w", encoding="utf-8") as f:
        json.dump(obj, f, indent=2, ensure_ascii=False)


def _binary_confusion_metrics(y_true: List[int], y_pred: List[int]) -> Dict[str, float]:
    yt = np.array(y_true, dtype=int)
    yp = np.array(y_pred, dtype=int)
    if yt.size == 0:
        return dict(tp=0, fp=0, tn=0, fn=0, precision=0.0, recall=0.0, f1=0.0, fpr=0.0, tpr=0.0, acc=0.0)

    tp = int(np.sum((yt == 1) & (yp == 1)))
    fp = int(np.sum((yt == 0) & (yp == 1)))
    tn = int(np.sum((yt == 0) & (yp == 0)))
    fn = int(np.sum((yt == 1) & (yp == 0)))

    precision = tp / (tp + fp) if (tp + fp) else 0.0
    recall = tp / (tp + fn) if (tp + fn) else 0.0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) else 0.0
    fpr = fp / (fp + tn) if (fp + tn) else 0.0
    tpr = recall
    acc = (tp + tn) / (tp + tn + fp + fn) if (tp + tn + fp + fn) else 0.0

    return dict(tp=tp, fp=fp, tn=tn, fn=fn, precision=precision, recall=recall, f1=f1, fpr=fpr, tpr=tpr, acc=acc)


def _nanmean(vals: List[float]) -> float:
    if not vals:
        return float("nan")
    a = np.array(vals, dtype=float)
    if a.size == 0:
        return float("nan")
    return float(np.nanmean(a))


def _nanstd(vals: List[float]) -> float:
    if not vals:
        return float("nan")
    a = np.array(vals, dtype=float)
    if a.size == 0:
        return float("nan")
    return float(np.nanstd(a))


def build_current_viable_adjacency(acld: ACLDCore, state: SwarmState, with_guard: bool = True) -> np.ndarray:
    """
    Current snapshot adjacency on V' (UAVs + optional GS), using theta_plus and optionally p_min guard.
    """
    cfg = acld.cfg
    n_uav = cfg.n_uavs
    n_total = cfg.n_total
    sidx = cfg.gs_index

    d_cur_uav = acld._pairwise_uav_dist(state.pos)
    _, _, theta_plus_uav = acld.adaptive_threshold_uav_uav(state, d_cur_uav)
    s_cur_uav = acld.success_probability_uav_uav(d_cur_uav, pos_for_jammer=state.pos)

    A = np.zeros((n_total, n_total), dtype=np.uint8)

    mask_uav = (d_cur_uav <= theta_plus_uav)
    if with_guard:
        mask_uav &= (s_cur_uav >= cfg.p_min)
    np.fill_diagonal(mask_uav, False)
    A[:n_uav, :n_uav] = mask_uav.astype(np.uint8)

    if cfg.include_ground_station:
        d_cur_gs = acld._uav_gs_dist(state.pos)
        _, _, theta_plus_gs = acld.adaptive_threshold_uav_gs(state, d_cur_uav)
        s_cur_gs = acld.success_probability_uav_gs(d_cur_gs, state.pos)

        for i in range(n_uav):
            ok = bool(d_cur_gs[i] <= theta_plus_gs[i])
            if with_guard:
                ok = ok and bool(s_cur_gs[i] >= cfg.p_min)
            if ok:
                A[i, sidx] = 1
                A[sidx, i] = 1

    return A


def binary_lambda2_from_adjacency(A: np.ndarray) -> float:
    D = np.diag(np.sum(A.astype(float), axis=1))
    L = D - A.astype(float)
    return ACLDCore.algebraic_connectivity_lambda2(L)


def topk_ebc_edges(A: np.ndarray, k: int) -> List[Tuple[int, int]]:
    if k <= 0:
        return []
    eb = ACLDCore.edge_betweenness_brandes_unweighted(A)
    if not eb:
        return []
    ranked = sorted(eb.items(), key=lambda kv: kv[1], reverse=True)
    return [e for e, _ in ranked[:k]]


def _edge_f1_exact_vs_bridges(A_target: np.ndarray, flagged_edges: List[Tuple[int, int]]) -> Dict[str, float]:
    """
    Edge-level F1 where ground truth = bridges of target graph (exact).
    """
    true_critical = set(ACLDCore.tarjan_bridges(A_target))
    all_edges = set((i, j) for i in range(A_target.shape[0]) for j in range(i + 1, A_target.shape[0]) if A_target[i, j] == 1)
    flagged = set(e for e in flagged_edges if e in all_edges)

    tp = len(true_critical & flagged)
    fp = len(flagged - true_critical)
    fn = len(true_critical - flagged)

    precision = tp / (tp + fp) if (tp + fp) else 0.0
    recall = tp / (tp + fn) if (tp + fn) else 0.0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) else 0.0
    return {"precision": precision, "recall": recall, "f1": f1, "tp": tp, "fp": fp, "fn": fn}


def run_evaluated_episode(
    cfg: ACLDConfig,
    horizon_s: float = 10.0,
    seed: Optional[int] = None,
    ebc_every_k_steps: int = 5,
    compute_exact_edge_f1_for_n_leq: int = 220,
) -> Tuple[List[Dict[str, object]], Dict[str, object]]:
    """
    Runs one episode and evaluates one-step-ahead prediction/alerts vs actual next snapshot.
    Returns:
      step_rows: list of per-step dicts
      episode_summary: dict
    """
    rng = make_rng(cfg.seed if seed is None else seed)
    state = SwarmState.random_init(cfg, rng)
    acld = ACLDCore(cfg)

    T = int(round(horizon_s / cfg.dt_sim))
    step_rows: List[Dict[str, object]] = []

    # For confusion metrics
    y_true_partition_next: List[int] = []
    yhat_acld: List[int] = []
    yhat_tarjan_guard: List[int] = []
    yhat_tarjan_geom: List[int] = []

    e1_acld_list: List[float] = []
    e1_tarjan_list: List[float] = []
    e1_ebc_list: List[float] = []

    t_now = 0.0
    for t_idx in range(max(0, T - 1)):
        t0 = time.perf_counter()
        res = acld.run_acld(state)
        t1 = time.perf_counter()
        runtime_ms = (t1 - t0) * 1000.0

        A_now_viable = build_current_viable_adjacency(acld, state, with_guard=True)
        A_now_geom = build_current_viable_adjacency(acld, state, with_guard=False)

        # Baselines at t
        tarjan_guard_edges = ACLDCore.tarjan_bridges(A_now_viable)
        tarjan_geom_edges = ACLDCore.tarjan_bridges(A_now_geom)

        # EBC baseline (subsampled in time due cost)
        do_ebc = (ebc_every_k_steps <= 1) or (t_idx % ebc_every_k_steps == 0)
        ebc_edges = []
        if do_ebc:
            k_ebc = max(1, len(res.risky_inter)) if len(res.risky_inter) > 0 else 1
            ebc_edges = topk_ebc_edges(A_now_viable, k=k_ebc)

        # Advance one step to obtain actual next label/graph
        next_state = state.copy()
        next_state.step_random_kinematics(cfg, rng)

        A_next_viable = build_current_viable_adjacency(acld, next_state, with_guard=True)
        n_comp_next, _ = connectivity_components_from_adjacency(A_next_viable)
        lambda2_next_bin = binary_lambda2_from_adjacency(A_next_viable)
        lambda2_next_weighted = acld._build_current_weighted_graph(next_state)[2]
        lambda2_next_weighted_val = ACLDCore.algebraic_connectivity_lambda2(lambda2_next_weighted)

        # One-step partition label (binary)
        # Positive if actual next viable graph is disconnected (components > 1)
        y_true = 1 if n_comp_next > 1 else 0

        # ACLD alert rule (early warning + predicted partition diagnostics)
        acld_alert = int(
            (len(res.risky_inter) > 0)
            or (len(res.components) > 1)
            or (res.lambda2_pred_viable < cfg.lambda2_partition_threshold)
        )

        tarjan_guard_alert = int(len(tarjan_guard_edges) > 0)
        tarjan_geom_alert = int(len(tarjan_geom_edges) > 0)

        y_true_partition_next.append(y_true)
        yhat_acld.append(acld_alert)
        yhat_tarjan_guard.append(tarjan_guard_alert)
        yhat_tarjan_geom.append(tarjan_geom_alert)

        # Edge-level F1 vs next-step bridges (exact for moderate n, approximate for larger n)
        target_flagged_acld = sorted(set(res.risky_inter + res.risky_intra))
        if cfg.n_total <= compute_exact_edge_f1_for_n_leq:
            f1_acld = _edge_f1_exact_vs_bridges(A_next_viable, target_flagged_acld)["f1"]
            f1_tarjan = _edge_f1_exact_vs_bridges(A_next_viable, tarjan_guard_edges)["f1"]
            if do_ebc:
                f1_ebc = _edge_f1_exact_vs_bridges(A_next_viable, ebc_edges)["f1"]
            else:
                f1_ebc = np.nan
        else:
            f1_acld = approximate_edge_f1_on_predicted_graph(A_next_viable, target_flagged_acld, sample_size=200, rng=rng)["f1"]
            f1_tarjan = approximate_edge_f1_on_predicted_graph(A_next_viable, tarjan_guard_edges, sample_size=200, rng=rng)["f1"]
            if do_ebc:
                f1_ebc = approximate_edge_f1_on_predicted_graph(A_next_viable, ebc_edges, sample_size=200, rng=rng)["f1"]
            else:
                f1_ebc = np.nan

        e1_acld_list.append(float(f1_acld))
        e1_tarjan_list.append(float(f1_tarjan))
        if do_ebc:
            e1_ebc_list.append(float(f1_ebc))

        step_rows.append(
            {
                "t": t_now,
                "step_idx": t_idx,
                "runtime_ms": runtime_ms,
                "lambda2_current_weighted": float(res.lambda2_current),
                "lambda2_pred_viable_bin": float(res.lambda2_pred_viable),
                "lambda2_next_viable_bin": float(lambda2_next_bin),
                "lambda2_next_weighted": float(lambda2_next_weighted_val),
                "n_components_pred_guar": int(len(res.components)),
                "n_components_next_viable": int(n_comp_next),
                "n_risky_inter": int(len(res.risky_inter)),
                "n_risky_intra": int(len(res.risky_intra)),
                "n_risky_s_uav": int(len(res.risky_s_uav)),
                "n_risky_uav_uav": int(len(res.risky_uav_uav)),
                "tarjan_guard_bridges": int(len(tarjan_guard_edges)),
                "tarjan_geom_bridges": int(len(tarjan_geom_edges)),
                "ebc_k": int(len(ebc_edges)),
                "label_partition_next": int(y_true),
                "alert_acld": int(acld_alert),
                "alert_tarjan_guard": int(tarjan_guard_alert),
                "alert_tarjan_geom": int(tarjan_geom_alert),
                "edge_f1_acld_vs_next": float(f1_acld),
                "edge_f1_tarjan_vs_next": float(f1_tarjan),
                "edge_f1_ebc_vs_next": float(f1_ebc) if do_ebc else np.nan,
            }
        )

        # move to next state
        state = next_state
        t_now += cfg.dt_sim

    m_acld = _binary_confusion_metrics(y_true_partition_next, yhat_acld)
    m_tg = _binary_confusion_metrics(y_true_partition_next, yhat_tarjan_guard)
    m_tgeom = _binary_confusion_metrics(y_true_partition_next, yhat_tarjan_geom)

    runtimes = [r["runtime_ms"] for r in step_rows]
    summary = {
        "n_steps_eval": int(len(step_rows)),
        "partition_next_rate": float(np.mean(y_true_partition_next)) if y_true_partition_next else 0.0,
        "runtime_median_ms": float(np.median(runtimes)) if runtimes else 0.0,
        "runtime_p95_ms": float(np.percentile(runtimes, 95)) if runtimes else 0.0,
        "mean_lambda2_current_weighted": float(np.mean([r["lambda2_current_weighted"] for r in step_rows])) if step_rows else 0.0,

        # ACLD confusion
        "acld_precision": m_acld["precision"],
        "acld_recall": m_acld["recall"],
        "acld_f1": m_acld["f1"],
        "acld_fpr": m_acld["fpr"],
        "acld_acc": m_acld["acc"],

        # Tarjan guard
        "tarjan_guard_precision": m_tg["precision"],
        "tarjan_guard_recall": m_tg["recall"],
        "tarjan_guard_f1": m_tg["f1"],
        "tarjan_guard_fpr": m_tg["fpr"],
        "tarjan_guard_acc": m_tg["acc"],

        # Tarjan geom-only
        "tarjan_geom_precision": m_tgeom["precision"],
        "tarjan_geom_recall": m_tgeom["recall"],
        "tarjan_geom_f1": m_tgeom["f1"],
        "tarjan_geom_fpr": m_tgeom["fpr"],
        "tarjan_geom_acc": m_tgeom["acc"],

        # Edge F1
        "edge_f1_acld_mean": _nanmean(e1_acld_list),
        "edge_f1_tarjan_mean": _nanmean(e1_tarjan_list),
        "edge_f1_ebc_mean": _nanmean(e1_ebc_list),
    }
    return step_rows, summary


def run_scenario_mc(
    base_cfg: ACLDConfig,
    spec: ScenarioSpec,
) -> Tuple[List[Dict[str, object]], List[Dict[str, object]], Dict[str, object]]:
    """
    Returns:
      scenario_episode_rows (one row per seed),
      scenario_step_rows (all steps, tagged),
      aggregated_summary
    """
    cfg = _clone_cfg(base_cfg, **spec.cfg_overrides)

    episode_rows: List[Dict[str, object]] = []
    all_step_rows: List[Dict[str, object]] = []

    for seed in spec.seeds:
        step_rows, ep_summary = run_evaluated_episode(
            cfg=cfg,
            horizon_s=spec.horizon_s,
            seed=seed,
            ebc_every_k_steps=spec.ebc_every_k_steps,
        )
        ep_row = {
            "scenario": spec.name,
            "seed": seed,
            "n_uavs": cfg.n_uavs,
            "density_scale_mode": cfg.density_scale_mode,
            "jnr_db": cfg.jnr_db if cfg.jnr_db is not None else "None",
            **ep_summary,
        }
        episode_rows.append(ep_row)

        for r in step_rows:
            rr = dict(r)
            rr["scenario"] = spec.name
            rr["seed"] = seed
            rr["n_uavs"] = cfg.n_uavs
            rr["density_scale_mode"] = cfg.density_scale_mode
            rr["jnr_db"] = cfg.jnr_db if cfg.jnr_db is not None else "None"
            all_step_rows.append(rr)

    # Aggregate across seeds
    def _agg(key: str) -> Tuple[float, float]:
        vals = [float(r[key]) for r in episode_rows if key in r and r[key] == r[key]]
        if not vals:
            return float("nan"), float("nan")
        return float(np.mean(vals)), float(np.std(vals))

    agg = {
        "scenario": spec.name,
        "n_episodes": len(episode_rows),
        "n_uavs": cfg.n_uavs,
        "density_scale_mode": cfg.density_scale_mode,
        "jnr_db": cfg.jnr_db if cfg.jnr_db is not None else "None",
    }

    keys_to_aggregate = [
        "partition_next_rate",
        "runtime_median_ms",
        "runtime_p95_ms",
        "mean_lambda2_current_weighted",
        "acld_precision", "acld_recall", "acld_f1", "acld_fpr", "acld_acc",
        "tarjan_guard_precision", "tarjan_guard_recall", "tarjan_guard_f1", "tarjan_guard_fpr", "tarjan_guard_acc",
        "tarjan_geom_precision", "tarjan_geom_recall", "tarjan_geom_f1", "tarjan_geom_fpr", "tarjan_geom_acc",
        "edge_f1_acld_mean", "edge_f1_tarjan_mean", "edge_f1_ebc_mean",
    ]

    for k in keys_to_aggregate:
        mu, sd = _agg(k)
        agg[f"{k}_mean"] = mu
        agg[f"{k}_std"] = sd

    return episode_rows, all_step_rows, agg


# --------------------------
# Plotting helpers
# --------------------------

def _plot_timeseries_representative(step_rows: List[Dict[str, object]], outdir: str, prefix: str = "fig") -> None:
    if not step_rows:
        return

    t = np.array([r["t"] for r in step_rows], dtype=float)
    lam = np.array([r["lambda2_current_weighted"] for r in step_rows], dtype=float)
    lam_pred = np.array([r["lambda2_pred_viable_bin"] for r in step_rows], dtype=float)
    lam_next = np.array([r["lambda2_next_viable_bin"] for r in step_rows], dtype=float)
    r_inter = np.array([r["n_risky_inter"] for r in step_rows], dtype=float)
    r_intra = np.array([r["n_risky_intra"] for r in step_rows], dtype=float)
    r_s = np.array([r["n_risky_s_uav"] for r in step_rows], dtype=float)
    r_uu = np.array([r["n_risky_uav_uav"] for r in step_rows], dtype=float)
    comp_pred = np.array([r["n_components_pred_guar"] for r in step_rows], dtype=float)
    comp_next = np.array([r["n_components_next_viable"] for r in step_rows], dtype=float)
    rt = np.array([r["runtime_ms"] for r in step_rows], dtype=float)
    y = np.array([r["label_partition_next"] for r in step_rows], dtype=float)

    # Figure 1: lambda2 trends
    plt.figure(figsize=(10, 4.8))
    plt.plot(t, lam, label="λ2 current (weighted)")
    plt.plot(t, lam_pred, label="λ2 pred viable (binary)")
    plt.plot(t, lam_next, label="λ2 next viable (binary)")
    plt.xlabel("Time (s)")
    plt.ylabel("Connectivity metric")
    plt.title("Connectivity traces over time")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(os.path.join(outdir, f"{prefix}_01_lambda2_timeseries.png"), dpi=180)
    plt.close()

    # Figure 2: risky link counts
    plt.figure(figsize=(10, 4.8))
    plt.plot(t, r_inter, label="risky inter")
    plt.plot(t, r_intra, label="risky intra")
    plt.plot(t, r_s, label="risky S-UAV")
    plt.plot(t, r_uu, label="risky UAV-UAV")
    plt.xlabel("Time (s)")
    plt.ylabel("Count")
    plt.title("Risky link counts over time")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(os.path.join(outdir, f"{prefix}_02_risky_link_counts.png"), dpi=180)
    plt.close()

    # Figure 3: component counts + partition label
    plt.figure(figsize=(10, 4.8))
    plt.plot(t, comp_pred, label="pred guaranteed components")
    plt.plot(t, comp_next, label="next viable components")
    plt.step(t, y, where="post", label="partition label (next)", alpha=0.8)
    plt.xlabel("Time (s)")
    plt.ylabel("Count / label")
    plt.title("Predicted vs actual next-step partition behavior")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(os.path.join(outdir, f"{prefix}_03_components_partition_label.png"), dpi=180)
    plt.close()

    # Figure 4: runtime
    plt.figure(figsize=(10, 4.8))
    plt.plot(t, rt)
    plt.xlabel("Time (s)")
    plt.ylabel("Runtime (ms)")
    plt.title("ACLD per-step runtime")
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(os.path.join(outdir, f"{prefix}_04_runtime_timeseries.png"), dpi=180)
    plt.close()


def _plot_runtime_scaling(bench_rows: List[Dict[str, float]], outdir: str, prefix: str = "fig") -> None:
    if not bench_rows:
        return
    n = np.array([r["n"] for r in bench_rows], dtype=float)
    med = np.array([r["median_ms"] for r in bench_rows], dtype=float)
    q1 = np.array([r["q1_ms"] for r in bench_rows], dtype=float)
    q3 = np.array([r["q3_ms"] for r in bench_rows], dtype=float)
    p95 = np.array([r["p95_ms"] for r in bench_rows], dtype=float)

    plt.figure(figsize=(8, 5))
    plt.plot(n, med, marker="o", label="Median")
    plt.fill_between(n, q1, q3, alpha=0.2, label="IQR")
    plt.plot(n, p95, marker="s", linestyle="--", label="P95")
    plt.xlabel("Number of UAVs")
    plt.ylabel("Runtime (ms)")
    plt.title("Runtime scaling vs swarm size")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(os.path.join(outdir, f"{prefix}_05_runtime_scaling.png"), dpi=180)
    plt.close()


def _plot_scenario_bars(agg_rows: List[Dict[str, object]], outdir: str, prefix: str = "fig") -> None:
    if not agg_rows:
        return
    names = [str(r["scenario"]) for r in agg_rows]
    x = np.arange(len(names), dtype=float)
    w = 0.23

    acld_f1 = np.array([float(r.get("acld_f1_mean", np.nan)) for r in agg_rows], dtype=float)
    tg_f1 = np.array([float(r.get("tarjan_guard_f1_mean", np.nan)) for r in agg_rows], dtype=float)
    tgeom_f1 = np.array([float(r.get("tarjan_geom_f1_mean", np.nan)) for r in agg_rows], dtype=float)

    plt.figure(figsize=(max(9, 1.2 * len(names)), 5.2))
    plt.bar(x - w, acld_f1, width=w, label="ACLD")
    plt.bar(x, tg_f1, width=w, label="Tarjan+guard")
    plt.bar(x + w, tgeom_f1, width=w, label="Tarjan geom-only")
    plt.xticks(x, names, rotation=30, ha="right")
    plt.ylabel("F1 (partition early warning)")
    plt.title("PEW F1 across scenarios")
    plt.legend()
    plt.grid(True, axis="y", alpha=0.3)
    plt.tight_layout()
    plt.savefig(os.path.join(outdir, f"{prefix}_06_pew_f1_scenarios.png"), dpi=180)
    plt.close()

    # FPR comparison
    acld_fpr = np.array([float(r.get("acld_fpr_mean", np.nan)) for r in agg_rows], dtype=float)
    tg_fpr = np.array([float(r.get("tarjan_guard_fpr_mean", np.nan)) for r in agg_rows], dtype=float)
    tgeom_fpr = np.array([float(r.get("tarjan_geom_fpr_mean", np.nan)) for r in agg_rows], dtype=float)

    plt.figure(figsize=(max(9, 1.2 * len(names)), 5.2))
    plt.bar(x - w, acld_fpr, width=w, label="ACLD")
    plt.bar(x, tg_fpr, width=w, label="Tarjan+guard")
    plt.bar(x + w, tgeom_fpr, width=w, label="Tarjan geom-only")
    plt.xticks(x, names, rotation=30, ha="right")
    plt.ylabel("False Positive Rate")
    plt.title("False-positive behavior across scenarios")
    plt.legend()
    plt.grid(True, axis="y", alpha=0.3)
    plt.tight_layout()
    plt.savefig(os.path.join(outdir, f"{prefix}_07_fpr_scenarios.png"), dpi=180)
    plt.close()

    # Edge F1 comparison
    acld_e = np.array([float(r.get("edge_f1_acld_mean_mean", np.nan)) for r in agg_rows], dtype=float)
    tg_e = np.array([float(r.get("edge_f1_tarjan_mean_mean", np.nan)) for r in agg_rows], dtype=float)
    ebc_e = np.array([float(r.get("edge_f1_ebc_mean_mean", np.nan)) for r in agg_rows], dtype=float)

    plt.figure(figsize=(max(9, 1.2 * len(names)), 5.2))
    plt.bar(x - w, acld_e, width=w, label="ACLD")
    plt.bar(x, tg_e, width=w, label="Tarjan+guard")
    plt.bar(x + w, ebc_e, width=w, label="EBC@t top-k")
    plt.xticks(x, names, rotation=30, ha="right")
    plt.ylabel("Edge-level F1 (vs next bridges)")
    plt.title("Edge-level critical-link prediction quality")
    plt.legend()
    plt.grid(True, axis="y", alpha=0.3)
    plt.tight_layout()
    plt.savefig(os.path.join(outdir, f"{prefix}_08_edge_f1_scenarios.png"), dpi=180)
    plt.close()


def _plot_runtime_boxplot_per_scenario(episode_rows: List[Dict[str, object]], outdir: str, prefix: str = "fig") -> None:
    if not episode_rows:
        return
    by_s: Dict[str, List[float]] = {}
    for r in episode_rows:
        by_s.setdefault(str(r["scenario"]), []).append(float(r["runtime_median_ms"]))
    names = list(by_s.keys())
    data = [by_s[k] for k in names]

    plt.figure(figsize=(max(8, 1.1 * len(names)), 5))
    plt.boxplot(data, labels=names, showmeans=True)
    plt.xticks(rotation=30, ha="right")
    plt.ylabel("Episode median runtime (ms)")
    plt.title("Runtime distribution across scenarios/seeds")
    plt.grid(True, axis="y", alpha=0.3)
    plt.tight_layout()
    plt.savefig(os.path.join(outdir, f"{prefix}_09_runtime_boxplot_scenarios.png"), dpi=180)
    plt.close()


def _plot_jamming_sweep(jam_rows: List[Dict[str, object]], outdir: str, prefix: str = "fig") -> None:
    if not jam_rows:
        return
    # sort by jnr
    def _jnr_key(v):
        x = v["jnr_db"]
        return -1 if x == "None" else float(x)
    rows = sorted(jam_rows, key=_jnr_key)

    x = []
    labels = []
    for r in rows:
        labels.append(str(r["jnr_db"]))
        x.append(len(x))
    x = np.array(x, dtype=float)

    lam = np.array([float(r["mean_lambda2_current_weighted_mean"]) for r in rows], dtype=float)
    part = np.array([float(r["partition_next_rate_mean"]) for r in rows], dtype=float)
    acld_f1 = np.array([float(r["acld_f1_mean"]) for r in rows], dtype=float)

    plt.figure(figsize=(9, 5))
    plt.plot(x, lam, marker="o", label="mean λ2 (weighted)")
    plt.plot(x, part, marker="s", label="partition-next rate")
    plt.plot(x, acld_f1, marker="^", label="ACLD PEW F1")
    plt.xticks(x, labels)
    plt.xlabel("JNR (dB) [None=nominal]")
    plt.ylabel("Metric value")
    plt.title("Jamming stress sweep")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(os.path.join(outdir, f"{prefix}_10_jamming_sweep.png"), dpi=180)
    plt.close()


def _plot_density_sweep(density_rows: List[Dict[str, object]], outdir: str, prefix: str = "fig") -> None:
    if not density_rows:
        return
    order = {"sparse": 0, "medium": 1, "dense": 2}
    rows = sorted(density_rows, key=lambda r: order.get(str(r["density_scale_mode"]), 999))
    names = [str(r["density_scale_mode"]) for r in rows]
    x = np.arange(len(names), dtype=float)

    part = np.array([float(r["partition_next_rate_mean"]) for r in rows], dtype=float)
    acld_f1 = np.array([float(r["acld_f1_mean"]) for r in rows], dtype=float)
    fpr = np.array([float(r["acld_fpr_mean"]) for r in rows], dtype=float)

    plt.figure(figsize=(8, 5))
    plt.plot(x, part, marker="o", label="partition-next rate")
    plt.plot(x, acld_f1, marker="s", label="ACLD PEW F1")
    plt.plot(x, fpr, marker="^", label="ACLD FPR")
    plt.xticks(x, names)
    plt.xlabel("Density mode")
    plt.ylabel("Metric value")
    plt.title("Density sweep")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(os.path.join(outdir, f"{prefix}_11_density_sweep.png"), dpi=180)
    plt.close()


def _plot_ablation(ablation_rows: List[Dict[str, object]], outdir: str, prefix: str = "fig") -> None:
    if not ablation_rows:
        return
    names = [str(r["scenario"]) for r in ablation_rows]
    x = np.arange(len(names), dtype=float)
    acld_f1 = np.array([float(r["acld_f1_mean"]) for r in ablation_rows], dtype=float)
    acld_fpr = np.array([float(r["acld_fpr_mean"]) for r in ablation_rows], dtype=float)
    edge_f1 = np.array([float(r["edge_f1_acld_mean_mean"]) for r in ablation_rows], dtype=float)

    plt.figure(figsize=(max(9, 1.2 * len(names)), 5))
    plt.plot(x, acld_f1, marker="o", label="ACLD PEW F1")
    plt.plot(x, acld_fpr, marker="s", label="ACLD FPR")
    plt.plot(x, edge_f1, marker="^", label="ACLD edge F1")
    plt.xticks(x, names, rotation=30, ha="right")
    plt.ylabel("Metric")
    plt.title("Ablation study (adaptive terms)")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(os.path.join(outdir, f"{prefix}_12_ablation.png"), dpi=180)
    plt.close()


# --------------------------
# Full paper-like suite
# --------------------------

def run_full_paper_suite(
    outdir: str = "acld_paper_outputs",
    base_cfg: Optional[ACLDConfig] = None,
    fast_mode: bool = False,
) -> Dict[str, object]:
    """
    Generates:
      - representative episode timeseries plots
      - scenario MC summaries + plots
      - runtime scaling benchmark
      - jamming sweep / density sweep
      - ablation study
      - CSV + JSON outputs
    """
    _ensure_dir(outdir)
    figdir = os.path.join(outdir, "figures")
    tabdir = os.path.join(outdir, "tables")
    rawdir = os.path.join(outdir, "raw")
    for d in (figdir, tabdir, rawdir):
        _ensure_dir(d)

    if base_cfg is None:
        base_cfg = ACLDConfig(
            n_uavs=150,
            include_ground_station=True,
            density_scale_mode="medium",
            jnr_db=None,
            seed=42,
        )

    # Speed presets
    if fast_mode:
        rep_horizon = 5.0
        mc_horizon = 6.0
        seeds_small = (42, 43, 44)
        repeats_bench = 20
        bench_sizes = (50, 150, 300)
    else:
        rep_horizon = 15.0
        mc_horizon = 12.0
        seeds_small = (42, 43, 44, 45, 46)
        repeats_bench = 60
        bench_sizes = (50, 150, 300, 500)

    # 1) Representative episode (timeseries figures)
    rep_spec = ScenarioSpec(
        name="representative_nominal",
        cfg_overrides=dict(jnr_db=None, density_scale_mode="medium"),
        horizon_s=rep_horizon,
        seeds=(base_cfg.seed,),
        ebc_every_k_steps=5,
    )
    rep_ep_rows, rep_step_rows, rep_agg = run_scenario_mc(base_cfg, rep_spec)
    _save_rows_csv(os.path.join(rawdir, "representative_episode_summary.csv"), rep_ep_rows)
    _save_rows_csv(os.path.join(rawdir, "representative_episode_steps.csv"), rep_step_rows)
    _save_json(os.path.join(rawdir, "representative_episode_agg.json"), rep_agg)
    # Only first seed representative rows
    rep_seed_rows = [r for r in rep_step_rows if int(r["seed"]) == int(base_cfg.seed)]
    rep_seed_rows = sorted(rep_seed_rows, key=lambda r: float(r["t"]))
    _plot_timeseries_representative(rep_seed_rows, figdir)

    # 2) Main scenario set (nominal / jamming / density)
    scenarios = [
        ScenarioSpec("nominal_medium", dict(jnr_db=None, density_scale_mode="medium"), horizon_s=mc_horizon, seeds=seeds_small),
        ScenarioSpec("jam_5dB_medium", dict(jnr_db=5.0, density_scale_mode="medium"), horizon_s=mc_horizon, seeds=seeds_small),
        ScenarioSpec("jam_15dB_medium", dict(jnr_db=15.0, density_scale_mode="medium"), horizon_s=mc_horizon, seeds=seeds_small),
        ScenarioSpec("nominal_sparse", dict(jnr_db=None, density_scale_mode="sparse"), horizon_s=mc_horizon, seeds=seeds_small),
        ScenarioSpec("nominal_dense", dict(jnr_db=None, density_scale_mode="dense"), horizon_s=mc_horizon, seeds=seeds_small),
    ]

    all_episode_rows: List[Dict[str, object]] = []
    all_step_rows: List[Dict[str, object]] = []
    agg_rows: List[Dict[str, object]] = []

    for spec in scenarios:
        ep_rows, st_rows, agg = run_scenario_mc(base_cfg, spec)
        all_episode_rows.extend(ep_rows)
        all_step_rows.extend(st_rows)
        agg_rows.append(agg)

    _save_rows_csv(os.path.join(tabdir, "scenario_episode_metrics.csv"), all_episode_rows)
    _save_rows_csv(os.path.join(rawdir, "scenario_step_metrics.csv"), all_step_rows)
    _save_rows_csv(os.path.join(tabdir, "scenario_aggregated_summary.csv"), agg_rows)

    _plot_scenario_bars(agg_rows, figdir)
    _plot_runtime_boxplot_per_scenario(all_episode_rows, figdir)

    # 3) Runtime scaling benchmark
    bench_rows = benchmark_runtime_vs_n(base_cfg, sizes=bench_sizes, repeats=repeats_bench)
    _save_rows_csv(os.path.join(tabdir, "runtime_scaling.csv"), bench_rows)
    _plot_runtime_scaling(bench_rows, figdir)

    # 4) Jamming sweep
    jam_specs = [
        ScenarioSpec(f"jnr_{'nominal' if j is None else str(int(j))+'dB'}",
                     dict(jnr_db=j, density_scale_mode="medium"),
                     horizon_s=mc_horizon, seeds=seeds_small)
        for j in (None, 0.0, 5.0, 10.0, 15.0, 20.0)
    ]
    jam_agg_rows = []
    for spec in jam_specs:
        _, _, agg = run_scenario_mc(base_cfg, spec)
        jam_agg_rows.append(agg)
    _save_rows_csv(os.path.join(tabdir, "jamming_sweep_summary.csv"), jam_agg_rows)
    _plot_jamming_sweep(jam_agg_rows, figdir)

    # 5) Density sweep
    density_specs = [
        ScenarioSpec(f"density_{mode}", dict(jnr_db=None, density_scale_mode=mode),
                     horizon_s=mc_horizon, seeds=seeds_small)
        for mode in ("sparse", "medium", "dense")
    ]
    density_agg_rows = []
    for spec in density_specs:
        _, _, agg = run_scenario_mc(base_cfg, spec)
        density_agg_rows.append(agg)
    _save_rows_csv(os.path.join(tabdir, "density_sweep_summary.csv"), density_agg_rows)
    _plot_density_sweep(density_agg_rows, figdir)

    # 6) Ablation study (adaptive terms)
    ablation_specs = [
        ScenarioSpec("ablation_full", dict(jnr_db=None, density_scale_mode="medium"), horizon_s=mc_horizon, seeds=seeds_small),
        ScenarioSpec("ablation_no_alpha", dict(alpha=0.0, jnr_db=None, density_scale_mode="medium"), horizon_s=mc_horizon, seeds=seeds_small),
        ScenarioSpec("ablation_no_beta", dict(beta=0.0, jnr_db=None, density_scale_mode="medium"), horizon_s=mc_horizon, seeds=seeds_small),
        ScenarioSpec("ablation_no_gamma", dict(gamma=0.0, jnr_db=None, density_scale_mode="medium"), horizon_s=mc_horizon, seeds=seeds_small),
        ScenarioSpec("ablation_no_delta", dict(delta=0.0, jnr_db=None, density_scale_mode="medium"), horizon_s=mc_horizon, seeds=seeds_small),
        ScenarioSpec("ablation_no_zeta", dict(zeta=0.0, jnr_db=None, density_scale_mode="medium"), horizon_s=mc_horizon, seeds=seeds_small),
        ScenarioSpec("ablation_fixed_cthr", dict(alpha=0.0, beta=0.0, gamma=0.0, delta=0.0, zeta=0.0,
                                                 jnr_db=None, density_scale_mode="medium"),
                     horizon_s=mc_horizon, seeds=seeds_small),
    ]
    ablation_agg_rows = []
    for spec in ablation_specs:
        _, _, agg = run_scenario_mc(base_cfg, spec)
        ablation_agg_rows.append(agg)
    _save_rows_csv(os.path.join(tabdir, "ablation_summary.csv"), ablation_agg_rows)
    _plot_ablation(ablation_agg_rows, figdir)

    # 7) Compact summary JSON
    summary = {
        "outdir": outdir,
        "fast_mode": fast_mode,
        "base_cfg": {
            "n_uavs": base_cfg.n_uavs,
            "include_ground_station": base_cfg.include_ground_station,
            "density_scale_mode": base_cfg.density_scale_mode,
            "jnr_db": base_cfg.jnr_db,
            "seed": base_cfg.seed,
            "p_min": base_cfg.p_min,
        },
        "files": {
            "figures_dir": figdir,
            "tables_dir": tabdir,
            "raw_dir": rawdir,
        },
        "n_main_scenarios": len(scenarios),
        "n_jam_sweep_points": len(jam_specs),
        "n_density_sweep_points": len(density_specs),
        "n_ablations": len(ablation_specs),
    }
    _save_json(os.path.join(outdir, "suite_summary.json"), summary)
    return summary

In [None]:
# ============================================================
# Demo main
# ============================================================


if __name__ == "__main__":
    # Tam deney paketi (hızlı test)
    summary = run_full_paper_suite(
        outdir="acld_paper_outputs",
        base_cfg=ACLDConfig(
            n_uavs=150,
            include_ground_station=True,
            density_scale_mode="medium",
            jnr_db=None,
            seed=42,
        ),
        fast_mode=True,   # önce True ile test edin, sonra False
    )
    print("Suite summary:")
    print(json.dumps(summary, indent=2, ensure_ascii=False))

"""if __name__ == "__main__":
    # Example 1: single episode demo
    cfg = ACLDConfig(
        n_uavs=50,
        include_ground_station=True,
        density_scale_mode="medium",
        jnr_db=None,   # nominal
        seed=42,
    )

    logs, results = run_episode(cfg, horizon_s=2.0, seed=42)
    summary = summarize_episode(logs)

    print("=== ACLD Episode Summary ===")
    for k, v in summary.items():
        print(f"{k}: {v}")

    last = results[-1]
    print("\n=== Last-step diagnostics ===")
    print(f"lambda2_current      : {last.lambda2_current:.6f}")
    print(f"lambda2_pred_viable  : {last.lambda2_pred_viable:.6f}")
    print(f"pred components (guar): {len(last.components)}")
    print(f"risky_inter          : {len(last.risky_inter)}")
    print(f"risky_intra          : {len(last.risky_intra)}")
    print(f"risky_s_uav          : {len(last.risky_s_uav)}")
    print(f"risky_uav_uav        : {len(last.risky_uav_uav)}")

    # Example 2: runtime benchmark (small repeats for demo)
    print("\n=== Runtime benchmark (demo repeats=20) ===")
    bench_rows = benchmark_runtime_vs_n(cfg, sizes=(50, 150, 300), repeats=20)
    for row in bench_rows:
        print(row) """