In [8]:
#!/usr/bin/env python3
"""
APPROACH 2 — FULL END-TO-END IMPLEMENTATION (DEBUG/PRINT HEAVY)

What this script does (in order):

1) Build a gridworld with two corridors separated by a wall (two “gaps”).
2) Place two sensors: one near each gap. The *hidden sensor mode* m ∈ {0,1}
   chooses which sensor is the active hotspot (high detection probability region).
3) Robot receives noisy alarm measurements o_t ∈ {0,1}. Robot keeps a belief b_t(m)
   and updates it with Bayes' rule.

4) Define a small finite set of robot policies Π_R and sensor policies Π_S.
   Each policy is treated as a pure action in an empirical normal-form game.

5) Estimate empirical payoff matrices U_R, U_S by Monte Carlo rollouts.

6) Compute Larson Nash Bargaining Solution (NBS) over a joint distribution x over
   joint policy pairs (π_R, π_S) using projected gradient ascent on:
       g(x) = log(u_R^T x - d_R) + log(u_S^T x - d_S)

7) Extract marginals σ_R, σ_S from x, compute best responses:
   - Robot BR: risk-weighted A* path under expected risk map induced by σ_S.
   - Sensor BR: brute-force best fixed mode against σ_R (Monte Carlo).

8) Iterate steps 5-7 (outer loop), printing each step.

9) Final stage: “MultiNash-PF-like” multimodal trajectory demo:
   - Take top-k joint pairs under x
   - Convert discrete paths to continuous trajectories
   - Sample noisy “particles” around each ref trajectory
   - Refine with a cheap local smoother (stand-in for IPOPT refinement)
   - Cluster refined trajectories by discrete Fréchet distance
   - Print discovered “modes” (clusters)

Dependencies:
  - Python 3.9+
  - numpy

Run:
  python approach2_nbs_stealth.py

Suggested first run:
  python approach2_nbs_stealth.py --outer-iters 2 --rollouts-payoff 10 --rollouts-br 15

If you want VERY verbose per-step rollouts:
  python approach2_nbs_stealth.py --debug-one-rollout-per-pair

"""

from __future__ import annotations
from dataclasses import dataclass
from typing import List, Tuple, Dict, Optional, Callable, Any
import argparse
import heapq
import math
import random
import statistics
import time
import sys

import numpy as np


# =============================================================================
# Small helper: consistent printing
# =============================================================================

def log(msg: str) -> None:
    """Print and flush immediately (useful for long runs)."""
    print(msg, flush=True)


# =============================================================================
# Grid World + Sensor Model
# =============================================================================

Action = Tuple[int, int]  # (dx, dy)

MOVES: Dict[str, Action] = {
    "UP": (0, -1),
    "DOWN": (0, 1),
    "LEFT": (-1, 0),
    "RIGHT": (1, 0),
    "STAY": (0, 0),
}


@dataclass(frozen=True)
class GridConfig:
    width: int
    height: int
    start: Tuple[int, int]
    goal: Tuple[int, int]
    obstacles: frozenset  # set of blocked cells (x,y)


@dataclass(frozen=True)
class SensorConfig:
    sensors: Tuple[Tuple[int, int], ...]  # sensor locations; mode=m chooses sensors[m] hotspot
    radius: int                            # Manhattan radius of hotspot region
    base_p: float                          # baseline detection probability
    hotspot_p: float                       # detection probability inside hotspot


class GridWorldStealthEnv:
    """
    A very small "stealth" environment.

    Hidden "mode" is not inside env state (we treat it as a parameter chosen by sensor policy),
    but the robot maintains belief over modes using observation model.
    """

    def __init__(self, grid: GridConfig, sensor_cfg: SensorConfig, fp: float = 0.05, fn: float = 0.10):
        self.grid = grid
        self.sensor_cfg = sensor_cfg

        # Observation noise:
        #   fp = false positive rate for alarm channel
        #   fn = false negative rate
        self.fp = float(fp)
        self.fn = float(fn)
        if not (0.0 <= self.fp <= 1.0 and 0.0 <= self.fn <= 1.0):
            raise ValueError("fp and fn must be in [0,1].")

        self._rng = random.Random()
        self.reset(sensor_mode=0)

    def seed(self, seed: int) -> None:
        self._rng.seed(int(seed))

    def reset(self, sensor_mode: int = 0) -> Dict[str, Any]:
        self.t = 0
        self.pos = self.grid.start
        self.sensor_mode = int(sensor_mode)
        self.detected = False
        self.total_true_risk = 0.0
        return {"pos": self.pos, "t": self.t}

    def in_bounds(self, p: Tuple[int, int]) -> bool:
        x, y = p
        return 0 <= x < self.grid.width and 0 <= y < self.grid.height

    def is_free(self, p: Tuple[int, int]) -> bool:
        return self.in_bounds(p) and (p not in self.grid.obstacles)

    def true_detection_prob(self, p: Tuple[int, int], mode: int) -> float:
        """
        True detection probability at position p when active mode=mode.
        Mode selects which sensor center is "hot".
        """
        base = self.sensor_cfg.base_p
        hot = self.sensor_cfg.hotspot_p
        sx, sy = self.sensor_cfg.sensors[mode]
        x, y = p
        d = abs(x - sx) + abs(y - sy)  # Manhattan distance
        return hot if d <= self.sensor_cfg.radius else base

    def observation_prob(self, alarm: int, p_true: float) -> float:
        """
        We emit a noisy "alarm" o∈{0,1}. Probability alarm=1:
          p_alarm = p_true*(1-fn) + (1-p_true)*fp
        """
        p_alarm = p_true * (1.0 - self.fn) + (1.0 - p_true) * self.fp
        if alarm == 1:
            return p_alarm
        if alarm == 0:
            return 1.0 - p_alarm
        raise ValueError("alarm must be 0 or 1")

    def step(self, a: Action) -> Dict[str, Any]:
        """
        Execute robot action, update robot position, sample detection, and output alarm observation.

        Returns dict with:
          pos, t, alarm, p_true, detected, done
        """
        if self.detected:
            # absorbing terminal
            return {
                "pos": self.pos,
                "t": self.t,
                "alarm": 1,
                "p_true": 1.0,
                "detected": True,
                "done": True,
            }

        self.t += 1

        nx = self.pos[0] + int(a[0])
        ny = self.pos[1] + int(a[1])
        np_ = (nx, ny)

        # If invalid move or obstacle -> stay in place
        if self.is_free(np_):
            self.pos = np_

        # true risk at new position
        p_true = self.true_detection_prob(self.pos, self.sensor_mode)
        self.total_true_risk += p_true

        # sample true detection event
        if self._rng.random() < p_true:
            self.detected = True

        # sample noisy alarm
        p_alarm = p_true * (1.0 - self.fn) + (1.0 - p_true) * self.fp
        alarm = 1 if (self._rng.random() < p_alarm) else 0

        done = self.detected or (self.pos == self.grid.goal) or (self.t >= 200)
        return {
            "pos": self.pos,
            "t": self.t,
            "alarm": alarm,
            "p_true": p_true,
            "detected": self.detected,
            "done": done,
        }


def build_two_corridor_grid(width: int = 15, height: int = 9) -> GridConfig:
    """
    Build a vertical wall splitting the grid, with two gaps (two corridors).
    This creates multimodality: robot can go through upper gap or lower gap.
    """
    obstacles = set()
    wall_x = width // 2
    gap_ys = {2, 6}  # two corridor openings

    for y in range(height):
        if y not in gap_ys:
            obstacles.add((wall_x, y))

    start = (1, height - 2)
    goal = (width - 2, 1)

    if start in obstacles or goal in obstacles:
        raise RuntimeError("Start/goal ended up blocked; change map parameters.")

    return GridConfig(width=width, height=height, start=start, goal=goal, obstacles=frozenset(obstacles))


def print_grid_ascii(grid: GridConfig, sensor_cfg: SensorConfig) -> None:
    """Print an ASCII map for debugging."""
    W, H = grid.width, grid.height
    obs = set(grid.obstacles)
    sens = set(sensor_cfg.sensors)

    for y in range(H):
        row = []
        for x in range(W):
            p = (x, y)
            if p == grid.start:
                row.append("R")
            elif p == grid.goal:
                row.append("G")
            elif p in sens:
                row.append("S")
            elif p in obs:
                row.append("#")
            else:
                row.append(".")
        log("".join(row))


# =============================================================================
# Robot belief over modes: exact Bayes filter for discrete mode set
# =============================================================================

class ModeBelief:
    """
    Belief b(m) over discrete hidden sensor mode m ∈ {0,...,M-1}.
    We do exact Bayes update:
        b_{t+1}(m) ∝ P(o_{t+1} | pos_{t+1}, m) * b_t(m)
    """

    def __init__(self, M: int, init: Optional[np.ndarray] = None):
        self.M = int(M)
        if self.M <= 0:
            raise ValueError("M must be positive.")

        if init is None:
            self.b = np.full(self.M, 1.0 / self.M)
        else:
            init = np.asarray(init, dtype=float).reshape(-1)
            if init.shape != (self.M,):
                raise ValueError("init belief has wrong shape.")
            if np.any(init < 0):
                raise ValueError("init belief must be nonnegative.")
            s = float(init.sum())
            self.b = init / s if s > 0 else np.full(self.M, 1.0 / self.M)

    def update(
        self,
        env: GridWorldStealthEnv,
        alarm: int,
        pos: Tuple[int, int],
        kappa: float = 1e-12,
        verbose: bool = False,
    ) -> None:
        """
        b'(m) ∝ P(alarm | pos, m) * b(m)
        """
        likelihood = np.zeros(self.M, dtype=float)
        for m in range(self.M):
            p_true = env.true_detection_prob(pos, m)
            likelihood[m] = env.observation_prob(alarm, p_true)

        prior = self.b
        post_unnorm = prior * likelihood
        Z = float(post_unnorm.sum())

        if (not np.isfinite(Z)) or Z < kappa:
            # If something goes numerically wrong, keep prior.
            if verbose:
                log(f"[BeliefUpdate] WARNING: normalization Z={Z}. Keeping prior.")
            post = prior
        else:
            post = post_unnorm / Z

        if verbose:
            log(
                f"[BeliefUpdate] pos={pos} alarm={alarm} "
                f"prior={prior.round(3)} like={likelihood.round(3)} post={post.round(3)}"
            )

        self.b = post


# =============================================================================
# A* pathfinding (robot BR oracle uses this)
# =============================================================================

def astar_path(
    grid: GridConfig,
    start: Tuple[int, int],
    goal: Tuple[int, int],
    step_cost: Callable[[Tuple[int, int], Tuple[int, int]], float],
    verbose: bool = False,
    max_expansions: int = 250_000,
) -> List[Tuple[int, int]]:
    """
    Classic A* on 4-neighbor grid with custom step cost.
    Returns a list of cells [start, ..., goal].
    """

    def h(p: Tuple[int, int]) -> float:
        # Manhattan heuristic
        return abs(p[0] - goal[0]) + abs(p[1] - goal[1])

    def neighbors(p: Tuple[int, int]):
        x, y = p
        for dx, dy in [(1, 0), (-1, 0), (0, 1), (0, -1)]:
            np_ = (x + dx, y + dy)
            if 0 <= np_[0] < grid.width and 0 <= np_[1] < grid.height and np_ not in grid.obstacles:
                yield np_

    open_heap: List[Tuple[float, float, Tuple[int, int]]] = []
    heapq.heappush(open_heap, (h(start), 0.0, start))

    came_from: Dict[Tuple[int, int], Optional[Tuple[int, int]]] = {start: None}
    gscore: Dict[Tuple[int, int], float] = {start: 0.0}

    expansions = 0

    while open_heap:
        f, g, cur = heapq.heappop(open_heap)
        expansions += 1

        if verbose and expansions % 5000 == 0:
            log(f"[A*] expansions={expansions} open={len(open_heap)} cur={cur} g={g:.2f} f={f:.2f}")

        if cur == goal:
            # reconstruct
            path: List[Tuple[int, int]] = []
            while cur is not None:
                path.append(cur)
                cur = came_from[cur]
            path.reverse()
            if verbose:
                log(f"[A*] SUCCESS path_len={len(path)} expansions={expansions}")
            return path

        if expansions > max_expansions:
            raise RuntimeError(f"A* exceeded max_expansions={max_expansions}; maybe unreachable?")

        for nb in neighbors(cur):
            tentative = gscore[cur] + float(step_cost(cur, nb))
            if nb not in gscore or tentative < gscore[nb] - 1e-12:
                gscore[nb] = tentative
                came_from[nb] = cur
                heapq.heappush(open_heap, (tentative + h(nb), tentative, nb))

    raise RuntimeError("A* failed: goal unreachable.")


def astar_via_waypoint(
    grid: GridConfig,
    start: Tuple[int, int],
    waypoint: Tuple[int, int],
    goal: Tuple[int, int],
    step_cost: Callable[[Tuple[int, int], Tuple[int, int]], float],
    verbose: bool = False,
) -> List[Tuple[int, int]]:
    """Force a path to go through waypoint by chaining A*."""
    p1 = astar_path(grid, start, waypoint, step_cost, verbose=verbose)
    p2 = astar_path(grid, waypoint, goal, step_cost, verbose=verbose)
    return p1[:-1] + p2  # avoid duplicating waypoint cell


# =============================================================================
# Policies (Robot + Sensor)
# =============================================================================

class RobotPolicy:
    """Interface for robot decision-making."""
    name: str = "RobotPolicy"

    def reset(self, start_pos: Tuple[int, int]) -> None:
        pass

    def act(
        self,
        env: GridWorldStealthEnv,
        belief: ModeBelief,
        last_obs: Optional[int],
        verbose: bool = False,
    ) -> Action:
        raise NotImplementedError


class FixedPathPolicy(RobotPolicy):
    """Robot follows a fixed precomputed path (interpretable macro strategy)."""

    def __init__(self, path: List[Tuple[int, int]], name: str):
        if len(path) < 2:
            raise ValueError("path must include start and goal.")
        self.path = list(path)
        self.name = str(name)
        self._idx = 0

    def reset(self, start_pos: Tuple[int, int]) -> None:
        # Attempt to align with path index; otherwise start at 0.
        try:
            self._idx = self.path.index(start_pos)
        except ValueError:
            self._idx = 0

    def act(
        self,
        env: GridWorldStealthEnv,
        belief: ModeBelief,
        last_obs: Optional[int],
        verbose: bool = False,
    ) -> Action:
        cur = env.pos
        if self._idx >= len(self.path) - 1:
            return MOVES["STAY"]

        # If we somehow deviated, try to resync (robustness).
        if cur != self.path[self._idx]:
            try:
                j = self.path.index(cur, self._idx)
                self._idx = j
            except ValueError:
                if verbose:
                    log(f"[{self.name}] WARNING: off-path at {cur}; STAY.")
                return MOVES["STAY"]

        nxt = self.path[self._idx + 1]
        dx = int(np.clip(nxt[0] - cur[0], -1, 1))
        dy = int(np.clip(nxt[1] - cur[1], -1, 1))
        self._idx += 1

        if verbose:
            log(f"[{self.name}] cur={cur} -> nxt={nxt} act={(dx, dy)} idx={self._idx} belief={belief.b.round(3)}")

        return (dx, dy)


class RandomPolicy(RobotPolicy):
    """Baseline robot policy: random valid move."""

    def __init__(self, name: str = "R_Random"):
        self.name = name

    def reset(self, start_pos: Tuple[int, int]) -> None:
        pass

    def act(
        self,
        env: GridWorldStealthEnv,
        belief: ModeBelief,
        last_obs: Optional[int],
        verbose: bool = False,
    ) -> Action:
        cands: List[Action] = []
        for a in [(1, 0), (-1, 0), (0, 1), (0, -1), (0, 0)]:
            np_ = (env.pos[0] + a[0], env.pos[1] + a[1])
            if env.is_free(np_):
                cands.append(a)

        a = random.choice(cands) if cands else (0, 0)
        if verbose:
            log(f"[{self.name}] cur={env.pos} act={a} belief={belief.b.round(3)}")
        return a


class SensorPolicy:
    """Interface for sensor controller."""
    name: str = "SensorPolicy"

    def reset(self) -> None:
        pass

    def select_mode(self, t: int) -> int:
        raise NotImplementedError


class FixedModeSensorPolicy(SensorPolicy):
    """Sensor policy that always picks the same mode (hotspot at fixed sensor)."""

    def __init__(self, mode: int, name: Optional[str] = None):
        self.mode = int(mode)
        self.name = name or f"S_Mode{mode}"

    def reset(self) -> None:
        pass

    def select_mode(self, t: int) -> int:
        return self.mode


# =============================================================================
# Rollout simulator (empirical payoff estimation)
# =============================================================================

@dataclass
class EpisodeStats:
    steps: int
    reached_goal: bool
    detected: bool
    total_true_risk: float
    robot_utility: float
    sensor_utility: float


def rollout_episode(
    env: GridWorldStealthEnv,
    robot_policy: RobotPolicy,
    sensor_policy: SensorPolicy,
    M_modes: int,
    max_steps: int = 200,
    lambda_risk: float = 1.0,
    det_penalty: float = 50.0,
    sensor_energy_per_step: float = 0.2,
    seed: Optional[int] = None,
    verbose: bool = False,
) -> EpisodeStats:
    """
    Run one episode and compute utilities.

    Robot cost:
      cost_R = steps + lambda_risk * sum_t p_true(t) + det_penalty * 1{detected}
      utility_R = -cost_R

    Sensor utility:
      utility_S = det_penalty * 1{detected} + lambda_risk * sum_t p_true(t) - sensor_energy_per_step * steps
    """
    if seed is not None:
        env.seed(seed)

    sensor_policy.reset()
    env.reset(sensor_mode=sensor_policy.select_mode(0))

    belief = ModeBelief(M_modes)  # uniform initial belief
    robot_policy.reset(env.pos)

    total_risk = 0.0
    steps = 0
    last_alarm = None

    if verbose:
        log("=" * 80)
        log(f"[Rollout] START robot={robot_policy.name} sensor={sensor_policy.name} start={env.pos} goal={env.grid.goal}")
        log(f"[Rollout] belief0={belief.b.round(3)} over modes 0..{M_modes-1}")
        log("=" * 80)

    while steps < max_steps:
        # Sensor chooses mode (can depend on time)
        env.sensor_mode = sensor_policy.select_mode(env.t)

        # Robot picks action based on belief
        a = robot_policy.act(env, belief, last_alarm, verbose=verbose)

        out = env.step(a)
        steps += 1
        total_risk += out["p_true"]

        # Belief update using observation model
        belief.update(env, out["alarm"], out["pos"], verbose=verbose)
        last_alarm = out["alarm"]

        if verbose:
            log(f"[Rollout] t={out['t']:3d} pos={out['pos']} a={a} p_true={out['p_true']:.3f} "
                f"alarm={out['alarm']} detected={out['detected']} done={out['done']}")

        if out["done"]:
            break

    reached_goal = (env.pos == env.grid.goal)
    detected = env.detected

    robot_cost = steps + lambda_risk * total_risk + (det_penalty if detected else 0.0)
    robot_utility = -robot_cost

    sensor_utility = (det_penalty if detected else 0.0) + lambda_risk * total_risk - sensor_energy_per_step * steps

    if verbose:
        log("-" * 80)
        log(f"[Rollout] END steps={steps} reached_goal={reached_goal} detected={detected} total_risk={total_risk:.3f}")
        log(f"[Rollout] U_R={robot_utility:.3f} U_S={sensor_utility:.3f}")
        log("=" * 80)

    return EpisodeStats(
        steps=steps,
        reached_goal=reached_goal,
        detected=detected,
        total_true_risk=total_risk,
        robot_utility=robot_utility,
        sensor_utility=sensor_utility,
    )


# =============================================================================
# Empirical payoff estimation: U_R, U_S matrices
# =============================================================================

def evaluate_payoff_matrices(
    env: GridWorldStealthEnv,
    robot_policies: List[RobotPolicy],
    sensor_policies: List[SensorPolicy],
    M_modes: int,
    N_rollouts: int = 30,
    base_seed: int = 123,
    debug_one_rollout_per_pair: bool = False,
    verbose: bool = True,
) -> Tuple[np.ndarray, np.ndarray, Dict[Tuple[int, int], Dict[str, float]]]:
    """
    Estimate empirical payoff matrices by Monte Carlo rollouts.

    Returns:
      U_R: (m,n)
      U_S: (m,n)
      diagnostics[(i,j)] = extra stats (det_rate, goal_rate, etc.)
    """
    m = len(robot_policies)
    n = len(sensor_policies)

    U_R = np.zeros((m, n), dtype=float)
    U_S = np.zeros((m, n), dtype=float)
    diagnostics: Dict[Tuple[int, int], Dict[str, float]] = {}

    if verbose:
        log("\n" + "#" * 80)
        log(f"[EvalPayoffs] m={m} robot policies, n={n} sensor policies, N_rollouts={N_rollouts}, base_seed={base_seed}")
        log("#" * 80)

    for i, rp in enumerate(robot_policies):
        for j, sp in enumerate(sensor_policies):
            if debug_one_rollout_per_pair and verbose:
                log("\n" + "-" * 80)
                log(f"[EvalPayoffs][DEBUG] One fully-verbose rollout for (R{i}:{rp.name}, S{j}:{sp.name})")
                rollout_episode(
                    env,
                    rp,
                    sp,
                    M_modes=M_modes,
                    seed=base_seed + 100000 * i + 1000 * j,
                    verbose=True,
                )
                log("-" * 80 + "\n")

            r_utils: List[float] = []
            s_utils: List[float] = []
            detected_count = 0
            goal_count = 0
            steps_list: List[int] = []
            risk_list: List[float] = []

            for k in range(N_rollouts):
                seed = base_seed + 100000 * i + 1000 * j + k
                stats = rollout_episode(env, rp, sp, M_modes=M_modes, seed=seed, verbose=False)
                r_utils.append(stats.robot_utility)
                s_utils.append(stats.sensor_utility)
                detected_count += int(stats.detected)
                goal_count += int(stats.reached_goal)
                steps_list.append(stats.steps)
                risk_list.append(stats.total_true_risk)

            U_R[i, j] = float(np.mean(r_utils))
            U_S[i, j] = float(np.mean(s_utils))

            diagnostics[(i, j)] = {
                "det_rate": detected_count / N_rollouts,
                "goal_rate": goal_count / N_rollouts,
                "mean_steps": float(np.mean(steps_list)),
                "mean_risk": float(np.mean(risk_list)),
                "std_UR": float(np.std(r_utils)),
                "std_US": float(np.std(s_utils)),
            }

            if verbose:
                d = diagnostics[(i, j)]
                log(
                    f"[EvalPayoffs] (R{i}:{rp.name}, S{j}:{sp.name}) -> "
                    f"U_R={U_R[i,j]:8.3f}±{d['std_UR']:.2f} | "
                    f"U_S={U_S[i,j]:8.3f}±{d['std_US']:.2f} | "
                    f"det%={d['det_rate']*100:5.1f} goal%={d['goal_rate']*100:5.1f} "
                    f"steps={d['mean_steps']:.1f} risk={d['mean_risk']:.2f}"
                )

    if verbose:
        log("#" * 80 + "\n")

    return U_R, U_S, diagnostics


# =============================================================================
# Larson NBS meta-solver: projected gradient ascent on log Nash product
# =============================================================================

def project_to_simplex(v: np.ndarray, z: float = 1.0) -> np.ndarray:
    """
    Euclidean projection onto simplex:
      { x >= 0, sum(x) = z }
    Robust implementation with sorting / threshold.
    """
    if z <= 0:
        raise ValueError("z must be > 0")
    v = np.asarray(v, dtype=float).reshape(-1)
    n = v.size
    u = np.sort(v)[::-1]
    cssv = np.cumsum(u)
    rho = np.nonzero(u * np.arange(1, n + 1) > (cssv - z))[0]
    if rho.size == 0:
        return np.full(n, z / n)
    rho = int(rho[-1])
    theta = (cssv[rho] - z) / (rho + 1.0)
    w = np.maximum(v - theta, 0.0)
    s = float(w.sum())
    if not np.isfinite(s) or s <= 0:
        return np.full(n, z / n)
    return w * (z / s)


@dataclass
class NBSResult:
    x: np.ndarray
    history: List[Dict[str, float]]


def solve_nbs_projected_gradient(
    uR: np.ndarray,
    uS: np.ndarray,
    max_iters: int = 400,
    alpha: float = 0.5,
    alpha_min: float = 1e-6,
    tol: float = 1e-6,
    kappa: float = 1e-6,
    verbose: bool = True,
) -> NBSResult:
    """
    Solve NBS on joint distribution x over joint actions:

      maximize g(x) = log(uR^T x - dR) + log(uS^T x - dS)
      s.t. x in simplex

    Disagreement:
      dR = min(uR) - 1
      dS = min(uS) - 1

    Gradient:
      ∇g(x) = uR/(uR^T x - dR) + uS/(uS^T x - dS)

    Update:
      y = x + alpha * ∇g(x)
      x = Proj_simplex(y)

    Includes backtracking line search for robustness.
    """
    uR = np.asarray(uR, dtype=float).reshape(-1)
    uS = np.asarray(uS, dtype=float).reshape(-1)
    if uR.shape != uS.shape:
        raise ValueError("uR and uS must have same shape.")
    d = uR.size
    if d < 2:
        raise ValueError("Need at least 2 joint actions to compute NBS.")

    dR = float(np.min(uR) - 1.0)
    dS = float(np.min(uS) - 1.0)

    x = np.full(d, 1.0 / d, dtype=float)  # start uniform

    def gains(xv: np.ndarray) -> Tuple[float, float]:
        gR = float(uR @ xv - dR)
        gS = float(uS @ xv - dS)
        return gR, gS

    def objective(xv: np.ndarray) -> float:
        gR, gS = gains(xv)
        gR = max(gR, kappa)
        gS = max(gS, kappa)
        return float(np.log(gR) + np.log(gS))

    def grad(xv: np.ndarray) -> np.ndarray:
        gR, gS = gains(xv)
        gR = max(gR, kappa)
        gS = max(gS, kappa)
        return (uR / gR) + (uS / gS)

    hist: List[Dict[str, float]] = []
    last_obj = objective(x)

    if verbose:
        gR, gS = gains(x)
        log("\n" + "#" * 80)
        log("[NBS] Starting projected gradient ascent")
        log(f"[NBS] d={d} joint actions | dR={dR:.3f}, dS={dS:.3f}")
        log(f"[NBS] init gains=(R:{gR:.3f}, S:{gS:.3f}) obj={last_obj:.6f}")
        log("#" * 80)

    for t in range(1, max_iters + 1):
        gvec = grad(x)
        gnorm = float(np.linalg.norm(gvec))
        if not np.isfinite(gnorm) or gnorm == 0.0:
            if verbose:
                log(f"[NBS] WARNING: bad gradient norm at iter={t}: {gnorm}")
            break

        # Backtracking line search
        a = alpha
        improved = False
        for _ in range(30):
            y = x + a * gvec
            x_new = project_to_simplex(y)
            obj_new = objective(x_new)
            if obj_new >= last_obj - 1e-12:
                improved = True
                break
            a *= 0.5
            if a < alpha_min:
                break

        if not improved:
            if verbose:
                log(f"[NBS] Line search failed at iter={t}. Stopping.")
            break

        delta_l1 = float(np.linalg.norm(x_new - x, ord=1))
        x = x_new
        last_obj = obj_new
        gR, gS = gains(x)

        hist.append(
            {
                "iter": float(t),
                "obj": float(last_obj),
                "gain_R": float(gR),
                "gain_S": float(gS),
                "alpha": float(a),
                "delta_L1": float(delta_l1),
                "grad_norm": float(gnorm),
            }
        )

        if verbose and (t <= 10 or t % 25 == 0):
            topk = np.argsort(-x)[:5]
            top_str = ", ".join([f"{idx}:{x[idx]:.3f}" for idx in topk])
            log(
                f"[NBS] iter={t:4d} obj={last_obj:.6f} gains(R={gR:.3f},S={gS:.3f}) "
                f"alpha={a:.3g} L1delta={delta_l1:.3g} top={top_str}"
            )

        if delta_l1 < tol:
            if verbose:
                log(f"[NBS] Converged at iter={t} (L1delta={delta_l1:.2e} < tol={tol})")
            break

    if verbose:
        log("-" * 80)
        log("[NBS] Finished. Top joint actions (index:prob):")
        for idx in np.argsort(-x)[:10]:
            log(f"  {idx:3d}: {x[idx]:.6f}")
        log(f"[NBS] sum(x)={x.sum():.6f} (should be 1.0)")
        log("#" * 80 + "\n")

    return NBSResult(x=x, history=hist)


def joint_to_matrix(x: np.ndarray, m: int, n: int) -> np.ndarray:
    x = np.asarray(x, dtype=float).reshape(-1)
    if x.size != m * n:
        raise ValueError("x size mismatch.")
    return x.reshape((m, n))


def marginals_from_joint(x: np.ndarray, m: int, n: int) -> Tuple[np.ndarray, np.ndarray]:
    X = joint_to_matrix(x, m, n)
    sigma_R = X.sum(axis=1)  # sum over sensors
    sigma_S = X.sum(axis=0)  # sum over robots
    if sigma_R.sum() > 0:
        sigma_R = sigma_R / sigma_R.sum()
    if sigma_S.sum() > 0:
        sigma_S = sigma_S / sigma_S.sum()
    return sigma_R, sigma_S


# =============================================================================
# Best Response Oracles (tractable, no deep nets)
# =============================================================================

def build_initial_robot_policies(env: GridWorldStealthEnv) -> List[RobotPolicy]:
    """
    Create interpretable initial robot policies:
      - shortest path
      - forced upper corridor
      - forced lower corridor
      - random baseline
    """
    grid = env.grid
    step_cost = lambda a, b: 1.0

    p_short = astar_path(grid, grid.start, grid.goal, step_cost)

    wall_x = grid.width // 2
    p_upper = astar_via_waypoint(grid, grid.start, (wall_x, 2), grid.goal, step_cost)
    p_lower = astar_via_waypoint(grid, grid.start, (wall_x, 6), grid.goal, step_cost)

    return [
        FixedPathPolicy(p_short, "R_Shortest"),
        FixedPathPolicy(p_upper, "R_UpperCorridor"),
        FixedPathPolicy(p_lower, "R_LowerCorridor"),
        RandomPolicy("R_Random"),
    ]


def build_initial_sensor_policies(M_modes: int) -> List[SensorPolicy]:
    """Start with each fixed mode as a pure sensor policy."""
    return [FixedModeSensorPolicy(m, name=f"S_Mode{m}") for m in range(M_modes)]


def find_sensor_policy_by_mode(sensor_policies: List[SensorPolicy], mode: int) -> Optional[FixedModeSensorPolicy]:
    for sp in sensor_policies:
        if isinstance(sp, FixedModeSensorPolicy) and sp.mode == mode:
            return sp
    return None


def compute_expected_risk_map(
    env: GridWorldStealthEnv,
    sensor_policies: List[SensorPolicy],
    sigma_S: np.ndarray,
    M_modes: int,
    verbose: bool = True,
) -> np.ndarray:
    """
    Compute expected true detection probability p_true(cell) under sensor mixture sigma_S.

    We approximate each sensor policy by its mode at t=0 (fixed-mode policies in this script).
    """
    sigma_S = np.asarray(sigma_S, dtype=float).reshape(-1)
    if sigma_S.size != len(sensor_policies):
        raise ValueError("sigma_S length mismatch.")

    W, H = env.grid.width, env.grid.height
    risk = np.zeros((H, W), dtype=float)

    # implied mode distribution from the mixture over sensor policies
    mode_probs = np.zeros(M_modes, dtype=float)
    for j, sp in enumerate(sensor_policies):
        m = sp.select_mode(0)
        if not (0 <= m < M_modes):
            raise ValueError(f"sensor policy returned invalid mode {m}")
        mode_probs[m] += sigma_S[j]
    if mode_probs.sum() > 0:
        mode_probs = mode_probs / mode_probs.sum()

    for y in range(H):
        for x in range(W):
            p = (x, y)
            if p in env.grid.obstacles:
                risk[y, x] = np.nan
                continue
            val = 0.0
            for m in range(M_modes):
                val += mode_probs[m] * env.true_detection_prob(p, m)
            risk[y, x] = val

    if verbose:
        finite = risk[np.isfinite(risk)]
        log("\n" + "#" * 80)
        log("[RiskMap] Expected risk map computed from sensor mixture")
        log(f"[RiskMap] sigma_S={sigma_S.round(3)} -> mode_probs={mode_probs.round(3)}")
        log(f"[RiskMap] risk stats: min={finite.min():.3f} mean={finite.mean():.3f} max={finite.max():.3f}")
        log("#" * 80 + "\n")

    return risk


def robot_best_response_from_sigmaS(
    env: GridWorldStealthEnv,
    sensor_policies: List[SensorPolicy],
    sigma_S: np.ndarray,
    existing_robot_policies: List[RobotPolicy],
    M_modes: int,
    risk_weight: float = 12.0,
    verbose: bool = True,
) -> RobotPolicy:
    """
    Robot BR to sigma_S:
      - compute expected risk map under sigma_S
      - solve risk-weighted shortest path with A*
      - return a FixedPathPolicy following that path
    """
    risk = compute_expected_risk_map(env, sensor_policies, sigma_S, M_modes=M_modes, verbose=verbose)

    def step_cost(frm: Tuple[int, int], to: Tuple[int, int]) -> float:
        x, y = to
        r = risk[y, x]
        if not np.isfinite(r):
            return 1e9
        return 1.0 + risk_weight * float(r)

    if verbose:
        log("[RobotBR] Running risk-weighted A* ...")
        log(f"[RobotBR] risk_weight={risk_weight}")

    path = astar_path(env.grid, env.grid.start, env.grid.goal, step_cost, verbose=verbose)

    # Avoid duplicates if same path already exists
    path_tuple = tuple(path)
    for pol in existing_robot_policies:
        if isinstance(pol, FixedPathPolicy) and tuple(pol.path) == path_tuple:
            if verbose:
                log(f"[RobotBR] BR path already exists as '{pol.name}'. Returning existing.")
            return pol

    name = f"R_BR_RiskAStar_w{risk_weight:.1f}_len{len(path)}"
    if verbose:
        log(f"[RobotBR] Created NEW robot policy: {name}")

    return FixedPathPolicy(path, name=name)


def sensor_best_response_from_sigmaR(
    env: GridWorldStealthEnv,
    robot_policies: List[RobotPolicy],
    sigma_R: np.ndarray,
    candidate_modes: List[int],
    M_modes: int,
    N_rollouts: int = 30,
    base_seed: int = 999,
    verbose: bool = True,
) -> FixedModeSensorPolicy:
    """
    Sensor BR to sigma_R:
      brute-force evaluate each fixed mode by Monte Carlo:
        pick robot policy index i ~ sigma_R
        rollout (π_R^i, sensor_mode)
      choose mode with highest E[U_S]
    """
    sigma_R = np.asarray(sigma_R, dtype=float).reshape(-1)
    if sigma_R.size != len(robot_policies):
        raise ValueError("sigma_R length mismatch.")

    rng = np.random.default_rng(base_seed)

    if verbose:
        log("\n" + "#" * 80)
        log("[SensorBR] Searching best-response fixed sensor mode against robot mixture")
        log(f"[SensorBR] sigma_R={sigma_R.round(3)}")
        log(f"[SensorBR] candidate_modes={candidate_modes}")
        log("#" * 80)

    best_mode: Optional[int] = None
    best_val = -1e18

    for mode in candidate_modes:
        if not (0 <= mode < M_modes):
            continue

        sp = FixedModeSensorPolicy(mode, name=f"S_BR_Mode{mode}")
        utils: List[float] = []

        for k in range(N_rollouts):
            i = int(rng.choice(len(robot_policies), p=sigma_R))
            rp = robot_policies[i]
            seed = base_seed + 10000 * mode + k
            stats = rollout_episode(env, rp, sp, M_modes=M_modes, seed=seed, verbose=False)
            utils.append(stats.sensor_utility)

        mean_u = float(np.mean(utils))
        std_u = float(np.std(utils))

        if verbose:
            log(f"[SensorBR] mode={mode} -> E[U_S]={mean_u:8.3f} (std={std_u:6.2f})")

        if mean_u > best_val:
            best_val = mean_u
            best_mode = mode

    if best_mode is None:
        raise RuntimeError("No valid sensor mode found.")

    if verbose:
        log("-" * 80)
        log(f"[SensorBR] Best response mode={best_mode} with estimated E[U_S]={best_val:.3f}")
        log("#" * 80 + "\n")

    return FixedModeSensorPolicy(best_mode, name=f"S_BR_Mode{best_mode}")


# =============================================================================
# Outer loop: Empirical game + NBS + BR expansion
# =============================================================================

def run_approach2_pipeline(
    verbose: bool = True,
    outer_iters: int = 3,
    rollouts_payoff: int = 20,
    rollouts_br: int = 30,
    risk_weight_br: float = 12.0,
    M_modes: int = 2,
    seed: int = 0,
    debug_one_rollout_per_pair: bool = False,
) -> Dict[str, Any]:
    """
    Full pipeline:
      - build env
      - init Π_R, Π_S
      - repeat:
          evaluate payoffs
          solve NBS for x*
          compute marginals σ_R, σ_S
          compute BR policies and expand sets
    """
    grid = build_two_corridor_grid()
    sensor_cfg = SensorConfig(
        sensors=((grid.width // 2, 2), (grid.width // 2, 6)),  # sensor centers near the two gaps
        radius=2,
        base_p=0.02,
        hotspot_p=0.60,
    )
    env = GridWorldStealthEnv(grid, sensor_cfg, fp=0.05, fn=0.10)
    env.seed(seed)

    if verbose:
        log("\n" + "=" * 100)
        log("[PIPELINE] Approach 2: Empirical POMDP Policy Game + Larson NBS + BR Expansion")
        log("=" * 100)
        log(f"[PIPELINE] Grid: {grid.width}x{grid.height} start={grid.start} goal={grid.goal} obstacles={len(grid.obstacles)}")
        log(f"[PIPELINE] Sensors: {sensor_cfg.sensors} radius={sensor_cfg.radius} base_p={sensor_cfg.base_p} hotspot_p={sensor_cfg.hotspot_p}")
        log("[PIPELINE] ASCII map (R=start, G=goal, S=sensors, #=wall):")
        print_grid_ascii(grid, sensor_cfg)
        log("=" * 100 + "\n")

    robot_policies = build_initial_robot_policies(env)
    sensor_policies = build_initial_sensor_policies(M_modes)

    if verbose:
        log("[PIPELINE] Initial Robot Policies:")
        for i, p in enumerate(robot_policies):
            log(f"  R{i}: {p.name}")
        log("[PIPELINE] Initial Sensor Policies:")
        for j, p in enumerate(sensor_policies):
            log(f"  S{j}: {p.name}")
        log("")

    x_star: Optional[np.ndarray] = None
    sigma_R: Optional[np.ndarray] = None
    sigma_S: Optional[np.ndarray] = None
    outer_history: List[Dict[str, float]] = []

    for k in range(1, outer_iters + 1):
        if verbose:
            log("\n" + "=" * 100)
            log(f"[PIPELINE] OUTER ITERATION {k}/{outer_iters}")
            log("=" * 100)

        # 1) empirical payoffs
        U_R, U_S, diag = evaluate_payoff_matrices(
            env,
            robot_policies,
            sensor_policies,
            M_modes=M_modes,
            N_rollouts=rollouts_payoff,
            base_seed=1000 + 100 * k,
            debug_one_rollout_per_pair=debug_one_rollout_per_pair,
            verbose=verbose,
        )

        # 2) flatten payoffs into vectors over joint actions
        m = len(robot_policies)
        n = len(sensor_policies)
        uR_vec = U_R.reshape(-1)  # length d=m*n
        uS_vec = U_S.reshape(-1)

        if verbose:
            log("[PIPELINE] Flattened payoff vectors for NBS:")
            log(f"[PIPELINE] m={m}, n={n}, d=m*n={m*n}")
            log(f"[PIPELINE] uR range=[{uR_vec.min():.3f}, {uR_vec.max():.3f}]")
            log(f"[PIPELINE] uS range=[{uS_vec.min():.3f}, {uS_vec.max():.3f}]")
            log("")

        # 3) solve NBS
        nbs_res = solve_nbs_projected_gradient(uR_vec, uS_vec, max_iters=300, alpha=0.5, tol=1e-6, verbose=verbose)
        x_star = nbs_res.x

        # 4) compute marginals
        X = joint_to_matrix(x_star, m, n)
        sigma_R, sigma_S = marginals_from_joint(x_star, m, n)

        if verbose:
            log("[PIPELINE] NBS joint distribution x* reshaped as matrix X (rows=robot, cols=sensor):")
            log(np.array_str(X, precision=3, suppress_small=True))
            log(f"[PIPELINE] Robot marginal sigma_R: {sigma_R.round(3)}")
            log(f"[PIPELINE] Sensor marginal sigma_S: {sigma_S.round(3)}")
            log("")
            log("[PIPELINE] Top-5 joint policy pairs (i,j) by x* probability:")
            top_pairs = np.argsort(-x_star)[:5]
            for rank, flat_idx in enumerate(top_pairs, start=1):
                i, j = np.unravel_index(int(flat_idx), (m, n))
                log(f"  #{rank}: (R{i}:{robot_policies[i].name}, S{j}:{sensor_policies[j].name}) prob={X[i,j]:.4f}")
            log("")

        # 5) robot BR to sigma_S
        br_robot = robot_best_response_from_sigmaS(
            env,
            sensor_policies,
            sigma_S,
            existing_robot_policies=robot_policies,
            M_modes=M_modes,
            risk_weight=risk_weight_br,
            verbose=verbose,
        )
        if br_robot not in robot_policies:
            robot_policies.append(br_robot)
            if verbose:
                log(f"[PIPELINE] Added NEW robot policy: {br_robot.name}")
        else:
            if verbose:
                log(f"[PIPELINE] Robot BR already in set: {br_robot.name}")

        # 6) sensor BR to sigma_R
        br_sensor = sensor_best_response_from_sigmaR(
            env,
            robot_policies,
            sigma_R,
            candidate_modes=list(range(M_modes)),
            M_modes=M_modes,
            N_rollouts=rollouts_br,
            base_seed=2000 + 100 * k,
            verbose=verbose,
        )

        existing = find_sensor_policy_by_mode(sensor_policies, br_sensor.mode)
        if existing is None:
            sensor_policies.append(br_sensor)
            if verbose:
                log(f"[PIPELINE] Added NEW sensor policy: {br_sensor.name}")
        else:
            if verbose:
                log(f"[PIPELINE] Sensor BR mode already present as '{existing.name}'. Not adding duplicate.")

        outer_history.append(
            {
                "outer_iter": float(k),
                "num_robot": float(len(robot_policies)),
                "num_sensor": float(len(sensor_policies)),
                "max_joint_prob": float(np.max(x_star)),
                "nbs_obj_last": float(nbs_res.history[-1]["obj"]) if nbs_res.history else float("nan"),
            }
        )

    if verbose:
        log("\n" + "=" * 100)
        log("[PIPELINE] FINISHED OUTER LOOP")
        log("=" * 100)
        log(f"[PIPELINE] Final |Pi_R|={len(robot_policies)}, |Pi_S|={len(sensor_policies)}")
        log("[PIPELINE] Final Robot Policies:")
        for i, p in enumerate(robot_policies):
            log(f"  R{i}: {p.name}")
        log("[PIPELINE] Final Sensor Policies:")
        for j, p in enumerate(sensor_policies):
            log(f"  S{j}: {p.name}")
        log("=" * 100 + "\n")

    return {
        "env": env,
        "robot_policies": robot_policies,
        "sensor_policies": sensor_policies,
        "x_star": x_star,
        "sigma_R": sigma_R,
        "sigma_S": sigma_S,
        "outer_history": outer_history,
    }


# =============================================================================
# Final stage: MultiNash-PF-like multimodal trajectory demo (runnable stand-in)
# =============================================================================

def grid_path_to_continuous(path: List[Tuple[int, int]]) -> np.ndarray:
    """Convert grid cells to continuous 2D points."""
    return np.array([[float(x), float(y)] for (x, y) in path], dtype=float)


def discrete_frechet(P: np.ndarray, Q: np.ndarray) -> float:
    """
    Discrete Fréchet distance between two point sequences P and Q.
    Used for clustering trajectories into modes.
    """
    P = np.asarray(P, dtype=float)
    Q = np.asarray(Q, dtype=float)
    m, n = P.shape[0], Q.shape[0]
    ca = np.full((m, n), -1.0, dtype=float)

    def dist(i: int, j: int) -> float:
        return float(np.linalg.norm(P[i] - Q[j]))

    def rec(i: int, j: int) -> float:
        if ca[i, j] > -0.5:
            return float(ca[i, j])
        if i == 0 and j == 0:
            ca[i, j] = dist(0, 0)
        elif i > 0 and j == 0:
            ca[i, j] = max(rec(i - 1, 0), dist(i, 0))
        elif i == 0 and j > 0:
            ca[i, j] = max(rec(0, j - 1), dist(0, j))
        else:
            ca[i, j] = max(min(rec(i - 1, j), rec(i - 1, j - 1), rec(i, j - 1)), dist(i, j))
        return float(ca[i, j])

    return rec(m - 1, n - 1)


def smooth_trajectory(
    traj: np.ndarray,
    ref: np.ndarray,
    risk_map: np.ndarray,
    w_track: float = 1.0,
    w_smooth: float = 0.3,
    w_risk: float = 5.0,
    iters: int = 60,
    step: float = 0.25,
    verbose: bool = False,
) -> np.ndarray:
    """
    Cheap local refinement (stand-in for IPOPT):
      - endpoint-fixed gradient-like update combining:
          tracking term
          smoothness term
          risk term
    """
    traj = traj.copy()
    L = traj.shape[0]
    H, W = risk_map.shape

    def risk_at(p: np.ndarray) -> float:
        x = int(round(float(p[0])))
        y = int(round(float(p[1])))
        x = max(0, min(W - 1, x))
        y = max(0, min(H - 1, y))
        r = risk_map[y, x]
        return 1.0 if (not np.isfinite(r)) else float(r)

    for it in range(iters):
        for t in range(1, L - 1):
            # smoothness: pull toward midpoint of neighbors
            mid = 0.5 * (traj[t - 1] + traj[t + 1])
            grad_smooth = traj[t] - mid

            # tracking: pull toward reference
            grad_track = traj[t] - ref[t]

            # risk: approximate gradient by finite differences
            eps = 1.0
            r0 = risk_at(traj[t])
            rx = (risk_at(traj[t] + np.array([eps, 0.0])) - r0) / eps
            ry = (risk_at(traj[t] + np.array([0.0, eps])) - r0) / eps
            grad_risk = np.array([rx, ry])

            grad = w_smooth * grad_smooth + w_track * grad_track + w_risk * grad_risk
            traj[t] = traj[t] - step * grad

        if verbose and (it % 20 == 0 or it == iters - 1):
            # approximate objective value for debugging
            obj = 0.0
            for k in range(L):
                obj += w_track * float(np.sum((traj[k] - ref[k]) ** 2))
                if k > 0:
                    obj += w_smooth * float(np.sum((traj[k] - traj[k - 1]) ** 2))
                obj += w_risk * risk_at(traj[k])
            log(f"[Smooth] iter={it} approx_obj={obj:.3f}")

    return traj


def multimodal_trajectory_modes_demo(
    env: GridWorldStealthEnv,
    robot_policies: List[RobotPolicy],
    sensor_policies: List[SensorPolicy],
    x_star: np.ndarray,
    top_k_pairs: int = 4,
    particles_per_pair: int = 20,
    noise_std: float = 0.4,
    cluster_threshold: float = 1.0,
    verbose: bool = True,
) -> Dict[str, Any]:
    """
    Runnable “MultiNash-PF-like” structure:
      1) pick top-k joint pairs from x*
      2) convert robot path to continuous ref trajectory
      3) sample noisy particles around each ref
      4) refine each particle (cheap smoothing)
      5) cluster refined trajectories by Fréchet distance
      6) report clusters as “modes”
    """
    m = len(robot_policies)
    n = len(sensor_policies)
    X = joint_to_matrix(x_star, m, n)
    sigma_R, sigma_S = marginals_from_joint(x_star, m, n)

    # risk map under sigma_S (used in refinement objective)
    risk_map = compute_expected_risk_map(env, sensor_policies, sigma_S, M_modes=len(env.sensor_cfg.sensors), verbose=verbose)

    # top-k joint pairs
    flat_idx = np.argsort(-x_star)[:top_k_pairs]
    pairs = [np.unravel_index(int(idx), (m, n)) for idx in flat_idx]

    if verbose:
        log("\n" + "=" * 100)
        log("[TrajectoryModes] MultiNash-PF-like demo: sample -> refine -> cluster")
        log("=" * 100)
        for rank, (i, j) in enumerate(pairs, start=1):
            log(f"[TrajectoryModes] Top#{rank}: (R{i}:{robot_policies[i].name}, S{j}:{sensor_policies[j].name}) prob={X[i,j]:.4f}")
        log("=" * 100 + "\n")

    # build reference trajectories from path-based robot policies
    refs: List[np.ndarray] = []
    ref_meta: List[Tuple[int, int, str, str, float]] = []
    for (i, j) in pairs:
        rp = robot_policies[i]
        if isinstance(rp, FixedPathPolicy):
            ref = grid_path_to_continuous(rp.path)
            refs.append(ref)
            ref_meta.append((i, j, rp.name, sensor_policies[j].name, float(X[i, j])))

    if not refs:
        raise RuntimeError("No FixedPathPolicy among top pairs; cannot build trajectories.")

    rng = np.random.default_rng(0)

    all_trajs: List[np.ndarray] = []
    all_scores: List[float] = []
    all_origin: List[int] = []

    H, W = risk_map.shape

    def risk_at(p: np.ndarray) -> float:
        x = int(round(float(p[0])))
        y = int(round(float(p[1])))
        x = max(0, min(W - 1, x))
        y = max(0, min(H - 1, y))
        r = risk_map[y, x]
        return 1.0 if (not np.isfinite(r)) else float(r)

    def traj_objective(traj: np.ndarray, ref: np.ndarray) -> float:
        w_track, w_smooth, w_risk = 1.0, 0.3, 5.0
        obj = 0.0
        for t in range(traj.shape[0]):
            obj += w_track * float(np.sum((traj[t] - ref[t]) ** 2))
            if t > 0:
                obj += w_smooth * float(np.sum((traj[t] - traj[t - 1]) ** 2))
            obj += w_risk * risk_at(traj[t])
        return float(obj)

    if verbose:
        log("[TrajectoryModes] Sampling and refining particles...")
        log(f"[TrajectoryModes] particles_per_pair={particles_per_pair}, noise_std={noise_std}, cluster_threshold={cluster_threshold}")
        log("")

    for idx_ref, ref in enumerate(refs):
        i, j, rn, sn, prob = ref_meta[idx_ref]
        if verbose:
            log(f"[TrajectoryModes] Ref#{idx_ref}: len={len(ref)} from (R{i}:{rn}, S{j}:{sn}) prob={prob:.4f}")

        for _ in range(particles_per_pair):
            noise = rng.normal(0.0, noise_std, size=ref.shape)
            traj0 = ref + noise
            traj0[0] = ref[0]
            traj0[-1] = ref[-1]  # endpoints fixed

            traj1 = smooth_trajectory(traj0, ref, risk_map, iters=60, verbose=False)
            score = traj_objective(traj1, ref)

            all_trajs.append(traj1)
            all_scores.append(score)
            all_origin.append(idx_ref)

    if verbose:
        log("\n[TrajectoryModes] Refinement complete.")
        log(f"[TrajectoryModes] Total refined trajectories = {len(all_trajs)}")
        log(f"[TrajectoryModes] Score stats: min={min(all_scores):.3f}, mean={statistics.mean(all_scores):.3f}, max={max(all_scores):.3f}")
        log("")

    # Cluster by Fréchet distance (greedy, best-first)
    order = list(np.argsort(all_scores))
    clusters: List[List[int]] = []
    reps: List[int] = []

    def add_to_clusters(idx: int) -> None:
        for c_idx, rep_idx in enumerate(reps):
            d = discrete_frechet(all_trajs[idx], all_trajs[rep_idx])
            if d <= cluster_threshold:
                clusters[c_idx].append(idx)
                return
        clusters.append([idx])
        reps.append(idx)

    for idx in order:
        add_to_clusters(int(idx))

    if verbose:
        log("#" * 80)
        log(f"[TrajectoryModes] Clustering complete: found {len(clusters)} modes (clusters)")
        for c_idx, members in enumerate(clusters):
            best = min(members, key=lambda ii: all_scores[ii])
            log(f"  Mode{c_idx}: size={len(members)} best_score={all_scores[best]:.3f} origin_ref={all_origin[best]}")
        log("#" * 80 + "\n")

    return {
        "risk_map": risk_map,
        "ref_meta": ref_meta,
        "refs": refs,
        "all_trajs": all_trajs,
        "all_scores": all_scores,
        "clusters": clusters,
        "reps": reps,
    }


# =============================================================================
# Main
# =============================================================================

def main(argv=None) -> None:  # <-- accept argv for notebooks
    parser = argparse.ArgumentParser()
    parser.add_argument("--outer-iters", type=int, default=3)
    parser.add_argument("--rollouts-payoff", type=int, default=20)
    parser.add_argument("--rollouts-br", type=int, default=30)
    parser.add_argument("--risk-weight-br", type=float, default=12.0)
    parser.add_argument("--m-modes", type=int, default=2)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--debug-one-rollout-per-pair", action="store_true")
    parser.add_argument("--no-trajectory-demo", action="store_true")
    parser.add_argument("--traj-top-k", type=int, default=4)
    parser.add_argument("--traj-particles-per-pair", type=int, default=20)
    parser.add_argument("--traj-noise-std", type=float, default=0.4)
    parser.add_argument("--traj-cluster-threshold", type=float, default=1.0)

    # <-- ignore ipykernel-injected args like --f=...json
    args, _unknown = parser.parse_known_args(args=argv)

    res = run_approach2_pipeline(
        verbose=True,
        outer_iters=args.outer_iters,
        rollouts_payoff=args.rollouts_payoff,
        rollouts_br=args.rollouts_br,
        risk_weight_br=args.risk_weight_br,
        M_modes=args.m_modes,
        seed=args.seed,
        debug_one_rollout_per_pair=args.debug_one_rollout_per_pair,
    )

    if not args.no_trajectory_demo:
        multimodal_trajectory_modes_demo(
            res["env"],
            res["robot_policies"],
            res["sensor_policies"],
            res["x_star"],
            top_k_pairs=args.traj_top_k,
            particles_per_pair=args.traj_particles_per_pair,
            noise_std=args.traj_noise_std,
            cluster_threshold=args.traj_cluster_threshold,
            verbose=True,
        )

# In notebooks, avoid auto-running on cell execution
if __name__ == "__main__" and ("ipykernel" not in sys.modules):
    main()


In [11]:
#!/usr/bin/env python3
"""approach2_robust.py

Robust, low-noise logging version of Approach 2.

Key upgrades vs the earlier script:
  1) Reproducible randomness everywhere (env + random policy) using numpy Generator.
  2) Logging is stage-based with levels (INFO/DEBUG) to avoid massive logs.
  3) More checks (shapes, simplex, NaNs, unreachable paths, invalid modes).
  4) NBS supports optional entropy regularization (keeps x* mixed if desired).
  5) TrajectoryModes demo ignores zero-prob pairs by default and can weight sampling by x*.

Run:
  python approach2_robust.py --outer-iters 3 --rollouts-payoff 20 --rollouts-br 30 --log-level INFO

If you want mixed x* (useful for “mode recovery”):
  python approach2_robust.py --entropy-tau 0.02

If you want to print one detailed rollout occasionally:
  python approach2_robust.py --debug-rollout-pair 2,1

"""

from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple

import argparse
import sys
import heapq
import math
import statistics

import numpy as np


# =============================================================================
# Logging
# =============================================================================

class Logger:
    """Minimal stage-based logger with levels."""

    LEVELS = {"QUIET": 0, "INFO": 1, "DEBUG": 2}

    def __init__(self, level: str = "INFO"):
        level = level.upper()
        if level not in self.LEVELS:
            raise ValueError(f"Unknown log level: {level}. Use QUIET/INFO/DEBUG")
        self.level = level
        self.k = self.LEVELS[level]

    def banner(self, title: str) -> None:
        if self.k >= 1:
            print("\n" + "=" * 100)
            print(title)
            print("=" * 100)

    def info(self, msg: str) -> None:
        if self.k >= 1:
            print(msg)

    def debug(self, msg: str) -> None:
        if self.k >= 2:
            print(msg)


# =============================================================================
# Types
# =============================================================================

Action = Tuple[int, int]  # (dx, dy)


# =============================================================================
# Environment (grid + hidden sensor mode + noisy alarm observation)
# =============================================================================

@dataclass(frozen=True)
class GridConfig:
    width: int
    height: int
    start: Tuple[int, int]
    goal: Tuple[int, int]
    obstacles: frozenset  # set[(x,y)]


@dataclass(frozen=True)
class SensorConfig:
    sensors: Tuple[Tuple[int, int], ...]  # centers
    radius: int
    base_p: float
    hotspot_p: float


class GridWorldStealthEnv:
    """Grid world where sensor chooses a mode selecting an active hotspot sensor."""

    def __init__(self, grid: GridConfig, sensor_cfg: SensorConfig, fp: float = 0.05, fn: float = 0.10, seed: int = 0):
        self.grid = grid
        self.sensor_cfg = sensor_cfg
        self.fp = float(fp)
        self.fn = float(fn)
        if not (0.0 <= self.fp <= 1.0 and 0.0 <= self.fn <= 1.0):
            raise ValueError("fp and fn must be in [0,1].")
        if not (0.0 <= sensor_cfg.base_p <= 1.0 and 0.0 <= sensor_cfg.hotspot_p <= 1.0):
            raise ValueError("base_p and hotspot_p must be in [0,1].")
        if sensor_cfg.radius < 0:
            raise ValueError("radius must be >= 0")
        if len(sensor_cfg.sensors) == 0:
            raise ValueError("Need at least one sensor center.")

        self.rng = np.random.default_rng(seed)
        self.reset(sensor_mode=0)

    def seed(self, seed: int) -> None:
        self.rng = np.random.default_rng(int(seed))

    def reset(self, sensor_mode: int = 0) -> Dict[str, Any]:
        self.t = 0
        self.pos = self.grid.start
        self.sensor_mode = int(sensor_mode)
        self.detected = False
        self.total_true_risk = 0.0
        return {"pos": self.pos, "t": self.t}

    def in_bounds(self, p: Tuple[int, int]) -> bool:
        x, y = p
        return 0 <= x < self.grid.width and 0 <= y < self.grid.height

    def is_free(self, p: Tuple[int, int]) -> bool:
        return self.in_bounds(p) and (p not in self.grid.obstacles)

    def true_detection_prob(self, p: Tuple[int, int], mode: int) -> float:
        """p_true(p | mode)."""
        if not (0 <= mode < len(self.sensor_cfg.sensors)):
            raise ValueError(f"mode {mode} out of range")
        base = self.sensor_cfg.base_p
        hot = self.sensor_cfg.hotspot_p
        sx, sy = self.sensor_cfg.sensors[mode]
        x, y = p
        d = abs(x - sx) + abs(y - sy)
        return hot if d <= self.sensor_cfg.radius else base

    def observation_prob(self, alarm: int, p_true: float) -> float:
        """P(o=alarm | p_true) with FP/FN noise."""
        p_alarm = p_true * (1.0 - self.fn) + (1.0 - p_true) * self.fp
        if alarm == 1:
            return p_alarm
        if alarm == 0:
            return 1.0 - p_alarm
        raise ValueError("alarm must be 0 or 1")

    def step(self, a: Action) -> Dict[str, Any]:
        if self.detected:
            return {"pos": self.pos, "t": self.t, "alarm": 1, "p_true": 1.0, "detected": True, "done": True}

        self.t += 1

        nx = self.pos[0] + int(a[0])
        ny = self.pos[1] + int(a[1])
        np_ = (nx, ny)
        if self.is_free(np_):
            self.pos = np_

        p_true = self.true_detection_prob(self.pos, self.sensor_mode)
        self.total_true_risk += p_true

        # detection event
        if self.rng.random() < p_true:
            self.detected = True

        # noisy alarm
        p_alarm = p_true * (1.0 - self.fn) + (1.0 - p_true) * self.fp
        alarm = 1 if (self.rng.random() < p_alarm) else 0

        done = self.detected or (self.pos == self.grid.goal) or (self.t >= 200)
        return {"pos": self.pos, "t": self.t, "alarm": alarm, "p_true": p_true, "detected": self.detected, "done": done}


def build_two_corridor_grid(width: int = 15, height: int = 9) -> GridConfig:
    obstacles = set()
    wall_x = width // 2
    gap_ys = {2, 6}
    for y in range(height):
        if y not in gap_ys:
            obstacles.add((wall_x, y))

    start = (1, height - 2)
    goal = (width - 2, 1)
    if start in obstacles or goal in obstacles:
        raise RuntimeError("Start/goal blocked unexpectedly.")

    return GridConfig(width=width, height=height, start=start, goal=goal, obstacles=frozenset(obstacles))


def print_grid_ascii(grid: GridConfig, sensor_cfg: SensorConfig) -> None:
    W, H = grid.width, grid.height
    obs = set(grid.obstacles)
    sens = set(sensor_cfg.sensors)
    for y in range(H):
        row = []
        for x in range(W):
            p = (x, y)
            if p == grid.start:
                row.append("R")
            elif p == grid.goal:
                row.append("G")
            elif p in sens:
                row.append("S")
            elif p in obs:
                row.append("#")
            else:
                row.append(".")
        print("".join(row))


# =============================================================================
# Belief over sensor modes (exact Bayes filter)
# =============================================================================

class ModeBelief:
    def __init__(self, M: int, init: Optional[np.ndarray] = None):
        self.M = int(M)
        if self.M <= 0:
            raise ValueError("M must be >= 1")
        if init is None:
            self.b = np.full(self.M, 1.0 / self.M)
        else:
            init = np.asarray(init, dtype=float).reshape(-1)
            if init.shape != (self.M,):
                raise ValueError("init shape mismatch")
            if np.any(init < 0):
                raise ValueError("init must be nonnegative")
            s = float(init.sum())
            self.b = init / s if s > 0 else np.full(self.M, 1.0 / self.M)

    def update(self, env: GridWorldStealthEnv, alarm: int, pos: Tuple[int, int], eps: float = 1e-12) -> None:
        like = np.zeros(self.M, dtype=float)
        for m in range(self.M):
            p_true = env.true_detection_prob(pos, m)
            like[m] = env.observation_prob(alarm, p_true)
        post = self.b * like
        Z = float(post.sum())
        if (not np.isfinite(Z)) or Z < eps:
            # robust fallback
            return
        self.b = post / Z


# =============================================================================
# Policies
# =============================================================================

class RobotPolicy:
    name: str = "RobotPolicy"

    def reset(self, start_pos: Tuple[int, int]) -> None:
        pass

    def act(self, env: GridWorldStealthEnv, belief: ModeBelief, last_obs: Optional[int]) -> Action:
        raise NotImplementedError


class FixedPathPolicy(RobotPolicy):
    def __init__(self, path: List[Tuple[int, int]], name: str):
        if len(path) < 2:
            raise ValueError("Path must have >=2 states")
        self.path = list(path)
        self.name = str(name)
        self._idx = 0

    def reset(self, start_pos: Tuple[int, int]) -> None:
        try:
            self._idx = self.path.index(start_pos)
        except ValueError:
            self._idx = 0

    def act(self, env: GridWorldStealthEnv, belief: ModeBelief, last_obs: Optional[int]) -> Action:
        cur = env.pos
        if self._idx >= len(self.path) - 1:
            return (0, 0)
        if cur != self.path[self._idx]:
            # resync if possible
            try:
                self._idx = self.path.index(cur, self._idx)
            except ValueError:
                return (0, 0)
        nxt = self.path[self._idx + 1]
        dx = int(np.clip(nxt[0] - cur[0], -1, 1))
        dy = int(np.clip(nxt[1] - cur[1], -1, 1))
        self._idx += 1
        return (dx, dy)


class RandomPolicy(RobotPolicy):
    """Reproducible random policy: uses env.rng (NOT global random)."""

    def __init__(self, name: str = "R_Random"):
        self.name = name

    def reset(self, start_pos: Tuple[int, int]) -> None:
        pass

    def act(self, env: GridWorldStealthEnv, belief: ModeBelief, last_obs: Optional[int]) -> Action:
        candidates: List[Action] = []
        x, y = env.pos
        for a in [(1, 0), (-1, 0), (0, 1), (0, -1), (0, 0)]:
            np_ = (x + a[0], y + a[1])
            if env.is_free(np_):
                candidates.append(a)
        if not candidates:
            return (0, 0)
        idx = int(env.rng.integers(0, len(candidates)))
        return candidates[idx]


class SensorPolicy:
    name: str = "SensorPolicy"

    def reset(self) -> None:
        pass

    def select_mode(self, t: int) -> int:
        raise NotImplementedError


class FixedModeSensorPolicy(SensorPolicy):
    def __init__(self, mode: int, name: Optional[str] = None):
        self.mode = int(mode)
        self.name = name or f"S_Mode{mode}"

    def reset(self) -> None:
        pass

    def select_mode(self, t: int) -> int:
        return self.mode


# =============================================================================
# A* (used by robot BR)
# =============================================================================

def astar_path(
    grid: GridConfig,
    start: Tuple[int, int],
    goal: Tuple[int, int],
    step_cost: Callable[[Tuple[int, int], Tuple[int, int]], float],
    max_expansions: int = 250_000,
) -> List[Tuple[int, int]]:
    if start == goal:
        return [start]

    def h(p: Tuple[int, int]) -> float:
        return abs(p[0] - goal[0]) + abs(p[1] - goal[1])

    def neighbors(p: Tuple[int, int]):
        x, y = p
        for dx, dy in [(1, 0), (-1, 0), (0, 1), (0, -1)]:
            np_ = (x + dx, y + dy)
            if 0 <= np_[0] < grid.width and 0 <= np_[1] < grid.height and np_ not in grid.obstacles:
                yield np_

    open_heap: List[Tuple[float, float, Tuple[int, int]]] = []
    heapq.heappush(open_heap, (h(start), 0.0, start))

    came: Dict[Tuple[int, int], Optional[Tuple[int, int]]] = {start: None}
    gscore: Dict[Tuple[int, int], float] = {start: 0.0}

    expansions = 0
    while open_heap:
        f, g, cur = heapq.heappop(open_heap)
        expansions += 1
        if cur == goal:
            path: List[Tuple[int, int]] = []
            while cur is not None:
                path.append(cur)
                cur = came[cur]
            path.reverse()
            return path
        if expansions > max_expansions:
            raise RuntimeError("A* exceeded max expansions")

        for nb in neighbors(cur):
            tentative = gscore[cur] + float(step_cost(cur, nb))
            if (nb not in gscore) or (tentative < gscore[nb] - 1e-12):
                gscore[nb] = tentative
                came[nb] = cur
                heapq.heappush(open_heap, (tentative + h(nb), tentative, nb))

    raise RuntimeError("A* failed: unreachable goal")


def astar_via_waypoint(
    grid: GridConfig,
    start: Tuple[int, int],
    waypoint: Tuple[int, int],
    goal: Tuple[int, int],
    step_cost: Callable[[Tuple[int, int], Tuple[int, int]], float],
) -> List[Tuple[int, int]]:
    p1 = astar_path(grid, start, waypoint, step_cost)
    p2 = astar_path(grid, waypoint, goal, step_cost)
    return p1[:-1] + p2


# =============================================================================
# Rollouts + payoffs
# =============================================================================

@dataclass
class EpisodeStats:
    steps: int
    reached_goal: bool
    detected: bool
    total_true_risk: float
    U_R: float
    U_S: float


def rollout_episode(
    env: GridWorldStealthEnv,
    robot: RobotPolicy,
    sensor: SensorPolicy,
    M_modes: int,
    seed: int,
    max_steps: int = 200,
    lambda_risk: float = 1.0,
    det_penalty: float = 50.0,
    sensor_energy_per_step: float = 0.2,
    step_debug: bool = False,
) -> EpisodeStats:
    """One episode with reproducible seeding; optional step_debug prints only when enabled."""

    env.seed(seed)
    sensor.reset()
    env.reset(sensor_mode=sensor.select_mode(0))

    belief = ModeBelief(M_modes)
    robot.reset(env.pos)

    total_risk = 0.0
    last_alarm: Optional[int] = None

    for _ in range(max_steps):
        env.sensor_mode = sensor.select_mode(env.t)
        a = robot.act(env, belief, last_alarm)
        out = env.step(a)

        total_risk += out["p_true"]
        belief.update(env, out["alarm"], out["pos"])
        last_alarm = out["alarm"]

        if step_debug:
            print(f"[Step] t={out['t']:3d} pos={out['pos']} a={a} p_true={out['p_true']:.3f} alarm={out['alarm']} det={out['detected']} done={out['done']} b={belief.b.round(3)}")

        if out["done"]:
            break

    reached_goal = (env.pos == env.grid.goal)
    detected = bool(env.detected)
    steps = int(env.t)

    cost_R = steps + lambda_risk * total_risk + (det_penalty if detected else 0.0)
    U_R = -float(cost_R)

    U_S = float((det_penalty if detected else 0.0) + lambda_risk * total_risk - sensor_energy_per_step * steps)

    return EpisodeStats(steps=steps, reached_goal=reached_goal, detected=detected, total_true_risk=float(total_risk), U_R=U_R, U_S=U_S)


def evaluate_payoffs(
    env: GridWorldStealthEnv,
    robots: List[RobotPolicy],
    sensors: List[SensorPolicy],
    M_modes: int,
    rollouts: int,
    base_seed: int,
    log: Logger,
    debug_rollout_pair: Optional[Tuple[int, int]] = None,
) -> Tuple[np.ndarray, np.ndarray, Dict[Tuple[int, int], Dict[str, float]]]:
    m, n = len(robots), len(sensors)
    U_R = np.zeros((m, n), dtype=float)
    U_S = np.zeros((m, n), dtype=float)
    diag: Dict[Tuple[int, int], Dict[str, float]] = {}

    log.info(f"[Eval] Estimating payoffs: m={m}, n={n}, rollouts={rollouts}, base_seed={base_seed}")

    for i, rpol in enumerate(robots):
        for j, spol in enumerate(sensors):
            step_debug = (debug_rollout_pair == (i, j))
            r_list: List[float] = []
            s_list: List[float] = []
            det = 0
            goal = 0
            steps_list: List[int] = []
            risk_list: List[float] = []

            for k in range(rollouts):
                seed = base_seed + 100000 * i + 1000 * j + k
                st = rollout_episode(env, rpol, spol, M_modes=M_modes, seed=seed, step_debug=step_debug)
                r_list.append(st.U_R)
                s_list.append(st.U_S)
                det += int(st.detected)
                goal += int(st.reached_goal)
                steps_list.append(st.steps)
                risk_list.append(st.total_true_risk)

                # If we printed step-debug once, don't do it for all rollouts
                if step_debug:
                    step_debug = False

            U_R[i, j] = float(np.mean(r_list))
            U_S[i, j] = float(np.mean(s_list))

            diag[(i, j)] = {
                "det_rate": det / rollouts,
                "goal_rate": goal / rollouts,
                "mean_steps": float(np.mean(steps_list)),
                "mean_risk": float(np.mean(risk_list)),
                "std_UR": float(np.std(r_list)),
                "std_US": float(np.std(s_list)),
            }

    # Print compact summary table (INFO) — only key pairs
    if log.k >= 1:
        log.info("[Eval] Compact payoff summary (showing all pairs but one-line each):")
        for i, rpol in enumerate(robots):
            for j, spol in enumerate(sensors):
                d = diag[(i, j)]
                log.info(
                    f"  (R{i}:{rpol.name}, S{j}:{spol.name}) "
                    f"UR={U_R[i,j]:8.3f}±{d['std_UR']:.2f} | "
                    f"US={U_S[i,j]:8.3f}±{d['std_US']:.2f} | "
                    f"det%={100*d['det_rate']:5.1f} goal%={100*d['goal_rate']:5.1f} "
                    f"steps={d['mean_steps']:.1f} risk={d['mean_risk']:.2f}"
                )

    return U_R, U_S, diag


# =============================================================================
# NBS solver (with optional entropy regularization)
# =============================================================================

def project_simplex(v: np.ndarray, z: float = 1.0) -> np.ndarray:
    v = np.asarray(v, dtype=float).reshape(-1)
    if v.size == 0:
        raise ValueError("Empty vector")
    if z <= 0:
        raise ValueError("z must be > 0")

    u = np.sort(v)[::-1]
    cssv = np.cumsum(u)
    rho = np.nonzero(u * np.arange(1, v.size + 1) > (cssv - z))[0]
    if rho.size == 0:
        return np.full_like(v, z / v.size)
    rho = int(rho[-1])
    theta = (cssv[rho] - z) / (rho + 1.0)
    w = np.maximum(v - theta, 0.0)
    s = float(w.sum())
    if not np.isfinite(s) or s <= 0:
        return np.full_like(v, z / v.size)
    return w * (z / s)


@dataclass
class NBSResult:
    x: np.ndarray
    obj: float
    gains: Tuple[float, float]
    support: int


def solve_nbs(
    uR: np.ndarray,
    uS: np.ndarray,
    log: Logger,
    max_iters: int = 400,
    alpha: float = 0.5,
    tol_l1: float = 1e-6,
    kappa: float = 1e-6,
    disagreement: str = "minminus",  # or "uniform"
    entropy_tau: float = 0.0,
) -> NBSResult:
    """Projected gradient ascent on:

        f(x) = log(uR^T x - dR) + log(uS^T x - dS) + tau * H(x)

    where H(x) = -sum x_i log x_i.

    disagreement:
      - minminus: dR=min(uR)-1, dS=min(uS)-1
      - uniform : dR=uR^T unif, dS=uS^T unif  (much less degenerate)
    """

    uR = np.asarray(uR, dtype=float).reshape(-1)
    uS = np.asarray(uS, dtype=float).reshape(-1)
    if uR.shape != uS.shape:
        raise ValueError("uR and uS must have same shape")
    d = uR.size
    if d < 2:
        raise ValueError("Need >=2 joint actions")

    unif = np.full(d, 1.0 / d)

    disagreement = disagreement.lower().strip()
    if disagreement == "minminus":
        dR = float(np.min(uR) - 1.0)
        dS = float(np.min(uS) - 1.0)
    elif disagreement == "uniform":
        # IMPORTANT: this is an *outside option* baseline, not a security level.
        dR = float(uR @ unif)
        dS = float(uS @ unif)
    else:
        raise ValueError("disagreement must be 'minminus' or 'uniform'")

    x = unif.copy()

    def gains(xv: np.ndarray) -> Tuple[float, float]:
        return float(uR @ xv - dR), float(uS @ xv - dS)

    def entropy(xv: np.ndarray) -> float:
        # stable entropy for simplex vector
        xx = np.clip(xv, 1e-12, 1.0)
        return float(-np.sum(xx * np.log(xx)))

    def obj(xv: np.ndarray) -> float:
        gR, gS = gains(xv)
        gR = max(gR, kappa)
        gS = max(gS, kappa)
        return float(np.log(gR) + np.log(gS) + entropy_tau * entropy(xv))

    def grad(xv: np.ndarray) -> np.ndarray:
        gR, gS = gains(xv)
        gR = max(gR, kappa)
        gS = max(gS, kappa)
        g = (uR / gR) + (uS / gS)
        if entropy_tau > 0:
            # grad of tau * (-sum x log x) is tau * (-(log x + 1))
            xx = np.clip(xv, 1e-12, 1.0)
            g += entropy_tau * (-(np.log(xx) + 1.0))
        return g

    last = obj(x)

    # Keep logs short: only print start + end + occasional progress if DEBUG.
    log.info(f"[NBS] d={d} disagreement=({dR:.3f},{dS:.3f}) entropy_tau={entropy_tau:.3g}")

    for t in range(1, max_iters + 1):
        g = grad(x)
        a = alpha
        improved = False
        for _ in range(30):
            x_new = project_simplex(x + a * g)
            new_obj = obj(x_new)
            if new_obj >= last - 1e-12:
                improved = True
                break
            a *= 0.5
            if a < 1e-6:
                break
        if not improved:
            break

        delta = float(np.linalg.norm(x_new - x, ord=1))
        x = x_new
        last = new_obj

        if log.k >= 2 and (t <= 5 or t % 25 == 0):
            top = np.argsort(-x)[:5]
            top_str = ", ".join([f"{i}:{x[i]:.3f}" for i in top])
            gR, gS = gains(x)
            log.debug(f"[NBS][it={t:3d}] obj={last:.6f} gains=({gR:.3f},{gS:.3f}) L1={delta:.2e} top={top_str}")

        if delta < tol_l1:
            break

    gR, gS = gains(x)
    support = int(np.sum(x > 1e-6))
    log.info(f"[NBS] done: obj={last:.6f} gains=({gR:.3f},{gS:.3f}) support={support}/{d}")

    return NBSResult(x=x, obj=float(last), gains=(float(gR), float(gS)), support=support)


def joint_to_matrix(x: np.ndarray, m: int, n: int) -> np.ndarray:
    x = np.asarray(x, dtype=float).reshape(-1)
    if x.size != m * n:
        raise ValueError("x size mismatch")
    return x.reshape((m, n))


def marginals_from_joint(x: np.ndarray, m: int, n: int) -> Tuple[np.ndarray, np.ndarray]:
    X = joint_to_matrix(x, m, n)
    sigma_R = X.sum(axis=1)
    sigma_S = X.sum(axis=0)
    if sigma_R.sum() > 0:
        sigma_R = sigma_R / sigma_R.sum()
    if sigma_S.sum() > 0:
        sigma_S = sigma_S / sigma_S.sum()
    return sigma_R, sigma_S


# =============================================================================
# Best responses
# =============================================================================

def compute_expected_risk_map(env: GridWorldStealthEnv, sensors: List[SensorPolicy], sigma_S: np.ndarray, M_modes: int) -> np.ndarray:
    sigma_S = np.asarray(sigma_S, dtype=float).reshape(-1)
    if sigma_S.size != len(sensors):
        raise ValueError("sigma_S length mismatch")

    mode_probs = np.zeros(M_modes, dtype=float)
    for j, sp in enumerate(sensors):
        m = sp.select_mode(0)
        if not (0 <= m < M_modes):
            raise ValueError(f"Sensor policy returned invalid mode {m}")
        mode_probs[m] += sigma_S[j]

    if mode_probs.sum() > 0:
        mode_probs = mode_probs / mode_probs.sum()

    H, W = env.grid.height, env.grid.width
    risk = np.zeros((H, W), dtype=float)
    for y in range(H):
        for x in range(W):
            if (x, y) in env.grid.obstacles:
                risk[y, x] = np.nan
                continue
            val = 0.0
            for m in range(M_modes):
                val += mode_probs[m] * env.true_detection_prob((x, y), m)
            risk[y, x] = val

    return risk


def robot_best_response(env: GridWorldStealthEnv, sensors: List[SensorPolicy], sigma_S: np.ndarray, robots: List[RobotPolicy], M_modes: int, risk_weight: float, log: Logger) -> RobotPolicy:
    risk = compute_expected_risk_map(env, sensors, sigma_S, M_modes=M_modes)

    def step_cost(frm: Tuple[int, int], to: Tuple[int, int]) -> float:
        x, y = to
        r = risk[y, x]
        if not np.isfinite(r):
            return 1e9
        return 1.0 + risk_weight * float(r)

    try:
        path = astar_path(env.grid, env.grid.start, env.grid.goal, step_cost)
    except Exception as e:
        log.info(f"[RobotBR] WARNING: A* failed: {e}. Returning existing shortest if present.")
        for p in robots:
            if isinstance(p, FixedPathPolicy) and "Shortest" in p.name:
                return p
        return robots[0]

    path_tuple = tuple(path)
    for p in robots:
        if isinstance(p, FixedPathPolicy) and tuple(p.path) == path_tuple:
            log.info(f"[RobotBR] BR path already exists: {p.name}")
            return p

    newp = FixedPathPolicy(path, name=f"R_BR_RiskAStar_w{risk_weight:.1f}_len{len(path)}")
    log.info(f"[RobotBR] Added new robot policy: {newp.name}")
    return newp


def sensor_best_response(env: GridWorldStealthEnv, robots: List[RobotPolicy], sigma_R: np.ndarray, candidate_modes: List[int], M_modes: int, rollouts: int, base_seed: int, log: Logger) -> FixedModeSensorPolicy:
    sigma_R = np.asarray(sigma_R, dtype=float).reshape(-1)
    if sigma_R.size != len(robots):
        raise ValueError("sigma_R length mismatch")

    rng = np.random.default_rng(base_seed)

    best_mode: Optional[int] = None
    best_val = -1e18

    for mode in candidate_modes:
        if not (0 <= mode < M_modes):
            continue
        sp = FixedModeSensorPolicy(mode, name=f"S_BR_Mode{mode}")
        vals: List[float] = []
        for k in range(rollouts):
            i = int(rng.choice(len(robots), p=sigma_R))
            rp = robots[i]
            seed = base_seed + 10000 * mode + k
            st = rollout_episode(env, rp, sp, M_modes=M_modes, seed=seed)
            vals.append(st.U_S)
        mean_u = float(np.mean(vals))
        if log.k >= 2:
            log.debug(f"[SensorBR] mode={mode} E[US]={mean_u:.3f} std={float(np.std(vals)):.2f}")
        if mean_u > best_val:
            best_val = mean_u
            best_mode = mode

    if best_mode is None:
        raise RuntimeError("No valid sensor BR mode found")

    log.info(f"[SensorBR] Best mode={best_mode} E[US]={best_val:.3f}")
    return FixedModeSensorPolicy(best_mode, name=f"S_BR_Mode{best_mode}")


def find_sensor_by_mode(pols: List[SensorPolicy], mode: int) -> Optional[FixedModeSensorPolicy]:
    for p in pols:
        if isinstance(p, FixedModeSensorPolicy) and p.mode == mode:
            return p
    return None


def build_initial_policies(env: GridWorldStealthEnv, M_modes: int) -> Tuple[List[RobotPolicy], List[SensorPolicy]]:
    grid = env.grid

    step_cost = lambda a, b: 1.0
    p_short = astar_path(grid, grid.start, grid.goal, step_cost)

    wall_x = grid.width // 2
    p_upper = astar_via_waypoint(grid, grid.start, (wall_x, 2), grid.goal, step_cost)
    p_lower = astar_via_waypoint(grid, grid.start, (wall_x, 6), grid.goal, step_cost)

    robots: List[RobotPolicy] = [
        FixedPathPolicy(p_short, "R_Shortest"),
        FixedPathPolicy(p_upper, "R_UpperCorridor"),
        FixedPathPolicy(p_lower, "R_LowerCorridor"),
        RandomPolicy("R_Random"),
    ]
    sensors: List[SensorPolicy] = [FixedModeSensorPolicy(m, name=f"S_Mode{m}") for m in range(M_modes)]
    return robots, sensors


# =============================================================================
# Trajectory modes demo (reduced confusion)
# =============================================================================

def grid_path_to_continuous(path: List[Tuple[int, int]]) -> np.ndarray:
    return np.array([[float(x), float(y)] for (x, y) in path], dtype=float)


def discrete_frechet(P: np.ndarray, Q: np.ndarray) -> float:
    P = np.asarray(P, dtype=float)
    Q = np.asarray(Q, dtype=float)
    m, n = P.shape[0], Q.shape[0]
    ca = np.full((m, n), -1.0, dtype=float)

    def dist(i: int, j: int) -> float:
        return float(np.linalg.norm(P[i] - Q[j]))

    def rec(i: int, j: int) -> float:
        if ca[i, j] > -0.5:
            return float(ca[i, j])
        if i == 0 and j == 0:
            ca[i, j] = dist(0, 0)
        elif i > 0 and j == 0:
            ca[i, j] = max(rec(i - 1, 0), dist(i, 0))
        elif i == 0 and j > 0:
            ca[i, j] = max(rec(0, j - 1), dist(0, j))
        else:
            ca[i, j] = max(min(rec(i - 1, j), rec(i - 1, j - 1), rec(i, j - 1)), dist(i, j))
        return float(ca[i, j])

    return rec(m - 1, n - 1)


def smooth_trajectory(traj: np.ndarray, ref: np.ndarray, risk_map: np.ndarray, iters: int = 60, step: float = 0.25) -> np.ndarray:
    traj = traj.copy()
    L = traj.shape[0]
    H, W = risk_map.shape

    def risk_at(p: np.ndarray) -> float:
        x = int(round(float(p[0])))
        y = int(round(float(p[1])))
        x = max(0, min(W - 1, x))
        y = max(0, min(H - 1, y))
        r = risk_map[y, x]
        return 1.0 if (not np.isfinite(r)) else float(r)

    w_track, w_smooth, w_risk = 1.0, 0.3, 5.0

    for _ in range(iters):
        for t in range(1, L - 1):
            mid = 0.5 * (traj[t - 1] + traj[t + 1])
            grad_smooth = traj[t] - mid
            grad_track = traj[t] - ref[t]

            eps = 1.0
            r0 = risk_at(traj[t])
            rx = (risk_at(traj[t] + np.array([eps, 0.0])) - r0) / eps
            ry = (risk_at(traj[t] + np.array([0.0, eps])) - r0) / eps
            grad_risk = np.array([rx, ry])

            grad = w_smooth * grad_smooth + w_track * grad_track + w_risk * grad_risk
            traj[t] = traj[t] - step * grad

    # restore endpoints
    traj[0] = ref[0]
    traj[-1] = ref[-1]
    return traj


def trajectory_modes_demo(
    env: GridWorldStealthEnv,
    robots: List[RobotPolicy],
    sensors: List[SensorPolicy],
    x_star: np.ndarray,
    log: Logger,
    top_k_pairs: int = 4,
    particles_per_pair: int = 20,
    noise_std: float = 0.4,
    cluster_threshold: float = 1.0,
    min_prob: float = 1e-6,
) -> None:
    m, n = len(robots), len(sensors)
    X = joint_to_matrix(x_star, m, n)
    sigma_R, sigma_S = marginals_from_joint(x_star, m, n)

    risk_map = compute_expected_risk_map(env, sensors, sigma_S, M_modes=len(env.sensor_cfg.sensors))

    # ONLY take positive-prob pairs (avoid the “prob=0 refs” confusion)
    flat = np.argsort(-x_star)
    pairs: List[Tuple[int, int]] = []
    for idx in flat:
        if x_star[idx] < min_prob:
            break
        pairs.append(tuple(np.unravel_index(int(idx), (m, n))))
        if len(pairs) >= top_k_pairs:
            break

    if not pairs:
        log.info("[TrajModes] No pairs above min_prob; skipping trajectory demo.")
        return

    log.banner("[TrajModes] Trajectory modes demo (positive-prob pairs only)")
    for rank, (i, j) in enumerate(pairs, start=1):
        log.info(f"  Top#{rank}: (R{i}:{robots[i].name}, S{j}:{sensors[j].name}) prob={X[i,j]:.4f}")

    # Build ref trajectories
    refs: List[np.ndarray] = []
    ref_w: List[float] = []
    for (i, j) in pairs:
        rp = robots[i]
        if not isinstance(rp, FixedPathPolicy):
            continue
        refs.append(grid_path_to_continuous(rp.path))
        ref_w.append(float(X[i, j]))

    if not refs:
        log.info("[TrajModes] No FixedPathPolicy among selected pairs; skipping.")
        return

    rng = np.random.default_rng(0)
    all_trajs: List[np.ndarray] = []
    all_scores: List[float] = []

    H, W = risk_map.shape

    def risk_at(p: np.ndarray) -> float:
        x = int(round(float(p[0])))
        y = int(round(float(p[1])))
        x = max(0, min(W - 1, x))
        y = max(0, min(H - 1, y))
        r = risk_map[y, x]
        return 1.0 if (not np.isfinite(r)) else float(r)

    def objective(traj: np.ndarray, ref: np.ndarray) -> float:
        w_track, w_smooth, w_risk = 1.0, 0.3, 5.0
        obj = 0.0
        for t in range(traj.shape[0]):
            obj += w_track * float(np.sum((traj[t] - ref[t]) ** 2))
            if t > 0:
                obj += w_smooth * float(np.sum((traj[t] - traj[t - 1]) ** 2))
            obj += w_risk * risk_at(traj[t])
        return float(obj)

    log.info(f"[TrajModes] Sampling {particles_per_pair} particles per ref, noise_std={noise_std}")

    for ref in refs:
        for _ in range(particles_per_pair):
            traj0 = ref + rng.normal(0.0, noise_std, size=ref.shape)
            traj0[0] = ref[0]
            traj0[-1] = ref[-1]
            traj1 = smooth_trajectory(traj0, ref, risk_map)
            all_trajs.append(traj1)
            all_scores.append(objective(traj1, ref))

    # Greedy clustering by Fréchet (best-first)
    order = list(np.argsort(all_scores))
    clusters: List[List[int]] = []
    reps: List[int] = []

    def assign(idx: int) -> None:
        for c_idx, rep in enumerate(reps):
            if discrete_frechet(all_trajs[idx], all_trajs[rep]) <= cluster_threshold:
                clusters[c_idx].append(idx)
                return
        clusters.append([idx])
        reps.append(idx)

    for idx in order:
        assign(int(idx))

    log.info(f"[TrajModes] Done. clusters={len(clusters)} (threshold={cluster_threshold})")
    for c_idx, members in enumerate(clusters[:10]):
        best = min(members, key=lambda ii: all_scores[ii])
        log.info(f"  Mode{c_idx}: size={len(members)} best_score={all_scores[best]:.3f}")


# =============================================================================
# Pipeline
# =============================================================================

def run_pipeline(args: argparse.Namespace) -> None:
    log = Logger(args.log_level)

    grid = build_two_corridor_grid()
    sensor_cfg = SensorConfig(
        sensors=((grid.width // 2, 2), (grid.width // 2, 6)),
        radius=2,
        base_p=0.02,
        hotspot_p=0.60,
    )
    env = GridWorldStealthEnv(grid, sensor_cfg, fp=0.05, fn=0.10, seed=args.seed)

    log.banner("[PIPELINE] Approach 2 (Robust): Empirical POMDP Policy Game + NBS + BR")
    log.info(f"Grid: {grid.width}x{grid.height} start={grid.start} goal={grid.goal} obstacles={len(grid.obstacles)}")
    log.info(f"Sensors: {sensor_cfg.sensors} radius={sensor_cfg.radius} base_p={sensor_cfg.base_p} hotspot_p={sensor_cfg.hotspot_p}")
    if log.k >= 2:
        log.debug("ASCII map:")
        print_grid_ascii(grid, sensor_cfg)

    M_modes = args.m_modes
    if M_modes != len(sensor_cfg.sensors):
        log.info(f"[WARN] m_modes={M_modes} but sensors listed={len(sensor_cfg.sensors)}. Using sensors count.")
        M_modes = len(sensor_cfg.sensors)

    robots, sensors = build_initial_policies(env, M_modes=M_modes)

    log.info("Initial robots: " + ", ".join([p.name for p in robots]))
    log.info("Initial sensors: " + ", ".join([p.name for p in sensors]))

    debug_pair = None
    if args.debug_rollout_pair:
        parts = args.debug_rollout_pair.split(",")
        if len(parts) == 2:
            debug_pair = (int(parts[0]), int(parts[1]))
            log.info(f"[DEBUG] Will print one step-by-step rollout for pair {debug_pair} (only once).")

    x_star = None

    for it in range(1, args.outer_iters + 1):
        log.banner(f"[PIPELINE] Outer iter {it}/{args.outer_iters}")

        U_R, U_S, diag = evaluate_payoffs(
            env,
            robots,
            sensors,
            M_modes=M_modes,
            rollouts=args.rollouts_payoff,
            base_seed=1000 + 100 * it,
            log=log,
            debug_rollout_pair=debug_pair,
        )
        debug_pair = None  # only once

        # Solve NBS over joint actions
        uR = U_R.reshape(-1)
        uS = U_S.reshape(-1)

        nbs = solve_nbs(
            uR,
            uS,
            log=log,
            max_iters=400,
            alpha=0.5,
            disagreement=args.disagreement,
            entropy_tau=args.entropy_tau,
        )
        x_star = nbs.x

        m, n = len(robots), len(sensors)
        X = joint_to_matrix(x_star, m, n)
        sigma_R, sigma_S = marginals_from_joint(x_star, m, n)

        # Compact report
        top = np.argsort(-x_star)[:min(5, x_star.size)]
        log.info("[NBS] Top joint actions:")
        for k, idx in enumerate(top, start=1):
            i, j = np.unravel_index(int(idx), (m, n))
            log.info(f"  #{k}: (R{i}:{robots[i].name}, S{j}:{sensors[j].name}) prob={X[i,j]:.4f}")
        log.info(f"[NBS] sigma_R={sigma_R.round(3)}")
        log.info(f"[NBS] sigma_S={sigma_S.round(3)}")

        # Best responses
        br_r = robot_best_response(env, sensors, sigma_S, robots, M_modes=M_modes, risk_weight=args.risk_weight_br, log=log)
        if br_r not in robots:
            robots.append(br_r)

        br_s = sensor_best_response(
            env,
            robots,
            sigma_R,
            candidate_modes=list(range(M_modes)),
            M_modes=M_modes,
            rollouts=args.rollouts_br,
            base_seed=2000 + 100 * it,
            log=log,
        )
        if find_sensor_by_mode(sensors, br_s.mode) is None:
            sensors.append(br_s)
        else:
            log.info(f"[SensorBR] Mode {br_s.mode} already present; not adding duplicate.")

        log.info(f"[Sets] |Pi_R|={len(robots)} |Pi_S|={len(sensors)}")

    log.banner("[PIPELINE] Finished")
    log.info("Final robots: " + ", ".join([p.name for p in robots]))
    log.info("Final sensors: " + ", ".join([p.name for p in sensors]))

    if args.run_traj_demo and x_star is not None:
        trajectory_modes_demo(
            env,
            robots,
            sensors,
            x_star,
            log=log,
            top_k_pairs=args.traj_top_k,
            particles_per_pair=args.traj_particles,
            noise_std=args.traj_noise_std,
            cluster_threshold=args.traj_cluster_threshold,
            min_prob=args.traj_min_prob,
        )


# =============================================================================
# CLI
# =============================================================================

def parse_args(argv: Optional[List[str]] = None) -> Tuple[argparse.Namespace, List[str]]:
    """Parse CLI args.

    Uses parse_known_args so Jupyter/ipykernel-injected args like --f=... don't crash.
    """
    p = argparse.ArgumentParser()

    p.add_argument("--seed", type=int, default=0)
    p.add_argument("--log-level", type=str, default="INFO", choices=["QUIET", "INFO", "DEBUG"])

    p.add_argument("--outer-iters", type=int, default=3)
    p.add_argument("--m-modes", type=int, default=2)

    p.add_argument("--rollouts-payoff", type=int, default=20)
    p.add_argument("--rollouts-br", type=int, default=30)
    p.add_argument("--risk-weight-br", type=float, default=12.0)

    # NBS knobs
    p.add_argument("--disagreement", type=str, default="minminus", choices=["minminus", "uniform"])
    p.add_argument("--entropy-tau", type=float, default=0.0)

    # Debug: print ONE step-by-step rollout for (Ri,Sj)
    p.add_argument("--debug-rollout-pair", type=str, default="")

    # Trajectory modes demo
    p.add_argument("--run-traj-demo", action="store_true")
    p.add_argument("--traj-top-k", type=int, default=4)
    p.add_argument("--traj-particles", type=int, default=20)
    p.add_argument("--traj-noise-std", type=float, default=0.4)
    p.add_argument("--traj-cluster-threshold", type=float, default=1.0)
    p.add_argument("--traj-min-prob", type=float, default=1e-6)

    args, unknown = p.parse_known_args(args=argv)
    return args, unknown


def main(argv: Optional[List[str]] = None) -> None:
    """Entry point.

    In notebooks, call:
        main(argv=[])
    to avoid picking up ipykernel flags.
    """
    args, unknown = parse_args(argv=argv)

    # If running as a normal script and there are unknown args, warn once.
    # In notebooks, ipykernel injects args like "--f=...json" — we intentionally ignore them.
    if unknown and ("ipykernel" not in sys.modules):
        print(f"[WARN] Ignoring unknown CLI args: {unknown}")
    # normalize empty debug string
    if args.debug_rollout_pair.strip() == "":
        args.debug_rollout_pair = ""
    run_pipeline(args)


# In notebooks, avoid auto-running on cell execution.
if __name__ == "__main__" and ("ipykernel" not in sys.modules):
    main()


In [15]:
#!/usr/bin/env python3
"""approach2_robust_correlated.py

Robust, reduced-log implementation of "Approach 2" on a stealth gridworld POMDP,
with a FIX for the conceptual mismatch:

  Old mismatch: solve a joint distribution x*(i,j) over (robot policy i, sensor policy j)
  but then compute best responses against the MARGINALS sigma_R, sigma_S.

  Fix here: treat x* as a CORRELATION DEVICE (mediator) and compute BRs against
  CONDITIONAL distributions:

      q_S(.|i) = X[i,:] / sigma_R[i]   (sensor conditional given robot recommendation i)
      q_R(.|j) = X[:,j] / sigma_S[j]   (robot conditional given sensor recommendation j)

  Then add the most profitable *deviation* policy (oracle) based on the most violated
  conditional recommendation.

This makes the PSRO-style expansion step consistent with a correlated-strategy viewpoint.

Also included:
  - Benchmark harness vs simple heuristics on the same game.
  - Saved plots (training curves + bar charts) for presentations.

Run (script):
  python approach2_robust_correlated.py --solver correlated --outer-iters 3 --save-plots

Run (notebook):
  main(argv=["--solver","both","--outer-iters","3","--save-plots"])  # ignore ipykernel args

Tips to encourage MULTIMODAL x* (useful for your "mode recovery"):
  --disagreement uniform --entropy-tau 0.02

"""

from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Tuple

import argparse
import sys
import heapq
import os
import time

import numpy as np


# =============================================================================
# Logging
# =============================================================================

class Logger:
    """Stage-based logger with QUIET/INFO/DEBUG."""

    LEVELS = {"QUIET": 0, "INFO": 1, "DEBUG": 2}

    def __init__(self, level: str = "INFO"):
        level = level.upper()
        if level not in self.LEVELS:
            raise ValueError(f"Unknown log level: {level}. Use QUIET/INFO/DEBUG")
        self.level = level
        self.k = self.LEVELS[level]

    def banner(self, title: str) -> None:
        if self.k >= 1:
            print("" + "=" * 100)
            print(title)
            print("=" * 100)

    def info(self, msg: str) -> None:
        if self.k >= 1:
            print(msg)

    def debug(self, msg: str) -> None:
        if self.k >= 2:
            print(msg)


# =============================================================================
# Types
# =============================================================================

Action = Tuple[int, int]  # (dx, dy)


# =============================================================================
# Environment (grid + hidden sensor mode + noisy alarm observation)
# =============================================================================

@dataclass(frozen=True)
class GridConfig:
    width: int
    height: int
    start: Tuple[int, int]
    goal: Tuple[int, int]
    obstacles: frozenset


@dataclass(frozen=True)
class SensorConfig:
    sensors: Tuple[Tuple[int, int], ...]
    radius: int
    base_p: float
    hotspot_p: float


class GridWorldStealthEnv:
    """Grid world with detection risk controlled by a hidden/selected sensor 'mode'."""

    def __init__(self, grid: GridConfig, sensor_cfg: SensorConfig, fp: float = 0.05, fn: float = 0.10, seed: int = 0):
        self.grid = grid
        self.sensor_cfg = sensor_cfg
        self.fp = float(fp)
        self.fn = float(fn)

        if not (0.0 <= self.fp <= 1.0 and 0.0 <= self.fn <= 1.0):
            raise ValueError("fp and fn must be in [0,1].")
        if not (0.0 <= sensor_cfg.base_p <= 1.0 and 0.0 <= sensor_cfg.hotspot_p <= 1.0):
            raise ValueError("base_p and hotspot_p must be in [0,1].")
        if sensor_cfg.radius < 0:
            raise ValueError("radius must be >= 0")
        if len(sensor_cfg.sensors) == 0:
            raise ValueError("Need at least one sensor center.")

        self.rng = np.random.default_rng(int(seed))
        self.reset(sensor_mode=0)

    def seed(self, seed: int) -> None:
        self.rng = np.random.default_rng(int(seed))

    def reset(self, sensor_mode: int = 0) -> Dict[str, Any]:
        self.t = 0
        self.pos = self.grid.start
        self.sensor_mode = int(sensor_mode)
        self.detected = False
        self.total_true_risk = 0.0
        return {"pos": self.pos, "t": self.t}

    def in_bounds(self, p: Tuple[int, int]) -> bool:
        x, y = p
        return 0 <= x < self.grid.width and 0 <= y < self.grid.height

    def is_free(self, p: Tuple[int, int]) -> bool:
        return self.in_bounds(p) and (p not in self.grid.obstacles)

    def true_detection_prob(self, p: Tuple[int, int], mode: int) -> float:
        if not (0 <= mode < len(self.sensor_cfg.sensors)):
            raise ValueError(f"mode {mode} out of range")
        sx, sy = self.sensor_cfg.sensors[mode]
        x, y = p
        d = abs(x - sx) + abs(y - sy)
        return self.sensor_cfg.hotspot_p if d <= self.sensor_cfg.radius else self.sensor_cfg.base_p

    def observation_prob(self, alarm: int, p_true: float) -> float:
        p_alarm = p_true * (1.0 - self.fn) + (1.0 - p_true) * self.fp
        if alarm == 1:
            return p_alarm
        if alarm == 0:
            return 1.0 - p_alarm
        raise ValueError("alarm must be 0 or 1")

    def step(self, a: Action) -> Dict[str, Any]:
        if self.detected:
            return {"pos": self.pos, "t": self.t, "alarm": 1, "p_true": 1.0, "detected": True, "done": True}

        self.t += 1
        nx = self.pos[0] + int(a[0])
        ny = self.pos[1] + int(a[1])
        np_ = (nx, ny)
        if self.is_free(np_):
            self.pos = np_

        p_true = float(self.true_detection_prob(self.pos, self.sensor_mode))
        self.total_true_risk += p_true

        if self.rng.random() < p_true:
            self.detected = True

        p_alarm = p_true * (1.0 - self.fn) + (1.0 - p_true) * self.fp
        alarm = 1 if (self.rng.random() < p_alarm) else 0

        done = self.detected or (self.pos == self.grid.goal) or (self.t >= 200)
        return {"pos": self.pos, "t": self.t, "alarm": alarm, "p_true": p_true, "detected": self.detected, "done": done}


def build_two_corridor_grid(width: int = 15, height: int = 9) -> GridConfig:
    obstacles = set()
    wall_x = width // 2
    gap_ys = {2, 6}
    for y in range(height):
        if y not in gap_ys:
            obstacles.add((wall_x, y))

    start = (1, height - 2)
    goal = (width - 2, 1)
    if start in obstacles or goal in obstacles:
        raise RuntimeError("Start/goal blocked unexpectedly")

    return GridConfig(width=width, height=height, start=start, goal=goal, obstacles=frozenset(obstacles))


def print_grid_ascii(grid: GridConfig, sensor_cfg: SensorConfig) -> None:
    W, H = grid.width, grid.height
    obs = set(grid.obstacles)
    sens = set(sensor_cfg.sensors)
    for y in range(H):
        row = []
        for x in range(W):
            p = (x, y)
            if p == grid.start:
                row.append("R")
            elif p == grid.goal:
                row.append("G")
            elif p in sens:
                row.append("S")
            elif p in obs:
                row.append("#")
            else:
                row.append(".")
        print("".join(row))


# =============================================================================
# Belief over modes
# =============================================================================

class ModeBelief:
    """Exact belief over discrete modes m in {0..M-1}."""

    def __init__(self, M: int, init: Optional[np.ndarray] = None):
        self.M = int(M)
        if self.M <= 0:
            raise ValueError("M must be >= 1")
        if init is None:
            self.b = np.full(self.M, 1.0 / self.M)
        else:
            init = np.asarray(init, dtype=float).reshape(-1)
            if init.shape != (self.M,):
                raise ValueError("init shape mismatch")
            if np.any(init < 0):
                raise ValueError("init must be nonnegative")
            s = float(init.sum())
            self.b = init / s if s > 0 else np.full(self.M, 1.0 / self.M)

    def update(self, env: GridWorldStealthEnv, alarm: int, pos: Tuple[int, int], eps: float = 1e-12) -> None:
        like = np.zeros(self.M, dtype=float)
        for m in range(self.M):
            p_true = env.true_detection_prob(pos, m)
            like[m] = env.observation_prob(alarm, p_true)
        post = self.b * like
        Z = float(post.sum())
        if (not np.isfinite(Z)) or Z < eps:
            return
        self.b = post / Z


# =============================================================================
# Policies
# =============================================================================

class RobotPolicy:
    name: str = "RobotPolicy"

    def reset(self, start_pos: Tuple[int, int]) -> None:
        pass

    def act(self, env: GridWorldStealthEnv, belief: ModeBelief, last_obs: Optional[int]) -> Action:
        raise NotImplementedError


class FixedPathPolicy(RobotPolicy):
    def __init__(self, path: List[Tuple[int, int]], name: str):
        if len(path) < 2:
            raise ValueError("Path must have >=2 states")
        self.path = list(path)
        self.name = str(name)
        self._idx = 0

    def reset(self, start_pos: Tuple[int, int]) -> None:
        try:
            self._idx = self.path.index(start_pos)
        except ValueError:
            self._idx = 0

    def act(self, env: GridWorldStealthEnv, belief: ModeBelief, last_obs: Optional[int]) -> Action:
        cur = env.pos
        if self._idx >= len(self.path) - 1:
            return (0, 0)
        if cur != self.path[self._idx]:
            try:
                self._idx = self.path.index(cur, self._idx)
            except ValueError:
                return (0, 0)
        nxt = self.path[self._idx + 1]
        dx = int(np.clip(nxt[0] - cur[0], -1, 1))
        dy = int(np.clip(nxt[1] - cur[1], -1, 1))
        self._idx += 1
        return (dx, dy)


class RandomPolicy(RobotPolicy):
    """Reproducible random policy using env.rng."""

    def __init__(self, name: str = "R_Random"):
        self.name = name

    def reset(self, start_pos: Tuple[int, int]) -> None:
        pass

    def act(self, env: GridWorldStealthEnv, belief: ModeBelief, last_obs: Optional[int]) -> Action:
        candidates: List[Action] = []
        x, y = env.pos
        for a in [(1, 0), (-1, 0), (0, 1), (0, -1), (0, 0)]:
            np_ = (x + a[0], y + a[1])
            if env.is_free(np_):
                candidates.append(a)
        if not candidates:
            return (0, 0)
        return candidates[int(env.rng.integers(0, len(candidates)))]


class OnlineBeliefReplanPolicy(RobotPolicy):
    """POMDP-ish heuristic: replan each step using risk map induced by current belief b_t."""

    def __init__(self, env: GridWorldStealthEnv, risk_weight: float = 12.0, name: str = "R_OnlineBeliefReplan"):
        self.env = env
        self.risk_weight = float(risk_weight)
        self.name = name
        self._cached_next: Optional[Tuple[int, int]] = None

    def reset(self, start_pos: Tuple[int, int]) -> None:
        self._cached_next = None

    def act(self, env: GridWorldStealthEnv, belief: ModeBelief, last_obs: Optional[int]) -> Action:
        # Build a risk map from belief over modes.
        mode_probs = belief.b

        def step_cost(frm: Tuple[int, int], to: Tuple[int, int]) -> float:
            x, y = to
            # expected risk at to under belief
            r = 0.0
            for m, pm in enumerate(mode_probs):
                r += float(pm) * float(env.true_detection_prob(to, m))
            return 1.0 + self.risk_weight * r

        # Plan from current pos to goal (one-step receding horizon)
        try:
            path = astar_path(env.grid, env.pos, env.grid.goal, step_cost)
            if len(path) < 2:
                return (0, 0)
            nxt = path[1]
            dx = int(np.clip(nxt[0] - env.pos[0], -1, 1))
            dy = int(np.clip(nxt[1] - env.pos[1], -1, 1))
            return (dx, dy)
        except Exception:
            return (0, 0)


class SensorPolicy:
    name: str = "SensorPolicy"

    def reset(self) -> None:
        pass

    def select_mode(self, t: int) -> int:
        raise NotImplementedError


class FixedModeSensorPolicy(SensorPolicy):
    def __init__(self, mode: int, name: Optional[str] = None):
        self.mode = int(mode)
        self.name = name or f"S_Mode{mode}"

    def reset(self) -> None:
        pass

    def select_mode(self, t: int) -> int:
        return self.mode


class AlternatingSensorPolicy(SensorPolicy):
    """Simple sensor heuristic for benchmarks: alternate modes 0,1,0,1,..."""

    def __init__(self, M: int, name: str = "S_Alternate"):
        self.M = int(M)
        self.name = name

    def reset(self) -> None:
        pass

    def select_mode(self, t: int) -> int:
        return int(t % self.M)


class RandomModeSensorPolicy(SensorPolicy):
    """Benchmark sensor: random mode each step (uses numpy Generator for reproducibility)."""

    def __init__(self, M: int, seed: int = 0, name: str = "S_RandomMode"):
        self.M = int(M)
        self.rng = np.random.default_rng(int(seed))
        self.name = name

    def reset(self) -> None:
        pass

    def select_mode(self, t: int) -> int:
        return int(self.rng.integers(0, self.M))


# =============================================================================
# A* (used for planning)
# =============================================================================

def astar_path(
    grid: GridConfig,
    start: Tuple[int, int],
    goal: Tuple[int, int],
    step_cost: Callable[[Tuple[int, int], Tuple[int, int]], float],
    max_expansions: int = 250_000,
) -> List[Tuple[int, int]]:
    if start == goal:
        return [start]

    def h(p: Tuple[int, int]) -> float:
        return abs(p[0] - goal[0]) + abs(p[1] - goal[1])

    def neighbors(p: Tuple[int, int]):
        x, y = p
        for dx, dy in [(1, 0), (-1, 0), (0, 1), (0, -1)]:
            np_ = (x + dx, y + dy)
            if 0 <= np_[0] < grid.width and 0 <= np_[1] < grid.height and np_ not in grid.obstacles:
                yield np_

    open_heap: List[Tuple[float, float, Tuple[int, int]]] = []
    heapq.heappush(open_heap, (h(start), 0.0, start))

    came: Dict[Tuple[int, int], Optional[Tuple[int, int]]] = {start: None}
    gscore: Dict[Tuple[int, int], float] = {start: 0.0}

    expansions = 0
    while open_heap:
        _, _, cur = heapq.heappop(open_heap)
        expansions += 1
        if cur == goal:
            path: List[Tuple[int, int]] = []
            while cur is not None:
                path.append(cur)
                cur = came[cur]
            path.reverse()
            return path
        if expansions > max_expansions:
            raise RuntimeError("A* exceeded max expansions")

        for nb in neighbors(cur):
            tentative = gscore[cur] + float(step_cost(cur, nb))
            if (nb not in gscore) or (tentative < gscore[nb] - 1e-12):
                gscore[nb] = tentative
                came[nb] = cur
                heapq.heappush(open_heap, (tentative + h(nb), tentative, nb))

    raise RuntimeError("A* failed: unreachable goal")


def astar_via_waypoint(
    grid: GridConfig,
    start: Tuple[int, int],
    waypoint: Tuple[int, int],
    goal: Tuple[int, int],
    step_cost: Callable[[Tuple[int, int], Tuple[int, int]], float],
) -> List[Tuple[int, int]]:
    p1 = astar_path(grid, start, waypoint, step_cost)
    p2 = astar_path(grid, waypoint, goal, step_cost)
    return p1[:-1] + p2


# =============================================================================
# Rollouts + payoffs
# =============================================================================

@dataclass
class EpisodeStats:
    steps: int
    reached_goal: bool
    detected: bool
    total_true_risk: float
    U_R: float
    U_S: float


def rollout_episode(
    env: GridWorldStealthEnv,
    robot: RobotPolicy,
    sensor: SensorPolicy,
    M_modes: int,
    seed: int,
    max_steps: int = 200,
    lambda_risk: float = 1.0,
    det_penalty: float = 50.0,
    sensor_energy_per_step: float = 0.2,
    step_debug: bool = False,
) -> EpisodeStats:
    env.seed(seed)
    sensor.reset()
    env.reset(sensor_mode=sensor.select_mode(0))

    belief = ModeBelief(M_modes)
    robot.reset(env.pos)

    total_risk = 0.0
    last_alarm: Optional[int] = None

    for _ in range(max_steps):
        env.sensor_mode = sensor.select_mode(env.t)
        a = robot.act(env, belief, last_alarm)
        out = env.step(a)

        total_risk += float(out["p_true"])
        belief.update(env, out["alarm"], out["pos"])
        last_alarm = int(out["alarm"])

        if step_debug:
            print(
                f"[Step] t={out['t']:3d} pos={out['pos']} a={a} p_true={out['p_true']:.3f} "
                f"alarm={out['alarm']} det={out['detected']} done={out['done']} b={belief.b.round(3)}"
            )

        if out["done"]:
            break

    reached_goal = (env.pos == env.grid.goal)
    detected = bool(env.detected)
    steps = int(env.t)

    cost_R = steps + lambda_risk * total_risk + (det_penalty if detected else 0.0)
    U_R = -float(cost_R)

    U_S = float((det_penalty if detected else 0.0) + lambda_risk * total_risk - sensor_energy_per_step * steps)

    return EpisodeStats(steps=steps, reached_goal=reached_goal, detected=detected, total_true_risk=float(total_risk), U_R=U_R, U_S=U_S)


def evaluate_payoffs(
    env: GridWorldStealthEnv,
    robots: List[RobotPolicy],
    sensors: List[SensorPolicy],
    M_modes: int,
    rollouts: int,
    base_seed: int,
    log: Logger,
    debug_rollout_pair: Optional[Tuple[int, int]] = None,
) -> Tuple[np.ndarray, np.ndarray, Dict[Tuple[int, int], Dict[str, float]]]:
    m, n = len(robots), len(sensors)
    U_R = np.zeros((m, n), dtype=float)
    U_S = np.zeros((m, n), dtype=float)
    diag: Dict[Tuple[int, int], Dict[str, float]] = {}

    log.info(f"[Eval] Estimating payoffs: m={m}, n={n}, rollouts={rollouts}, base_seed={base_seed}")

    for i, rpol in enumerate(robots):
        for j, spol in enumerate(sensors):
            step_debug = (debug_rollout_pair == (i, j))

            r_list: List[float] = []
            s_list: List[float] = []
            det = 0
            goal = 0
            steps_list: List[int] = []
            risk_list: List[float] = []

            for k in range(rollouts):
                seed = base_seed + 100000 * i + 1000 * j + k
                st = rollout_episode(env, rpol, spol, M_modes=M_modes, seed=seed, step_debug=step_debug)
                r_list.append(st.U_R)
                s_list.append(st.U_S)
                det += int(st.detected)
                goal += int(st.reached_goal)
                steps_list.append(st.steps)
                risk_list.append(st.total_true_risk)

                if step_debug:
                    step_debug = False  # only show one rollout

            U_R[i, j] = float(np.mean(r_list))
            U_S[i, j] = float(np.mean(s_list))

            diag[(i, j)] = {
                "det_rate": det / rollouts,
                "goal_rate": goal / rollouts,
                "mean_steps": float(np.mean(steps_list)),
                "mean_risk": float(np.mean(risk_list)),
                "std_UR": float(np.std(r_list)),
                "std_US": float(np.std(s_list)),
            }

    if log.k >= 1:
        log.info("[Eval] Compact payoff summary:")
        for i, rpol in enumerate(robots):
            for j, spol in enumerate(sensors):
                d = diag[(i, j)]
                log.info(
                    f"  (R{i}:{rpol.name}, S{j}:{spol.name}) "
                    f"UR={U_R[i,j]:8.3f}±{d['std_UR']:.2f} | "
                    f"US={U_S[i,j]:8.3f}±{d['std_US']:.2f} | "
                    f"det%={100*d['det_rate']:5.1f} goal%={100*d['goal_rate']:5.1f} "
                    f"steps={d['mean_steps']:.1f} risk={d['mean_risk']:.2f}"
                )

    return U_R, U_S, diag


# =============================================================================
# NBS solver (with optional entropy regularization)
# =============================================================================

def project_simplex(v: np.ndarray, z: float = 1.0) -> np.ndarray:
    v = np.asarray(v, dtype=float).reshape(-1)
    if v.size == 0:
        raise ValueError("Empty vector")
    if z <= 0:
        raise ValueError("z must be > 0")

    u = np.sort(v)[::-1]
    cssv = np.cumsum(u)
    rho = np.nonzero(u * np.arange(1, v.size + 1) > (cssv - z))[0]
    if rho.size == 0:
        return np.full_like(v, z / v.size)
    rho = int(rho[-1])
    theta = (cssv[rho] - z) / (rho + 1.0)
    w = np.maximum(v - theta, 0.0)
    s = float(w.sum())
    if not np.isfinite(s) or s <= 0:
        return np.full_like(v, z / v.size)
    return w * (z / s)


@dataclass
class NBSResult:
    x: np.ndarray
    obj: float
    gains: Tuple[float, float]
    support: int


def solve_nbs(
    uR: np.ndarray,
    uS: np.ndarray,
    log: Logger,
    max_iters: int = 400,
    alpha: float = 0.5,
    tol_l1: float = 1e-6,
    kappa: float = 1e-6,
    disagreement: str = "minminus",
    entropy_tau: float = 0.0,
) -> NBSResult:
    uR = np.asarray(uR, dtype=float).reshape(-1)
    uS = np.asarray(uS, dtype=float).reshape(-1)
    if uR.shape != uS.shape:
        raise ValueError("uR and uS must have same shape")
    d = uR.size
    if d < 2:
        raise ValueError("Need >=2 joint actions")

    unif = np.full(d, 1.0 / d)

    disagreement = disagreement.lower().strip()
    if disagreement == "minminus":
        dR = float(np.min(uR) - 1.0)
        dS = float(np.min(uS) - 1.0)
    elif disagreement == "uniform":
        dR = float(uR @ unif)
        dS = float(uS @ unif)
    else:
        raise ValueError("disagreement must be 'minminus' or 'uniform'")

    x = unif.copy()

    def gains(xv: np.ndarray) -> Tuple[float, float]:
        return float(uR @ xv - dR), float(uS @ xv - dS)

    def entropy(xv: np.ndarray) -> float:
        xx = np.clip(xv, 1e-12, 1.0)
        return float(-np.sum(xx * np.log(xx)))

    def obj(xv: np.ndarray) -> float:
        gR, gS = gains(xv)
        gR = max(gR, kappa)
        gS = max(gS, kappa)
        return float(np.log(gR) + np.log(gS) + entropy_tau * entropy(xv))

    def grad(xv: np.ndarray) -> np.ndarray:
        gR, gS = gains(xv)
        gR = max(gR, kappa)
        gS = max(gS, kappa)
        g = (uR / gR) + (uS / gS)
        if entropy_tau > 0:
            xx = np.clip(xv, 1e-12, 1.0)
            g += entropy_tau * (-(np.log(xx) + 1.0))
        return g

    last = obj(x)
    log.info(f"[NBS] d={d} disagreement=({dR:.3f},{dS:.3f}) entropy_tau={entropy_tau:.3g}")

    for t in range(1, max_iters + 1):
        g = grad(x)
        a = alpha
        improved = False
        for _ in range(30):
            x_new = project_simplex(x + a * g)
            new_obj = obj(x_new)
            if new_obj >= last - 1e-12:
                improved = True
                break
            a *= 0.5
            if a < 1e-6:
                break
        if not improved:
            break

        delta = float(np.linalg.norm(x_new - x, ord=1))
        x = x_new
        last = new_obj

        if log.k >= 2 and (t <= 5 or t % 25 == 0):
            gR, gS = gains(x)
            top = np.argsort(-x)[:5]
            top_str = ", ".join([f"{i}:{x[i]:.3f}" for i in top])
            log.debug(f"[NBS][it={t:3d}] obj={last:.6f} gains=({gR:.3f},{gS:.3f}) L1={delta:.2e} top={top_str}")

        if delta < tol_l1:
            break

    gR, gS = gains(x)
    support = int(np.sum(x > 1e-6))
    log.info(f"[NBS] done: obj={last:.6f} gains=({gR:.3f},{gS:.3f}) support={support}/{d}")

    return NBSResult(x=x, obj=float(last), gains=(float(gR), float(gS)), support=support)


def joint_to_matrix(x: np.ndarray, m: int, n: int) -> np.ndarray:
    x = np.asarray(x, dtype=float).reshape(-1)
    if x.size != m * n:
        raise ValueError("x size mismatch")
    X = x.reshape((m, n))
    s = float(X.sum())
    if not np.isfinite(s) or abs(s - 1.0) > 1e-6:
        # Renormalize defensively
        X = X / max(s, 1e-12)
    return X


def marginals_from_joint(X: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    sigma_R = X.sum(axis=1)
    sigma_S = X.sum(axis=0)
    if sigma_R.sum() > 0:
        sigma_R = sigma_R / sigma_R.sum()
    if sigma_S.sum() > 0:
        sigma_S = sigma_S / sigma_S.sum()
    return sigma_R, sigma_S


def entropy_of_joint(X: np.ndarray) -> float:
    xx = np.clip(X.reshape(-1), 1e-12, 1.0)
    return float(-np.sum(xx * np.log(xx)))


# =============================================================================
# Best responses (marginal vs correlated)
# =============================================================================

def compute_expected_risk_map_from_policy_mixture(
    env: GridWorldStealthEnv,
    sensors: List[SensorPolicy],
    pi: np.ndarray,
    M_modes: int,
) -> np.ndarray:
    """Expected p_true(cell) under mixture over sensor POLICIES (not modes).

    For each sensor policy j, we use its mode at t=0 as its defining mode.
    (This matches FixedModeSensorPolicy exactly.)
    """
    pi = np.asarray(pi, dtype=float).reshape(-1)
    if pi.size != len(sensors):
        raise ValueError("mixture length mismatch")

    mode_probs = np.zeros(M_modes, dtype=float)
    for j, sp in enumerate(sensors):
        m = int(sp.select_mode(0))
        if not (0 <= m < M_modes):
            raise ValueError("invalid sensor mode")
        mode_probs[m] += float(pi[j])
    if mode_probs.sum() > 0:
        mode_probs = mode_probs / mode_probs.sum()

    H, W = env.grid.height, env.grid.width
    risk = np.zeros((H, W), dtype=float)
    for y in range(H):
        for x in range(W):
            if (x, y) in env.grid.obstacles:
                risk[y, x] = np.nan
                continue
            val = 0.0
            for m in range(M_modes):
                val += float(mode_probs[m]) * float(env.true_detection_prob((x, y), m))
            risk[y, x] = float(val)

    return risk


def plan_risk_weighted_path(env: GridWorldStealthEnv, risk_map: np.ndarray, risk_weight: float) -> List[Tuple[int, int]]:
    def step_cost(frm: Tuple[int, int], to: Tuple[int, int]) -> float:
        x, y = to
        r = risk_map[y, x]
        if not np.isfinite(r):
            return 1e9
        return 1.0 + float(risk_weight) * float(r)

    return astar_path(env.grid, env.grid.start, env.grid.goal, step_cost)


def robot_best_response_to_mixture(
    env: GridWorldStealthEnv,
    sensors: List[SensorPolicy],
    pi_S: np.ndarray,
    robots: List[RobotPolicy],
    M_modes: int,
    risk_weight: float,
    log: Logger,
    tag: str,
) -> RobotPolicy:
    risk = compute_expected_risk_map_from_policy_mixture(env, sensors, pi_S, M_modes=M_modes)
    try:
        path = plan_risk_weighted_path(env, risk, risk_weight=risk_weight)
    except Exception as e:
        log.info(f"[RobotBR] WARNING A* failed ({tag}): {e}")
        return robots[0]

    path_tuple = tuple(path)
    for p in robots:
        if isinstance(p, FixedPathPolicy) and tuple(p.path) == path_tuple:
            log.info(f"[RobotBR] ({tag}) BR path already exists: {p.name}")
            return p

    newp = FixedPathPolicy(path, name=f"R_BR_{tag}_w{risk_weight:.1f}_len{len(path)}")
    log.info(f"[RobotBR] ({tag}) Added new robot policy: {newp.name}")
    return newp


def sensor_best_response_to_mixture(
    env: GridWorldStealthEnv,
    robots: List[RobotPolicy],
    pi_R: np.ndarray,
    candidate_modes: List[int],
    M_modes: int,
    rollouts: int,
    base_seed: int,
    log: Logger,
    tag: str,
) -> FixedModeSensorPolicy:
    pi_R = np.asarray(pi_R, dtype=float).reshape(-1)
    if pi_R.size != len(robots):
        raise ValueError("pi_R length mismatch")

    rng = np.random.default_rng(int(base_seed))

    best_mode: Optional[int] = None
    best_val = -1e18

    for mode in candidate_modes:
        if not (0 <= mode < M_modes):
            continue
        sp = FixedModeSensorPolicy(mode, name=f"S_BR_{tag}_Mode{mode}")
        vals: List[float] = []
        for k in range(rollouts):
            i = int(rng.choice(len(robots), p=pi_R))
            rp = robots[i]
            seed = base_seed + 10000 * mode + k
            st = rollout_episode(env, rp, sp, M_modes=M_modes, seed=seed)
            vals.append(st.U_S)
        mean_u = float(np.mean(vals))
        if log.k >= 2:
            log.debug(f"[SensorBR] ({tag}) mode={mode} E[US]={mean_u:.3f} std={float(np.std(vals)):.2f}")
        if mean_u > best_val:
            best_val = mean_u
            best_mode = mode

    if best_mode is None:
        raise RuntimeError("No valid sensor BR mode found")

    log.info(f"[SensorBR] ({tag}) Best mode={best_mode} E[US]={best_val:.3f}")
    return FixedModeSensorPolicy(best_mode, name=f"S_BR_{tag}_Mode{best_mode}")


def conditional_sensor_given_robot(X: np.ndarray, i: int, eps: float = 1e-12) -> np.ndarray:
    row = np.asarray(X[i, :], dtype=float)
    s = float(row.sum())
    if s <= eps:
        return np.full_like(row, 1.0 / row.size)
    return row / s


def conditional_robot_given_sensor(X: np.ndarray, j: int, eps: float = 1e-12) -> np.ndarray:
    col = np.asarray(X[:, j], dtype=float)
    s = float(col.sum())
    if s <= eps:
        return np.full_like(col, 1.0 / col.size)
    return col / s


def compute_ce_regrets(U_R: np.ndarray, U_S: np.ndarray, X: np.ndarray, eps: float = 1e-12) -> Dict[str, float]:
    """Conditional recommendation regrets (CE-style) computed on current meta-game.

    For robot (given recommendation i):
        regret_R(i) = max_{i'} E_{j~q(.|i)}[U_R(i',j) - U_R(i,j)]

    For sensor (given recommendation j):
        regret_S(j) = max_{j'} E_{i~q(.|j)}[U_S(i,j') - U_S(i,j)]

    Returns max and average regrets.
    """
    m, n = U_R.shape
    assert U_S.shape == (m, n)
    assert X.shape == (m, n)

    sigma_R, sigma_S = marginals_from_joint(X)

    reg_R = []
    for i in range(m):
        if sigma_R[i] <= eps:
            continue
        q = conditional_sensor_given_robot(X, i, eps=eps)
        rec = float(np.dot(q, U_R[i, :]))
        best = rec
        for ip in range(m):
            val = float(np.dot(q, U_R[ip, :]))
            if val > best:
                best = val
        reg_R.append(best - rec)

    reg_S = []
    for j in range(n):
        if sigma_S[j] <= eps:
            continue
        q = conditional_robot_given_sensor(X, j, eps=eps)
        rec = float(np.dot(q, U_S[:, j]))
        best = rec
        for jp in range(n):
            val = float(np.dot(q, U_S[:, jp]))
            if val > best:
                best = val
        reg_S.append(best - rec)

    return {
        "max_regret_R": float(max(reg_R) if reg_R else 0.0),
        "max_regret_S": float(max(reg_S) if reg_S else 0.0),
        "mean_regret_R": float(np.mean(reg_R) if reg_R else 0.0),
        "mean_regret_S": float(np.mean(reg_S) if reg_S else 0.0),
    }


def find_sensor_by_mode(pols: List[SensorPolicy], mode: int) -> Optional[FixedModeSensorPolicy]:
    for p in pols:
        if isinstance(p, FixedModeSensorPolicy) and p.mode == mode:
            return p
    return None


# =============================================================================
# Policy initialization + evaluator for joint strategy
# =============================================================================

def build_initial_policies(env: GridWorldStealthEnv, M_modes: int) -> Tuple[List[RobotPolicy], List[SensorPolicy]]:
    grid = env.grid

    step_cost = lambda a, b: 1.0
    p_short = astar_path(grid, grid.start, grid.goal, step_cost)

    wall_x = grid.width // 2
    p_upper = astar_via_waypoint(grid, grid.start, (wall_x, 2), grid.goal, step_cost)
    p_lower = astar_via_waypoint(grid, grid.start, (wall_x, 6), grid.goal, step_cost)

    robots: List[RobotPolicy] = [
        FixedPathPolicy(p_short, "R_Shortest"),
        FixedPathPolicy(p_upper, "R_UpperCorridor"),
        FixedPathPolicy(p_lower, "R_LowerCorridor"),
        RandomPolicy("R_Random"),
        OnlineBeliefReplanPolicy(env, risk_weight=12.0, name="R_OnlineBeliefReplan"),
    ]
    sensors: List[SensorPolicy] = [FixedModeSensorPolicy(m, name=f"S_Mode{m}") for m in range(M_modes)]
    return robots, sensors


@dataclass
class StrategyEval:
    mean_U_R: float
    mean_U_S: float
    det_rate: float
    goal_rate: float
    mean_steps: float
    mean_risk: float


def evaluate_joint_strategy(
    env: GridWorldStealthEnv,
    robots: List[RobotPolicy],
    sensors: List[SensorPolicy],
    X: np.ndarray,
    M_modes: int,
    episodes: int,
    base_seed: int,
) -> StrategyEval:
    m, n = X.shape
    probs = X.reshape(-1)
    probs = probs / max(float(probs.sum()), 1e-12)

    rng = np.random.default_rng(int(base_seed))

    UR = []
    US = []
    det = 0
    goal = 0
    steps_list = []
    risk_list = []

    for k in range(episodes):
        idx = int(rng.choice(m * n, p=probs))
        i, j = np.unravel_index(idx, (m, n))
        seed = base_seed + k
        st = rollout_episode(env, robots[i], sensors[j], M_modes=M_modes, seed=seed)
        UR.append(st.U_R)
        US.append(st.U_S)
        det += int(st.detected)
        goal += int(st.reached_goal)
        steps_list.append(st.steps)
        risk_list.append(st.total_true_risk)

    return StrategyEval(
        mean_U_R=float(np.mean(UR)),
        mean_U_S=float(np.mean(US)),
        det_rate=float(det / episodes),
        goal_rate=float(goal / episodes),
        mean_steps=float(np.mean(steps_list)),
        mean_risk=float(np.mean(risk_list)),
    )


# =============================================================================
# Training loop (marginal vs correlated)
# =============================================================================

@dataclass
class TrainHistoryRow:
    outer_iter: int
    m: int
    n: int
    nbs_obj: float
    entropy_X: float
    max_regret_R: float
    max_regret_S: float
    selfplay_UR: float
    selfplay_US: float
    selfplay_det: float
    selfplay_goal: float
    seconds: float


@dataclass
class TrainResult:
    solver: str
    env: GridWorldStealthEnv
    robots: List[RobotPolicy]
    sensors: List[SensorPolicy]
    X: np.ndarray
    history: List[TrainHistoryRow]


def run_training(env: GridWorldStealthEnv, args: argparse.Namespace, solver: str, log: Logger) -> TrainResult:
    t0_all = time.time()

    grid = env.grid
    sensor_cfg = env.sensor_cfg
    M_modes = len(sensor_cfg.sensors)

    robots, sensors = build_initial_policies(env, M_modes=M_modes)

    # Optional: reduce initial set if you want smaller games.
    # (We keep it as-is for benchmarks.)

    if log.k >= 1:
        log.info(f"[{solver}] Initial robots: " + ", ".join([p.name for p in robots]))
        log.info(f"[{solver}] Initial sensors: " + ", ".join([p.name for p in sensors]))

    debug_pair = None
    if args.debug_rollout_pair:
        parts = args.debug_rollout_pair.split(",")
        if len(parts) == 2:
            debug_pair = (int(parts[0]), int(parts[1]))
            log.info(f"[{solver}] Will print one step-by-step rollout for pair {debug_pair} (only once).")

    history: List[TrainHistoryRow] = []
    X = None

    for it in range(1, args.outer_iters + 1):
        t0 = time.time()
        log.banner(f"[{solver}] Outer iter {it}/{args.outer_iters}")

        U_R, U_S, _diag = evaluate_payoffs(
            env,
            robots,
            sensors,
            M_modes=M_modes,
            rollouts=args.rollouts_payoff,
            base_seed=1000 + 100 * it,
            log=log,
            debug_rollout_pair=debug_pair,
        )
        debug_pair = None

        # Solve NBS over joint actions
        uR = U_R.reshape(-1)
        uS = U_S.reshape(-1)

        nbs = solve_nbs(
            uR,
            uS,
            log=log,
            disagreement=args.disagreement,
            entropy_tau=args.entropy_tau,
        )

        m, n = U_R.shape
        X = joint_to_matrix(nbs.x, m, n)
        sigma_R, sigma_S = marginals_from_joint(X)

        # Print top joint actions
        top = np.argsort(-X.reshape(-1))[:min(5, X.size)]
        log.info(f"[{solver}] Top joint actions:")
        for k, idx in enumerate(top, start=1):
            i, j = np.unravel_index(int(idx), (m, n))
            log.info(f"  #{k}: (R{i}:{robots[i].name}, S{j}:{sensors[j].name}) prob={X[i,j]:.4f}")
        log.info(f"[{solver}] sigma_R={sigma_R.round(3)}")
        log.info(f"[{solver}] sigma_S={sigma_S.round(3)}")

        # Stability diagnostics
        regrets = compute_ce_regrets(U_R, U_S, X)
        ent = entropy_of_joint(X)

        # Self-play evaluation under joint X
        sp = evaluate_joint_strategy(
            env,
            robots,
            sensors,
            X,
            M_modes=M_modes,
            episodes=args.eval_episodes,
            base_seed=9000 + 100 * it,
        )

        log.info(
            f"[{solver}] CE-regrets: maxR={regrets['max_regret_R']:.3f} maxS={regrets['max_regret_S']:.3f} | "
            f"SelfPlay: UR={sp.mean_U_R:.2f} US={sp.mean_U_S:.2f} det%={100*sp.det_rate:.1f} goal%={100*sp.goal_rate:.1f} | "
            f"H(X)={ent:.3f}"
        )

        # Best-response expansion
        if solver == "marginal":
            br_r = robot_best_response_to_mixture(
                env,
                sensors,
                pi_S=sigma_S,
                robots=robots,
                M_modes=M_modes,
                risk_weight=args.risk_weight_br,
                log=log,
                tag="Marginal",
            )

            br_s = sensor_best_response_to_mixture(
                env,
                robots,
                pi_R=sigma_R,
                candidate_modes=list(range(M_modes)),
                M_modes=M_modes,
                rollouts=args.rollouts_br,
                base_seed=2000 + 100 * it,
                log=log,
                tag="Marginal",
            )

        elif solver == "correlated":
            # FIX: choose conditional mixtures q(.|i) and q(.|j)
            # We only check top-K recommendations to keep it tractable.
            topK = max(1, int(args.cond_top_k))

            # Robot: check the top-K robot recommendations by sigma_R
            cand_i = list(np.argsort(-sigma_R)[:topK])
            best_gain = 0.0
            br_r = robots[0]

            for i in cand_i:
                qS = conditional_sensor_given_robot(X, int(i))
                tag = f"Cond_i{i}"
                pol = robot_best_response_to_mixture(
                    env,
                    sensors,
                    pi_S=qS,
                    robots=robots,
                    M_modes=M_modes,
                    risk_weight=args.risk_weight_br,
                    log=log,
                    tag=tag,
                )

                # Estimate conditional improvement: E_q[U_R(pol,j)] - E_q[U_R(i,j)]
                # If pol is already in set, its payoff exists in U_R row of that policy.
                # Otherwise we simulate pol against each sensor policy j.
                if pol in robots:
                    ip = robots.index(pol)
                    dev = float(np.dot(qS, U_R[ip, :]))
                else:
                    # simulate quickly vs each sensor policy
                    dev_vals = []
                    for j in range(n):
                        vals = []
                        for kk in range(args.br_eval_rollouts):
                            seed = 777000 + 1000 * it + 100 * i + 10 * j + kk
                            st = rollout_episode(env, pol, sensors[j], M_modes=M_modes, seed=seed)
                            vals.append(st.U_R)
                        dev_vals.append(float(np.mean(vals)))
                    dev = float(np.dot(qS, np.asarray(dev_vals)))

                rec = float(np.dot(qS, U_R[i, :]))
                gain = dev - rec
                if gain > best_gain + 1e-9:
                    best_gain = gain
                    br_r = pol

            if best_gain > args.add_threshold:
                log.info(f"[{solver}] Adding robot deviation with estimated conditional gain={best_gain:.3f}")
            else:
                log.info(f"[{solver}] No robot deviation above threshold (best_gain={best_gain:.3f}).")

            # Sensor: check top-K sensor recommendations by sigma_S
            cand_j = list(np.argsort(-sigma_S)[:topK])
            best_gain_s = 0.0
            br_s = FixedModeSensorPolicy(0, name="S_dummy")

            for j in cand_j:
                qR = conditional_robot_given_sensor(X, int(j))
                tag = f"Cond_j{j}"
                polS = sensor_best_response_to_mixture(
                    env,
                    robots,
                    pi_R=qR,
                    candidate_modes=list(range(M_modes)),
                    M_modes=M_modes,
                    rollouts=args.rollouts_br,
                    base_seed=333000 + 1000 * it + 10 * j,
                    log=log,
                    tag=tag,
                )

                # Estimate conditional improvement for sensor
                # rec under recommendation j is E_q[U_S(i,j)]
                recS = float(np.dot(qR, U_S[:, j]))

                # dev under mode polS.mode: if already present, use its column.
                existing_col = None
                for jj, spj in enumerate(sensors):
                    if isinstance(spj, FixedModeSensorPolicy) and spj.mode == polS.mode:
                        existing_col = jj
                        break

                if existing_col is not None:
                    devS = float(np.dot(qR, U_S[:, existing_col]))
                else:
                    vals = []
                    for kk in range(args.br_eval_rollouts):
                        i_samp = int(np.random.default_rng(444 + kk).choice(len(robots), p=qR))
                        seed = 888000 + 1000 * it + 10 * j + kk
                        st = rollout_episode(env, robots[i_samp], polS, M_modes=M_modes, seed=seed)
                        vals.append(st.U_S)
                    devS = float(np.mean(vals))

                gainS = devS - recS
                if gainS > best_gain_s + 1e-9:
                    best_gain_s = gainS
                    br_s = polS

            if best_gain_s > args.add_threshold:
                log.info(f"[{solver}] Adding sensor deviation with estimated conditional gain={best_gain_s:.3f}")
            else:
                log.info(f"[{solver}] No sensor deviation above threshold (best_gain={best_gain_s:.3f}).")

        else:
            raise ValueError("solver must be marginal or correlated")

        # Add to sets (dedupe)
        if br_r not in robots:
            robots.append(br_r)

        if isinstance(br_s, FixedModeSensorPolicy):
            if find_sensor_by_mode(sensors, br_s.mode) is None:
                sensors.append(br_s)
            else:
                log.info(f"[{solver}] Sensor mode {br_s.mode} already present; not adding duplicate.")

        seconds = float(time.time() - t0)
        history.append(
            TrainHistoryRow(
                outer_iter=it,
                m=len(robots),
                n=len(sensors),
                nbs_obj=float(nbs.obj),
                entropy_X=float(ent),
                max_regret_R=float(regrets["max_regret_R"]),
                max_regret_S=float(regrets["max_regret_S"]),
                selfplay_UR=float(sp.mean_U_R),
                selfplay_US=float(sp.mean_U_S),
                selfplay_det=float(sp.det_rate),
                selfplay_goal=float(sp.goal_rate),
                seconds=seconds,
            )
        )

        log.info(f"[{solver}] Sets: |Pi_R|={len(robots)} |Pi_S|={len(sensors)} | iter_seconds={seconds:.2f}")

    if X is None:
        raise RuntimeError("Training produced no X")

    log.banner(f"[{solver}] Finished")
    log.info(f"[{solver}] Final robots: " + ", ".join([p.name for p in robots]))
    log.info(f"[{solver}] Final sensors: " + ", ".join([p.name for p in sensors]))
    log.info(f"[{solver}] Total time: {time.time()-t0_all:.2f}s")

    return TrainResult(solver=solver, env=env, robots=robots, sensors=sensors, X=X, history=history)


# =============================================================================
# Benchmarks + plotting
# =============================================================================

@dataclass
class BenchCell:
    UR: float
    US: float
    det: float
    goal: float


def run_policy_matrix_benchmark(
    env: GridWorldStealthEnv,
    robot_methods: List[RobotPolicy],
    sensor_methods: List[SensorPolicy],
    M_modes: int,
    episodes: int,
    base_seed: int,
) -> Dict[Tuple[int, int], BenchCell]:
    res: Dict[Tuple[int, int], BenchCell] = {}
    for i, rp in enumerate(robot_methods):
        for j, sp in enumerate(sensor_methods):
            UR = []
            US = []
            det = 0
            goal = 0
            for k in range(episodes):
                seed = base_seed + 100000 * i + 1000 * j + k
                st = rollout_episode(env, rp, sp, M_modes=M_modes, seed=seed)
                UR.append(st.U_R)
                US.append(st.U_S)
                det += int(st.detected)
                goal += int(st.reached_goal)
            res[(i, j)] = BenchCell(
                UR=float(np.mean(UR)),
                US=float(np.mean(US)),
                det=float(det / episodes),
                goal=float(goal / episodes),
            )
    return res


def safe_makedirs(path: str) -> None:
    os.makedirs(path, exist_ok=True)


def save_history_csv(hist: List[TrainHistoryRow], path: str) -> None:
    import csv

    fields = list(TrainHistoryRow.__annotations__.keys())
    with open(path, "w", newline="") as f:
        w = csv.DictWriter(f, fieldnames=fields)
        w.writeheader()
        for row in hist:
            w.writerow({k: getattr(row, k) for k in fields})


def plot_training_curves(results: List[TrainResult], outdir: str) -> None:
    import matplotlib.pyplot as plt

    safe_makedirs(outdir)

    # 1) NBS objective
    plt.figure()
    for r in results:
        xs = [h.outer_iter for h in r.history]
        ys = [h.nbs_obj for h in r.history]
        plt.plot(xs, ys, marker="o", label=r.solver)
    plt.xlabel("Outer iteration")
    plt.ylabel("NBS objective")
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(outdir, "curve_nbs_obj.png"), dpi=200)
    plt.close()

    # 2) Max conditional regrets
    plt.figure()
    for r in results:
        xs = [h.outer_iter for h in r.history]
        ys = [h.max_regret_R for h in r.history]
        plt.plot(xs, ys, marker="o", label=f"{r.solver}: maxRegret_R")
    for r in results:
        xs = [h.outer_iter for h in r.history]
        ys = [h.max_regret_S for h in r.history]
        plt.plot(xs, ys, marker="x", linestyle="--", label=f"{r.solver}: maxRegret_S")
    plt.xlabel("Outer iteration")
    plt.ylabel("Max conditional regret")
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(outdir, "curve_max_regrets.png"), dpi=200)
    plt.close()

    # 3) Entropy of X
    plt.figure()
    for r in results:
        xs = [h.outer_iter for h in r.history]
        ys = [h.entropy_X for h in r.history]
        plt.plot(xs, ys, marker="o", label=r.solver)
    plt.xlabel("Outer iteration")
    plt.ylabel("Entropy H(X)")
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(outdir, "curve_entropy_X.png"), dpi=200)
    plt.close()

    # 4) Self-play outcomes
    plt.figure()
    for r in results:
        xs = [h.outer_iter for h in r.history]
        ys = [h.selfplay_UR for h in r.history]
        plt.plot(xs, ys, marker="o", label=f"{r.solver}: UR")
    for r in results:
        xs = [h.outer_iter for h in r.history]
        ys = [h.selfplay_US for h in r.history]
        plt.plot(xs, ys, marker="x", linestyle="--", label=f"{r.solver}: US")
    plt.xlabel("Outer iteration")
    plt.ylabel("Expected utility under X")
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(outdir, "curve_selfplay_utils.png"), dpi=200)
    plt.close()


def plot_benchmark_bars(
    robot_names: List[str],
    sensor_names: List[str],
    bench: Dict[Tuple[int, int], BenchCell],
    outdir: str,
) -> None:
    import matplotlib.pyplot as plt

    safe_makedirs(outdir)

    # For each sensor, bar chart of robot UR
    for j, sname in enumerate(sensor_names):
        plt.figure(figsize=(10, 4))
        vals = [bench[(i, j)].UR for i in range(len(robot_names))]
        plt.bar(range(len(robot_names)), vals)
        plt.xticks(range(len(robot_names)), robot_names, rotation=35, ha="right")
        plt.ylabel("Robot utility")
        plt.title(f"Robot utility vs sensor={sname}")
        plt.tight_layout()
        plt.savefig(os.path.join(outdir, f"bench_UR_vs_{sname}.png"), dpi=200)
        plt.close()

    # For each sensor, bar chart of robot goal rate
    for j, sname in enumerate(sensor_names):
        plt.figure(figsize=(10, 4))
        vals = [bench[(i, j)].goal for i in range(len(robot_names))]
        plt.bar(range(len(robot_names)), vals)
        plt.xticks(range(len(robot_names)), robot_names, rotation=35, ha="right")
        plt.ylabel("Goal rate")
        plt.title(f"Robot goal rate vs sensor={sname}")
        plt.tight_layout()
        plt.savefig(os.path.join(outdir, f"bench_goal_vs_{sname}.png"), dpi=200)
        plt.close()


# =============================================================================
# Main pipeline wrapper
# =============================================================================

def run_pipeline(args: argparse.Namespace) -> None:
    log = Logger(args.log_level)

    grid = build_two_corridor_grid()
    sensor_cfg = SensorConfig(
        sensors=((grid.width // 2, 2), (grid.width // 2, 6)),
        radius=2,
        base_p=0.02,
        hotspot_p=0.60,
    )
    env = GridWorldStealthEnv(grid, sensor_cfg, fp=0.05, fn=0.10, seed=args.seed)

    if log.k >= 1:
        log.banner("[PIPELINE] Stealth grid game")
        log.info(f"Grid: {grid.width}x{grid.height} start={grid.start} goal={grid.goal} obstacles={len(grid.obstacles)}")
        log.info(f"Sensors: {sensor_cfg.sensors} radius={sensor_cfg.radius} base_p={sensor_cfg.base_p} hotspot_p={sensor_cfg.hotspot_p}")
        if log.k >= 2:
            log.debug("ASCII map:")
            print_grid_ascii(grid, sensor_cfg)

    solvers: List[str]
    if args.solver == "both":
        solvers = ["marginal", "correlated"]
    else:
        solvers = [args.solver]

    results: List[TrainResult] = []
    for s in solvers:
        # Use a fresh env copy per solver (to avoid RNG coupling)
        env_s = GridWorldStealthEnv(grid, sensor_cfg, fp=0.05, fn=0.10, seed=args.seed)
        results.append(run_training(env_s, args, solver=s, log=log))

    if args.results_dir:
        safe_makedirs(args.results_dir)
        for r in results:
            save_history_csv(r.history, os.path.join(args.results_dir, f"history_{r.solver}.csv"))

    if args.save_plots and args.results_dir:
        plot_training_curves(results, outdir=args.results_dir)

    # Benchmarks on the same game
    if args.run_benchmarks:
        log.banner("[BENCH] Heuristics on same game")
        M_modes = len(sensor_cfg.sensors)

        # Robot heuristics
        #  - fixed paths from initial set
        robots_init, sensors_init = build_initial_policies_for_bench(env, M_modes)

        # Sensor heuristics
        sensor_methods: List[SensorPolicy] = [
            FixedModeSensorPolicy(0, "S_Mode0"),
            FixedModeSensorPolicy(1, "S_Mode1"),
            AlternatingSensorPolicy(M_modes, "S_Alternate"),
            RandomModeSensorPolicy(M_modes, seed=args.seed + 123, name="S_RandomMode"),
        ]

        bench = run_policy_matrix_benchmark(
            env,
            robot_methods=robots_init,
            sensor_methods=sensor_methods,
            M_modes=M_modes,
            episodes=args.bench_episodes,
            base_seed=555000,
        )

        # Print a compact table
        robot_names = [p.name for p in robots_init]
        sensor_names = [p.name for p in sensor_methods]
        for j, sname in enumerate(sensor_names):
            log.info(f"[BENCH] Sensor={sname}")
            for i, rname in enumerate(robot_names):
                cell = bench[(i, j)]
                log.info(f"  Robot={rname:22s} UR={cell.UR:8.2f} goal%={100*cell.goal:5.1f} det%={100*cell.det:5.1f}")

        if args.save_plots and args.results_dir:
            plot_benchmark_bars(robot_names, sensor_names, bench, outdir=args.results_dir)


# Helper: initial robot set for benchmark (without the solver-added BR policies)

def build_initial_policies_for_bench(env: GridWorldStealthEnv, M_modes: int) -> Tuple[List[RobotPolicy], List[SensorPolicy]]:
    grid = env.grid
    step_cost = lambda a, b: 1.0
    p_short = astar_path(grid, grid.start, grid.goal, step_cost)

    wall_x = grid.width // 2
    p_upper = astar_via_waypoint(grid, grid.start, (wall_x, 2), grid.goal, step_cost)
    p_lower = astar_via_waypoint(grid, grid.start, (wall_x, 6), grid.goal, step_cost)

    # Static risk-aware A* under UNIFORM mode belief
    uniform_mode = np.full(M_modes, 1.0 / M_modes)

    # Construct risk map directly under uniform belief
    H, W = grid.height, grid.width
    risk_uniform = np.zeros((H, W), dtype=float)
    risk_worst = np.zeros((H, W), dtype=float)
    for y in range(H):
        for x in range(W):
            if (x, y) in grid.obstacles:
                risk_uniform[y, x] = np.nan
                risk_worst[y, x] = np.nan
                continue
            vals = [env.true_detection_prob((x, y), m) for m in range(M_modes)]
            risk_uniform[y, x] = float(np.dot(uniform_mode, vals))
            risk_worst[y, x] = float(np.max(vals))

    p_uniform = plan_risk_weighted_path(env, risk_uniform, risk_weight=12.0)
    p_worst = plan_risk_weighted_path(env, risk_worst, risk_weight=12.0)

    robots: List[RobotPolicy] = [
        FixedPathPolicy(p_short, "R_Shortest"),
        FixedPathPolicy(p_upper, "R_UpperCorridor"),
        FixedPathPolicy(p_lower, "R_LowerCorridor"),
        FixedPathPolicy(p_uniform, "R_RiskAStar_Uniform"),
        FixedPathPolicy(p_worst, "R_RiskAStar_WorstCase"),
        OnlineBeliefReplanPolicy(env, risk_weight=12.0, name="R_OnlineBeliefReplan"),
        RandomPolicy("R_Random"),
    ]

    sensors: List[SensorPolicy] = [FixedModeSensorPolicy(m, name=f"S_Mode{m}") for m in range(M_modes)]
    return robots, sensors


# =============================================================================
# CLI
# =============================================================================

def parse_args(argv: Optional[List[str]] = None) -> Tuple[argparse.Namespace, List[str]]:
    p = argparse.ArgumentParser()

    p.add_argument("--seed", type=int, default=0)
    p.add_argument("--log-level", type=str, default="INFO", choices=["QUIET", "INFO", "DEBUG"])

    p.add_argument("--solver", type=str, default="correlated", choices=["marginal", "correlated", "both"])

    p.add_argument("--outer-iters", type=int, default=3)

    p.add_argument("--rollouts-payoff", type=int, default=20)
    p.add_argument("--rollouts-br", type=int, default=30)
    p.add_argument("--risk-weight-br", type=float, default=12.0)

    # NBS knobs
    p.add_argument("--disagreement", type=str, default="minminus", choices=["minminus", "uniform"])
    p.add_argument("--entropy-tau", type=float, default=0.0)

    # Fix mismatch knobs
    p.add_argument("--cond-top-k", type=int, default=2, help="How many top recommendations to check for conditional BRs")
    p.add_argument("--br-eval-rollouts", type=int, default=8, help="Small evaluation rollouts for new deviations")
    p.add_argument("--add-threshold", type=float, default=0.25, help="Minimum estimated conditional gain to add a deviation policy")

    # Eval episodes under joint X (self-play)
    p.add_argument("--eval-episodes", type=int, default=60)

    # Debug
    p.add_argument("--debug-rollout-pair", type=str, default="", help="Print one step-by-step rollout for i,j")

    # Outputs
    p.add_argument("--results-dir", type=str, default="results", help="Directory for csv/plots")
    p.add_argument("--save-plots", action="store_true")

    # Benchmarks
    p.add_argument("--run-benchmarks", action="store_true")
    p.add_argument("--bench-episodes", type=int, default=80)

    args, unknown = p.parse_known_args(args=argv)
    return args, unknown


def main(argv: Optional[List[str]] = None) -> None:
    args, unknown = parse_args(argv=argv)
    if unknown and ("ipykernel" not in sys.modules):
        print(f"[WARN] Ignoring unknown CLI args: {unknown}")
    if args.debug_rollout_pair.strip() == "":
        args.debug_rollout_pair = ""
    run_pipeline(args)


if __name__ == "__main__" and ("ipykernel" not in sys.modules):
    main()


In [18]:
#!/usr/bin/env python3
"""approach2_robust_correlated.py

Robust, reduced-log implementation of "Approach 2" on a stealth gridworld POMDP,
with a FIX for the conceptual mismatch:

  Old mismatch: solve a joint distribution x*(i,j) over (robot policy i, sensor policy j)
  but then compute best responses against the MARGINALS sigma_R, sigma_S.

  Fix here: treat x* as a CORRELATION DEVICE (mediator) and compute BRs against
  CONDITIONAL distributions:

      q_S(.|i) = X[i,:] / sigma_R[i]   (sensor conditional given robot recommendation i)
      q_R(.|j) = X[:,j] / sigma_S[j]   (robot conditional given sensor recommendation j)

  Then add the most profitable *deviation* policy (oracle) based on the most violated
  conditional recommendation.

This makes the PSRO-style expansion step consistent with a correlated-strategy viewpoint.

Also included:
  - Benchmark harness vs simple heuristics on the same game.
  - Saved plots (training curves + bar charts) for presentations.

Run (script):
  python approach2_robust_correlated.py --solver correlated --outer-iters 3 --save-plots

Run (notebook):
  main(argv=["--solver","both","--outer-iters","3","--save-plots"])  # ignore ipykernel args

Tips to encourage MULTIMODAL x* (useful for your "mode recovery"):
  --disagreement uniform --entropy-tau 0.02

"""

from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Tuple

import argparse
import sys
import heapq
import os
import time

import numpy as np


# =============================================================================
# Logging
# =============================================================================

class Logger:
    """Stage-based logger with QUIET/INFO/DEBUG."""

    LEVELS = {"QUIET": 0, "INFO": 1, "DEBUG": 2}

    def __init__(self, level: str = "INFO"):
        level = level.upper()
        if level not in self.LEVELS:
            raise ValueError(f"Unknown log level: {level}. Use QUIET/INFO/DEBUG")
        self.level = level
        self.k = self.LEVELS[level]

    def banner(self, title: str) -> None:
        if self.k >= 1:
            print("" + "=" * 100)
            print(title)
            print("=" * 100)

    def info(self, msg: str) -> None:
        if self.k >= 1:
            print(msg)

    def debug(self, msg: str) -> None:
        if self.k >= 2:
            print(msg)


# =============================================================================
# Types
# =============================================================================

Action = Tuple[int, int]  # (dx, dy)


# =============================================================================
# Environment (grid + hidden sensor mode + noisy alarm observation)
# =============================================================================

@dataclass(frozen=True)
class GridConfig:
    width: int
    height: int
    start: Tuple[int, int]
    goal: Tuple[int, int]
    obstacles: frozenset


@dataclass(frozen=True)
class SensorConfig:
    sensors: Tuple[Tuple[int, int], ...]
    radius: int
    base_p: float
    hotspot_p: float


class GridWorldStealthEnv:
    """Grid world with detection risk controlled by a hidden/selected sensor 'mode'."""

    def __init__(self, grid: GridConfig, sensor_cfg: SensorConfig, fp: float = 0.05, fn: float = 0.10, seed: int = 0):
        self.grid = grid
        self.sensor_cfg = sensor_cfg
        self.fp = float(fp)
        self.fn = float(fn)

        if not (0.0 <= self.fp <= 1.0 and 0.0 <= self.fn <= 1.0):
            raise ValueError("fp and fn must be in [0,1].")
        if not (0.0 <= sensor_cfg.base_p <= 1.0 and 0.0 <= sensor_cfg.hotspot_p <= 1.0):
            raise ValueError("base_p and hotspot_p must be in [0,1].")
        if sensor_cfg.radius < 0:
            raise ValueError("radius must be >= 0")
        if len(sensor_cfg.sensors) == 0:
            raise ValueError("Need at least one sensor center.")

        self.rng = np.random.default_rng(int(seed))
        self.reset(sensor_mode=0)

    def seed(self, seed: int) -> None:
        self.rng = np.random.default_rng(int(seed))

    def reset(self, sensor_mode: int = 0) -> Dict[str, Any]:
        self.t = 0
        self.pos = self.grid.start
        self.sensor_mode = int(sensor_mode)
        self.detected = False
        self.total_true_risk = 0.0
        return {"pos": self.pos, "t": self.t}

    def in_bounds(self, p: Tuple[int, int]) -> bool:
        x, y = p
        return 0 <= x < self.grid.width and 0 <= y < self.grid.height

    def is_free(self, p: Tuple[int, int]) -> bool:
        return self.in_bounds(p) and (p not in self.grid.obstacles)

    def true_detection_prob(self, p: Tuple[int, int], mode: int) -> float:
        if not (0 <= mode < len(self.sensor_cfg.sensors)):
            raise ValueError(f"mode {mode} out of range")
        sx, sy = self.sensor_cfg.sensors[mode]
        x, y = p
        d = abs(x - sx) + abs(y - sy)
        return self.sensor_cfg.hotspot_p if d <= self.sensor_cfg.radius else self.sensor_cfg.base_p

    def observation_prob(self, alarm: int, p_true: float) -> float:
        p_alarm = p_true * (1.0 - self.fn) + (1.0 - p_true) * self.fp
        if alarm == 1:
            return p_alarm
        if alarm == 0:
            return 1.0 - p_alarm
        raise ValueError("alarm must be 0 or 1")

    def step(self, a: Action) -> Dict[str, Any]:
        if self.detected:
            return {"pos": self.pos, "t": self.t, "alarm": 1, "p_true": 1.0, "detected": True, "done": True}

        self.t += 1
        nx = self.pos[0] + int(a[0])
        ny = self.pos[1] + int(a[1])
        np_ = (nx, ny)
        if self.is_free(np_):
            self.pos = np_

        p_true = float(self.true_detection_prob(self.pos, self.sensor_mode))
        self.total_true_risk += p_true

        if self.rng.random() < p_true:
            self.detected = True

        p_alarm = p_true * (1.0 - self.fn) + (1.0 - p_true) * self.fp
        alarm = 1 if (self.rng.random() < p_alarm) else 0

        done = self.detected or (self.pos == self.grid.goal) or (self.t >= 200)
        return {"pos": self.pos, "t": self.t, "alarm": alarm, "p_true": p_true, "detected": self.detected, "done": done}


def build_two_corridor_grid(width: int = 15, height: int = 9) -> GridConfig:
    obstacles = set()
    wall_x = width // 2
    gap_ys = {2, 6}
    for y in range(height):
        if y not in gap_ys:
            obstacles.add((wall_x, y))

    start = (1, height - 2)
    goal = (width - 2, 1)
    if start in obstacles or goal in obstacles:
        raise RuntimeError("Start/goal blocked unexpectedly")

    return GridConfig(width=width, height=height, start=start, goal=goal, obstacles=frozenset(obstacles))


def print_grid_ascii(grid: GridConfig, sensor_cfg: SensorConfig) -> None:
    W, H = grid.width, grid.height
    obs = set(grid.obstacles)
    sens = set(sensor_cfg.sensors)
    for y in range(H):
        row = []
        for x in range(W):
            p = (x, y)
            if p == grid.start:
                row.append("R")
            elif p == grid.goal:
                row.append("G")
            elif p in sens:
                row.append("S")
            elif p in obs:
                row.append("#")
            else:
                row.append(".")
        print("".join(row))


# =============================================================================
# Belief over modes
# =============================================================================

class ModeBelief:
    """Exact belief over discrete modes m in {0..M-1}."""

    def __init__(self, M: int, init: Optional[np.ndarray] = None):
        self.M = int(M)
        if self.M <= 0:
            raise ValueError("M must be >= 1")
        if init is None:
            self.b = np.full(self.M, 1.0 / self.M)
        else:
            init = np.asarray(init, dtype=float).reshape(-1)
            if init.shape != (self.M,):
                raise ValueError("init shape mismatch")
            if np.any(init < 0):
                raise ValueError("init must be nonnegative")
            s = float(init.sum())
            self.b = init / s if s > 0 else np.full(self.M, 1.0 / self.M)

    def update(self, env: GridWorldStealthEnv, alarm: int, pos: Tuple[int, int], eps: float = 1e-12) -> None:
        like = np.zeros(self.M, dtype=float)
        for m in range(self.M):
            p_true = env.true_detection_prob(pos, m)
            like[m] = env.observation_prob(alarm, p_true)
        post = self.b * like
        Z = float(post.sum())
        if (not np.isfinite(Z)) or Z < eps:
            return
        self.b = post / Z


# =============================================================================
# Policies
# =============================================================================

class RobotPolicy:
    name: str = "RobotPolicy"

    def reset(self, start_pos: Tuple[int, int]) -> None:
        pass

    def act(self, env: GridWorldStealthEnv, belief: ModeBelief, last_obs: Optional[int]) -> Action:
        raise NotImplementedError


class FixedPathPolicy(RobotPolicy):
    def __init__(self, path: List[Tuple[int, int]], name: str):
        if len(path) < 2:
            raise ValueError("Path must have >=2 states")
        self.path = list(path)
        self.name = str(name)
        self._idx = 0

    def reset(self, start_pos: Tuple[int, int]) -> None:
        try:
            self._idx = self.path.index(start_pos)
        except ValueError:
            self._idx = 0

    def act(self, env: GridWorldStealthEnv, belief: ModeBelief, last_obs: Optional[int]) -> Action:
        cur = env.pos
        if self._idx >= len(self.path) - 1:
            return (0, 0)
        if cur != self.path[self._idx]:
            try:
                self._idx = self.path.index(cur, self._idx)
            except ValueError:
                return (0, 0)
        nxt = self.path[self._idx + 1]
        dx = int(np.clip(nxt[0] - cur[0], -1, 1))
        dy = int(np.clip(nxt[1] - cur[1], -1, 1))
        self._idx += 1
        return (dx, dy)


class RandomPolicy(RobotPolicy):
    """Reproducible random policy using env.rng."""

    def __init__(self, name: str = "R_Random"):
        self.name = name

    def reset(self, start_pos: Tuple[int, int]) -> None:
        pass

    def act(self, env: GridWorldStealthEnv, belief: ModeBelief, last_obs: Optional[int]) -> Action:
        candidates: List[Action] = []
        x, y = env.pos
        for a in [(1, 0), (-1, 0), (0, 1), (0, -1), (0, 0)]:
            np_ = (x + a[0], y + a[1])
            if env.is_free(np_):
                candidates.append(a)
        if not candidates:
            return (0, 0)
        return candidates[int(env.rng.integers(0, len(candidates)))]


class OnlineBeliefReplanPolicy(RobotPolicy):
    """POMDP-ish heuristic: replan each step using risk map induced by current belief b_t."""

    def __init__(self, env: GridWorldStealthEnv, risk_weight: float = 12.0, name: str = "R_OnlineBeliefReplan"):
        self.env = env
        self.risk_weight = float(risk_weight)
        self.name = name
        self._cached_next: Optional[Tuple[int, int]] = None

    def reset(self, start_pos: Tuple[int, int]) -> None:
        self._cached_next = None

    def act(self, env: GridWorldStealthEnv, belief: ModeBelief, last_obs: Optional[int]) -> Action:
        # Build a risk map from belief over modes.
        mode_probs = belief.b

        def step_cost(frm: Tuple[int, int], to: Tuple[int, int]) -> float:
            x, y = to
            # expected risk at to under belief
            r = 0.0
            for m, pm in enumerate(mode_probs):
                r += float(pm) * float(env.true_detection_prob(to, m))
            return 1.0 + self.risk_weight * r

        # Plan from current pos to goal (one-step receding horizon)
        try:
            path = astar_path(env.grid, env.pos, env.grid.goal, step_cost)
            if len(path) < 2:
                return (0, 0)
            nxt = path[1]
            dx = int(np.clip(nxt[0] - env.pos[0], -1, 1))
            dy = int(np.clip(nxt[1] - env.pos[1], -1, 1))
            return (dx, dy)
        except Exception:
            return (0, 0)


class SensorPolicy:
    name: str = "SensorPolicy"

    def reset(self) -> None:
        pass

    def select_mode(self, t: int) -> int:
        raise NotImplementedError


class FixedModeSensorPolicy(SensorPolicy):
    def __init__(self, mode: int, name: Optional[str] = None):
        self.mode = int(mode)
        self.name = name or f"S_Mode{mode}"

    def reset(self) -> None:
        pass

    def select_mode(self, t: int) -> int:
        return self.mode


class AlternatingSensorPolicy(SensorPolicy):
    """Simple sensor heuristic for benchmarks: alternate modes 0,1,0,1,..."""

    def __init__(self, M: int, name: str = "S_Alternate"):
        self.M = int(M)
        self.name = name

    def reset(self) -> None:
        pass

    def select_mode(self, t: int) -> int:
        return int(t % self.M)


class RandomModeSensorPolicy(SensorPolicy):
    """Benchmark sensor: random mode each step (uses numpy Generator for reproducibility)."""

    def __init__(self, M: int, seed: int = 0, name: str = "S_RandomMode"):
        self.M = int(M)
        self.rng = np.random.default_rng(int(seed))
        self.name = name

    def reset(self) -> None:
        pass

    def select_mode(self, t: int) -> int:
        return int(self.rng.integers(0, self.M))


# =============================================================================
# A* (used for planning)
# =============================================================================

def astar_path(
    grid: GridConfig,
    start: Tuple[int, int],
    goal: Tuple[int, int],
    step_cost: Callable[[Tuple[int, int], Tuple[int, int]], float],
    max_expansions: int = 250_000,
) -> List[Tuple[int, int]]:
    if start == goal:
        return [start]

    def h(p: Tuple[int, int]) -> float:
        return abs(p[0] - goal[0]) + abs(p[1] - goal[1])

    def neighbors(p: Tuple[int, int]):
        x, y = p
        for dx, dy in [(1, 0), (-1, 0), (0, 1), (0, -1)]:
            np_ = (x + dx, y + dy)
            if 0 <= np_[0] < grid.width and 0 <= np_[1] < grid.height and np_ not in grid.obstacles:
                yield np_

    open_heap: List[Tuple[float, float, Tuple[int, int]]] = []
    heapq.heappush(open_heap, (h(start), 0.0, start))

    came: Dict[Tuple[int, int], Optional[Tuple[int, int]]] = {start: None}
    gscore: Dict[Tuple[int, int], float] = {start: 0.0}

    expansions = 0
    while open_heap:
        _, _, cur = heapq.heappop(open_heap)
        expansions += 1
        if cur == goal:
            path: List[Tuple[int, int]] = []
            while cur is not None:
                path.append(cur)
                cur = came[cur]
            path.reverse()
            return path
        if expansions > max_expansions:
            raise RuntimeError("A* exceeded max expansions")

        for nb in neighbors(cur):
            tentative = gscore[cur] + float(step_cost(cur, nb))
            if (nb not in gscore) or (tentative < gscore[nb] - 1e-12):
                gscore[nb] = tentative
                came[nb] = cur
                heapq.heappush(open_heap, (tentative + h(nb), tentative, nb))

    raise RuntimeError("A* failed: unreachable goal")


def astar_via_waypoint(
    grid: GridConfig,
    start: Tuple[int, int],
    waypoint: Tuple[int, int],
    goal: Tuple[int, int],
    step_cost: Callable[[Tuple[int, int], Tuple[int, int]], float],
) -> List[Tuple[int, int]]:
    p1 = astar_path(grid, start, waypoint, step_cost)
    p2 = astar_path(grid, waypoint, goal, step_cost)
    return p1[:-1] + p2


# =============================================================================
# Rollouts + payoffs
# =============================================================================

@dataclass
class EpisodeStats:
    steps: int
    reached_goal: bool
    detected: bool
    total_true_risk: float
    U_R: float
    U_S: float


def rollout_episode(
    env: GridWorldStealthEnv,
    robot: RobotPolicy,
    sensor: SensorPolicy,
    M_modes: int,
    seed: int,
    max_steps: int = 200,
    lambda_risk: float = 1.0,
    det_penalty: float = 50.0,
    sensor_energy_per_step: float = 0.2,
    step_debug: bool = False,
) -> EpisodeStats:
    env.seed(seed)
    sensor.reset()
    env.reset(sensor_mode=sensor.select_mode(0))

    belief = ModeBelief(M_modes)
    robot.reset(env.pos)

    total_risk = 0.0
    last_alarm: Optional[int] = None

    for _ in range(max_steps):
        env.sensor_mode = sensor.select_mode(env.t)
        a = robot.act(env, belief, last_alarm)
        out = env.step(a)

        total_risk += float(out["p_true"])
        belief.update(env, out["alarm"], out["pos"])
        last_alarm = int(out["alarm"])

        if step_debug:
            print(
                f"[Step] t={out['t']:3d} pos={out['pos']} a={a} p_true={out['p_true']:.3f} "
                f"alarm={out['alarm']} det={out['detected']} done={out['done']} b={belief.b.round(3)}"
            )

        if out["done"]:
            break

    reached_goal = (env.pos == env.grid.goal)
    detected = bool(env.detected)
    steps = int(env.t)

    cost_R = steps + lambda_risk * total_risk + (det_penalty if detected else 0.0)
    U_R = -float(cost_R)

    U_S = float((det_penalty if detected else 0.0) + lambda_risk * total_risk - sensor_energy_per_step * steps)

    return EpisodeStats(steps=steps, reached_goal=reached_goal, detected=detected, total_true_risk=float(total_risk), U_R=U_R, U_S=U_S)


def evaluate_payoffs(
    env: GridWorldStealthEnv,
    robots: List[RobotPolicy],
    sensors: List[SensorPolicy],
    M_modes: int,
    rollouts: int,
    base_seed: int,
    log: Logger,
    debug_rollout_pair: Optional[Tuple[int, int]] = None,
) -> Tuple[np.ndarray, np.ndarray, Dict[Tuple[int, int], Dict[str, float]]]:
    m, n = len(robots), len(sensors)
    U_R = np.zeros((m, n), dtype=float)
    U_S = np.zeros((m, n), dtype=float)
    diag: Dict[Tuple[int, int], Dict[str, float]] = {}

    log.info(f"[Eval] Estimating payoffs: m={m}, n={n}, rollouts={rollouts}, base_seed={base_seed}")

    for i, rpol in enumerate(robots):
        for j, spol in enumerate(sensors):
            step_debug = (debug_rollout_pair == (i, j))

            r_list: List[float] = []
            s_list: List[float] = []
            det = 0
            goal = 0
            steps_list: List[int] = []
            risk_list: List[float] = []

            for k in range(rollouts):
                seed = base_seed + 100000 * i + 1000 * j + k
                st = rollout_episode(env, rpol, spol, M_modes=M_modes, seed=seed, step_debug=step_debug)
                r_list.append(st.U_R)
                s_list.append(st.U_S)
                det += int(st.detected)
                goal += int(st.reached_goal)
                steps_list.append(st.steps)
                risk_list.append(st.total_true_risk)

                if step_debug:
                    step_debug = False  # only show one rollout

            U_R[i, j] = float(np.mean(r_list))
            U_S[i, j] = float(np.mean(s_list))

            diag[(i, j)] = {
                "det_rate": det / rollouts,
                "goal_rate": goal / rollouts,
                "mean_steps": float(np.mean(steps_list)),
                "mean_risk": float(np.mean(risk_list)),
                "std_UR": float(np.std(r_list)),
                "std_US": float(np.std(s_list)),
            }

    if log.k >= 1:
        log.info("[Eval] Compact payoff summary:")
        for i, rpol in enumerate(robots):
            for j, spol in enumerate(sensors):
                d = diag[(i, j)]
                log.info(
                    f"  (R{i}:{rpol.name}, S{j}:{spol.name}) "
                    f"UR={U_R[i,j]:8.3f}±{d['std_UR']:.2f} | "
                    f"US={U_S[i,j]:8.3f}±{d['std_US']:.2f} | "
                    f"det%={100*d['det_rate']:5.1f} goal%={100*d['goal_rate']:5.1f} "
                    f"steps={d['mean_steps']:.1f} risk={d['mean_risk']:.2f}"
                )

    return U_R, U_S, diag


# =============================================================================
# NBS solver (with optional entropy regularization)
# =============================================================================

def project_simplex(v: np.ndarray, z: float = 1.0) -> np.ndarray:
    v = np.asarray(v, dtype=float).reshape(-1)
    if v.size == 0:
        raise ValueError("Empty vector")
    if z <= 0:
        raise ValueError("z must be > 0")

    u = np.sort(v)[::-1]
    cssv = np.cumsum(u)
    rho = np.nonzero(u * np.arange(1, v.size + 1) > (cssv - z))[0]
    if rho.size == 0:
        return np.full_like(v, z / v.size)
    rho = int(rho[-1])
    theta = (cssv[rho] - z) / (rho + 1.0)
    w = np.maximum(v - theta, 0.0)
    s = float(w.sum())
    if not np.isfinite(s) or s <= 0:
        return np.full_like(v, z / v.size)
    return w * (z / s)


@dataclass
class NBSResult:
    x: np.ndarray
    obj: float
    gains: Tuple[float, float]
    support: int


def solve_nbs(
    uR: np.ndarray,
    uS: np.ndarray,
    log: Logger,
    max_iters: int = 400,
    alpha: float = 0.5,
    tol_l1: float = 1e-6,
    kappa: float = 1e-6,
    disagreement: str = "minminus",
    entropy_tau: float = 0.0,
) -> NBSResult:
    uR = np.asarray(uR, dtype=float).reshape(-1)
    uS = np.asarray(uS, dtype=float).reshape(-1)
    if uR.shape != uS.shape:
        raise ValueError("uR and uS must have same shape")
    d = uR.size
    if d < 2:
        raise ValueError("Need >=2 joint actions")

    unif = np.full(d, 1.0 / d)

    disagreement = disagreement.lower().strip()
    if disagreement == "minminus":
        dR = float(np.min(uR) - 1.0)
        dS = float(np.min(uS) - 1.0)
    elif disagreement == "uniform":
        dR = float(uR @ unif)
        dS = float(uS @ unif)
    else:
        raise ValueError("disagreement must be 'minminus' or 'uniform'")

    x = unif.copy()

    def gains(xv: np.ndarray) -> Tuple[float, float]:
        return float(uR @ xv - dR), float(uS @ xv - dS)

    def entropy(xv: np.ndarray) -> float:
        xx = np.clip(xv, 1e-12, 1.0)
        return float(-np.sum(xx * np.log(xx)))

    def obj(xv: np.ndarray) -> float:
        gR, gS = gains(xv)
        gR = max(gR, kappa)
        gS = max(gS, kappa)
        return float(np.log(gR) + np.log(gS) + entropy_tau * entropy(xv))

    def grad(xv: np.ndarray) -> np.ndarray:
        gR, gS = gains(xv)
        gR = max(gR, kappa)
        gS = max(gS, kappa)
        g = (uR / gR) + (uS / gS)
        if entropy_tau > 0:
            xx = np.clip(xv, 1e-12, 1.0)
            g += entropy_tau * (-(np.log(xx) + 1.0))
        return g

    last = obj(x)
    log.info(f"[NBS] d={d} disagreement=({dR:.3f},{dS:.3f}) entropy_tau={entropy_tau:.3g}")

    for t in range(1, max_iters + 1):
        g = grad(x)
        a = alpha
        improved = False
        for _ in range(30):
            x_new = project_simplex(x + a * g)
            new_obj = obj(x_new)
            if new_obj >= last - 1e-12:
                improved = True
                break
            a *= 0.5
            if a < 1e-6:
                break
        if not improved:
            break

        delta = float(np.linalg.norm(x_new - x, ord=1))
        x = x_new
        last = new_obj

        if log.k >= 2 and (t <= 5 or t % 25 == 0):
            gR, gS = gains(x)
            top = np.argsort(-x)[:5]
            top_str = ", ".join([f"{i}:{x[i]:.3f}" for i in top])
            log.debug(f"[NBS][it={t:3d}] obj={last:.6f} gains=({gR:.3f},{gS:.3f}) L1={delta:.2e} top={top_str}")

        if delta < tol_l1:
            break

    gR, gS = gains(x)
    support = int(np.sum(x > 1e-6))
    log.info(f"[NBS] done: obj={last:.6f} gains=({gR:.3f},{gS:.3f}) support={support}/{d}")

    return NBSResult(x=x, obj=float(last), gains=(float(gR), float(gS)), support=support)


def joint_to_matrix(x: np.ndarray, m: int, n: int) -> np.ndarray:
    x = np.asarray(x, dtype=float).reshape(-1)
    if x.size != m * n:
        raise ValueError("x size mismatch")
    X = x.reshape((m, n))
    s = float(X.sum())
    if not np.isfinite(s) or abs(s - 1.0) > 1e-6:
        # Renormalize defensively
        X = X / max(s, 1e-12)
    return X


def marginals_from_joint(X: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    sigma_R = X.sum(axis=1)
    sigma_S = X.sum(axis=0)
    if sigma_R.sum() > 0:
        sigma_R = sigma_R / sigma_R.sum()
    if sigma_S.sum() > 0:
        sigma_S = sigma_S / sigma_S.sum()
    return sigma_R, sigma_S


def entropy_of_joint(X: np.ndarray) -> float:
    xx = np.clip(X.reshape(-1), 1e-12, 1.0)
    return float(-np.sum(xx * np.log(xx)))


# =============================================================================
# Best responses (marginal vs correlated)
# =============================================================================

def compute_expected_risk_map_from_policy_mixture(
    env: GridWorldStealthEnv,
    sensors: List[SensorPolicy],
    pi: np.ndarray,
    M_modes: int,
) -> np.ndarray:
    """Expected p_true(cell) under mixture over sensor POLICIES (not modes).

    For each sensor policy j, we use its mode at t=0 as its defining mode.
    (This matches FixedModeSensorPolicy exactly.)
    """
    pi = np.asarray(pi, dtype=float).reshape(-1)
    if pi.size != len(sensors):
        raise ValueError("mixture length mismatch")

    mode_probs = np.zeros(M_modes, dtype=float)
    for j, sp in enumerate(sensors):
        m = int(sp.select_mode(0))
        if not (0 <= m < M_modes):
            raise ValueError("invalid sensor mode")
        mode_probs[m] += float(pi[j])
    if mode_probs.sum() > 0:
        mode_probs = mode_probs / mode_probs.sum()

    H, W = env.grid.height, env.grid.width
    risk = np.zeros((H, W), dtype=float)
    for y in range(H):
        for x in range(W):
            if (x, y) in env.grid.obstacles:
                risk[y, x] = np.nan
                continue
            val = 0.0
            for m in range(M_modes):
                val += float(mode_probs[m]) * float(env.true_detection_prob((x, y), m))
            risk[y, x] = float(val)

    return risk


def plan_risk_weighted_path(env: GridWorldStealthEnv, risk_map: np.ndarray, risk_weight: float) -> List[Tuple[int, int]]:
    def step_cost(frm: Tuple[int, int], to: Tuple[int, int]) -> float:
        x, y = to
        r = risk_map[y, x]
        if not np.isfinite(r):
            return 1e9
        return 1.0 + float(risk_weight) * float(r)

    return astar_path(env.grid, env.grid.start, env.grid.goal, step_cost)


def robot_best_response_to_mixture(
    env: GridWorldStealthEnv,
    sensors: List[SensorPolicy],
    pi_S: np.ndarray,
    robots: List[RobotPolicy],
    M_modes: int,
    risk_weight: float,
    log: Logger,
    tag: str,
) -> RobotPolicy:
    risk = compute_expected_risk_map_from_policy_mixture(env, sensors, pi_S, M_modes=M_modes)
    try:
        path = plan_risk_weighted_path(env, risk, risk_weight=risk_weight)
    except Exception as e:
        log.info(f"[RobotBR] WARNING A* failed ({tag}): {e}")
        return robots[0]

    path_tuple = tuple(path)
    for p in robots:
        if isinstance(p, FixedPathPolicy) and tuple(p.path) == path_tuple:
            log.info(f"[RobotBR] ({tag}) BR path already exists: {p.name}")
            return p

    newp = FixedPathPolicy(path, name=f"R_BR_{tag}_w{risk_weight:.1f}_len{len(path)}")
    log.info(f"[RobotBR] ({tag}) Added new robot policy: {newp.name}")
    return newp


def sensor_best_response_to_mixture(
    env: GridWorldStealthEnv,
    robots: List[RobotPolicy],
    pi_R: np.ndarray,
    candidate_modes: List[int],
    M_modes: int,
    rollouts: int,
    base_seed: int,
    log: Logger,
    tag: str,
) -> FixedModeSensorPolicy:
    """Sensor best response with common-random-numbers (CRN).

    Why CRN matters: payoff variance is large (detection is a rare/threshold event).
    If each candidate mode is evaluated on different random rollouts, you can pick
    the wrong 'best mode' by noise, which then breaks the PSRO expansion logic.

    Fix: reuse the same sampled robot indices AND the same episode seeds across all
    candidate modes.
    """
    pi_R = np.asarray(pi_R, dtype=float).reshape(-1)
    if pi_R.size != len(robots):
        raise ValueError("pi_R length mismatch")

    rng = np.random.default_rng(int(base_seed))

    # Same robot-index samples for every mode
    robot_idxs = rng.choice(len(robots), size=rollouts, p=pi_R, replace=True)

    # Same episode seeds for every mode (common random numbers)
    seeds = (int(base_seed) + np.arange(rollouts)).astype(int)

    best_mode: Optional[int] = None
    best_val = -1e18

    for mode in candidate_modes:
        if not (0 <= mode < M_modes):
            continue
        sp = FixedModeSensorPolicy(mode, name=f"S_BR_{tag}_Mode{mode}")

        vals: List[float] = []
        for k in range(rollouts):
            rp = robots[int(robot_idxs[k])]
            st = rollout_episode(env, rp, sp, M_modes=M_modes, seed=int(seeds[k]))
            vals.append(st.U_S)

        mean_u = float(np.mean(vals))
        if log.k >= 2:
            log.debug(f"[SensorBR] ({tag}) mode={mode} E[US]={mean_u:.3f} std={float(np.std(vals)):.2f}")

        if mean_u > best_val:
            best_val = mean_u
            best_mode = mode

    if best_mode is None:
        raise RuntimeError("No valid sensor BR mode found")

    log.info(f"[SensorBR] ({tag}) Best mode={best_mode} E[US]={best_val:.3f}")
    return FixedModeSensorPolicy(best_mode, name=f"S_BR_{tag}_Mode{best_mode}")


def conditional_sensor_given_robot(X: np.ndarray, i: int, eps: float = 1e-12) -> np.ndarray:
    row = np.asarray(X[i, :], dtype=float)
    s = float(row.sum())
    if s <= eps:
        return np.full_like(row, 1.0 / row.size)
    return row / s


def conditional_robot_given_sensor(X: np.ndarray, j: int, eps: float = 1e-12) -> np.ndarray:
    col = np.asarray(X[:, j], dtype=float)
    s = float(col.sum())
    if s <= eps:
        return np.full_like(col, 1.0 / col.size)
    return col / s


def compute_ce_regrets(U_R: np.ndarray, U_S: np.ndarray, X: np.ndarray, eps: float = 1e-12) -> Dict[str, float]:
    """Conditional recommendation regrets (CE-style) computed on current meta-game.

    For robot (given recommendation i):
        regret_R(i) = max_{i'} E_{j~q(.|i)}[U_R(i',j) - U_R(i,j)]

    For sensor (given recommendation j):
        regret_S(j) = max_{j'} E_{i~q(.|j)}[U_S(i,j') - U_S(i,j)]

    Returns max and average regrets.
    """
    m, n = U_R.shape
    assert U_S.shape == (m, n)
    assert X.shape == (m, n)

    sigma_R, sigma_S = marginals_from_joint(X)

    reg_R = []
    for i in range(m):
        if sigma_R[i] <= eps:
            continue
        q = conditional_sensor_given_robot(X, i, eps=eps)
        rec = float(np.dot(q, U_R[i, :]))
        best = rec
        for ip in range(m):
            val = float(np.dot(q, U_R[ip, :]))
            if val > best:
                best = val
        reg_R.append(best - rec)

    reg_S = []
    for j in range(n):
        if sigma_S[j] <= eps:
            continue
        q = conditional_robot_given_sensor(X, j, eps=eps)
        rec = float(np.dot(q, U_S[:, j]))
        best = rec
        for jp in range(n):
            val = float(np.dot(q, U_S[:, jp]))
            if val > best:
                best = val
        reg_S.append(best - rec)

    return {
        "max_regret_R": float(max(reg_R) if reg_R else 0.0),
        "max_regret_S": float(max(reg_S) if reg_S else 0.0),
        "mean_regret_R": float(np.mean(reg_R) if reg_R else 0.0),
        "mean_regret_S": float(np.mean(reg_S) if reg_S else 0.0),
    }


def find_sensor_by_mode(pols: List[SensorPolicy], mode: int) -> Optional[FixedModeSensorPolicy]:
    for p in pols:
        if isinstance(p, FixedModeSensorPolicy) and p.mode == mode:
            return p
    return None


# =============================================================================
# Policy initialization + evaluator for joint strategy
# =============================================================================

def build_initial_policies(env: GridWorldStealthEnv, M_modes: int) -> Tuple[List[RobotPolicy], List[SensorPolicy]]:
    grid = env.grid

    step_cost = lambda a, b: 1.0
    p_short = astar_path(grid, grid.start, grid.goal, step_cost)

    wall_x = grid.width // 2
    p_upper = astar_via_waypoint(grid, grid.start, (wall_x, 2), grid.goal, step_cost)
    p_lower = astar_via_waypoint(grid, grid.start, (wall_x, 6), grid.goal, step_cost)

    robots: List[RobotPolicy] = [
        FixedPathPolicy(p_short, "R_Shortest"),
        FixedPathPolicy(p_upper, "R_UpperCorridor"),
        FixedPathPolicy(p_lower, "R_LowerCorridor"),
        RandomPolicy("R_Random"),
        OnlineBeliefReplanPolicy(env, risk_weight=12.0, name="R_OnlineBeliefReplan"),
    ]
    sensors: List[SensorPolicy] = [FixedModeSensorPolicy(m, name=f"S_Mode{m}") for m in range(M_modes)]
    return robots, sensors


@dataclass
class StrategyEval:
    mean_U_R: float
    mean_U_S: float
    det_rate: float
    goal_rate: float
    mean_steps: float
    mean_risk: float


def evaluate_joint_strategy(
    env: GridWorldStealthEnv,
    robots: List[RobotPolicy],
    sensors: List[SensorPolicy],
    X: np.ndarray,
    M_modes: int,
    episodes: int,
    base_seed: int,
) -> StrategyEval:
    m, n = X.shape
    probs = X.reshape(-1)
    probs = probs / max(float(probs.sum()), 1e-12)

    rng = np.random.default_rng(int(base_seed))

    UR = []
    US = []
    det = 0
    goal = 0
    steps_list = []
    risk_list = []

    for k in range(episodes):
        idx = int(rng.choice(m * n, p=probs))
        i, j = np.unravel_index(idx, (m, n))
        seed = base_seed + k
        st = rollout_episode(env, robots[i], sensors[j], M_modes=M_modes, seed=seed)
        UR.append(st.U_R)
        US.append(st.U_S)
        det += int(st.detected)
        goal += int(st.reached_goal)
        steps_list.append(st.steps)
        risk_list.append(st.total_true_risk)

    return StrategyEval(
        mean_U_R=float(np.mean(UR)),
        mean_U_S=float(np.mean(US)),
        det_rate=float(det / episodes),
        goal_rate=float(goal / episodes),
        mean_steps=float(np.mean(steps_list)),
        mean_risk=float(np.mean(risk_list)),
    )


# =============================================================================
# Training loop (marginal vs correlated)
# =============================================================================

@dataclass
class TrainHistoryRow:
    outer_iter: int
    m: int
    n: int
    nbs_obj: float
    entropy_X: float
    max_regret_R: float
    max_regret_S: float
    selfplay_UR: float
    selfplay_US: float
    selfplay_det: float
    selfplay_goal: float
    seconds: float


@dataclass
class TrainResult:
    solver: str
    env: GridWorldStealthEnv
    robots: List[RobotPolicy]
    sensors: List[SensorPolicy]
    X: np.ndarray
    history: List[TrainHistoryRow]


def run_training(env: GridWorldStealthEnv, args: argparse.Namespace, solver: str, log: Logger) -> TrainResult:
    t0_all = time.time()

    grid = env.grid
    sensor_cfg = env.sensor_cfg
    M_modes = len(sensor_cfg.sensors)

    robots, sensors = build_initial_policies(env, M_modes=M_modes)

    # Optional: reduce initial set if you want smaller games.
    # (We keep it as-is for benchmarks.)

    if log.k >= 1:
        log.info(f"[{solver}] Initial robots: " + ", ".join([p.name for p in robots]))
        log.info(f"[{solver}] Initial sensors: " + ", ".join([p.name for p in sensors]))

    debug_pair = None
    if args.debug_rollout_pair:
        parts = args.debug_rollout_pair.split(",")
        if len(parts) == 2:
            debug_pair = (int(parts[0]), int(parts[1]))
            log.info(f"[{solver}] Will print one step-by-step rollout for pair {debug_pair} (only once).")

    history: List[TrainHistoryRow] = []
    X = None

    for it in range(1, args.outer_iters + 1):
        t0 = time.time()
        log.banner(f"[{solver}] Outer iter {it}/{args.outer_iters}")

        U_R, U_S, _diag = evaluate_payoffs(
            env,
            robots,
            sensors,
            M_modes=M_modes,
            rollouts=args.rollouts_payoff,
            base_seed=1000 + 100 * it,
            log=log,
            debug_rollout_pair=debug_pair,
        )
        debug_pair = None

        # Solve NBS over joint actions
        uR = U_R.reshape(-1)
        uS = U_S.reshape(-1)

        nbs = solve_nbs(
            uR,
            uS,
            log=log,
            disagreement=args.disagreement,
            entropy_tau=args.entropy_tau,
        )

        m, n = U_R.shape
        X = joint_to_matrix(nbs.x, m, n)
        sigma_R, sigma_S = marginals_from_joint(X)

        # Print top joint actions
        top = np.argsort(-X.reshape(-1))[:min(5, X.size)]
        log.info(f"[{solver}] Top joint actions:")
        for k, idx in enumerate(top, start=1):
            i, j = np.unravel_index(int(idx), (m, n))
            log.info(f"  #{k}: (R{i}:{robots[i].name}, S{j}:{sensors[j].name}) prob={X[i,j]:.4f}")
        log.info(f"[{solver}] sigma_R={sigma_R.round(3)}")
        log.info(f"[{solver}] sigma_S={sigma_S.round(3)}")

        # Stability diagnostics
        regrets = compute_ce_regrets(U_R, U_S, X)
        ent = entropy_of_joint(X)

        # Self-play evaluation under joint X
        sp = evaluate_joint_strategy(
            env,
            robots,
            sensors,
            X,
            M_modes=M_modes,
            episodes=args.eval_episodes,
            base_seed=9000 + 100 * it,
        )

        log.info(
            f"[{solver}] CE-regrets: maxR={regrets['max_regret_R']:.3f} maxS={regrets['max_regret_S']:.3f} | "
            f"SelfPlay: UR={sp.mean_U_R:.2f} US={sp.mean_U_S:.2f} det%={100*sp.det_rate:.1f} goal%={100*sp.goal_rate:.1f} | "
            f"H(X)={ent:.3f}"
        )

        # Best-response expansion
        if solver == "marginal":
            br_r = robot_best_response_to_mixture(
                env,
                sensors,
                pi_S=sigma_S,
                robots=robots,
                M_modes=M_modes,
                risk_weight=args.risk_weight_br,
                log=log,
                tag="Marginal",
            )

            br_s = sensor_best_response_to_mixture(
                env,
                robots,
                pi_R=sigma_R,
                candidate_modes=list(range(M_modes)),
                M_modes=M_modes,
                rollouts=args.rollouts_br,
                base_seed=2000 + 100 * it,
                log=log,
                tag="Marginal",
            )

        elif solver == "correlated":
            # FIX: choose conditional mixtures q(.|i) and q(.|j)
            # We only check top-K recommendations to keep it tractable.
            topK = max(1, int(args.cond_top_k))

            # Robot: check the top-K robot recommendations by sigma_R
            cand_i = [int(i) for i in np.argsort(-sigma_R) if sigma_R[int(i)] > 1e-8][:topK]
            if not cand_i:
                cand_i = [int(i) for i in np.argsort(-sigma_R)[:topK]]
            best_gain = 0.0
            br_r = robots[0]

            for i in cand_i:
                qS = conditional_sensor_given_robot(X, int(i))
                tag = f"Cond_i{i}"
                pol = robot_best_response_to_mixture(
                    env,
                    sensors,
                    pi_S=qS,
                    robots=robots,
                    M_modes=M_modes,
                    risk_weight=args.risk_weight_br,
                    log=log,
                    tag=tag,
                )

                # Estimate conditional improvement: E_q[U_R(pol,j)] - E_q[U_R(i,j)]
                # If pol is already in set, its payoff exists in U_R row of that policy.
                # Otherwise we simulate pol against each sensor policy j.
                if pol in robots:
                    ip = robots.index(pol)
                    dev = float(np.dot(qS, U_R[ip, :]))
                else:
                    # simulate quickly vs each sensor policy
                    dev_vals = []
                    for j in range(n):
                        vals = []
                        for kk in range(args.br_eval_rollouts):
                            seed = 777000 + 1000 * it + 100 * i + 10 * j + kk
                            st = rollout_episode(env, pol, sensors[j], M_modes=M_modes, seed=seed)
                            vals.append(st.U_R)
                        dev_vals.append(float(np.mean(vals)))
                    dev = float(np.dot(qS, np.asarray(dev_vals)))

                rec = float(np.dot(qS, U_R[i, :]))
                gain = dev - rec
                if gain > best_gain + 1e-9:
                    best_gain = gain
                    br_r = pol

            if best_gain > args.add_threshold:
                if br_r in robots:
                    log.info(f"[{solver}] Best robot deviation already in set; est_gain={best_gain:.3f}")
                else:
                    log.info(f"[{solver}] Adding robot deviation; est_gain={best_gain:.3f}")
            else:
                log.info(f"[{solver}] No robot deviation above threshold (best_gain={best_gain:.3f}).")

            # Sensor: check top-K sensor recommendations by sigma_S
            cand_j = [int(j) for j in np.argsort(-sigma_S) if sigma_S[int(j)] > 1e-8][:topK]
            if not cand_j:
                cand_j = [int(j) for j in np.argsort(-sigma_S)[:topK]]
            best_gain_s = 0.0
            br_s = FixedModeSensorPolicy(0, name="S_dummy")

            for j in cand_j:
                qR = conditional_robot_given_sensor(X, int(j))
                tag = f"Cond_j{j}"
                polS = sensor_best_response_to_mixture(
                    env,
                    robots,
                    pi_R=qR,
                    candidate_modes=list(range(M_modes)),
                    M_modes=M_modes,
                    rollouts=args.rollouts_br,
                    base_seed=333000 + 1000 * it + 10 * j,
                    log=log,
                    tag=tag,
                )

                # Estimate conditional improvement for sensor
                # rec under recommendation j is E_q[U_S(i,j)]
                recS = float(np.dot(qR, U_S[:, j]))

                # dev under mode polS.mode: if already present, use its column.
                existing_col = None
                for jj, spj in enumerate(sensors):
                    if isinstance(spj, FixedModeSensorPolicy) and spj.mode == polS.mode:
                        existing_col = jj
                        break

                if existing_col is not None:
                    devS = float(np.dot(qR, U_S[:, existing_col]))
                else:
                    vals = []
                    for kk in range(args.br_eval_rollouts):
                        i_samp = int(np.random.default_rng(444 + kk).choice(len(robots), p=qR))
                        seed = 888000 + 1000 * it + 10 * j + kk
                        st = rollout_episode(env, robots[i_samp], polS, M_modes=M_modes, seed=seed)
                        vals.append(st.U_S)
                    devS = float(np.mean(vals))

                gainS = devS - recS
                if gainS > best_gain_s + 1e-9:
                    best_gain_s = gainS
                    br_s = polS

            if best_gain_s > args.add_threshold:
                if (isinstance(br_s, FixedModeSensorPolicy)) and (find_sensor_by_mode(sensors, br_s.mode) is not None):
                    log.info(f"[{solver}] Best sensor deviation already in set (mode={br_s.mode}); est_gain={best_gain_s:.3f}")
                    br_s = None
                else:
                    log.info(f"[{solver}] Adding sensor deviation; est_gain={best_gain_s:.3f}")
            else:
                log.info(f"[{solver}] No sensor deviation above threshold (best_gain={best_gain_s:.3f}).")
                br_s = None

        else:
            raise ValueError("solver must be marginal or correlated")

        # Add to sets (dedupe)
        if br_r not in robots:
            robots.append(br_r)

        if isinstance(br_s, FixedModeSensorPolicy):
            if find_sensor_by_mode(sensors, br_s.mode) is None:
                sensors.append(br_s)
            else:
                log.info(f"[{solver}] Sensor mode {br_s.mode} already present; not adding duplicate.")

        seconds = float(time.time() - t0)
        history.append(
            TrainHistoryRow(
                outer_iter=it,
                m=len(robots),
                n=len(sensors),
                nbs_obj=float(nbs.obj),
                entropy_X=float(ent),
                max_regret_R=float(regrets["max_regret_R"]),
                max_regret_S=float(regrets["max_regret_S"]),
                selfplay_UR=float(sp.mean_U_R),
                selfplay_US=float(sp.mean_U_S),
                selfplay_det=float(sp.det_rate),
                selfplay_goal=float(sp.goal_rate),
                seconds=seconds,
            )
        )

        log.info(f"[{solver}] Sets: |Pi_R|={len(robots)} |Pi_S|={len(sensors)} | iter_seconds={seconds:.2f}")

    if X is None:
        raise RuntimeError("Training produced no X")

    log.banner(f"[{solver}] Finished")
    log.info(f"[{solver}] Final robots: " + ", ".join([p.name for p in robots]))
    log.info(f"[{solver}] Final sensors: " + ", ".join([p.name for p in sensors]))
    log.info(f"[{solver}] Total time: {time.time()-t0_all:.2f}s")

    return TrainResult(solver=solver, env=env, robots=robots, sensors=sensors, X=X, history=history)


# =============================================================================
# Benchmarks + plotting
# =============================================================================

@dataclass
class BenchCell:
    UR: float
    US: float
    det: float
    goal: float


def run_policy_matrix_benchmark(
    env: GridWorldStealthEnv,
    robot_methods: List[RobotPolicy],
    sensor_methods: List[SensorPolicy],
    M_modes: int,
    episodes: int,
    base_seed: int,
) -> Dict[Tuple[int, int], BenchCell]:
    res: Dict[Tuple[int, int], BenchCell] = {}
    for i, rp in enumerate(robot_methods):
        for j, sp in enumerate(sensor_methods):
            UR = []
            US = []
            det = 0
            goal = 0
            for k in range(episodes):
                seed = base_seed + 100000 * i + 1000 * j + k
                st = rollout_episode(env, rp, sp, M_modes=M_modes, seed=seed)
                UR.append(st.U_R)
                US.append(st.U_S)
                det += int(st.detected)
                goal += int(st.reached_goal)
            res[(i, j)] = BenchCell(
                UR=float(np.mean(UR)),
                US=float(np.mean(US)),
                det=float(det / episodes),
                goal=float(goal / episodes),
            )
    return res


def safe_makedirs(path: str) -> None:
    os.makedirs(path, exist_ok=True)


def save_history_csv(hist: List[TrainHistoryRow], path: str) -> None:
    import csv

    fields = list(TrainHistoryRow.__annotations__.keys())
    with open(path, "w", newline="") as f:
        w = csv.DictWriter(f, fieldnames=fields)
        w.writeheader()
        for row in hist:
            w.writerow({k: getattr(row, k) for k in fields})


def plot_training_curves(results: List[TrainResult], outdir: str) -> None:
    import matplotlib.pyplot as plt

    safe_makedirs(outdir)

    # 1) NBS objective
    plt.figure()
    for r in results:
        xs = [h.outer_iter for h in r.history]
        ys = [h.nbs_obj for h in r.history]
        plt.plot(xs, ys, marker="o", label=r.solver)
    plt.xlabel("Outer iteration")
    plt.ylabel("NBS objective")
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(outdir, "curve_nbs_obj.png"), dpi=200)
    plt.close()

    # 2) Max conditional regrets
    plt.figure()
    for r in results:
        xs = [h.outer_iter for h in r.history]
        ys = [h.max_regret_R for h in r.history]
        plt.plot(xs, ys, marker="o", label=f"{r.solver}: maxRegret_R")
    for r in results:
        xs = [h.outer_iter for h in r.history]
        ys = [h.max_regret_S for h in r.history]
        plt.plot(xs, ys, marker="x", linestyle="--", label=f"{r.solver}: maxRegret_S")
    plt.xlabel("Outer iteration")
    plt.ylabel("Max conditional regret")
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(outdir, "curve_max_regrets.png"), dpi=200)
    plt.close()

    # 3) Entropy of X
    plt.figure()
    for r in results:
        xs = [h.outer_iter for h in r.history]
        ys = [h.entropy_X for h in r.history]
        plt.plot(xs, ys, marker="o", label=r.solver)
    plt.xlabel("Outer iteration")
    plt.ylabel("Entropy H(X)")
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(outdir, "curve_entropy_X.png"), dpi=200)
    plt.close()

    # 4) Self-play outcomes
    plt.figure()
    for r in results:
        xs = [h.outer_iter for h in r.history]
        ys = [h.selfplay_UR for h in r.history]
        plt.plot(xs, ys, marker="o", label=f"{r.solver}: UR")
    for r in results:
        xs = [h.outer_iter for h in r.history]
        ys = [h.selfplay_US for h in r.history]
        plt.plot(xs, ys, marker="x", linestyle="--", label=f"{r.solver}: US")
    plt.xlabel("Outer iteration")
    plt.ylabel("Expected utility under X")
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(outdir, "curve_selfplay_utils.png"), dpi=200)
    plt.close()


def plot_benchmark_bars(
    robot_names: List[str],
    sensor_names: List[str],
    bench: Dict[Tuple[int, int], BenchCell],
    outdir: str,
) -> None:
    import matplotlib.pyplot as plt

    safe_makedirs(outdir)

    # For each sensor, bar chart of robot UR
    for j, sname in enumerate(sensor_names):
        plt.figure(figsize=(10, 4))
        vals = [bench[(i, j)].UR for i in range(len(robot_names))]
        plt.bar(range(len(robot_names)), vals)
        plt.xticks(range(len(robot_names)), robot_names, rotation=35, ha="right")
        plt.ylabel("Robot utility")
        plt.title(f"Robot utility vs sensor={sname}")
        plt.tight_layout()
        plt.savefig(os.path.join(outdir, f"bench_UR_vs_{sname}.png"), dpi=200)
        plt.close()

    # For each sensor, bar chart of robot goal rate
    for j, sname in enumerate(sensor_names):
        plt.figure(figsize=(10, 4))
        vals = [bench[(i, j)].goal for i in range(len(robot_names))]
        plt.bar(range(len(robot_names)), vals)
        plt.xticks(range(len(robot_names)), robot_names, rotation=35, ha="right")
        plt.ylabel("Goal rate")
        plt.title(f"Robot goal rate vs sensor={sname}")
        plt.tight_layout()
        plt.savefig(os.path.join(outdir, f"bench_goal_vs_{sname}.png"), dpi=200)
        plt.close()


# =============================================================================
# Main pipeline wrapper
# =============================================================================

def run_pipeline(args: argparse.Namespace) -> None:
    log = Logger(args.log_level)

    grid = build_two_corridor_grid()
    sensor_cfg = SensorConfig(
        sensors=((grid.width // 2, 2), (grid.width // 2, 6)),
        radius=2,
        base_p=0.02,
        hotspot_p=0.60,
    )
    env = GridWorldStealthEnv(grid, sensor_cfg, fp=0.05, fn=0.10, seed=args.seed)

    if log.k >= 1:
        log.banner("[PIPELINE] Stealth grid game")
        log.info(f"Grid: {grid.width}x{grid.height} start={grid.start} goal={grid.goal} obstacles={len(grid.obstacles)}")
        log.info(f"Sensors: {sensor_cfg.sensors} radius={sensor_cfg.radius} base_p={sensor_cfg.base_p} hotspot_p={sensor_cfg.hotspot_p}")
        if log.k >= 2:
            log.debug("ASCII map:")
            print_grid_ascii(grid, sensor_cfg)

    solvers: List[str]
    if args.solver == "both":
        solvers = ["marginal", "correlated"]
    else:
        solvers = [args.solver]

    results: List[TrainResult] = []
    for s in solvers:
        # Use a fresh env copy per solver (to avoid RNG coupling)
        env_s = GridWorldStealthEnv(grid, sensor_cfg, fp=0.05, fn=0.10, seed=args.seed)
        results.append(run_training(env_s, args, solver=s, log=log))

    if args.results_dir:
        safe_makedirs(args.results_dir)
        for r in results:
            save_history_csv(r.history, os.path.join(args.results_dir, f"history_{r.solver}.csv"))

    if args.save_plots and args.results_dir:
        plot_training_curves(results, outdir=args.results_dir)

    # Benchmarks on the same game
    if args.run_benchmarks:
        log.banner("[BENCH] Heuristics on same game")
        M_modes = len(sensor_cfg.sensors)

        # Robot heuristics
        #  - fixed paths from initial set
        robots_init, sensors_init = build_initial_policies_for_bench(env, M_modes)

        # Sensor heuristics
        sensor_methods: List[SensorPolicy] = [
            FixedModeSensorPolicy(0, "S_Mode0"),
            FixedModeSensorPolicy(1, "S_Mode1"),
            AlternatingSensorPolicy(M_modes, "S_Alternate"),
            RandomModeSensorPolicy(M_modes, seed=args.seed + 123, name="S_RandomMode"),
        ]

        bench = run_policy_matrix_benchmark(
            env,
            robot_methods=robots_init,
            sensor_methods=sensor_methods,
            M_modes=M_modes,
            episodes=args.bench_episodes,
            base_seed=555000,
        )

        # Print a compact table
        robot_names = [p.name for p in robots_init]
        sensor_names = [p.name for p in sensor_methods]
        for j, sname in enumerate(sensor_names):
            log.info(f"[BENCH] Sensor={sname}")
            for i, rname in enumerate(robot_names):
                cell = bench[(i, j)]
                log.info(f"  Robot={rname:22s} UR={cell.UR:8.2f} goal%={100*cell.goal:5.1f} det%={100*cell.det:5.1f}")

        if args.save_plots and args.results_dir:
            plot_benchmark_bars(robot_names, sensor_names, bench, outdir=args.results_dir)


# Helper: initial robot set for benchmark (without the solver-added BR policies)

def build_initial_policies_for_bench(env: GridWorldStealthEnv, M_modes: int) -> Tuple[List[RobotPolicy], List[SensorPolicy]]:
    grid = env.grid
    step_cost = lambda a, b: 1.0
    p_short = astar_path(grid, grid.start, grid.goal, step_cost)

    wall_x = grid.width // 2
    p_upper = astar_via_waypoint(grid, grid.start, (wall_x, 2), grid.goal, step_cost)
    p_lower = astar_via_waypoint(grid, grid.start, (wall_x, 6), grid.goal, step_cost)

    # Static risk-aware A* under UNIFORM mode belief
    uniform_mode = np.full(M_modes, 1.0 / M_modes)

    # Construct risk map directly under uniform belief
    H, W = grid.height, grid.width
    risk_uniform = np.zeros((H, W), dtype=float)
    risk_worst = np.zeros((H, W), dtype=float)
    for y in range(H):
        for x in range(W):
            if (x, y) in grid.obstacles:
                risk_uniform[y, x] = np.nan
                risk_worst[y, x] = np.nan
                continue
            vals = [env.true_detection_prob((x, y), m) for m in range(M_modes)]
            risk_uniform[y, x] = float(np.dot(uniform_mode, vals))
            risk_worst[y, x] = float(np.max(vals))

    p_uniform = plan_risk_weighted_path(env, risk_uniform, risk_weight=12.0)
    p_worst = plan_risk_weighted_path(env, risk_worst, risk_weight=12.0)

    robots: List[RobotPolicy] = [
        FixedPathPolicy(p_short, "R_Shortest"),
        FixedPathPolicy(p_upper, "R_UpperCorridor"),
        FixedPathPolicy(p_lower, "R_LowerCorridor"),
        FixedPathPolicy(p_uniform, "R_RiskAStar_Uniform"),
        FixedPathPolicy(p_worst, "R_RiskAStar_WorstCase"),
        OnlineBeliefReplanPolicy(env, risk_weight=12.0, name="R_OnlineBeliefReplan"),
        RandomPolicy("R_Random"),
    ]

    sensors: List[SensorPolicy] = [FixedModeSensorPolicy(m, name=f"S_Mode{m}") for m in range(M_modes)]
    return robots, sensors


# =============================================================================
# CLI
# =============================================================================

def parse_args(argv: Optional[List[str]] = None) -> Tuple[argparse.Namespace, List[str]]:
    p = argparse.ArgumentParser()

    p.add_argument("--seed", type=int, default=0)
    p.add_argument("--log-level", type=str, default="INFO", choices=["QUIET", "INFO", "DEBUG"])

    p.add_argument("--solver", type=str, default="correlated", choices=["marginal", "correlated", "both"])

    p.add_argument("--outer-iters", type=int, default=3)

    p.add_argument("--rollouts-payoff", type=int, default=20)
    p.add_argument("--rollouts-br", type=int, default=30)
    p.add_argument("--risk-weight-br", type=float, default=12.0)

    # NBS knobs
    p.add_argument("--disagreement", type=str, default="minminus", choices=["minminus", "uniform"])
    p.add_argument("--entropy-tau", type=float, default=0.0)

    # Fix mismatch knobs
    p.add_argument("--cond-top-k", type=int, default=2, help="How many top recommendations to check for conditional BRs")
    p.add_argument("--br-eval-rollouts", type=int, default=8, help="Small evaluation rollouts for new deviations")
    p.add_argument("--add-threshold", type=float, default=0.25, help="Minimum estimated conditional gain to add a deviation policy")

    # Eval episodes under joint X (self-play)
    p.add_argument("--eval-episodes", type=int, default=60)

    # Debug
    p.add_argument("--debug-rollout-pair", type=str, default="", help="Print one step-by-step rollout for i,j")

    # Outputs
    p.add_argument("--results-dir", type=str, default="results", help="Directory for csv/plots")
    p.add_argument("--save-plots", action="store_true")

    # Benchmarks
    p.add_argument("--run-benchmarks", action="store_true")
    p.add_argument("--bench-episodes", type=int, default=80)

    args, unknown = p.parse_known_args(args=argv)
    return args, unknown


def main(argv: Optional[List[str]] = None) -> None:
    args, unknown = parse_args(argv=argv)
    if unknown and ("ipykernel" not in sys.modules):
        print(f"[WARN] Ignoring unknown CLI args: {unknown}")
    if args.debug_rollout_pair.strip() == "":
        args.debug_rollout_pair = ""
    run_pipeline(args)


if __name__ == "__main__" and ("ipykernel" not in sys.modules):
    main()


In [21]:
#!/usr/bin/env python3
"""approach2_robust_correlated.py

Robust, reduced-log implementation of "Approach 2" on a stealth gridworld POMDP,
with a FIX for the conceptual mismatch:

  Old mismatch: solve a joint distribution x*(i,j) over (robot policy i, sensor policy j)
  but then compute best responses against the MARGINALS sigma_R, sigma_S.

  Fix here: treat x* as a CORRELATION DEVICE (mediator) and compute BRs against
  CONDITIONAL distributions:

      q_S(.|i) = X[i,:] / sigma_R[i]   (sensor conditional given robot recommendation i)
      q_R(.|j) = X[:,j] / sigma_S[j]   (robot conditional given sensor recommendation j)

  Then add the most profitable *deviation* policy (oracle) based on the most violated
  conditional recommendation.

This makes the PSRO-style expansion step consistent with a correlated-strategy viewpoint.

Also included:
  - Benchmark harness vs simple heuristics on the same game.
  - Saved plots (training curves + bar charts) for presentations.

Run (script):
  python approach2_robust_correlated.py --solver correlated --outer-iters 3 --save-plots

Run (notebook):
  main(argv=["--solver","both","--outer-iters","3","--save-plots"])  # ignore ipykernel args

Tips to encourage MULTIMODAL x* (useful for your "mode recovery"):
  --disagreement uniform --entropy-tau 0.02

"""

from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Tuple

import argparse
import sys
import heapq
import os
import time

import numpy as np


# =============================================================================
# Logging
# =============================================================================

class Logger:
    """Stage-based logger with QUIET/INFO/DEBUG."""

    LEVELS = {"QUIET": 0, "INFO": 1, "DEBUG": 2}

    def __init__(self, level: str = "INFO"):
        level = level.upper()
        if level not in self.LEVELS:
            raise ValueError(f"Unknown log level: {level}. Use QUIET/INFO/DEBUG")
        self.level = level
        self.k = self.LEVELS[level]

    def banner(self, title: str) -> None:
        if self.k >= 1:
            print("" + "=" * 100)
            print(title)
            print("=" * 100)

    def info(self, msg: str) -> None:
        if self.k >= 1:
            print(msg)

    def debug(self, msg: str) -> None:
        if self.k >= 2:
            print(msg)


# =============================================================================
# Types
# =============================================================================

Action = Tuple[int, int]  # (dx, dy)


# =============================================================================
# Environment (grid + hidden sensor mode + noisy alarm observation)
# =============================================================================

@dataclass(frozen=True)
class GridConfig:
    width: int
    height: int
    start: Tuple[int, int]
    goal: Tuple[int, int]
    obstacles: frozenset


@dataclass(frozen=True)
class SensorConfig:
    sensors: Tuple[Tuple[int, int], ...]
    radius: int
    base_p: float
    hotspot_p: float


class GridWorldStealthEnv:
    """Grid world with detection risk controlled by a hidden/selected sensor 'mode'."""

    def __init__(self, grid: GridConfig, sensor_cfg: SensorConfig, fp: float = 0.05, fn: float = 0.10, seed: int = 0):
        self.grid = grid
        self.sensor_cfg = sensor_cfg
        self.fp = float(fp)
        self.fn = float(fn)

        if not (0.0 <= self.fp <= 1.0 and 0.0 <= self.fn <= 1.0):
            raise ValueError("fp and fn must be in [0,1].")
        if not (0.0 <= sensor_cfg.base_p <= 1.0 and 0.0 <= sensor_cfg.hotspot_p <= 1.0):
            raise ValueError("base_p and hotspot_p must be in [0,1].")
        if sensor_cfg.radius < 0:
            raise ValueError("radius must be >= 0")
        if len(sensor_cfg.sensors) == 0:
            raise ValueError("Need at least one sensor center.")

        self.rng = np.random.default_rng(int(seed))
        self.reset(sensor_mode=0)

    def seed(self, seed: int) -> None:
        self.rng = np.random.default_rng(int(seed))

    def reset(self, sensor_mode: int = 0) -> Dict[str, Any]:
        self.t = 0
        self.pos = self.grid.start
        self.sensor_mode = int(sensor_mode)
        self.detected = False
        self.total_true_risk = 0.0
        return {"pos": self.pos, "t": self.t}

    def in_bounds(self, p: Tuple[int, int]) -> bool:
        x, y = p
        return 0 <= x < self.grid.width and 0 <= y < self.grid.height

    def is_free(self, p: Tuple[int, int]) -> bool:
        return self.in_bounds(p) and (p not in self.grid.obstacles)

    def true_detection_prob(self, p: Tuple[int, int], mode: int) -> float:
        if not (0 <= mode < len(self.sensor_cfg.sensors)):
            raise ValueError(f"mode {mode} out of range")
        sx, sy = self.sensor_cfg.sensors[mode]
        x, y = p
        d = abs(x - sx) + abs(y - sy)
        return self.sensor_cfg.hotspot_p if d <= self.sensor_cfg.radius else self.sensor_cfg.base_p

    def observation_prob(self, alarm: int, p_true: float) -> float:
        p_alarm = p_true * (1.0 - self.fn) + (1.0 - p_true) * self.fp
        if alarm == 1:
            return p_alarm
        if alarm == 0:
            return 1.0 - p_alarm
        raise ValueError("alarm must be 0 or 1")

    def step(self, a: Action) -> Dict[str, Any]:
        if self.detected:
            return {"pos": self.pos, "t": self.t, "alarm": 1, "p_true": 1.0, "detected": True, "done": True}

        self.t += 1
        nx = self.pos[0] + int(a[0])
        ny = self.pos[1] + int(a[1])
        np_ = (nx, ny)
        if self.is_free(np_):
            self.pos = np_

        p_true = float(self.true_detection_prob(self.pos, self.sensor_mode))
        self.total_true_risk += p_true

        if self.rng.random() < p_true:
            self.detected = True

        p_alarm = p_true * (1.0 - self.fn) + (1.0 - p_true) * self.fp
        alarm = 1 if (self.rng.random() < p_alarm) else 0

        done = self.detected or (self.pos == self.grid.goal) or (self.t >= 200)
        return {"pos": self.pos, "t": self.t, "alarm": alarm, "p_true": p_true, "detected": self.detected, "done": done}


def build_two_corridor_grid(width: int = 15, height: int = 9) -> GridConfig:
    obstacles = set()
    wall_x = width // 2
    gap_ys = {2, 6}
    for y in range(height):
        if y not in gap_ys:
            obstacles.add((wall_x, y))

    start = (1, height - 2)
    goal = (width - 2, 1)
    if start in obstacles or goal in obstacles:
        raise RuntimeError("Start/goal blocked unexpectedly")

    return GridConfig(width=width, height=height, start=start, goal=goal, obstacles=frozenset(obstacles))


def print_grid_ascii(grid: GridConfig, sensor_cfg: SensorConfig) -> None:
    W, H = grid.width, grid.height
    obs = set(grid.obstacles)
    sens = set(sensor_cfg.sensors)
    for y in range(H):
        row = []
        for x in range(W):
            p = (x, y)
            if p == grid.start:
                row.append("R")
            elif p == grid.goal:
                row.append("G")
            elif p in sens:
                row.append("S")
            elif p in obs:
                row.append("#")
            else:
                row.append(".")
        print("".join(row))


# =============================================================================
# Belief over modes
# =============================================================================

class ModeBelief:
    """Exact belief over discrete modes m in {0..M-1}."""

    def __init__(self, M: int, init: Optional[np.ndarray] = None):
        self.M = int(M)
        if self.M <= 0:
            raise ValueError("M must be >= 1")
        if init is None:
            self.b = np.full(self.M, 1.0 / self.M)
        else:
            init = np.asarray(init, dtype=float).reshape(-1)
            if init.shape != (self.M,):
                raise ValueError("init shape mismatch")
            if np.any(init < 0):
                raise ValueError("init must be nonnegative")
            s = float(init.sum())
            self.b = init / s if s > 0 else np.full(self.M, 1.0 / self.M)

    def update(self, env: GridWorldStealthEnv, alarm: int, pos: Tuple[int, int], eps: float = 1e-12) -> None:
        like = np.zeros(self.M, dtype=float)
        for m in range(self.M):
            p_true = env.true_detection_prob(pos, m)
            like[m] = env.observation_prob(alarm, p_true)
        post = self.b * like
        Z = float(post.sum())
        if (not np.isfinite(Z)) or Z < eps:
            return
        self.b = post / Z


# =============================================================================
# Policies
# =============================================================================

class RobotPolicy:
    name: str = "RobotPolicy"

    def reset(self, start_pos: Tuple[int, int]) -> None:
        pass

    def act(self, env: GridWorldStealthEnv, belief: ModeBelief, last_obs: Optional[int]) -> Action:
        raise NotImplementedError


class FixedPathPolicy(RobotPolicy):
    def __init__(self, path: List[Tuple[int, int]], name: str):
        if len(path) < 2:
            raise ValueError("Path must have >=2 states")
        self.path = list(path)
        self.name = str(name)
        self._idx = 0

    def reset(self, start_pos: Tuple[int, int]) -> None:
        try:
            self._idx = self.path.index(start_pos)
        except ValueError:
            self._idx = 0

    def act(self, env: GridWorldStealthEnv, belief: ModeBelief, last_obs: Optional[int]) -> Action:
        cur = env.pos
        if self._idx >= len(self.path) - 1:
            return (0, 0)
        if cur != self.path[self._idx]:
            try:
                self._idx = self.path.index(cur, self._idx)
            except ValueError:
                return (0, 0)
        nxt = self.path[self._idx + 1]
        dx = int(np.clip(nxt[0] - cur[0], -1, 1))
        dy = int(np.clip(nxt[1] - cur[1], -1, 1))
        self._idx += 1
        return (dx, dy)


class RandomPolicy(RobotPolicy):
    """Reproducible random policy using env.rng."""

    def __init__(self, name: str = "R_Random"):
        self.name = name

    def reset(self, start_pos: Tuple[int, int]) -> None:
        pass

    def act(self, env: GridWorldStealthEnv, belief: ModeBelief, last_obs: Optional[int]) -> Action:
        candidates: List[Action] = []
        x, y = env.pos
        for a in [(1, 0), (-1, 0), (0, 1), (0, -1), (0, 0)]:
            np_ = (x + a[0], y + a[1])
            if env.is_free(np_):
                candidates.append(a)
        if not candidates:
            return (0, 0)
        return candidates[int(env.rng.integers(0, len(candidates)))]


class OnlineBeliefReplanPolicy(RobotPolicy):
    """POMDP-ish heuristic: replan each step using risk map induced by current belief b_t."""

    def __init__(self, env: GridWorldStealthEnv, risk_weight: float = 12.0, name: str = "R_OnlineBeliefReplan"):
        self.env = env
        self.risk_weight = float(risk_weight)
        self.name = name
        self._cached_next: Optional[Tuple[int, int]] = None

    def reset(self, start_pos: Tuple[int, int]) -> None:
        self._cached_next = None

    def act(self, env: GridWorldStealthEnv, belief: ModeBelief, last_obs: Optional[int]) -> Action:
        # Build a risk map from belief over modes.
        mode_probs = belief.b

        def step_cost(frm: Tuple[int, int], to: Tuple[int, int]) -> float:
            x, y = to
            # expected risk at to under belief
            r = 0.0
            for m, pm in enumerate(mode_probs):
                r += float(pm) * float(env.true_detection_prob(to, m))
            return 1.0 + self.risk_weight * r

        # Plan from current pos to goal (one-step receding horizon)
        try:
            path = astar_path(env.grid, env.pos, env.grid.goal, step_cost)
            if len(path) < 2:
                return (0, 0)
            nxt = path[1]
            dx = int(np.clip(nxt[0] - env.pos[0], -1, 1))
            dy = int(np.clip(nxt[1] - env.pos[1], -1, 1))
            return (dx, dy)
        except Exception:
            return (0, 0)


class SensorPolicy:
    name: str = "SensorPolicy"

    def reset(self) -> None:
        pass

    def select_mode(self, t: int) -> int:
        raise NotImplementedError


class FixedModeSensorPolicy(SensorPolicy):
    def __init__(self, mode: int, name: Optional[str] = None):
        self.mode = int(mode)
        self.name = name or f"S_Mode{mode}"

    def reset(self) -> None:
        pass

    def select_mode(self, t: int) -> int:
        return self.mode


class AlternatingSensorPolicy(SensorPolicy):
    """Simple sensor heuristic for benchmarks: alternate modes 0,1,0,1,..."""

    def __init__(self, M: int, name: str = "S_Alternate"):
        self.M = int(M)
        self.name = name

    def reset(self) -> None:
        pass

    def select_mode(self, t: int) -> int:
        return int(t % self.M)


class RandomModeSensorPolicy(SensorPolicy):
    """Benchmark sensor: random mode each step (uses numpy Generator for reproducibility)."""

    def __init__(self, M: int, seed: int = 0, name: str = "S_RandomMode"):
        self.M = int(M)
        self.rng = np.random.default_rng(int(seed))
        self.name = name

    def reset(self) -> None:
        pass

    def select_mode(self, t: int) -> int:
        return int(self.rng.integers(0, self.M))


# =============================================================================
# A* (used for planning)
# =============================================================================

def astar_path(
    grid: GridConfig,
    start: Tuple[int, int],
    goal: Tuple[int, int],
    step_cost: Callable[[Tuple[int, int], Tuple[int, int]], float],
    max_expansions: int = 250_000,
) -> List[Tuple[int, int]]:
    if start == goal:
        return [start]

    def h(p: Tuple[int, int]) -> float:
        return abs(p[0] - goal[0]) + abs(p[1] - goal[1])

    def neighbors(p: Tuple[int, int]):
        x, y = p
        for dx, dy in [(1, 0), (-1, 0), (0, 1), (0, -1)]:
            np_ = (x + dx, y + dy)
            if 0 <= np_[0] < grid.width and 0 <= np_[1] < grid.height and np_ not in grid.obstacles:
                yield np_

    open_heap: List[Tuple[float, float, Tuple[int, int]]] = []
    heapq.heappush(open_heap, (h(start), 0.0, start))

    came: Dict[Tuple[int, int], Optional[Tuple[int, int]]] = {start: None}
    gscore: Dict[Tuple[int, int], float] = {start: 0.0}

    expansions = 0
    while open_heap:
        _, _, cur = heapq.heappop(open_heap)
        expansions += 1
        if cur == goal:
            path: List[Tuple[int, int]] = []
            while cur is not None:
                path.append(cur)
                cur = came[cur]
            path.reverse()
            return path
        if expansions > max_expansions:
            raise RuntimeError("A* exceeded max expansions")

        for nb in neighbors(cur):
            tentative = gscore[cur] + float(step_cost(cur, nb))
            if (nb not in gscore) or (tentative < gscore[nb] - 1e-12):
                gscore[nb] = tentative
                came[nb] = cur
                heapq.heappush(open_heap, (tentative + h(nb), tentative, nb))

    raise RuntimeError("A* failed: unreachable goal")


def astar_via_waypoint(
    grid: GridConfig,
    start: Tuple[int, int],
    waypoint: Tuple[int, int],
    goal: Tuple[int, int],
    step_cost: Callable[[Tuple[int, int], Tuple[int, int]], float],
) -> List[Tuple[int, int]]:
    p1 = astar_path(grid, start, waypoint, step_cost)
    p2 = astar_path(grid, waypoint, goal, step_cost)
    return p1[:-1] + p2


# =============================================================================
# Rollouts + payoffs
# =============================================================================

@dataclass
class EpisodeStats:
    steps: int
    reached_goal: bool
    detected: bool
    total_true_risk: float
    U_R: float
    U_S: float


def rollout_episode(
    env: GridWorldStealthEnv,
    robot: RobotPolicy,
    sensor: SensorPolicy,
    M_modes: int,
    seed: int,
    max_steps: int = 200,
    lambda_risk: float = 1.0,
    det_penalty: float = 50.0,
    sensor_energy_per_step: float = 0.2,
    step_debug: bool = False,
) -> EpisodeStats:
    env.seed(seed)
    sensor.reset()
    env.reset(sensor_mode=sensor.select_mode(0))

    belief = ModeBelief(M_modes)
    robot.reset(env.pos)

    total_risk = 0.0
    last_alarm: Optional[int] = None

    for _ in range(max_steps):
        env.sensor_mode = sensor.select_mode(env.t)
        a = robot.act(env, belief, last_alarm)
        out = env.step(a)

        total_risk += float(out["p_true"])
        belief.update(env, out["alarm"], out["pos"])
        last_alarm = int(out["alarm"])

        if step_debug:
            print(
                f"[Step] t={out['t']:3d} pos={out['pos']} a={a} p_true={out['p_true']:.3f} "
                f"alarm={out['alarm']} det={out['detected']} done={out['done']} b={belief.b.round(3)}"
            )

        if out["done"]:
            break

    reached_goal = (env.pos == env.grid.goal)
    detected = bool(env.detected)
    steps = int(env.t)

    cost_R = steps + lambda_risk * total_risk + (det_penalty if detected else 0.0)
    U_R = -float(cost_R)

    U_S = float((det_penalty if detected else 0.0) + lambda_risk * total_risk - sensor_energy_per_step * steps)

    return EpisodeStats(steps=steps, reached_goal=reached_goal, detected=detected, total_true_risk=float(total_risk), U_R=U_R, U_S=U_S)


def evaluate_payoffs(
    env: GridWorldStealthEnv,
    robots: List[RobotPolicy],
    sensors: List[SensorPolicy],
    M_modes: int,
    rollouts: int,
    base_seed: int,
    log: Logger,
    debug_rollout_pair: Optional[Tuple[int, int]] = None,
) -> Tuple[np.ndarray, np.ndarray, Dict[Tuple[int, int], Dict[str, float]]]:
    m, n = len(robots), len(sensors)
    U_R = np.zeros((m, n), dtype=float)
    U_S = np.zeros((m, n), dtype=float)
    diag: Dict[Tuple[int, int], Dict[str, float]] = {}

    log.info(f"[Eval] Estimating payoffs: m={m}, n={n}, rollouts={rollouts}, base_seed={base_seed}")

    for i, rpol in enumerate(robots):
        for j, spol in enumerate(sensors):
            step_debug = (debug_rollout_pair == (i, j))

            r_list: List[float] = []
            s_list: List[float] = []
            det = 0
            goal = 0
            steps_list: List[int] = []
            risk_list: List[float] = []

            for k in range(rollouts):
                seed = base_seed + 100000 * i + 1000 * j + k
                st = rollout_episode(env, rpol, spol, M_modes=M_modes, seed=seed, step_debug=step_debug)
                r_list.append(st.U_R)
                s_list.append(st.U_S)
                det += int(st.detected)
                goal += int(st.reached_goal)
                steps_list.append(st.steps)
                risk_list.append(st.total_true_risk)

                if step_debug:
                    step_debug = False  # only show one rollout

            U_R[i, j] = float(np.mean(r_list))
            U_S[i, j] = float(np.mean(s_list))

            diag[(i, j)] = {
                "det_rate": det / rollouts,
                "goal_rate": goal / rollouts,
                "mean_steps": float(np.mean(steps_list)),
                "mean_risk": float(np.mean(risk_list)),
                "std_UR": float(np.std(r_list)),
                "std_US": float(np.std(s_list)),
            }

    if log.k >= 1:
        log.info("[Eval] Compact payoff summary:")
        for i, rpol in enumerate(robots):
            for j, spol in enumerate(sensors):
                d = diag[(i, j)]
                log.info(
                    f"  (R{i}:{rpol.name}, S{j}:{spol.name}) "
                    f"UR={U_R[i,j]:8.3f}±{d['std_UR']:.2f} | "
                    f"US={U_S[i,j]:8.3f}±{d['std_US']:.2f} | "
                    f"det%={100*d['det_rate']:5.1f} goal%={100*d['goal_rate']:5.1f} "
                    f"steps={d['mean_steps']:.1f} risk={d['mean_risk']:.2f}"
                )

    return U_R, U_S, diag


# =============================================================================
# NBS solver (with optional entropy regularization)
# =============================================================================

def project_simplex(v: np.ndarray, z: float = 1.0) -> np.ndarray:
    v = np.asarray(v, dtype=float).reshape(-1)
    if v.size == 0:
        raise ValueError("Empty vector")
    if z <= 0:
        raise ValueError("z must be > 0")

    u = np.sort(v)[::-1]
    cssv = np.cumsum(u)
    rho = np.nonzero(u * np.arange(1, v.size + 1) > (cssv - z))[0]
    if rho.size == 0:
        return np.full_like(v, z / v.size)
    rho = int(rho[-1])
    theta = (cssv[rho] - z) / (rho + 1.0)
    w = np.maximum(v - theta, 0.0)
    s = float(w.sum())
    if not np.isfinite(s) or s <= 0:
        return np.full_like(v, z / v.size)
    return w * (z / s)


@dataclass
class NBSResult:
    x: np.ndarray
    obj: float
    gains: Tuple[float, float]
    support: int


def solve_nbs(
    uR: np.ndarray,
    uS: np.ndarray,
    log: Logger,
    max_iters: int = 400,
    alpha: float = 0.5,
    tol_l1: float = 1e-6,
    kappa: float = 1e-6,
    disagreement: str = "minminus",
    entropy_tau: float = 0.0,
) -> NBSResult:
    uR = np.asarray(uR, dtype=float).reshape(-1)
    uS = np.asarray(uS, dtype=float).reshape(-1)
    if uR.shape != uS.shape:
        raise ValueError("uR and uS must have same shape")
    d = uR.size
    if d < 2:
        raise ValueError("Need >=2 joint actions")

    unif = np.full(d, 1.0 / d)

    disagreement = disagreement.lower().strip()
    if disagreement == "minminus":
        dR = float(np.min(uR) - 1.0)
        dS = float(np.min(uS) - 1.0)
    elif disagreement == "uniform":
        dR = float(uR @ unif)
        dS = float(uS @ unif)
    else:
        raise ValueError("disagreement must be 'minminus' or 'uniform'")

    x = unif.copy()

    def gains(xv: np.ndarray) -> Tuple[float, float]:
        return float(uR @ xv - dR), float(uS @ xv - dS)

    def entropy(xv: np.ndarray) -> float:
        xx = np.clip(xv, 1e-12, 1.0)
        return float(-np.sum(xx * np.log(xx)))

    def obj(xv: np.ndarray) -> float:
        gR, gS = gains(xv)
        gR = max(gR, kappa)
        gS = max(gS, kappa)
        return float(np.log(gR) + np.log(gS) + entropy_tau * entropy(xv))

    def grad(xv: np.ndarray) -> np.ndarray:
        gR, gS = gains(xv)
        gR = max(gR, kappa)
        gS = max(gS, kappa)
        g = (uR / gR) + (uS / gS)
        if entropy_tau > 0:
            xx = np.clip(xv, 1e-12, 1.0)
            g += entropy_tau * (-(np.log(xx) + 1.0))
        return g

    last = obj(x)
    log.info(f"[NBS] d={d} disagreement=({dR:.3f},{dS:.3f}) entropy_tau={entropy_tau:.3g}")

    for t in range(1, max_iters + 1):
        g = grad(x)
        a = alpha
        improved = False
        for _ in range(30):
            x_new = project_simplex(x + a * g)
            new_obj = obj(x_new)
            if new_obj >= last - 1e-12:
                improved = True
                break
            a *= 0.5
            if a < 1e-6:
                break
        if not improved:
            break

        delta = float(np.linalg.norm(x_new - x, ord=1))
        x = x_new
        last = new_obj

        if log.k >= 2 and (t <= 5 or t % 25 == 0):
            gR, gS = gains(x)
            top = np.argsort(-x)[:5]
            top_str = ", ".join([f"{i}:{x[i]:.3f}" for i in top])
            log.debug(f"[NBS][it={t:3d}] obj={last:.6f} gains=({gR:.3f},{gS:.3f}) L1={delta:.2e} top={top_str}")

        if delta < tol_l1:
            break

    gR, gS = gains(x)
    support = int(np.sum(x > 1e-6))
    log.info(f"[NBS] done: obj={last:.6f} gains=({gR:.3f},{gS:.3f}) support={support}/{d}")

    return NBSResult(x=x, obj=float(last), gains=(float(gR), float(gS)), support=support)


def joint_to_matrix(x: np.ndarray, m: int, n: int) -> np.ndarray:
    x = np.asarray(x, dtype=float).reshape(-1)
    if x.size != m * n:
        raise ValueError("x size mismatch")
    X = x.reshape((m, n))
    s = float(X.sum())
    if not np.isfinite(s) or abs(s - 1.0) > 1e-6:
        # Renormalize defensively
        X = X / max(s, 1e-12)
    return X


def marginals_from_joint(X: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    sigma_R = X.sum(axis=1)
    sigma_S = X.sum(axis=0)
    if sigma_R.sum() > 0:
        sigma_R = sigma_R / sigma_R.sum()
    if sigma_S.sum() > 0:
        sigma_S = sigma_S / sigma_S.sum()
    return sigma_R, sigma_S


def entropy_of_joint(X: np.ndarray) -> float:
    xx = np.clip(X.reshape(-1), 1e-12, 1.0)
    return float(-np.sum(xx * np.log(xx)))


# =============================================================================
# Best responses (marginal vs correlated)
# =============================================================================

def compute_expected_risk_map_from_policy_mixture(
    env: GridWorldStealthEnv,
    sensors: List[SensorPolicy],
    pi: np.ndarray,
    M_modes: int,
) -> np.ndarray:
    """Expected p_true(cell) under mixture over sensor POLICIES (not modes).

    For each sensor policy j, we use its mode at t=0 as its defining mode.
    (This matches FixedModeSensorPolicy exactly.)
    """
    pi = np.asarray(pi, dtype=float).reshape(-1)
    if pi.size != len(sensors):
        raise ValueError("mixture length mismatch")

    mode_probs = np.zeros(M_modes, dtype=float)
    for j, sp in enumerate(sensors):
        m = int(sp.select_mode(0))
        if not (0 <= m < M_modes):
            raise ValueError("invalid sensor mode")
        mode_probs[m] += float(pi[j])
    if mode_probs.sum() > 0:
        mode_probs = mode_probs / mode_probs.sum()

    H, W = env.grid.height, env.grid.width
    risk = np.zeros((H, W), dtype=float)
    for y in range(H):
        for x in range(W):
            if (x, y) in env.grid.obstacles:
                risk[y, x] = np.nan
                continue
            val = 0.0
            for m in range(M_modes):
                val += float(mode_probs[m]) * float(env.true_detection_prob((x, y), m))
            risk[y, x] = float(val)

    return risk


def plan_risk_weighted_path(env: GridWorldStealthEnv, risk_map: np.ndarray, risk_weight: float) -> List[Tuple[int, int]]:
    def step_cost(frm: Tuple[int, int], to: Tuple[int, int]) -> float:
        x, y = to
        r = risk_map[y, x]
        if not np.isfinite(r):
            return 1e9
        return 1.0 + float(risk_weight) * float(r)

    return astar_path(env.grid, env.grid.start, env.grid.goal, step_cost)


def robot_best_response_to_mixture(
    env: GridWorldStealthEnv,
    sensors: List[SensorPolicy],
    pi_S: np.ndarray,
    robots: List[RobotPolicy],
    M_modes: int,
    risk_weight: float,
    log: Logger,
    tag: str,
) -> RobotPolicy:
    risk = compute_expected_risk_map_from_policy_mixture(env, sensors, pi_S, M_modes=M_modes)
    try:
        path = plan_risk_weighted_path(env, risk, risk_weight=risk_weight)
    except Exception as e:
        log.info(f"[RobotBR] WARNING A* failed ({tag}): {e}")
        return robots[0]

    path_tuple = tuple(path)
    for p in robots:
        if isinstance(p, FixedPathPolicy) and tuple(p.path) == path_tuple:
            log.info(f"[RobotBR] ({tag}) BR path already exists: {p.name}")
            return p

    newp = FixedPathPolicy(path, name=f"R_BR_{tag}_w{risk_weight:.1f}_len{len(path)}")
    log.info(f"[RobotBR] ({tag}) Added new robot policy: {newp.name}")
    return newp


def sensor_best_response_to_mixture(
    env: GridWorldStealthEnv,
    robots: List[RobotPolicy],
    pi_R: np.ndarray,
    candidate_modes: List[int],
    M_modes: int,
    rollouts: int,
    base_seed: int,
    log: Logger,
    tag: str,
) -> FixedModeSensorPolicy:
    """Sensor best response with common-random-numbers (CRN).

    Why CRN matters: payoff variance is large (detection is a rare/threshold event).
    If each candidate mode is evaluated on different random rollouts, you can pick
    the wrong 'best mode' by noise, which then breaks the PSRO expansion logic.

    Fix: reuse the same sampled robot indices AND the same episode seeds across all
    candidate modes.
    """
    pi_R = np.asarray(pi_R, dtype=float).reshape(-1)
    if pi_R.size != len(robots):
        raise ValueError("pi_R length mismatch")

    rng = np.random.default_rng(int(base_seed))

    # Same robot-index samples for every mode
    robot_idxs = rng.choice(len(robots), size=rollouts, p=pi_R, replace=True)

    # Same episode seeds for every mode (common random numbers)
    seeds = (int(base_seed) + np.arange(rollouts)).astype(int)

    best_mode: Optional[int] = None
    best_val = -1e18

    for mode in candidate_modes:
        if not (0 <= mode < M_modes):
            continue
        sp = FixedModeSensorPolicy(mode, name=f"S_BR_{tag}_Mode{mode}")

        vals: List[float] = []
        for k in range(rollouts):
            rp = robots[int(robot_idxs[k])]
            st = rollout_episode(env, rp, sp, M_modes=M_modes, seed=int(seeds[k]))
            vals.append(st.U_S)

        mean_u = float(np.mean(vals))
        if log.k >= 2:
            log.debug(f"[SensorBR] ({tag}) mode={mode} E[US]={mean_u:.3f} std={float(np.std(vals)):.2f}")

        if mean_u > best_val:
            best_val = mean_u
            best_mode = mode

    if best_mode is None:
        raise RuntimeError("No valid sensor BR mode found")

    log.info(f"[SensorBR] ({tag}) Best mode={best_mode} E[US]={best_val:.3f}")
    return FixedModeSensorPolicy(best_mode, name=f"S_BR_{tag}_Mode{best_mode}")


def conditional_sensor_given_robot(X: np.ndarray, i: int, eps: float = 1e-12) -> np.ndarray:
    row = np.asarray(X[i, :], dtype=float)
    s = float(row.sum())
    if s <= eps:
        return np.full_like(row, 1.0 / row.size)
    return row / s


def conditional_robot_given_sensor(X: np.ndarray, j: int, eps: float = 1e-12) -> np.ndarray:
    col = np.asarray(X[:, j], dtype=float)
    s = float(col.sum())
    if s <= eps:
        return np.full_like(col, 1.0 / col.size)
    return col / s


def compute_ce_regrets(U_R: np.ndarray, U_S: np.ndarray, X: np.ndarray, eps: float = 1e-12) -> Dict[str, float]:
    """Conditional recommendation regrets (CE-style) computed on current meta-game.

    For robot (given recommendation i):
        regret_R(i) = max_{i'} E_{j~q(.|i)}[U_R(i',j) - U_R(i,j)]

    For sensor (given recommendation j):
        regret_S(j) = max_{j'} E_{i~q(.|j)}[U_S(i,j') - U_S(i,j)]

    Returns max and average regrets.
    """
    m, n = U_R.shape
    assert U_S.shape == (m, n)
    assert X.shape == (m, n)

    sigma_R, sigma_S = marginals_from_joint(X)

    reg_R = []
    for i in range(m):
        if sigma_R[i] <= eps:
            continue
        q = conditional_sensor_given_robot(X, i, eps=eps)
        rec = float(np.dot(q, U_R[i, :]))
        best = rec
        for ip in range(m):
            val = float(np.dot(q, U_R[ip, :]))
            if val > best:
                best = val
        reg_R.append(best - rec)

    reg_S = []
    for j in range(n):
        if sigma_S[j] <= eps:
            continue
        q = conditional_robot_given_sensor(X, j, eps=eps)
        rec = float(np.dot(q, U_S[:, j]))
        best = rec
        for jp in range(n):
            val = float(np.dot(q, U_S[:, jp]))
            if val > best:
                best = val
        reg_S.append(best - rec)

    return {
        "max_regret_R": float(max(reg_R) if reg_R else 0.0),
        "max_regret_S": float(max(reg_S) if reg_S else 0.0),
        "mean_regret_R": float(np.mean(reg_R) if reg_R else 0.0),
        "mean_regret_S": float(np.mean(reg_S) if reg_S else 0.0),
    }


def find_sensor_by_mode(pols: List[SensorPolicy], mode: int) -> Optional[FixedModeSensorPolicy]:
    for p in pols:
        if isinstance(p, FixedModeSensorPolicy) and p.mode == mode:
            return p
    return None


# =============================================================================
# Policy initialization + evaluator for joint strategy
# =============================================================================

def build_initial_policies(env: GridWorldStealthEnv, M_modes: int) -> Tuple[List[RobotPolicy], List[SensorPolicy]]:
    grid = env.grid

    step_cost = lambda a, b: 1.0
    p_short = astar_path(grid, grid.start, grid.goal, step_cost)

    wall_x = grid.width // 2
    p_upper = astar_via_waypoint(grid, grid.start, (wall_x, 2), grid.goal, step_cost)
    p_lower = astar_via_waypoint(grid, grid.start, (wall_x, 6), grid.goal, step_cost)

    robots: List[RobotPolicy] = [
        FixedPathPolicy(p_short, "R_Shortest"),
        FixedPathPolicy(p_upper, "R_UpperCorridor"),
        FixedPathPolicy(p_lower, "R_LowerCorridor"),
        RandomPolicy("R_Random"),
        OnlineBeliefReplanPolicy(env, risk_weight=12.0, name="R_OnlineBeliefReplan"),
    ]
    sensors: List[SensorPolicy] = [FixedModeSensorPolicy(m, name=f"S_Mode{m}") for m in range(M_modes)]
    return robots, sensors


@dataclass
class StrategyEval:
    mean_U_R: float
    mean_U_S: float
    det_rate: float
    goal_rate: float
    mean_steps: float
    mean_risk: float


def evaluate_joint_strategy(
    env: GridWorldStealthEnv,
    robots: List[RobotPolicy],
    sensors: List[SensorPolicy],
    X: np.ndarray,
    M_modes: int,
    episodes: int,
    base_seed: int,
) -> StrategyEval:
    m, n = X.shape
    probs = X.reshape(-1)
    probs = probs / max(float(probs.sum()), 1e-12)

    rng = np.random.default_rng(int(base_seed))

    UR = []
    US = []
    det = 0
    goal = 0
    steps_list = []
    risk_list = []

    for k in range(episodes):
        idx = int(rng.choice(m * n, p=probs))
        i, j = np.unravel_index(idx, (m, n))
        seed = base_seed + k
        st = rollout_episode(env, robots[i], sensors[j], M_modes=M_modes, seed=seed)
        UR.append(st.U_R)
        US.append(st.U_S)
        det += int(st.detected)
        goal += int(st.reached_goal)
        steps_list.append(st.steps)
        risk_list.append(st.total_true_risk)

    return StrategyEval(
        mean_U_R=float(np.mean(UR)),
        mean_U_S=float(np.mean(US)),
        det_rate=float(det / episodes),
        goal_rate=float(goal / episodes),
        mean_steps=float(np.mean(steps_list)),
        mean_risk=float(np.mean(risk_list)),
    )


# =============================================================================
# Training loop (marginal vs correlated)
# =============================================================================

@dataclass
class TrainHistoryRow:
    outer_iter: int
    m: int
    n: int
    nbs_obj: float
    entropy_X: float
    max_regret_R: float
    max_regret_S: float
    selfplay_UR: float
    selfplay_US: float
    selfplay_det: float
    selfplay_goal: float
    seconds: float


@dataclass
class TrainResult:
    solver: str
    env: GridWorldStealthEnv
    robots: List[RobotPolicy]
    sensors: List[SensorPolicy]
    X: np.ndarray
    history: List[TrainHistoryRow]


def run_training(env: GridWorldStealthEnv, args: argparse.Namespace, solver: str, log: Logger) -> TrainResult:
    t0_all = time.time()

    grid = env.grid
    sensor_cfg = env.sensor_cfg
    M_modes = len(sensor_cfg.sensors)

    robots, sensors = build_initial_policies(env, M_modes=M_modes)

    # Optional: reduce initial set if you want smaller games.
    # (We keep it as-is for benchmarks.)

    if log.k >= 1:
        log.info(f"[{solver}] Initial robots: " + ", ".join([p.name for p in robots]))
        log.info(f"[{solver}] Initial sensors: " + ", ".join([p.name for p in sensors]))

    debug_pair = None
    if args.debug_rollout_pair:
        parts = args.debug_rollout_pair.split(",")
        if len(parts) == 2:
            debug_pair = (int(parts[0]), int(parts[1]))
            log.info(f"[{solver}] Will print one step-by-step rollout for pair {debug_pair} (only once).")

    history: List[TrainHistoryRow] = []
    X = None

    for it in range(1, args.outer_iters + 1):
        t0 = time.time()
        log.banner(f"[{solver}] Outer iter {it}/{args.outer_iters}")

        U_R, U_S, _diag = evaluate_payoffs(
            env,
            robots,
            sensors,
            M_modes=M_modes,
            rollouts=args.rollouts_payoff,
            base_seed=1000 + 100 * it,
            log=log,
            debug_rollout_pair=debug_pair,
        )
        debug_pair = None

        # Solve NBS over joint actions
        uR = U_R.reshape(-1)
        uS = U_S.reshape(-1)

        nbs = solve_nbs(
            uR,
            uS,
            log=log,
            disagreement=args.disagreement,
            entropy_tau=args.entropy_tau,
        )

        m, n = U_R.shape
        X = joint_to_matrix(nbs.x, m, n)
        sigma_R, sigma_S = marginals_from_joint(X)

        # Print top joint actions
        top = np.argsort(-X.reshape(-1))[:min(5, X.size)]
        log.info(f"[{solver}] Top joint actions:")
        for k, idx in enumerate(top, start=1):
            i, j = np.unravel_index(int(idx), (m, n))
            log.info(f"  #{k}: (R{i}:{robots[i].name}, S{j}:{sensors[j].name}) prob={X[i,j]:.4f}")
        log.info(f"[{solver}] sigma_R={sigma_R.round(3)}")
        log.info(f"[{solver}] sigma_S={sigma_S.round(3)}")

        # Stability diagnostics
        regrets = compute_ce_regrets(U_R, U_S, X)
        ent = entropy_of_joint(X)

        # Self-play evaluation under joint X
        sp = evaluate_joint_strategy(
            env,
            robots,
            sensors,
            X,
            M_modes=M_modes,
            episodes=args.eval_episodes,
            base_seed=9000 + 100 * it,
        )

        log.info(
            f"[{solver}] CE-regrets: maxR={regrets['max_regret_R']:.3f} maxS={regrets['max_regret_S']:.3f} | "
            f"SelfPlay: UR={sp.mean_U_R:.2f} US={sp.mean_U_S:.2f} det%={100*sp.det_rate:.1f} goal%={100*sp.goal_rate:.1f} | "
            f"H(X)={ent:.3f}"
        )

        # Best-response expansion
        if solver == "marginal":
            br_r = robot_best_response_to_mixture(
                env,
                sensors,
                pi_S=sigma_S,
                robots=robots,
                M_modes=M_modes,
                risk_weight=args.risk_weight_br,
                log=log,
                tag="Marginal",
            )

            br_s = sensor_best_response_to_mixture(
                env,
                robots,
                pi_R=sigma_R,
                candidate_modes=list(range(M_modes)),
                M_modes=M_modes,
                rollouts=args.rollouts_br,
                base_seed=2000 + 100 * it,
                log=log,
                tag="Marginal",
            )

        elif solver == "correlated":
            # FIX: choose conditional mixtures q(.|i) and q(.|j)
            # We only check top-K recommendations to keep it tractable.
            topK = max(1, int(args.cond_top_k))

            # Robot: check the top-K robot recommendations by sigma_R
            cand_i = [int(i) for i in np.argsort(-sigma_R) if sigma_R[int(i)] > 1e-8][:topK]
            if not cand_i:
                cand_i = [int(i) for i in np.argsort(-sigma_R)[:topK]]
            best_gain = 0.0
            br_r = robots[0]

            for i in cand_i:
                qS = conditional_sensor_given_robot(X, int(i))
                tag = f"Cond_i{i}"
                pol = robot_best_response_to_mixture(
                    env,
                    sensors,
                    pi_S=qS,
                    robots=robots,
                    M_modes=M_modes,
                    risk_weight=args.risk_weight_br,
                    log=log,
                    tag=tag,
                )

                # Estimate conditional improvement: E_q[U_R(pol,j)] - E_q[U_R(i,j)]
                # If pol is already in set, its payoff exists in U_R row of that policy.
                # Otherwise we simulate pol against each sensor policy j.
                if pol in robots:
                    ip = robots.index(pol)
                    dev = float(np.dot(qS, U_R[ip, :]))
                else:
                    # simulate quickly vs each sensor policy
                    dev_vals = []
                    for j in range(n):
                        vals = []
                        for kk in range(args.br_eval_rollouts):
                            seed = 777000 + 1000 * it + 100 * i + 10 * j + kk
                            st = rollout_episode(env, pol, sensors[j], M_modes=M_modes, seed=seed)
                            vals.append(st.U_R)
                        dev_vals.append(float(np.mean(vals)))
                    dev = float(np.dot(qS, np.asarray(dev_vals)))

                rec = float(np.dot(qS, U_R[i, :]))
                gain = dev - rec
                if gain > best_gain + 1e-9:
                    best_gain = gain
                    br_r = pol

            if best_gain > args.add_threshold:
                if br_r in robots:
                    log.info(f"[{solver}] Best robot deviation already in set; est_gain={best_gain:.3f}")
                else:
                    log.info(f"[{solver}] Adding robot deviation; est_gain={best_gain:.3f}")
            else:
                log.info(f"[{solver}] No robot deviation above threshold (best_gain={best_gain:.3f}).")

            # Sensor: check top-K sensor recommendations by sigma_S
            cand_j = [int(j) for j in np.argsort(-sigma_S) if sigma_S[int(j)] > 1e-8][:topK]
            if not cand_j:
                cand_j = [int(j) for j in np.argsort(-sigma_S)[:topK]]
            best_gain_s = 0.0
            br_s = FixedModeSensorPolicy(0, name="S_dummy")

            for j in cand_j:
                qR = conditional_robot_given_sensor(X, int(j))
                tag = f"Cond_j{j}"
                polS = sensor_best_response_to_mixture(
                    env,
                    robots,
                    pi_R=qR,
                    candidate_modes=list(range(M_modes)),
                    M_modes=M_modes,
                    rollouts=args.rollouts_br,
                    base_seed=333000 + 1000 * it + 10 * j,
                    log=log,
                    tag=tag,
                )

                # Estimate conditional improvement for sensor
                # rec under recommendation j is E_q[U_S(i,j)]
                recS = float(np.dot(qR, U_S[:, j]))

                # dev under mode polS.mode: if already present, use its column.
                existing_col = None
                for jj, spj in enumerate(sensors):
                    if isinstance(spj, FixedModeSensorPolicy) and spj.mode == polS.mode:
                        existing_col = jj
                        break

                if existing_col is not None:
                    devS = float(np.dot(qR, U_S[:, existing_col]))
                else:
                    vals = []
                    for kk in range(args.br_eval_rollouts):
                        i_samp = int(np.random.default_rng(444 + kk).choice(len(robots), p=qR))
                        seed = 888000 + 1000 * it + 10 * j + kk
                        st = rollout_episode(env, robots[i_samp], polS, M_modes=M_modes, seed=seed)
                        vals.append(st.U_S)
                    devS = float(np.mean(vals))

                gainS = devS - recS
                if gainS > best_gain_s + 1e-9:
                    best_gain_s = gainS
                    br_s = polS

            if best_gain_s > args.add_threshold:
                if (isinstance(br_s, FixedModeSensorPolicy)) and (find_sensor_by_mode(sensors, br_s.mode) is not None):
                    log.info(f"[{solver}] Best sensor deviation already in set (mode={br_s.mode}); est_gain={best_gain_s:.3f}")
                    br_s = None
                else:
                    log.info(f"[{solver}] Adding sensor deviation; est_gain={best_gain_s:.3f}")
            else:
                log.info(f"[{solver}] No sensor deviation above threshold (best_gain={best_gain_s:.3f}).")
                br_s = None

        else:
            raise ValueError("solver must be marginal or correlated")

        # Add to sets (dedupe)
        if br_r not in robots:
            robots.append(br_r)

        if isinstance(br_s, FixedModeSensorPolicy):
            if find_sensor_by_mode(sensors, br_s.mode) is None:
                sensors.append(br_s)
            else:
                log.info(f"[{solver}] Sensor mode {br_s.mode} already present; not adding duplicate.")

        seconds = float(time.time() - t0)
        history.append(
            TrainHistoryRow(
                outer_iter=it,
                m=len(robots),
                n=len(sensors),
                nbs_obj=float(nbs.obj),
                entropy_X=float(ent),
                max_regret_R=float(regrets["max_regret_R"]),
                max_regret_S=float(regrets["max_regret_S"]),
                selfplay_UR=float(sp.mean_U_R),
                selfplay_US=float(sp.mean_U_S),
                selfplay_det=float(sp.det_rate),
                selfplay_goal=float(sp.goal_rate),
                seconds=seconds,
            )
        )

        log.info(f"[{solver}] Sets: |Pi_R|={len(robots)} |Pi_S|={len(sensors)} | iter_seconds={seconds:.2f}")

    if X is None:
        raise RuntimeError("Training produced no X")

    log.banner(f"[{solver}] Finished")
    log.info(f"[{solver}] Final robots: " + ", ".join([p.name for p in robots]))
    log.info(f"[{solver}] Final sensors: " + ", ".join([p.name for p in sensors]))
    log.info(f"[{solver}] Total time: {time.time()-t0_all:.2f}s")

    return TrainResult(solver=solver, env=env, robots=robots, sensors=sensors, X=X, history=history)


# =============================================================================
# Benchmarks + plotting
# =============================================================================

@dataclass
class BenchCell:
    UR: float
    US: float
    det: float
    goal: float


def run_policy_matrix_benchmark(
    env: GridWorldStealthEnv,
    robot_methods: List[RobotPolicy],
    sensor_methods: List[SensorPolicy],
    M_modes: int,
    episodes: int,
    base_seed: int,
) -> Dict[Tuple[int, int], BenchCell]:
    res: Dict[Tuple[int, int], BenchCell] = {}
    for i, rp in enumerate(robot_methods):
        for j, sp in enumerate(sensor_methods):
            UR = []
            US = []
            det = 0
            goal = 0
            for k in range(episodes):
                seed = base_seed + 100000 * i + 1000 * j + k
                st = rollout_episode(env, rp, sp, M_modes=M_modes, seed=seed)
                UR.append(st.U_R)
                US.append(st.U_S)
                det += int(st.detected)
                goal += int(st.reached_goal)
            res[(i, j)] = BenchCell(
                UR=float(np.mean(UR)),
                US=float(np.mean(US)),
                det=float(det / episodes),
                goal=float(goal / episodes),
            )
    return res


def safe_makedirs(path: str) -> None:
    os.makedirs(path, exist_ok=True)


def save_history_csv(hist: List[TrainHistoryRow], path: str) -> None:
    import csv

    fields = list(TrainHistoryRow.__annotations__.keys())
    with open(path, "w", newline="") as f:
        w = csv.DictWriter(f, fieldnames=fields)
        w.writeheader()
        for row in hist:
            w.writerow({k: getattr(row, k) for k in fields})


def plot_training_curves(results: List[TrainResult], outdir: str) -> None:
    import matplotlib.pyplot as plt

    safe_makedirs(outdir)

    # 1) NBS objective
    plt.figure()
    for r in results:
        xs = [h.outer_iter for h in r.history]
        ys = [h.nbs_obj for h in r.history]
        plt.plot(xs, ys, marker="o", label=r.solver)
    plt.xlabel("Outer iteration")
    plt.ylabel("NBS objective")
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(outdir, "curve_nbs_obj.png"), dpi=200)
    plt.close()

    # 2) Max conditional regrets
    plt.figure()
    for r in results:
        xs = [h.outer_iter for h in r.history]
        ys = [h.max_regret_R for h in r.history]
        plt.plot(xs, ys, marker="o", label=f"{r.solver}: maxRegret_R")
    for r in results:
        xs = [h.outer_iter for h in r.history]
        ys = [h.max_regret_S for h in r.history]
        plt.plot(xs, ys, marker="x", linestyle="--", label=f"{r.solver}: maxRegret_S")
    plt.xlabel("Outer iteration")
    plt.ylabel("Max conditional regret")
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(outdir, "curve_max_regrets.png"), dpi=200)
    plt.close()

    # 3) Entropy of X
    plt.figure()
    for r in results:
        xs = [h.outer_iter for h in r.history]
        ys = [h.entropy_X for h in r.history]
        plt.plot(xs, ys, marker="o", label=r.solver)
    plt.xlabel("Outer iteration")
    plt.ylabel("Entropy H(X)")
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(outdir, "curve_entropy_X.png"), dpi=200)
    plt.close()

    # 4) Self-play outcomes
    plt.figure()
    for r in results:
        xs = [h.outer_iter for h in r.history]
        ys = [h.selfplay_UR for h in r.history]
        plt.plot(xs, ys, marker="o", label=f"{r.solver}: UR")
    for r in results:
        xs = [h.outer_iter for h in r.history]
        ys = [h.selfplay_US for h in r.history]
        plt.plot(xs, ys, marker="x", linestyle="--", label=f"{r.solver}: US")
    plt.xlabel("Outer iteration")
    plt.ylabel("Expected utility under X")
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(outdir, "curve_selfplay_utils.png"), dpi=200)
    plt.close()


def plot_benchmark_bars(
    robot_names: List[str],
    sensor_names: List[str],
    bench: Dict[Tuple[int, int], BenchCell],
    outdir: str,
) -> None:
    import matplotlib.pyplot as plt

    safe_makedirs(outdir)

    # For each sensor, bar chart of robot UR
    for j, sname in enumerate(sensor_names):
        plt.figure(figsize=(10, 4))
        vals = [bench[(i, j)].UR for i in range(len(robot_names))]
        plt.bar(range(len(robot_names)), vals)
        plt.xticks(range(len(robot_names)), robot_names, rotation=35, ha="right")
        plt.ylabel("Robot utility")
        plt.title(f"Robot utility vs sensor={sname}")
        plt.tight_layout()
        plt.savefig(os.path.join(outdir, f"bench_UR_vs_{sname}.png"), dpi=200)
        plt.close()

    # For each sensor, bar chart of robot goal rate
    for j, sname in enumerate(sensor_names):
        plt.figure(figsize=(10, 4))
        vals = [bench[(i, j)].goal for i in range(len(robot_names))]
        plt.bar(range(len(robot_names)), vals)
        plt.xticks(range(len(robot_names)), robot_names, rotation=35, ha="right")
        plt.ylabel("Goal rate")
        plt.title(f"Robot goal rate vs sensor={sname}")
        plt.tight_layout()
        plt.savefig(os.path.join(outdir, f"bench_goal_vs_{sname}.png"), dpi=200)
        plt.close()


# =============================================================================
# Main pipeline wrapper
# =============================================================================

def run_pipeline(args: argparse.Namespace) -> None:
    log = Logger(args.log_level)

    grid = build_two_corridor_grid()
    sensor_cfg = SensorConfig(
        sensors=((grid.width // 2, 2), (grid.width // 2, 6)),
        radius=2,
        base_p=0.02,
        hotspot_p=0.60,
    )
    env = GridWorldStealthEnv(grid, sensor_cfg, fp=0.05, fn=0.10, seed=args.seed)

    if log.k >= 1:
        log.banner("[PIPELINE] Stealth grid game")
        log.info(f"Grid: {grid.width}x{grid.height} start={grid.start} goal={grid.goal} obstacles={len(grid.obstacles)}")
        log.info(f"Sensors: {sensor_cfg.sensors} radius={sensor_cfg.radius} base_p={sensor_cfg.base_p} hotspot_p={sensor_cfg.hotspot_p}")
        if log.k >= 2:
            log.debug("ASCII map:")
            print_grid_ascii(grid, sensor_cfg)

    solvers: List[str]
    if args.solver == "both":
        solvers = ["marginal", "correlated"]
    else:
        solvers = [args.solver]

    results: List[TrainResult] = []
    for s in solvers:
        # Use a fresh env copy per solver (to avoid RNG coupling)
        env_s = GridWorldStealthEnv(grid, sensor_cfg, fp=0.05, fn=0.10, seed=args.seed)
        results.append(run_training(env_s, args, solver=s, log=log))

    if args.results_dir:
        safe_makedirs(args.results_dir)
        for r in results:
            save_history_csv(r.history, os.path.join(args.results_dir, f"history_{r.solver}.csv"))

    if args.save_plots and args.results_dir:
        plot_training_curves(results, outdir=args.results_dir)

    # Benchmarks on the same game
    if args.run_benchmarks:
        log.banner("[BENCH] Heuristics on same game")
        M_modes = len(sensor_cfg.sensors)

        # Robot heuristics
        #  - fixed paths from initial set
        robots_init, sensors_init = build_initial_policies_for_bench(env, M_modes)

        # Sensor heuristics
        sensor_methods: List[SensorPolicy] = [
            FixedModeSensorPolicy(0, "S_Mode0"),
            FixedModeSensorPolicy(1, "S_Mode1"),
            AlternatingSensorPolicy(M_modes, "S_Alternate"),
            RandomModeSensorPolicy(M_modes, seed=args.seed + 123, name="S_RandomMode"),
        ]

        bench = run_policy_matrix_benchmark(
            env,
            robot_methods=robots_init,
            sensor_methods=sensor_methods,
            M_modes=M_modes,
            episodes=args.bench_episodes,
            base_seed=555000,
        )

        # Print a compact table
        robot_names = [p.name for p in robots_init]
        sensor_names = [p.name for p in sensor_methods]
        for j, sname in enumerate(sensor_names):
            log.info(f"[BENCH] Sensor={sname}")
            for i, rname in enumerate(robot_names):
                cell = bench[(i, j)]
                log.info(f"  Robot={rname:22s} UR={cell.UR:8.2f} goal%={100*cell.goal:5.1f} det%={100*cell.det:5.1f}")

        if args.save_plots and args.results_dir:
            plot_benchmark_bars(robot_names, sensor_names, bench, outdir=args.results_dir)


# Helper: initial robot set for benchmark (without the solver-added BR policies)

def build_initial_policies_for_bench(env: GridWorldStealthEnv, M_modes: int) -> Tuple[List[RobotPolicy], List[SensorPolicy]]:
    grid = env.grid
    step_cost = lambda a, b: 1.0
    p_short = astar_path(grid, grid.start, grid.goal, step_cost)

    wall_x = grid.width // 2
    p_upper = astar_via_waypoint(grid, grid.start, (wall_x, 2), grid.goal, step_cost)
    p_lower = astar_via_waypoint(grid, grid.start, (wall_x, 6), grid.goal, step_cost)

    # Static risk-aware A* under UNIFORM mode belief
    uniform_mode = np.full(M_modes, 1.0 / M_modes)

    # Construct risk map directly under uniform belief
    H, W = grid.height, grid.width
    risk_uniform = np.zeros((H, W), dtype=float)
    risk_worst = np.zeros((H, W), dtype=float)
    for y in range(H):
        for x in range(W):
            if (x, y) in grid.obstacles:
                risk_uniform[y, x] = np.nan
                risk_worst[y, x] = np.nan
                continue
            vals = [env.true_detection_prob((x, y), m) for m in range(M_modes)]
            risk_uniform[y, x] = float(np.dot(uniform_mode, vals))
            risk_worst[y, x] = float(np.max(vals))

    p_uniform = plan_risk_weighted_path(env, risk_uniform, risk_weight=12.0)
    p_worst = plan_risk_weighted_path(env, risk_worst, risk_weight=12.0)

    robots: List[RobotPolicy] = [
        FixedPathPolicy(p_short, "R_Shortest"),
        FixedPathPolicy(p_upper, "R_UpperCorridor"),
        FixedPathPolicy(p_lower, "R_LowerCorridor"),
        FixedPathPolicy(p_uniform, "R_RiskAStar_Uniform"),
        FixedPathPolicy(p_worst, "R_RiskAStar_WorstCase"),
        OnlineBeliefReplanPolicy(env, risk_weight=12.0, name="R_OnlineBeliefReplan"),
        RandomPolicy("R_Random"),
    ]

    sensors: List[SensorPolicy] = [FixedModeSensorPolicy(m, name=f"S_Mode{m}") for m in range(M_modes)]
    return robots, sensors


# =============================================================================
# CLI
# =============================================================================

def parse_args(argv: Optional[List[str]] = None) -> Tuple[argparse.Namespace, List[str]]:
    p = argparse.ArgumentParser()

    p.add_argument("--seed", type=int, default=0)
    p.add_argument("--log-level", type=str, default="INFO", choices=["QUIET", "INFO", "DEBUG"])

    p.add_argument("--solver", type=str, default="correlated", choices=["marginal", "correlated", "both"])

    p.add_argument("--outer-iters", type=int, default=3)

    p.add_argument("--rollouts-payoff", type=int, default=20)
    p.add_argument("--rollouts-br", type=int, default=30)
    p.add_argument("--risk-weight-br", type=float, default=12.0)

    # NBS knobs
    p.add_argument("--disagreement", type=str, default="minminus", choices=["minminus", "uniform"])
    p.add_argument("--entropy-tau", type=float, default=0.0)

    # Fix mismatch knobs
    p.add_argument("--cond-top-k", type=int, default=2, help="How many top recommendations to check for conditional BRs")
    p.add_argument("--br-eval-rollouts", type=int, default=8, help="Small evaluation rollouts for new deviations")
    p.add_argument("--add-threshold", type=float, default=0.25, help="Minimum estimated conditional gain to add a deviation policy")

    # Eval episodes under joint X (self-play)
    p.add_argument("--eval-episodes", type=int, default=60)

    # Debug
    p.add_argument("--debug-rollout-pair", type=str, default="", help="Print one step-by-step rollout for i,j")

    # Outputs
    p.add_argument("--results-dir", type=str, default="results", help="Directory for csv/plots")
    p.add_argument("--save-plots", action="store_true")

    # Benchmarks
    p.add_argument("--run-benchmarks", action="store_true")
    p.add_argument("--bench-episodes", type=int, default=80)

    args, unknown = p.parse_known_args(args=argv)
    return args, unknown


def main(argv: Optional[List[str]] = None) -> None:
    args, unknown = parse_args(argv=argv)
    if unknown and ("ipykernel" not in sys.modules):
        print(f"[WARN] Ignoring unknown CLI args: {unknown}")
    if args.debug_rollout_pair.strip() == "":
        args.debug_rollout_pair = ""
    run_pipeline(args)


if __name__ == "__main__" and ("ipykernel" not in sys.modules):
    main()


# =============================================================================
# EXPERIMENTS (fixed game set + baselines + tests + presentation plots)
# =============================================================================
#
# Keep the training code above exactly as-is.
# This section adds:
#   (A) sanity tests (quick asserts)
#   (B) a fixed game-set generator (deterministic)
#   (C) an experiment runner that compares solvers + baselines on the same games
#   (D) plots + CSV outputs for slides
#
# Notebook usage:
#   run_sanity_tests()
#   run_fixed_games_experiment(outdir="results_exp", n_games=6, seeds=[0,1,2])
#

from dataclasses import asdict


@dataclass(frozen=True)
class GameInstance:
    game_id: str
    grid: GridConfig
    sensor_cfg: SensorConfig
    desc: str


def _is_reachable(grid: GridConfig) -> bool:
    try:
        _ = astar_path(grid, grid.start, grid.goal, step_cost=lambda a, b: 1.0, max_expansions=250_000)
        return True
    except Exception:
        return False


def build_fixed_game_set(
    n_games: int = 6,
    seed: int = 123,
    width: int = 15,
    height: int = 9,
    base_radius: int = 2,
    base_p: float = 0.02,
    hotspot_p: float = 0.60,
) -> List[GameInstance]:
    """Deterministic *fixed* game set for fair comparisons.

    We generate corridor-variant grids by adding a small number of extra obstacles
    (without breaking reachability), while keeping the base corridor wall.

    Notes:
      - Uses a fixed RNG seed => same games every run.
      - Guarantees start->goal reachability (ignoring detection risk).
    """
    rng = np.random.default_rng(int(seed))

    base_grid = build_two_corridor_grid(width=width, height=height)
    wall_x = width // 2

    # Keep the two sensor centers near the corridor gaps for interpretability.
    base_sensors = ((wall_x, 2), (wall_x, 6))

    games: List[GameInstance] = []
    attempts = 0

    # Increasing difficulty: more extra obstacles.
    # (Chosen small so it stays tractable and doesn’t accidentally block corridors.)
    obstacle_budgets = [0, 2, 4, 6, 8, 10]
    while len(games) < n_games:
        attempts += 1
        if attempts > 4000:
            raise RuntimeError("Could not generate enough solvable game instances")

        k = len(games)
        extra_obs_target = obstacle_budgets[min(k, len(obstacle_budgets) - 1)]

        # Candidate cells for extra obstacles (avoid start/goal/sensors and the wall column).
        candidates: List[Tuple[int, int]] = []
        for y in range(height):
            for x in range(width):
                p = (x, y)
                if p in base_grid.obstacles:
                    continue
                if p == base_grid.start or p == base_grid.goal:
                    continue
                if p in base_sensors:
                    continue
                if x == wall_x:
                    continue
                candidates.append(p)

        rng.shuffle(candidates)
        extra = set(candidates[:extra_obs_target])

        grid = GridConfig(
            width=base_grid.width,
            height=base_grid.height,
            start=base_grid.start,
            goal=base_grid.goal,
            obstacles=frozenset(set(base_grid.obstacles) | extra),
        )
        if not _is_reachable(grid):
            continue

        sensor_cfg = SensorConfig(
            sensors=base_sensors,
            radius=int(base_radius),
            base_p=float(base_p),
            hotspot_p=float(hotspot_p),
        )

        game_id = f"G{k:02d}_extraObs{extra_obs_target}"
        desc = f"corridors + {extra_obs_target} extra obstacles"
        games.append(GameInstance(game_id=game_id, grid=grid, sensor_cfg=sensor_cfg, desc=desc))

    return games


# -----------------------------------------------------------------------------
# Sanity tests
# -----------------------------------------------------------------------------

def run_sanity_tests() -> None:
    """Fast, high-signal checks so you can trust comparisons."""
    print("[TEST] Running sanity tests...")

    grid = build_two_corridor_grid()
    sensor_cfg = SensorConfig(
        sensors=((grid.width // 2, 2), (grid.width // 2, 6)),
        radius=2,
        base_p=0.02,
        hotspot_p=0.60,
    )
    env = GridWorldStealthEnv(grid, sensor_cfg, fp=0.05, fn=0.10, seed=0)

    # A* reachability
    p = astar_path(grid, grid.start, grid.goal, step_cost=lambda a, b: 1.0)
    assert p[0] == grid.start and p[-1] == grid.goal

    # Belief update preserves simplex
    b = ModeBelief(M=2)
    env.reset(sensor_mode=0)
    out = env.step((1, 0))
    b.update(env, out["alarm"], out["pos"])
    assert np.isfinite(b.b).all()
    assert abs(float(b.b.sum()) - 1.0) < 1e-6
    assert (b.b >= -1e-12).all()

    # NBS returns simplex
    uR = np.array([-1.0, -2.0, -3.0, -4.0])
    uS = np.array([1.0, 2.0, 3.0, 4.0])
    nbs = solve_nbs(uR, uS, log=Logger("QUIET"), max_iters=50)
    x = nbs.x
    assert np.isfinite(x).all()
    assert (x >= -1e-10).all()
    assert abs(float(x.sum()) - 1.0) < 1e-6

    # Joint/marginals/conditionals are sane
    X = joint_to_matrix(x, m=2, n=2)
    sigma_R, sigma_S = marginals_from_joint(X)
    assert abs(float(sigma_R.sum()) - 1.0) < 1e-6
    assert abs(float(sigma_S.sum()) - 1.0) < 1e-6
    qS = conditional_sensor_given_robot(X, 0)
    qR = conditional_robot_given_sensor(X, 0)
    assert abs(float(qS.sum()) - 1.0) < 1e-6
    assert abs(float(qR.sum()) - 1.0) < 1e-6

    # CE regrets should be >= 0 (numerical)
    UR = np.random.default_rng(0).normal(size=(3, 2))
    US = np.random.default_rng(1).normal(size=(3, 2))
    Xr = np.full((3, 2), 1.0 / 6)
    reg = compute_ce_regrets(UR, US, Xr)
    assert reg["max_regret_R"] >= -1e-9
    assert reg["max_regret_S"] >= -1e-9

    # Rollout returns finite stats
    robots, sensors = build_initial_policies(env, M_modes=2)
    st = rollout_episode(env, robots[0], sensors[0], M_modes=2, seed=0)
    assert np.isfinite(st.U_R) and np.isfinite(st.U_S)

    print("[TEST] All sanity tests passed ✅")


# -----------------------------------------------------------------------------
# Fair comparisons on the same fixed game set
# -----------------------------------------------------------------------------

@dataclass
class ExperimentRow:
    alg: str
    solver: str
    game_id: str
    seed: int
    metric: str
    value: float


def _evaluate_robot_mixture_against_sensor(
    env: GridWorldStealthEnv,
    robots: List[RobotPolicy],
    sigma_R: np.ndarray,
    sensor: SensorPolicy,
    M_modes: int,
    episodes: int,
    base_seed: int,
) -> StrategyEval:
    sigma_R = np.asarray(sigma_R, dtype=float).reshape(-1)
    sigma_R = sigma_R / max(float(sigma_R.sum()), 1e-12)

    rng = np.random.default_rng(int(base_seed))

    UR, US = [], []
    det = 0
    goal = 0
    steps_list = []
    risk_list = []

    for k in range(episodes):
        i = int(rng.choice(len(robots), p=sigma_R))
        st = rollout_episode(env, robots[i], sensor, M_modes=M_modes, seed=int(base_seed + k))
        UR.append(st.U_R)
        US.append(st.U_S)
        det += int(st.detected)
        goal += int(st.reached_goal)
        steps_list.append(st.steps)
        risk_list.append(st.total_true_risk)

    return StrategyEval(
        mean_U_R=float(np.mean(UR)),
        mean_U_S=float(np.mean(US)),
        det_rate=float(det / episodes),
        goal_rate=float(goal / episodes),
        mean_steps=float(np.mean(steps_list)),
        mean_risk=float(np.mean(risk_list)),
    )


def _summarize_against_sensor_suite(
    env: GridWorldStealthEnv,
    robots: List[RobotPolicy],
    sigma_R: np.ndarray,
    sensors_suite: List[SensorPolicy],
    M_modes: int,
    episodes_per_sensor: int,
    base_seed: int,
) -> Dict[str, float]:
    vals_UR = []
    vals_goal = []
    vals_det = []

    for j, sp in enumerate(sensors_suite):
        ev = _evaluate_robot_mixture_against_sensor(
            env,
            robots,
            sigma_R=sigma_R,
            sensor=sp,
            M_modes=M_modes,
            episodes=episodes_per_sensor,
            base_seed=base_seed + 10000 * j,
        )
        vals_UR.append(ev.mean_U_R)
        vals_goal.append(ev.goal_rate)
        vals_det.append(ev.det_rate)

    # "Robust" = worst case over the sensor suite.
    return {
        "mean_UR": float(np.mean(vals_UR)),
        "robust_UR": float(np.min(vals_UR)),
        "mean_goal": float(np.mean(vals_goal)),
        "robust_goal": float(np.min(vals_goal)),
        "mean_det": float(np.mean(vals_det)),
        "robust_det": float(np.max(vals_det)),
    }


def _write_experiment_csv(rows: List[ExperimentRow], path: str) -> None:
    import csv

    safe_makedirs(os.path.dirname(path) or ".")
    with open(path, "w", newline="") as f:
        w = csv.writer(f)
        w.writerow(["alg", "solver", "game_id", "seed", "metric", "value"])
        for r in rows:
            w.writerow([r.alg, r.solver, r.game_id, r.seed, r.metric, f"{r.value:.8f}"])


def _group_stats(values: List[float]) -> Dict[str, float]:
    arr = np.asarray(values, dtype=float)
    return {
        "mean": float(np.mean(arr)) if arr.size else float("nan"),
        "std": float(np.std(arr)) if arr.size else float("nan"),
        "n": int(arr.size),
    }


def plot_experiment_summary(rows: List[ExperimentRow], outdir: str) -> None:
    """Creates high-signal graphs for your presentation."""
    import matplotlib.pyplot as plt

    safe_makedirs(outdir)

    # Collect per-alg stats for a few key metrics.
    metrics = [
        "robust_UR",
        "mean_UR",
        "robust_goal",
        "mean_goal",
        "robust_det",
        "runtime_s",
        "entropy_X",
        "max_regret_R",
        "max_regret_S",
    ]

    algs = sorted({r.alg for r in rows})

    # Build dict: metric -> alg -> list of values
    bucket: Dict[str, Dict[str, List[float]]] = {m: {a: [] for a in algs} for m in metrics}
    for r in rows:
        if r.metric in bucket:
            bucket[r.metric][r.alg].append(float(r.value))

    # Helper: bar plot with error bars.
    def bar_with_std(metric: str, ylabel: str, filename: str) -> None:
        plt.figure(figsize=(10, 4))
        means = []
        stds = []
        for a in algs:
            st = _group_stats(bucket[metric][a])
            means.append(st["mean"])
            stds.append(st["std"])
        xs = np.arange(len(algs))
        plt.bar(xs, means, yerr=stds, capsize=4)
        plt.xticks(xs, algs, rotation=25, ha="right")
        plt.ylabel(ylabel)
        plt.tight_layout()
        plt.savefig(os.path.join(outdir, filename), dpi=200)
        plt.close()

    bar_with_std("robust_UR", "Robot utility (worst-case over sensor suite)", "exp_bar_robust_UR.png")
    bar_with_std("mean_UR", "Robot utility (mean over sensor suite)", "exp_bar_mean_UR.png")
    bar_with_std("robust_goal", "Goal rate (worst-case over sensor suite)", "exp_bar_robust_goal.png")
    bar_with_std("robust_det", "Detection rate (worst-case over sensor suite)", "exp_bar_robust_det.png")
    bar_with_std("runtime_s", "Runtime (seconds)", "exp_bar_runtime.png")

    # Scatter: robustness vs (robot) regret to show the stability/optimality tradeoff.
    plt.figure(figsize=(6, 5))
    for a in algs:
        xs = bucket["max_regret_R"][a]
        ys = bucket["robust_UR"][a]
        if xs and ys:
            plt.scatter(xs, ys, label=a)
    plt.xlabel("Max conditional regret (robot)")
    plt.ylabel("Robust robot utility")
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(outdir, "exp_scatter_robustUR_vs_regretR.png"), dpi=200)
    plt.close()


def run_fixed_games_experiment(
    outdir: str = "results_exp",
    n_games: int = 6,
    seeds: Optional[List[int]] = None,
    outer_iters: int = 3,
    rollouts_payoff: int = 20,
    rollouts_br: int = 30,
    eval_episodes: int = 120,
    episodes_per_sensor: int = 80,
    disagreement: str = "minminus",
    entropy_tau: float = 0.0,
    cond_top_k: int = 2,
    br_eval_rollouts: int = 8,
    add_threshold: float = 0.25,
    risk_weight_br: float = 12.0,
    include_solvers: Optional[List[str]] = None,
    include_baselines: bool = True,
    log_level: str = "QUIET",
) -> None:
    """Run a clean comparison on a deterministic game set.

    What you get in outdir:
      - experiments.csv  (all raw numbers)
      - exp_bar_*.png    (summary bars)
      - exp_scatter_*.png

    Recommended presentation story:
      1) Robust UR and robust goal rate vs baselines
      2) Runtime vs baselines
      3) Scatter showing tradeoff: regret (stability) vs robust UR (performance)

    NOTE: NBS is not a Nash/CE solver; higher regret is expected. That’s *part of the story*.
    """

    if seeds is None:
        seeds = [0, 1, 2]
    if include_solvers is None:
        include_solvers = ["correlated", "marginal"]

    safe_makedirs(outdir)

    games = build_fixed_game_set(n_games=n_games)

    rows: List[ExperimentRow] = []

    # Sensor suite for robust evaluation (adversarial-ish)
    # You can add Alternate/Random to show generalization, but the key is Mode0/Mode1.
    def make_sensor_suite(M_modes: int, seed0: int) -> List[SensorPolicy]:
        return [
            FixedModeSensorPolicy(0, "S_Mode0"),
            FixedModeSensorPolicy(1, "S_Mode1"),
            AlternatingSensorPolicy(M_modes, "S_Alternate"),
            RandomModeSensorPolicy(M_modes, seed=seed0 + 999, name="S_RandomMode"),
        ]

    for g in games:
        for s in seeds:
            env = GridWorldStealthEnv(g.grid, g.sensor_cfg, fp=0.05, fn=0.10, seed=s)
            M_modes = len(g.sensor_cfg.sensors)
            sensors_suite = make_sensor_suite(M_modes=M_modes, seed0=s)

            # -----------------
            # (1) Run solvers
            # -----------------
            for solver in include_solvers:
                # Build an args object compatible with run_training()
                args = argparse.Namespace(
                    seed=s,
                    log_level=log_level,
                    solver=solver,
                    outer_iters=int(outer_iters),
                    rollouts_payoff=int(rollouts_payoff),
                    rollouts_br=int(rollouts_br),
                    risk_weight_br=float(risk_weight_br),
                    disagreement=str(disagreement),
                    entropy_tau=float(entropy_tau),
                    cond_top_k=int(cond_top_k),
                    br_eval_rollouts=int(br_eval_rollouts),
                    add_threshold=float(add_threshold),
                    eval_episodes=int(eval_episodes),
                    debug_rollout_pair="",
                    results_dir="",
                    save_plots=False,
                    run_benchmarks=False,
                    bench_episodes=0,
                )

                t0 = time.time()
                log = Logger(log_level)
                tr = run_training(env, args, solver=solver, log=log)
                runtime_s = float(time.time() - t0)

                # Extract sigma_R from final X
                sigma_R, _sigma_S = marginals_from_joint(tr.X)
                suite_summary = _summarize_against_sensor_suite(
                    env,
                    robots=tr.robots,
                    sigma_R=sigma_R,
                    sensors_suite=sensors_suite,
                    M_modes=M_modes,
                    episodes_per_sensor=int(episodes_per_sensor),
                    base_seed=100_000 + 1000 * s,
                )

                # Also keep solver-internal diagnostics (from last history row)
                last = tr.history[-1]

                alg = f"Solver:{solver}"
                for k, v in suite_summary.items():
                    rows.append(ExperimentRow(alg=alg, solver=solver, game_id=g.game_id, seed=s, metric=k, value=float(v)))

                rows.append(ExperimentRow(alg=alg, solver=solver, game_id=g.game_id, seed=s, metric="runtime_s", value=runtime_s))
                rows.append(ExperimentRow(alg=alg, solver=solver, game_id=g.game_id, seed=s, metric="entropy_X", value=float(last.entropy_X)))
                rows.append(ExperimentRow(alg=alg, solver=solver, game_id=g.game_id, seed=s, metric="max_regret_R", value=float(last.max_regret_R)))
                rows.append(ExperimentRow(alg=alg, solver=solver, game_id=g.game_id, seed=s, metric="max_regret_S", value=float(last.max_regret_S)))
                rows.append(ExperimentRow(alg=alg, solver=solver, game_id=g.game_id, seed=s, metric="selfplay_UR", value=float(last.selfplay_UR)))
                rows.append(ExperimentRow(alg=alg, solver=solver, game_id=g.game_id, seed=s, metric="selfplay_US", value=float(last.selfplay_US)))

            # -----------------
            # (2) Baseline robots (no training)
            # -----------------
            if include_baselines:
                robots_base, _ = build_initial_policies_for_bench(env, M_modes=M_modes)

                for rp in robots_base:
                    sigma = np.zeros(len(robots_base), dtype=float)
                    sigma[robots_base.index(rp)] = 1.0

                    suite_summary = _summarize_against_sensor_suite(
                        env,
                        robots=robots_base,
                        sigma_R=sigma,
                        sensors_suite=sensors_suite,
                        M_modes=M_modes,
                        episodes_per_sensor=int(episodes_per_sensor),
                        base_seed=200_000 + 1000 * s,
                    )

                    alg = f"Baseline:{rp.name}"
                    for k, v in suite_summary.items():
                        rows.append(ExperimentRow(alg=alg, solver="baseline", game_id=g.game_id, seed=s, metric=k, value=float(v)))

    # Save raw numbers
    csv_path = os.path.join(outdir, "experiments.csv")
    _write_experiment_csv(rows, csv_path)

    # Produce summary plots
    plot_experiment_summary(rows, outdir=outdir)

    print(f"[EXP] Done. Wrote: {csv_path}")
    print(f"[EXP] Plots saved to: {outdir}")



In [22]:
run_sanity_tests()

run_fixed_games_experiment(
    outdir="results_exp",
    n_games=6,
    seeds=[0,1,2],
    outer_iters=3,
    include_solvers=["correlated","marginal"],
    include_baselines=True,
)


[TEST] Running sanity tests...
[TEST] All sanity tests passed ✅
[EXP] Done. Wrote: results_exp\experiments.csv
[EXP] Plots saved to: results_exp


In [25]:
#!/usr/bin/env python3
"""approach2_robust_correlated.py

Robust, reduced-log implementation of "Approach 2" on a stealth gridworld POMDP,
with a FIX for the conceptual mismatch:

  Old mismatch: solve a joint distribution x*(i,j) over (robot policy i, sensor policy j)
  but then compute best responses against the MARGINALS sigma_R, sigma_S.

  Fix here: treat x* as a CORRELATION DEVICE (mediator) and compute BRs against
  CONDITIONAL distributions:

      q_S(.|i) = X[i,:] / sigma_R[i]   (sensor conditional given robot recommendation i)
      q_R(.|j) = X[:,j] / sigma_S[j]   (robot conditional given sensor recommendation j)

  Then add the most profitable *deviation* policy (oracle) based on the most violated
  conditional recommendation.

This makes the PSRO-style expansion step consistent with a correlated-strategy viewpoint.

Also included:
  - Benchmark harness vs simple heuristics on the same game.
  - Saved plots (training curves + bar charts) for presentations.

Run (script):
  python approach2_robust_correlated.py --solver correlated --outer-iters 3 --save-plots

Run (notebook):
  main(argv=["--solver","both","--outer-iters","3","--save-plots"])  # ignore ipykernel args

Tips to encourage MULTIMODAL x* (useful for your "mode recovery"):
  --disagreement uniform --entropy-tau 0.02

"""

from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Tuple

import argparse
import sys
import heapq
import os
import time

import numpy as np


# =============================================================================
# Logging
# =============================================================================

class Logger:
    """Stage-based logger with QUIET/INFO/DEBUG."""

    LEVELS = {"QUIET": 0, "INFO": 1, "DEBUG": 2}

    def __init__(self, level: str = "INFO"):
        level = level.upper()
        if level not in self.LEVELS:
            raise ValueError(f"Unknown log level: {level}. Use QUIET/INFO/DEBUG")
        self.level = level
        self.k = self.LEVELS[level]

    def banner(self, title: str) -> None:
        if self.k >= 1:
            print("" + "=" * 100)
            print(title)
            print("=" * 100)

    def info(self, msg: str) -> None:
        if self.k >= 1:
            print(msg)

    def debug(self, msg: str) -> None:
        if self.k >= 2:
            print(msg)


# =============================================================================
# Types
# =============================================================================

Action = Tuple[int, int]  # (dx, dy)


# =============================================================================
# Environment (grid + hidden sensor mode + noisy alarm observation)
# =============================================================================

@dataclass(frozen=True)
class GridConfig:
    width: int
    height: int
    start: Tuple[int, int]
    goal: Tuple[int, int]
    obstacles: frozenset


@dataclass(frozen=True)
class SensorConfig:
    sensors: Tuple[Tuple[int, int], ...]
    radius: int
    base_p: float
    hotspot_p: float


class GridWorldStealthEnv:
    """Grid world with detection risk controlled by a hidden/selected sensor 'mode'."""

    def __init__(self, grid: GridConfig, sensor_cfg: SensorConfig, fp: float = 0.05, fn: float = 0.10, seed: int = 0):
        self.grid = grid
        self.sensor_cfg = sensor_cfg
        self.fp = float(fp)
        self.fn = float(fn)

        if not (0.0 <= self.fp <= 1.0 and 0.0 <= self.fn <= 1.0):
            raise ValueError("fp and fn must be in [0,1].")
        if not (0.0 <= sensor_cfg.base_p <= 1.0 and 0.0 <= sensor_cfg.hotspot_p <= 1.0):
            raise ValueError("base_p and hotspot_p must be in [0,1].")
        if sensor_cfg.radius < 0:
            raise ValueError("radius must be >= 0")
        if len(sensor_cfg.sensors) == 0:
            raise ValueError("Need at least one sensor center.")

        self.rng = np.random.default_rng(int(seed))
        self.reset(sensor_mode=0)

    def seed(self, seed: int) -> None:
        self.rng = np.random.default_rng(int(seed))

    def reset(self, sensor_mode: int = 0) -> Dict[str, Any]:
        self.t = 0
        self.pos = self.grid.start
        self.sensor_mode = int(sensor_mode)
        self.detected = False
        self.total_true_risk = 0.0
        return {"pos": self.pos, "t": self.t}

    def in_bounds(self, p: Tuple[int, int]) -> bool:
        x, y = p
        return 0 <= x < self.grid.width and 0 <= y < self.grid.height

    def is_free(self, p: Tuple[int, int]) -> bool:
        return self.in_bounds(p) and (p not in self.grid.obstacles)

    def true_detection_prob(self, p: Tuple[int, int], mode: int) -> float:
        if not (0 <= mode < len(self.sensor_cfg.sensors)):
            raise ValueError(f"mode {mode} out of range")
        sx, sy = self.sensor_cfg.sensors[mode]
        x, y = p
        d = abs(x - sx) + abs(y - sy)
        return self.sensor_cfg.hotspot_p if d <= self.sensor_cfg.radius else self.sensor_cfg.base_p

    def observation_prob(self, alarm: int, p_true: float) -> float:
        p_alarm = p_true * (1.0 - self.fn) + (1.0 - p_true) * self.fp
        if alarm == 1:
            return p_alarm
        if alarm == 0:
            return 1.0 - p_alarm
        raise ValueError("alarm must be 0 or 1")

    def step(self, a: Action) -> Dict[str, Any]:
        if self.detected:
            return {"pos": self.pos, "t": self.t, "alarm": 1, "p_true": 1.0, "detected": True, "done": True}

        self.t += 1
        nx = self.pos[0] + int(a[0])
        ny = self.pos[1] + int(a[1])
        np_ = (nx, ny)
        if self.is_free(np_):
            self.pos = np_

        p_true = float(self.true_detection_prob(self.pos, self.sensor_mode))
        self.total_true_risk += p_true

        if self.rng.random() < p_true:
            self.detected = True

        p_alarm = p_true * (1.0 - self.fn) + (1.0 - p_true) * self.fp
        alarm = 1 if (self.rng.random() < p_alarm) else 0

        done = self.detected or (self.pos == self.grid.goal) or (self.t >= 200)
        return {"pos": self.pos, "t": self.t, "alarm": alarm, "p_true": p_true, "detected": self.detected, "done": done}


def build_two_corridor_grid(width: int = 15, height: int = 9) -> GridConfig:
    obstacles = set()
    wall_x = width // 2
    gap_ys = {2, 6}
    for y in range(height):
        if y not in gap_ys:
            obstacles.add((wall_x, y))

    start = (1, height - 2)
    goal = (width - 2, 1)
    if start in obstacles or goal in obstacles:
        raise RuntimeError("Start/goal blocked unexpectedly")

    return GridConfig(width=width, height=height, start=start, goal=goal, obstacles=frozenset(obstacles))


def print_grid_ascii(grid: GridConfig, sensor_cfg: SensorConfig) -> None:
    W, H = grid.width, grid.height
    obs = set(grid.obstacles)
    sens = set(sensor_cfg.sensors)
    for y in range(H):
        row = []
        for x in range(W):
            p = (x, y)
            if p == grid.start:
                row.append("R")
            elif p == grid.goal:
                row.append("G")
            elif p in sens:
                row.append("S")
            elif p in obs:
                row.append("#")
            else:
                row.append(".")
        print("".join(row))


# =============================================================================
# Belief over modes
# =============================================================================

class ModeBelief:
    """Exact belief over discrete modes m in {0..M-1}."""

    def __init__(self, M: int, init: Optional[np.ndarray] = None):
        self.M = int(M)
        if self.M <= 0:
            raise ValueError("M must be >= 1")
        if init is None:
            self.b = np.full(self.M, 1.0 / self.M)
        else:
            init = np.asarray(init, dtype=float).reshape(-1)
            if init.shape != (self.M,):
                raise ValueError("init shape mismatch")
            if np.any(init < 0):
                raise ValueError("init must be nonnegative")
            s = float(init.sum())
            self.b = init / s if s > 0 else np.full(self.M, 1.0 / self.M)

    def update(self, env: GridWorldStealthEnv, alarm: int, pos: Tuple[int, int], eps: float = 1e-12) -> None:
        like = np.zeros(self.M, dtype=float)
        for m in range(self.M):
            p_true = env.true_detection_prob(pos, m)
            like[m] = env.observation_prob(alarm, p_true)
        post = self.b * like
        Z = float(post.sum())
        if (not np.isfinite(Z)) or Z < eps:
            return
        self.b = post / Z


# =============================================================================
# Policies
# =============================================================================

class RobotPolicy:
    name: str = "RobotPolicy"

    def reset(self, start_pos: Tuple[int, int]) -> None:
        pass

    def act(self, env: GridWorldStealthEnv, belief: ModeBelief, last_obs: Optional[int]) -> Action:
        raise NotImplementedError


class FixedPathPolicy(RobotPolicy):
    def __init__(self, path: List[Tuple[int, int]], name: str):
        if len(path) < 2:
            raise ValueError("Path must have >=2 states")
        self.path = list(path)
        self.name = str(name)
        self._idx = 0

    def reset(self, start_pos: Tuple[int, int]) -> None:
        try:
            self._idx = self.path.index(start_pos)
        except ValueError:
            self._idx = 0

    def act(self, env: GridWorldStealthEnv, belief: ModeBelief, last_obs: Optional[int]) -> Action:
        cur = env.pos
        if self._idx >= len(self.path) - 1:
            return (0, 0)
        if cur != self.path[self._idx]:
            try:
                self._idx = self.path.index(cur, self._idx)
            except ValueError:
                return (0, 0)
        nxt = self.path[self._idx + 1]
        dx = int(np.clip(nxt[0] - cur[0], -1, 1))
        dy = int(np.clip(nxt[1] - cur[1], -1, 1))
        self._idx += 1
        return (dx, dy)


class RandomPolicy(RobotPolicy):
    """Reproducible random policy using env.rng."""

    def __init__(self, name: str = "R_Random"):
        self.name = name

    def reset(self, start_pos: Tuple[int, int]) -> None:
        pass

    def act(self, env: GridWorldStealthEnv, belief: ModeBelief, last_obs: Optional[int]) -> Action:
        candidates: List[Action] = []
        x, y = env.pos
        for a in [(1, 0), (-1, 0), (0, 1), (0, -1), (0, 0)]:
            np_ = (x + a[0], y + a[1])
            if env.is_free(np_):
                candidates.append(a)
        if not candidates:
            return (0, 0)
        return candidates[int(env.rng.integers(0, len(candidates)))]


class OnlineBeliefReplanPolicy(RobotPolicy):
    """POMDP-ish heuristic: replan each step using risk map induced by current belief b_t."""

    def __init__(self, env: GridWorldStealthEnv, risk_weight: float = 12.0, name: str = "R_OnlineBeliefReplan"):
        self.env = env
        self.risk_weight = float(risk_weight)
        self.name = name
        self._cached_next: Optional[Tuple[int, int]] = None

    def reset(self, start_pos: Tuple[int, int]) -> None:
        self._cached_next = None

    def act(self, env: GridWorldStealthEnv, belief: ModeBelief, last_obs: Optional[int]) -> Action:
        # Build a risk map from belief over modes.
        mode_probs = belief.b

        def step_cost(frm: Tuple[int, int], to: Tuple[int, int]) -> float:
            x, y = to
            # expected risk at to under belief
            r = 0.0
            for m, pm in enumerate(mode_probs):
                r += float(pm) * float(env.true_detection_prob(to, m))
            return 1.0 + self.risk_weight * r

        # Plan from current pos to goal (one-step receding horizon)
        try:
            path = astar_path(env.grid, env.pos, env.grid.goal, step_cost)
            if len(path) < 2:
                return (0, 0)
            nxt = path[1]
            dx = int(np.clip(nxt[0] - env.pos[0], -1, 1))
            dy = int(np.clip(nxt[1] - env.pos[1], -1, 1))
            return (dx, dy)
        except Exception:
            return (0, 0)


class SensorPolicy:
    name: str = "SensorPolicy"

    def reset(self) -> None:
        pass

    def select_mode(self, t: int) -> int:
        raise NotImplementedError


class FixedModeSensorPolicy(SensorPolicy):
    def __init__(self, mode: int, name: Optional[str] = None):
        self.mode = int(mode)
        self.name = name or f"S_Mode{mode}"

    def reset(self) -> None:
        pass

    def select_mode(self, t: int) -> int:
        return self.mode


class AlternatingSensorPolicy(SensorPolicy):
    """Simple sensor heuristic for benchmarks: alternate modes 0,1,0,1,..."""

    def __init__(self, M: int, name: str = "S_Alternate"):
        self.M = int(M)
        self.name = name

    def reset(self) -> None:
        pass

    def select_mode(self, t: int) -> int:
        return int(t % self.M)


class RandomModeSensorPolicy(SensorPolicy):
    """Benchmark sensor: random mode each step (uses numpy Generator for reproducibility)."""

    def __init__(self, M: int, seed: int = 0, name: str = "S_RandomMode"):
        self.M = int(M)
        self.rng = np.random.default_rng(int(seed))
        self.name = name

    def reset(self) -> None:
        pass

    def select_mode(self, t: int) -> int:
        return int(self.rng.integers(0, self.M))


# =============================================================================
# A* (used for planning)
# =============================================================================

def astar_path(
    grid: GridConfig,
    start: Tuple[int, int],
    goal: Tuple[int, int],
    step_cost: Callable[[Tuple[int, int], Tuple[int, int]], float],
    max_expansions: int = 250_000,
) -> List[Tuple[int, int]]:
    if start == goal:
        return [start]

    def h(p: Tuple[int, int]) -> float:
        return abs(p[0] - goal[0]) + abs(p[1] - goal[1])

    def neighbors(p: Tuple[int, int]):
        x, y = p
        for dx, dy in [(1, 0), (-1, 0), (0, 1), (0, -1)]:
            np_ = (x + dx, y + dy)
            if 0 <= np_[0] < grid.width and 0 <= np_[1] < grid.height and np_ not in grid.obstacles:
                yield np_

    open_heap: List[Tuple[float, float, Tuple[int, int]]] = []
    heapq.heappush(open_heap, (h(start), 0.0, start))

    came: Dict[Tuple[int, int], Optional[Tuple[int, int]]] = {start: None}
    gscore: Dict[Tuple[int, int], float] = {start: 0.0}

    expansions = 0
    while open_heap:
        _, _, cur = heapq.heappop(open_heap)
        expansions += 1
        if cur == goal:
            path: List[Tuple[int, int]] = []
            while cur is not None:
                path.append(cur)
                cur = came[cur]
            path.reverse()
            return path
        if expansions > max_expansions:
            raise RuntimeError("A* exceeded max expansions")

        for nb in neighbors(cur):
            tentative = gscore[cur] + float(step_cost(cur, nb))
            if (nb not in gscore) or (tentative < gscore[nb] - 1e-12):
                gscore[nb] = tentative
                came[nb] = cur
                heapq.heappush(open_heap, (tentative + h(nb), tentative, nb))

    raise RuntimeError("A* failed: unreachable goal")


def astar_via_waypoint(
    grid: GridConfig,
    start: Tuple[int, int],
    waypoint: Tuple[int, int],
    goal: Tuple[int, int],
    step_cost: Callable[[Tuple[int, int], Tuple[int, int]], float],
) -> List[Tuple[int, int]]:
    p1 = astar_path(grid, start, waypoint, step_cost)
    p2 = astar_path(grid, waypoint, goal, step_cost)
    return p1[:-1] + p2


# =============================================================================
# Rollouts + payoffs
# =============================================================================

@dataclass
class EpisodeStats:
    steps: int
    reached_goal: bool
    detected: bool
    total_true_risk: float
    U_R: float
    U_S: float


def rollout_episode(
    env: GridWorldStealthEnv,
    robot: RobotPolicy,
    sensor: SensorPolicy,
    M_modes: int,
    seed: int,
    max_steps: int = 200,
    lambda_risk: float = 1.0,
    det_penalty: float = 50.0,
    sensor_energy_per_step: float = 0.2,
    step_debug: bool = False,
) -> EpisodeStats:
    env.seed(seed)
    sensor.reset()
    env.reset(sensor_mode=sensor.select_mode(0))

    belief = ModeBelief(M_modes)
    robot.reset(env.pos)

    total_risk = 0.0
    last_alarm: Optional[int] = None

    for _ in range(max_steps):
        env.sensor_mode = sensor.select_mode(env.t)
        a = robot.act(env, belief, last_alarm)
        out = env.step(a)

        total_risk += float(out["p_true"])
        belief.update(env, out["alarm"], out["pos"])
        last_alarm = int(out["alarm"])

        if step_debug:
            print(
                f"[Step] t={out['t']:3d} pos={out['pos']} a={a} p_true={out['p_true']:.3f} "
                f"alarm={out['alarm']} det={out['detected']} done={out['done']} b={belief.b.round(3)}"
            )

        if out["done"]:
            break

    reached_goal = (env.pos == env.grid.goal)
    detected = bool(env.detected)
    steps = int(env.t)

    cost_R = steps + lambda_risk * total_risk + (det_penalty if detected else 0.0)
    U_R = -float(cost_R)

    U_S = float((det_penalty if detected else 0.0) + lambda_risk * total_risk - sensor_energy_per_step * steps)

    return EpisodeStats(steps=steps, reached_goal=reached_goal, detected=detected, total_true_risk=float(total_risk), U_R=U_R, U_S=U_S)


def evaluate_payoffs(
    env: GridWorldStealthEnv,
    robots: List[RobotPolicy],
    sensors: List[SensorPolicy],
    M_modes: int,
    rollouts: int,
    base_seed: int,
    log: Logger,
    debug_rollout_pair: Optional[Tuple[int, int]] = None,
) -> Tuple[np.ndarray, np.ndarray, Dict[Tuple[int, int], Dict[str, float]]]:
    m, n = len(robots), len(sensors)
    U_R = np.zeros((m, n), dtype=float)
    U_S = np.zeros((m, n), dtype=float)
    diag: Dict[Tuple[int, int], Dict[str, float]] = {}

    log.info(f"[Eval] Estimating payoffs: m={m}, n={n}, rollouts={rollouts}, base_seed={base_seed}")

    for i, rpol in enumerate(robots):
        for j, spol in enumerate(sensors):
            step_debug = (debug_rollout_pair == (i, j))

            r_list: List[float] = []
            s_list: List[float] = []
            det = 0
            goal = 0
            steps_list: List[int] = []
            risk_list: List[float] = []

            for k in range(rollouts):
                seed = base_seed + 100000 * i + 1000 * j + k
                st = rollout_episode(env, rpol, spol, M_modes=M_modes, seed=seed, step_debug=step_debug)
                r_list.append(st.U_R)
                s_list.append(st.U_S)
                det += int(st.detected)
                goal += int(st.reached_goal)
                steps_list.append(st.steps)
                risk_list.append(st.total_true_risk)

                if step_debug:
                    step_debug = False  # only show one rollout

            U_R[i, j] = float(np.mean(r_list))
            U_S[i, j] = float(np.mean(s_list))

            diag[(i, j)] = {
                "det_rate": det / rollouts,
                "goal_rate": goal / rollouts,
                "mean_steps": float(np.mean(steps_list)),
                "mean_risk": float(np.mean(risk_list)),
                "std_UR": float(np.std(r_list)),
                "std_US": float(np.std(s_list)),
            }

    if log.k >= 1:
        log.info("[Eval] Compact payoff summary:")
        for i, rpol in enumerate(robots):
            for j, spol in enumerate(sensors):
                d = diag[(i, j)]
                log.info(
                    f"  (R{i}:{rpol.name}, S{j}:{spol.name}) "
                    f"UR={U_R[i,j]:8.3f}±{d['std_UR']:.2f} | "
                    f"US={U_S[i,j]:8.3f}±{d['std_US']:.2f} | "
                    f"det%={100*d['det_rate']:5.1f} goal%={100*d['goal_rate']:5.1f} "
                    f"steps={d['mean_steps']:.1f} risk={d['mean_risk']:.2f}"
                )

    return U_R, U_S, diag


# =============================================================================
# NBS solver (with optional entropy regularization)
# =============================================================================

def project_simplex(v: np.ndarray, z: float = 1.0) -> np.ndarray:
    v = np.asarray(v, dtype=float).reshape(-1)
    if v.size == 0:
        raise ValueError("Empty vector")
    if z <= 0:
        raise ValueError("z must be > 0")

    u = np.sort(v)[::-1]
    cssv = np.cumsum(u)
    rho = np.nonzero(u * np.arange(1, v.size + 1) > (cssv - z))[0]
    if rho.size == 0:
        return np.full_like(v, z / v.size)
    rho = int(rho[-1])
    theta = (cssv[rho] - z) / (rho + 1.0)
    w = np.maximum(v - theta, 0.0)
    s = float(w.sum())
    if not np.isfinite(s) or s <= 0:
        return np.full_like(v, z / v.size)
    return w * (z / s)


@dataclass
class NBSResult:
    x: np.ndarray
    obj: float
    gains: Tuple[float, float]
    support: int


def solve_nbs(
    uR: np.ndarray,
    uS: np.ndarray,
    log: Logger,
    max_iters: int = 400,
    alpha: float = 0.5,
    tol_l1: float = 1e-6,
    kappa: float = 1e-6,
    disagreement: str = "minminus",
    entropy_tau: float = 0.0,
) -> NBSResult:
    uR = np.asarray(uR, dtype=float).reshape(-1)
    uS = np.asarray(uS, dtype=float).reshape(-1)
    if uR.shape != uS.shape:
        raise ValueError("uR and uS must have same shape")
    d = uR.size
    if d < 2:
        raise ValueError("Need >=2 joint actions")

    unif = np.full(d, 1.0 / d)

    disagreement = disagreement.lower().strip()
    if disagreement == "minminus":
        dR = float(np.min(uR) - 1.0)
        dS = float(np.min(uS) - 1.0)
    elif disagreement == "uniform":
        dR = float(uR @ unif)
        dS = float(uS @ unif)
    else:
        raise ValueError("disagreement must be 'minminus' or 'uniform'")

    x = unif.copy()

    def gains(xv: np.ndarray) -> Tuple[float, float]:
        return float(uR @ xv - dR), float(uS @ xv - dS)

    def entropy(xv: np.ndarray) -> float:
        xx = np.clip(xv, 1e-12, 1.0)
        return float(-np.sum(xx * np.log(xx)))

    def obj(xv: np.ndarray) -> float:
        gR, gS = gains(xv)
        gR = max(gR, kappa)
        gS = max(gS, kappa)
        return float(np.log(gR) + np.log(gS) + entropy_tau * entropy(xv))

    def grad(xv: np.ndarray) -> np.ndarray:
        gR, gS = gains(xv)
        gR = max(gR, kappa)
        gS = max(gS, kappa)
        g = (uR / gR) + (uS / gS)
        if entropy_tau > 0:
            xx = np.clip(xv, 1e-12, 1.0)
            g += entropy_tau * (-(np.log(xx) + 1.0))
        return g

    last = obj(x)
    log.info(f"[NBS] d={d} disagreement=({dR:.3f},{dS:.3f}) entropy_tau={entropy_tau:.3g}")

    for t in range(1, max_iters + 1):
        g = grad(x)
        a = alpha
        improved = False
        for _ in range(30):
            x_new = project_simplex(x + a * g)
            new_obj = obj(x_new)
            if new_obj >= last - 1e-12:
                improved = True
                break
            a *= 0.5
            if a < 1e-6:
                break
        if not improved:
            break

        delta = float(np.linalg.norm(x_new - x, ord=1))
        x = x_new
        last = new_obj

        if log.k >= 2 and (t <= 5 or t % 25 == 0):
            gR, gS = gains(x)
            top = np.argsort(-x)[:5]
            top_str = ", ".join([f"{i}:{x[i]:.3f}" for i in top])
            log.debug(f"[NBS][it={t:3d}] obj={last:.6f} gains=({gR:.3f},{gS:.3f}) L1={delta:.2e} top={top_str}")

        if delta < tol_l1:
            break

    gR, gS = gains(x)
    support = int(np.sum(x > 1e-6))
    log.info(f"[NBS] done: obj={last:.6f} gains=({gR:.3f},{gS:.3f}) support={support}/{d}")

    return NBSResult(x=x, obj=float(last), gains=(float(gR), float(gS)), support=support)


def joint_to_matrix(x: np.ndarray, m: int, n: int) -> np.ndarray:
    x = np.asarray(x, dtype=float).reshape(-1)
    if x.size != m * n:
        raise ValueError("x size mismatch")
    X = x.reshape((m, n))
    s = float(X.sum())
    if not np.isfinite(s) or abs(s - 1.0) > 1e-6:
        # Renormalize defensively
        X = X / max(s, 1e-12)
    return X


def marginals_from_joint(X: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    sigma_R = X.sum(axis=1)
    sigma_S = X.sum(axis=0)
    if sigma_R.sum() > 0:
        sigma_R = sigma_R / sigma_R.sum()
    if sigma_S.sum() > 0:
        sigma_S = sigma_S / sigma_S.sum()
    return sigma_R, sigma_S


def entropy_of_joint(X: np.ndarray) -> float:
    xx = np.clip(X.reshape(-1), 1e-12, 1.0)
    return float(-np.sum(xx * np.log(xx)))


# =============================================================================
# Best responses (marginal vs correlated)
# =============================================================================

def compute_expected_risk_map_from_policy_mixture(
    env: GridWorldStealthEnv,
    sensors: List[SensorPolicy],
    pi: np.ndarray,
    M_modes: int,
) -> np.ndarray:
    """Expected p_true(cell) under mixture over sensor POLICIES (not modes).

    For each sensor policy j, we use its mode at t=0 as its defining mode.
    (This matches FixedModeSensorPolicy exactly.)
    """
    pi = np.asarray(pi, dtype=float).reshape(-1)
    if pi.size != len(sensors):
        raise ValueError("mixture length mismatch")

    mode_probs = np.zeros(M_modes, dtype=float)
    for j, sp in enumerate(sensors):
        m = int(sp.select_mode(0))
        if not (0 <= m < M_modes):
            raise ValueError("invalid sensor mode")
        mode_probs[m] += float(pi[j])
    if mode_probs.sum() > 0:
        mode_probs = mode_probs / mode_probs.sum()

    H, W = env.grid.height, env.grid.width
    risk = np.zeros((H, W), dtype=float)
    for y in range(H):
        for x in range(W):
            if (x, y) in env.grid.obstacles:
                risk[y, x] = np.nan
                continue
            val = 0.0
            for m in range(M_modes):
                val += float(mode_probs[m]) * float(env.true_detection_prob((x, y), m))
            risk[y, x] = float(val)

    return risk


def plan_risk_weighted_path(env: GridWorldStealthEnv, risk_map: np.ndarray, risk_weight: float) -> List[Tuple[int, int]]:
    def step_cost(frm: Tuple[int, int], to: Tuple[int, int]) -> float:
        x, y = to
        r = risk_map[y, x]
        if not np.isfinite(r):
            return 1e9
        return 1.0 + float(risk_weight) * float(r)

    return astar_path(env.grid, env.grid.start, env.grid.goal, step_cost)


def robot_best_response_to_mixture(
    env: GridWorldStealthEnv,
    sensors: List[SensorPolicy],
    pi_S: np.ndarray,
    robots: List[RobotPolicy],
    M_modes: int,
    risk_weight: float,
    log: Logger,
    tag: str,
) -> RobotPolicy:
    risk = compute_expected_risk_map_from_policy_mixture(env, sensors, pi_S, M_modes=M_modes)
    try:
        path = plan_risk_weighted_path(env, risk, risk_weight=risk_weight)
    except Exception as e:
        log.info(f"[RobotBR] WARNING A* failed ({tag}): {e}")
        return robots[0]

    path_tuple = tuple(path)
    for p in robots:
        if isinstance(p, FixedPathPolicy) and tuple(p.path) == path_tuple:
            log.info(f"[RobotBR] ({tag}) BR path already exists: {p.name}")
            return p

    newp = FixedPathPolicy(path, name=f"R_BR_{tag}_w{risk_weight:.1f}_len{len(path)}")
    log.info(f"[RobotBR] ({tag}) Added new robot policy: {newp.name}")
    return newp


def sensor_best_response_to_mixture(
    env: GridWorldStealthEnv,
    robots: List[RobotPolicy],
    pi_R: np.ndarray,
    candidate_modes: List[int],
    M_modes: int,
    rollouts: int,
    base_seed: int,
    log: Logger,
    tag: str,
) -> FixedModeSensorPolicy:
    """Sensor best response with common-random-numbers (CRN).

    Why CRN matters: payoff variance is large (detection is a rare/threshold event).
    If each candidate mode is evaluated on different random rollouts, you can pick
    the wrong 'best mode' by noise, which then breaks the PSRO expansion logic.

    Fix: reuse the same sampled robot indices AND the same episode seeds across all
    candidate modes.
    """
    pi_R = np.asarray(pi_R, dtype=float).reshape(-1)
    if pi_R.size != len(robots):
        raise ValueError("pi_R length mismatch")

    rng = np.random.default_rng(int(base_seed))

    # Same robot-index samples for every mode
    robot_idxs = rng.choice(len(robots), size=rollouts, p=pi_R, replace=True)

    # Same episode seeds for every mode (common random numbers)
    seeds = (int(base_seed) + np.arange(rollouts)).astype(int)

    best_mode: Optional[int] = None
    best_val = -1e18

    for mode in candidate_modes:
        if not (0 <= mode < M_modes):
            continue
        sp = FixedModeSensorPolicy(mode, name=f"S_BR_{tag}_Mode{mode}")

        vals: List[float] = []
        for k in range(rollouts):
            rp = robots[int(robot_idxs[k])]
            st = rollout_episode(env, rp, sp, M_modes=M_modes, seed=int(seeds[k]))
            vals.append(st.U_S)

        mean_u = float(np.mean(vals))
        if log.k >= 2:
            log.debug(f"[SensorBR] ({tag}) mode={mode} E[US]={mean_u:.3f} std={float(np.std(vals)):.2f}")

        if mean_u > best_val:
            best_val = mean_u
            best_mode = mode

    if best_mode is None:
        raise RuntimeError("No valid sensor BR mode found")

    log.info(f"[SensorBR] ({tag}) Best mode={best_mode} E[US]={best_val:.3f}")
    return FixedModeSensorPolicy(best_mode, name=f"S_BR_{tag}_Mode{best_mode}")


def conditional_sensor_given_robot(X: np.ndarray, i: int, eps: float = 1e-12) -> np.ndarray:
    row = np.asarray(X[i, :], dtype=float)
    s = float(row.sum())
    if s <= eps:
        return np.full_like(row, 1.0 / row.size)
    return row / s


def conditional_robot_given_sensor(X: np.ndarray, j: int, eps: float = 1e-12) -> np.ndarray:
    col = np.asarray(X[:, j], dtype=float)
    s = float(col.sum())
    if s <= eps:
        return np.full_like(col, 1.0 / col.size)
    return col / s


def compute_ce_regrets(U_R: np.ndarray, U_S: np.ndarray, X: np.ndarray, eps: float = 1e-12) -> Dict[str, float]:
    """Conditional recommendation regrets (CE-style) computed on current meta-game.

    For robot (given recommendation i):
        regret_R(i) = max_{i'} E_{j~q(.|i)}[U_R(i',j) - U_R(i,j)]

    For sensor (given recommendation j):
        regret_S(j) = max_{j'} E_{i~q(.|j)}[U_S(i,j') - U_S(i,j)]

    Returns max and average regrets.
    """
    m, n = U_R.shape
    assert U_S.shape == (m, n)
    assert X.shape == (m, n)

    sigma_R, sigma_S = marginals_from_joint(X)

    reg_R = []
    for i in range(m):
        if sigma_R[i] <= eps:
            continue
        q = conditional_sensor_given_robot(X, i, eps=eps)
        rec = float(np.dot(q, U_R[i, :]))
        best = rec
        for ip in range(m):
            val = float(np.dot(q, U_R[ip, :]))
            if val > best:
                best = val
        reg_R.append(best - rec)

    reg_S = []
    for j in range(n):
        if sigma_S[j] <= eps:
            continue
        q = conditional_robot_given_sensor(X, j, eps=eps)
        rec = float(np.dot(q, U_S[:, j]))
        best = rec
        for jp in range(n):
            val = float(np.dot(q, U_S[:, jp]))
            if val > best:
                best = val
        reg_S.append(best - rec)

    return {
        "max_regret_R": float(max(reg_R) if reg_R else 0.0),
        "max_regret_S": float(max(reg_S) if reg_S else 0.0),
        "mean_regret_R": float(np.mean(reg_R) if reg_R else 0.0),
        "mean_regret_S": float(np.mean(reg_S) if reg_S else 0.0),
    }


def find_sensor_by_mode(pols: List[SensorPolicy], mode: int) -> Optional[FixedModeSensorPolicy]:
    for p in pols:
        if isinstance(p, FixedModeSensorPolicy) and p.mode == mode:
            return p
    return None


# =============================================================================
# Policy initialization + evaluator for joint strategy
# =============================================================================

def build_initial_policies(env: GridWorldStealthEnv, M_modes: int) -> Tuple[List[RobotPolicy], List[SensorPolicy]]:
    grid = env.grid

    step_cost = lambda a, b: 1.0
    p_short = astar_path(grid, grid.start, grid.goal, step_cost)

    wall_x = grid.width // 2
    p_upper = astar_via_waypoint(grid, grid.start, (wall_x, 2), grid.goal, step_cost)
    p_lower = astar_via_waypoint(grid, grid.start, (wall_x, 6), grid.goal, step_cost)

    robots: List[RobotPolicy] = [
        FixedPathPolicy(p_short, "R_Shortest"),
        FixedPathPolicy(p_upper, "R_UpperCorridor"),
        FixedPathPolicy(p_lower, "R_LowerCorridor"),
        RandomPolicy("R_Random"),
        OnlineBeliefReplanPolicy(env, risk_weight=12.0, name="R_OnlineBeliefReplan"),
    ]
    sensors: List[SensorPolicy] = [FixedModeSensorPolicy(m, name=f"S_Mode{m}") for m in range(M_modes)]
    return robots, sensors


@dataclass
class StrategyEval:
    mean_U_R: float
    mean_U_S: float
    det_rate: float
    goal_rate: float
    mean_steps: float
    mean_risk: float


def evaluate_joint_strategy(
    env: GridWorldStealthEnv,
    robots: List[RobotPolicy],
    sensors: List[SensorPolicy],
    X: np.ndarray,
    M_modes: int,
    episodes: int,
    base_seed: int,
) -> StrategyEval:
    m, n = X.shape
    probs = X.reshape(-1)
    probs = probs / max(float(probs.sum()), 1e-12)

    rng = np.random.default_rng(int(base_seed))

    UR = []
    US = []
    det = 0
    goal = 0
    steps_list = []
    risk_list = []

    for k in range(episodes):
        idx = int(rng.choice(m * n, p=probs))
        i, j = np.unravel_index(idx, (m, n))
        seed = base_seed + k
        st = rollout_episode(env, robots[i], sensors[j], M_modes=M_modes, seed=seed)
        UR.append(st.U_R)
        US.append(st.U_S)
        det += int(st.detected)
        goal += int(st.reached_goal)
        steps_list.append(st.steps)
        risk_list.append(st.total_true_risk)

    return StrategyEval(
        mean_U_R=float(np.mean(UR)),
        mean_U_S=float(np.mean(US)),
        det_rate=float(det / episodes),
        goal_rate=float(goal / episodes),
        mean_steps=float(np.mean(steps_list)),
        mean_risk=float(np.mean(risk_list)),
    )


# =============================================================================
# Training loop (marginal vs correlated)
# =============================================================================

@dataclass
class TrainHistoryRow:
    outer_iter: int
    m: int
    n: int
    nbs_obj: float
    entropy_X: float
    max_regret_R: float
    max_regret_S: float
    selfplay_UR: float
    selfplay_US: float
    selfplay_det: float
    selfplay_goal: float
    seconds: float


@dataclass
class TrainResult:
    solver: str
    env: GridWorldStealthEnv
    robots: List[RobotPolicy]
    sensors: List[SensorPolicy]
    X: np.ndarray
    history: List[TrainHistoryRow]


def run_training(env: GridWorldStealthEnv, args: argparse.Namespace, solver: str, log: Logger) -> TrainResult:
    t0_all = time.time()

    grid = env.grid
    sensor_cfg = env.sensor_cfg
    M_modes = len(sensor_cfg.sensors)

    robots, sensors = build_initial_policies(env, M_modes=M_modes)

    # Optional: reduce initial set if you want smaller games.
    # (We keep it as-is for benchmarks.)

    if log.k >= 1:
        log.info(f"[{solver}] Initial robots: " + ", ".join([p.name for p in robots]))
        log.info(f"[{solver}] Initial sensors: " + ", ".join([p.name for p in sensors]))

    debug_pair = None
    if args.debug_rollout_pair:
        parts = args.debug_rollout_pair.split(",")
        if len(parts) == 2:
            debug_pair = (int(parts[0]), int(parts[1]))
            log.info(f"[{solver}] Will print one step-by-step rollout for pair {debug_pair} (only once).")

    history: List[TrainHistoryRow] = []
    X = None

    for it in range(1, args.outer_iters + 1):
        t0 = time.time()
        log.banner(f"[{solver}] Outer iter {it}/{args.outer_iters}")

        U_R, U_S, _diag = evaluate_payoffs(
            env,
            robots,
            sensors,
            M_modes=M_modes,
            rollouts=args.rollouts_payoff,
            base_seed=1000 + 100 * it,
            log=log,
            debug_rollout_pair=debug_pair,
        )
        debug_pair = None

        # Solve NBS over joint actions
        uR = U_R.reshape(-1)
        uS = U_S.reshape(-1)

        nbs = solve_nbs(
            uR,
            uS,
            log=log,
            disagreement=args.disagreement,
            entropy_tau=args.entropy_tau,
        )

        m, n = U_R.shape
        X = joint_to_matrix(nbs.x, m, n)
        sigma_R, sigma_S = marginals_from_joint(X)

        # Print top joint actions
        top = np.argsort(-X.reshape(-1))[:min(5, X.size)]
        log.info(f"[{solver}] Top joint actions:")
        for k, idx in enumerate(top, start=1):
            i, j = np.unravel_index(int(idx), (m, n))
            log.info(f"  #{k}: (R{i}:{robots[i].name}, S{j}:{sensors[j].name}) prob={X[i,j]:.4f}")
        log.info(f"[{solver}] sigma_R={sigma_R.round(3)}")
        log.info(f"[{solver}] sigma_S={sigma_S.round(3)}")

        # Stability diagnostics
        regrets = compute_ce_regrets(U_R, U_S, X)
        ent = entropy_of_joint(X)

        # Self-play evaluation under joint X
        sp = evaluate_joint_strategy(
            env,
            robots,
            sensors,
            X,
            M_modes=M_modes,
            episodes=args.eval_episodes,
            base_seed=9000 + 100 * it,
        )

        log.info(
            f"[{solver}] CE-regrets: maxR={regrets['max_regret_R']:.3f} maxS={regrets['max_regret_S']:.3f} | "
            f"SelfPlay: UR={sp.mean_U_R:.2f} US={sp.mean_U_S:.2f} det%={100*sp.det_rate:.1f} goal%={100*sp.goal_rate:.1f} | "
            f"H(X)={ent:.3f}"
        )

        # Best-response expansion
        if solver == "marginal":
            br_r = robot_best_response_to_mixture(
                env,
                sensors,
                pi_S=sigma_S,
                robots=robots,
                M_modes=M_modes,
                risk_weight=args.risk_weight_br,
                log=log,
                tag="Marginal",
            )

            br_s = sensor_best_response_to_mixture(
                env,
                robots,
                pi_R=sigma_R,
                candidate_modes=list(range(M_modes)),
                M_modes=M_modes,
                rollouts=args.rollouts_br,
                base_seed=2000 + 100 * it,
                log=log,
                tag="Marginal",
            )

        elif solver == "correlated":
            # FIX: choose conditional mixtures q(.|i) and q(.|j)
            # We only check top-K recommendations to keep it tractable.
            topK = max(1, int(args.cond_top_k))

            # Robot: check the top-K robot recommendations by sigma_R
            cand_i = [int(i) for i in np.argsort(-sigma_R) if sigma_R[int(i)] > 1e-8][:topK]
            if not cand_i:
                cand_i = [int(i) for i in np.argsort(-sigma_R)[:topK]]
            best_gain = 0.0
            br_r = robots[0]

            for i in cand_i:
                qS = conditional_sensor_given_robot(X, int(i))
                tag = f"Cond_i{i}"
                pol = robot_best_response_to_mixture(
                    env,
                    sensors,
                    pi_S=qS,
                    robots=robots,
                    M_modes=M_modes,
                    risk_weight=args.risk_weight_br,
                    log=log,
                    tag=tag,
                )

                # Estimate conditional improvement: E_q[U_R(pol,j)] - E_q[U_R(i,j)]
                # If pol is already in set, its payoff exists in U_R row of that policy.
                # Otherwise we simulate pol against each sensor policy j.
                if pol in robots:
                    ip = robots.index(pol)
                    dev = float(np.dot(qS, U_R[ip, :]))
                else:
                    # simulate quickly vs each sensor policy
                    dev_vals = []
                    for j in range(n):
                        vals = []
                        for kk in range(args.br_eval_rollouts):
                            seed = 777000 + 1000 * it + 100 * i + 10 * j + kk
                            st = rollout_episode(env, pol, sensors[j], M_modes=M_modes, seed=seed)
                            vals.append(st.U_R)
                        dev_vals.append(float(np.mean(vals)))
                    dev = float(np.dot(qS, np.asarray(dev_vals)))

                rec = float(np.dot(qS, U_R[i, :]))
                gain = dev - rec
                if gain > best_gain + 1e-9:
                    best_gain = gain
                    br_r = pol

            if best_gain > args.add_threshold:
                if br_r in robots:
                    log.info(f"[{solver}] Best robot deviation already in set; est_gain={best_gain:.3f}")
                else:
                    log.info(f"[{solver}] Adding robot deviation; est_gain={best_gain:.3f}")
            else:
                log.info(f"[{solver}] No robot deviation above threshold (best_gain={best_gain:.3f}).")

            # Sensor: check top-K sensor recommendations by sigma_S
            cand_j = [int(j) for j in np.argsort(-sigma_S) if sigma_S[int(j)] > 1e-8][:topK]
            if not cand_j:
                cand_j = [int(j) for j in np.argsort(-sigma_S)[:topK]]
            best_gain_s = 0.0
            br_s = FixedModeSensorPolicy(0, name="S_dummy")

            for j in cand_j:
                qR = conditional_robot_given_sensor(X, int(j))
                tag = f"Cond_j{j}"
                polS = sensor_best_response_to_mixture(
                    env,
                    robots,
                    pi_R=qR,
                    candidate_modes=list(range(M_modes)),
                    M_modes=M_modes,
                    rollouts=args.rollouts_br,
                    base_seed=333000 + 1000 * it + 10 * j,
                    log=log,
                    tag=tag,
                )

                # Estimate conditional improvement for sensor
                # rec under recommendation j is E_q[U_S(i,j)]
                recS = float(np.dot(qR, U_S[:, j]))

                # dev under mode polS.mode: if already present, use its column.
                existing_col = None
                for jj, spj in enumerate(sensors):
                    if isinstance(spj, FixedModeSensorPolicy) and spj.mode == polS.mode:
                        existing_col = jj
                        break

                if existing_col is not None:
                    devS = float(np.dot(qR, U_S[:, existing_col]))
                else:
                    vals = []
                    for kk in range(args.br_eval_rollouts):
                        i_samp = int(np.random.default_rng(444 + kk).choice(len(robots), p=qR))
                        seed = 888000 + 1000 * it + 10 * j + kk
                        st = rollout_episode(env, robots[i_samp], polS, M_modes=M_modes, seed=seed)
                        vals.append(st.U_S)
                    devS = float(np.mean(vals))

                gainS = devS - recS
                if gainS > best_gain_s + 1e-9:
                    best_gain_s = gainS
                    br_s = polS

            if best_gain_s > args.add_threshold:
                if (isinstance(br_s, FixedModeSensorPolicy)) and (find_sensor_by_mode(sensors, br_s.mode) is not None):
                    log.info(f"[{solver}] Best sensor deviation already in set (mode={br_s.mode}); est_gain={best_gain_s:.3f}")
                    br_s = None
                else:
                    log.info(f"[{solver}] Adding sensor deviation; est_gain={best_gain_s:.3f}")
            else:
                log.info(f"[{solver}] No sensor deviation above threshold (best_gain={best_gain_s:.3f}).")
                br_s = None

        else:
            raise ValueError("solver must be marginal or correlated")

        # Add to sets (dedupe)
        if br_r not in robots:
            robots.append(br_r)

        if isinstance(br_s, FixedModeSensorPolicy):
            if find_sensor_by_mode(sensors, br_s.mode) is None:
                sensors.append(br_s)
            else:
                log.info(f"[{solver}] Sensor mode {br_s.mode} already present; not adding duplicate.")

        seconds = float(time.time() - t0)
        history.append(
            TrainHistoryRow(
                outer_iter=it,
                m=len(robots),
                n=len(sensors),
                nbs_obj=float(nbs.obj),
                entropy_X=float(ent),
                max_regret_R=float(regrets["max_regret_R"]),
                max_regret_S=float(regrets["max_regret_S"]),
                selfplay_UR=float(sp.mean_U_R),
                selfplay_US=float(sp.mean_U_S),
                selfplay_det=float(sp.det_rate),
                selfplay_goal=float(sp.goal_rate),
                seconds=seconds,
            )
        )

        log.info(f"[{solver}] Sets: |Pi_R|={len(robots)} |Pi_S|={len(sensors)} | iter_seconds={seconds:.2f}")

    if X is None:
        raise RuntimeError("Training produced no X")

    log.banner(f"[{solver}] Finished")
    log.info(f"[{solver}] Final robots: " + ", ".join([p.name for p in robots]))
    log.info(f"[{solver}] Final sensors: " + ", ".join([p.name for p in sensors]))
    log.info(f"[{solver}] Total time: {time.time()-t0_all:.2f}s")

    return TrainResult(solver=solver, env=env, robots=robots, sensors=sensors, X=X, history=history)


# =============================================================================
# Benchmarks + plotting
# =============================================================================

@dataclass
class BenchCell:
    UR: float
    US: float
    det: float
    goal: float


def run_policy_matrix_benchmark(
    env: GridWorldStealthEnv,
    robot_methods: List[RobotPolicy],
    sensor_methods: List[SensorPolicy],
    M_modes: int,
    episodes: int,
    base_seed: int,
) -> Dict[Tuple[int, int], BenchCell]:
    res: Dict[Tuple[int, int], BenchCell] = {}
    for i, rp in enumerate(robot_methods):
        for j, sp in enumerate(sensor_methods):
            UR = []
            US = []
            det = 0
            goal = 0
            for k in range(episodes):
                seed = base_seed + 100000 * i + 1000 * j + k
                st = rollout_episode(env, rp, sp, M_modes=M_modes, seed=seed)
                UR.append(st.U_R)
                US.append(st.U_S)
                det += int(st.detected)
                goal += int(st.reached_goal)
            res[(i, j)] = BenchCell(
                UR=float(np.mean(UR)),
                US=float(np.mean(US)),
                det=float(det / episodes),
                goal=float(goal / episodes),
            )
    return res


def safe_makedirs(path: str) -> None:
    os.makedirs(path, exist_ok=True)


def save_history_csv(hist: List[TrainHistoryRow], path: str) -> None:
    import csv

    fields = list(TrainHistoryRow.__annotations__.keys())
    with open(path, "w", newline="") as f:
        w = csv.DictWriter(f, fieldnames=fields)
        w.writeheader()
        for row in hist:
            w.writerow({k: getattr(row, k) for k in fields})


def plot_training_curves(results: List[TrainResult], outdir: str) -> None:
    import matplotlib.pyplot as plt

    safe_makedirs(outdir)

    # 1) NBS objective
    plt.figure()
    for r in results:
        xs = [h.outer_iter for h in r.history]
        ys = [h.nbs_obj for h in r.history]
        plt.plot(xs, ys, marker="o", label=r.solver)
    plt.xlabel("Outer iteration")
    plt.ylabel("NBS objective")
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(outdir, "curve_nbs_obj.png"), dpi=200)
    plt.close()

    # 2) Max conditional regrets
    plt.figure()
    for r in results:
        xs = [h.outer_iter for h in r.history]
        ys = [h.max_regret_R for h in r.history]
        plt.plot(xs, ys, marker="o", label=f"{r.solver}: maxRegret_R")
    for r in results:
        xs = [h.outer_iter for h in r.history]
        ys = [h.max_regret_S for h in r.history]
        plt.plot(xs, ys, marker="x", linestyle="--", label=f"{r.solver}: maxRegret_S")
    plt.xlabel("Outer iteration")
    plt.ylabel("Max conditional regret")
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(outdir, "curve_max_regrets.png"), dpi=200)
    plt.close()

    # 3) Entropy of X
    plt.figure()
    for r in results:
        xs = [h.outer_iter for h in r.history]
        ys = [h.entropy_X for h in r.history]
        plt.plot(xs, ys, marker="o", label=r.solver)
    plt.xlabel("Outer iteration")
    plt.ylabel("Entropy H(X)")
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(outdir, "curve_entropy_X.png"), dpi=200)
    plt.close()

    # 4) Self-play outcomes
    plt.figure()
    for r in results:
        xs = [h.outer_iter for h in r.history]
        ys = [h.selfplay_UR for h in r.history]
        plt.plot(xs, ys, marker="o", label=f"{r.solver}: UR")
    for r in results:
        xs = [h.outer_iter for h in r.history]
        ys = [h.selfplay_US for h in r.history]
        plt.plot(xs, ys, marker="x", linestyle="--", label=f"{r.solver}: US")
    plt.xlabel("Outer iteration")
    plt.ylabel("Expected utility under X")
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(outdir, "curve_selfplay_utils.png"), dpi=200)
    plt.close()


def plot_benchmark_bars(
    robot_names: List[str],
    sensor_names: List[str],
    bench: Dict[Tuple[int, int], BenchCell],
    outdir: str,
) -> None:
    import matplotlib.pyplot as plt

    safe_makedirs(outdir)

    # For each sensor, bar chart of robot UR
    for j, sname in enumerate(sensor_names):
        plt.figure(figsize=(10, 4))
        vals = [bench[(i, j)].UR for i in range(len(robot_names))]
        plt.bar(range(len(robot_names)), vals)
        plt.xticks(range(len(robot_names)), robot_names, rotation=35, ha="right")
        plt.ylabel("Robot utility")
        plt.title(f"Robot utility vs sensor={sname}")
        plt.tight_layout()
        plt.savefig(os.path.join(outdir, f"bench_UR_vs_{sname}.png"), dpi=200)
        plt.close()

    # For each sensor, bar chart of robot goal rate
    for j, sname in enumerate(sensor_names):
        plt.figure(figsize=(10, 4))
        vals = [bench[(i, j)].goal for i in range(len(robot_names))]
        plt.bar(range(len(robot_names)), vals)
        plt.xticks(range(len(robot_names)), robot_names, rotation=35, ha="right")
        plt.ylabel("Goal rate")
        plt.title(f"Robot goal rate vs sensor={sname}")
        plt.tight_layout()
        plt.savefig(os.path.join(outdir, f"bench_goal_vs_{sname}.png"), dpi=200)
        plt.close()


# =============================================================================
# Main pipeline wrapper
# =============================================================================

def run_pipeline(args: argparse.Namespace) -> None:
    log = Logger(args.log_level)

    grid = build_two_corridor_grid()
    sensor_cfg = SensorConfig(
        sensors=((grid.width // 2, 2), (grid.width // 2, 6)),
        radius=2,
        base_p=0.02,
        hotspot_p=0.60,
    )
    env = GridWorldStealthEnv(grid, sensor_cfg, fp=0.05, fn=0.10, seed=args.seed)

    if log.k >= 1:
        log.banner("[PIPELINE] Stealth grid game")
        log.info(f"Grid: {grid.width}x{grid.height} start={grid.start} goal={grid.goal} obstacles={len(grid.obstacles)}")
        log.info(f"Sensors: {sensor_cfg.sensors} radius={sensor_cfg.radius} base_p={sensor_cfg.base_p} hotspot_p={sensor_cfg.hotspot_p}")
        if log.k >= 2:
            log.debug("ASCII map:")
            print_grid_ascii(grid, sensor_cfg)

    solvers: List[str]
    if args.solver == "both":
        solvers = ["marginal", "correlated"]
    else:
        solvers = [args.solver]

    results: List[TrainResult] = []
    for s in solvers:
        # Use a fresh env copy per solver (to avoid RNG coupling)
        env_s = GridWorldStealthEnv(grid, sensor_cfg, fp=0.05, fn=0.10, seed=args.seed)
        results.append(run_training(env_s, args, solver=s, log=log))

    if args.results_dir:
        safe_makedirs(args.results_dir)
        for r in results:
            save_history_csv(r.history, os.path.join(args.results_dir, f"history_{r.solver}.csv"))

    if args.save_plots and args.results_dir:
        plot_training_curves(results, outdir=args.results_dir)

    # Benchmarks on the same game
    if args.run_benchmarks:
        log.banner("[BENCH] Heuristics on same game")
        M_modes = len(sensor_cfg.sensors)

        # Robot heuristics
        #  - fixed paths from initial set
        robots_init, sensors_init = build_initial_policies_for_bench(env, M_modes)

        # Sensor heuristics
        sensor_methods: List[SensorPolicy] = [
            FixedModeSensorPolicy(0, "S_Mode0"),
            FixedModeSensorPolicy(1, "S_Mode1"),
            AlternatingSensorPolicy(M_modes, "S_Alternate"),
            RandomModeSensorPolicy(M_modes, seed=args.seed + 123, name="S_RandomMode"),
        ]

        bench = run_policy_matrix_benchmark(
            env,
            robot_methods=robots_init,
            sensor_methods=sensor_methods,
            M_modes=M_modes,
            episodes=args.bench_episodes,
            base_seed=555000,
        )

        # Print a compact table
        robot_names = [p.name for p in robots_init]
        sensor_names = [p.name for p in sensor_methods]
        for j, sname in enumerate(sensor_names):
            log.info(f"[BENCH] Sensor={sname}")
            for i, rname in enumerate(robot_names):
                cell = bench[(i, j)]
                log.info(f"  Robot={rname:22s} UR={cell.UR:8.2f} goal%={100*cell.goal:5.1f} det%={100*cell.det:5.1f}")

        if args.save_plots and args.results_dir:
            plot_benchmark_bars(robot_names, sensor_names, bench, outdir=args.results_dir)


# Helper: initial robot set for benchmark (without the solver-added BR policies)

def build_initial_policies_for_bench(env: GridWorldStealthEnv, M_modes: int) -> Tuple[List[RobotPolicy], List[SensorPolicy]]:
    grid = env.grid
    step_cost = lambda a, b: 1.0
    p_short = astar_path(grid, grid.start, grid.goal, step_cost)

    wall_x = grid.width // 2
    p_upper = astar_via_waypoint(grid, grid.start, (wall_x, 2), grid.goal, step_cost)
    p_lower = astar_via_waypoint(grid, grid.start, (wall_x, 6), grid.goal, step_cost)

    # Static risk-aware A* under UNIFORM mode belief
    uniform_mode = np.full(M_modes, 1.0 / M_modes)

    # Construct risk map directly under uniform belief
    H, W = grid.height, grid.width
    risk_uniform = np.zeros((H, W), dtype=float)
    risk_worst = np.zeros((H, W), dtype=float)
    for y in range(H):
        for x in range(W):
            if (x, y) in grid.obstacles:
                risk_uniform[y, x] = np.nan
                risk_worst[y, x] = np.nan
                continue
            vals = [env.true_detection_prob((x, y), m) for m in range(M_modes)]
            risk_uniform[y, x] = float(np.dot(uniform_mode, vals))
            risk_worst[y, x] = float(np.max(vals))

    p_uniform = plan_risk_weighted_path(env, risk_uniform, risk_weight=12.0)
    p_worst = plan_risk_weighted_path(env, risk_worst, risk_weight=12.0)

    robots: List[RobotPolicy] = [
        FixedPathPolicy(p_short, "R_Shortest"),
        FixedPathPolicy(p_upper, "R_UpperCorridor"),
        FixedPathPolicy(p_lower, "R_LowerCorridor"),
        FixedPathPolicy(p_uniform, "R_RiskAStar_Uniform"),
        FixedPathPolicy(p_worst, "R_RiskAStar_WorstCase"),
        OnlineBeliefReplanPolicy(env, risk_weight=12.0, name="R_OnlineBeliefReplan"),
        RandomPolicy("R_Random"),
    ]

    sensors: List[SensorPolicy] = [FixedModeSensorPolicy(m, name=f"S_Mode{m}") for m in range(M_modes)]
    return robots, sensors


# =============================================================================
# CLI
# =============================================================================

def parse_args(argv: Optional[List[str]] = None) -> Tuple[argparse.Namespace, List[str]]:
    p = argparse.ArgumentParser()

    p.add_argument("--seed", type=int, default=0)
    p.add_argument("--log-level", type=str, default="INFO", choices=["QUIET", "INFO", "DEBUG"])

    p.add_argument("--solver", type=str, default="correlated", choices=["marginal", "correlated", "both"])

    p.add_argument("--outer-iters", type=int, default=3)

    p.add_argument("--rollouts-payoff", type=int, default=20)
    p.add_argument("--rollouts-br", type=int, default=30)
    p.add_argument("--risk-weight-br", type=float, default=12.0)

    # NBS knobs
    p.add_argument("--disagreement", type=str, default="minminus", choices=["minminus", "uniform"])
    p.add_argument("--entropy-tau", type=float, default=0.0)

    # Fix mismatch knobs
    p.add_argument("--cond-top-k", type=int, default=2, help="How many top recommendations to check for conditional BRs")
    p.add_argument("--br-eval-rollouts", type=int, default=8, help="Small evaluation rollouts for new deviations")
    p.add_argument("--add-threshold", type=float, default=0.25, help="Minimum estimated conditional gain to add a deviation policy")

    # Eval episodes under joint X (self-play)
    p.add_argument("--eval-episodes", type=int, default=60)

    # Debug
    p.add_argument("--debug-rollout-pair", type=str, default="", help="Print one step-by-step rollout for i,j")

    # Outputs
    p.add_argument("--results-dir", type=str, default="results", help="Directory for csv/plots")
    p.add_argument("--save-plots", action="store_true")

    # Benchmarks
    p.add_argument("--run-benchmarks", action="store_true")
    p.add_argument("--bench-episodes", type=int, default=80)

    args, unknown = p.parse_known_args(args=argv)
    return args, unknown


def main(argv: Optional[List[str]] = None) -> None:
    args, unknown = parse_args(argv=argv)
    if unknown and ("ipykernel" not in sys.modules):
        print(f"[WARN] Ignoring unknown CLI args: {unknown}")
    if args.debug_rollout_pair.strip() == "":
        args.debug_rollout_pair = ""
    run_pipeline(args)


if __name__ == "__main__" and ("ipykernel" not in sys.modules):
    main()


# =============================================================================
# EXPERIMENTS (fixed game set + baselines + tests + presentation plots)
# =============================================================================
#
# Keep the training code above exactly as-is.
# This section adds:
#   (A) sanity tests (quick asserts)
#   (B) a fixed game-set generator (deterministic)
#   (C) an experiment runner that compares solvers + baselines on the same games
#   (D) plots + CSV outputs for slides
#
# Notebook usage:
#   run_sanity_tests()
#   run_fixed_games_experiment(outdir="results_exp", n_games=6, seeds=[0,1,2])
#

from dataclasses import asdict


@dataclass(frozen=True)
class GameInstance:
    game_id: str
    grid: GridConfig
    sensor_cfg: SensorConfig
    desc: str


def _is_reachable(grid: GridConfig) -> bool:
    try:
        _ = astar_path(grid, grid.start, grid.goal, step_cost=lambda a, b: 1.0, max_expansions=250_000)
        return True
    except Exception:
        return False


def build_fixed_game_set(
    n_games: int = 6,
    seed: int = 123,
    width: int = 15,
    height: int = 9,
    base_radius: int = 2,
    base_p: float = 0.02,
    hotspot_p: float = 0.60,
) -> List[GameInstance]:
    """Deterministic *fixed* game set for fair comparisons.

    We generate corridor-variant grids by adding a small number of extra obstacles
    (without breaking reachability), while keeping the base corridor wall.

    Notes:
      - Uses a fixed RNG seed => same games every run.
      - Guarantees start->goal reachability (ignoring detection risk).
    """
    rng = np.random.default_rng(int(seed))

    base_grid = build_two_corridor_grid(width=width, height=height)
    wall_x = width // 2

    # Keep the two sensor centers near the corridor gaps for interpretability.
    base_sensors = ((wall_x, 2), (wall_x, 6))

    games: List[GameInstance] = []
    attempts = 0

    # Increasing difficulty: more extra obstacles.
    # (Chosen small so it stays tractable and doesn’t accidentally block corridors.)
    obstacle_budgets = [0, 2, 4, 6, 8, 10]
    while len(games) < n_games:
        attempts += 1
        if attempts > 4000:
            raise RuntimeError("Could not generate enough solvable game instances")

        k = len(games)
        extra_obs_target = obstacle_budgets[min(k, len(obstacle_budgets) - 1)]

        # Candidate cells for extra obstacles (avoid start/goal/sensors and the wall column).
        candidates: List[Tuple[int, int]] = []
        for y in range(height):
            for x in range(width):
                p = (x, y)
                if p in base_grid.obstacles:
                    continue
                if p == base_grid.start or p == base_grid.goal:
                    continue
                if p in base_sensors:
                    continue
                if x == wall_x:
                    continue
                candidates.append(p)

        rng.shuffle(candidates)
        extra = set(candidates[:extra_obs_target])

        grid = GridConfig(
            width=base_grid.width,
            height=base_grid.height,
            start=base_grid.start,
            goal=base_grid.goal,
            obstacles=frozenset(set(base_grid.obstacles) | extra),
        )
        if not _is_reachable(grid):
            continue

        sensor_cfg = SensorConfig(
            sensors=base_sensors,
            radius=int(base_radius),
            base_p=float(base_p),
            hotspot_p=float(hotspot_p),
        )

        game_id = f"G{k:02d}_extraObs{extra_obs_target}"
        desc = f"corridors + {extra_obs_target} extra obstacles"
        games.append(GameInstance(game_id=game_id, grid=grid, sensor_cfg=sensor_cfg, desc=desc))

    return games


# -----------------------------------------------------------------------------
# Sanity tests
# -----------------------------------------------------------------------------

def run_sanity_tests() -> None:
    """Fast, high-signal checks so you can trust comparisons."""
    print("[TEST] Running sanity tests...")

    grid = build_two_corridor_grid()
    sensor_cfg = SensorConfig(
        sensors=((grid.width // 2, 2), (grid.width // 2, 6)),
        radius=2,
        base_p=0.02,
        hotspot_p=0.60,
    )
    env = GridWorldStealthEnv(grid, sensor_cfg, fp=0.05, fn=0.10, seed=0)

    # A* reachability
    p = astar_path(grid, grid.start, grid.goal, step_cost=lambda a, b: 1.0)
    assert p[0] == grid.start and p[-1] == grid.goal

    # Belief update preserves simplex
    b = ModeBelief(M=2)
    env.reset(sensor_mode=0)
    out = env.step((1, 0))
    b.update(env, out["alarm"], out["pos"])
    assert np.isfinite(b.b).all()
    assert abs(float(b.b.sum()) - 1.0) < 1e-6
    assert (b.b >= -1e-12).all()

    # NBS returns simplex
    uR = np.array([-1.0, -2.0, -3.0, -4.0])
    uS = np.array([1.0, 2.0, 3.0, 4.0])
    nbs = solve_nbs(uR, uS, log=Logger("QUIET"), max_iters=50)
    x = nbs.x
    assert np.isfinite(x).all()
    assert (x >= -1e-10).all()
    assert abs(float(x.sum()) - 1.0) < 1e-6

    # Joint/marginals/conditionals are sane
    X = joint_to_matrix(x, m=2, n=2)
    sigma_R, sigma_S = marginals_from_joint(X)
    assert abs(float(sigma_R.sum()) - 1.0) < 1e-6
    assert abs(float(sigma_S.sum()) - 1.0) < 1e-6
    qS = conditional_sensor_given_robot(X, 0)
    qR = conditional_robot_given_sensor(X, 0)
    assert abs(float(qS.sum()) - 1.0) < 1e-6
    assert abs(float(qR.sum()) - 1.0) < 1e-6

    # CE regrets should be >= 0 (numerical)
    UR = np.random.default_rng(0).normal(size=(3, 2))
    US = np.random.default_rng(1).normal(size=(3, 2))
    Xr = np.full((3, 2), 1.0 / 6)
    reg = compute_ce_regrets(UR, US, Xr)
    assert reg["max_regret_R"] >= -1e-9
    assert reg["max_regret_S"] >= -1e-9

    # Rollout returns finite stats
    robots, sensors = build_initial_policies(env, M_modes=2)
    st = rollout_episode(env, robots[0], sensors[0], M_modes=2, seed=0)
    assert np.isfinite(st.U_R) and np.isfinite(st.U_S)

    print("[TEST] All sanity tests passed ✅")


# -----------------------------------------------------------------------------
# Fair comparisons on the same fixed game set
# -----------------------------------------------------------------------------

@dataclass
class ExperimentRow:
    alg: str
    solver: str
    game_id: str
    seed: int
    metric: str
    value: float


def _evaluate_robot_mixture_against_sensor(
    env: GridWorldStealthEnv,
    robots: List[RobotPolicy],
    sigma_R: np.ndarray,
    sensor: SensorPolicy,
    M_modes: int,
    episodes: int,
    base_seed: int,
) -> StrategyEval:
    sigma_R = np.asarray(sigma_R, dtype=float).reshape(-1)
    sigma_R = sigma_R / max(float(sigma_R.sum()), 1e-12)

    rng = np.random.default_rng(int(base_seed))

    UR, US = [], []
    det = 0
    goal = 0
    steps_list = []
    risk_list = []

    for k in range(episodes):
        i = int(rng.choice(len(robots), p=sigma_R))
        st = rollout_episode(env, robots[i], sensor, M_modes=M_modes, seed=int(base_seed + k))
        UR.append(st.U_R)
        US.append(st.U_S)
        det += int(st.detected)
        goal += int(st.reached_goal)
        steps_list.append(st.steps)
        risk_list.append(st.total_true_risk)

    return StrategyEval(
        mean_U_R=float(np.mean(UR)),
        mean_U_S=float(np.mean(US)),
        det_rate=float(det / episodes),
        goal_rate=float(goal / episodes),
        mean_steps=float(np.mean(steps_list)),
        mean_risk=float(np.mean(risk_list)),
    )


def _summarize_against_sensor_suite(
    env: GridWorldStealthEnv,
    robots: List[RobotPolicy],
    sigma_R: np.ndarray,
    sensors_suite: List[SensorPolicy],
    M_modes: int,
    episodes_per_sensor: int,
    base_seed: int,
) -> Dict[str, float]:
    vals_UR = []
    vals_goal = []
    vals_det = []

    for j, sp in enumerate(sensors_suite):
        ev = _evaluate_robot_mixture_against_sensor(
            env,
            robots,
            sigma_R=sigma_R,
            sensor=sp,
            M_modes=M_modes,
            episodes=episodes_per_sensor,
            base_seed=base_seed + 10000 * j,
        )
        vals_UR.append(ev.mean_U_R)
        vals_goal.append(ev.goal_rate)
        vals_det.append(ev.det_rate)

    # "Robust" = worst case over the sensor suite.
    return {
        "mean_UR": float(np.mean(vals_UR)),
        "robust_UR": float(np.min(vals_UR)),
        "mean_goal": float(np.mean(vals_goal)),
        "robust_goal": float(np.min(vals_goal)),
        "mean_det": float(np.mean(vals_det)),
        "robust_det": float(np.max(vals_det)),
    }


def _write_experiment_csv(rows: List[ExperimentRow], path: str) -> None:
    import csv

    safe_makedirs(os.path.dirname(path) or ".")
    with open(path, "w", newline="") as f:
        w = csv.writer(f)
        w.writerow(["alg", "solver", "game_id", "seed", "metric", "value"])
        for r in rows:
            w.writerow([r.alg, r.solver, r.game_id, r.seed, r.metric, f"{r.value:.8f}"])


def _group_stats(values: List[float]) -> Dict[str, float]:
    arr = np.asarray(values, dtype=float)
    return {
        "mean": float(np.mean(arr)) if arr.size else float("nan"),
        "std": float(np.std(arr)) if arr.size else float("nan"),
        "n": int(arr.size),
    }


def plot_experiment_summary(rows: List[ExperimentRow], outdir: str) -> None:
    """Create *paper-style* figures from ExperimentRow logs.

    Why these figures are meaningful:
      - We plot performance *per game instance* (difficulty sweep) so you can see
        robustness trends, not just an average.
      - We summarize distributions across (game, seed) with boxplots.
      - We show trade-offs (goal vs detection, runtime vs performance) with scatter.

    Outputs (PNG + PDF):
      fig_robustUR_by_game, fig_goal_det_tradeoff, fig_runtime_vs_robustUR,
      fig_robust_goal_by_game, fig_robust_det_by_game,
      fig_box_robustUR, fig_box_goal, fig_box_det,
      fig_stability_regret_vs_robustUR

    Notes:
      - Uses matplotlib defaults (no manual colors), saves high DPI.
      - Assumes metric semantics:
          robust_UR: higher is better
          robust_goal: higher is better
          robust_det: lower is better (worst-case detection rate)
          runtime_s, regrets: lower is better
    """
    import matplotlib.pyplot as plt

    safe_makedirs(outdir)

    # ----------------------------
    # Collect + aggregate utilities
    # ----------------------------
    algs = sorted({r.alg for r in rows})
    games = sorted({r.game_id for r in rows})

    # Map: (alg, game, metric) -> [values across seeds]
    bucket: Dict[Tuple[str, str, str], List[float]] = {}
    for r in rows:
        key = (r.alg, r.game_id, r.metric)
        bucket.setdefault(key, []).append(float(r.value))

    def get_vals(alg: str, game_id: str, metric: str) -> List[float]:
        return bucket.get((alg, game_id, metric), [])

    def mean_ci95(vals: List[float]) -> Tuple[float, float]:
        arr = np.asarray(vals, dtype=float)
        if arr.size == 0:
            return float("nan"), float("nan")
        mu = float(np.mean(arr))
        if arr.size == 1:
            return mu, 0.0
        se = float(np.std(arr, ddof=1) / np.sqrt(arr.size))
        return mu, 1.96 * se

    def savefig(base: str) -> None:
        plt.tight_layout()
        plt.savefig(os.path.join(outdir, base + ".png"), dpi=300)
        plt.savefig(os.path.join(outdir, base + ".pdf"))
        plt.close()

    # ----------------------------
    # (1) Metric vs game difficulty (line + CI)
    # ----------------------------
    def plot_by_game(metric: str, ylabel: str, base: str) -> None:
        plt.figure(figsize=(9.0, 4.2))
        x = np.arange(len(games))
        for alg in algs:
            ys = []
            es = []
            for gid in games:
                mu, ci = mean_ci95(get_vals(alg, gid, metric))
                ys.append(mu)
                es.append(ci)
            ys = np.asarray(ys, dtype=float)
            es = np.asarray(es, dtype=float)
            plt.plot(x, ys, marker="o", linewidth=2.0, label=alg)
            # CI band (skip if nan)
            ok = np.isfinite(ys) & np.isfinite(es)
            if np.any(ok):
                plt.fill_between(x[ok], (ys - es)[ok], (ys + es)[ok], alpha=0.15)
        plt.xticks(x, games, rotation=25, ha="right")
        plt.xlabel("Game instance (increasing obstacle perturbations)")
        plt.ylabel(ylabel)
        plt.grid(True, alpha=0.25)
        plt.legend(ncol=2, fontsize=9)
        savefig(base)

    plot_by_game("robust_UR", "Robust robot utility (worst-case over sensor suite) ↑", "fig_robustUR_by_game")
    plot_by_game("robust_goal", "Robust goal rate (worst-case over sensor suite) ↑", "fig_robust_goal_by_game")
    plot_by_game("robust_det", "Worst-case detection rate over sensor suite ↓", "fig_robust_det_by_game")

    # ----------------------------
    # (2) Overall distribution boxplots (across games+seeds)
    # ----------------------------
    def plot_box(metric: str, ylabel: str, base: str) -> None:
        plt.figure(figsize=(9.0, 4.2))
        data = []
        for alg in algs:
            vals = [float(r.value) for r in rows if r.alg == alg and r.metric == metric]
            data.append(vals)
        plt.boxplot(data, labels=algs, showfliers=False)
        plt.xticks(rotation=25, ha="right")
        plt.ylabel(ylabel)
        plt.grid(True, axis="y", alpha=0.25)
        savefig(base)

    plot_box("robust_UR", "Robust robot utility (worst-case over sensor suite) ↑", "fig_box_robustUR")
    plot_box("robust_goal", "Robust goal rate (worst-case over sensor suite) ↑", "fig_box_goal")
    plot_box("robust_det", "Worst-case detection rate over sensor suite ↓", "fig_box_det")

    # ----------------------------
    # (3) Goal vs detection trade-off (Pareto-ish view)
    #     Each point is (game, mean over seeds).
    # ----------------------------
    plt.figure(figsize=(6.0, 5.2))
    for alg in algs:
        xs = []  # detection
        ys = []  # goal
        for gid in games:
            mu_det, _ = mean_ci95(get_vals(alg, gid, "robust_det"))
            mu_goal, _ = mean_ci95(get_vals(alg, gid, "robust_goal"))
            if np.isfinite(mu_det) and np.isfinite(mu_goal):
                xs.append(mu_det)
                ys.append(mu_goal)
        if xs and ys:
            plt.scatter(xs, ys, label=alg)
    plt.xlabel("Worst-case detection rate (lower is better)")
    plt.ylabel("Worst-case goal rate (higher is better)")
    plt.grid(True, alpha=0.25)
    plt.legend(fontsize=9)
    savefig("fig_goal_det_tradeoff")

    # ----------------------------
    # (4) Runtime vs robust utility (efficiency vs performance)
    #     Use per-run points (game, seed).
    # ----------------------------
    plt.figure(figsize=(6.0, 5.2))
    for alg in algs:
        xs = [float(r.value) for r in rows if r.alg == alg and r.metric == "runtime_s"]
        ys = [float(r.value) for r in rows if r.alg == alg and r.metric == "robust_UR"]
        if xs and ys and len(xs) == len(ys):
            plt.scatter(xs, ys, label=alg)
    plt.xlabel("Runtime (seconds) ↓")
    plt.ylabel("Robust robot utility ↑")
    plt.grid(True, alpha=0.25)
    plt.legend(fontsize=9)
    savefig("fig_runtime_vs_robustUR")

    # ----------------------------
    # (5) Stability diagnostic: regret vs robust utility
    # ----------------------------
    plt.figure(figsize=(6.0, 5.2))
    for alg in algs:
        xs = [float(r.value) for r in rows if r.alg == alg and r.metric == "max_regret_R"]
        ys = [float(r.value) for r in rows if r.alg == alg and r.metric == "robust_UR"]
        if xs and ys and len(xs) == len(ys):
            plt.scatter(xs, ys, label=alg)
    plt.xlabel("Max conditional regret (robot) ↓")
    plt.ylabel("Robust robot utility ↑")
    plt.grid(True, alpha=0.25)
    plt.legend(fontsize=9)
    savefig("fig_stability_regret_vs_robustUR")


def run_fixed_games_experiment(
    outdir: str = "results_exp",
    n_games: int = 6,
    seeds: Optional[List[int]] = None,
    outer_iters: int = 3,
    rollouts_payoff: int = 20,
    rollouts_br: int = 30,
    eval_episodes: int = 120,
    episodes_per_sensor: int = 80,
    disagreement: str = "minminus",
    entropy_tau: float = 0.0,
    cond_top_k: int = 2,
    br_eval_rollouts: int = 8,
    add_threshold: float = 0.25,
    risk_weight_br: float = 12.0,
    include_solvers: Optional[List[str]] = None,
    include_baselines: bool = True,
    log_level: str = "QUIET",
) -> None:
    """Run a clean comparison on a deterministic game set.

    What you get in outdir:
      - experiments.csv  (all raw numbers)
      - exp_bar_*.png    (summary bars)
      - exp_scatter_*.png

    Recommended presentation story:
      1) Robust UR and robust goal rate vs baselines
      2) Runtime vs baselines
      3) Scatter showing tradeoff: regret (stability) vs robust UR (performance)

    NOTE: NBS is not a Nash/CE solver; higher regret is expected. That’s *part of the story*.
    """

    if seeds is None:
        seeds = [0, 1, 2]
    if include_solvers is None:
        include_solvers = ["correlated", "marginal"]

    safe_makedirs(outdir)

    games = build_fixed_game_set(n_games=n_games)

    rows: List[ExperimentRow] = []

    # Sensor suite for robust evaluation (adversarial-ish)
    # You can add Alternate/Random to show generalization, but the key is Mode0/Mode1.
    def make_sensor_suite(M_modes: int, seed0: int) -> List[SensorPolicy]:
        return [
            FixedModeSensorPolicy(0, "S_Mode0"),
            FixedModeSensorPolicy(1, "S_Mode1"),
            AlternatingSensorPolicy(M_modes, "S_Alternate"),
            RandomModeSensorPolicy(M_modes, seed=seed0 + 999, name="S_RandomMode"),
        ]

    for g in games:
        for s in seeds:
            env = GridWorldStealthEnv(g.grid, g.sensor_cfg, fp=0.05, fn=0.10, seed=s)
            M_modes = len(g.sensor_cfg.sensors)
            sensors_suite = make_sensor_suite(M_modes=M_modes, seed0=s)

            # -----------------
            # (1) Run solvers
            # -----------------
            for solver in include_solvers:
                # Build an args object compatible with run_training()
                args = argparse.Namespace(
                    seed=s,
                    log_level=log_level,
                    solver=solver,
                    outer_iters=int(outer_iters),
                    rollouts_payoff=int(rollouts_payoff),
                    rollouts_br=int(rollouts_br),
                    risk_weight_br=float(risk_weight_br),
                    disagreement=str(disagreement),
                    entropy_tau=float(entropy_tau),
                    cond_top_k=int(cond_top_k),
                    br_eval_rollouts=int(br_eval_rollouts),
                    add_threshold=float(add_threshold),
                    eval_episodes=int(eval_episodes),
                    debug_rollout_pair="",
                    results_dir="",
                    save_plots=False,
                    run_benchmarks=False,
                    bench_episodes=0,
                )

                t0 = time.time()
                log = Logger(log_level)
                tr = run_training(env, args, solver=solver, log=log)
                runtime_s = float(time.time() - t0)

                # Extract sigma_R from final X
                sigma_R, _sigma_S = marginals_from_joint(tr.X)
                suite_summary = _summarize_against_sensor_suite(
                    env,
                    robots=tr.robots,
                    sigma_R=sigma_R,
                    sensors_suite=sensors_suite,
                    M_modes=M_modes,
                    episodes_per_sensor=int(episodes_per_sensor),
                    base_seed=100_000 + 1000 * s,
                )

                # Also keep solver-internal diagnostics (from last history row)
                last = tr.history[-1]

                alg = f"Solver:{solver}"
                for k, v in suite_summary.items():
                    rows.append(ExperimentRow(alg=alg, solver=solver, game_id=g.game_id, seed=s, metric=k, value=float(v)))

                rows.append(ExperimentRow(alg=alg, solver=solver, game_id=g.game_id, seed=s, metric="runtime_s", value=runtime_s))
                rows.append(ExperimentRow(alg=alg, solver=solver, game_id=g.game_id, seed=s, metric="entropy_X", value=float(last.entropy_X)))
                rows.append(ExperimentRow(alg=alg, solver=solver, game_id=g.game_id, seed=s, metric="max_regret_R", value=float(last.max_regret_R)))
                rows.append(ExperimentRow(alg=alg, solver=solver, game_id=g.game_id, seed=s, metric="max_regret_S", value=float(last.max_regret_S)))
                rows.append(ExperimentRow(alg=alg, solver=solver, game_id=g.game_id, seed=s, metric="selfplay_UR", value=float(last.selfplay_UR)))
                rows.append(ExperimentRow(alg=alg, solver=solver, game_id=g.game_id, seed=s, metric="selfplay_US", value=float(last.selfplay_US)))

            # -----------------
            # (2) Baseline robots (no training)
            # -----------------
            if include_baselines:
                robots_base, _ = build_initial_policies_for_bench(env, M_modes=M_modes)

                for rp in robots_base:
                    sigma = np.zeros(len(robots_base), dtype=float)
                    sigma[robots_base.index(rp)] = 1.0

                    suite_summary = _summarize_against_sensor_suite(
                        env,
                        robots=robots_base,
                        sigma_R=sigma,
                        sensors_suite=sensors_suite,
                        M_modes=M_modes,
                        episodes_per_sensor=int(episodes_per_sensor),
                        base_seed=200_000 + 1000 * s,
                    )

                    alg = f"Baseline:{rp.name}"
                    for k, v in suite_summary.items():
                        rows.append(ExperimentRow(alg=alg, solver="baseline", game_id=g.game_id, seed=s, metric=k, value=float(v)))

    # Save raw numbers
    csv_path = os.path.join(outdir, "experiments.csv")
    _write_experiment_csv(rows, csv_path)

    # Produce summary plots
    plot_experiment_summary(rows, outdir=outdir)

    print(f"[EXP] Done. Wrote: {csv_path}")
    print(f"[EXP] Plots saved to: {outdir}")



In [27]:
run_sanity_tests()

run_fixed_games_experiment(
    outdir="results_exp",
    n_games=6,
    seeds=[0,1,2,3,4],          # use more for a paper
    outer_iters=3,
    entropy_tau=0.02,           # IMPORTANT if you want multimodal x*
    disagreement="uniform",     # encourages mixture rather than collapsing
    log_level="QUIET",
)


[TEST] Running sanity tests...
[TEST] All sanity tests passed ✅


  plt.boxplot(data, labels=algs, showfliers=False)
  plt.boxplot(data, labels=algs, showfliers=False)
  plt.boxplot(data, labels=algs, showfliers=False)


[EXP] Done. Wrote: results_exp\experiments.csv
[EXP] Plots saved to: results_exp


In [32]:
#!/usr/bin/env python3
"""approach2_robust_correlated.py

Robust, reduced-log implementation of "Approach 2" on a stealth gridworld POMDP,
with a FIX for the conceptual mismatch:

  Old mismatch: solve a joint distribution x*(i,j) over (robot policy i, sensor policy j)
  but then compute best responses against the MARGINALS sigma_R, sigma_S.

  Fix here: treat x* as a CORRELATION DEVICE (mediator) and compute BRs against
  CONDITIONAL distributions:

      q_S(.|i) = X[i,:] / sigma_R[i]   (sensor conditional given robot recommendation i)
      q_R(.|j) = X[:,j] / sigma_S[j]   (robot conditional given sensor recommendation j)

  Then add the most profitable *deviation* policy (oracle) based on the most violated
  conditional recommendation.

This makes the PSRO-style expansion step consistent with a correlated-strategy viewpoint.

Also included:
  - Benchmark harness vs simple heuristics on the same game.
  - Saved plots (training curves + bar charts) for presentations.

Run (script):
  python approach2_robust_correlated.py --solver correlated --outer-iters 3 --save-plots

Run (notebook):
  main(argv=["--solver","both","--outer-iters","3","--save-plots"])  # ignore ipykernel args

Tips to encourage MULTIMODAL x* (useful for your "mode recovery"):
  --disagreement uniform --entropy-tau 0.02

"""

from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Tuple

import argparse
import sys
import heapq
import os
import time

import numpy as np


# =============================================================================
# Logging
# =============================================================================

class Logger:
    """Stage-based logger with QUIET/INFO/DEBUG."""

    LEVELS = {"QUIET": 0, "INFO": 1, "DEBUG": 2}

    def __init__(self, level: str = "INFO"):
        level = level.upper()
        if level not in self.LEVELS:
            raise ValueError(f"Unknown log level: {level}. Use QUIET/INFO/DEBUG")
        self.level = level
        self.k = self.LEVELS[level]

    def banner(self, title: str) -> None:
        if self.k >= 1:
            print("" + "=" * 100)
            print(title)
            print("=" * 100)

    def info(self, msg: str) -> None:
        if self.k >= 1:
            print(msg)

    def debug(self, msg: str) -> None:
        if self.k >= 2:
            print(msg)


# =============================================================================
# Types
# =============================================================================

Action = Tuple[int, int]  # (dx, dy)


# =============================================================================
# Environment (grid + hidden sensor mode + noisy alarm observation)
# =============================================================================

@dataclass(frozen=True)
class GridConfig:
    width: int
    height: int
    start: Tuple[int, int]
    goal: Tuple[int, int]
    obstacles: frozenset


@dataclass(frozen=True)
class SensorConfig:
    sensors: Tuple[Tuple[int, int], ...]
    radius: int
    base_p: float
    hotspot_p: float


class GridWorldStealthEnv:
    """Grid world with detection risk controlled by a hidden/selected sensor 'mode'."""

    def __init__(self, grid: GridConfig, sensor_cfg: SensorConfig, fp: float = 0.05, fn: float = 0.10, seed: int = 0):
        self.grid = grid
        self.sensor_cfg = sensor_cfg
        self.fp = float(fp)
        self.fn = float(fn)

        if not (0.0 <= self.fp <= 1.0 and 0.0 <= self.fn <= 1.0):
            raise ValueError("fp and fn must be in [0,1].")
        if not (0.0 <= sensor_cfg.base_p <= 1.0 and 0.0 <= sensor_cfg.hotspot_p <= 1.0):
            raise ValueError("base_p and hotspot_p must be in [0,1].")
        if sensor_cfg.radius < 0:
            raise ValueError("radius must be >= 0")
        if len(sensor_cfg.sensors) == 0:
            raise ValueError("Need at least one sensor center.")

        self.rng = np.random.default_rng(int(seed))
        self.reset(sensor_mode=0)

    def seed(self, seed: int) -> None:
        self.rng = np.random.default_rng(int(seed))

    def reset(self, sensor_mode: int = 0) -> Dict[str, Any]:
        self.t = 0
        self.pos = self.grid.start
        self.sensor_mode = int(sensor_mode)
        self.detected = False
        self.total_true_risk = 0.0
        return {"pos": self.pos, "t": self.t}

    def in_bounds(self, p: Tuple[int, int]) -> bool:
        x, y = p
        return 0 <= x < self.grid.width and 0 <= y < self.grid.height

    def is_free(self, p: Tuple[int, int]) -> bool:
        return self.in_bounds(p) and (p not in self.grid.obstacles)

    def true_detection_prob(self, p: Tuple[int, int], mode: int) -> float:
        if not (0 <= mode < len(self.sensor_cfg.sensors)):
            raise ValueError(f"mode {mode} out of range")
        sx, sy = self.sensor_cfg.sensors[mode]
        x, y = p
        d = abs(x - sx) + abs(y - sy)
        return self.sensor_cfg.hotspot_p if d <= self.sensor_cfg.radius else self.sensor_cfg.base_p

    def observation_prob(self, alarm: int, p_true: float) -> float:
        p_alarm = p_true * (1.0 - self.fn) + (1.0 - p_true) * self.fp
        if alarm == 1:
            return p_alarm
        if alarm == 0:
            return 1.0 - p_alarm
        raise ValueError("alarm must be 0 or 1")

    def step(self, a: Action) -> Dict[str, Any]:
        if self.detected:
            return {"pos": self.pos, "t": self.t, "alarm": 1, "p_true": 1.0, "detected": True, "done": True}

        self.t += 1
        nx = self.pos[0] + int(a[0])
        ny = self.pos[1] + int(a[1])
        np_ = (nx, ny)
        if self.is_free(np_):
            self.pos = np_

        p_true = float(self.true_detection_prob(self.pos, self.sensor_mode))
        self.total_true_risk += p_true

        if self.rng.random() < p_true:
            self.detected = True

        p_alarm = p_true * (1.0 - self.fn) + (1.0 - p_true) * self.fp
        alarm = 1 if (self.rng.random() < p_alarm) else 0

        done = self.detected or (self.pos == self.grid.goal) or (self.t >= 200)
        return {"pos": self.pos, "t": self.t, "alarm": alarm, "p_true": p_true, "detected": self.detected, "done": done}


def build_two_corridor_grid(width: int = 15, height: int = 9) -> GridConfig:
    obstacles = set()
    wall_x = width // 2
    gap_ys = {2, 6}
    for y in range(height):
        if y not in gap_ys:
            obstacles.add((wall_x, y))

    start = (1, height - 2)
    goal = (width - 2, 1)
    if start in obstacles or goal in obstacles:
        raise RuntimeError("Start/goal blocked unexpectedly")

    return GridConfig(width=width, height=height, start=start, goal=goal, obstacles=frozenset(obstacles))


def print_grid_ascii(grid: GridConfig, sensor_cfg: SensorConfig) -> None:
    W, H = grid.width, grid.height
    obs = set(grid.obstacles)
    sens = set(sensor_cfg.sensors)
    for y in range(H):
        row = []
        for x in range(W):
            p = (x, y)
            if p == grid.start:
                row.append("R")
            elif p == grid.goal:
                row.append("G")
            elif p in sens:
                row.append("S")
            elif p in obs:
                row.append("#")
            else:
                row.append(".")
        print("".join(row))


# =============================================================================
# Belief over modes
# =============================================================================

class ModeBelief:
    """Exact belief over discrete modes m in {0..M-1}."""

    def __init__(self, M: int, init: Optional[np.ndarray] = None):
        self.M = int(M)
        if self.M <= 0:
            raise ValueError("M must be >= 1")
        if init is None:
            self.b = np.full(self.M, 1.0 / self.M)
        else:
            init = np.asarray(init, dtype=float).reshape(-1)
            if init.shape != (self.M,):
                raise ValueError("init shape mismatch")
            if np.any(init < 0):
                raise ValueError("init must be nonnegative")
            s = float(init.sum())
            self.b = init / s if s > 0 else np.full(self.M, 1.0 / self.M)

    def update(self, env: GridWorldStealthEnv, alarm: int, pos: Tuple[int, int], eps: float = 1e-12) -> None:
        like = np.zeros(self.M, dtype=float)
        for m in range(self.M):
            p_true = env.true_detection_prob(pos, m)
            like[m] = env.observation_prob(alarm, p_true)
        post = self.b * like
        Z = float(post.sum())
        if (not np.isfinite(Z)) or Z < eps:
            return
        self.b = post / Z


# =============================================================================
# Policies
# =============================================================================

class RobotPolicy:
    name: str = "RobotPolicy"

    def reset(self, start_pos: Tuple[int, int]) -> None:
        pass

    def act(self, env: GridWorldStealthEnv, belief: ModeBelief, last_obs: Optional[int]) -> Action:
        raise NotImplementedError


class FixedPathPolicy(RobotPolicy):
    def __init__(self, path: List[Tuple[int, int]], name: str):
        if len(path) < 2:
            raise ValueError("Path must have >=2 states")
        self.path = list(path)
        self.name = str(name)
        self._idx = 0

    def reset(self, start_pos: Tuple[int, int]) -> None:
        try:
            self._idx = self.path.index(start_pos)
        except ValueError:
            self._idx = 0

    def act(self, env: GridWorldStealthEnv, belief: ModeBelief, last_obs: Optional[int]) -> Action:
        cur = env.pos
        if self._idx >= len(self.path) - 1:
            return (0, 0)
        if cur != self.path[self._idx]:
            try:
                self._idx = self.path.index(cur, self._idx)
            except ValueError:
                return (0, 0)
        nxt = self.path[self._idx + 1]
        dx = int(np.clip(nxt[0] - cur[0], -1, 1))
        dy = int(np.clip(nxt[1] - cur[1], -1, 1))
        self._idx += 1
        return (dx, dy)


class RandomPolicy(RobotPolicy):
    """Reproducible random policy using env.rng."""

    def __init__(self, name: str = "R_Random"):
        self.name = name

    def reset(self, start_pos: Tuple[int, int]) -> None:
        pass

    def act(self, env: GridWorldStealthEnv, belief: ModeBelief, last_obs: Optional[int]) -> Action:
        candidates: List[Action] = []
        x, y = env.pos
        for a in [(1, 0), (-1, 0), (0, 1), (0, -1), (0, 0)]:
            np_ = (x + a[0], y + a[1])
            if env.is_free(np_):
                candidates.append(a)
        if not candidates:
            return (0, 0)
        return candidates[int(env.rng.integers(0, len(candidates)))]


class OnlineBeliefReplanPolicy(RobotPolicy):
    """POMDP-ish heuristic: replan each step using risk map induced by current belief b_t."""

    def __init__(self, env: GridWorldStealthEnv, risk_weight: float = 12.0, name: str = "R_OnlineBeliefReplan"):
        self.env = env
        self.risk_weight = float(risk_weight)
        self.name = name
        self._cached_next: Optional[Tuple[int, int]] = None

    def reset(self, start_pos: Tuple[int, int]) -> None:
        self._cached_next = None

    def act(self, env: GridWorldStealthEnv, belief: ModeBelief, last_obs: Optional[int]) -> Action:
        # Build a risk map from belief over modes.
        mode_probs = belief.b

        def step_cost(frm: Tuple[int, int], to: Tuple[int, int]) -> float:
            x, y = to
            # expected risk at to under belief
            r = 0.0
            for m, pm in enumerate(mode_probs):
                r += float(pm) * float(env.true_detection_prob(to, m))
            return 1.0 + self.risk_weight * r

        # Plan from current pos to goal (one-step receding horizon)
        try:
            path = astar_path(env.grid, env.pos, env.grid.goal, step_cost)
            if len(path) < 2:
                return (0, 0)
            nxt = path[1]
            dx = int(np.clip(nxt[0] - env.pos[0], -1, 1))
            dy = int(np.clip(nxt[1] - env.pos[1], -1, 1))
            return (dx, dy)
        except Exception:
            return (0, 0)


class SensorPolicy:
    name: str = "SensorPolicy"

    def reset(self) -> None:
        pass

    def select_mode(self, t: int) -> int:
        raise NotImplementedError


class FixedModeSensorPolicy(SensorPolicy):
    def __init__(self, mode: int, name: Optional[str] = None):
        self.mode = int(mode)
        self.name = name or f"S_Mode{mode}"

    def reset(self) -> None:
        pass

    def select_mode(self, t: int) -> int:
        return self.mode


class AlternatingSensorPolicy(SensorPolicy):
    """Simple sensor heuristic for benchmarks: alternate modes 0,1,0,1,..."""

    def __init__(self, M: int, name: str = "S_Alternate"):
        self.M = int(M)
        self.name = name

    def reset(self) -> None:
        pass

    def select_mode(self, t: int) -> int:
        return int(t % self.M)


class RandomModeSensorPolicy(SensorPolicy):
    """Benchmark sensor: random mode each step (uses numpy Generator for reproducibility)."""

    def __init__(self, M: int, seed: int = 0, name: str = "S_RandomMode"):
        self.M = int(M)
        self.rng = np.random.default_rng(int(seed))
        self.name = name

    def reset(self) -> None:
        pass

    def select_mode(self, t: int) -> int:
        return int(self.rng.integers(0, self.M))


# =============================================================================
# A* (used for planning)
# =============================================================================

def astar_path(
    grid: GridConfig,
    start: Tuple[int, int],
    goal: Tuple[int, int],
    step_cost: Callable[[Tuple[int, int], Tuple[int, int]], float],
    max_expansions: int = 250_000,
) -> List[Tuple[int, int]]:
    if start == goal:
        return [start]

    def h(p: Tuple[int, int]) -> float:
        return abs(p[0] - goal[0]) + abs(p[1] - goal[1])

    def neighbors(p: Tuple[int, int]):
        x, y = p
        for dx, dy in [(1, 0), (-1, 0), (0, 1), (0, -1)]:
            np_ = (x + dx, y + dy)
            if 0 <= np_[0] < grid.width and 0 <= np_[1] < grid.height and np_ not in grid.obstacles:
                yield np_

    open_heap: List[Tuple[float, float, Tuple[int, int]]] = []
    heapq.heappush(open_heap, (h(start), 0.0, start))

    came: Dict[Tuple[int, int], Optional[Tuple[int, int]]] = {start: None}
    gscore: Dict[Tuple[int, int], float] = {start: 0.0}

    expansions = 0
    while open_heap:
        _, _, cur = heapq.heappop(open_heap)
        expansions += 1
        if cur == goal:
            path: List[Tuple[int, int]] = []
            while cur is not None:
                path.append(cur)
                cur = came[cur]
            path.reverse()
            return path
        if expansions > max_expansions:
            raise RuntimeError("A* exceeded max expansions")

        for nb in neighbors(cur):
            tentative = gscore[cur] + float(step_cost(cur, nb))
            if (nb not in gscore) or (tentative < gscore[nb] - 1e-12):
                gscore[nb] = tentative
                came[nb] = cur
                heapq.heappush(open_heap, (tentative + h(nb), tentative, nb))

    raise RuntimeError("A* failed: unreachable goal")


def astar_via_waypoint(
    grid: GridConfig,
    start: Tuple[int, int],
    waypoint: Tuple[int, int],
    goal: Tuple[int, int],
    step_cost: Callable[[Tuple[int, int], Tuple[int, int]], float],
) -> List[Tuple[int, int]]:
    p1 = astar_path(grid, start, waypoint, step_cost)
    p2 = astar_path(grid, waypoint, goal, step_cost)
    return p1[:-1] + p2


# =============================================================================
# Rollouts + payoffs
# =============================================================================

@dataclass
class EpisodeStats:
    steps: int
    reached_goal: bool
    detected: bool
    total_true_risk: float
    U_R: float
    U_S: float


def rollout_episode(
    env: GridWorldStealthEnv,
    robot: RobotPolicy,
    sensor: SensorPolicy,
    M_modes: int,
    seed: int,
    max_steps: int = 200,
    lambda_risk: float = 1.0,
    det_penalty: float = 50.0,
    sensor_energy_per_step: float = 0.2,
    step_debug: bool = False,
) -> EpisodeStats:
    env.seed(seed)
    sensor.reset()
    env.reset(sensor_mode=sensor.select_mode(0))

    belief = ModeBelief(M_modes)
    robot.reset(env.pos)

    total_risk = 0.0
    last_alarm: Optional[int] = None

    for _ in range(max_steps):
        env.sensor_mode = sensor.select_mode(env.t)
        a = robot.act(env, belief, last_alarm)
        out = env.step(a)

        total_risk += float(out["p_true"])
        belief.update(env, out["alarm"], out["pos"])
        last_alarm = int(out["alarm"])

        if step_debug:
            print(
                f"[Step] t={out['t']:3d} pos={out['pos']} a={a} p_true={out['p_true']:.3f} "
                f"alarm={out['alarm']} det={out['detected']} done={out['done']} b={belief.b.round(3)}"
            )

        if out["done"]:
            break

    reached_goal = (env.pos == env.grid.goal)
    detected = bool(env.detected)
    steps = int(env.t)

    cost_R = steps + lambda_risk * total_risk + (det_penalty if detected else 0.0)
    U_R = -float(cost_R)

    U_S = float((det_penalty if detected else 0.0) + lambda_risk * total_risk - sensor_energy_per_step * steps)

    return EpisodeStats(steps=steps, reached_goal=reached_goal, detected=detected, total_true_risk=float(total_risk), U_R=U_R, U_S=U_S)


def evaluate_payoffs(
    env: GridWorldStealthEnv,
    robots: List[RobotPolicy],
    sensors: List[SensorPolicy],
    M_modes: int,
    rollouts: int,
    base_seed: int,
    log: Logger,
    debug_rollout_pair: Optional[Tuple[int, int]] = None,
) -> Tuple[np.ndarray, np.ndarray, Dict[Tuple[int, int], Dict[str, float]]]:
    m, n = len(robots), len(sensors)
    U_R = np.zeros((m, n), dtype=float)
    U_S = np.zeros((m, n), dtype=float)
    diag: Dict[Tuple[int, int], Dict[str, float]] = {}

    log.info(f"[Eval] Estimating payoffs: m={m}, n={n}, rollouts={rollouts}, base_seed={base_seed}")

    for i, rpol in enumerate(robots):
        for j, spol in enumerate(sensors):
            step_debug = (debug_rollout_pair == (i, j))

            r_list: List[float] = []
            s_list: List[float] = []
            det = 0
            goal = 0
            steps_list: List[int] = []
            risk_list: List[float] = []

            for k in range(rollouts):
                seed = base_seed + 100000 * i + 1000 * j + k
                st = rollout_episode(env, rpol, spol, M_modes=M_modes, seed=seed, step_debug=step_debug)
                r_list.append(st.U_R)
                s_list.append(st.U_S)
                det += int(st.detected)
                goal += int(st.reached_goal)
                steps_list.append(st.steps)
                risk_list.append(st.total_true_risk)

                if step_debug:
                    step_debug = False  # only show one rollout

            U_R[i, j] = float(np.mean(r_list))
            U_S[i, j] = float(np.mean(s_list))

            diag[(i, j)] = {
                "det_rate": det / rollouts,
                "goal_rate": goal / rollouts,
                "mean_steps": float(np.mean(steps_list)),
                "mean_risk": float(np.mean(risk_list)),
                "std_UR": float(np.std(r_list)),
                "std_US": float(np.std(s_list)),
            }

    if log.k >= 1:
        log.info("[Eval] Compact payoff summary:")
        for i, rpol in enumerate(robots):
            for j, spol in enumerate(sensors):
                d = diag[(i, j)]
                log.info(
                    f"  (R{i}:{rpol.name}, S{j}:{spol.name}) "
                    f"UR={U_R[i,j]:8.3f}±{d['std_UR']:.2f} | "
                    f"US={U_S[i,j]:8.3f}±{d['std_US']:.2f} | "
                    f"det%={100*d['det_rate']:5.1f} goal%={100*d['goal_rate']:5.1f} "
                    f"steps={d['mean_steps']:.1f} risk={d['mean_risk']:.2f}"
                )

    return U_R, U_S, diag


# =============================================================================
# NBS solver (with optional entropy regularization)
# =============================================================================

def project_simplex(v: np.ndarray, z: float = 1.0) -> np.ndarray:
    v = np.asarray(v, dtype=float).reshape(-1)
    if v.size == 0:
        raise ValueError("Empty vector")
    if z <= 0:
        raise ValueError("z must be > 0")

    u = np.sort(v)[::-1]
    cssv = np.cumsum(u)
    rho = np.nonzero(u * np.arange(1, v.size + 1) > (cssv - z))[0]
    if rho.size == 0:
        return np.full_like(v, z / v.size)
    rho = int(rho[-1])
    theta = (cssv[rho] - z) / (rho + 1.0)
    w = np.maximum(v - theta, 0.0)
    s = float(w.sum())
    if not np.isfinite(s) or s <= 0:
        return np.full_like(v, z / v.size)
    return w * (z / s)


@dataclass
class NBSResult:
    x: np.ndarray
    obj: float
    gains: Tuple[float, float]
    support: int


def solve_nbs(
    uR: np.ndarray,
    uS: np.ndarray,
    log: Logger,
    max_iters: int = 400,
    alpha: float = 0.5,
    tol_l1: float = 1e-6,
    kappa: float = 1e-6,
    disagreement: str = "minminus",
    entropy_tau: float = 0.0,
) -> NBSResult:
    uR = np.asarray(uR, dtype=float).reshape(-1)
    uS = np.asarray(uS, dtype=float).reshape(-1)
    if uR.shape != uS.shape:
        raise ValueError("uR and uS must have same shape")
    d = uR.size
    if d < 2:
        raise ValueError("Need >=2 joint actions")

    unif = np.full(d, 1.0 / d)

    disagreement = disagreement.lower().strip()
    if disagreement == "minminus":
        dR = float(np.min(uR) - 1.0)
        dS = float(np.min(uS) - 1.0)
    elif disagreement == "uniform":
        dR = float(uR @ unif)
        dS = float(uS @ unif)
    else:
        raise ValueError("disagreement must be 'minminus' or 'uniform'")

    x = unif.copy()

    def gains(xv: np.ndarray) -> Tuple[float, float]:
        return float(uR @ xv - dR), float(uS @ xv - dS)

    def entropy(xv: np.ndarray) -> float:
        xx = np.clip(xv, 1e-12, 1.0)
        return float(-np.sum(xx * np.log(xx)))

    def obj(xv: np.ndarray) -> float:
        gR, gS = gains(xv)
        gR = max(gR, kappa)
        gS = max(gS, kappa)
        return float(np.log(gR) + np.log(gS) + entropy_tau * entropy(xv))

    def grad(xv: np.ndarray) -> np.ndarray:
        gR, gS = gains(xv)
        gR = max(gR, kappa)
        gS = max(gS, kappa)
        g = (uR / gR) + (uS / gS)
        if entropy_tau > 0:
            xx = np.clip(xv, 1e-12, 1.0)
            g += entropy_tau * (-(np.log(xx) + 1.0))
        return g

    last = obj(x)
    log.info(f"[NBS] d={d} disagreement=({dR:.3f},{dS:.3f}) entropy_tau={entropy_tau:.3g}")

    for t in range(1, max_iters + 1):
        g = grad(x)
        a = alpha
        improved = False
        for _ in range(30):
            x_new = project_simplex(x + a * g)
            new_obj = obj(x_new)
            if new_obj >= last - 1e-12:
                improved = True
                break
            a *= 0.5
            if a < 1e-6:
                break
        if not improved:
            break

        delta = float(np.linalg.norm(x_new - x, ord=1))
        x = x_new
        last = new_obj

        if log.k >= 2 and (t <= 5 or t % 25 == 0):
            gR, gS = gains(x)
            top = np.argsort(-x)[:5]
            top_str = ", ".join([f"{i}:{x[i]:.3f}" for i in top])
            log.debug(f"[NBS][it={t:3d}] obj={last:.6f} gains=({gR:.3f},{gS:.3f}) L1={delta:.2e} top={top_str}")

        if delta < tol_l1:
            break

    gR, gS = gains(x)
    support = int(np.sum(x > 1e-6))
    log.info(f"[NBS] done: obj={last:.6f} gains=({gR:.3f},{gS:.3f}) support={support}/{d}")

    return NBSResult(x=x, obj=float(last), gains=(float(gR), float(gS)), support=support)


def joint_to_matrix(x: np.ndarray, m: int, n: int) -> np.ndarray:
    x = np.asarray(x, dtype=float).reshape(-1)
    if x.size != m * n:
        raise ValueError("x size mismatch")
    X = x.reshape((m, n))
    s = float(X.sum())
    if not np.isfinite(s) or abs(s - 1.0) > 1e-6:
        # Renormalize defensively
        X = X / max(s, 1e-12)
    return X


def marginals_from_joint(X: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
    sigma_R = X.sum(axis=1)
    sigma_S = X.sum(axis=0)
    if sigma_R.sum() > 0:
        sigma_R = sigma_R / sigma_R.sum()
    if sigma_S.sum() > 0:
        sigma_S = sigma_S / sigma_S.sum()
    return sigma_R, sigma_S


def entropy_of_joint(X: np.ndarray) -> float:
    xx = np.clip(X.reshape(-1), 1e-12, 1.0)
    return float(-np.sum(xx * np.log(xx)))


# =============================================================================
# Best responses (marginal vs correlated)
# =============================================================================

def compute_expected_risk_map_from_policy_mixture(
    env: GridWorldStealthEnv,
    sensors: List[SensorPolicy],
    pi: np.ndarray,
    M_modes: int,
) -> np.ndarray:
    """Expected p_true(cell) under mixture over sensor POLICIES (not modes).

    For each sensor policy j, we use its mode at t=0 as its defining mode.
    (This matches FixedModeSensorPolicy exactly.)
    """
    pi = np.asarray(pi, dtype=float).reshape(-1)
    if pi.size != len(sensors):
        raise ValueError("mixture length mismatch")

    mode_probs = np.zeros(M_modes, dtype=float)
    for j, sp in enumerate(sensors):
        m = int(sp.select_mode(0))
        if not (0 <= m < M_modes):
            raise ValueError("invalid sensor mode")
        mode_probs[m] += float(pi[j])
    if mode_probs.sum() > 0:
        mode_probs = mode_probs / mode_probs.sum()

    H, W = env.grid.height, env.grid.width
    risk = np.zeros((H, W), dtype=float)
    for y in range(H):
        for x in range(W):
            if (x, y) in env.grid.obstacles:
                risk[y, x] = np.nan
                continue
            val = 0.0
            for m in range(M_modes):
                val += float(mode_probs[m]) * float(env.true_detection_prob((x, y), m))
            risk[y, x] = float(val)

    return risk


def plan_risk_weighted_path(env: GridWorldStealthEnv, risk_map: np.ndarray, risk_weight: float) -> List[Tuple[int, int]]:
    def step_cost(frm: Tuple[int, int], to: Tuple[int, int]) -> float:
        x, y = to
        r = risk_map[y, x]
        if not np.isfinite(r):
            return 1e9
        return 1.0 + float(risk_weight) * float(r)

    return astar_path(env.grid, env.grid.start, env.grid.goal, step_cost)


def robot_best_response_to_mixture(
    env: GridWorldStealthEnv,
    sensors: List[SensorPolicy],
    pi_S: np.ndarray,
    robots: List[RobotPolicy],
    M_modes: int,
    risk_weight: float,
    log: Logger,
    tag: str,
) -> RobotPolicy:
    risk = compute_expected_risk_map_from_policy_mixture(env, sensors, pi_S, M_modes=M_modes)
    try:
        path = plan_risk_weighted_path(env, risk, risk_weight=risk_weight)
    except Exception as e:
        log.info(f"[RobotBR] WARNING A* failed ({tag}): {e}")
        return robots[0]

    path_tuple = tuple(path)
    for p in robots:
        if isinstance(p, FixedPathPolicy) and tuple(p.path) == path_tuple:
            log.info(f"[RobotBR] ({tag}) BR path already exists: {p.name}")
            return p

    newp = FixedPathPolicy(path, name=f"R_BR_{tag}_w{risk_weight:.1f}_len{len(path)}")
    log.info(f"[RobotBR] ({tag}) Added new robot policy: {newp.name}")
    return newp


def sensor_best_response_to_mixture(
    env: GridWorldStealthEnv,
    robots: List[RobotPolicy],
    pi_R: np.ndarray,
    candidate_modes: List[int],
    M_modes: int,
    rollouts: int,
    base_seed: int,
    log: Logger,
    tag: str,
) -> FixedModeSensorPolicy:
    """Sensor best response with common-random-numbers (CRN).

    Why CRN matters: payoff variance is large (detection is a rare/threshold event).
    If each candidate mode is evaluated on different random rollouts, you can pick
    the wrong 'best mode' by noise, which then breaks the PSRO expansion logic.

    Fix: reuse the same sampled robot indices AND the same episode seeds across all
    candidate modes.
    """
    pi_R = np.asarray(pi_R, dtype=float).reshape(-1)
    if pi_R.size != len(robots):
        raise ValueError("pi_R length mismatch")

    rng = np.random.default_rng(int(base_seed))

    # Same robot-index samples for every mode
    robot_idxs = rng.choice(len(robots), size=rollouts, p=pi_R, replace=True)

    # Same episode seeds for every mode (common random numbers)
    seeds = (int(base_seed) + np.arange(rollouts)).astype(int)

    best_mode: Optional[int] = None
    best_val = -1e18

    for mode in candidate_modes:
        if not (0 <= mode < M_modes):
            continue
        sp = FixedModeSensorPolicy(mode, name=f"S_BR_{tag}_Mode{mode}")

        vals: List[float] = []
        for k in range(rollouts):
            rp = robots[int(robot_idxs[k])]
            st = rollout_episode(env, rp, sp, M_modes=M_modes, seed=int(seeds[k]))
            vals.append(st.U_S)

        mean_u = float(np.mean(vals))
        if log.k >= 2:
            log.debug(f"[SensorBR] ({tag}) mode={mode} E[US]={mean_u:.3f} std={float(np.std(vals)):.2f}")

        if mean_u > best_val:
            best_val = mean_u
            best_mode = mode

    if best_mode is None:
        raise RuntimeError("No valid sensor BR mode found")

    log.info(f"[SensorBR] ({tag}) Best mode={best_mode} E[US]={best_val:.3f}")
    return FixedModeSensorPolicy(best_mode, name=f"S_BR_{tag}_Mode{best_mode}")


def conditional_sensor_given_robot(X: np.ndarray, i: int, eps: float = 1e-12) -> np.ndarray:
    row = np.asarray(X[i, :], dtype=float)
    s = float(row.sum())
    if s <= eps:
        return np.full_like(row, 1.0 / row.size)
    return row / s


def conditional_robot_given_sensor(X: np.ndarray, j: int, eps: float = 1e-12) -> np.ndarray:
    col = np.asarray(X[:, j], dtype=float)
    s = float(col.sum())
    if s <= eps:
        return np.full_like(col, 1.0 / col.size)
    return col / s


def compute_ce_regrets(U_R: np.ndarray, U_S: np.ndarray, X: np.ndarray, eps: float = 1e-12) -> Dict[str, float]:
    """Conditional recommendation regrets (CE-style) computed on current meta-game.

    For robot (given recommendation i):
        regret_R(i) = max_{i'} E_{j~q(.|i)}[U_R(i',j) - U_R(i,j)]

    For sensor (given recommendation j):
        regret_S(j) = max_{j'} E_{i~q(.|j)}[U_S(i,j') - U_S(i,j)]

    Returns max and average regrets.
    """
    m, n = U_R.shape
    assert U_S.shape == (m, n)
    assert X.shape == (m, n)

    sigma_R, sigma_S = marginals_from_joint(X)

    reg_R = []
    for i in range(m):
        if sigma_R[i] <= eps:
            continue
        q = conditional_sensor_given_robot(X, i, eps=eps)
        rec = float(np.dot(q, U_R[i, :]))
        best = rec
        for ip in range(m):
            val = float(np.dot(q, U_R[ip, :]))
            if val > best:
                best = val
        reg_R.append(best - rec)

    reg_S = []
    for j in range(n):
        if sigma_S[j] <= eps:
            continue
        q = conditional_robot_given_sensor(X, j, eps=eps)
        rec = float(np.dot(q, U_S[:, j]))
        best = rec
        for jp in range(n):
            val = float(np.dot(q, U_S[:, jp]))
            if val > best:
                best = val
        reg_S.append(best - rec)

    return {
        "max_regret_R": float(max(reg_R) if reg_R else 0.0),
        "max_regret_S": float(max(reg_S) if reg_S else 0.0),
        "mean_regret_R": float(np.mean(reg_R) if reg_R else 0.0),
        "mean_regret_S": float(np.mean(reg_S) if reg_S else 0.0),
    }


def find_sensor_by_mode(pols: List[SensorPolicy], mode: int) -> Optional[FixedModeSensorPolicy]:
    for p in pols:
        if isinstance(p, FixedModeSensorPolicy) and p.mode == mode:
            return p
    return None


# =============================================================================
# Policy initialization + evaluator for joint strategy
# =============================================================================

def build_initial_policies(env: GridWorldStealthEnv, M_modes: int) -> Tuple[List[RobotPolicy], List[SensorPolicy]]:
    grid = env.grid

    step_cost = lambda a, b: 1.0
    p_short = astar_path(grid, grid.start, grid.goal, step_cost)

    wall_x = grid.width // 2
    p_upper = astar_via_waypoint(grid, grid.start, (wall_x, 2), grid.goal, step_cost)
    p_lower = astar_via_waypoint(grid, grid.start, (wall_x, 6), grid.goal, step_cost)

    robots: List[RobotPolicy] = [
        FixedPathPolicy(p_short, "R_Shortest"),
        FixedPathPolicy(p_upper, "R_UpperCorridor"),
        FixedPathPolicy(p_lower, "R_LowerCorridor"),
        RandomPolicy("R_Random"),
        OnlineBeliefReplanPolicy(env, risk_weight=12.0, name="R_OnlineBeliefReplan"),
    ]
    sensors: List[SensorPolicy] = [FixedModeSensorPolicy(m, name=f"S_Mode{m}") for m in range(M_modes)]
    return robots, sensors


@dataclass
class StrategyEval:
    mean_U_R: float
    mean_U_S: float
    det_rate: float
    goal_rate: float
    mean_steps: float
    mean_risk: float


def evaluate_joint_strategy(
    env: GridWorldStealthEnv,
    robots: List[RobotPolicy],
    sensors: List[SensorPolicy],
    X: np.ndarray,
    M_modes: int,
    episodes: int,
    base_seed: int,
) -> StrategyEval:
    m, n = X.shape
    probs = X.reshape(-1)
    probs = probs / max(float(probs.sum()), 1e-12)

    rng = np.random.default_rng(int(base_seed))

    UR = []
    US = []
    det = 0
    goal = 0
    steps_list = []
    risk_list = []

    for k in range(episodes):
        idx = int(rng.choice(m * n, p=probs))
        i, j = np.unravel_index(idx, (m, n))
        seed = base_seed + k
        st = rollout_episode(env, robots[i], sensors[j], M_modes=M_modes, seed=seed)
        UR.append(st.U_R)
        US.append(st.U_S)
        det += int(st.detected)
        goal += int(st.reached_goal)
        steps_list.append(st.steps)
        risk_list.append(st.total_true_risk)

    return StrategyEval(
        mean_U_R=float(np.mean(UR)),
        mean_U_S=float(np.mean(US)),
        det_rate=float(det / episodes),
        goal_rate=float(goal / episodes),
        mean_steps=float(np.mean(steps_list)),
        mean_risk=float(np.mean(risk_list)),
    )


# =============================================================================
# Training loop (marginal vs correlated)
# =============================================================================

@dataclass
class TrainHistoryRow:
    outer_iter: int
    m: int
    n: int
    nbs_obj: float
    entropy_X: float
    max_regret_R: float
    max_regret_S: float
    selfplay_UR: float
    selfplay_US: float
    selfplay_det: float
    selfplay_goal: float
    seconds: float


@dataclass
class TrainResult:
    solver: str
    env: GridWorldStealthEnv
    robots: List[RobotPolicy]
    sensors: List[SensorPolicy]
    X: np.ndarray
    history: List[TrainHistoryRow]


def run_training(env: GridWorldStealthEnv, args: argparse.Namespace, solver: str, log: Logger) -> TrainResult:
    t0_all = time.time()

    grid = env.grid
    sensor_cfg = env.sensor_cfg
    M_modes = len(sensor_cfg.sensors)

    robots, sensors = build_initial_policies(env, M_modes=M_modes)

    # Optional: reduce initial set if you want smaller games.
    # (We keep it as-is for benchmarks.)

    if log.k >= 1:
        log.info(f"[{solver}] Initial robots: " + ", ".join([p.name for p in robots]))
        log.info(f"[{solver}] Initial sensors: " + ", ".join([p.name for p in sensors]))

    debug_pair = None
    if args.debug_rollout_pair:
        parts = args.debug_rollout_pair.split(",")
        if len(parts) == 2:
            debug_pair = (int(parts[0]), int(parts[1]))
            log.info(f"[{solver}] Will print one step-by-step rollout for pair {debug_pair} (only once).")

    history: List[TrainHistoryRow] = []
    X = None

    for it in range(1, args.outer_iters + 1):
        t0 = time.time()
        log.banner(f"[{solver}] Outer iter {it}/{args.outer_iters}")

        U_R, U_S, _diag = evaluate_payoffs(
            env,
            robots,
            sensors,
            M_modes=M_modes,
            rollouts=args.rollouts_payoff,
            base_seed=1000 + 100 * it,
            log=log,
            debug_rollout_pair=debug_pair,
        )
        debug_pair = None

        # Solve NBS over joint actions
        uR = U_R.reshape(-1)
        uS = U_S.reshape(-1)

        nbs = solve_nbs(
            uR,
            uS,
            log=log,
            disagreement=args.disagreement,
            entropy_tau=args.entropy_tau,
        )

        m, n = U_R.shape
        X = joint_to_matrix(nbs.x, m, n)
        sigma_R, sigma_S = marginals_from_joint(X)

        # Print top joint actions
        top = np.argsort(-X.reshape(-1))[:min(5, X.size)]
        log.info(f"[{solver}] Top joint actions:")
        for k, idx in enumerate(top, start=1):
            i, j = np.unravel_index(int(idx), (m, n))
            log.info(f"  #{k}: (R{i}:{robots[i].name}, S{j}:{sensors[j].name}) prob={X[i,j]:.4f}")
        log.info(f"[{solver}] sigma_R={sigma_R.round(3)}")
        log.info(f"[{solver}] sigma_S={sigma_S.round(3)}")

        # Stability diagnostics
        regrets = compute_ce_regrets(U_R, U_S, X)
        ent = entropy_of_joint(X)

        # Self-play evaluation under joint X
        sp = evaluate_joint_strategy(
            env,
            robots,
            sensors,
            X,
            M_modes=M_modes,
            episodes=args.eval_episodes,
            base_seed=9000 + 100 * it,
        )

        log.info(
            f"[{solver}] CE-regrets: maxR={regrets['max_regret_R']:.3f} maxS={regrets['max_regret_S']:.3f} | "
            f"SelfPlay: UR={sp.mean_U_R:.2f} US={sp.mean_U_S:.2f} det%={100*sp.det_rate:.1f} goal%={100*sp.goal_rate:.1f} | "
            f"H(X)={ent:.3f}"
        )

        # Best-response expansion
        if solver == "marginal":
            br_r = robot_best_response_to_mixture(
                env,
                sensors,
                pi_S=sigma_S,
                robots=robots,
                M_modes=M_modes,
                risk_weight=args.risk_weight_br,
                log=log,
                tag="Marginal",
            )

            br_s = sensor_best_response_to_mixture(
                env,
                robots,
                pi_R=sigma_R,
                candidate_modes=list(range(M_modes)),
                M_modes=M_modes,
                rollouts=args.rollouts_br,
                base_seed=2000 + 100 * it,
                log=log,
                tag="Marginal",
            )

        elif solver == "correlated":
            # FIX: choose conditional mixtures q(.|i) and q(.|j)
            # We only check top-K recommendations to keep it tractable.
            topK = max(1, int(args.cond_top_k))

            # Robot: check the top-K robot recommendations by sigma_R
            cand_i = [int(i) for i in np.argsort(-sigma_R) if sigma_R[int(i)] > 1e-8][:topK]
            if not cand_i:
                cand_i = [int(i) for i in np.argsort(-sigma_R)[:topK]]
            best_gain = 0.0
            br_r = robots[0]

            for i in cand_i:
                qS = conditional_sensor_given_robot(X, int(i))
                tag = f"Cond_i{i}"
                pol = robot_best_response_to_mixture(
                    env,
                    sensors,
                    pi_S=qS,
                    robots=robots,
                    M_modes=M_modes,
                    risk_weight=args.risk_weight_br,
                    log=log,
                    tag=tag,
                )

                # Estimate conditional improvement: E_q[U_R(pol,j)] - E_q[U_R(i,j)]
                # If pol is already in set, its payoff exists in U_R row of that policy.
                # Otherwise we simulate pol against each sensor policy j.
                if pol in robots:
                    ip = robots.index(pol)
                    dev = float(np.dot(qS, U_R[ip, :]))
                else:
                    # simulate quickly vs each sensor policy
                    dev_vals = []
                    for j in range(n):
                        vals = []
                        for kk in range(args.br_eval_rollouts):
                            seed = 777000 + 1000 * it + 100 * i + 10 * j + kk
                            st = rollout_episode(env, pol, sensors[j], M_modes=M_modes, seed=seed)
                            vals.append(st.U_R)
                        dev_vals.append(float(np.mean(vals)))
                    dev = float(np.dot(qS, np.asarray(dev_vals)))

                rec = float(np.dot(qS, U_R[i, :]))
                gain = dev - rec
                if gain > best_gain + 1e-9:
                    best_gain = gain
                    br_r = pol

            if best_gain > args.add_threshold:
                if br_r in robots:
                    log.info(f"[{solver}] Best robot deviation already in set; est_gain={best_gain:.3f}")
                else:
                    log.info(f"[{solver}] Adding robot deviation; est_gain={best_gain:.3f}")
            else:
                log.info(f"[{solver}] No robot deviation above threshold (best_gain={best_gain:.3f}).")

            # Sensor: check top-K sensor recommendations by sigma_S
            cand_j = [int(j) for j in np.argsort(-sigma_S) if sigma_S[int(j)] > 1e-8][:topK]
            if not cand_j:
                cand_j = [int(j) for j in np.argsort(-sigma_S)[:topK]]
            best_gain_s = 0.0
            br_s = FixedModeSensorPolicy(0, name="S_dummy")

            for j in cand_j:
                qR = conditional_robot_given_sensor(X, int(j))
                tag = f"Cond_j{j}"
                polS = sensor_best_response_to_mixture(
                    env,
                    robots,
                    pi_R=qR,
                    candidate_modes=list(range(M_modes)),
                    M_modes=M_modes,
                    rollouts=args.rollouts_br,
                    base_seed=333000 + 1000 * it + 10 * j,
                    log=log,
                    tag=tag,
                )

                # Estimate conditional improvement for sensor
                # rec under recommendation j is E_q[U_S(i,j)]
                recS = float(np.dot(qR, U_S[:, j]))

                # dev under mode polS.mode: if already present, use its column.
                existing_col = None
                for jj, spj in enumerate(sensors):
                    if isinstance(spj, FixedModeSensorPolicy) and spj.mode == polS.mode:
                        existing_col = jj
                        break

                if existing_col is not None:
                    devS = float(np.dot(qR, U_S[:, existing_col]))
                else:
                    vals = []
                    for kk in range(args.br_eval_rollouts):
                        i_samp = int(np.random.default_rng(444 + kk).choice(len(robots), p=qR))
                        seed = 888000 + 1000 * it + 10 * j + kk
                        st = rollout_episode(env, robots[i_samp], polS, M_modes=M_modes, seed=seed)
                        vals.append(st.U_S)
                    devS = float(np.mean(vals))

                gainS = devS - recS
                if gainS > best_gain_s + 1e-9:
                    best_gain_s = gainS
                    br_s = polS

            if best_gain_s > args.add_threshold:
                if (isinstance(br_s, FixedModeSensorPolicy)) and (find_sensor_by_mode(sensors, br_s.mode) is not None):
                    log.info(f"[{solver}] Best sensor deviation already in set (mode={br_s.mode}); est_gain={best_gain_s:.3f}")
                    br_s = None
                else:
                    log.info(f"[{solver}] Adding sensor deviation; est_gain={best_gain_s:.3f}")
            else:
                log.info(f"[{solver}] No sensor deviation above threshold (best_gain={best_gain_s:.3f}).")
                br_s = None

        else:
            raise ValueError("solver must be marginal or correlated")

        # Add to sets (dedupe)
        if br_r not in robots:
            robots.append(br_r)

        if isinstance(br_s, FixedModeSensorPolicy):
            if find_sensor_by_mode(sensors, br_s.mode) is None:
                sensors.append(br_s)
            else:
                log.info(f"[{solver}] Sensor mode {br_s.mode} already present; not adding duplicate.")

        seconds = float(time.time() - t0)
        history.append(
            TrainHistoryRow(
                outer_iter=it,
                m=len(robots),
                n=len(sensors),
                nbs_obj=float(nbs.obj),
                entropy_X=float(ent),
                max_regret_R=float(regrets["max_regret_R"]),
                max_regret_S=float(regrets["max_regret_S"]),
                selfplay_UR=float(sp.mean_U_R),
                selfplay_US=float(sp.mean_U_S),
                selfplay_det=float(sp.det_rate),
                selfplay_goal=float(sp.goal_rate),
                seconds=seconds,
            )
        )

        log.info(f"[{solver}] Sets: |Pi_R|={len(robots)} |Pi_S|={len(sensors)} | iter_seconds={seconds:.2f}")

    if X is None:
        raise RuntimeError("Training produced no X")

    log.banner(f"[{solver}] Finished")
    log.info(f"[{solver}] Final robots: " + ", ".join([p.name for p in robots]))
    log.info(f"[{solver}] Final sensors: " + ", ".join([p.name for p in sensors]))
    log.info(f"[{solver}] Total time: {time.time()-t0_all:.2f}s")

    return TrainResult(solver=solver, env=env, robots=robots, sensors=sensors, X=X, history=history)


# =============================================================================
# Benchmarks + plotting
# =============================================================================

@dataclass
class BenchCell:
    UR: float
    US: float
    det: float
    goal: float


def run_policy_matrix_benchmark(
    env: GridWorldStealthEnv,
    robot_methods: List[RobotPolicy],
    sensor_methods: List[SensorPolicy],
    M_modes: int,
    episodes: int,
    base_seed: int,
) -> Dict[Tuple[int, int], BenchCell]:
    res: Dict[Tuple[int, int], BenchCell] = {}
    for i, rp in enumerate(robot_methods):
        for j, sp in enumerate(sensor_methods):
            UR = []
            US = []
            det = 0
            goal = 0
            for k in range(episodes):
                seed = base_seed + 100000 * i + 1000 * j + k
                st = rollout_episode(env, rp, sp, M_modes=M_modes, seed=seed)
                UR.append(st.U_R)
                US.append(st.U_S)
                det += int(st.detected)
                goal += int(st.reached_goal)
            res[(i, j)] = BenchCell(
                UR=float(np.mean(UR)),
                US=float(np.mean(US)),
                det=float(det / episodes),
                goal=float(goal / episodes),
            )
    return res


def safe_makedirs(path: str) -> None:
    os.makedirs(path, exist_ok=True)


def save_history_csv(hist: List[TrainHistoryRow], path: str) -> None:
    import csv

    fields = list(TrainHistoryRow.__annotations__.keys())
    with open(path, "w", newline="") as f:
        w = csv.DictWriter(f, fieldnames=fields)
        w.writeheader()
        for row in hist:
            w.writerow({k: getattr(row, k) for k in fields})


def plot_training_curves(results: List[TrainResult], outdir: str) -> None:
    import matplotlib.pyplot as plt

    safe_makedirs(outdir)

    # 1) NBS objective
    plt.figure()
    for r in results:
        xs = [h.outer_iter for h in r.history]
        ys = [h.nbs_obj for h in r.history]
        plt.plot(xs, ys, marker="o", label=r.solver)
    plt.xlabel("Outer iteration")
    plt.ylabel("NBS objective")
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(outdir, "curve_nbs_obj.png"), dpi=200)
    plt.close()

    # 2) Max conditional regrets
    plt.figure()
    for r in results:
        xs = [h.outer_iter for h in r.history]
        ys = [h.max_regret_R for h in r.history]
        plt.plot(xs, ys, marker="o", label=f"{r.solver}: maxRegret_R")
    for r in results:
        xs = [h.outer_iter for h in r.history]
        ys = [h.max_regret_S for h in r.history]
        plt.plot(xs, ys, marker="x", linestyle="--", label=f"{r.solver}: maxRegret_S")
    plt.xlabel("Outer iteration")
    plt.ylabel("Max conditional regret")
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(outdir, "curve_max_regrets.png"), dpi=200)
    plt.close()

    # 3) Entropy of X
    plt.figure()
    for r in results:
        xs = [h.outer_iter for h in r.history]
        ys = [h.entropy_X for h in r.history]
        plt.plot(xs, ys, marker="o", label=r.solver)
    plt.xlabel("Outer iteration")
    plt.ylabel("Entropy H(X)")
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(outdir, "curve_entropy_X.png"), dpi=200)
    plt.close()

    # 4) Self-play outcomes
    plt.figure()
    for r in results:
        xs = [h.outer_iter for h in r.history]
        ys = [h.selfplay_UR for h in r.history]
        plt.plot(xs, ys, marker="o", label=f"{r.solver}: UR")
    for r in results:
        xs = [h.outer_iter for h in r.history]
        ys = [h.selfplay_US for h in r.history]
        plt.plot(xs, ys, marker="x", linestyle="--", label=f"{r.solver}: US")
    plt.xlabel("Outer iteration")
    plt.ylabel("Expected utility under X")
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(outdir, "curve_selfplay_utils.png"), dpi=200)
    plt.close()


def plot_benchmark_bars(
    robot_names: List[str],
    sensor_names: List[str],
    bench: Dict[Tuple[int, int], BenchCell],
    outdir: str,
) -> None:
    import matplotlib.pyplot as plt

    safe_makedirs(outdir)

    # For each sensor, bar chart of robot UR
    for j, sname in enumerate(sensor_names):
        plt.figure(figsize=(10, 4))
        vals = [bench[(i, j)].UR for i in range(len(robot_names))]
        plt.bar(range(len(robot_names)), vals)
        plt.xticks(range(len(robot_names)), robot_names, rotation=35, ha="right")
        plt.ylabel("Robot utility")
        plt.title(f"Robot utility vs sensor={sname}")
        plt.tight_layout()
        plt.savefig(os.path.join(outdir, f"bench_UR_vs_{sname}.png"), dpi=200)
        plt.close()

    # For each sensor, bar chart of robot goal rate
    for j, sname in enumerate(sensor_names):
        plt.figure(figsize=(10, 4))
        vals = [bench[(i, j)].goal for i in range(len(robot_names))]
        plt.bar(range(len(robot_names)), vals)
        plt.xticks(range(len(robot_names)), robot_names, rotation=35, ha="right")
        plt.ylabel("Goal rate")
        plt.title(f"Robot goal rate vs sensor={sname}")
        plt.tight_layout()
        plt.savefig(os.path.join(outdir, f"bench_goal_vs_{sname}.png"), dpi=200)
        plt.close()


# =============================================================================
# Main pipeline wrapper
# =============================================================================

def run_pipeline(args: argparse.Namespace) -> None:
    log = Logger(args.log_level)

    grid = build_two_corridor_grid()
    sensor_cfg = SensorConfig(
        sensors=((grid.width // 2, 2), (grid.width // 2, 6)),
        radius=2,
        base_p=0.02,
        hotspot_p=0.60,
    )
    env = GridWorldStealthEnv(grid, sensor_cfg, fp=0.05, fn=0.10, seed=args.seed)

    if log.k >= 1:
        log.banner("[PIPELINE] Stealth grid game")
        log.info(f"Grid: {grid.width}x{grid.height} start={grid.start} goal={grid.goal} obstacles={len(grid.obstacles)}")
        log.info(f"Sensors: {sensor_cfg.sensors} radius={sensor_cfg.radius} base_p={sensor_cfg.base_p} hotspot_p={sensor_cfg.hotspot_p}")
        if log.k >= 2:
            log.debug("ASCII map:")
            print_grid_ascii(grid, sensor_cfg)

    solvers: List[str]
    if args.solver == "both":
        solvers = ["marginal", "correlated"]
    else:
        solvers = [args.solver]

    results: List[TrainResult] = []
    for s in solvers:
        # Use a fresh env copy per solver (to avoid RNG coupling)
        env_s = GridWorldStealthEnv(grid, sensor_cfg, fp=0.05, fn=0.10, seed=args.seed)
        results.append(run_training(env_s, args, solver=s, log=log))

    if args.results_dir:
        safe_makedirs(args.results_dir)
        for r in results:
            save_history_csv(r.history, os.path.join(args.results_dir, f"history_{r.solver}.csv"))

    if args.save_plots and args.results_dir:
        plot_training_curves(results, outdir=args.results_dir)

    # Benchmarks on the same game
    if args.run_benchmarks:
        log.banner("[BENCH] Heuristics on same game")
        M_modes = len(sensor_cfg.sensors)

        # Robot heuristics
        #  - fixed paths from initial set
        robots_init, sensors_init = build_initial_policies_for_bench(env, M_modes)

        # Sensor heuristics
        sensor_methods: List[SensorPolicy] = [
            FixedModeSensorPolicy(0, "S_Mode0"),
            FixedModeSensorPolicy(1, "S_Mode1"),
            AlternatingSensorPolicy(M_modes, "S_Alternate"),
            RandomModeSensorPolicy(M_modes, seed=args.seed + 123, name="S_RandomMode"),
        ]

        bench = run_policy_matrix_benchmark(
            env,
            robot_methods=robots_init,
            sensor_methods=sensor_methods,
            M_modes=M_modes,
            episodes=args.bench_episodes,
            base_seed=555000,
        )

        # Print a compact table
        robot_names = [p.name for p in robots_init]
        sensor_names = [p.name for p in sensor_methods]
        for j, sname in enumerate(sensor_names):
            log.info(f"[BENCH] Sensor={sname}")
            for i, rname in enumerate(robot_names):
                cell = bench[(i, j)]
                log.info(f"  Robot={rname:22s} UR={cell.UR:8.2f} goal%={100*cell.goal:5.1f} det%={100*cell.det:5.1f}")

        if args.save_plots and args.results_dir:
            plot_benchmark_bars(robot_names, sensor_names, bench, outdir=args.results_dir)


# Helper: initial robot set for benchmark (without the solver-added BR policies)

def build_initial_policies_for_bench(env: GridWorldStealthEnv, M_modes: int) -> Tuple[List[RobotPolicy], List[SensorPolicy]]:
    grid = env.grid
    step_cost = lambda a, b: 1.0
    p_short = astar_path(grid, grid.start, grid.goal, step_cost)

    wall_x = grid.width // 2
    p_upper = astar_via_waypoint(grid, grid.start, (wall_x, 2), grid.goal, step_cost)
    p_lower = astar_via_waypoint(grid, grid.start, (wall_x, 6), grid.goal, step_cost)

    # Static risk-aware A* under UNIFORM mode belief
    uniform_mode = np.full(M_modes, 1.0 / M_modes)

    # Construct risk map directly under uniform belief
    H, W = grid.height, grid.width
    risk_uniform = np.zeros((H, W), dtype=float)
    risk_worst = np.zeros((H, W), dtype=float)
    for y in range(H):
        for x in range(W):
            if (x, y) in grid.obstacles:
                risk_uniform[y, x] = np.nan
                risk_worst[y, x] = np.nan
                continue
            vals = [env.true_detection_prob((x, y), m) for m in range(M_modes)]
            risk_uniform[y, x] = float(np.dot(uniform_mode, vals))
            risk_worst[y, x] = float(np.max(vals))

    p_uniform = plan_risk_weighted_path(env, risk_uniform, risk_weight=12.0)
    p_worst = plan_risk_weighted_path(env, risk_worst, risk_weight=12.0)

    robots: List[RobotPolicy] = [
        FixedPathPolicy(p_short, "R_Shortest"),
        FixedPathPolicy(p_upper, "R_UpperCorridor"),
        FixedPathPolicy(p_lower, "R_LowerCorridor"),
        FixedPathPolicy(p_uniform, "R_RiskAStar_Uniform"),
        FixedPathPolicy(p_worst, "R_RiskAStar_WorstCase"),
        OnlineBeliefReplanPolicy(env, risk_weight=12.0, name="R_OnlineBeliefReplan"),
        RandomPolicy("R_Random"),
    ]

    sensors: List[SensorPolicy] = [FixedModeSensorPolicy(m, name=f"S_Mode{m}") for m in range(M_modes)]
    return robots, sensors


# =============================================================================
# CLI
# =============================================================================

def parse_args(argv: Optional[List[str]] = None) -> Tuple[argparse.Namespace, List[str]]:
    p = argparse.ArgumentParser()

    p.add_argument("--seed", type=int, default=0)
    p.add_argument("--log-level", type=str, default="INFO", choices=["QUIET", "INFO", "DEBUG"])

    p.add_argument("--solver", type=str, default="correlated", choices=["marginal", "correlated", "both"])

    p.add_argument("--outer-iters", type=int, default=3)

    p.add_argument("--rollouts-payoff", type=int, default=20)
    p.add_argument("--rollouts-br", type=int, default=30)
    p.add_argument("--risk-weight-br", type=float, default=12.0)

    # NBS knobs
    p.add_argument("--disagreement", type=str, default="minminus", choices=["minminus", "uniform"])
    p.add_argument("--entropy-tau", type=float, default=0.0)

    # Fix mismatch knobs
    p.add_argument("--cond-top-k", type=int, default=2, help="How many top recommendations to check for conditional BRs")
    p.add_argument("--br-eval-rollouts", type=int, default=8, help="Small evaluation rollouts for new deviations")
    p.add_argument("--add-threshold", type=float, default=0.25, help="Minimum estimated conditional gain to add a deviation policy")

    # Eval episodes under joint X (self-play)
    p.add_argument("--eval-episodes", type=int, default=60)

    # Debug
    p.add_argument("--debug-rollout-pair", type=str, default="", help="Print one step-by-step rollout for i,j")

    # Outputs
    p.add_argument("--results-dir", type=str, default="results", help="Directory for csv/plots")
    p.add_argument("--save-plots", action="store_true")

    # Benchmarks
    p.add_argument("--run-benchmarks", action="store_true")
    p.add_argument("--bench-episodes", type=int, default=80)

    args, unknown = p.parse_known_args(args=argv)
    return args, unknown


def main(argv: Optional[List[str]] = None) -> None:
    args, unknown = parse_args(argv=argv)
    if unknown and ("ipykernel" not in sys.modules):
        print(f"[WARN] Ignoring unknown CLI args: {unknown}")
    if args.debug_rollout_pair.strip() == "":
        args.debug_rollout_pair = ""
    run_pipeline(args)


if __name__ == "__main__" and ("ipykernel" not in sys.modules):
    main()


# =============================================================================
# EXPERIMENTS (fixed game set + baselines + tests + presentation plots)
# =============================================================================
#
# Keep the training code above exactly as-is.
# This section adds:
#   (A) sanity tests (quick asserts)
#   (B) a fixed game-set generator (deterministic)
#   (C) an experiment runner that compares solvers + baselines on the same games
#   (D) plots + CSV outputs for slides
#   (E) **3 core GAME figures** for your report: risk maps + belief trace + tradeoff scatter
#
# Notebook usage:
#   run_sanity_tests()
#   make_game_report_figures(outdir="results_game_figs")
#   run_fixed_games_experiment(outdir="results_exp", n_games=6, seeds=[0,1,2])
#

from dataclasses import asdict


@dataclass(frozen=True)
class GameInstance:
    game_id: str
    grid: GridConfig
    sensor_cfg: SensorConfig
    desc: str


def _is_reachable(grid: GridConfig) -> bool:
    try:
        _ = astar_path(grid, grid.start, grid.goal, step_cost=lambda a, b: 1.0, max_expansions=250_000)
        return True
    except Exception:
        return False


def build_fixed_game_set(
    n_games: int = 6,
    seed: int = 123,
    width: int = 15,
    height: int = 9,
    base_radius: int = 2,
    base_p: float = 0.02,
    hotspot_p: float = 0.60,
) -> List[GameInstance]:
    """Deterministic *fixed* game set for fair comparisons.

    We generate corridor-variant grids by adding a small number of extra obstacles
    (without breaking reachability), while keeping the base corridor wall.

    Notes:
      - Uses a fixed RNG seed => same games every run.
      - Guarantees start->goal reachability (ignoring detection risk).
    """
    rng = np.random.default_rng(int(seed))

    base_grid = build_two_corridor_grid(width=width, height=height)
    wall_x = width // 2

    # Keep the two sensor centers near the corridor gaps for interpretability.
    base_sensors = ((wall_x, 2), (wall_x, 6))

    games: List[GameInstance] = []
    attempts = 0

    # Increasing difficulty: more extra obstacles.
    obstacle_budgets = [0, 2, 4, 6, 8, 10]
    while len(games) < n_games:
        attempts += 1
        if attempts > 4000:
            raise RuntimeError("Could not generate enough solvable game instances")

        k = len(games)
        extra_obs_target = obstacle_budgets[min(k, len(obstacle_budgets) - 1)]

        # Candidate cells for extra obstacles (avoid start/goal/sensors and the wall column).
        candidates: List[Tuple[int, int]] = []
        for y in range(height):
            for x in range(width):
                p = (x, y)
                if p in base_grid.obstacles:
                    continue
                if p == base_grid.start or p == base_grid.goal:
                    continue
                if p in base_sensors:
                    continue
                if x == wall_x:
                    continue
                candidates.append(p)

        rng.shuffle(candidates)
        extra = set(candidates[:extra_obs_target])

        grid = GridConfig(
            width=base_grid.width,
            height=base_grid.height,
            start=base_grid.start,
            goal=base_grid.goal,
            obstacles=frozenset(set(base_grid.obstacles) | extra),
        )
        if not _is_reachable(grid):
            continue

        sensor_cfg = SensorConfig(
            sensors=base_sensors,
            radius=int(base_radius),
            base_p=float(base_p),
            hotspot_p=float(hotspot_p),
        )

        game_id = f"G{k:02d}_extraObs{extra_obs_target}"
        desc = f"corridors + {extra_obs_target} extra obstacles"
        games.append(GameInstance(game_id=game_id, grid=grid, sensor_cfg=sensor_cfg, desc=desc))

    return games


# -----------------------------------------------------------------------------
# Sanity tests
# -----------------------------------------------------------------------------

def run_sanity_tests() -> None:
    """Fast, high-signal checks so you can trust comparisons."""
    print("[TEST] Running sanity tests...")

    grid = build_two_corridor_grid()
    sensor_cfg = SensorConfig(
        sensors=((grid.width // 2, 2), (grid.width // 2, 6)),
        radius=2,
        base_p=0.02,
        hotspot_p=0.60,
    )
    env = GridWorldStealthEnv(grid, sensor_cfg, fp=0.05, fn=0.10, seed=0)

    # A* reachability
    p = astar_path(grid, grid.start, grid.goal, step_cost=lambda a, b: 1.0)
    assert p[0] == grid.start and p[-1] == grid.goal

    # Belief update preserves simplex
    b = ModeBelief(M=2)
    env.reset(sensor_mode=0)
    out = env.step((1, 0))
    b.update(env, out["alarm"], out["pos"])
    assert np.isfinite(b.b).all()
    assert abs(float(b.b.sum()) - 1.0) < 1e-6
    assert (b.b >= -1e-12).all()

    # NBS returns simplex
    uR = np.array([-1.0, -2.0, -3.0, -4.0])
    uS = np.array([1.0, 2.0, 3.0, 4.0])
    nbs = solve_nbs(uR, uS, log=Logger("QUIET"), max_iters=50)
    x = nbs.x
    assert np.isfinite(x).all()
    assert (x >= -1e-10).all()
    assert abs(float(x.sum()) - 1.0) < 1e-6

    # Joint/marginals/conditionals are sane
    X = joint_to_matrix(x, m=2, n=2)
    sigma_R, sigma_S = marginals_from_joint(X)
    assert abs(float(sigma_R.sum()) - 1.0) < 1e-6
    assert abs(float(sigma_S.sum()) - 1.0) < 1e-6
    qS = conditional_sensor_given_robot(X, 0)
    qR = conditional_robot_given_sensor(X, 0)
    assert abs(float(qS.sum()) - 1.0) < 1e-6
    assert abs(float(qR.sum()) - 1.0) < 1e-6

    # CE regrets should be >= 0 (numerical)
    UR = np.random.default_rng(0).normal(size=(3, 2))
    US = np.random.default_rng(1).normal(size=(3, 2))
    Xr = np.full((3, 2), 1.0 / 6)
    reg = compute_ce_regrets(UR, US, Xr)
    assert reg["max_regret_R"] >= -1e-9
    assert reg["max_regret_S"] >= -1e-9

    # Rollout returns finite stats
    robots, sensors = build_initial_policies(env, M_modes=2)
    st = rollout_episode(env, robots[0], sensors[0], M_modes=2, seed=0)
    assert np.isfinite(st.U_R) and np.isfinite(st.U_S)

    print("[TEST] All sanity tests passed ✅")


# -----------------------------------------------------------------------------
# Fair comparisons on the same fixed game set
# -----------------------------------------------------------------------------

@dataclass
class ExperimentRow:
    alg: str
    solver: str
    game_id: str
    seed: int
    metric: str
    value: float


def _evaluate_robot_mixture_against_sensor(
    env: GridWorldStealthEnv,
    robots: List[RobotPolicy],
    sigma_R: np.ndarray,
    sensor: SensorPolicy,
    M_modes: int,
    episodes: int,
    base_seed: int,
) -> StrategyEval:
    sigma_R = np.asarray(sigma_R, dtype=float).reshape(-1)
    sigma_R = sigma_R / max(float(sigma_R.sum()), 1e-12)

    rng = np.random.default_rng(int(base_seed))

    UR, US = [], []
    det = 0
    goal = 0
    steps_list = []
    risk_list = []

    for k in range(episodes):
        i = int(rng.choice(len(robots), p=sigma_R))
        st = rollout_episode(env, robots[i], sensor, M_modes=M_modes, seed=int(base_seed + k))
        UR.append(st.U_R)
        US.append(st.U_S)
        det += int(st.detected)
        goal += int(st.reached_goal)
        steps_list.append(st.steps)
        risk_list.append(st.total_true_risk)

    return StrategyEval(
        mean_U_R=float(np.mean(UR)),
        mean_U_S=float(np.mean(US)),
        det_rate=float(det / episodes),
        goal_rate=float(goal / episodes),
        mean_steps=float(np.mean(steps_list)),
        mean_risk=float(np.mean(risk_list)),
    )


def _summarize_against_sensor_suite(
    env: GridWorldStealthEnv,
    robots: List[RobotPolicy],
    sigma_R: np.ndarray,
    sensors_suite: List[SensorPolicy],
    M_modes: int,
    episodes_per_sensor: int,
    base_seed: int,
) -> Dict[str, float]:
    vals_UR = []
    vals_goal = []
    vals_det = []

    for j, sp in enumerate(sensors_suite):
        ev = _evaluate_robot_mixture_against_sensor(
            env,
            robots,
            sigma_R=sigma_R,
            sensor=sp,
            M_modes=M_modes,
            episodes=episodes_per_sensor,
            base_seed=base_seed + 10000 * j,
        )
        vals_UR.append(ev.mean_U_R)
        vals_goal.append(ev.goal_rate)
        vals_det.append(ev.det_rate)

    # "Robust" = worst case over the sensor suite.
    return {
        "mean_UR": float(np.mean(vals_UR)),
        "robust_UR": float(np.min(vals_UR)),
        "mean_goal": float(np.mean(vals_goal)),
        "robust_goal": float(np.min(vals_goal)),
        "mean_det": float(np.mean(vals_det)),
        "robust_det": float(np.max(vals_det)),
    }


def _write_experiment_csv(rows: List[ExperimentRow], path: str) -> None:
    import csv

    safe_makedirs(os.path.dirname(path) or ".")
    with open(path, "w", newline="") as f:
        w = csv.writer(f)
        w.writerow(["alg", "solver", "game_id", "seed", "metric", "value"])
        for r in rows:
            w.writerow([r.alg, r.solver, r.game_id, r.seed, r.metric, f"{r.value:.8f}"])


def _group_stats(values: List[float]) -> Dict[str, float]:
    arr = np.asarray(values, dtype=float)
    return {
        "mean": float(np.mean(arr)) if arr.size else float("nan"),
        "std": float(np.std(arr)) if arr.size else float("nan"),
        "n": int(arr.size),
    }


def plot_experiment_summary(rows: List[ExperimentRow], outdir: str) -> None:
    """Create *paper-style* figures from ExperimentRow logs."""
    import matplotlib.pyplot as plt

    safe_makedirs(outdir)

    algs = sorted({r.alg for r in rows})
    games = sorted({r.game_id for r in rows})

    bucket: Dict[Tuple[str, str, str], List[float]] = {}
    for r in rows:
        key = (r.alg, r.game_id, r.metric)
        bucket.setdefault(key, []).append(float(r.value))

    def get_vals(alg: str, game_id: str, metric: str) -> List[float]:
        return bucket.get((alg, game_id, metric), [])

    def mean_ci95(vals: List[float]) -> Tuple[float, float]:
        arr = np.asarray(vals, dtype=float)
        if arr.size == 0:
            return float("nan"), float("nan")
        mu = float(np.mean(arr))
        if arr.size == 1:
            return mu, 0.0
        se = float(np.std(arr, ddof=1) / np.sqrt(arr.size))
        return mu, 1.96 * se

    def savefig(base: str) -> None:
        plt.tight_layout()
        plt.savefig(os.path.join(outdir, base + ".png"), dpi=300)
        plt.savefig(os.path.join(outdir, base + ".pdf"))
        plt.close()

    def plot_by_game(metric: str, ylabel: str, base: str) -> None:
        plt.figure(figsize=(9.0, 4.2))
        x = np.arange(len(games))
        for alg in algs:
            ys = []
            es = []
            for gid in games:
                mu, ci = mean_ci95(get_vals(alg, gid, metric))
                ys.append(mu)
                es.append(ci)
            ys = np.asarray(ys, dtype=float)
            es = np.asarray(es, dtype=float)
            plt.plot(x, ys, marker="o", linewidth=2.0, label=alg)
            ok = np.isfinite(ys) & np.isfinite(es)
            if np.any(ok):
                plt.fill_between(x[ok], (ys - es)[ok], (ys + es)[ok], alpha=0.15)
        plt.xticks(x, games, rotation=25, ha="right")
        plt.xlabel("Game instance (increasing obstacle perturbations)")
        plt.ylabel(ylabel)
        plt.grid(True, alpha=0.25)
        plt.legend(ncol=2, fontsize=9)
        savefig(base)

    plot_by_game("robust_UR", "Robust robot utility (worst-case over sensor suite) ↑", "fig_robustUR_by_game")
    plot_by_game("robust_goal", "Robust goal rate (worst-case over sensor suite) ↑", "fig_robust_goal_by_game")
    plot_by_game("robust_det", "Worst-case detection rate over sensor suite ↓", "fig_robust_det_by_game")

    def plot_box(metric: str, ylabel: str, base: str) -> None:
        plt.figure(figsize=(9.0, 4.2))
        data = []
        for alg in algs:
            vals = [float(r.value) for r in rows if r.alg == alg and r.metric == metric]
            data.append(vals)
        plt.boxplot(data, labels=algs, showfliers=False)
        plt.xticks(rotation=25, ha="right")
        plt.ylabel(ylabel)
        plt.grid(True, axis="y", alpha=0.25)
        savefig(base)

    plot_box("robust_UR", "Robust robot utility (worst-case over sensor suite) ↑", "fig_box_robustUR")
    plot_box("robust_goal", "Robust goal rate (worst-case over sensor suite) ↑", "fig_box_goal")
    plot_box("robust_det", "Worst-case detection rate over sensor suite ↓", "fig_box_det")

    plt.figure(figsize=(6.0, 5.2))
    for alg in algs:
        xs = []
        ys = []
        for gid in games:
            mu_det, _ = mean_ci95(get_vals(alg, gid, "robust_det"))
            mu_goal, _ = mean_ci95(get_vals(alg, gid, "robust_goal"))
            if np.isfinite(mu_det) and np.isfinite(mu_goal):
                xs.append(mu_det)
                ys.append(mu_goal)
        if xs and ys:
            plt.scatter(xs, ys, label=alg)
    plt.xlabel("Worst-case detection rate (lower is better)")
    plt.ylabel("Worst-case goal rate (higher is better)")
    plt.grid(True, alpha=0.25)
    plt.legend(fontsize=9)
    savefig("fig_goal_det_tradeoff")

    plt.figure(figsize=(6.0, 5.2))
    for alg in algs:
        xs = [float(r.value) for r in rows if r.alg == alg and r.metric == "runtime_s"]
        ys = [float(r.value) for r in rows if r.alg == alg and r.metric == "robust_UR"]
        if xs and ys and len(xs) == len(ys):
            plt.scatter(xs, ys, label=alg)
    plt.xlabel("Runtime (seconds) ↓")
    plt.ylabel("Robust robot utility ↑")
    plt.grid(True, alpha=0.25)
    plt.legend(fontsize=9)
    savefig("fig_runtime_vs_robustUR")

    plt.figure(figsize=(6.0, 5.2))
    for alg in algs:
        xs = [float(r.value) for r in rows if r.alg == alg and r.metric == "max_regret_R"]
        ys = [float(r.value) for r in rows if r.alg == alg and r.metric == "robust_UR"]
        if xs and ys and len(xs) == len(ys):
            plt.scatter(xs, ys, label=alg)
    plt.xlabel("Max conditional regret (robot) ↓")
    plt.ylabel("Robust robot utility ↑")
    plt.grid(True, alpha=0.25)
    plt.legend(fontsize=9)
    savefig("fig_stability_regret_vs_robustUR")


def run_fixed_games_experiment(
    outdir: str = "results_exp",
    n_games: int = 6,
    seeds: Optional[List[int]] = None,
    outer_iters: int = 3,
    rollouts_payoff: int = 20,
    rollouts_br: int = 30,
    eval_episodes: int = 120,
    episodes_per_sensor: int = 80,
    disagreement: str = "minminus",
    entropy_tau: float = 0.0,
    cond_top_k: int = 2,
    br_eval_rollouts: int = 8,
    add_threshold: float = 0.25,
    risk_weight_br: float = 12.0,
    include_solvers: Optional[List[str]] = None,
    include_baselines: bool = True,
    log_level: str = "QUIET",
) -> None:
    """Run a clean comparison on a deterministic game set."""

    if seeds is None:
        seeds = [0, 1, 2]
    if include_solvers is None:
        include_solvers = ["correlated", "marginal"]

    safe_makedirs(outdir)

    games = build_fixed_game_set(n_games=n_games)

    rows: List[ExperimentRow] = []

    def make_sensor_suite(M_modes: int, seed0: int) -> List[SensorPolicy]:
        return [
            FixedModeSensorPolicy(0, "S_Mode0"),
            FixedModeSensorPolicy(1, "S_Mode1"),
            AlternatingSensorPolicy(M_modes, "S_Alternate"),
            RandomModeSensorPolicy(M_modes, seed=seed0 + 999, name="S_RandomMode"),
        ]

    for g in games:
        for s in seeds:
            env = GridWorldStealthEnv(g.grid, g.sensor_cfg, fp=0.05, fn=0.10, seed=s)
            M_modes = len(g.sensor_cfg.sensors)
            sensors_suite = make_sensor_suite(M_modes=M_modes, seed0=s)

            for solver in include_solvers:
                args = argparse.Namespace(
                    seed=s,
                    log_level=log_level,
                    solver=solver,
                    outer_iters=int(outer_iters),
                    rollouts_payoff=int(rollouts_payoff),
                    rollouts_br=int(rollouts_br),
                    risk_weight_br=float(risk_weight_br),
                    disagreement=str(disagreement),
                    entropy_tau=float(entropy_tau),
                    cond_top_k=int(cond_top_k),
                    br_eval_rollouts=int(br_eval_rollouts),
                    add_threshold=float(add_threshold),
                    eval_episodes=int(eval_episodes),
                    debug_rollout_pair="",
                    results_dir="",
                    save_plots=False,
                    run_benchmarks=False,
                    bench_episodes=0,
                )

                t0 = time.time()
                log = Logger(log_level)
                tr = run_training(env, args, solver=solver, log=log)
                runtime_s = float(time.time() - t0)

                sigma_R, _sigma_S = marginals_from_joint(tr.X)
                suite_summary = _summarize_against_sensor_suite(
                    env,
                    robots=tr.robots,
                    sigma_R=sigma_R,
                    sensors_suite=sensors_suite,
                    M_modes=M_modes,
                    episodes_per_sensor=int(episodes_per_sensor),
                    base_seed=100_000 + 1000 * s,
                )

                last = tr.history[-1]

                alg = f"Solver:{solver}"
                for k, v in suite_summary.items():
                    rows.append(ExperimentRow(alg=alg, solver=solver, game_id=g.game_id, seed=s, metric=k, value=float(v)))

                rows.append(ExperimentRow(alg=alg, solver=solver, game_id=g.game_id, seed=s, metric="runtime_s", value=runtime_s))
                rows.append(ExperimentRow(alg=alg, solver=solver, game_id=g.game_id, seed=s, metric="entropy_X", value=float(last.entropy_X)))
                rows.append(ExperimentRow(alg=alg, solver=solver, game_id=g.game_id, seed=s, metric="max_regret_R", value=float(last.max_regret_R)))
                rows.append(ExperimentRow(alg=alg, solver=solver, game_id=g.game_id, seed=s, metric="max_regret_S", value=float(last.max_regret_S)))
                rows.append(ExperimentRow(alg=alg, solver=solver, game_id=g.game_id, seed=s, metric="selfplay_UR", value=float(last.selfplay_UR)))
                rows.append(ExperimentRow(alg=alg, solver=solver, game_id=g.game_id, seed=s, metric="selfplay_US", value=float(last.selfplay_US)))

            if include_baselines:
                robots_base, _ = build_initial_policies_for_bench(env, M_modes=M_modes)

                for rp in robots_base:
                    sigma = np.zeros(len(robots_base), dtype=float)
                    sigma[robots_base.index(rp)] = 1.0

                    suite_summary = _summarize_against_sensor_suite(
                        env,
                        robots=robots_base,
                        sigma_R=sigma,
                        sensors_suite=sensors_suite,
                        M_modes=M_modes,
                        episodes_per_sensor=int(episodes_per_sensor),
                        base_seed=200_000 + 1000 * s,
                    )

                    alg = f"Baseline:{rp.name}"
                    for k, v in suite_summary.items():
                        rows.append(ExperimentRow(alg=alg, solver="baseline", game_id=g.game_id, seed=s, metric=k, value=float(v)))

    csv_path = os.path.join(outdir, "experiments.csv")
    _write_experiment_csv(rows, csv_path)
    plot_experiment_summary(rows, outdir=outdir)

    print(f"[EXP] Done. Wrote: {csv_path}")
    print(f"[EXP] Plots saved to: {outdir}")


# -----------------------------------------------------------------------------
# (E) 3 CORE GAME FIGURES FOR YOUR REPORT
# -----------------------------------------------------------------------------
# These are NOT "training" plots. They are about the GAME itself:
#   1) How the map + sensor modes create different risk landscapes and chokepoints
#   2) How partial observability drives BELIEF updates during a rollout
#   3) The fundamental tradeoff (goal vs detection vs time) across robot behaviors

@dataclass
class Trace:
    pos: List[Tuple[int, int]]
    alarm: List[int]
    p_true: List[float]
    b0: List[float]
    b1: List[float]
    detected: bool
    reached_goal: bool


def rollout_trace(
    env: GridWorldStealthEnv,
    robot: RobotPolicy,
    sensor: SensorPolicy,
    M_modes: int,
    seed: int,
    max_steps: int = 200,
) -> Trace:
    """Like rollout_episode(), but records the trajectory + belief over time."""
    env.seed(seed)
    sensor.reset()
    env.reset(sensor_mode=sensor.select_mode(0))

    belief = ModeBelief(M_modes)
    robot.reset(env.pos)

    pos = [env.pos]
    alarm = []
    p_true = []
    b0 = [float(belief.b[0]) if M_modes > 0 else 1.0]
    b1 = [float(belief.b[1]) if M_modes > 1 else 0.0]

    last_alarm: Optional[int] = None

    for _ in range(max_steps):
        env.sensor_mode = sensor.select_mode(env.t)
        a = robot.act(env, belief, last_alarm)
        out = env.step(a)

        belief.update(env, out["alarm"], out["pos"])
        last_alarm = int(out["alarm"])

        pos.append(out["pos"])
        alarm.append(int(out["alarm"]))
        p_true.append(float(out["p_true"]))
        b0.append(float(belief.b[0]) if M_modes > 0 else 1.0)
        b1.append(float(belief.b[1]) if M_modes > 1 else 0.0)

        if out["done"]:
            break

    reached_goal = (env.pos == env.grid.goal)
    detected = bool(env.detected)
    return Trace(pos=pos, alarm=alarm, p_true=p_true, b0=b0, b1=b1, detected=detected, reached_goal=reached_goal)


def compute_mode_risk_map(env: GridWorldStealthEnv, mode: int) -> np.ndarray:
    H, W = env.grid.height, env.grid.width
    rm = np.full((H, W), np.nan, dtype=float)
    for y in range(H):
        for x in range(W):
            if (x, y) in env.grid.obstacles:
                continue
            rm[y, x] = float(env.true_detection_prob((x, y), mode))
    return rm


def _plot_grid_overlay(ax, env: GridWorldStealthEnv) -> None:
    # obstacles as black squares
    for (x, y) in env.grid.obstacles:
        ax.add_patch(__import__("matplotlib").patches.Rectangle((x - 0.5, y - 0.5), 1, 1, fill=True, color="black", alpha=0.95))

    # start/goal/sensors
    sx, sy = env.grid.start
    gx, gy = env.grid.goal
    ax.scatter([sx], [sy], marker="s", s=80, edgecolor="k")
    ax.text(sx + 0.2, sy + 0.2, "Start", fontsize=9)

    ax.scatter([gx], [gy], marker="*", s=140, edgecolor="k")
    ax.text(gx + 0.2, gy + 0.2, "Goal", fontsize=9)

    for k, (cx, cy) in enumerate(env.sensor_cfg.sensors):
        ax.scatter([cx], [cy], marker="^", s=110, edgecolor="k")
        ax.text(cx + 0.2, cy + 0.2, f"S{k}", fontsize=9)

    ax.set_xlim(-0.5, env.grid.width - 0.5)
    ax.set_ylim(env.grid.height - 0.5, -0.5)  # origin at top
    ax.set_aspect("equal")
    ax.set_xticks(range(env.grid.width))
    ax.set_yticks(range(env.grid.height))
    ax.grid(True, alpha=0.2)


def _plot_path(ax, path: List[Tuple[int, int]], label: str) -> None:
    xs = [p[0] for p in path]
    ys = [p[1] for p in path]
    ax.plot(xs, ys, linewidth=2.2, label=label)


@dataclass
class PairEval:
    goal_rate: float
    det_rate: float
    mean_steps: float


def eval_pair_basic(
    env: GridWorldStealthEnv,
    robot: RobotPolicy,
    sensor: SensorPolicy,
    M_modes: int,
    episodes: int,
    base_seed: int,
) -> PairEval:
    det = 0
    goal = 0
    steps = []
    for k in range(episodes):
        st = rollout_episode(env, robot, sensor, M_modes=M_modes, seed=base_seed + k)
        det += int(st.detected)
        goal += int(st.reached_goal)
        steps.append(int(st.steps))
    return PairEval(goal_rate=float(goal / episodes), det_rate=float(det / episodes), mean_steps=float(np.mean(steps)))


def make_game_report_figures(
    outdir: str = "results_game_figs",
    seed: int = 0,
    episodes_per_point: int = 200,
) -> None:
    """Generate 3 **meaningful, paper-usable** plots about the GAME mechanics.

    Saves BOTH PNG (easy to view) and PDF (best for LaTeX/papers).

    Figures:
      Fig1: risk maps for Mode0/Mode1 + key paths overlaid (shows chokepoints)
      Fig2: belief evolution + alarms + true risk along one rollout (shows partial observability)
      Fig3: goal/detection tradeoff scatter for multiple robot behaviors under each sensor mode
    """
    import matplotlib.pyplot as plt
    from matplotlib.ticker import MaxNLocator

    safe_makedirs(outdir)

    # Build the canonical game
    grid = build_two_corridor_grid()
    sensor_cfg = SensorConfig(
        sensors=((grid.width // 2, 2), (grid.width // 2, 6)),
        radius=2,
        base_p=0.02,
        hotspot_p=0.60,
    )
    env = GridWorldStealthEnv(grid, sensor_cfg, fp=0.05, fn=0.10, seed=seed)
    M_modes = len(sensor_cfg.sensors)

    # Canonical paths (interpretation baselines)
    step_cost = lambda a, b: 1.0
    p_short = astar_path(grid, grid.start, grid.goal, step_cost)
    wall_x = grid.width // 2
    p_upper = astar_via_waypoint(grid, grid.start, (wall_x, 2), grid.goal, step_cost)
    p_lower = astar_via_waypoint(grid, grid.start, (wall_x, 6), grid.goal, step_cost)

    # ------------------------------
    # FIG 1 — Risk landscapes + chokepoint paths
    # ------------------------------
    rm0 = compute_mode_risk_map(env, mode=0)
    rm1 = compute_mode_risk_map(env, mode=1)

    vmin = float(np.nanmin([rm0, rm1]))
    vmax = float(np.nanmax([rm0, rm1]))

    fig, axes = plt.subplots(1, 2, figsize=(10.8, 4.6), constrained_layout=True)

    im0 = axes[0].imshow(rm0, vmin=vmin, vmax=vmax)
    _plot_grid_overlay(axes[0], env)
    _plot_path(axes[0], p_short, "Shortest")
    _plot_path(axes[0], p_upper, "Upper gap")
    _plot_path(axes[0], p_lower, "Lower gap")
    axes[0].set_title("Mode 0 risk map (hotspot at S0)")
    axes[0].legend(fontsize=8, loc="lower left")

    im1 = axes[1].imshow(rm1, vmin=vmin, vmax=vmax)
    _plot_grid_overlay(axes[1], env)
    _plot_path(axes[1], p_short, "Shortest")
    _plot_path(axes[1], p_upper, "Upper gap")
    _plot_path(axes[1], p_lower, "Lower gap")
    axes[1].set_title("Mode 1 risk map (hotspot at S1)")
    axes[1].legend(fontsize=8, loc="lower left")

    cbar = fig.colorbar(im1, ax=axes.ravel().tolist(), shrink=0.95)
    cbar.set_label("True detection probability p_true(x)")

    fig.savefig(os.path.join(outdir, "fig1_riskmaps_paths.png"), dpi=300)
    fig.savefig(os.path.join(outdir, "fig1_riskmaps_paths.pdf"))
    plt.close(fig)

    # ------------------------------
    # FIG 2 — Belief evolution on a rollout (partial observability)
    # ------------------------------
    rp = OnlineBeliefReplanPolicy(env, risk_weight=12.0, name="R_OnlineBeliefReplan")
    sp_true = FixedModeSensorPolicy(1, name="S_Mode1")

    tr = rollout_trace(env, rp, sp_true, M_modes=M_modes, seed=seed)

    T = len(tr.pos) - 1
    ts = np.arange(T + 1)
    ts_step = np.arange(1, T + 1)

    fig, axes = plt.subplots(2, 1, figsize=(10.0, 6.2), sharex=True, constrained_layout=True)

    # (a) belief over modes
    axes[0].plot(ts, tr.b0, linewidth=2.2, label="Belief P(mode=0)")
    axes[0].plot(ts, tr.b1, linewidth=2.2, label="Belief P(mode=1)")
    axes[0].set_ylabel("Belief")
    axes[0].set_ylim(-0.02, 1.02)
    axes[0].yaxis.set_major_locator(MaxNLocator(6))
    axes[0].grid(True, alpha=0.25)
    axes[0].legend(loc="best")
    title_end = "DETECTED" if tr.detected else ("GOAL" if tr.reached_goal else "TIMEOUT")
    axes[0].set_title(f"Belief update under noisy alarms (true mode=1) — episode ends: {title_end}")

    # (b) alarm + true risk
    axes[1].plot(ts_step, tr.p_true, linewidth=2.0, label="True p_true at visited cell")
    # show alarm events as vertical markers
    alarm_ts = [k for k, a in enumerate(tr.alarm, start=1) if a == 1]
    if alarm_ts:
        axes[1].vlines(alarm_ts, ymin=min(tr.p_true + [0.0]), ymax=max(tr.p_true + [1.0]), alpha=0.25, label="Alarm=1")
    axes[1].set_xlabel("Time step")
    axes[1].set_ylabel("Signal")
    axes[1].grid(True, alpha=0.25)
    axes[1].legend(loc="best")

    fig.savefig(os.path.join(outdir, "fig2_belief_trace.png"), dpi=300)
    fig.savefig(os.path.join(outdir, "fig2_belief_trace.pdf"))
    plt.close(fig)

    # ------------------------------
    # FIG 3 — Tradeoff scatter: goal vs detection (and time)
    # ------------------------------
    robots_base, _ = build_initial_policies_for_bench(env, M_modes=M_modes)
    sensors_eval = [FixedModeSensorPolicy(0, "S_Mode0"), FixedModeSensorPolicy(1, "S_Mode1")]

    fig, axes = plt.subplots(1, 2, figsize=(11.0, 4.8), sharey=True, constrained_layout=True)

    for j, sp in enumerate(sensors_eval):
        xs, ys, ss, labels = [], [], [], []
        for i, rp in enumerate(robots_base):
            ev = eval_pair_basic(
                env,
                robot=rp,
                sensor=sp,
                M_modes=M_modes,
                episodes=int(episodes_per_point),
                base_seed=50_000 + 10_000 * j + 1000 * i,
            )
            xs.append(ev.det_rate)
            ys.append(ev.goal_rate)
            # marker size encodes mean steps (bigger = slower)
            ss.append(25.0 + 2.0 * ev.mean_steps)
            labels.append(rp.name)

        axes[j].scatter(xs, ys, s=ss, alpha=0.9)
        for x, y, name in zip(xs, ys, labels):
            axes[j].annotate(name, (x, y), textcoords="offset points", xytext=(6, 4), fontsize=8)

        axes[j].set_title(f"Tradeoff vs {sp.name} (true mode fixed)")
        axes[j].set_xlabel("Detection rate (lower is better)")
        axes[j].grid(True, alpha=0.25)

    axes[0].set_ylabel("Goal rate (higher is better)")

    fig.savefig(os.path.join(outdir, "fig3_tradeoff_scatter.png"), dpi=300)
    fig.savefig(os.path.join(outdir, "fig3_tradeoff_scatter.pdf"))
    plt.close(fig)

    print(f"[GAME-FIGS] Saved 3 report figures to: {outdir}")
    print("[GAME-FIGS] Files: fig1_riskmaps_paths.*, fig2_belief_trace.*, fig3_tradeoff_scatter.*")


In [34]:
main([]) 
make_game_report_figures(outdir="results_game_figs", seed=0, episodes_per_point=200)


[PIPELINE] Stealth grid game
Grid: 15x9 start=(1, 7) goal=(13, 1) obstacles=7
Sensors: ((7, 2), (7, 6)) radius=2 base_p=0.02 hotspot_p=0.6
[correlated] Initial robots: R_Shortest, R_UpperCorridor, R_LowerCorridor, R_Random, R_OnlineBeliefReplan
[correlated] Initial sensors: S_Mode0, S_Mode1
[correlated] Outer iter 1/3
[Eval] Estimating payoffs: m=5, n=2, rollouts=20, base_seed=1100
[Eval] Compact payoff summary:
  (R0:R_Shortest, S0:S_Mode0) UR= -60.414±2.50 | US=  49.254±0.41 | det%=100.0 goal%=  0.0 steps=9.3 risk=1.11
  (R0:R_Shortest, S1:S_Mode1) UR= -41.321±21.09 | US=  25.061±25.61 | det%= 55.0 goal%= 45.0 steps=13.6 risk=0.27
  (R1:R_UpperCorridor, S0:S_Mode0) UR= -60.196±1.83 | US=  49.156±0.29 | det%=100.0 goal%=  0.0 steps=9.2 risk=1.00
  (R1:R_UpperCorridor, S1:S_Mode1) UR= -27.545±15.94 | US=   9.845±22.66 | det%= 25.0 goal%= 75.0 steps=14.8 risk=0.30
  (R2:R_LowerCorridor, S0:S_Mode0) UR= -61.704±3.33 | US=  48.924±0.38 | det%=100.0 goal%=  0.0 steps=10.7 risk=1.05
  (R2:R