# Написанные человеком и проверенные эксперименты

## Вспомогательные функции

In [None]:
from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple

import math
import numpy as np
import torch
from PIL import Image

import matplotlib.pyplot as plt


# -------------------------
# Image conversions
# -------------------------

def pil_to_np_bgr(pil: Image.Image) -> np.ndarray:
    """PIL RGB -> np.uint8 BGR (H,W,3) for ultralytics predictor preprocess."""
    if pil.mode != "RGB":
        pil = pil.convert("RGB")
    rgb = np.asarray(pil)  # uint8 RGB
    bgr = rgb[..., ::-1].copy()
    return bgr


def pil_to_torch_rgb01(pil: Image.Image) -> torch.Tensor:
    """PIL -> torch float32 RGB in [0,1], shape [3,H,W] (CPU)."""
    if pil.mode != "RGB":
        pil = pil.convert("RGB")
    arr = np.asarray(pil).astype(np.float32) / 255.0
    t = torch.from_numpy(arr).permute(2, 0, 1).contiguous()
    return t


# -------------------------
# Patch application
# -------------------------

def apply_patch_to_image(
    base_pil: Image.Image,
    patch_pil: Image.Image,
    position_xy: Tuple[int, int] = (0, 0),
) -> Tuple[Image.Image, Optional[torch.Tensor], int]:
    """Paste patch onto base image at (x,y). Returns patched PIL, bbox_xyxy, area_px."""
    if base_pil.mode != "RGB":
        base_pil = base_pil.convert("RGB")
    if patch_pil.mode != "RGB":
        patch_pil = patch_pil.convert("RGB")
    base_w, base_h = base_pil.size
    patch_w, patch_h = patch_pil.size
    x, y = position_xy
    x1 = max(0, int(x))
    y1 = max(0, int(y))
    x2 = min(base_w, int(x) + patch_w)
    y2 = min(base_h, int(y) + patch_h)
    if x2 <= x1 or y2 <= y1:
        return base_pil.copy(), None, 0
    # Crop patch if it spills outside the base image.
    px1 = x1 - int(x)
    py1 = y1 - int(y)
    px2 = px1 + (x2 - x1)
    py2 = py1 + (y2 - y1)
    patch_crop = patch_pil.crop((px1, py1, px2, py2))
    out = base_pil.copy()
    out.paste(patch_crop, (x1, y1))
    bbox = torch.tensor([float(x1), float(y1), float(x2), float(y2)], dtype=torch.float32)
    area = int((x2 - x1) * (y2 - y1))
    return out, bbox, area


def letterbox_pil(pil: Image.Image, imgsz: int, color: Tuple[int, int, int] = (114, 114, 114)) -> Image.Image:
    """Resize+pad image to square imgsz, matching Ultralytics letterbox behavior."""
    if pil.mode != "RGB":
        pil = pil.convert("RGB")
    w, h = pil.size
    r = min(imgsz / h, imgsz / w)
    new_w = int(round(w * r))
    new_h = int(round(h * r))
    resized = pil.resize((new_w, new_h), Image.BILINEAR)
    out = Image.new("RGB", (imgsz, imgsz), color)
    pad_w = int((imgsz - new_w) / 2)
    pad_h = int((imgsz - new_h) / 2)
    out.paste(resized, (pad_w, pad_h))
    return out

from pathlib import Path
import random
from typing import List


def collect_images(
    folder: Path,
    n: int,
    extensions: tuple[str, ...] = (".png", ".jpg", ".jpeg"),
    shuffle: bool = False,
    seed: int | None = 42,
) -> List[str]:
    """
    Берёт n изображений из папки.
    Если shuffle=True — случайная выборка (с фиксируемым seed).
    """
    files = [
        p for p in folder.iterdir()
        if p.is_file() and p.suffix.lower() in extensions
    ]

    if not files:
        raise RuntimeError(f"В папке {folder} нет изображений")

    if shuffle:
        if seed is not None:
            random.seed(seed)
        random.shuffle(files)

    return [str(p) for p in files[:n]]

In [None]:
def _load_ultralytics_yolo(model_path: str):
    try:
        from ultralytics import YOLO
    except Exception as e:
        raise RuntimeError(
            "Не удалось импортировать ultralytics. Установите пакет ultralytics или убедитесь, что среда активна."
        ) from e
    return YOLO(model_path)


def _get_class_id_from_names(names: Dict[int, str], class_name: str) -> Optional[int]:
    class_name = class_name.lower().strip()
    for k, v in names.items():
        if str(v).lower().strip() == class_name:
            return int(k)
    return None


@torch.no_grad()
def yolo_predict_conf_scalar(
    yolo_model,
    pil_img: Image.Image,
    imgsz: int,
    target_class_id: Optional[int],
    conf: float = 0.001,
) -> Tuple[float, Any]:
    """Ultralytics predict -> scalar confidence (NMS-based) + raw Results."""
    bgr = pil_to_np_bgr(pil_img)
    res = yolo_model.predict(source=bgr, imgsz=imgsz, conf=conf, verbose=False)[0]

    if res.boxes is None or len(res.boxes) == 0:
        return 0.0, res

    confs = res.boxes.conf.detach().cpu().numpy()
    clss = res.boxes.cls.detach().cpu().numpy().astype(int)

    if target_class_id is None:
        return float(confs.max(initial=0.0)), res

    mask = (clss == int(target_class_id))
    if not np.any(mask):
        return 0.0, res

    return float(confs[mask].max(initial=0.0)), res

In [None]:
def build_clean_and_patched_letterboxed(
    base_pil: Image.Image,
    patch_pil: Image.Image,
    cfg: ExpConfig,
) -> Tuple[Image.Image, Image.Image, Optional[torch.Tensor]]:
    clean_lb = letterbox_pil(base_pil, imgsz=cfg.imgsz)

    if cfg.apply_patch_after_letterbox:
        patched_lb, bbox, _ = apply_patch_to_image(clean_lb, patch_pil, cfg.patch_xy)
        return clean_lb, patched_lb, bbox

    patched_orig, _bbox_orig, _ = apply_patch_to_image(base_pil, patch_pil, cfg.patch_xy)
    patched_lb = letterbox_pil(patched_orig, imgsz=cfg.imgsz)
    return clean_lb, patched_lb, None

def pick_center_person_bbox_from_results(
    res,
    target_class_id: int,
    imgsz: int,
    strategy: str = "score",
    min_conf: float = 0.10,
    min_area_frac: float = 0.01,
    ar_range: Tuple[float, float] = (0.15, 1.20),
    w_conf: float = 1.0,
    w_area: float = 0.6,
    w_dist: float = 0.4,
    w_ar: float = 0.2,
):
    """Pick the most plausible 'person' bbox from YOLO post-NMS detections.

    Why this exists:
      Center-closest can pick tiny false positives near center.

    Strategies:
      - 'center' : legacy center-closest (kept for debugging)
      - 'conf'   : highest confidence (tie-breaker by area)
      - 'area'   : largest area (tie-breaker by confidence)
      - 'score'  : weighted score combining confidence, area, center-distance and aspect-ratio prior

    Returns dict: {picked_bbox: [x1,y1,x2,y2], picked_conf: float, picked_idx: int}
    or {} if no person found.
    """
    if res is None or getattr(res, "boxes", None) is None or len(res.boxes) == 0:
        return {}

    xyxy = res.boxes.xyxy.detach().cpu().numpy()          # (N,4)
    conf = res.boxes.conf.detach().cpu().numpy()          # (N,)
    cls  = res.boxes.cls.detach().cpu().numpy().astype(int)  # (N,)

    m = (cls == int(target_class_id))
    if not np.any(m):
        return {}

    idxs = np.flatnonzero(m)
    xyxy_p = xyxy[m]
    conf_p = conf[m]

        # geometry
    w = (xyxy_p[:, 2] - xyxy_p[:, 0]).clip(min=0.0)
    h = (xyxy_p[:, 3] - xyxy_p[:, 1]).clip(min=0.0)
    area = w * h

    # basic filtering to avoid tiny garbage boxes
    min_area = float(min_area_frac) * float(imgsz) * float(imgsz)
    keep = (conf_p >= float(min_conf)) & (area >= min_area)

    # aspect-ratio prior for a person box in full-body-ish imagery (very loose)
    ar = w / (h + 1e-9)  # width/height
    keep = keep & (ar >= float(ar_range[0])) & (ar <= float(ar_range[1]))

    if not np.any(keep):
        # Fallback: just keep the best-conf person (even if small)
        j = int(np.argmax(conf_p))
        picked_idx = int(idxs[j])
        return {
            "picked_bbox": [float(v) for v in xyxy_p[j].tolist()],
            "picked_conf": float(conf_p[j]),
            "picked_idx": picked_idx,
        }

    xyxy_p = xyxy_p[keep]
    conf_p = conf_p[keep]
    area = area[keep]
    ar = ar[keep]
    idxs = idxs[keep]

    # center distance (normalized)

    cx = 0.5 * (xyxy_p[:, 0] + xyxy_p[:, 2])
    cy = 0.5 * (xyxy_p[:, 1] + xyxy_p[:, 3])
    dx = (cx - (float(imgsz) * 0.5)) / float(imgsz)
    dy = (cy - (float(imgsz) * 0.5)) / float(imgsz)
    dist = np.sqrt(dx * dx + dy * dy)  # 0..~0.7

    strategy = str(strategy).lower().strip()

    if strategy == "center":
        j = int(np.argmin(dist))
    elif strategy == "conf":
        # tie-breaker: larger area
        j = int(np.lexsort((-area, -conf_p))[-1])

    elif strategy == "area":
        # tie-breaker: higher conf
        j = int(np.lexsort((-conf_p, -area))[-1])

    elif strategy == "score":
        # normalize area to [0,1] by image area; then compress with sqrt to reduce dominance
        area_n = np.sqrt(area / (float(imgsz) * float(imgsz) + 1e-9))

        # aspect ratio prior: prefer tall-ish boxes (smaller ar). Use a soft penalty around ar~0.5
        # penalty in [0,1], 0 best.
        ar_target = 0.55
        ar_pen = np.clip(np.abs(ar - ar_target) / 0.65, 0.0, 1.0)

        # score: higher is better
        score = (
            float(w_conf) * conf_p
            + float(w_area) * area_n
            - float(w_dist) * dist
            - float(w_ar) * ar_pen
        )
        j = int(np.argmax(score))

    else:
        raise ValueError(f"Unknown strategy='{strategy}'")

    picked_idx = int(idxs[j])

    return {
        "picked_bbox": [float(v) for v in xyxy_p[j].tolist()],
        "picked_conf": float(conf_p[j]),
        "picked_idx": picked_idx,
    }

## Параметры экспериментов

In [None]:
# We do NOT trust folder labels fully; we will oversample from folders and then
# select a balanced set by the *computed* success flag (drop > CFG.success_thresh).
N_FROM_EACH_DIR = 150   # how many files to pull from each folder initially
N_SUCCESS = 100         # how many SUCCESS examples to keep after labeling
N_FAIL = 100            # how many FAIL examples to keep after labeling
SEED = 17

# successful_dir = Path("../stats/attack_split/success")
# unsuccessful_dir = Path("../stats/attack_split/fail")

successful_dir = Path("successful_examples")
unsuccessful_dir = Path("unsuccessful_examples")

image_paths = (
    collect_images(successful_dir, N_FROM_EACH_DIR, shuffle=True, seed=SEED)
    + collect_images(unsuccessful_dir, N_FROM_EACH_DIR, shuffle=True, seed=SEED)
)
patch_path = "data/patch.png"

In [None]:
@dataclass
class ExpConfig:
    # путь к весам (можно заменить на локальный)
    model_path: str = "yolo11s.pt"

    # размер входа
    imgsz: int = 640

    # устройство
    device: str = "cuda" if torch.cuda.is_available() else "mps" if torch.mps.is_available() else "cpu"

    # True: накладываем патч уже на letterbox-изображение (самый честный режим)
    # False: накладываем на оригинал, потом делаем letterbox
    apply_patch_after_letterbox: bool = True

    # позиция патча (x,y) в координатах того изображения, куда накладываем
    patch_xy: Tuple[int, int] = (0, 0)

    # целевой класс (None => берём max по всем классам)
    target_class_name: Optional[str] = "person"

    # успех атаки: падение скаляра уверенности >= threshold
    success_thresh: float = 0.30

    # число каналов, используемое в top-k метриках и в абляциях
    topk_channels: int = 32

    # ROI scalar reduction for attribution: "max" | "logsumexp" | "topk_mean"
    roi_reduce: str = "logsumexp"

    # temperature for logsumexp (larger -> closer to max)
    roi_lse_temp: float = 20.0

    # k for topk_mean
    roi_topk: int = 10

    # --- ROI attribution layer selection ---
    # How many layers to keep for ROI-attribution (set 1 to mimic the previous "single good layer" behavior)
    roi_keep_topk_layers: int = 1

    # How many near-head conv layers to consider as candidates
    roi_candidate_last_n_layers: int = 8

    # How many examples from each group (success/fail) to use for scanning grad signal
    roi_scan_n_per_group: int = 2

    # Minimum median ma_abs.max required to accept a layer (0.0 keeps the best even if tiny)
    roi_min_grad_strength: float = 0.0

CFG = ExpConfig()

## Подсчет разности фич между атакованным и чистым примерами

Берем хуками выходы предпоследних слоев модели и считаем:
- `F_l(clean)` и `F_l(patched)` для набора слоёв `l`
- `Δ_l = F_l(patched) - F_l(clean)`

In [None]:
import torch.nn as nn


def pick_default_conv_layers(model_torch: nn.Module, n_layers: int = 12) -> List[str]:
    convs: List[str] = []
    for n, m in model_torch.named_modules():
        if isinstance(m, nn.Conv2d):
            convs.append(n)
    if not convs:
        raise RuntimeError("Не удалось найти Conv2d модули в модели.")

    idxs = np.linspace(0, len(convs) - 1, num=min(n_layers, len(convs))).round().astype(int)
    picked = []
    seen = set()
    for i in idxs:
        name = convs[int(i)]
        if name not in seen:
            picked.append(name)
            seen.add(name)
    return picked


def get_module_by_name(model_torch: nn.Module, name: str) -> nn.Module:
    d = dict(model_torch.named_modules())
    if name not in d:
        raise KeyError(f"Layer '{name}' not found.")
    return d[name]


@torch.no_grad()
def torch_preprocess_letterboxed(pil_img_lb: Image.Image, device: str, dtype: torch.dtype) -> torch.Tensor:
    """Preprocess PIL -> BCHW tensor on the requested device/dtype."""
    t = pil_to_torch_rgb01(pil_img_lb)  # CPU float32 [3,H,W]
    t = t.unsqueeze(0)  # [1,3,H,W]
    return t.to(device=device, dtype=dtype, non_blocking=False)


def capture_activations_for_layers(
    model_torch: nn.Module,
    x_bchw: torch.Tensor,
    layer_names: List[str],
) -> Dict[str, torch.Tensor]:
    acts: Dict[str, torch.Tensor] = {}
    hooks = []

    def _mk_hook(name: str):
        def _hook(_m, _inp, out):
            # Some Ultralytics blocks can emit tuples/lists; take the first tensor.
            if isinstance(out, (list, tuple)):
                for o in out:
                    if isinstance(o, torch.Tensor):
                        out = o
                        break

            # We focus on spatial feature maps (B,C,H,W). Skip others.
            if isinstance(out, torch.Tensor) and out.ndim == 4:
                acts[name] = out.detach().cpu()
        return _hook

    for ln in layer_names:
        m = get_module_by_name(model_torch, ln)
        hooks.append(m.register_forward_hook(_mk_hook(ln)))
    # MPS requires input and weights to share device AND dtype.
    x_bchw = x_bchw.to(
        device=next(model_torch.parameters()).device,
        dtype=next(model_torch.parameters()).dtype,
    )


    _ = model_torch(x_bchw)

    for h in hooks:
        h.remove()

    return acts

def _top_level_model_index(name: str) -> Optional[int]:
    """Parse 'model.<idx>....' -> idx. Returns None if not matching."""
    if not name.startswith("model."):
        return None
    parts = name.split(".")
    if len(parts) < 2:
        return None
    try:
        return int(parts[1])
    except Exception:
        return None


