# loss代码：
/opt/data/private/BlackBox/loss.py

In [6]:
# %%writefile loss.py
# 文件: loss.py
# 目的：实现 BlackBox 论文风格的攻击损失（兼容 GSE + TMM）
# 兼容路径: /opt/data/private/BlackBox/gse.py, /opt/data/private/BlackBox/tmm.py
# 设计原则：
#  - 用 GSE 提供的 decoder intermediate logits 作为检测打分来源
#  - 提供多种 layer 聚合策略（论文语义通常按层 avg）
#  - 提供 TV / L2 正则化（常用于补丁平滑与幅度控制）；保留 NPS 钩子以便外部实现
#  - 训练时，应只优化补丁参数（模型参数冻结），调用 loss.backward() 会把梯度传回 patch
import torch
import torch.nn.functional as F
from typing import Optional, Callable, List, Dict
from tmm import NestedTensor

class BlackBoxLoss:
    """
    BlackBox-style loss for adversarial patch optimization.

    主要组件（可配置）：
      - detection_weight: 检测目标损失权重（论文主目标）
      - tv_weight: Total Variation 正则项权重（平滑）
      - l2_weight: L2 约束权重（限制补丁幅度）
      - target_class: 目标类别索引（例如论文目标: 行人类 index，DETR里可能是1）
      - layer_aggregation: 聚合中间层输出策略，支持:
          * 'mean_logits' - 对各层 logits 做平均后再算 softmax/sigmoid -> loss
          * 'mean_prob'   - 先对每层 logits 做 softmax/sigmoid 得到 prob，再对 prob 求均值 -> loss
          * 'per_layer_loss' - 对每层单独计算 loss，最后平均这些 loss（最贴近论文“对每层求和/均值”的做法）
      - use_sigmoid_for_binary: 对单一目标（binary 目标如 person）是否用 sigmoid；DETR通常多分类，若class数>2建议用 softmax
      - reduction: 'mean' or 'sum' for loss aggregation over batch
    """

    def __init__(
        self,
        gse,                         # GradientSelfEnsemble 实例或任意 callable that returns logits
        target_class: int = 1,
        detection_weight: float = 1.0,
        tv_weight: float = 1e-3,
        l2_weight: float = 1e-3,
        layer_aggregation: str = 'per_layer_loss',  # 'mean_logits' | 'mean_prob' | 'per_layer_loss'
        use_sigmoid_for_binary: bool = True,
        reduction: str = 'mean',
        nps_fn: Optional[Callable[[torch.Tensor], torch.Tensor]] = None,  # 非可打印性分数函数 (可选)
        device: Optional[torch.device] = None
    ):
        self.gse = gse
        self.target_class = target_class
        self.detection_weight = float(detection_weight)
        self.tv_weight = float(tv_weight)
        self.l2_weight = float(l2_weight)
        self.layer_aggregation = layer_aggregation
        self.use_sigmoid_for_binary = use_sigmoid_for_binary
        self.reduction = reduction
        self.nps_fn = nps_fn
        self.device = device if device is not None else next(gse.model.parameters()).device

        assert layer_aggregation in ('mean_logits', 'mean_prob', 'per_layer_loss'), \
            "layer_aggregation must be one of 'mean_logits','mean_prob','per_layer_loss'"

    # --------------------------
    # detection loss helpers
    # --------------------------
    @staticmethod
    def _softmax_probs_over_classes(logits: torch.Tensor):
        # logits shape: [..., num_classes]
        return F.softmax(logits, dim=-1)

    @staticmethod
    def _sigmoid_probs_over_classes(logits: torch.Tensor):
        return torch.sigmoid(logits)

    def _compute_detection_loss_from_logits(
        self,
        logits_all_layers: torch.Tensor,
        target_class: Optional[int] = None
    ) -> torch.Tensor:
        """
        logits_all_layers: [L, B, Q, C]   (L = decoder layers)
        返回标量 loss（对 batch 聚合，按 reduction）
        论文实现细节：对所有层进行聚合以得到最终判分 / loss。这里提供三种聚合策略（可配置）。
        """

        if target_class is None:
            target_class = self.target_class

        L, B, Q, C = logits_all_layers.shape

        # Strategy 1: mean_logits -> average logits across layers then compute prob on class
        if self.layer_aggregation == 'mean_logits':
            mean_logits = logits_all_layers.mean(dim=0)  # [B, Q, C]
            if self.use_sigmoid_for_binary and C == 1:
                probs = self._sigmoid_probs_over_classes(mean_logits).squeeze(-1)  # [B,Q]
                # For binary sigmoid case, target class = single logit (we assume positive)
                loss_per_query = probs[..., 0] if probs.dim() == 2 else probs  # safe
            else:
                probs = self._softmax_probs_over_classes(mean_logits)  # [B,Q,C]
                loss_per_query = probs[..., target_class]  # [B,Q]

            # want to minimize target_class probability (make model less confident)
            # create loss = mean(probabilities) so gradient pushes down the prob
            if self.reduction == 'mean':
                return loss_per_query.mean() * self.detection_weight
            else:
                return loss_per_query.sum() * self.detection_weight

        # Strategy 2: mean_prob -> per-layer probs then average
        elif self.layer_aggregation == 'mean_prob':
            # compute probs per layer then mean
            if self.use_sigmoid_for_binary and C == 1:
                probs = torch.sigmoid(logits_all_layers).squeeze(-1)  # [L,B,Q]
                target_probs = probs[..., 0]
            else:
                probs = F.softmax(logits_all_layers, dim=-1)  # [L,B,Q,C]
                target_probs = probs[..., target_class]  # [L,B,Q]
            mean_target_prob = target_probs.mean(dim=0)  # [B,Q]
            if self.reduction == 'mean':
                return mean_target_prob.mean() * self.detection_weight
            else:
                return mean_target_prob.sum() * self.detection_weight

        # Strategy 3: per_layer_loss -> compute loss per layer then average
        elif self.layer_aggregation == 'per_layer_loss':
            layer_losses = []
            for li in range(L):
                logits = logits_all_layers[li]  # [B,Q,C]
                if self.use_sigmoid_for_binary and C == 1:
                    probs = torch.sigmoid(logits).squeeze(-1)  # [B,Q]
                    target_probs = probs  # assumes single logit is positive class
                else:
                    probs = F.softmax(logits, dim=-1)  # [B,Q,C]
                    target_probs = probs[..., target_class]  # [B,Q]
                if self.reduction == 'mean':
                    layer_losses.append(target_probs.mean())
                else:
                    layer_losses.append(target_probs.sum())
            # average over layers
            loss = torch.stack(layer_losses).mean() * self.detection_weight
            return loss

        else:
            raise RuntimeError("未知的 layer_aggregation 策略")

    # --------------------------
    # regularizers
    # --------------------------
    @staticmethod
    def total_variation(image: torch.Tensor):
        """
        Total variation for a batch of image patches or images.
        Accepts tensor shape [B, C, H, W] or [1, C, H, W] (for a single patch padded into image).
        Returns scalar (mean over batch).
        """
        # use anisotropic TV (sum of abs of differences)
        if image.dim() != 4:
            raise ValueError("total_variation expects [B,C,H,W]")
        dh = torch.abs(image[:, :, 1:, :] - image[:, :, :-1, :])
        dw = torch.abs(image[:, :, :, 1:] - image[:, :, :, :-1])
        return (dh.mean() + dw.mean())

    @staticmethod
    def l2_norm(image: torch.Tensor):
        # L2 norm averaged over batch
        return torch.mean(image.pow(2))

    # --------------------------
    # 主接口：给定原始图像、patched_image 以及补丁张量，返回总loss与各子项
    # --------------------------
    def __call__(
        self,
        imgs: torch.Tensor,              # 原始batch images [B,3,H,W] or list->stacked
        patched_imgs: torch.Tensor,      # patched images [B,3,H,W] or already NestedTensor
        patch_tensor: Optional[torch.Tensor] = None,  # 裁剪/原补丁 [1,C,ph,pw]（可选，用于 TV/L2 正则）
        reduction: Optional[str] = None
    ) -> Dict[str, torch.Tensor]:
        """
        计算并返回损失字典：
          {
            'total_loss': Tensor(scalar),
            'det_loss': Tensor,
            'tv_loss': Tensor,
            'l2_loss': Tensor,
            'nps_loss': Tensor
          }
    
        注意：
        - 为避免 DETR 内部在 no_grad 路径中对 view 做 inplace copy 的问题，
          我们在传入模型前将 patched_imgs 包装为 NestedTensor（若尚未包装）。
        - patched_imgs 可以是 Tensor 或预先构造好的 NestedTensor；函数会兼容两者。
        """
        if reduction is None:
            reduction = self.reduction
    
        # --- Wrap patched images into NestedTensor to avoid DETR's internal inplace copy issue
        # Ensure images are on the correct device
        if isinstance(patched_imgs, torch.Tensor):
            # move to device and ensure contiguous
            patched_imgs = patched_imgs.to(self.device)
            if not patched_imgs.is_contiguous():
                patched_imgs = patched_imgs.contiguous()
            samples = NestedTensor(tensors=patched_imgs)
        else:
            # assume caller has provided a NestedTensor or compatible object
            samples = patched_imgs
    
        # 1) use GSE to compute logits across decoder layers
        # GSE is expected to accept the same type as model.forward (NestedTensor for DETR)
        logits_all = self.gse(samples, return_all_layers=True)  # expected shape [L, B, Q, C]
        logits_all = logits_all.to(self.device)
    
        # 2) detection loss (target class prob) computed from GSE logits
        det_loss = self._compute_detection_loss_from_logits(logits_all, target_class=self.target_class)
    
        # 3) regularization losses
        tv_loss = torch.tensor(0.0, device=self.device)
        l2_loss = torch.tensor(0.0, device=self.device)
        nps_loss = torch.tensor(0.0, device=self.device)
    
        # If patch_tensor provided, compute TV and L2 on the patch content
        if patch_tensor is not None:
            p = patch_tensor.to(self.device)
            # Normalize shape to [B, C, H, W] if needed
            if p.dim() == 3:
                # [C, H, W] -> [1, C, H, W]
                p_batch = p.unsqueeze(0)
            elif p.dim() == 4 and p.shape[0] == 1:
                p_batch = p
            elif p.dim() == 4 and p.shape[0] > 1:
                p_batch = p
            else:
                # fallback: try to unsqueeze
                p_batch = p.unsqueeze(0) if p.dim() == 3 else p
    
            # compute TV on patch region; fallback to whole image TV if fails
            try:
                tv_loss = self.total_variation(p_batch)
            except Exception:
                # As a fallback compute TV on a dummy expanded patched image region if available
                # Here we compute on p_batch anyway to avoid raising
                tv_loss = self.total_variation(p_batch)
    
            l2_loss = self.l2_norm(p_batch)
    
        # optional NPS (non-printability score) if function provided
        if self.nps_fn is not None and patch_tensor is not None:
            try:
                nps_loss = self.nps_fn(patch_tensor.to(self.device))
            except Exception:
                # If nps function fails, keep as zero but do not break optimization
                nps_loss = torch.tensor(0.0, device=self.device)
    
        # 4) combine losses
        total_loss = det_loss + self.tv_weight * tv_loss + self.l2_weight * l2_loss
        if self.nps_fn is not None:
            total_loss = total_loss + nps_loss
    
        # Respect reduction if needed (det_loss and others already aggregated in helper)
        return {
            'total_loss': total_loss,
            'det_loss': det_loss,
            'tv_loss': tv_loss,
            'l2_loss': l2_loss,
            'nps_loss': nps_loss
        }
    
        

Writing loss.py


# 替换GSE代码：

In [4]:
# %%writefile gse.py
# /opt/data/private/BlackBox/gse.py
import torch
import torch.nn as nn
from typing import List, Optional

class GradientSelfEnsemble:
    """修正版：适配DETR的维度顺序与解码器层属性名差异，并兼容 NestedTensor 输入"""
    def __init__(self, model: nn.Module, device: Optional[torch.device] = None):
        self.model = model
        self.device = device if device is not None else next(model.parameters()).device
        self._last_captured_layers: List[torch.Tensor] = []
        self._hooks = []

    def _find_decoder_layers(self) -> List[nn.Module]:
        """适配：同时检查'layers'（复数）和'layer'（单数）属性"""
        if hasattr(self.model, 'transformer'):
            transformer = getattr(self.model, 'transformer')
            if hasattr(transformer, 'decoder'):
                dec = getattr(transformer, 'decoder')
                # 优先检查'layers'，再检查'layer'（适配不同DETR版本）
                for attr in ['layers', 'layer']:
                    if hasattr(dec, attr):
                        layers = getattr(dec, attr)
                        # 转换为列表（处理ModuleList或list）
                        return list(layers) if isinstance(layers, (nn.ModuleList, list, tuple)) else [layers]
                # 兜底：遍历解码器子模块，找到所有层（命名含'layer'）
                decoder_layers = []
                for name, child in dec.named_children():
                    if 'layer' in name.lower():
                        decoder_layers.append(child)
                if decoder_layers:
                    return decoder_layers
        # 若所有方法都找不到，报错
        raise RuntimeError(
            "无法在model.transformer.decoder中找到解码器层（检查'layers'或'layer'属性）。"
            "请使用官方DETR模型，或修改模型以暴露解码器层。"
        )

    def _clear_hooks(self):
        for h in self._hooks:
            try:
                h.remove()
            except Exception:
                pass
        self._hooks = []
        self._last_captured_layers = []

    def _register_decoder_hooks(self, decoder_layers: List[nn.Module]):
        self._clear_hooks()
        captured = []
        def make_hook(i):
            def hook(module, inp, out):
                # 提取输出张量（处理tuple/list）
                o = out[0] if isinstance(out, (tuple, list)) else out
                captured.append(o)
            return hook
        for idx, layer_module in enumerate(decoder_layers):
            h = layer_module.register_forward_hook(make_hook(idx))
            self._hooks.append(h)
        self._last_captured_layers = captured

    def call_model_and_get_all_layer_logits(self, imgs_list, return_mean: bool = False):
        """
        imgs_list: can be
           - list of torch.Tensor images (each [3,H,W]) -> stack into [B,3,H,W]
           - torch.Tensor batch [B,3,H,W]
           - NestedTensor instance (with .tensors and .mask) from tmm.py
        return_mean: if True, return mean logits across layers (i.e. [B,Q,C]); otherwise return [L,B,Q,C]
        """
        decoder_layers = self._find_decoder_layers()
        self._register_decoder_hooks(decoder_layers)

        # === Normalize input for model invocation ===
        is_nested = False
        batch_size = None

        # If input is list of tensors
        if isinstance(imgs_list, list) and all(isinstance(x, torch.Tensor) for x in imgs_list):
            imgs_tensor = torch.stack(imgs_list, dim=0).to(self.device)  # [B,3,H,W]
            batch_size = imgs_tensor.shape[0]
            model_input = imgs_tensor
        # If it's already a torch.Tensor (batch)
        elif isinstance(imgs_list, torch.Tensor):
            imgs_tensor = imgs_list.to(self.device)
            batch_size = imgs_tensor.shape[0]
            model_input = imgs_tensor
        else:
            # Could be NestedTensor or other object expected by DETR
            # Detect by duck-typing: has attribute 'tensors'
            if hasattr(imgs_list, 'tensors'):
                # treat as NestedTensor-like (we pass it directly to model)
                is_nested = True
                model_input = imgs_list  # pass as-is (expected by DETR.forward)
                # determine batch size from underlying tensors if possible
                try:
                    batch_size = imgs_list.tensors.shape[0]
                except Exception:
                    # fallback: None
                    batch_size = None
            else:
                # last resort: try to convert to tensor
                try:
                    imgs_tensor = torch.stack(list(imgs_list), dim=0).to(self.device)
                    batch_size = imgs_tensor.shape[0]
                    model_input = imgs_tensor
                except Exception as e:
                    raise RuntimeError("Unsupported imgs_list type for GSE: must be list/tensor/NestedTensor") from e

        # call model (pass NestedTensor if that's what DETR expects)
        out = self.model(model_input)

        captured = self._last_captured_layers
        self._clear_hooks()

        if len(captured) == 0:
            raise RuntimeError("未捕获到解码器层输出")

        # captured is a list of tensors (each layer output). Normalize them to [B,Q,D]
        normalized = []
        for t in captured:
            # t expected shape either [Q, B, D] or [B, Q, D] or possibly [B, Q, D] with Q=100 as in DETR
            if not isinstance(t, torch.Tensor):
                raise RuntimeError("捕获到的解码器输出不是tensor类型")
            if t.dim() != 3:
                raise RuntimeError(f"解码器输出维度错误：预期3维，实际{t.dim()}维")
            # try to detect ordering: if first dim equals queries (commonly 100) and second equals batch
            # but we can't assume batch dimension from model_input in all cases; use determined batch_size when available
            # If batch_size is known and t.shape[1] == batch_size and t.shape[0] != batch_size -> assume [Q,B,D]
            if batch_size is not None and t.shape[1] == batch_size and t.shape[0] != batch_size:
                # common DETR case: [Q,B,D] -> transpose
                t_proc = t.transpose(0, 1).contiguous()  # -> [B,Q,D]
            else:
                # otherwise assume it's already [B,Q,D]
                t_proc = t.contiguous()
            normalized.append(t_proc)

        # apply class_embed to each layer output to get logits [B,Q,C]
        logits_per_layer = []
        for layer_out in normalized:
            try:
                logits = self.model.class_embed(layer_out)  # [B,Q,C]
            except Exception:
                # some models may expect different ordering; try transpose fallback
                logits = self.model.class_embed(layer_out.transpose(0,1)).transpose(0,1)
            logits_per_layer.append(logits)

        all_logits = torch.stack(logits_per_layer, dim=0)  # [L, B, Q, C]

        if return_mean:
            return all_logits.mean(dim=0)  # [B,Q,C]
        return all_logits  # [L,B,Q,C]

    def __call__(self, imgs_list, return_all_layers: bool = False):
        # return_all_layers True -> return [L,B,Q,C]; False -> mean across layers [B,Q,C]
        return self.call_model_and_get_all_layer_logits(imgs_list, return_mean=(not return_all_layers))

Writing gse.py


# train.py测试pipeline