def find_head_start_index(model_torch: nn.Module) -> int:
    """Heuristically find the first top-level model.<idx> belonging to the head."""
    head_class_names = {
        "Detect", "Segment", "Pose", "OBB", "Classify", "RTDETRDecoder", "WorldDetect",
    }

    best_idx: Optional[int] = None

    # 1) Prefer class-name detection
    for n, m in model_torch.named_modules():
        if m.__class__.__name__ in head_class_names:
            idx = _top_level_model_index(n)
            if idx is not None:
                best_idx = idx if best_idx is None else min(best_idx, idx)

    # 2) Fallback: name heuristics
    if best_idx is None:
        for n, _m in model_torch.named_modules():
            nnm = n.lower()
            if ("dfl" in nnm) or ("detect" in nnm) or ("decoder" in nnm):
                idx = _top_level_model_index(n)
                if idx is not None:
                    best_idx = idx if best_idx is None else min(best_idx, idx)

    # 3) If still unknown: treat everything as pre-head
    if best_idx is None:
        top_idxs = []
        for n, _m in model_torch.named_modules():
            idx = _top_level_model_index(n)
            if idx is not None:
                top_idxs.append(idx)
        return max(top_idxs) if top_idxs else 10**9

    return int(best_idx)

### Запуск и подсчет нужных метрик по слоям

In [None]:
# load model

yolo = _load_ultralytics_yolo(CFG.model_path)

# Put Ultralytics wrapper AND underlying torch module on the same device.
# On Apple Silicon, this ensures weights are moved to MPS as well.
try:
    yolo.to(CFG.device)
except Exception:
    # Some ultralytics versions may not expose .to(); we still move the torch model below.
    pass

model_torch: nn.Module = yolo.model
model_torch = model_torch.to(CFG.device)
model_torch.eval()

# Keep a reference dtype for safe casting of inputs (MPS is strict about dtype/device).
_MODEL_DTYPE = next(model_torch.parameters()).dtype


names = getattr(yolo, "names", None)
if names is None:
    names = getattr(model_torch, "names", {})

if CFG.target_class_name is not None:
    target_class_id = _get_class_id_from_names(names, CFG.target_class_name)
    if target_class_id is None:
        print(f"[warn] target_class_name='{CFG.target_class_name}' не найден в model.names; использую max по всем классам")
else:
    target_class_id = None

# ---- Use all pre-head top-level blocks as layers ----
def pick_prehead_top_level_blocks(model_torch: nn.Module) -> List[str]:
    """Return names ['model.0', ..., f'model.{head_start-1}'] for pre-head top-level blocks."""
    head_start = find_head_start_index(model_torch)
    # Confirm model has that many top-level blocks
    seq = getattr(model_torch, "model", None)
    if isinstance(seq, (nn.Sequential, list, tuple)):
        n_blocks = len(seq)
        head_start = min(head_start, n_blocks)
    return [f"model.{i}" for i in range(int(head_start))]

head_start_idx = find_head_start_index(model_torch)
# Choose which blocks to analyze.
# - Full pre-head profile: pick_prehead_top_level_blocks(model_torch)
# - Single block (requested): ["model.22"]
layer_names = ["model.22"]

print(f"Head starts at top-level model index: {head_start_idx}")
print(f"Selected blocks: {len(layer_names)} -> {layer_names}")

patch_pil = Image.open(patch_path).convert("RGB")

run_data: List[Dict[str, Any]] = []

for p in image_paths:
    base_pil = Image.open(p).convert("RGB")
    clean_lb, patched_lb, patch_bbox_lb = build_clean_and_patched_letterboxed(base_pil, patch_pil, CFG)

    conf_clean, res_clean = yolo_predict_conf_scalar(yolo, clean_lb, CFG.imgsz, target_class_id)
    conf_patch, res_patch = yolo_predict_conf_scalar(yolo, patched_lb, CFG.imgsz, target_class_id)

    gradcam_info = {}
    if target_class_id is not None:
        gradcam_info = pick_center_person_bbox_from_results(
            res_clean,
            target_class_id=int(target_class_id),
            imgsz=int(CFG.imgsz),
            strategy="score",      # попробуй также: 'conf' или 'area'
            min_conf=0.10,
            min_area_frac=0.01,
        )

    x_clean = torch_preprocess_letterboxed(clean_lb, device=CFG.device, dtype=_MODEL_DTYPE)
    x_patch = torch_preprocess_letterboxed(patched_lb, device=CFG.device, dtype=_MODEL_DTYPE)

    acts_clean = capture_activations_for_layers(model_torch, x_clean, layer_names)
    acts_patch = capture_activations_for_layers(model_torch, x_patch, layer_names)

    deltas: Dict[str, torch.Tensor] = {}
    for ln in layer_names:
        if ln in acts_clean and ln in acts_patch:
            deltas[ln] = (acts_patch[ln] - acts_clean[ln])

    drop = float(conf_clean - conf_patch)
    #success = bool(conf_clean >= CFG.success_thresh > conf_patch)
    success = bool(drop > CFG.success_thresh)

    run_data.append(
        {
            "path": p,
            "clean_lb": clean_lb,
            "patched_lb": patched_lb,
            "patch_bbox_lb": patch_bbox_lb,
            "conf_clean": float(conf_clean),
            "conf_patch": float(conf_patch),
            "drop": drop,
            "success": success,
            "acts_clean": acts_clean,
            "acts_patch": acts_patch,
            "deltas": deltas,
            "res_clean": res_clean,
            "res_patch": res_patch,
            "gradcam_info": gradcam_info,
        }
    )

print("\nPer-image summary:")
for d in run_data:
    print(f"- {d['path']}: clean={d['conf_clean']:.3f} patched={d['conf_patch']:.3f} drop={d['drop']:.3f} success={d['success']}")


# --- Re-balance by computed success label (folder names may be wrong) ---
rng = np.random.default_rng(SEED)

succ_rd = [d for d in run_data if bool(d.get("success", False))]
fail_rd = [d for d in run_data if not bool(d.get("success", False))]

# shuffle to avoid ordering bias
rng.shuffle(succ_rd)
rng.shuffle(fail_rd)

n_s = min(int(N_SUCCESS), len(succ_rd))
n_f = min(int(N_FAIL), len(fail_rd))

run_data_balanced = succ_rd[:n_s] + fail_rd[:n_f]
rng.shuffle(run_data_balanced)

print(f"\n[rebalance] Computed labels in pool: success={len(succ_rd)} fail={len(fail_rd)}")
print(f"[rebalance] Keeping: success={n_s} fail={n_f} (total={len(run_data_balanced)})")

# Overwrite run_data so all downstream cells use the balanced subset.
run_data = run_data_balanced

## Визуализация карт разности фич (Δ = F(patched) − F(clean))

Для каждого изображения строим 2D-карту разности на выбранном слое, схлопывая 512 каналов в один скаляр на пиксель.
Покажем две версии:
- **signed**: сохраняем знак (положительные/отрицательные изменения)
- **abs**: модуль разности

Далее можно варьировать способ схлопывания каналов (mean / l2 / max / topk_mean / pca1), чтобы уменьшить потерю информации.


In [None]:
from pathlib import Path
import torch.nn.functional as F
import matplotlib.patches as patches
from typing import List


def _reduce_channels(delta_chw: torch.Tensor, mode: str = "l2", topk: int = 32) -> torch.Tensor:
    """Reduce CxHxW -> HxW.

    mode:
      - 'mean'      : mean over channels (signed)
      - 'l2'        : sqrt(mean(delta^2)) (non-negative)
      - 'max'       : max over channels (signed, preserves sign of max value)
      - 'abs_mean'  : mean(|delta|)
      - 'topk_mean' : mean of top-k |delta|, keeps sign by multiplying by sign(mean(delta over topk idx))
      - 'pca1'      : first principal component projection (signed). Uses a light SVD on (HW x C).

    Returns: HxW on CPU.
    """
    if delta_chw.ndim != 3:
        raise ValueError(f"Expected CHW tensor, got {tuple(delta_chw.shape)}")

    C, H, W = delta_chw.shape
    x = delta_chw.reshape(C, H * W).T  # (HW, C)

    mode = mode.lower().strip()

    if mode == "mean":
        return delta_chw.mean(dim=0)

    if mode == "abs_mean":
        return delta_chw.abs().mean(dim=0)

    if mode == "l2":
        # sqrt(mean(delta^2))
        return (delta_chw.float().pow(2).mean(dim=0)).sqrt()

    if mode == "max":
        return delta_chw.max(dim=0).values

    if mode == "topk_mean":
        k = int(min(max(topk, 1), C))
        # indices of top-k by magnitude per spatial position
        # We do it in Cx(HW) space to stay vectorized.
        mag = delta_chw.abs().reshape(C, H * W)
        _, idx = torch.topk(mag, k=k, dim=0, largest=True, sorted=False)
        vals = delta_chw.reshape(C, H * W).gather(0, idx)  # (k, HW)
        # signed aggregation: mean of selected values
        out = vals.mean(dim=0).reshape(H, W)
        return out

    if mode == "pca1":
        # Center across spatial positions, compute first right-singular vector.
        # x: (HW, C)
        x = x.float()
        x = x - x.mean(dim=0, keepdim=True)
        # For typical YOLO feature maps, HW is not huge; SVD is acceptable for visualization.
        # We only need V[:,0].
        try:
            # torch.linalg.svd returns U, S, Vh; Vh shape (C, C)
            _U, _S, Vh = torch.linalg.svd(x, full_matrices=False)
            v0 = Vh[0]  # (C,)
        except Exception:
            # Fallback: use power iteration on covariance (C x C) if SVD fails.
            cov = (x.T @ x) / max(x.shape[0] - 1, 1)
            v0 = torch.randn((C,), device=cov.device, dtype=cov.dtype)
            v0 = v0 / (v0.norm() + 1e-12)
            for _ in range(15):
                v0 = cov @ v0
                v0 = v0 / (v0.norm() + 1e-12)

        proj = (x @ v0)  # (HW,)
        return proj.reshape(H, W)

    raise ValueError(f"Unknown mode='{mode}'")


def _robust_norm_signed(hw: np.ndarray, q: float = 99.0) -> Tuple[np.ndarray, float]:
    """Normalize signed map to [-1,1] using symmetric percentile scaling."""
    s = float(np.percentile(np.abs(hw), q))
    s = max(s, 1e-8)
    out = np.clip(hw / s, -1.0, 1.0)
    return out, s


def _robust_norm_positive(hw: np.ndarray, q: float = 99.0) -> Tuple[np.ndarray, float]:
    """Normalize non-negative map to [0,1] using percentile scaling."""
    s = float(np.percentile(hw, q))
    s = max(s, 1e-8)
    out = np.clip(hw / s, 0.0, 1.0)
    return out, s


In [None]:
from random import sample

def visualize_delta_maps(
    run_data: List[Dict[str, Any]],
    layer: str = "model.22",
    reduce_mode: str = "topk_mean",
    topk: int = 32,
    max_rows: int | None = None,
    save_dir: str | Path | None = None,
):
    """Show per-image delta maps overlaid on the patched image.

    Layout per row: patched | patched+signedΔ | patched+absΔ

    Notes:
    - signed overlay uses a diverging cmap (seismic) with symmetric robust scaling.
    - abs overlay uses a positive cmap (viridis) with robust scaling.
    - alpha controls overlay transparency.
    """
    if save_dir is not None:
        save_dir = Path(save_dir)
        save_dir.mkdir(parents=True, exist_ok=True)

    items = run_data if max_rows is None else run_data[: int(max_rows)]

    n = len(items)
    if n == 0:
        raise RuntimeError("run_data is empty")

    alpha = 1  # overlay transparency

    def _draw_clean_bbox(ax, bbox_xyxy, success: bool):
        """Draw the predicted person bbox from clean image onto an axis."""
        if bbox_xyxy is None:
            return
        x1, y1, x2, y2 = [float(v) for v in bbox_xyxy]
        w = max(0.0, x2 - x1)
        h = max(0.0, y2 - y1)
        if w <= 0 or h <= 0:
            return
        # Green for success, red for fail
        ec = "lime" if success else "black"
        rect = patches.Rectangle((x1, y1), w, h, linewidth=4.0, edgecolor=ec, facecolor="none")
        ax.add_patch(rect)


    fig, axes = plt.subplots(nrows=n, ncols=3, figsize=(15, 4.6 * n))
    if n == 1:
        axes = np.expand_dims(axes, axis=0)

    for i, d in enumerate(items):
        p = d.get("path", f"idx_{i}")
        success = d.get("success", False)
        drop = d.get("drop", 0.0)
        conf_c = d.get("conf_clean", 0.0)
        conf_p = d.get("conf_patch", 0.0)

        status = "SUCCESS" if success else "FAIL"
        # Predicted bbox of the person BEFORE attack (from clean prediction)
        bbox_xyxy = None
        gi = d.get("gradcam_info", {})
        if isinstance(gi, dict):
            bbox_xyxy = gi.get("picked_bbox", None)

        # Base (patched) image
        patched_img = d["patched_lb"]

        # --- Δ feature map ---
        delta = d["deltas"].get(layer, None)
        if delta is None:
            for j, title in enumerate(["patched", "patched + signed Δ", "patched + abs Δ"]):
                axes[i, j].imshow(patched_img)
                _draw_clean_bbox(axes[i, j], bbox_xyxy, success)
                axes[i, j].set_title(title + " (missing Δ)")
                axes[i, j].axis("off")
            axes[i, 0].set_ylabel(
                f"{Path(p).name}\n{status} | clean={conf_c:.3f} patched={conf_p:.3f} drop={drop:.3f}",
                rotation=0,
                labelpad=60,
                va="center",
            )
            continue

        # delta: (B,C,H,W) -> (C,H,W)
        if delta.ndim == 4:
            delta_chw = delta[0]
        else:
            raise ValueError(f"Unexpected delta ndim={delta.ndim} for layer={layer}")

        # Reduce channels -> HxW feature-space map (typically ~20x20) and upsample to image resolution (e.g., 640x640)
        signed_hw_t = _reduce_channels(delta_chw, mode=reduce_mode, topk=topk).detach().cpu()  # torch [h,w]

        # Target size is the rendered patched image size
        if isinstance(patched_img, Image.Image):
            tgt_w, tgt_h = patched_img.size
        else:
            # assume numpy HWC
            tgt_h, tgt_w = np.asarray(patched_img).shape[:2]

        if tuple(signed_hw_t.shape) != (tgt_h, tgt_w):
            signed_hw_t = F.interpolate(
                signed_hw_t[None, None, ...].float(),
                size=(int(tgt_h), int(tgt_w)),
                mode="nearest",
            )[0, 0]

        signed_hw = signed_hw_t.numpy()

        # abs map: if reduce_mode already non-negative, keep it; otherwise take abs
        if reduce_mode in {"l2", "abs_mean"}:
            abs_hw = signed_hw
        else:
            abs_hw = np.abs(signed_hw)

        signed_n, s_signed = _robust_norm_signed(signed_hw, q=99.0)
        abs_n, s_abs = _robust_norm_positive(abs_hw, q=99.0)

        # --- Col 1: patched only ---
        axes[i, 0].imshow(patched_img)
        _draw_clean_bbox(axes[i, 0], bbox_xyxy, success)
        axes[i, 0].set_title(f"patched | {status}\nclean={conf_c:.3f} → patched={conf_p:.3f} (drop={drop:.3f})")
        axes[i, 0].axis("off")

                # Row badge
        axes[i, 0].text(
            8,
            18,
            status,
            fontsize=14,
            fontweight="bold",
            color=("lime" if success else "red"),
            bbox=dict(facecolor="black", alpha=0.35, pad=3, edgecolor="none"),
        )


        # --- Col 2: patched + signed Δ ---
        axes[i, 1].imshow(patched_img)
        _draw_clean_bbox(axes[i, 1], bbox_xyxy, success)
        im_signed = axes[i, 1].imshow(signed_n, vmin=-1, vmax=1, cmap="seismic", alpha=alpha)
        axes[i, 1].set_title(f"patched + signed Δ | {reduce_mode}\n{status} | clean={conf_c:.3f} → patched={conf_p:.3f}")
        axes[i, 1].axis("off")
        plt.colorbar(im_signed, ax=axes[i, 1], fraction=0.046, pad=0.04)

        # --- Col 3: patched + abs Δ ---
        axes[i, 2].imshow(patched_img)
        _draw_clean_bbox(axes[i, 2], bbox_xyxy, success)
        im_abs = axes[i, 2].imshow(abs_n, vmin=0, vmax=1, cmap="viridis", alpha=alpha)
        axes[i, 2].set_title(f"patched + abs Δ | {reduce_mode}\n{status} | clean={conf_c:.3f} → patched={conf_p:.3f}")
        axes[i, 2].axis("off")
        plt.colorbar(im_abs, ax=axes[i, 2], fraction=0.046, pad=0.04)

        axes[i, 0].set_ylabel(
            f"{Path(p).name}\n{status} | clean={conf_c:.3f} patched={conf_p:.3f} drop={drop:.3f}",
            rotation=0,
            labelpad=60,
            va="center",
        )

        if save_dir is not None:
            stem = Path(p).stem
            # Save overlays as images
            out_signed = (save_dir / f"{stem}__overlay_signed_{reduce_mode}.png")
            out_abs = (save_dir / f"{stem}__overlay_abs_{reduce_mode}.png")

            # Render and save (matplotlib imsave doesn't support alpha-compositing easily),
            # so we create composited RGB arrays.
            base = np.asarray(patched_img).astype(np.float32) / 255.0  # HWC RGB

            # signed overlay: map [-1,1] -> RGBA via cmap
            cmap_s = plt.get_cmap("seismic")
            rgba_s = cmap_s((signed_n + 1.0) * 0.5)  # HxWx4
            comp_s = (1 - alpha) * base + alpha * rgba_s[..., :3]
            comp_s = np.clip(comp_s, 0.0, 1.0)

            cmap_a = plt.get_cmap("viridis")
            rgba_a = cmap_a(abs_n)  # HxWx4
            comp_a = (1 - alpha) * base + alpha * rgba_a[..., :3]
            comp_a = np.clip(comp_a, 0.0, 1.0)

            plt.imsave(out_signed, comp_s)
            plt.imsave(out_abs, comp_a)

    plt.tight_layout()
    plt.show()