In [4]:
# %%writefile train.py
# 文件: /opt/data/private/BlackBox/train.py
import os
import torch
from torch.utils.data import DataLoader
from torchvision.utils import save_image
import torchvision.transforms as T
from PIL import Image
import numpy as np

# 修改为你的工程模块导入路径（假设 train.py 与这些模块处于同一 package 目录）
from inria_dataloader import get_inria_dataloader
from tmm import TransformerMaskingMatrix, load_detr_r50
from gse import GradientSelfEnsemble
from loss import BlackBoxLoss

# -----------------------
# 配置（少量、用于 demo）
# -----------------------
ROOT = "/opt/data/private/BlackBox"
DATA_ROOT = os.path.join(ROOT, "data", "INRIAPerson")
SAVE_DIR = os.path.join(ROOT, "save", "demo")
os.makedirs(SAVE_DIR, exist_ok=True)

BATCH_SIZE = 2          # demo 小 batch
NUM_ITERS = 5           # 试验迭代次数（短）
PATCH_SIZE = (300,300)  # 论文示例大小
PATCH_INIT_STD = 0.1    # patch 初始噪声尺度
PATCH_POS = (100, 100)  # 左上角坐标 (x_start, y_start) — demo 占位
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# -----------------------
# 辅助函数
# -----------------------
def tensor_to_pil(img_tensor: torch.Tensor):
    """把[3,H,W] 或 [B,3,H,W] 的 tensor (0..1 or normalized) 转 PIL
       这里假设传入的是未经归一化到 ImageNet 的 ImageTensor (0..1).
       如果你的 transform 包含 Normalize，请先反 normalize。
    """
    if img_tensor.dim() == 4:
        img_tensor = img_tensor[0]
    img = img_tensor.detach().cpu().clamp(0,1)
    arr = (img.numpy().transpose(1,2,0) * 255).astype(np.uint8)
    return Image.fromarray(arr)

def apply_patch_to_image_batch(images: torch.Tensor, patch: torch.Tensor, pos=(100,100)):
    """
    images: [B,3,H,W], values expected in [0,1]
    patch:  [1,3,ph,pw] or [3,ph,pw] values in same scale
    pos: (x_start, y_start)
    returns patched_images (new tensor, no in-place)
    """
    B, C, H, W = images.shape
    if patch.dim() == 4 and patch.shape[0] == 1:
        p = patch[0]
    elif patch.dim() == 3:
        p = patch
    else:
        raise ValueError("patch shape expect [1,3,ph,pw] or [3,ph,pw]")

    ph, pw = p.shape[1], p.shape[2]
    x0, y0 = pos
    if x0 + pw > W or y0 + ph > H:
        raise ValueError("patch does not fit in image at given pos")

    # create padded_patch [B,3,H,W] by broadcasting p into place
    padded = torch.zeros_like(images)
    # use out-of-place assignment via slice on copy
    padded[:,:, y0:y0+ph, x0:x0+pw] = p.unsqueeze(0).expand(B, -1, -1, -1)
    patched = images * (1.0 - (padded>0).float()) + padded * (padded>0).float()
    return patched

# -----------------------
# 加载数据（demo：只取 Train 中少量）
# -----------------------
print("加载 INRIAPerson dataloader...")
dataloader = get_inria_dataloader(DATA_ROOT, split="Train", batch_size=BATCH_SIZE, num_workers=0)

# -----------------------
# 加载模型
# -----------------------
print("加载 DETR-R50 模型（demo load）...")
model = load_detr_r50()  # 你提供的 tmm.py 中有 load_detr_r50
model = model.to(DEVICE).train()
# 冻结模型参数（只优化 patch）
for p in model.parameters():
    p.requires_grad = False

# -----------------------
# 初始化 TMM（注册 hooks）
# -----------------------
tmm = TransformerMaskingMatrix(num_enc_layers=6, num_dec_layers=6, p_base=0.2, sampling_strategy='categorical', device=DEVICE)
tmm.register_hooks(model)
# 安全：清空历史（初始）
tmm.reset_grad_history()

# -----------------------
# 初始化 GSE
# -----------------------
gse = GradientSelfEnsemble(model=model, device=DEVICE)

# -----------------------
# 初始化 Loss (BlackBoxLoss)
# -----------------------
loss_fn = BlackBoxLoss(gse=gse, target_class=1,
                       detection_weight=1.0,
                       tv_weight=1e-3,
                       l2_weight=0.0,            # 可选 0
                       layer_aggregation='per_layer_loss',
                       use_sigmoid_for_binary=False,  # DETR has multi-class logits
                       device=DEVICE)

# -----------------------
# 初始化补丁与优化器
# -----------------------
# patch 范围使用 [0,1] 颜色空间（与 dataset transform 一致）
ph, pw = PATCH_SIZE
patch = torch.randn(1, 3, ph, pw, device=DEVICE) * PATCH_INIT_STD + 0.5
patch = patch.clamp(0.0, 1.0)
patch.requires_grad_(True)
optimizer = torch.optim.Adam([patch], lr=0.005)

# -----------------------
# Demo 训练 loop（短循环）
# -----------------------
print("开始 demo 训练循环（小规模）...")
it = 0
for epoch in range(1):
    for batch_idx, (imgs, boxes_list) in enumerate(dataloader):
        if it >= NUM_ITERS:
            break
        imgs = imgs.to(DEVICE)              # [B,3,H,W]
        # ensure in 0..1 range (dataset ToTensor should have done this)
        imgs = imgs.clamp(0,1)

        # 1) construct patched images by pasting patch at fixed position
        padded_patch = torch.nn.functional.pad(patch, (PATCH_POS[0], imgs.shape[-1]-PATCH_POS[0]-pw, PATCH_POS[1], imgs.shape[-2]-PATCH_POS[1]-ph))
        # padded_patch shape [1,3,H,W], expand to batch later
        patched_imgs = imgs * 1.0  # copy
        # do non-inplace combination
        mask = (padded_patch > 0).float()
        patched_imgs = imgs * (1.0 - mask) + padded_patch.expand(imgs.shape[0], -1, -1, -1) * mask

        # 2) forward + compute loss via GSE and loss_fn
        # NOTE: tmm hooks are active and will inject masks during model.forward
        loss_dict = loss_fn(imgs, patched_imgs, patch_tensor=patch)
        total_loss = loss_dict['total_loss']

        # 3) backward -> optimizer step
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        # optional clamp to valid pixel range
        with torch.no_grad():
            patch.clamp_(0.0, 1.0)

        # 4) save debug/visualization for the first sample in batch
        # Save original, patched, and the patch itself for inspection
        orig = imgs[0].detach().cpu()
        patched0 = patched_imgs[0].detach().cpu()
        single_patch = patch[0].detach().cpu()

        save_image(orig, os.path.join(SAVE_DIR, f"iter_{it+1}_orig.png"))
        save_image(patched0, os.path.join(SAVE_DIR, f"iter_{it+1}_patched.png"))
        save_image(single_patch, os.path.join(SAVE_DIR, f"iter_{it+1}_patch.png"))

        # print status
        print(f"[iter {it+1}] total_loss={total_loss.item():.6f} | det_loss={loss_dict['det_loss'].item():.6f} | tv={loss_dict['tv_loss'].item():.6f} | l2={loss_dict['l2_loss'].item():.6f} | grad_history_layers={len(tmm.grad_history)}")

        it += 1

    if it >= NUM_ITERS:
        break

# -----------------------
# 清理 hooks
# -----------------------
tmm.remove_hooks()
print("DEMO 完成。可视化图像已保存在：", SAVE_DIR)

加载 INRIAPerson dataloader...
加载 DETR-R50 模型（demo load）...


Using cache found in /root/.cache/torch/hub/facebookresearch_detr_main


✅ TMM已移除所有hook
✅ TMM已注册24个hook（6 encoder + 6 decoder）
✅ 采样策略：categorical（符合论文设置）
开始 demo 训练循环（小规模）...
[iter 1] total_loss=0.058981 | det_loss=0.058755 | tv=0.225636 | l2=0.260113 | grad_history_layers=11
[iter 2] total_loss=0.016582 | det_loss=0.016356 | tv=0.225856 | l2=0.260133 | grad_history_layers=11
[iter 3] total_loss=0.023158 | det_loss=0.022932 | tv=0.226138 | l2=0.260166 | grad_history_layers=11
[iter 4] total_loss=0.060522 | det_loss=0.060296 | tv=0.226394 | l2=0.260196 | grad_history_layers=11
[iter 5] total_loss=0.009781 | det_loss=0.009554 | tv=0.226566 | l2=0.260223 | grad_history_layers=11
✅ TMM已移除所有hook
DEMO 完成。可视化图像已保存在： /opt/data/private/BlackBox/save/demo


# 添加功能GPT train：

# 替换dataloader：

In [3]:
# %%writefile inria_dataloader.py
# 文件名：inria_dataloader.py
import os
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
from PIL import Image
import json

class INRIAPersonDataset(Dataset):
    """INRIA Person数据集加载器（严格对齐论文训练设置）"""
    def __init__(self, data_root, split="Train", augment=True, disable_random_aug=False):
        self.split = split
        self.augment = augment
        self.disable_random_aug = disable_random_aug
        # 加载解析后的标注
        self.annotations = json.load(open(os.path.join(data_root, f"inria_{split}_annotations.json"), "r"))
        # 论文数据增强流水线（图2隐含操作）
        self.transform = self._build_transform()

    def _build_transform(self):
        base_transform = [T.Resize((640, 640)),  # 论文隐含输入尺寸（适配DETR）
                         T.ToTensor()]
        # 判断是否禁用增强
        if not self.disable_random_aug and self.augment and self.split == "Train":
            augment_transform = [
                T.RandomHorizontalFlip(p=0.5),
                T.RandomRotation(degrees=10),
                T.ColorJitter(brightness=0.2, contrast=0.2)
            ]
            return T.Compose(augment_transform + base_transform)
        else:
            return T.Compose(base_transform)

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, idx):
        sample = self.annotations[idx]
        # 加载图像
        img = Image.open(sample["image_path"]).convert("RGB")
        img_tensor = self.transform(img)
        # 加载边界框（仅训练时用于辅助观察，论文攻击损失不依赖标注）
        boxes = torch.tensor(sample["boxes"], dtype=torch.float32)
        return img_tensor, boxes

def get_inria_dataloader(data_root, split="Train", batch_size=8, num_workers=1, disable_random_aug=False):
    """获取数据加载器（论文4.1节：batch_size=8）"""
    dataset = INRIAPersonDataset(data_root, split=split, augment=(split=="Train"),
                                 disable_random_aug=disable_random_aug)
    return DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=(split=="Train"),
        num_workers=num_workers,
        pin_memory=True if torch.cuda.is_available() else False,
        collate_fn=lambda x: (torch.stack([i[0] for i in x]), [i[1] for i in x])
    )

Overwriting inria_dataloader.py


# 替换train（旧逻辑错误）：gpt

In [13]:
# /opt/data/private/BlackBox/train.py
import os
import torch
from torch.utils.data import DataLoader
from torchvision.utils import save_image, draw_bounding_boxes
from torchvision.ops import box_convert, nms
from torch.nn.functional import interpolate
from PIL import Image
import numpy as np

from inria_dataloader import get_inria_dataloader
from tmm import TransformerMaskingMatrix, load_detr_r50, NestedTensor
from gse import GradientSelfEnsemble
from loss import BlackBoxLoss

# -----------------------
# Config (可调整)
# -----------------------
ROOT = "/opt/data/private/BlackBox"
DATA_ROOT = os.path.join(ROOT, "data", "INRIAPerson")
SAVE_DIR = os.path.join(ROOT, "save", "demo")
os.makedirs(SAVE_DIR, exist_ok=True)

BATCH_SIZE = 2
NUM_ITERS = 300

PATCH_BASE = 300               # initial global patch size (square)
PATCH_INIT_STD = 0.1
PATCH_RATIO = 0.15             # T-SEA recommended ratio
MIN_PATCH_PX = 16               # minimum patch side in px (for production use 1; debug: set to 8/16)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

MODEL_INPUT_H, MODEL_INPUT_W = 640, 640  # dataloader resize (H, W)
TARGET_CLASS_IDX = 1
SCORE_THRESH = 0.5             # keep candidates above this
FALLBACK_TO_TOP = True
FALLBACK_SCORE_THRESH = 0.2    # only fallback top-1 if its score >= this
IOU_NMS_THRESH = 0.5           # NMS IoU threshold
MIN_BOX_SIDE = 5               # ignore boxes with width or height < this (px)

# -----------------------
# Helpers
# -----------------------
def detach_cpu(img: torch.Tensor):
    """return CPU float tensor in 0..1"""
    return img.detach().cpu().clamp(0,1)

def draw_boxes_on_tensor(img_tensor: torch.Tensor, boxes_xyxy_cpu: torch.Tensor):
    """
    img_tensor: [3,H,W] float on CPU
    boxes_xyxy_cpu: [N,4] CPU float xmin,ymin,xmax,ymax
    """
    if boxes_xyxy_cpu is None or boxes_xyxy_cpu.numel() == 0:
        return img_tensor
    img_uint8 = (img_tensor * 255).byte()
    boxes = boxes_xyxy_cpu.clone()
    H, W = img_tensor.shape[1], img_tensor.shape[2]
    boxes[:, [0,2]] = boxes[:, [0,2]].clamp(0, W-1)
    boxes[:, [1,3]] = boxes[:, [1,3]].clamp(0, H-1)
    valid = (boxes[:,2] > boxes[:,0]) & (boxes[:,3] > boxes[:,1])
    boxes = boxes[valid]
    if boxes.shape[0] == 0:
        return img_tensor
    boxes_int = boxes.to(torch.int64)
    img_boxes = draw_bounding_boxes(img_uint8, boxes=boxes_int, colors="red", width=2)
    return img_boxes.float() / 255.0

def detr_boxes_to_xyxy_pixel(pred_boxes):
    """
    pred_boxes: [Q,4] cx,cy,w,h (normalized 0..1 or absolute)
    returns [Q,4] xyxy in pixel coords (CPU tensor)
    """
    pb = pred_boxes.clone()
    if pb.max() <= 1.01:
        pb[:,0] = pb[:,0] * MODEL_INPUT_W
        pb[:,1] = pb[:,1] * MODEL_INPUT_H
        pb[:,2] = pb[:,2] * MODEL_INPUT_W
        pb[:,3] = pb[:,3] * MODEL_INPUT_H
    xyxy = box_convert(pb, in_fmt='cxcywh', out_fmt='xyxy')
    return xyxy.cpu()

def paste_patch_centered(base_img: torch.Tensor, patch_tensor: torch.Tensor, center_xy: tuple):
    """
    base_img: [3,H,W] float on device
    patch_tensor: [1,3,ph,pw] or [3,ph,pw] on same device
    center_xy: (cx, cy) pixel coords (float)
    returns new image [3,H,W] (device) with patch pasted (non-inplace)
    """
    if patch_tensor.dim() == 4 and patch_tensor.shape[0] == 1:
        p = patch_tensor[0]
    elif patch_tensor.dim() == 3:
        p = patch_tensor
    else:
        raise ValueError("invalid patch shape")

    ph, pw = p.shape[1], p.shape[2]
    cx, cy = int(round(center_xy[0])), int(round(center_xy[1]))
    x0 = cx - pw // 2
    y0 = cy - ph // 2

    H, W = base_img.shape[1], base_img.shape[2]
    src_x0, src_y0 = 0, 0
    dst_x0, dst_y0 = x0, y0
    dst_x1, dst_y1 = x0 + pw, y0 + ph

    if dst_x0 < 0:
        src_x0 = -dst_x0; dst_x0 = 0
    if dst_y0 < 0:
        src_y0 = -dst_y0; dst_y0 = 0
    if dst_x1 > W:
        dst_x1 = W
    if dst_y1 > H:
        dst_y1 = H

    out_w = dst_x1 - dst_x0
    out_h = dst_y1 - dst_y0
    if out_w <= 0 or out_h <= 0:
        return base_img.clone()

    src_x1 = src_x0 + out_w
    src_y1 = src_y0 + out_h
    p_cropped = p[:, src_y0:src_y1, src_x0:src_x1]

    new_img = base_img.clone()
    mask = (p_cropped > 0).float()
    new_img[:, dst_y0:dst_y1, dst_x0:dst_x1] = new_img[:, dst_y0:dst_y1, dst_x0:dst_x1] * (1.0 - (mask>0).float()) + p_cropped * (mask>0).float()
    return new_img

# -----------------------
# Data and model init
# -----------------------
dataloader = get_inria_dataloader(DATA_ROOT, split="Train", batch_size=BATCH_SIZE, num_workers=0, disable_random_aug=True)

model = load_detr_r50().to(DEVICE)
model.eval()
for p in model.parameters():
    p.requires_grad = False

tmm = TransformerMaskingMatrix(num_enc_layers=6, num_dec_layers=6, p_base=0.2, sampling_strategy='categorical', device=DEVICE)
tmm.register_hooks(model)
tmm.reset_grad_history()

gse = GradientSelfEnsemble(model=model, device=DEVICE)
loss_fn = BlackBoxLoss(gse=gse, target_class=TARGET_CLASS_IDX,
                       detection_weight=1.0, tv_weight=1e-3, l2_weight=0.0,
                       layer_aggregation='per_layer_loss', use_sigmoid_for_binary=False,
                       device=DEVICE)

# global square patch
patch = torch.randn(1, 3, PATCH_BASE, PATCH_BASE, device=DEVICE) * PATCH_INIT_STD + 0.7
patch = patch.clamp(0.0, 1.0)
patch.requires_grad_(True)
optimizer = torch.optim.Adam([patch], lr=0.005)