# Быстрый запуск
visualize_delta_maps(
    run_data[:10],
    layer="model.22",
    reduce_mode="mean",  # попробуй: 'mean', 'max', 'l2', 'abs_mean', 'topk_mean', 'pca1'
    topk=CFG.topk_channels,
    max_rows=None,
    save_dir=None,  # например: "viz_delta_overlays/model22"
)


# Выводы из эксперимента

Синий = фича упала, предположительно, нам нужно именно это в ROI

In [None]:
import pandas as pd
import numpy as np
import torch

# --- gini для неотрицательных величин ---
def gini(x: np.ndarray, eps: float = 1e-12) -> float:
    x = np.asarray(x, dtype=np.float64).reshape(-1)
    x = x[np.isfinite(x)]
    if x.size == 0:
        return float("nan")
    x = np.clip(x, 0.0, None)
    s = x.sum()
    if s <= eps:
        return 0.0
    x = np.sort(x)
    n = x.size
    i = np.arange(1, n + 1, dtype=np.float64)
    return float(1.0 + 1.0 / n - 2.0 * np.sum((n + 1 - i) * x) / (n * s))

# bbox 640x640 -> (H,W) координаты фич
def xyxy_to_hw_roi(xyxy, H: int, W: int, imgsz: int = 640):
    if xyxy is None:
        return None
    x1, y1, x2, y2 = [float(v) for v in xyxy]
    x1 = max(0.0, min(float(imgsz), x1)); x2 = max(0.0, min(float(imgsz), x2))
    y1 = max(0.0, min(float(imgsz), y1)); y2 = max(0.0, min(float(imgsz), y2))
    if x2 <= x1 or y2 <= y1:
        return None
    sx = W / float(imgsz); sy = H / float(imgsz)
    fx1 = int(np.floor(x1 * sx)); fx2 = int(np.ceil(x2 * sx))
    fy1 = int(np.floor(y1 * sy)); fy2 = int(np.ceil(y2 * sy))
    fx1 = max(0, min(W - 1, fx1)); fx2 = max(1, min(W, fx2))
    fy1 = max(0, min(H - 1, fy1)); fy2 = max(1, min(H, fy2))
    if fx2 <= fx1 or fy2 <= fy1:
        return None
    return (slice(fy1, fy2), slice(fx1, fx2))

def compute_metrics(delta_bchw: torch.Tensor, reduce_mode_hw: str = "mean", topk: int = 32, imgsz: int = 640, bbox_xyxy=None):
    """
    delta_bchw: torch [B,C,H,W] (CPU у тебя)
    reduce_mode_hw: как схлопывать каналы в карту HxW через _reduce_channels (mean/topk_mean/l2/abs_mean/max/pca1)
    """
    if delta_bchw is None:
        return {}

    d = delta_bchw[0].float()          # [C,H,W]
    C, H, W = d.shape

    # A) full Δ (512,20,20)
    mean_signed = float(d.mean().item())
    l2_rms = float((d.pow(2).mean()).sqrt().item())
    abs_mean = float(d.abs().mean().item())
    abs_max = float(d.abs().max().item())

    # channel energies E_c = mean_{h,w}|Δ_c|
    E = d.abs().mean(dim=(1,2)).cpu().numpy().astype(np.float64)  # (C,)
    k = int(min(max(topk, 1), C))
    topk_E = np.partition(E, -k)[-k:]
    chan_energy_topk_mean = float(np.mean(topk_E))
    chan_energy_gini = gini(E)

    # B) reduced HxW map using your existing reducer
    hw_signed = _reduce_channels(d, mode=reduce_mode_hw, topk=topk).detach().cpu().numpy().astype(np.float64)
    hw_abs = hw_signed if reduce_mode_hw in {"l2", "abs_mean"} else np.abs(hw_signed)

    hw_abs_mean = float(hw_abs.mean())
    hw_abs_max  = float(hw_abs.max())
    hw_gini     = gini(hw_abs)

    p90 = float(np.percentile(hw_abs.reshape(-1), 90.0))
    hw_sparsity_p90 = float(np.mean(hw_abs >= p90)) if p90 > 0 else 0.0

    # C) ROI (bbox человека из clean)
    roi_abs_mean = float("nan")
    roi_abs_out_mean = float("nan")
    roi_abs_ratio = float("nan")
    roi = xyxy_to_hw_roi(bbox_xyxy, H=H, W=W, imgsz=imgsz)
    if roi is not None:
        roi_map = hw_abs[roi]
        out_mask = np.ones((H,W), dtype=bool)
        out_mask[roi] = False
        out_map = hw_abs[out_mask]
        roi_abs_mean = float(np.mean(roi_map)) if roi_map.size else float("nan")
        roi_abs_out_mean = float(np.mean(out_map)) if out_map.size else float("nan")
        if np.isfinite(roi_abs_out_mean) and roi_abs_out_mean > 0:
            roi_abs_ratio = float(roi_abs_mean / roi_abs_out_mean)

    return {
        "mean_signed": mean_signed,
        "l2_rms": l2_rms,
        "abs_mean": abs_mean,
        "abs_max": abs_max,
        "chan_energy_topk_mean": chan_energy_topk_mean,
        "chan_energy_gini": chan_energy_gini,
        "hw_abs_mean": hw_abs_mean,
        "hw_abs_max": hw_abs_max,
        "hw_gini": hw_gini,
        "hw_sparsity_p90": hw_sparsity_p90,
        "roi_abs_mean": roi_abs_mean,
        "roi_abs_ratio": roi_abs_ratio,
        "H": H, "W": W, "C": C,
    }

LAYER = "model.22"

rows = []
for d in run_data:
    gi = d.get("gradcam_info", {})
    bbox = gi.get("picked_bbox", None) if isinstance(gi, dict) else None

    m = compute_metrics(
        d.get("deltas", {}).get(LAYER, None),
        reduce_mode_hw="mean",           # попробуй также "topk_mean" / "l2" / "abs_mean"
        topk=int(CFG.topk_channels),
        imgsz=int(CFG.imgsz),
        bbox_xyxy=bbox,
    )

    rows.append({
        "path": d.get("path"),
        "name": Path(d.get("path","")).name,
        "success": bool(d.get("success", False)),
        "conf_clean": float(d.get("conf_clean", 0.0)),
        "conf_patch": float(d.get("conf_patch", 0.0)),
        "drop": float(d.get("drop", 0.0)),
        **m
    })

df = pd.DataFrame(rows)
df.head()

In [None]:
from sklearn.metrics import roc_auc_score, balanced_accuracy_score
import numpy as np
import pandas as pd


def roc_auc_for_metric(df: pd.DataFrame, metric: str, label_col: str = "success") -> tuple[float, int]:
    """
    Return (AUC>=0.5, direction).

    direction = +1: larger metric -> more likely SUCCESS
    direction = -1: smaller metric -> more likely SUCCESS (we flip to keep AUC>=0.5)
    """
    y = df[label_col].astype(int).to_numpy()
    x = df[metric].to_numpy().astype(float)

    m = np.isfinite(x) & np.isfinite(y)
    x = x[m]
    y = y[m]

    if np.unique(y).size < 2:
        return float("nan"), 0

    auc_raw = float(roc_auc_score(y, x))
    direction = +1
    auc = auc_raw
    if auc_raw < 0.5:
        auc = 1.0 - auc_raw
        direction = -1

    return float(auc), int(direction)

def best_balanced_accuracy_for_metric(
    df: pd.DataFrame,
    metric: str,
    label_col: str = "success",
    *,
    direction: int | None = None,
) -> tuple[float, float, int]:
    """Best balanced accuracy achievable by thresholding a single scalar metric.

    Returns:
      (best_bal_acc, best_threshold, used_direction)

    used_direction:
      +1 => predict SUCCESS when x >= thr
      -1 => predict SUCCESS when x <= thr
    """
    y = df[label_col].astype(int).to_numpy()
    x = df[metric].to_numpy().astype(float)

    m = np.isfinite(x) & np.isfinite(y)
    x = x[m]
    y = y[m]

    if x.size == 0 or np.unique(y).size < 2:
        return float("nan"), float("nan"), 0

    if direction is None:
        _auc, direction = roc_auc_for_metric(df.loc[m], metric, label_col=label_col)

    direction = +1 if int(direction) >= 0 else -1

    xs = np.unique(x)
    if xs.size == 1:
        thr = float(xs[0])
        yhat = (x >= thr) if direction == +1 else (x <= thr)
        ba = float(balanced_accuracy_score(y, yhat.astype(int)))
        return ba, thr, int(direction)

    mids = (xs[:-1] + xs[1:]) * 0.5
    thr_candidates = np.concatenate(([xs[0] - 1e-12], mids, [xs[-1] + 1e-12]))

    best_ba = -1.0
    best_thr = float("nan")

    for thr in thr_candidates:
        yhat = (x >= thr) if direction == +1 else (x <= thr)
        ba = float(balanced_accuracy_score(y, yhat.astype(int)))
        if ba > best_ba:
            best_ba = ba
            best_thr = float(thr)

    return float(best_ba), float(best_thr), int(direction)

metric_cols_delta_only = [
    "mean_signed", "l2_rms", "abs_mean", "abs_max",
    "chan_energy_topk_mean", "chan_energy_gini",
    "hw_abs_mean", "hw_abs_max", "hw_gini", "hw_sparsity_p90",
    "roi_abs_mean", "roi_abs_ratio",
]

succ = df[df.success == True]
fail = df[df.success == False]

summary = []
for c in metric_cols_delta_only:
    xs = succ[c].to_numpy()
    xf = fail[c].to_numpy()

    auc, direction = roc_auc_for_metric(df, c, label_col="success")
    bacc, bthr, _ = best_balanced_accuracy_for_metric(df, c, label_col="success", direction=direction)

    summary.append({
        "metric": c,
        "mean_s": float(np.nanmean(xs)),
        "mean_f": float(np.nanmean(xf)),
        "std_s": float(np.nanstd(xs)),
        "std_f": float(np.nanstd(xf)),
        "best_bal_acc": bacc,
        "best_bal_thr": bthr,
        "roc_auc": auc,
        "auc_direction": direction,
    })

summary_df_delta_only = pd.DataFrame(summary).sort_values("best_bal_acc", ascending=False)
summary_df_delta_only

In [None]:
import matplotlib.pyplot as plt
import numpy as np


def _violin_metric(ax, data_s, data_f, title: str):
    """FAIL vs SUCCESS distribution: violin + median + jitter points.

    Лучше boxplot при tiny-std и при дискретных значениях (квантизация).
    """
    data_s = np.asarray(data_s, dtype=float)
    data_f = np.asarray(data_f, dtype=float)

    data_s = data_s[np.isfinite(data_s)]
    data_f = data_f[np.isfinite(data_f)]

    if (data_s.size == 0) and (data_f.size == 0):
        ax.text(0.5, 0.5, "no finite data", ha="center", va="center")
        ax.set_axis_off()
        return

    series = []
    labels = []
    positions = []

    if data_f.size:
        series.append(data_f)
        labels.append(f"FAIL (n={data_f.size})")
        positions.append(1)

    if data_s.size:
        series.append(data_s)
        labels.append(f"SUCCESS (n={data_s.size})")
        positions.append(2 if data_f.size else 1)

    ax.violinplot(
        series,
        positions=positions,
        widths=0.8,
        showmeans=False,
        showmedians=True,
        showextrema=False,
    )

    # jitter + median line (чтобы видно было схлопывание)
    rng = np.random.default_rng(0)
    if data_f.size:
        jf = (rng.random(data_f.size) - 0.5) * 0.16
        ax.scatter(np.full(data_f.size, 1.0) + jf, data_f, s=10, alpha=0.35)
        ax.hlines(np.median(data_f), 0.78, 1.22, linewidth=3)

    if data_s.size:
        xs = 2.0 if data_f.size else 1.0
        js = (rng.random(data_s.size) - 0.5) * 0.16
        ax.scatter(np.full(data_s.size, xs) + js, data_s, s=10, alpha=0.35)
        ax.hlines(np.median(data_s), xs - 0.22, xs + 0.22, linewidth=3)

    ax.set_xticks(positions)
    ax.set_xticklabels(labels)
    ax.set_title(title)
    ax.grid(True, alpha=0.25)


# --- Violins for top metrics by ROC-AUC ---
topM = 20
best = summary_df_delta_only.head(topM)["metric"].tolist()