# -----------------------
# Training loop with NMS & fallback control
# -----------------------
print("Start training with NMS de-dup and controlled fallback...")
it = 0
for epoch in range(1):
    for batch_idx, (imgs, _) in enumerate(dataloader):
        if it >= NUM_ITERS:
            break
        imgs = imgs.to(DEVICE).clamp(0,1)
        B = imgs.shape[0]

        # --- STEP A: clean DETR detection (remove TMM hooks)
        tmm.remove_hooks()
        model.eval()
        with torch.no_grad():
            try:
                det_out = model(imgs)
            except Exception:
                det_out = model(NestedTensor(imgs))

        batch_boxes_all = []  # list of CPU tensors per image
        for bi in range(B):
            logits = det_out['pred_logits'][bi]  # [Q,C]
            boxes = det_out['pred_boxes'][bi]    # [Q,4]
            probs = torch.softmax(logits, dim=-1)
            cls_scores = probs[..., TARGET_CLASS_IDX]  # [Q]

            # candidate indices above SCORE_THRESH
            keep_idx = (cls_scores > SCORE_THRESH).nonzero(as_tuple=False).squeeze(1) if (cls_scores > SCORE_THRESH).any() else torch.tensor([], dtype=torch.long, device=cls_scores.device)

            # fallback logic: only fallback if highest-scoring query >= FALLBACK_SCORE_THRESH
            if keep_idx.numel() == 0 and FALLBACK_TO_TOP:
                top_score, top_idx = torch.max(cls_scores, dim=0)
                if top_score.item() >= FALLBACK_SCORE_THRESH:
                    keep_idx = top_idx.unsqueeze(0)
                else:
                    keep_idx = torch.tensor([], dtype=torch.long, device=cls_scores.device)

            if keep_idx.numel() == 0:
                batch_boxes_all.append(torch.empty((0,4), dtype=torch.float32))
                continue

            sel_boxes = boxes[keep_idx]  # [K,4] (cxcywh, may be normalized)
            sel_scores = cls_scores[keep_idx].detach()  # keep on same device as boxes

            # convert to xyxy pixel coords on CPU for NMS (we can nms on device too)
            sel_xyxy = detr_boxes_to_xyxy_pixel(sel_boxes.detach().cpu())  # CPU
            # filter out tiny boxes before NMS
            widths = (sel_xyxy[:,2] - sel_xyxy[:,0])
            heights = (sel_xyxy[:,3] - sel_xyxy[:,1])
            large_mask = (widths >= MIN_BOX_SIDE) & (heights >= MIN_BOX_SIDE)
            if large_mask.sum() == 0:
                batch_boxes_all.append(torch.empty((0,4), dtype=torch.float32))
                continue
            sel_xyxy = sel_xyxy[large_mask]
            sel_scores_cpu = sel_scores.detach().cpu()[large_mask]

            # NMS: needs tensors on same device; we will run on CPU
            try:
                keep_nms = nms(sel_xyxy, sel_scores_cpu, IOU_NMS_THRESH)
            except Exception:
                # if nms complains about dtype/device, move to CPU
                keep_nms = nms(sel_xyxy.cpu(), sel_scores_cpu.cpu(), IOU_NMS_THRESH)

            sel_xyxy_nms = sel_xyxy[keep_nms]
            batch_boxes_all.append(sel_xyxy_nms)  # CPU tensor

        # Save orig and orig boxes (visualize first sample)
        save_image(detach_cpu(imgs[0]), os.path.join(SAVE_DIR, f"iter_{it+1}_orig.png"))
        boxes0_cpu = batch_boxes_all[0] if len(batch_boxes_all)>0 else torch.empty((0,4))
        img_orig_v = draw_boxes_on_tensor(detach_cpu(imgs[0]), boxes0_cpu)
        save_image(img_orig_v, os.path.join(SAVE_DIR, f"iter_{it+1}_orig_boxes.png"))

        # --- STEP B: enable TMM and build patched images by pasting patch on all selected boxes
        tmm.register_hooks(model)
        patched = imgs.clone()
        for bi in range(B):
            sel_boxes_cpu = batch_boxes_all[bi]  # CPU [K,4]
            if sel_boxes_cpu.numel() == 0:
                continue
            sel_boxes_dev = sel_boxes_cpu.to(DEVICE)
            for box in sel_boxes_dev:
                xmin, ymin, xmax, ymax = box.tolist()
                box_w = max(int(xmax - xmin), 1)
                box_h = max(int(ymax - ymin), 1)
                short = min(box_w, box_h)
                side = max(MIN_PATCH_PX, int(round(short * PATCH_RATIO)))
                side = max(1, side)
                patch_resized = interpolate(patch, size=(side, side), mode='bilinear', align_corners=False)
                cx = (xmin + xmax) / 2.0
                cy = (ymin + ymax) / 2.0
                patched_img = paste_patch_centered(patched[bi], patch_resized, center_xy=(cx, cy))
                patched[bi] = patched_img

        # --- STEP C: compute loss and update patch
        loss_dict = loss_fn(imgs, patched, patch_tensor=patch)
        total_loss = loss_dict['total_loss']
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()
        with torch.no_grad():
            patch.clamp_(0.0, 1.0)

        # Save patched and patch (first sample)
        save_image(detach_cpu(patched[0]), os.path.join(SAVE_DIR, f"iter_{it+1}_patched.png"))
        save_image(patch[0].detach().cpu(), os.path.join(SAVE_DIR, f"iter_{it+1}_patch.png"))

        # --- STEP D: optional clean DETR on patched image and save boxes for comparison
        tmm.remove_hooks()
        model.eval()
        with torch.no_grad():
            try:
                det_out_p = model(patched)
            except Exception:
                det_out_p = model(NestedTensor(patched))

        # use first image predictions to visualize
        logits_p = det_out_p['pred_logits'][0]
        boxes_p = det_out_p['pred_boxes'][0]
        probs_p = torch.softmax(logits_p, dim=-1)
        cls_scores_p = probs_p[..., TARGET_CLASS_IDX]
        keep_idx_p = (cls_scores_p > SCORE_THRESH).nonzero(as_tuple=False).squeeze(1) if (cls_scores_p > SCORE_THRESH).any() else torch.tensor([], dtype=torch.long, device=cls_scores_p.device)
        if keep_idx_p.numel() == 0 and FALLBACK_TO_TOP:
            top_score_p, top_idx_p = torch.max(cls_scores_p, dim=0)
            if top_score_p.item() >= FALLBACK_SCORE_THRESH:
                keep_idx_p = top_idx_p.unsqueeze(0)
            else:
                keep_idx_p = torch.tensor([], dtype=torch.long, device=cls_scores_p.device)

        if keep_idx_p.numel() == 0:
            boxes_p_xyxy = torch.empty((0,4))
        else:
            sel_boxes_p = boxes_p[keep_idx_p]
            sel_xyxy_p = detr_boxes_to_xyxy_pixel(sel_boxes_p.detach().cpu())
            # filter tiny + NMS for patched boxes (cpu)
            widths_p = (sel_xyxy_p[:,2] - sel_xyxy_p[:,0])
            heights_p = (sel_xyxy_p[:,3] - sel_xyxy_p[:,1])
            large_mask_p = (widths_p >= MIN_BOX_SIDE) & (heights_p >= MIN_BOX_SIDE)
            if large_mask_p.sum() == 0:
                boxes_p_xyxy = torch.empty((0,4))
            else:
                sel_xyxy_p = sel_xyxy_p[large_mask_p]
                sel_scores_p = cls_scores_p[keep_idx_p].detach().cpu()[large_mask_p]
                keep_nms_p = nms(sel_xyxy_p, sel_scores_p, IOU_NMS_THRESH)
                boxes_p_xyxy = sel_xyxy_p[keep_nms_p]

        img_patched_v = draw_boxes_on_tensor(detach_cpu(patched[0]), boxes_p_xyxy)
        save_image(img_patched_v, os.path.join(SAVE_DIR, f"iter_{it+1}_patched_boxes.png"))

        print(f"[iter {it+1}] total_loss={total_loss.item():.6f} | det_loss={loss_dict['det_loss'].item():.6f} | tv={loss_dict['tv_loss'].item():.6f}")
        it += 1

    if it >= NUM_ITERS:
        break

# cleanup
tmm.remove_hooks()
print("Done. Visuals saved to", SAVE_DIR)

Using cache found in /root/.cache/torch/hub/facebookresearch_detr_main


✅ TMM已移除所有hook
✅ TMM已注册24个hook（6 encoder + 6 decoder）
✅ 采样策略：categorical（符合论文设置）
Start training with NMS de-dup and controlled fallback...
✅ TMM已移除所有hook
✅ TMM已移除所有hook
✅ TMM已注册24个hook（6 encoder + 6 decoder）
✅ 采样策略：categorical（符合论文设置）
✅ TMM已移除所有hook
[iter 1] total_loss=0.143936 | det_loss=0.143711 | tv=0.225462
✅ TMM已移除所有hook
✅ TMM已移除所有hook
✅ TMM已注册24个hook（6 encoder + 6 decoder）
✅ 采样策略：categorical（符合论文设置）
✅ TMM已移除所有hook
[iter 2] total_loss=0.101182 | det_loss=0.100963 | tv=0.219129
✅ TMM已移除所有hook
✅ TMM已移除所有hook
✅ TMM已注册24个hook（6 encoder + 6 decoder）
✅ 采样策略：categorical（符合论文设置）
✅ TMM已移除所有hook
[iter 3] total_loss=0.099098 | det_loss=0.098884 | tv=0.214126
✅ TMM已移除所有hook
✅ TMM已移除所有hook
✅ TMM已注册24个hook（6 encoder + 6 decoder）
✅ 采样策略：categorical（符合论文设置）
✅ TMM已移除所有hook
[iter 4] total_loss=0.061028 | det_loss=0.060819 | tv=0.209241
✅ TMM已移除所有hook
✅ TMM已移除所有hook
✅ TMM已注册24个hook（6 encoder + 6 decoder）
✅ 采样策略：categorical（符合论文设置）
✅ TMM已移除所有hook
[iter 5] total_loss=0.097850 | det_loss=0.097645 | tv=

# 正式train非demo：GPT

## 需要先修改tmm：

In [2]:
%%writefile tmm.py
import torch
import torch.nn as nn
from typing import List, Optional, Dict, Literal
from torch.hub import load_state_dict_from_url


class NestedTensor:
    """匹配DETR的NestedTensor属性（复数tensors）"""
    def __init__(self, tensors: torch.Tensor, mask: Optional[torch.Tensor] = None):
        self.tensors = tensors  # 复数属性，匹配DETR调用
        self.mask = mask if mask is not None else torch.zeros(
            (tensors.shape[0], tensors.shape[2], tensors.shape[3]), 
            dtype=torch.bool, 
            device=tensors.device
        )

    def decompose(self):
        return self.tensors, self.mask

    @property
    def device(self):
        return self.tensors.device


class TransformerMaskingMatrix(nn.Module):
    """严格对齐《BlackBox》论文3.1节TMM模块（保留梯度传播）"""
    def __init__(
        self,
        num_enc_layers: int = 6,
        num_dec_layers: int = 6,
        p_base: float = 0.2,
        sampling_strategy: Literal['categorical', 'bernoulli'] = 'categorical',
        device: Optional[torch.device] = None
    ):
        super().__init__()
        self.num_enc_layers = num_enc_layers
        self.num_dec_layers = num_dec_layers
        self.p_base = p_base
        self.sampling_strategy = sampling_strategy
        self.device = device if device is not None else torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu'
        )

        if self.sampling_strategy not in ['categorical', 'bernoulli']:
            raise ValueError(f"采样策略仅支持'categorical'和'bernoulli'，当前为{self.sampling_strategy}")

        self.grad_history: Dict[str, torch.Tensor] = {}
        self.hooks: List[torch.utils.hooks.RemovableHandle] = []

    def _categorical_mask_sampling(self, grad_abs: torch.Tensor) -> torch.Tensor:
        grad_flat = grad_abs.flatten()
        total_grad = grad_flat.sum()
        num_elements = grad_flat.numel()

        if total_grad < 1e-8:
            prob_dist = torch.ones_like(grad_flat) / num_elements
        else:
            prob_dist = grad_flat / total_grad

        num_to_mask = max(1, int(self.p_base * num_elements))
        indices = torch.multinomial(prob_dist, num_to_mask, replacement=False)
        mask_flat = torch.ones_like(grad_flat)
        mask_flat = mask_flat.scatter_(0, indices, 0.0)

        return mask_flat.view(grad_abs.shape).contiguous()

    def _apply_mask_to_input(self, input_tensor: torch.Tensor, layer_key: str) -> torch.Tensor:
        input_tensor = input_tensor.clone().contiguous()
        input_dim = input_tensor.dim()
    
        if input_dim == 4:  # (B, C, H, W)
            B, C, H, W = input_tensor.shape
            input_seq = input_tensor.flatten(2).permute(2, 0, 1).contiguous()  # (seq_len, B, C)
            masked_seq_list = []
    
            for b in range(B):
                S_len = input_seq.shape[0]  # 当前样本 seq_len
                mask = torch.rand(S_len, 1, C, device=input_tensor.device) > self.p_base
                mask = mask.float()
                masked_seq = input_seq[:, b:b+1, :] * mask
                masked_seq_list.append(masked_seq)
    
            masked_seq = torch.cat(masked_seq_list, dim=1)  # (seq_len, B, C)
            return masked_seq.permute(1, 2, 0).view(B, C, H, W).contiguous()
    
        elif input_dim == 3:  # (B, S, C)
            B, S, C = input_tensor.shape
            masked_list = []
            for b in range(B):
                mask = torch.rand(S, C, device=input_tensor.device) > self.p_base
                mask = mask.float()
                masked_list.append(input_tensor[b] * mask)
            return torch.stack(masked_list, dim=0)
    
        else:
            raise ValueError(f"不支持的输入维度：{input_dim}")

    def _register_layer_hooks(self, layers: nn.ModuleList, prefix: str):
        for layer_idx, layer in enumerate(layers):
            layer_key = f"{prefix}_{layer_idx}"

            def backward_hook(module, grad_in, grad_out, key=layer_key):
                if grad_in[0] is not None:
                    # 存储梯度时仍需detach（不影响传播链）
                    self.grad_history[key] = grad_in[0].abs().detach().clone().contiguous()

            def forward_hook(module, args, key=layer_key):
                input_tensor = args[0]
                return (self._apply_mask_to_input(input_tensor, key),) + args[1:]

            self.hooks.append(layer.register_full_backward_hook(backward_hook, prepend=False))
            self.hooks.append(layer.register_forward_pre_hook(forward_hook))

    def register_hooks(self, model: nn.Module):
        self.remove_hooks()
        base_model = getattr(model, 'module', model)

        assert hasattr(base_model, "transformer"), "模型必须包含transformer属性"
        assert len(base_model.transformer.encoder.layers) >= self.num_enc_layers, "encoder层数不足"
        assert len(base_model.transformer.decoder.layers) >= self.num_dec_layers, "decoder层数不足"

        self._register_layer_hooks(base_model.transformer.encoder.layers, prefix="enc")
        self._register_layer_hooks(base_model.transformer.decoder.layers, prefix="dec")

        print(f"✅ TMM已注册{len(self.hooks)}个hook（{self.num_enc_layers} encoder + {self.num_dec_layers} decoder）")
        print(f"✅ 采样策略：{self.sampling_strategy}（符合论文设置）")

    def remove_hooks(self):
        for hook in self.hooks:
            hook.remove()
        self.hooks = []
        print("✅ TMM已移除所有hook")

    def reset_grad_history(self):
        self.grad_history.clear()

    def forward(self, *args, **kwargs):
        raise NotImplementedError("TMM通过register_hooks()注入掩码，无需调用forward")


def load_detr_r50():
    """加载DETR-R50模型"""
    model = torch.hub.load(
        "facebookresearch/detr:main",
        "detr_resnet50",
        pretrained=False,
        force_reload=False
    )

    weight_url = "https://dl.fbaipublicfiles.com/detr/detr-r50-e632da11.pth"
    checkpoint = load_state_dict_from_url(weight_url, progress=True)
    model.load_state_dict(checkpoint["model"])

    model = model.cuda().train()
    for param in model.parameters():
        param.requires_grad = False  # 冻结模型参数，只优化补丁

    return model


def run_blackbox_whitebox_demo():
    # 1. 加载模型
    print("正在加载DETR-R50模型...")
    model = load_detr_r50()
    print("✅ DETR-R50模型加载完成")

    # 2. 初始化TMM
    tmm = TransformerMaskingMatrix(
        num_enc_layers=6,
        num_dec_layers=6,
        p_base=0.2,
        sampling_strategy='categorical',
        device='cuda'
    )
    tmm.register_hooks(model)

    # 3. 初始化补丁（需要梯度）和优化器
    patch = torch.randn(1, 3, 300, 300, device='cuda', requires_grad=True)  # 关键：requires_grad=True
    optimizer = torch.optim.Adam([patch], lr=0.005)  # 优化器绑定patch

    # 4. 模拟输入图像（无需梯度）
    img = torch.randn(1, 3, 800, 800, device='cuda').clone().contiguous()
    img.requires_grad = False

    # 5. 优化循环
    for iter in range(5):
        optimizer.zero_grad()  # 清零梯度
        tmm.reset_grad_history()

        # 生成掩码（无需梯度）
        mask = torch.zeros_like(img, device='cuda').clone().contiguous()
        mask[:, :, 100:400, 100:400] = 1.0

        # 补丁填充（保留梯度，移除detach()）
        padded_patch = torch.nn.functional.pad(patch, (100, 400, 100, 400)).clone().contiguous()

        # 生成patched_img（保留梯度传播链）
        patched_img = torch.empty_like(img, device='cuda')
        fusion_result = img * (1 - mask) + padded_patch * mask  # 融合逻辑（保留梯度）
        patched_img.copy_(fusion_result.clone().contiguous())  # 仅clone，不detach
        patched_img.requires_grad_(True)  # 确保启用梯度

        # 构造NestedTensor输入模型
        nested_patched_img = NestedTensor(tensors=patched_img)
        outputs = model(nested_patched_img)

        # 计算损失（行人类别置信度）
        pred_logits = outputs['pred_logits']
        person_confidence = torch.sigmoid(pred_logits[..., 1]).mean()
        loss = person_confidence  # 目标：降低行人置信度

        # 反向传播（此时梯度链已连通）
        loss.backward()  # 现在loss能找到需要梯度的patch
        optimizer.step()

        # 补丁裁剪
        with torch.no_grad():
            patch.data = torch.clamp(patch.data, -2.1179, 2.6400)

        print(f"迭代{iter+1}/5 | 行人置信度损失: {loss.item():.4f} | 梯度历史数: {len(tmm.grad_history)}")

    tmm.remove_hooks()
    print("\n✅ 白盒实验核心流程验证完成（梯度传播正常）")