fig = plt.figure(figsize=(14, 3.2 * ((topM + 1) // 2)))
for i, m in enumerate(best, 1):
    ax = plt.subplot((topM + 1) // 2, 2, i)
    _violin_metric(
        ax,
        succ[m].to_numpy(),
        fail[m].to_numpy(),
        f"{m} | AUC={summary_df_delta_only.set_index('metric').loc[m, 'roc_auc']:.3g} | bAcc={summary_df_delta_only.set_index('metric').loc[m, 'best_bal_acc']:.3g}",
    )
plt.tight_layout()
plt.show()

# --- Bar chart of ROC-AUC for all metrics ---
fig = plt.figure(figsize=(9.5, 0.42 * len(summary_df_delta_only) + 1.5))
order = summary_df_delta_only.sort_values("best_bal_acc", ascending=True)
plt.barh(order["metric"], order["best_bal_acc"])
plt.xlabel("best_bal_acc (>=0.5; higher is better)")
plt.title("Best balanced accuracy for feature-delta metrics (no attribution)")
plt.grid(True, axis="x", alpha=0.25)
plt.tight_layout()
plt.show()

## Δ-only метрики (по тензору разности фич) — что означает каждая

Ниже `Δ = F(patched) − F(clean)` для выбранного слоя, где `Δ` имеет форму `(C, H, W)` (в нашем случае обычно `C=512`, `H=W≈20`).
Важно: эти метрики описывают **только величину/структуру изменения фич**, без сравнения с картой важности.

---

### `mean_signed`
**Среднее значение Δ по всем элементам** (с сохранением знака):
$
\text{mean\_signed}=\frac{1}{CHW}\sum_{c,h,w}\Delta_{c,h,w}
$
- Показывает: общий **сдвиг** фич в плюс/минус.
- Не показывает: локализацию и “энергию” (может быть близко к 0 при больших разнонаправленных изменениях).

---

### `l2_rms`
**RMS-энергия изменения** (аналог L2-нормы на элемент, без знака):
$
\text{l2\_rms}=\sqrt{\frac{1}{CHW}\sum_{c,h,w}\Delta_{c,h,w}^2}
$
- Показывает: **общую силу** изменения фич.
- Не показывает: где именно изменение произошло и в каком направлении (знак теряется).

---

### `abs_mean`
**Средний модуль изменения**:
$
\text{abs\_mean}=\frac{1}{CHW}\sum_{c,h,w}\left|\Delta_{c,h,w}\right|
$
- Показывает: “насколько в среднем шевельнули фичи”.
- Менее чувствителен к редким большим выбросам, чем `abs_max`.

---

### `abs_max`
**Максимальный модуль изменения**:
$
\text{abs\_max}=\max_{c,h,w}\left|\Delta_{c,h,w}\right|
$
- Показывает: наличие **очень сильного локального выброса** (хотя бы в одном канале/пикселе).
- Может быть нестабильным: одна “аномальная” ячейка доминирует метрику.

---

### `chan_energy_topk_mean`
Сначала считаем **энергию по каналам**:
$
E_c=\frac{1}{HW}\sum_{h,w}\left|\Delta_{c,h,w}\right|
$
Далее берём **top-k** каналов по \(E_c\) и усредняем:
$
\text{chan\_energy\_topk\_mean}=\frac{1}{k}\sum_{c \in \text{TopK}(E)} E_c
$
- Показывает: насколько сильно изменены **наиболее затронутые каналы**.
- Это “канальная концентрация” атаки: бьёт ли она в несколько каналов сильно.

---

### `chan_energy_gini`
**Коэффициент Джини** для распределения энергий по каналам \(E_c\) (неотрицательных).
- Показывает: **насколько неравномерно** распределено изменение по каналам.
  - ближе к 0 → изменение “размазано” по многим каналам;
  - ближе к 1 → изменение сосредоточено в малом числе каналов.
- Не показывает: абсолютную величину изменения (может быть высокий Джини при маленькой общей энергии).

---

### `hw_abs_mean`
Строим 2D-карту по пространству `H×W` (через `_reduce_channels`, у тебя здесь `reduce_mode_hw="mean"`):
$
\Delta_{hw} = \text{reduce}_c(\Delta_{c,h,w})
$
и берём среднее по пространству от модуля:
$
\text{hw\_abs\_mean}=\frac{1}{HW}\sum_{h,w}|\Delta_{hw}(h,w)|
$
- Показывает: “среднее изменение по клеткам” после сведения каналов к карте.
- Зависит от выбранного `reduce_mode_hw` (mean/l2/topk_mean/...).

---

### `hw_abs_max`
$
\text{hw\_abs\_max}=\max_{h,w}|\Delta_{hw}(h,w)|
$
- Показывает: самый сильный “пространственный” выброс после сведения каналов.
- Аналогично `abs_max`, но уже после collapse каналов.

---

### `hw_gini`
Коэффициент Джини по **пространственной карте** \( |\Delta_{hw}(h,w)| \).
- Показывает: **насколько локализовано** изменение по H×W:
  - низкий → равномерно по всей карте;
  - высокий → сосредоточено в малом числе клеток.

---

### `hw_sparsity_p90`
Доля клеток, чьё значение ≥ 90-перцентиля карты:
$
p90=\text{percentile}(|\Delta_{hw}|,90),\quad
\text{hw\_sparsity\_p90}=\frac{1}{HW}\sum_{h,w}\mathbf{1}\{|\Delta_{hw}(h,w)|\ge p90\}
$
- Показывает: “разреженность” по хвосту распределения.
- Для непрерывных значений часто будет около 0.10 по определению; информативность появляется, когда карта **квантизована/имеет плато/много нулей**.

---

### `roi_abs_mean`
Берём ROI на сетке фич (проекция bbox человека из clean в (H,W)) и считаем среднее по ROI:
$
\text{roi\_abs\_mean}=\frac{1}{|ROI|}\sum_{(h,w)\in ROI}|\Delta_{hw}(h,w)|
$
- Показывает: **насколько сильно изменились фичи внутри ROI объекта**.
- Не показывает: что происходит вне ROI (поэтому полезно сравнивать с `roi_abs_ratio`).

---

### `roi_abs_ratio`
$
\text{roi\_abs\_ratio}=\frac{\text{mean}(|\Delta_{hw}| \text{ inside ROI})}{\text{mean}(|\Delta_{hw}| \text{ outside ROI})}
$
- Показывает: “насколько атака сфокусирована на ROI”:
  - > 1 → изменение концентрируется в ROI сильнее, чем вне ROI;
  - < 1 → больше “шумит” вне ROI.
- Важно: это **отношение**, поэтому оно может быть большим и при малых абсолютных значениях (если вне ROI почти ноль).

# Смотрим на важность

In [None]:
# --- Minimal cell: visualize SSGrad-CAM (top-3 ROI logits) vs |grad*act| on model.22 ---
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from pathlib import Path
from typing import Dict, Any, Optional, Tuple

# -------------------
# Params (edit if needed)
# -------------------
TARGET_LAYER = "model.22"
NUM_EXAMPLES = 10          # total images shown (tries to split equally success/fail)
TOPK = 3                   # top-k ROI cells (rank by sigmoid, aggregate logits)
IMG_SIZE = int(CFG.imgsz)

# -------------------
# Helpers (minimal)
# -------------------
def _normalize_11hw(m_11hw: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
    # m: (1,1,h,w) -> [0,1]
    m = torch.nan_to_num(m_11hw, nan=0.0, posinf=0.0, neginf=0.0)
    m = m - m.amin(dim=(2, 3), keepdim=True)
    m = m / m.amax(dim=(2, 3), keepdim=True).clamp(min=eps)
    return m

def _roi_mask_on_grid(h: int, w: int, imgsz: int, bbox_xyxy: Tuple[float, float, float, float], device, dtype):
    # mark a cell as inside ROI if its center lies within bbox; return (1,h,w) binary
    x1, y1, x2, y2 = [float(v) for v in bbox_xyxy]
    stride_h = float(imgsz) / float(h)
    stride_w = float(imgsz) / float(w)
    stride = 0.5 * (stride_h + stride_w)

    ys = (torch.arange(h, device=device, dtype=dtype) + 0.5) * stride
    xs = (torch.arange(w, device=device, dtype=dtype) + 0.5) * stride
    yy, xx = torch.meshgrid(ys, xs, indexing="ij")
    m = (xx >= x1) & (xx <= x2) & (yy >= y1) & (yy <= y2)
    mask = m.to(dtype=dtype).unsqueeze(0)  # (1,h,w)

    # tiny bbox fallback: pick closest cell to bbox center
    if float(mask.sum().item()) < 1.0:
        cx, cy = 0.5 * (x1 + x2), 0.5 * (y1 + y2)
        d2 = (xx - cx) ** 2 + (yy - cy) ** 2
        ij = torch.argmin(d2)
        mask = torch.zeros_like(mask)
        mask.view(-1)[ij] = 1.0
    return mask  # (1,h,w)

def _detect_forward_train_levels(model_torch: nn.Module, x_bchw: torch.Tensor):
    # Force Detect head into train() so we avoid inference tensors in some Ultralytics builds.
    detect = get_module_by_name(model_torch, "model.23")
    was_detect_training = bool(detect.training)
    detect.train()
    try:
        out = model_torch(x_bchw)
    finally:
        if not was_detect_training:
            detect.eval()

    pred_levels = out
    if isinstance(out, (tuple, list)) and len(out) == 2 and isinstance(out[0], torch.Tensor):
        pred_levels = out[1]  # some builds return (pred, x)

    if not isinstance(pred_levels, (list, tuple)) or len(pred_levels) == 0:
        raise RuntimeError(f"Unexpected Detect output type: {type(pred_levels)}")

    reg_max = int(getattr(detect, "reg_max", 16))
    nc = int(getattr(detect, "nc", 80))
    return pred_levels, reg_max, nc

def _person_logit_map_from_level(t: torch.Tensor, reg_max: int, nc: int, target_class_id: int) -> torch.Tensor:
    # t: (B,C,h,w) -> (B,h,w) person logits
    cls_logits = t[:, reg_max * 4: reg_max * 4 + nc, :, :]
    return cls_logits[:, int(target_class_id), :, :]

def target_topk_roi_logits_scalar(
    model_torch: nn.Module,
    x_bchw: torch.Tensor,
    target_class_id: int,
    imgsz: int,
    bbox_xyxy: Tuple[float, float, float, float],
    topk: int = 3,
) -> torch.Tensor:
    """
    y = mean(logits of top-k ROI cells), where ranking is by sigmoid(prob) but aggregation is logits.
    """
    pred_levels, reg_max, nc = _detect_forward_train_levels(model_torch, x_bchw)

    all_logits = []
    all_probs = []

    for t in pred_levels:
        if not (isinstance(t, torch.Tensor) and t.ndim == 4):
            continue
        pl = _person_logit_map_from_level(t, reg_max, nc, target_class_id)  # (B,h,w)
        roi_bin = _roi_mask_on_grid(pl.shape[1], pl.shape[2], imgsz, bbox_xyxy, device=pl.device, dtype=pl.dtype)  # (1,h,w)
        mask = roi_bin[0].reshape(-1) > 0

        if mask.sum().item() < 1:
            continue

        lr = pl.reshape(pl.shape[0], -1)[:, mask]   # (B, n_roi)
        pr = lr.sigmoid()                           # (B, n_roi)
        all_logits.append(lr)
        all_probs.append(pr)

    if len(all_logits) == 0:
        raise RuntimeError("No ROI cells found across pred levels.")

    L = torch.cat(all_logits, dim=1)  # (B, total_roi)
    P = torch.cat(all_probs, dim=1)   # (B, total_roi)

    k = int(min(int(topk), P.shape[1]))
    idx = torch.topk(P, k=k, dim=1, largest=True, sorted=False).indices  # (B,k)
    chosen_logits = torch.gather(L, dim=1, index=idx)                    # (B,k)
    y = chosen_logits.mean(dim=1).mean()                                 # scalar
    return y

def layer_act_and_grad_for_y(
    model_torch: nn.Module,
    x_bchw: torch.Tensor,
    target_layer_name: str,
    y_fn,                       # callable(model_torch, x)->scalar tensor
) -> Tuple[torch.Tensor, torch.Tensor, float]:
    device = next(model_torch.parameters()).device
    dtype = next(model_torch.parameters()).dtype

    x = x_bchw.to(device=device, dtype=dtype)
    x.requires_grad_(True)

    layer = get_module_by_name(model_torch, target_layer_name)
    buf: Dict[str, torch.Tensor] = {}

    def _hook(_m, _inp, out):
        if isinstance(out, (list, tuple)):
            for o in out:
                if isinstance(o, torch.Tensor):
                    out = o
                    break
        if isinstance(out, torch.Tensor) and out.ndim == 4:
            buf["A"] = out
            try:
                out.retain_grad()
            except Exception:
                pass

            def _save_grad(g):
                if isinstance(g, torch.Tensor) and g.ndim == 4:
                    buf["G"] = g
                return g

            try:
                out.register_hook(_save_grad)
            except Exception:
                pass

    h = layer.register_forward_hook(_hook)
    was_training = bool(model_torch.training)

    try:
        model_torch.zero_grad(set_to_none=True)
        model_torch.train()  # keep autograd-friendly path

        y = y_fn(model_torch, x)
        y.backward()

        if "A" not in buf:
            raise RuntimeError("Failed to capture activation A.")
        if "G" not in buf:
            Ag = getattr(buf["A"], "grad", None)
            if isinstance(Ag, torch.Tensor) and Ag.ndim == 4:
                buf["G"] = Ag
            else:
                raise RuntimeError("Failed to capture gradient G.")

        A = buf["A"].detach().cpu().to(torch.float32)
        G = buf["G"].detach().cpu().to(torch.float32)
        return A, G, float(y.detach().cpu().item())

    finally:
        h.remove()
        if not was_training:
            model_torch.eval()
        x.requires_grad_(False)

def ssgradcam_raw(A_bchw: torch.Tensor, G_bchw: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
    """
    SSGrad-CAM (raw grid):
      w_k = GAP(G_k)
      S_k = |G_k| / max(|G_k|)
      cam = ReLU( sum_k (w_k * A_k) ∘ S_k )
    Returns: (1,1,h,w) normalized to [0,1]
    """
    w = G_bchw.mean(dim=(2, 3), keepdim=True)                 # (1,C,1,1)
    S = G_bchw.abs()
    S = S / S.amax(dim=(2, 3), keepdim=True).clamp(min=eps)   # (1,C,h,w)
    cam = ((w * A_bchw) * S).sum(dim=1, keepdim=True)         # (1,1,h,w)
    cam = F.relu(cam)
    return _normalize_11hw(cam, eps=eps)

def gradxact_raw(A_bchw: torch.Tensor, G_bchw: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
    # |grad * act| energy map on the same grid: (1,1,h,w) in [0,1]
    m = (G_bchw * A_bchw).abs().sum(dim=1, keepdim=True)
    return _normalize_11hw(m, eps=eps)

def _overlay_pair(pil_img, m1_11hw, m2_11hw, title1, title2, suptitle="", imgsz=640):
    img = np.asarray(pil_img.convert("RGB"))

    def up(h):
        return F.interpolate(h.detach().cpu(), size=(imgsz, imgsz), mode="nearest")[0, 0].numpy()

    h1 = up(m1_11hw)
    h2 = up(m2_11hw)

    plt.figure(figsize=(14, 5))

    plt.subplot(1, 2, 1)
    plt.imshow(img)
    plt.imshow(h1, alpha=0.45, vmin=0.0, vmax=1.0)
    plt.title(title1)
    plt.axis("off")

    plt.subplot(1, 2, 2)
    plt.imshow(img)
    plt.imshow(h2, alpha=0.45, vmin=0.0, vmax=1.0)
    plt.title(title2)
    plt.axis("off")

    if suptitle:
        plt.suptitle(suptitle)

    plt.tight_layout()
    plt.show()

# -------------------
# Pick balanced examples (equal successes & fails)
# -------------------
succ = [d for d in run_data if bool(d.get("success", False))]
fail = [d for d in run_data if not bool(d.get("success", False))]

n_each = max(1, NUM_EXAMPLES // 2)
n_each = min(n_each, len(succ), len(fail))
examples = succ[:n_each] + fail[:n_each]

print(f"[info] Using {len(examples)} examples: success={n_each}, fail={n_each}")

# -------------------
# Run + plot
# -------------------
for d in examples:
    gcinfo = d.get("gradcam_info", {})
    bbox = gcinfo.get("picked_bbox", None)
    if bbox is None:
        print("[skip] no picked_bbox for", d.get("path", "<?>"))
        continue

    x_clean = torch_preprocess_letterboxed(d["clean_lb"], device=CFG.device, dtype=_MODEL_DTYPE)

    def _y_fn(model, x):
        return target_topk_roi_logits_scalar(
            model, x,
            target_class_id=int(target_class_id),
            imgsz=int(IMG_SIZE),
            bbox_xyxy=tuple(float(v) for v in bbox),
            topk=int(TOPK),
        )

    A, G, y = layer_act_and_grad_for_y(model_torch, x_clean, TARGET_LAYER, _y_fn)

    ssg = ssgradcam_raw(A, G)
    gxa = gradxact_raw(A, G)

    title = (
        f"{Path(d['path']).name} | success={bool(d['success'])} | drop={float(d['drop']):.3f} | "
        f"clean={float(d['conf_clean']):.3f} patched={float(d['conf_patch']):.3f} | y(top-{TOPK} ROI logits)={y:.3f}"
    )
    print(title)

    _overlay_pair(
        d["clean_lb"],
        ssg, gxa,
        title1=f"SSGrad-CAM (top-{TOPK} ROI logits)",
        title2="|grad*act| (same y)",
        suptitle=title,
        imgsz=int(IMG_SIZE),
    )

## SSGrad-CAM vs Grad-CAM: математика и интуиция (в терминах нашего кода)

### Обозначения
- Выбираем слой (например, `model.22`) и фиксируем вход (clean).
- Активации слоя:  
  $
  A \in \mathbb{R}^{1\times C\times H\times W}
  $
- Градиент целевой скалярной функции по активациям:  
  $
  G = \frac{\partial y}{\partial A} \in \mathbb{R}^{1\times C\times H\times W}
  $
- Целевая функция `y` у нас **не “класс-скор всей картинки”**, а скаляр, построенный из логитов детекции человека **в ROI**:  
  мы берём по всем уровням головы logits карты для класса `person`, ограничиваемся ROI, выбираем top-K ячеек по `sigmoid(prob)`, и агрегируем **их logits** (обычно mean).  
  Формально можно писать:
  $
  y = \frac{1}{K}\sum_{i\in \text{TopK}(\sigma(L_{ROI}))} L_i
  $
  где \(L\) — логит-карта (по всем ROI-ячеек в head features), \(\sigma\) — sigmoid, TopK выбирается по вероятностям, но агрегируются logits.

---

## Grad-CAM (классика)
1) Считаем **канальные веса** через spatial average pooling градиента:
$
w_k = \mathrm{GAP}(G_k) = \frac{1}{HW}\sum_{h=1}^{H}\sum_{w=1}^{W} G_{k,h,w}
$
где \(k\in\{1,\dots,C\}\).

2) Строим карту:
$
\mathrm{CAM}_{h,w} = \mathrm{ReLU}\Big(\sum_{k=1}^{C} w_k \, A_{k,h,w}\Big)
$

3) Нормируем в \([0,1]\) для визуализации.

### Интуиция Grad-CAM
- Градиент говорит, **какие каналы** важны для увеличения целевой функции \(y\) (через \(w_k\)).
- Активации \(A_{k,h,w}\) говорят, **где** эти каналы “светятся”.
- Итоговая карта: “где в пространстве находятся признаки, важные для \(y\)”.

Ключевое ограничение: вся пространственная структура градиента внутри канала **сжимается в одно число** \(w_k\).  
То есть Grad-CAM очень хорошо отвечает на “какие каналы важны”, но хуже — на “в каких точках карта действительно чувствительна”.

---

## SSGrad-CAM (как в нашем коде)
SSGrad добавляет **пространственный гейт** на основе модуля градиента.

1) Те же веса \(w_k\) (как в Grad-CAM):
$
w_k = \mathrm{GAP}(G_k)
$

2) Строим **spatial scaling** (нормированный модуль градиента по каждому каналу):
$
S_{k,h,w} = \frac{|G_{k,h,w}|}{\max_{h,w}|G_{k,h,w}| + \varepsilon}
$
Здесь \(S_{k,h,w}\in[0,1]\), и он “внутри канала” подсвечивает места, где градиент большой.

3) Карта SSGrad:
$
\mathrm{SSCAM}_{h,w} = \mathrm{ReLU}\Big(\sum_{k=1}^{C} \big(w_k A_{k,h,w}\big)\cdot S_{k,h,w}\Big)
$

4) Нормируем в \([0,1]\).

### Интуиция SSGrad
- \(w_k\) отвечает за “важность канала” (как в Grad-CAM),
- \(S_{k,h,w}\) отвечает за “где внутри канала модель действительно **чувствительна** к \(y\)”,
- поэтому SSGrad — это Grad-CAM, **дополнительно подавленный там, где градиенты малы**.

Практический эффект:
- карты обычно становятся **более локализованными** и менее “размазанными”;
- SSGrad лучше отражает “что надо пошевелить, чтобы поменять \(y\)”, а не просто “что активно”.

---

## Чем это отличается от |grad * act|
В нашем коде есть ещё карта:
$
M_{h,w} = \sum_{k=1}^{C} |G_{k,h,w}\cdot A_{k,h,w}|
$
Это чистая “энергия чувствительности”: одновременно учитывает и силу активаций, и силу градиентов, без каналных весов и без ReLU по сумме.

Интуитивно:
- **|grad*act|** — “где одновременно много активации и большой градиент” (часто очень контрастно).
- **Grad-CAM** — “где активны важные каналы” (более гладко).
- **SSGrad-CAM** — “где активны важные каналы И где градиент внутри канала реально большой” (обычно более точечно).

---

## Важный нюанс: что именно мы “объясняем”
Поскольку \(y\) у нас = top-K ROI logits класса `person`, обе карты (Grad-CAM/SSGrad) объясняют **не весь детектор**, а именно:
- “какие области входа влияют на логиты человека **в ROI** по выбранной схеме агрегации”.

Поэтому эти карты корректно сравнивать с метриками типа “Δ в ROI / вне ROI” и “совпадение Δ с важностью”, но не стоит ожидать, что они всегда совпадут с NMS/финальным bbox после инференса.

In [None]:
from tqdm.auto import tqdm
import numpy as np

def compute_importance_tensors_for_all(
    run_data,
    model_torch,
    target_layer: str,
    target_class_id: int,
    imgsz: int,
    topk_roi: int = 3,
    force_recompute: bool = False,
):
    """
    Для каждого примера (clean) считаем и сохраняем:
      - imp_ssgrad_hw : (H,W) float32  [0..1]
      - imp_gxa_hw    : (H,W) float32  [0..1]  (channel-summed |grad*act|)
      - imp_gxa_chw   : (C,H,W) float32 >=0    (per-channel |grad*act|)
      - imp_y_clean   : float (твоя ROI-целевая функция)
    """
    for d in tqdm(run_data, desc="importance", total=len(run_data)):
        if (not force_recompute) and ("imp_gxa_chw" in d) and ("imp_ssgrad_hw" in d):
            continue

        gcinfo = d.get("gradcam_info", {})
        bbox = gcinfo.get("picked_bbox", None) if isinstance(gcinfo, dict) else None
        if bbox is None:
            d["imp_ssgrad_hw"] = None
            d["imp_gxa_hw"] = None
            d["imp_gxa_chw"] = None
            d["imp_y_clean"] = float("nan")
            continue

        x_clean = torch_preprocess_letterboxed(d["clean_lb"], device=CFG.device, dtype=_MODEL_DTYPE)

        def _y_fn(model, x):
            return target_topk_roi_logits_scalar(
                model, x,
                target_class_id=int(target_class_id),
                imgsz=int(imgsz),
                bbox_xyxy=tuple(float(v) for v in bbox),
                topk=int(topk_roi),
            )

        A, G, y = layer_act_and_grad_for_y(model_torch, x_clean, target_layer, _y_fn)  # (1,C,H,W)

        # SSGrad: (1,1,H,W) in [0,1]
        ssg_11hw = ssgradcam_raw(A, G)
        # grad*act HxW: (1,1,H,W) in [0,1]
        gxa_11hw = gradxact_raw(A, G)

        # grad*act full: (C,H,W) >= 0
        gxa_chw = (G[0] * A[0]).abs().to(torch.float32)

        d["imp_ssgrad_hw"] = ssg_11hw[0, 0].numpy().astype(np.float32)
        d["imp_gxa_hw"] = gxa_11hw[0, 0].numpy().astype(np.float32)
        d["imp_gxa_chw"] = gxa_chw.numpy().astype(np.float32)
        d["imp_y_clean"] = float(y)

# запуск
compute_importance_tensors_for_all(
    run_data=run_data,
    model_torch=model_torch,
    target_layer="model.22",
    target_class_id=int(target_class_id),
    imgsz=int(CFG.imgsz),
    topk_roi=3,
    force_recompute=False,
)

In [None]:
import numpy as np
import pandas as pd
from scipy.stats import spearmanr

def _safe_l2(a: np.ndarray, b: np.ndarray, eps: float = 1e-12) -> float:
    """Euclidean (L2) distance between flattened arrays, ignoring non-finite entries."""
    a = np.asarray(a, dtype=np.float64).reshape(-1)
    b = np.asarray(b, dtype=np.float64).reshape(-1)
    m = np.isfinite(a) & np.isfinite(b)
    a = a[m]; b = b[m]
    if a.size == 0:
        return float("nan")
    return float(np.linalg.norm(a - b))


def _safe_l2_rel(a: np.ndarray, b: np.ndarray, eps: float = 1e-12) -> float:
    """Globally-normalized L2: ||a-b|| / (||a|| + ||b||)."""
    a = np.asarray(a, dtype=np.float64).reshape(-1)
    b = np.asarray(b, dtype=np.float64).reshape(-1)
    m = np.isfinite(a) & np.isfinite(b)
    a = a[m]; b = b[m]
    if a.size == 0:
        return float("nan")
    na = float(np.linalg.norm(a))
    nb = float(np.linalg.norm(b))
    den = max(na + nb, eps)
    return float(np.linalg.norm(a - b) / den)

def _safe_cosine(a: np.ndarray, b: np.ndarray, eps: float = 1e-12) -> float:
    a = np.asarray(a, dtype=np.float64).reshape(-1)
    b = np.asarray(b, dtype=np.float64).reshape(-1)
    m = np.isfinite(a) & np.isfinite(b)
    a = a[m]; b = b[m]
    if a.size == 0:
        return float("nan")
    na = float(np.linalg.norm(a))
    nb = float(np.linalg.norm(b))
    if na <= eps or nb <= eps:
        return float("nan")
    return float(np.dot(a, b) / (na * nb))

def _topq_energy_frac(delta_hw: np.ndarray, imp_hw: np.ndarray, q: float = 10.0, eps: float = 1e-12) -> float:
    """Доля энергии |Δ| в top-q% наиболее важных ячейках."""
    d = np.abs(np.asarray(delta_hw, dtype=np.float64))
    w = np.clip(np.asarray(imp_hw, dtype=np.float64), 0.0, None)
    if d.shape != w.shape:
        raise ValueError(f"shape mismatch: delta {d.shape} vs imp {w.shape}")
    flat_w = w.reshape(-1)
    if flat_w.size == 0:
        return float("nan")
    thr = np.percentile(flat_w, 100.0 - float(q))
    mask = w >= thr
    num = float(d[mask].sum())
    den = float(d.sum())
    return float(num / den) if den > eps else 0.0

def _jaccard_topk(a: np.ndarray, b: np.ndarray, k: int = 32) -> float:
    a = np.asarray(a, dtype=np.float64).reshape(-1)
    b = np.asarray(b, dtype=np.float64).reshape(-1)
    if a.size == 0 or b.size == 0:
        return float("nan")
    k = int(min(max(k, 1), a.size, b.size))
    ia = set(np.argpartition(a, -k)[-k:].tolist())
    ib = set(np.argpartition(b, -k)[-k:].tolist())
    inter = len(ia & ib)
    union = len(ia | ib)
    return float(inter / union) if union else float("nan")

def compute_delta_importance_metrics(
    delta_bchw: torch.Tensor,
    imp_ssgrad_hw,
    imp_gxa_hw,
    imp_gxa_chw,
    reduce_mode_hw: str = "l2",
    topk: int = 32,
    topq: float = 10.0,
) -> dict:
    if delta_bchw is None:
        return {}

    d_chw = delta_bchw[0].float()  # (C,H,W)
    C, H, W = d_chw.shape

    # |Δ| full
    delta_abs_chw = d_chw.abs().detach().cpu().numpy().astype(np.float32)

    # |Δ|_hw (мagnitude карта для сопоставления с HxW importance)
    delta_hw = _reduce_channels(d_chw, mode=reduce_mode_hw, topk=topk).detach().cpu().numpy().astype(np.float32)
    delta_hw_abs = delta_hw if reduce_mode_hw in {"l2", "abs_mean"} else np.abs(delta_hw)

    out = {}

    # --- HxW: SSGrad ---
    if imp_ssgrad_hw is not None:
        out["cos_hw_ssgrad"] = _safe_cosine(delta_hw_abs, imp_ssgrad_hw)
        out["topq_energy_frac_ssgrad"] = _topq_energy_frac(delta_hw_abs, imp_ssgrad_hw, q=topq)
        out["l2_hw_ssgrad"] = _safe_l2(delta_hw_abs, imp_ssgrad_hw)
        out["l2rel_hw_ssgrad"] = _safe_l2_rel(delta_hw_abs, imp_ssgrad_hw)
    else:
        out["cos_hw_ssgrad"] = float("nan")
        out["topq_energy_frac_ssgrad"] = float("nan")
        out["l2_hw_ssgrad"] = float("nan")
        out["l2rel_hw_ssgrad"] = float("nan")

    # --- HxW: grad*act (collapsed) ---
    if imp_gxa_hw is not None:
        out["cos_hw_gxa"] = _safe_cosine(delta_hw_abs, imp_gxa_hw)
        out["topq_energy_frac_gxa"] = _topq_energy_frac(delta_hw_abs, imp_gxa_hw, q=topq)
        out["l2_hw_gxa"] = _safe_l2(delta_hw_abs, imp_gxa_hw)
        out["l2rel_hw_gxa"] = _safe_l2_rel(delta_hw_abs, imp_gxa_hw)
    else:
        out["cos_hw_gxa"] = float("nan")
        out["topq_energy_frac_gxa"] = float("nan")
        out["l2_hw_gxa"] = float("nan")
        out["l2rel_hw_gxa"] = float("nan")

    # --- CxHxW: grad*act (full) ---
    if imp_gxa_chw is not None:
        gxa = np.asarray(imp_gxa_chw, dtype=np.float32)
        if gxa.shape != delta_abs_chw.shape:
            raise ValueError(f"gxa_chw shape {gxa.shape} != delta_chw shape {delta_abs_chw.shape}")

        out["cos_chw_gxa"] = _safe_cosine(delta_abs_chw, gxa)
        out["l2_chw_gxa"] = _safe_l2(delta_abs_chw, gxa)
        out["l2rel_chw_gxa"] = _safe_l2_rel(delta_abs_chw, gxa)

        E_delta = delta_abs_chw.reshape(C, -1).mean(axis=1)
        E_gxa = gxa.reshape(C, -1).mean(axis=1)

        rho = spearmanr(E_delta, E_gxa).correlation
        out["spearman_chan_energy_gxa"] = float(rho) if rho is not None else float("nan")
        out["jaccard_topk_chan_gxa"] = _jaccard_topk(E_delta, E_gxa, k=topk)
    else:
        out["cos_chw_gxa"] = float("nan")
        out["spearman_chan_energy_gxa"] = float("nan")
        out["jaccard_topk_chan_gxa"] = float("nan")
        out["l2_chw_gxa"] = float("nan")
        out["l2rel_chw_gxa"] = float("nan")

    return out


# Добавляем новые колонки в df
LAYER = "model.22"
new_rows = []
for d in run_data:
    nm = compute_delta_importance_metrics(
        delta_bchw=d.get("deltas", {}).get(LAYER, None),
        imp_ssgrad_hw=d.get("imp_ssgrad_hw", None),
        imp_gxa_hw=d.get("imp_gxa_hw", None),
        imp_gxa_chw=d.get("imp_gxa_chw", None),
        reduce_mode_hw="l2",
        topk=int(CFG.topk_channels),
        topq=10.0,
    )
    new_rows.append(nm)

new_df = pd.DataFrame(new_rows)
for c in new_df.columns:
    df[c] = new_df[c]

print("Added:", list(new_df.columns))
df[["success", "drop"] + list(new_df.columns)].head(10)

In [None]:
from sklearn.metrics import roc_auc_score, balanced_accuracy_score
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

def roc_auc_for_metric(df: pd.DataFrame, metric: str, label_col: str = "success") -> tuple[float, int]:
    """
    Return (AUC>=0.5, direction).
    direction = +1: larger metric -> more likely SUCCESS
    direction = -1: smaller metric -> more likely SUCCESS
    """
    y = df[label_col].astype(int).to_numpy()
    x = df[metric].to_numpy().astype(float)

    m = np.isfinite(x) & np.isfinite(y)
    x = x[m]; y = y[m]
    if np.unique(y).size < 2:
        return float("nan"), 0

    auc_raw = float(roc_auc_score(y, x))
    direction = +1
    auc = auc_raw
    if auc_raw < 0.5:
        auc = 1.0 - auc_raw
        direction = -1
    return float(auc), int(direction)

def best_balanced_accuracy_for_metric(
    df: pd.DataFrame,
    metric: str,
    label_col: str = "success",
    *,
    direction: int | None = None,
) -> tuple[float, float, int]:
    """Best balanced accuracy achievable by thresholding a single scalar metric.

    Returns:
      (best_bal_acc, best_threshold, used_direction)

    used_direction:
      +1 => predict SUCCESS when x >= thr
      -1 => predict SUCCESS when x <= thr
    """
    y = df[label_col].astype(int).to_numpy()
    x = df[metric].to_numpy().astype(float)

    m = np.isfinite(x) & np.isfinite(y)
    x = x[m]
    y = y[m]

    if x.size == 0 or np.unique(y).size < 2:
        return float("nan"), float("nan"), 0

    if direction is None:
        _auc, direction = roc_auc_for_metric(df.loc[m], metric, label_col=label_col)

    direction = +1 if int(direction) >= 0 else -1

    xs = np.unique(x)
    if xs.size == 1:
        thr = float(xs[0])
        yhat = (x >= thr) if direction == +1 else (x <= thr)
        ba = float(balanced_accuracy_score(y, yhat.astype(int)))
        return ba, thr, int(direction)

    mids = (xs[:-1] + xs[1:]) * 0.5
    thr_candidates = np.concatenate(([xs[0] - 1e-12], mids, [xs[-1] + 1e-12]))

    best_ba = -1.0
    best_thr = float("nan")

    for thr in thr_candidates:
        yhat = (x >= thr) if direction == +1 else (x <= thr)
        ba = float(balanced_accuracy_score(y, yhat.astype(int)))
        if ba > best_ba:
            best_ba = ba
            best_thr = float(thr)

    return float(best_ba), float(best_thr), int(direction)

# --- ONLY NEW: Δ vs importance (SSGrad / grad*act) ---
metric_cols = [
    # SSGrad (HxW)
    "cos_hw_ssgrad",
    "topq_energy_frac_ssgrad",
    "l2_hw_ssgrad",
    "l2rel_hw_ssgrad",

    # grad*act (HxW)
    "cos_hw_gxa",
    "topq_energy_frac_gxa",
    "l2_hw_gxa",
    "l2rel_hw_gxa",

    # grad*act (CxHxW)
    "cos_chw_gxa",
    "l2_chw_gxa",
    "l2rel_chw_gxa",
    "spearman_chan_energy_gxa",
    "jaccard_topk_chan_gxa",
]

succ = df[df.success == True]
fail = df[df.success == False]

summary = []
for c in metric_cols:
    if c not in df.columns:
        continue
    xs = succ[c].to_numpy()
    xf = fail[c].to_numpy()
    auc, direction = roc_auc_for_metric(df, c, label_col="success")
    bacc, bthr, _ = best_balanced_accuracy_for_metric(df, c, label_col="success", direction=direction)

    summary.append({
        "metric": c,
        "mean_s": float(np.nanmean(xs)),
        "mean_f": float(np.nanmean(xf)),
        "std_s": float(np.nanstd(xs)),
        "std_f": float(np.nanstd(xf)),
        "best_bal_acc": bacc,
        "best_bal_thr": bthr,
        "roc_auc": auc,
        "auc_direction": direction,
    })

summary_df = pd.DataFrame(summary).sort_values("roc_auc", ascending=False)
summary_df

In [None]:
import matplotlib.pyplot as plt

def _metric_family(m: str) -> str:
    if m.endswith("_ssgrad"):
        return "SSGrad (HxW)"
    if m.endswith("_gxa") and m.startswith("cos_hw"):
        return "grad*act (HxW)"
    if m.endswith("_gxa") and (m.startswith("cos_chw") or m.startswith("spearman") or m.startswith("jaccard")):
        return "grad*act (CxHxW)"
    return "Δ vs importance"

import numpy as np

def _violin_metric(ax, data_s, data_f, title: str):
    """Robust distribution viz for FAIL vs SUCCESS.

    Uses violin + median lines + jittered scatter.
    Works well when std is tiny (boxplots collapse to a line) or values are discretized.
    """
    data_s = np.asarray(data_s, dtype=float)
    data_f = np.asarray(data_f, dtype=float)

    data_s = data_s[np.isfinite(data_s)]
    data_f = data_f[np.isfinite(data_f)]

    if (data_s.size == 0) and (data_f.size == 0):
        ax.text(0.5, 0.5, "no finite data", ha="center", va="center")
        ax.set_axis_off()
        return

    series = []
    labels = []
    positions = []

    if data_f.size:
        series.append(data_f)
        labels.append(f"FAIL (n={data_f.size})")
        positions.append(1)

    if data_s.size:
        series.append(data_s)
        labels.append(f"SUCCESS (n={data_s.size})")
        positions.append(2 if data_f.size else 1)

    # violin
    vp = ax.violinplot(
        series,
        positions=positions,
        widths=0.8,
        showmeans=False,
        showmedians=True,
        showextrema=False,
    )

    # jittered scatter to show discretization / collapsed distributions
    rng = np.random.default_rng(0)
    if data_f.size:
        jf = (rng.random(data_f.size) - 0.5) * 0.16
        ax.scatter(np.full(data_f.size, 1.0) + jf, data_f, s=10, alpha=0.35)
        ax.hlines(np.median(data_f), 0.78, 1.22, linewidth=3)

    if data_s.size:
        xs = 2.0 if data_f.size else 1.0
        js = (rng.random(data_s.size) - 0.5) * 0.16
        ax.scatter(np.full(data_s.size, xs) + js, data_s, s=10, alpha=0.35)
        ax.hlines(np.median(data_s), xs - 0.22, xs + 0.22, linewidth=3)

    ax.set_xticks(positions)
    ax.set_xticklabels(labels)
    ax.set_title(title)
    ax.grid(True, alpha=0.25)

topM = 12

# Plot only the new metrics (same order as metric_cols; fall back to summary_df ordering)
if "metric_cols" in globals():
    allowed = set(metric_cols)
    order_df = summary_df[summary_df["metric"].isin(allowed)].sort_values("best_bal_acc", ascending=False)
    best = order_df.head(topM)["metric"].tolist()
else:
    best = summary_df.sort_values("best_bal_acc", ascending=False).head(topM)["metric"].tolist()

succ = df[df.success == True]
fail = df[df.success == False]

fig = plt.figure(figsize=(14, 3.2 * ((len(best) + 1) // 2)))
for i, m in enumerate(best, 1):
    ax = plt.subplot((len(best) + 1) // 2, 2, i)
    fam = _metric_family(m)
    _violin_metric(
        ax,
        succ[m].to_numpy(),
        fail[m].to_numpy(),
        f"{m} | AUC={summary_df.set_index('metric').loc[m, 'roc_auc']:.3g} | bAcc={summary_df.set_index('metric').loc[m, 'best_bal_acc']:.3g}",
    )
plt.tight_layout()
plt.show()

fig = plt.figure(figsize=(9.5, 0.42 * len(summary_df) + 1.5))
order = summary_df.sort_values("best_bal_acc", ascending=True)
plt.barh(order["metric"], order["best_bal_acc"])
plt.xlabel("best_bal_acc (>=0.5; higher is better)")
plt.title("best_bal_acc for Δ vs importance metrics")
plt.grid(True, axis="x", alpha=0.25)
plt.tight_layout()
plt.show()

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# -------------------------
# Helpers: top-K energy stats
# -------------------------

def _topk_indices(x_flat: np.ndarray, k: int):
    x = np.asarray(x_flat, dtype=np.float64).reshape(-1)
    n = x.size
    if n == 0:
        return np.array([], dtype=np.int64)
    k = int(min(max(int(k), 1), n))
    return np.argpartition(x, -k)[-k:]

def _energy_frac_in_topk(delta_abs_flat: np.ndarray, imp_flat: np.ndarray, k: int, eps: float = 1e-12):
    """
    frac   = sum(|Δ| over topK(imp)) / sum(|Δ|)
    enrich = frac / (K/N)
    """
    d = np.asarray(delta_abs_flat, dtype=np.float64).reshape(-1)
    w = np.asarray(imp_flat, dtype=np.float64).reshape(-1)
    m = np.isfinite(d) & np.isfinite(w)
    d = d[m]; w = w[m]
    n = d.size
    if n == 0:
        return float("nan"), float("nan")

    den = float(d.sum())
    if den <= eps:
        return 0.0, 0.0

    idx = _topk_indices(w, k)
    num = float(d[idx].sum())
    frac = float(num / den)

    k_eff = int(min(max(int(k), 1), n))
    base = float(k_eff / n)
    enrich = float(frac / max(base, eps))
    return frac, enrich

def _jaccard_topk_sets(a_flat: np.ndarray, b_flat: np.ndarray, k: int):
    a = np.asarray(a_flat, dtype=np.float64).reshape(-1)
    b = np.asarray(b_flat, dtype=np.float64).reshape(-1)
    m = np.isfinite(a) & np.isfinite(b)
    a = a[m]; b = b[m]
    if a.size == 0:
        return float("nan")
    k = int(min(max(int(k), 1), a.size))
    ia = set(_topk_indices(a, k).tolist())
    ib = set(_topk_indices(b, k).tolist())
    inter = len(ia & ib)
    union = len(ia | ib)
    return float(inter / union) if union else float("nan")

# -------------------------
# Compute per-image metrics for multiple K
# -------------------------
assert "run_data" in globals() and "df" in globals(), "Need run_data and df."

LAYER = "model.22"

# Для CHW: всего 512*20*20 ~ 204_800 элементов → K можно брать больше
K_LIST_CHW = [200, 500, 1000, 2000, 5000]

# Для HW: всего 20*20 = 400 элементов → K поменьше
K_LIST_HW  = [10, 25, 50, 100, 200]

rows = []
for d in run_data:
    delta = d.get("deltas", {}).get(LAYER, None)
    if delta is None:
        rows.append({})
        continue

    d_chw = delta[0].float()  # (C,H,W)
    C, H, W = d_chw.shape
    delta_abs_chw = d_chw.abs().detach().cpu().numpy().astype(np.float32)
    delta_abs_flat = delta_abs_chw.reshape(-1)

    # Importance tensors (already computed earlier)
    imp_chw     = d.get("imp_gxa_chw", None)     # (C,H,W) >=0
    imp_gxa_hw  = d.get("imp_gxa_hw", None)      # (H,W)   in [0,1]
    imp_ss_hw   = d.get("imp_ssgrad_hw", None)   # (H,W)   in [0,1]

    out = {}

    # ---- grad*act full (CxHxW): energy in top-K important features ----
    if imp_chw is not None:
        imp_flat = np.asarray(imp_chw, dtype=np.float32).reshape(-1)
        for k in K_LIST_CHW:
            frac, enrich = _energy_frac_in_topk(delta_abs_flat, imp_flat, k=k)
            out[f"focus_frac__gxa_chw_k{k}"] = frac
            out[f"enrichE_impTopK_gxa_chw_k{k}"] = enrich
            out[f"overlap_jacc__gxa_chw_k{k}"] = _jaccard_topk_sets(delta_abs_flat, imp_flat, k=k)
    else:
        for k in K_LIST_CHW:
            out[f"focus_frac__gxa_chw_k{k}"] = float("nan")
            out[f"enrichE_impTopK_gxa_chw_k{k}"] = float("nan")
            out[f"overlap_jacc__gxa_chw_k{k}"] = float("nan")

    # ---- collapsed HxW: compare |Δ|_hw to HxW importance ----
    # magnitude map for Δ on HxW grid (same reducer as elsewhere)
    delta_hw = _reduce_channels(d_chw, mode="l2", topk=int(CFG.topk_channels)).detach().cpu().numpy().astype(np.float32)
    delta_hw_flat = delta_hw.reshape(-1)  # non-negative

    if imp_gxa_hw is not None:
        imp_flat = np.asarray(imp_gxa_hw, dtype=np.float32).reshape(-1)
        for k in K_LIST_HW:
            frac, enrich = _energy_frac_in_topk(delta_hw_flat, imp_flat, k=k)
            out[f"focus_frac__gxa_hw_k{k}"] = frac
            out[f"enrichE_impTopK_gxa_hw_k{k}"] = enrich
    else:
        for k in K_LIST_HW:
            out[f"focus_frac__gxa_hw_k{k}"] = float("nan")
            out[f"enrichE_impTopK_gxa_hw_k{k}"] = float("nan")

    if imp_ss_hw is not None:
        imp_flat = np.asarray(imp_ss_hw, dtype=np.float32).reshape(-1)
        for k in K_LIST_HW:
            frac, enrich = _energy_frac_in_topk(delta_hw_flat, imp_flat, k=k)
            out[f"focus_frac__ssgrad_hw_k{k}"] = frac
            out[f"enrichE_impTopK_ssgrad_hw_k{k}"] = enrich
    else:
        for k in K_LIST_HW:
            out[f"focus_frac__ssgrad_hw_k{k}"] = float("nan")
            out[f"enrichE_impTopK_ssgrad_hw_k{k}"] = float("nan")

    rows.append(out)

new_topk_df = pd.DataFrame(rows)
for c in new_topk_df.columns:
    df[c] = new_topk_df[c]

print("Added TOP-K energy columns:")
print(list(new_topk_df.columns))

# -------------------------
# Quick visualization: frac/enrich vs success for a few K
# -------------------------
def _plot_scatter(df, xcol, ycol, title):
    s = df[df.success == True]
    f = df[df.success == False]
    plt.figure(figsize=(6.2, 4.6))
    plt.scatter(f[xcol], f[ycol], s=18, alpha=0.55, label="FAIL")
    plt.scatter(s[xcol], s[ycol], s=18, alpha=0.55, label="SUCCESS")
    plt.xlabel(xcol)
    plt.ylabel(ycol)
    plt.title(title)
    plt.grid(True, alpha=0.25)
    plt.legend()
    plt.tight_layout()
    plt.show()

# total delta energy for CHW (L1)
if "delta_energy_l1" not in df.columns:
    energies = []
    for d in run_data:
        delta = d.get("deltas", {}).get(LAYER, None)
        if delta is None:
            energies.append(float("nan"))
        else:
            dd = delta[0].float().abs().detach().cpu().numpy()
            energies.append(float(dd.sum()))
    df["delta_energy_l1"] = energies

for k in [500, 2000, 5000]:
    colf = f"focus_frac__gxa_chw_k{k}"
    cole = f"enrichE_impTopK_gxa_chw_k{k}"
    if colf in df.columns:
        _plot_scatter(df, "delta_energy_l1", colf, f"grad*act CHW: frac energy in top-{k} important features")
    if cole in df.columns:
        _plot_scatter(df, "delta_energy_l1", cole, f"grad*act CHW: enrichment (top-{k})")

In [None]:
# --- RAW L2 energy inside top-K important features (NO normalization) + ROC-AUC for all K=2000 metrics ---

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score, balanced_accuracy_score

K = 2000
LAYER = "model.22"

# ---------------------------
# Compute raw L2 inside/outside top-K important (CHW)
# ---------------------------
l2_in_list = []
l2_out_list = []

for d in run_data:
    delta = d.get("deltas", {}).get(LAYER, None)
    imp_chw = d.get("imp_gxa_chw", None)

    if (delta is None) or (imp_chw is None):
        l2_in_list.append(float("nan"))
        l2_out_list.append(float("nan"))
        continue

    d_chw = delta[0].float()
    delta_abs_flat = d_chw.abs().detach().cpu().numpy().astype(np.float32).reshape(-1)
    imp_flat = np.asarray(imp_chw, dtype=np.float32).reshape(-1)

    m = np.isfinite(delta_abs_flat) & np.isfinite(imp_flat)
    dflat = delta_abs_flat[m]
    wflat = imp_flat[m]
    n = dflat.size
    if n == 0:
        l2_in_list.append(float("nan"))
        l2_out_list.append(float("nan"))
        continue

    k_eff = int(min(max(int(K), 1), n))
    idx = _topk_idx(wflat, k_eff)  # top-K by importance

    mask = np.zeros(n, dtype=bool)
    mask[idx] = True

    din = dflat[mask]
    dout = dflat[~mask]

    l2_in_list.append(float(np.linalg.norm(din)) if din.size else float("nan"))
    l2_out_list.append(float(np.linalg.norm(dout)) if dout.size else float("nan"))

col_l2_in  = f"l2_in_topk__gxa_chw_k{K}"
col_l2_out = f"l2_out_topk__gxa_chw_k{K}"
df[col_l2_in] = l2_in_list
df[col_l2_out] = l2_out_list

# IMPORTANT: create ratio BEFORE succ/fail slicing (fixes your KeyError)
col_l2_ratio = f"l2_ratio_in_out__gxa_chw_k{K}"
with np.errstate(divide="ignore", invalid="ignore"):
    df[col_l2_ratio] = df[col_l2_in] / df[col_l2_out]

print("Added:", col_l2_in, col_l2_out, col_l2_ratio)
print(df[["success", "drop", col_l2_in, col_l2_out, col_l2_ratio]].head(12))


# ---------------------------
# Violin plots (3 side-by-side)
# ---------------------------
succ = df[df.success == True]
fail = df[df.success == False]

y_bin = df["success"].astype(int).to_numpy()

# --- best balanced accuracy for a single metric (try both directions) ---
def _best_bacc_any_direction(y, x):
    y = np.asarray(y, dtype=int)
    x = np.asarray(x, dtype=float)
    m = np.isfinite(x) & np.isfinite(y)
    if m.sum() < 5 or np.unique(y[m]).size < 2:
        return float("nan"), 0

    yy = y[m]
    xx = x[m]
    xs = np.unique(xx)

    # candidate thresholds
    if xs.size == 1:
        thr_candidates = np.array([xs[0]], dtype=float)
    else:
        mids = (xs[:-1] + xs[1:]) * 0.5
        thr_candidates = np.concatenate(([xs[0] - 1e-12], mids, [xs[-1] + 1e-12]))

    def _best_for_dir(dir_sign: int):
        best = -1.0
        for thr in thr_candidates:
            yhat = (xx >= thr) if dir_sign >= 0 else (xx <= thr)
            ba = float(balanced_accuracy_score(yy, yhat.astype(int)))
            if ba > best:
                best = ba
        return best

    bpos = _best_for_dir(+1)
    bneg = _best_for_dir(-1)
    if bpos >= bneg:
        return float(bpos), +1
    return float(bneg), -1


def _bacc_str(colname: str) -> str:
    if colname not in df.columns:
        return ""
    bacc, _dir = _best_bacc_any_direction(y_bin, df[colname].to_numpy(dtype=float))
    return f" | bAcc={bacc:.3g}" if np.isfinite(bacc) else ""


def _violin_fail_succ(ax, values_s, values_f, title, ylabel=""):
    vs = np.asarray(values_s, float)
    vf = np.asarray(values_f, float)
    vs = vs[np.isfinite(vs)]
    vf = vf[np.isfinite(vf)]

    series, pos, labels = [], [], []
    if vf.size:
        series.append(vf); pos.append(1); labels.append(f"FAIL (n={vf.size})")
    if vs.size:
        series.append(vs); pos.append(2 if vf.size else 1); labels.append(f"SUCCESS (n={vs.size})")

    if not series:
        ax.text(0.5, 0.5, "no finite data", ha="center", va="center")
        ax.set_axis_off()
        return

    ax.violinplot(series, positions=pos, widths=0.8, showmeans=False, showmedians=True, showextrema=False)

    rng = np.random.default_rng(0)
    if vf.size:
        jf = (rng.random(vf.size) - 0.5) * 0.16
        ax.scatter(np.full(vf.size, 1.0) + jf, vf, s=10, alpha=0.35)
        ax.hlines(np.median(vf), 0.78, 1.22, linewidth=3)
    if vs.size:
        xs = 2.0 if vf.size else 1.0
        js = (rng.random(vs.size) - 0.5) * 0.16
        ax.scatter(np.full(vs.size, xs) + js, vs, s=10, alpha=0.35)
        ax.hlines(np.median(vs), xs - 0.22, xs + 0.22, linewidth=3)

    ax.set_xticks(pos)
    ax.set_xticklabels(labels)
    ax.set_title(title)
    if ylabel:
        ax.set_ylabel(ylabel)
    ax.grid(True, alpha=0.25)


fig, axs = plt.subplots(1, 3, figsize=(19.5, 4.8), constrained_layout=True)

_violin_fail_succ(
    axs[0],
    succ[col_l2_in].to_numpy(),
    fail[col_l2_in].to_numpy(),
    f"RAW L2 inside top-K important (K={K})" + _bacc_str(col_l2_in),
    col_l2_in,
)
_violin_fail_succ(
    axs[1],
    succ[col_l2_out].to_numpy(),
    fail[col_l2_out].to_numpy(),
    f"RAW L2 outside top-K important (K={K})" + _bacc_str(col_l2_out),
    col_l2_out,
)
_violin_fail_succ(
    axs[2],
    succ[col_l2_ratio].to_numpy(),
    fail[col_l2_ratio].to_numpy(),
    f"L2_in / L2_out (K={K})" + _bacc_str(col_l2_ratio),
    col_l2_ratio,
)

plt.show()


# ---------------------------
# ROC-AUC for ALL metrics at K=2000
# ---------------------------
def _roc_auc_dir(y, x):
    m = np.isfinite(x) & np.isfinite(y)
    if m.sum() < 5 or np.unique(y[m]).size < 2:
        return float("nan"), 0
    auc_raw = float(roc_auc_score(y[m], x[m]))
    if auc_raw >= 0.5:
        return auc_raw, +1
    return 1.0 - auc_raw, -1


# --- Helper: best balanced accuracy by thresholding ---
def _best_bacc_dir(y, x, direction: int):
    y = np.asarray(y, dtype=int)
    x = np.asarray(x, dtype=float)
    m = np.isfinite(x) & np.isfinite(y)
    if m.sum() < 5 or np.unique(y[m]).size < 2:
        return float("nan"), float("nan")

    yy = y[m]
    xx = x[m]

    xs = np.unique(xx)
    if xs.size == 1:
        thr = float(xs[0])
        yhat = (xx >= thr) if int(direction) >= 0 else (xx <= thr)
        return float(balanced_accuracy_score(yy, yhat.astype(int))), thr

    mids = (xs[:-1] + xs[1:]) * 0.5
    thr_candidates = np.concatenate(([xs[0] - 1e-12], mids, [xs[-1] + 1e-12]))

    best_ba = -1.0
    best_thr = float("nan")
    for thr in thr_candidates:
        yhat = (xx >= thr) if int(direction) >= 0 else (xx <= thr)
        ba = float(balanced_accuracy_score(yy, yhat.astype(int)))
        if ba > best_ba:
            best_ba = ba
            best_thr = float(thr)

    return float(best_ba), float(best_thr)


k = K
metrics_k2000 = [
    f"focus_frac__gxa_chw_k{k}",
    f"enrichE_impTopK_gxa_chw_k{k}",
    f"overlap_jacc__gxa_chw_k{k}",
    f"focus_ratio_mean__gxa_chw_k{k}",
    f"focus_diff_mean__gxa_chw_k{k}",
    f"l2_in_topk__gxa_chw_k{k}",
    f"l2_out_topk__gxa_chw_k{k}",
    f"l2_ratio_in_out__gxa_chw_k{k}",
]

y = df["success"].astype(int).to_numpy()

rows = []
for mname in metrics_k2000:
    if mname not in df.columns:
        continue
    x = df[mname].to_numpy(dtype=float)
    auc, direction = _roc_auc_dir(y, x)
    bacc, bthr = _best_bacc_dir(y, x, direction)

    xs = succ[mname].to_numpy(dtype=float) if mname in succ.columns else np.array([])
    xf = fail[mname].to_numpy(dtype=float) if mname in fail.columns else np.array([])
    xs = xs[np.isfinite(xs)]
    xf = xf[np.isfinite(xf)]

    rows.append({
        "metric": mname,
        "roc_auc": float(auc),
        "auc_direction": int(direction),
        "best_bal_acc": float(bacc),
        "best_bal_thr": float(bthr),
        "mean_s": float(np.nanmean(xs)) if xs.size else float("nan"),
        "mean_f": float(np.nanmean(xf)) if xf.size else float("nan"),
        "std_s": float(np.nanstd(xs)) if xs.size else float("nan"),
        "std_f": float(np.nanstd(xf)) if xf.size else float("nan"),
        "n_s": int(xs.size),
        "n_f": int(xf.size),
    })

auc_df_k2000 = pd.DataFrame(rows).sort_values("roc_auc", ascending=False)


# ---------------------------
# Side-by-side: ROC-AUC barh and best balanced accuracy barh
# ---------------------------
_v = auc_df_k2000.copy()
_v = _v[np.isfinite(_v["roc_auc"].to_numpy(dtype=float))]
_v = _v.sort_values("roc_auc", ascending=True)

_vb = auc_df_k2000.copy()
_vb = _vb[np.isfinite(_vb["best_bal_acc"].to_numpy(dtype=float))]
_vb = _vb.sort_values("best_bal_acc", ascending=True)

h = 0.45 * max(len(_v), len(_vb), 1) + 1.8
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(21.0, h), constrained_layout=True)

# ROC-AUC (left)
ax1.barh(_v["metric"], _v["roc_auc"])
ax1.set_xlabel("ROC-AUC (>=0.5; higher is better)")
ax1.set_title("ROC-AUC for K=2000 metrics")
ax1.grid(True, axis="x", alpha=0.25)
for i, (auc, direc) in enumerate(zip(_v["roc_auc"], _v["auc_direction"])):
    arrow = "↑" if int(direc) >= 0 else "↓"
    ax1.text(float(auc) + 0.003, i, f"{auc:.3f} {arrow}", va="center", fontsize=10)
xmax1 = float(np.nanmax(_v["roc_auc"].to_numpy(dtype=float))) if len(_v) else 1.0
ax1.set_xlim(0.45, min(1.0, xmax1 + 0.06))

# best bAcc (right)
ax2.barh(_vb["metric"], _vb["best_bal_acc"])
ax2.set_xlabel("best_bal_acc (>=0.5; higher is better)")
ax2.set_title("Best balanced accuracy for K=2000 metrics")
ax2.grid(True, axis="x", alpha=0.25)
for i, (ba, thr, direc) in enumerate(zip(_vb["best_bal_acc"], _vb["best_bal_thr"], _vb["auc_direction"])):
    rule = ">=" if int(direc) >= 0 else "<="
    ax2.text(float(ba) + 0.003, i, f"{ba:.3f} | thr {rule} {thr:.3g}", va="center", fontsize=10)
xmax2 = float(np.nanmax(_vb["best_bal_acc"].to_numpy(dtype=float))) if len(_vb) else 1.0
ax2.set_xlim(0.45, min(1.0, xmax2 + 0.10))

plt.show()

# Метрики для анализа Δ-фич и (не)успешности атаки

Ниже собраны все метрики, которые мы считаем в ноутбуке, с кратким описанием **как именно они вычисляются** и **что должны отражать**.
Обозначения:

- Есть чистое изображение `x` и изображение с патчем `x̃`.
- На фиксированном слое детектора получаем фичи `F(x)` и `F(x̃)`.
- **Δ-фичи**: `Δ = F(x̃) - F(x)`, тензор размера `(C,H,W)` (в нашем случае обычно `C=512`, `H=W≈20`).
- `|Δ|` — поэлементный модуль.
- `success ∈ {0,1}` — успешность атаки (по выбранному критерию).
- `importance` — карта “важности фич” для детекции объекта (считали двумя методами):
  - `SSGrad`: **только H×W** (каналы неизбежно схлопнуты).
  - `GradAct` (grad*act): доступна и **H×W**, и **полная C×H×W**.

---

## 1) Базовые метрики только по Δ (без importance)

### `mean_signed`
**Как считаем:** среднее значение Δ по всем элементам:

$[
\text{mean\_signed} = \mathrm{mean}(\Delta)
]$

**Что показывает:** общий “сдвиг знака” (например, систематическое уменьшение/увеличение активаций).
Часто малоинформативно, потому что положительные и отрицательные изменения могут взаимно компенсироваться.

### `l2_rms`
**Как считаем:** RMS по всем элементам Δ:
$
\text{l2\_rms} = \sqrt{\mathrm{mean}(\Delta^2)}
$
**Что показывает:** общую мощность изменения фич (масштаб атаки в фич-пространстве).

### `abs_mean`, `abs_max`
**Как считаем:**
$
\text{abs\_mean} = \mathrm{mean}(|\Delta|), \quad
\text{abs\_max} = \max(|\Delta|)
$
**Что показывает:** среднюю и пиковую величину изменения.

### Канальные метрики (по каналам C)
Пусть:
$
E_c = \mathrm{mean}_{h,w}(|\Delta_{c,h,w}|)
$
это энергия изменения в канале `c`.

- `chan_energy_topk_mean` — среднее по top-k каналов по `E_c`.
- `chan_energy_gini` — **Gini** по вектору `E` (неотрицательный).

**Gini (для неотрицательных):** мера неравномерности распределения энергии по каналам.
**Интерпретация:** высокий Gini → атака “концентрирует” изменение в небольшом числе каналов; низкий → изменения размазаны равномернее.

### Метрики на H×W-карте Δ
Мы строим “magnitude карту” на сетке H×W, схлопывая каналы, напр. через `reduce_mode="l2"`:
$
\Delta_{hw} = \mathrm{reduce}_C(\Delta) \in \mathbb{R}^{H\times W},\ \Delta_{hw}\ge 0
$
Далее считаем:
- `hw_abs_mean`, `hw_abs_max`
- `hw_gini` — Gini по `Δ_hw`
- `hw_sparsity_p90` — доля ячеек, попавших в верхние 10% по `Δ_hw`

### ROI-метрики (по bbox из clean-предсказания)
bbox человека (из clean) переводим в ROI на сетке H×W. Затем:
- `roi_abs_mean` — среднее `Δ_hw` внутри ROI
- `roi_abs_ratio` — отношение среднего внутри ROI к среднему вне ROI

**Смысл:** если успех атаки связан с попаданием изменений в область объекта, то `roi_abs_ratio`/`roi_abs_mean` могут расти для успешных атак.

---

## 2) Метрики Δ vs importance (сопоставление с “важными фичами”)

Мы сравниваем `|Δ|` с `importance` в двух представлениях:

### H×W (SSGrad и GradAct collapsed)
Есть:
- `Δ_hw` (например reduce=`l2`)
- `imp_hw` (важность на H×W)

Метрики:

#### `cos_hw_ssgrad`, `cos_hw_gxa`
**Как считаем:** cosine similarity между `Δ_hw` и `imp_hw`:
$
\cos = \frac{\langle \Delta_{hw},\ imp_{hw}\rangle}{\|\Delta_{hw}\|\cdot\|imp_{hw}\|}
$
**Что показывает:** насколько форма карты Δ похожа на карту важности.

#### `l2_hw_*`, `l2rel_hw_*`
**Как считаем:**
- `l2`: \(\|\Delta_{hw} - imp_{hw}\|_2\)
- `l2rel`: \(\frac{\|\Delta_{hw}-imp_{hw}\|_2}{\|\Delta_{hw}\|_2 + \|imp_{hw}\|_2}\)

**Что показывает:** “расхождение” Δ и importance.

#### `topq_energy_frac_*` (обычно q=10%)
**Как считаем:** доля энергии `|Δ_hw|` в верхних q% наиболее важных ячеек:
$
\text{topq\_energy\_frac} = \frac{\sum_{(h,w)\in \text{top-q%}(imp)} |\Delta_{hw}|}{\sum_{h,w} |\Delta_{hw}|}
$
**Что показывает:** концентрирует ли атака энергию именно в наиболее важных местах.

### C×H×W (GradAct full)
Есть полный тензор важности `imp_chw ≥ 0` и `|Δ|_chw`.

#### `cos_chw_gxa`
Cosine similarity на векторизации всех элементов:
$
\cos = \frac{\langle |\Delta|,\ imp\rangle}{\||\Delta|\|\cdot\|imp\|}
$

#### `l2_chw_gxa`, `l2rel_chw_gxa`
L2 и относительная L2 (как выше, но по CHW).

#### `spearman_chan_energy_gxa`
Берём вектор энергий по каналам:
$
E^\Delta_c = \mathrm{mean}_{h,w}(|\Delta_{c,h,w}|), \quad
E^{imp}_c = \mathrm{mean}_{h,w}(imp_{c,h,w})
$
Считаем Spearman(ранговую корреляцию) между двумя векторами.

**Смысл:** совпадает ли ранжирование “важных каналов” и “каналов, по которым бьёт патч”.

#### `jaccard_topk_chan_gxa`
Берём top-k каналов по `E^\Delta` и по `E^{imp}` и считаем Jaccard:
$
J = \frac{|TopK(E^\Delta)\cap TopK(E^{imp})|}{|TopK(E^\Delta)\cup TopK(E^{imp})|}
$
**Смысл:** доля совпадения самых “энергичных” и самых “важных” каналов.

---

## 3) Top-K фокус на важных фичах (feature-focus experiments)

Здесь мы выбираем **top-K элементов importance** и смотрим, сколько энергии Δ попало именно туда.
Работает на:
- CHW: (очень много элементов, K может быть 2000+)
- HW: (всего 400 элементов, K меньше)

### Ключевая идея
Пусть `S = TopK(imp)` — множество индексов K самых важных элементов.

Тогда:

### `focus_frac__*__*__kK`
**Как считаем:**
$
\text{focus\_frac} = \frac{\sum_{i \in S} |\Delta_i|}{\sum_i |\Delta_i|}
$
**Что показывает:** какая доля энергии атаки попала в важные фичи (масштаб-независимо).

### `focus_enrich__*__*__kK`
**Как считаем:** нормируем на “ожидаемую долю при равномерном попадании” `K/N`:
$
\text{focus\_enrich} = \frac{\text{focus\_frac}}{K/N}
$
**Интерпретация:**
- ~1 → попадание как случайное
- >1 → атака концентрируется в важных
- <1 → атакует “неважные”

### `overlap_jacc__delta_vs_imp__...__kK` (CHW GradAct)
Берём `TopK(|Δ|)` и `TopK(imp)` и считаем Jaccard:
$
J = \frac{|TopK(|\Delta|)\cap TopK(imp)|}{|TopK(|\Delta|)\cup TopK(imp)|}
$
**Смысл:** совпадают ли “самые изменённые” элементы с “самыми важными”.

---

## 4) “Только важные фичи” без общей энергии: внутри vs снаружи (mean-contrast)

Мы сравниваем **средний уровень |Δ|** внутри top-K(imp) и снаружи.

Пусть `S = TopK(imp)`, тогда:
$
\mu_{in}=\mathrm{mean}_{i\in S}(|\Delta_i|), \quad \mu_{out}=\mathrm{mean}_{i\notin S}(|\Delta_i|)
$

### `focus_ratio_mean__...__kK`
$
\text{ratio} = \frac{\mu_{in}}{\mu_{out}}
$
**Смысл:** насколько сильнее “дрожат” важные фичи относительно неважных.

### `focus_diff_mean__...__kK`
$
\text{diff} = \mu_{in}-\mu_{out}
$
**Смысл:** абсолютная разница уровня изменения.

---

## 5) RAW L2 энергия Δ в важных фичах (без нормировок)

Это именно “энергия атаки в выбранных важных координатах”, а не сравнение Δ и importance.

Пусть `S = TopK(imp)`:

### `l2_in_topk__...__kK`
$
\text{L2\_in} = \|\ |\Delta|_{S}\ \|_2
$

### `l2_out_topk__...__kK`
$
\text{L2\_out} = \|\ |\Delta|_{\neg S}\ \|_2
$

### `l2_ratio_in_out__...__kK`
$
\text{L2\_ratio} = \frac{\text{L2\_in}}{\text{L2\_out}}
$

**Смысл:** сколько “квадратичной энергии” Δ оказалось в важной области относительно неважной.

---

## 6) ROC-AUC и направление (auc_direction)

Для каждой метрики `m` мы считаем ROC-AUC по бинарной метке `success`.
Так как нам важно сравнивать “качество разделения”, мы всегда приводим AUC к ≥ 0.5:

- `auc_direction = +1`: **большие значения метрики → более вероятный SUCCESS**
- `auc_direction = -1`: метрика работает “в обратную сторону” (SUCCESS при меньших значениях), поэтому мы выводим `1 - AUC_raw`

Это удобно для сортировки метрик по “разделяющей способности”, но важно помнить направление при интерпретации.

---

## 7) Как читать метрики в контексте гипотезы

Гипотеза: *успешная атака отличается тем, что энергия Δ лучше “попадает” в важные фичи (или ROI) для детекции.*

Ожидаемые паттерны при подтверждении гипотезы:
- выше `focus_frac`, `focus_enrich`
- выше `focus_ratio_mean` / `focus_diff_mean`
- выше `l2_in_topk` и/или выше `l2_ratio_in_out`
- выше `topq_energy_frac_*`
- выше `overlap_jacc_*` и/или `jaccard_topk_chan_gxa`
- выше `roi_abs_ratio`

Если “всё плохо” и AUC близко к 0.5, это означает, что:
- либо патч меняет фичи схожим образом и в success и в fail,
- либо текущая importance-карта не соответствует реальным причинным “важным” фичам (ошибка выбора target/ROI/критерия),
- либо успех определяется не попаданием в важные фичи на данном слое, а динамикой на других слоях/в голове/в NMS/в bbox selection.

# Взвешенные дельты

В дальнейших расчетах будут участвовать тензоры дельт от наложения патча, умноженные на тензоры важностей соответствующих фич (поэлементно)

In [None]:
# --- Δ overlays + grad*act + weighted(Δ⊙imp) with 5 columns ---
# Cols:
# 1) patched
# 2) patched + signed Δ (seismic, symmetric)
# 3) patched + signed grad*act (seismic, symmetric around 0)
# 4) patched + signed (Δ ⊙ importance) with alpha=0.0 (i.e., hidden overlay) + bbox
# 5) patched + signed (Δ ⊙ importance) with custom alpha ONLY for this column + bbox

from pathlib import Path
from typing import Any, Dict, List
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt


def _robust_norm_signed(x: np.ndarray, q: float = 99.0, eps: float = 1e-12):
    """Normalize to [-1,1] using symmetric robust scale based on q-th percentile of |x|."""
    x = np.asarray(x, dtype=np.float32)
    s = float(np.percentile(np.abs(x[np.isfinite(x)]), q)) if np.isfinite(x).any() else 1.0
    s = max(s, eps)
    return np.clip(x / s, -1.0, 1.0), s


def _reduce_channels(t_chw: torch.Tensor, mode: str = "topk_mean", topk: int = 32) -> torch.Tensor:
    """
    Reduce (C,H,W) -> (H,W).
    Expected to exist in your notebook already; kept here for safety.
    Modes used in this cell:
      - "l2": sqrt(mean_c x^2)
      - "abs_mean": mean_c |x|
      - "topk_mean": mean of top-k |x| across channels (signed by mean sign per pixel)
      - "topk_sum": sum of top-k |x| across channels (signed by mean sign per pixel)
    """
    if t_chw.ndim != 3:
        raise ValueError(f"_reduce_channels expects (C,H,W), got {tuple(t_chw.shape)}")

    C, H, W = t_chw.shape
    if mode == "l2":
        return torch.sqrt(torch.mean(t_chw * t_chw, dim=0).clamp(min=0))
    if mode == "abs_mean":
        return torch.mean(t_chw.abs(), dim=0)

    # topk by magnitude across channels, but keep sign:
    # sign per (h,w) from mean over channels (robust enough for visualization)
    k = int(min(max(int(topk), 1), C))
    mag = t_chw.abs()  # (C,H,W)
    topv, _ = torch.topk(mag, k=k, dim=0, largest=True, sorted=False)  # (k,H,W)
    base = topv.mean(dim=0) if mode == "topk_mean" else topv.sum(dim=0)
    sgn = torch.sign(t_chw.mean(dim=0))  # (H,W)
    return base * sgn


def _draw_clean_bbox(ax, bbox_xyxy, success: bool):
    """Draw bbox (x1,y1,x2,y2) on axis; assumes image is 640x640."""
    if bbox_xyxy is None:
        return
    x1, y1, x2, y2 = [float(v) for v in bbox_xyxy]
    w = max(0.0, x2 - x1)
    h = max(0.0, y2 - y1)
    import matplotlib.patches as patches
    rect = patches.Rectangle(
        (x1, y1), w, h,
        linewidth=2.0,
        edgecolor=("lime" if success else "red"),
        facecolor="none",
    )
    ax.add_patch(rect)


def visualize_delta_maps_5cols(
    run_data: List[Dict[str, Any]],
    layer: str = "model.22",
    reduce_mode: str = "topk_mean",
    topk: int = 32,
    max_rows: int | None = None,
    alpha: float = 0.45,
    alpha_weighted: float = 0.25,   # ONLY for col 5
    save_dir: str | Path | None = None,
):
    """
    Per-row layout (5 columns):
      1) patched
      2) patched + signed Δ
      3) patched + signed grad*act (imp_gxa_hw, sign restored)
      4) patched + signed (Δ ⊙ importance), BUT alpha=0.0 (hidden)  [still draws bbox]
      5) patched + signed (Δ ⊙ importance) with alpha=alpha_weighted

    Requirements for `run_data` entries:
      - d["patch_lb"] or d["patched_lb"] or equivalent PIL image (letterboxed)
      - d["gradcam_info"]["picked_bbox"] for bbox overlay (optional)
      - d["success"], d["conf_clean"], d["conf_patch"], d["drop"]
      - d["deltas"][layer] as torch.Tensor (1,C,H,W) OR (C,H,W) (we handle both)
      - d["imp_gxa_hw"] as (H,W) float array in [0,1] (used for col 3, sign restored)
      - d["imp_gxa_chw"] as (C,H,W) float array >=0 (used for weighted product)
    """

    # choose subset
    rows = run_data if (max_rows is None) else run_data[: int(max_rows)]
    n = len(rows)
    if n == 0:
        print("[warn] No rows to visualize.")
        return

    fig, axes = plt.subplots(nrows=n, ncols=5, figsize=(25, 4.6 * n))
    if n == 1:
        axes = np.expand_dims(axes, axis=0)

    if save_dir is not None:
        save_dir = Path(save_dir)
        save_dir.mkdir(parents=True, exist_ok=True)

    for i, d in enumerate(rows):
        # ---- basic fields ----
        p = d.get("path", f"row{i}")
        success = bool(d.get("success", False))
        status = "SUCCESS" if success else "FAIL"
        conf_c = float(d.get("conf_clean", float("nan")))
        conf_p = float(d.get("conf_patch", float("nan")))
        drop = float(d.get("drop", float("nan")))

        gcinfo = d.get("gradcam_info", {}) if isinstance(d.get("gradcam_info", {}), dict) else {}
        bbox_xyxy = gcinfo.get("picked_bbox", None)

        # patched image (try a few common keys)
        patched_img = d.get("patch_lb", None)
        if patched_img is None:
            patched_img = d.get("patched_lb", None)
        if patched_img is None:
            patched_img = d.get("clean_lb", None)

        if patched_img is None:
            # nothing to draw
            for j in range(5):
                axes[i, j].set_axis_off()
                axes[i, j].set_title("missing image")
            continue

        img_np = np.asarray(patched_img.convert("RGB"))
        tgt_h, tgt_w = img_np.shape[0], img_np.shape[1]

        # delta (prefer stored torch tensor)
        delta = d.get("deltas", {}).get(layer, None)
        if delta is None:
            # still show image-only
            for j, title in enumerate(
                ["patched", "patched + signed Δ", "patched + signed grad*act", "patched + signed (Δ⊙imp) alpha=0", "patched + signed (Δ⊙imp)"]
            ):
                axes[i, j].imshow(img_np)
                _draw_clean_bbox(axes[i, j], bbox_xyxy, success)
                axes[i, j].set_title(title + " (missing Δ)")
                axes[i, j].axis("off")
            continue

        # normalize delta shape to (C,H,W) torch
        if isinstance(delta, torch.Tensor):
            if delta.ndim == 4:
                delta_chw = delta[0].detach().cpu()
            elif delta.ndim == 3:
                delta_chw = delta.detach().cpu()
            else:
                raise ValueError(f"Unexpected delta tensor shape: {tuple(delta.shape)}")
        else:
            # allow numpy
            delta_arr = np.asarray(delta)
            if delta_arr.ndim == 4:
                delta_arr = delta_arr[0]
            if delta_arr.ndim != 3:
                raise ValueError(f"Unexpected delta array shape: {tuple(delta_arr.shape)}")
            delta_chw = torch.as_tensor(delta_arr, dtype=torch.float32)

        C, H, W = tuple(delta_chw.shape)

        # --- signed Δ map (H,W) then upsample ---
        signed_hw_t = _reduce_channels(delta_chw.float(), mode=reduce_mode, topk=topk)  # (H,W) signed
        if tuple(signed_hw_t.shape) != (tgt_h, tgt_w):
            signed_hw_t = F.interpolate(signed_hw_t[None, None, ...], size=(tgt_h, tgt_w), mode="nearest")[0, 0]
        signed_hw = signed_hw_t.numpy()
        signed_n, _ = _robust_norm_signed(signed_hw, q=99.0)

        # --- signed grad*act map from imp_gxa_hw (H,W in [0,1]) ---
        gxa_hw = d.get("imp_gxa_hw", None)
        if gxa_hw is not None:
            gxa_hw_t = torch.as_tensor(np.asarray(gxa_hw), dtype=torch.float32)
            # restore sign from Δ map (same grid) BEFORE upsampling
            sign_hw_small = torch.sign(_reduce_channels(delta_chw.float(), mode=reduce_mode, topk=topk))
            # align shapes (should already be H,W)
            if tuple(gxa_hw_t.shape) != tuple(sign_hw_small.shape):
                raise ValueError(f"imp_gxa_hw shape {tuple(gxa_hw_t.shape)} != sign_hw shape {tuple(sign_hw_small.shape)}")
            gxa_signed_small = gxa_hw_t * sign_hw_small  # signed grad*act proxy
            if tuple(gxa_signed_small.shape) != (tgt_h, tgt_w):
                gxa_signed_up = F.interpolate(gxa_signed_small[None, None, ...], size=(tgt_h, tgt_w), mode="nearest")[0, 0]
            else:
                gxa_signed_up = gxa_signed_small
            gxa_signed = gxa_signed_up.numpy()
            gxa_n, _ = _robust_norm_signed(gxa_signed, q=99.0)
        else:
            gxa_n = None

        # --- signed (Δ ⊙ importance) computed on full (C,H,W) ---
        imp_chw = d.get("imp_gxa_chw", None)
        if imp_chw is not None:
            imp_chw_t = torch.as_tensor(np.asarray(imp_chw), dtype=torch.float32)
            if tuple(imp_chw_t.shape) != tuple(delta_chw.shape):
                raise ValueError(f"imp_gxa_chw shape {tuple(imp_chw_t.shape)} != delta_chw shape {tuple(delta_chw.shape)}")

            weighted_chw = delta_chw.float() * imp_chw_t  # KEEP SIGN ✅
            weighted_hw_t = _reduce_channels(weighted_chw, mode=reduce_mode, topk=topk).detach().cpu()  # signed (H,W)

            if tuple(weighted_hw_t.shape) != (tgt_h, tgt_w):
                weighted_hw_t = F.interpolate(weighted_hw_t[None, None, ...], size=(tgt_h, tgt_w), mode="nearest")[0, 0]
            weighted_hw = weighted_hw_t.numpy()
            weighted_n, _ = _robust_norm_signed(weighted_hw, q=99.0)
        else:
            weighted_n = None

        # -------------------------
        # Plot: 5 columns
        # -------------------------
        # Col 1: patched
        axes[i, 0].imshow(img_np)
        _draw_clean_bbox(axes[i, 0], bbox_xyxy, success)
        axes[i, 0].set_title(f"patched | {status}\nclean={conf_c:.3f} → patched={conf_p:.3f} (drop={drop:.3f})")
        axes[i, 0].axis("off")

        # Col 2: signed Δ
        axes[i, 1].imshow(img_np)
        _draw_clean_bbox(axes[i, 1], bbox_xyxy, success)
        im2 = axes[i, 1].imshow(signed_n, vmin=-1, vmax=1, cmap="seismic", alpha=alpha)
        axes[i, 1].set_title(f"patched + signed Δ | {reduce_mode}\n{status}")
        axes[i, 1].axis("off")
        plt.colorbar(im2, ax=axes[i, 1], fraction=0.046, pad=0.04)

        # Col 3: signed grad*act
        axes[i, 2].imshow(img_np)
        _draw_clean_bbox(axes[i, 2], bbox_xyxy, success)
        if gxa_n is None:
            axes[i, 2].set_title("patched + signed grad*act (missing)")
        else:
            im3 = axes[i, 2].imshow(gxa_n, vmin=-1, vmax=1, cmap="seismic", alpha=alpha)
            axes[i, 2].set_title(f"patched + signed grad*act\n{status}")
            plt.colorbar(im3, ax=axes[i, 2], fraction=0.046, pad=0.04)
        axes[i, 2].axis("off")

        # Col 4: signed (Δ⊙imp) alpha=0.0
        axes[i, 3].imshow(img_np)
        _draw_clean_bbox(axes[i, 3], bbox_xyxy, success)
        if weighted_n is None:
            axes[i, 3].set_title("patched + signed (Δ⊙imp) (missing) | alpha=0")
        else:
            axes[i, 3].imshow(weighted_n, vmin=-1, vmax=1, cmap="seismic", alpha=1.0)  # alpha=0 ✅
            axes[i, 3].set_title(f"patched + signed (Δ⊙imp) | alpha=0\n{status}")
        axes[i, 3].axis("off")

        # Col 5: signed (Δ⊙imp) custom alpha
        axes[i, 4].imshow(img_np)
        _draw_clean_bbox(axes[i, 4], bbox_xyxy, success)
        if weighted_n is None:
            axes[i, 4].set_title("patched + signed (Δ⊙imp) (missing)")
        else:
            im5 = axes[i, 4].imshow(weighted_n, vmin=-1, vmax=1, cmap="seismic", alpha=float(alpha_weighted))
            axes[i, 4].set_title(f"patched + signed (Δ⊙imp) | alpha={alpha_weighted:g}\n{status}")
            plt.colorbar(im5, ax=axes[i, 4], fraction=0.046, pad=0.04)
        axes[i, 4].axis("off")

        # Row label
        axes[i, 0].set_ylabel(
            f"{Path(str(p)).name}\n{status} | clean={conf_c:.3f} patched={conf_p:.3f} drop={drop:.3f}",
            rotation=0,
            labelpad=70,
            va="center",
        )

        # Optional save: save composite RGBs for cols 2/3/5 (col 4 is intentionally alpha=0)
        if save_dir is not None:
            stem = Path(str(p)).stem
            base = img_np.astype(np.float32) / 255.0

            def _save_signed_overlay(arr_n, out_path: Path, a: float):
                cmap = plt.get_cmap("seismic")
                rgba = cmap((arr_n + 1.0) * 0.5)
                comp = (1 - a) * base + a * rgba[..., :3]
                comp = np.clip(comp, 0.0, 1.0)
                plt.imsave(out_path, comp)

            out2 = save_dir / f"{stem}__overlay_signedDelta_{reduce_mode}.png"
            _save_signed_overlay(signed_n, out2, float(alpha))

            if gxa_n is not None:
                out3 = save_dir / f"{stem}__overlay_signedGradAct.png"
                _save_signed_overlay(gxa_n, out3, float(alpha))

            if weighted_n is not None:
                out5 = save_dir / f"{stem}__overlay_signedWeighted_{reduce_mode}_a{alpha_weighted:g}.png"
                _save_signed_overlay(weighted_n, out5, float(alpha_weighted))

    plt.tight_layout()
    plt.show()


# Example call:
visualize_delta_maps_5cols(
     run_data=examples,          # or run_data
     layer="model.22",
     reduce_mode="topk_mean",
     topk=32,
     max_rows=10,
     alpha=1,
     alpha_weighted=0.22,
     save_dir=None,
 )