if __name__ == "__main__":
    torch.manual_seed(42)
    torch.cuda.manual_seed(42)
    run_blackbox_whitebox_demo()


Overwriting tmm.py


In [3]:
# %%writefile train.py
# /opt/data/private/BlackBox/train.py
import os
import math
import random
import torch
from torch.utils.data import DataLoader
from torchvision.utils import save_image, draw_bounding_boxes
from torchvision.ops import box_convert, nms
from torch.nn.functional import interpolate
from PIL import Image
import numpy as np
import torchvision.transforms.functional as TF
from torchvision.transforms import InterpolationMode

from inria_dataloader import get_inria_dataloader
from tmm import TransformerMaskingMatrix, load_detr_r50, NestedTensor
from gse import GradientSelfEnsemble
from loss import BlackBoxLoss

# -----------------------
# Config (可调整)
# -----------------------
ROOT = "/opt/data/private/BlackBox"
DATA_ROOT = os.path.join(ROOT, "data", "INRIAPerson")
SAVE_DIR = os.path.join(ROOT, "save", "demo")
os.makedirs(SAVE_DIR, exist_ok=True)

# training params
BATCH_SIZE = 8
NUM_EPOCHS = 10           # 可根据需要放大（paper 使用更大迭代）
NUM_WORKERS = 4

# patch params
PATCH_SIDE = 300          # 固定 global patch side (严格对齐论文)
PATCH_INIT_STD = 0.5
MIN_PATCH_PX = 16         # fallback minimum when resizing
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# model / detection params
MODEL_INPUT_H, MODEL_INPUT_W = 640, 640  # dataloader resize (H, W)
TARGET_CLASS_IDX = 1
SCORE_THRESH = 0.5
FALLBACK_TO_TOP = True
FALLBACK_SCORE_THRESH = 0.2
IOU_NMS_THRESH = 0.5
MIN_BOX_SIDE = 5

# loss weights (默认基于论文设置，可调整)
DETECTION_WEIGHT = 1.0
TV_WEIGHT = 1e-3
NPS_WEIGHT = 0.0

# EoT / augmentation switches (简单实现)
USE_EOT = True
EOT_SCALE = (0.9, 1.1)
EOT_ROT_DEG = (-10, 10)
EOT_BRIGHT = (0.9, 1.1)
EOT_CONTRAST = (0.9, 1.1)

# reproducibility
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

# -----------------------
# Helpers
# -----------------------
def detach_cpu(img: torch.Tensor):
    """return CPU float tensor in 0..1"""
    return img.detach().cpu().clamp(0,1)

def draw_boxes_on_tensor(img_tensor: torch.Tensor, boxes_xyxy_cpu: torch.Tensor):
    """draw boxes (cpu tensor img 3,H,W)"""
    if boxes_xyxy_cpu is None or boxes_xyxy_cpu.numel() == 0:
        return img_tensor
    img_uint8 = (img_tensor * 255).byte()
    boxes = boxes_xyxy_cpu.clone()
    H, W = img_tensor.shape[1], img_tensor.shape[2]
    boxes[:, [0,2]] = boxes[:, [0,2]].clamp(0, W-1)
    boxes[:, [1,3]] = boxes[:, [1,3]].clamp(0, H-1)
    valid = (boxes[:,2] > boxes[:,0]) & (boxes[:,3] > boxes[:,1])
    boxes = boxes[valid]
    if boxes.shape[0] == 0:
        return img_tensor
    boxes_int = boxes.to(torch.int64)
    img_boxes = draw_bounding_boxes(img_uint8, boxes=boxes_int, colors="red", width=2)
    return img_boxes.float() / 255.0

def detr_boxes_to_xyxy_pixel(pred_boxes):
    """
    pred_boxes: [Q,4] cx,cy,w,h (normalized 0..1 or absolute)
    returns [Q,4] xyxy in pixel coords (CPU tensor)
    """
    pb = pred_boxes.clone()
    if pb.max() <= 1.01:
        pb[:,0] = pb[:,0] * MODEL_INPUT_W
        pb[:,1] = pb[:,1] * MODEL_INPUT_H
        pb[:,2] = pb[:,2] * MODEL_INPUT_W
        pb[:,3] = pb[:,3] * MODEL_INPUT_H
    xyxy = box_convert(pb, in_fmt='cxcywh', out_fmt='xyxy')
    return xyxy.cpu()

def eot_transform_patch(patch_tensor: torch.Tensor):
    """
    patch_tensor: [1,3,H,W] on DEVICE
    apply small random scale / rotate / brightness / contrast
    returns transformed patch [1,3,h2,w2]
    """
    if not USE_EOT:
        return patch_tensor
    # to CPU PIL for some transforms, but to keep gradient chain we operate in tensor domain
    _, _, H, W = patch_tensor.shape
    # random scale
    scale = float(np.random.uniform(EOT_SCALE[0], EOT_SCALE[1]))
    new_side = max(1, int(round(PATCH_SIDE * scale)))
    p = interpolate(patch_tensor, size=(new_side, new_side), mode='bilinear', align_corners=False)
    # random rotate (use TF.affine which accepts tensor)
    angle = float(np.random.uniform(EOT_ROT_DEG[0], EOT_ROT_DEG[1]))
    # torchvision's functional.affine expects shape [...,H,W], supports tensors
    # We'll apply rotate around center (no translate, no shear)
    p = TF.affine(
        p,
        angle=angle,
        translate=[0, 0],
        scale=1.0,
        shear=[0.0, 0.0],
        interpolation=InterpolationMode.BILINEAR,  # 替换 resample
        fill=0
        )
    # brightness & contrast by simple scale/add
    b = float(np.random.uniform(EOT_BRIGHT[0], EOT_BRIGHT[1]))
    c = float(np.random.uniform(EOT_CONTRAST[0], EOT_CONTRAST[1]))
    p = torch.clamp((p * c) * b, 0.0, 1.0)
    return p

def paste_patch_via_mask(base_img: torch.Tensor, patch_tensor: torch.Tensor, center_xy: tuple):
    """
    base_img: [3,H,W] float on device
    patch_tensor: [1,3,ph,pw] or [3,ph,pw] on same device
    center_xy: (cx, cy) pixel coords (float)
    returns new image [3,H,W] (device) with patch pasted (non-inplace, gradient-preserving)
    Implemented via mask fusion: out = base*(1-mask) + patch*mask
    """
    if patch_tensor.dim() == 4 and patch_tensor.shape[0] == 1:
        p = patch_tensor[0]
    elif patch_tensor.dim() == 3:
        p = patch_tensor
    else:
        raise ValueError("invalid patch shape")

    ph, pw = p.shape[1], p.shape[2]
    cx, cy = int(round(center_xy[0])), int(round(center_xy[1]))
    x0 = cx - pw // 2
    y0 = cy - ph // 2

    H, W = base_img.shape[1], base_img.shape[2]

    # compute crop ranges
    src_x0, src_y0 = 0, 0
    dst_x0, dst_y0 = x0, y0
    dst_x1, dst_y1 = x0 + pw, y0 + ph

    if dst_x0 < 0:
        src_x0 = -dst_x0; dst_x0 = 0
    if dst_y0 < 0:
        src_y0 = -dst_y0; dst_y0 = 0
    if dst_x1 > W:
        dst_x1 = W
    if dst_y1 > H:
        dst_y1 = H

    out_w = dst_x1 - dst_x0
    out_h = dst_y1 - dst_y0
    if out_w <= 0 or out_h <= 0:
        return base_img.clone()

    src_x1 = src_x0 + out_w
    src_y1 = src_y0 + out_h
    p_cropped = p[:, src_y0:src_y1, src_x0:src_x1]

    # create mask shaped [3,H,W], zeros then set box area to 1
    mask = torch.zeros_like(base_img)
    mask[:, dst_y0:dst_y1, dst_x0:dst_x1] = 1.0

    # build padded_patch with same H,W by padding p_cropped to correct location
    padded_patch = torch.zeros_like(base_img)
    padded_patch[:, dst_y0:dst_y1, dst_x0:dst_x1] = p_cropped

    # fusion (non-inplace)
    fused = base_img * (1.0 - mask) + padded_patch * mask
    return fused

# -----------------------
# Data and model init
# -----------------------
dataloader = get_inria_dataloader(DATA_ROOT, split="Train", batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, disable_random_aug=False)
print("Train dataset size:", len(dataloader.dataset))

model = load_detr_r50().to(DEVICE)
model.eval()
for p in model.parameters():
    p.requires_grad = False

# TMM: register once and keep enabled for entire training (严格按论文)
tmm = TransformerMaskingMatrix(num_enc_layers=6, num_dec_layers=6, p_base=0.2, sampling_strategy='categorical', device=DEVICE)
tmm.register_hooks(model)
tmm.reset_grad_history()

# GSE and loss
gse = GradientSelfEnsemble(model=model, device=DEVICE)
loss_fn = BlackBoxLoss(gse=gse, target_class=TARGET_CLASS_IDX,
                       detection_weight=DETECTION_WEIGHT, tv_weight=TV_WEIGHT, l2_weight=0.0,
                       layer_aggregation='per_layer_loss', use_sigmoid_for_binary=False,
                       device=DEVICE)

# initialize patch (requires_grad=True)
patch = torch.randn(1, 3, PATCH_SIDE, PATCH_SIDE, device=DEVICE) * PATCH_INIT_STD + 0.7
patch = patch.clamp(0.0, 1.0)
patch.requires_grad_(True)
optimizer = torch.optim.Adam([patch], lr=0.005)

print("Start training (TMM enabled during all forwards). Saving to:", SAVE_DIR)

# -----------------------
# Training loop (epoch)
# -----------------------
global_step = 0
for epoch in range(NUM_EPOCHS):
    model.eval()  # model remains eval (weights frozen)
    tmm.reset_grad_history()  # clear grad history at epoch start for stability

    for batch_idx, (imgs, _) in enumerate(dataloader):
        imgs = imgs.to(DEVICE).clamp(0,1)
        B = imgs.shape[0]

        # --- STEP: obtain detections with TMM active (we keep hooks active per-paper)
        with torch.no_grad():
            try:
                det_out = model(imgs)
            except Exception:
                det_out = model(NestedTensor(imgs))

        batch_boxes_all = []
        for bi in range(B):
            logits = det_out['pred_logits'][bi]  # [Q,C]
            boxes = det_out['pred_boxes'][bi]    # [Q,4]
            probs = torch.softmax(logits, dim=-1)
            cls_scores = probs[..., TARGET_CLASS_IDX]  # [Q]

            keep_idx = (cls_scores > SCORE_THRESH).nonzero(as_tuple=False).squeeze(1) if (cls_scores > SCORE_THRESH).any() else torch.tensor([], dtype=torch.long, device=cls_scores.device)
            if keep_idx.numel() == 0 and FALLBACK_TO_TOP:
                top_score, top_idx = torch.max(cls_scores, dim=0)
                if top_score.item() >= FALLBACK_SCORE_THRESH:
                    keep_idx = top_idx.unsqueeze(0)
                else:
                    keep_idx = torch.tensor([], dtype=torch.long, device=cls_scores.device)

            if keep_idx.numel() == 0:
                batch_boxes_all.append(torch.empty((0,4), dtype=torch.float32))
                continue

            sel_boxes = boxes[keep_idx]
            sel_scores = cls_scores[keep_idx].detach()

            sel_xyxy = detr_boxes_to_xyxy_pixel(sel_boxes.detach().cpu())
            widths = (sel_xyxy[:,2] - sel_xyxy[:,0])
            heights = (sel_xyxy[:,3] - sel_xyxy[:,1])
            large_mask = (widths >= MIN_BOX_SIDE) & (heights >= MIN_BOX_SIDE)
            if large_mask.sum() == 0:
                batch_boxes_all.append(torch.empty((0,4), dtype=torch.float32))
                continue
            sel_xyxy = sel_xyxy[large_mask]
            sel_scores_cpu = sel_scores.detach().cpu()[large_mask]

            try:
                keep_nms = nms(sel_xyxy, sel_scores_cpu, IOU_NMS_THRESH)
            except Exception:
                keep_nms = nms(sel_xyxy.cpu(), sel_scores_cpu.cpu(), IOU_NMS_THRESH)
            sel_xyxy_nms = sel_xyxy[keep_nms]
            batch_boxes_all.append(sel_xyxy_nms)

        # --- STEP: build patched images (gradient-preserving fusion)
        # keep TMM hooks enabled during forward of patched imgs (already registered)
        patched = imgs.clone()
        for bi in range(B):
            sel_boxes_cpu = batch_boxes_all[bi]  # CPU [K,4]
            if sel_boxes_cpu.numel() == 0:
                continue
            sel_boxes_dev = sel_boxes_cpu.to(DEVICE)
            for box in sel_boxes_dev:
                xmin, ymin, xmax, ymax = box.tolist()
                box_w = max(int(xmax - xmin), 1)
                box_h = max(int(ymax - ymin), 1)
                short = min(box_w, box_h)
                # use fixed patch size but optionally scale a bit relative to short side
                scale = float(np.clip(short / PATCH_SIDE, 0.5, 2.0))  # relative scale
                # ensure at least MIN_PATCH_PX
                side = max(MIN_PATCH_PX, int(round(PATCH_SIDE * scale)))
                # apply EoT transforms to patch (returns [1,3,side,side])
                patch_to_paste = eot_transform_patch(patch)
                # resize transformed patch to desired side
                patch_resized = interpolate(patch_to_paste, size=(side, side), mode='bilinear', align_corners=False)
                cx = (xmin + xmax) / 2.0
                cy = (ymin + ymax) / 2.0
                # fusion (non-inplace, gradient-preserving)
                patched[bi] = paste_patch_via_mask(patched[bi], patch_resized, center_xy=(cx, cy))

        # --- STEP: compute loss and update patch
        # ensure tmm.grad_history is available for GSE/TMM internal use
        loss_dict = loss_fn(imgs, patched, patch_tensor=patch)
        total_loss = loss_dict['total_loss']
        optimizer.zero_grad()
        total_loss.backward()

        # debug: inspect patch grad
        if patch.grad is None:
            print(f"[epoch {epoch+1} batch {batch_idx}] WARNING: patch.grad is None")
            grad_norm = None
        else:
            grad_norm = patch.grad.detach().cpu().norm().item()

        optimizer.step()
        with torch.no_grad():
            patch.clamp_(0.0, 1.0)

        # debug print
        det_loss_v = loss_dict.get('det_loss', torch.tensor(0.0)).item() if isinstance(loss_dict.get('det_loss', 0.0), torch.Tensor) else float(loss_dict.get('det_loss', 0.0))
        tv_loss_v = loss_dict.get('tv_loss', torch.tensor(0.0)).item() if isinstance(loss_dict.get('tv_loss', 0.0), torch.Tensor) else float(loss_dict.get('tv_loss', 0.0))
        nps_loss_v = loss_dict.get('nps_loss', torch.tensor(0.0)).item() if isinstance(loss_dict.get('nps_loss', 0.0), torch.Tensor) else float(loss_dict.get('nps_loss', 0.0))
        print(f"[epoch {epoch+1} batch {batch_idx}] total_loss={total_loss.item():.6f} | det_loss={det_loss_v:.6f} | tv={tv_loss_v:.6f} | nps={nps_loss_v:.6f} | grad_norm={grad_norm} | selected_counts={[b.shape[0] for b in batch_boxes_all]}")

        # save occasional visual snapshots
        if global_step % 200 == 0:
            # orig, patched, patch
            save_image(detach_cpu(imgs[0]), os.path.join(SAVE_DIR, f"step_{global_step}_orig.png"))
            save_image(detach_cpu(patched[0]), os.path.join(SAVE_DIR, f"step_{global_step}_patched.png"))
            save_image(patch[0].detach().cpu(), os.path.join(SAVE_DIR, f"step_{global_step}_patch.png"))
        global_step += 1

    # end epoch: save epoch patch
    save_image(patch[0].detach().cpu(), os.path.join(SAVE_DIR, f"epoch_{epoch+1}_patch.png"))
    torch.save(patch[0].detach().cpu(), os.path.join(SAVE_DIR, f"epoch_{epoch+1}_patch.pt"))
    print(f"Epoch {epoch+1} saved patch snapshot.")

# cleanup & final save
tmm.remove_hooks()
save_image(patch[0].detach().cpu(), os.path.join(SAVE_DIR, "final_patch.png"))
torch.save(patch[0].detach().cpu(), os.path.join(SAVE_DIR, "final_patch.pt"))
print("Done. Final patch saved to", SAVE_DIR)


Train dataset size: 614


Using cache found in /root/.cache/torch/hub/facebookresearch_detr_main


✅ TMM已移除所有hook
✅ TMM已注册24个hook（6 encoder + 6 decoder）
✅ 采样策略：categorical（符合论文设置）
Start training (TMM enabled during all forwards). Saving to: /opt/data/private/BlackBox/save/demo
[epoch 1 batch 0] total_loss=0.039853 | det_loss=0.039079 | tv=0.774092 | nps=0.000000 | grad_norm=0.005961076822131872 | selected_counts=[3, 1, 8, 3, 9, 1, 5, 12]
[epoch 1 batch 1] total_loss=0.066802 | det_loss=0.066030 | tv=0.771549 | nps=0.000000 | grad_norm=0.00907645933330059 | selected_counts=[6, 1, 1, 1, 5, 17, 1, 2]
[epoch 1 batch 2] total_loss=0.036438 | det_loss=0.035668 | tv=0.770258 | nps=0.000000 | grad_norm=0.00730087049305439 | selected_counts=[1, 3, 7, 3, 3, 10, 1, 5]
[epoch 1 batch 3] total_loss=0.037939 | det_loss=0.037169 | tv=0.769276 | nps=0.000000 | grad_norm=0.007412371225655079 | selected_counts=[1, 1, 1, 10, 4, 4, 8, 1]
[epoch 1 batch 4] total_loss=0.060026 | det_loss=0.059258 | tv=0.768370 | nps=0.000000 | grad_norm=0.007400457747280598 | selected_counts=[5, 8, 3, 3, 2, 3, 7, 0]
[epo